Skip to content

Commit

Permalink
[datasets] Allow benchmark to be None at constructor time.
Browse files Browse the repository at this point in the history
Issue #45.
  • Loading branch information
ChrisCummins committed Apr 28, 2021
1 parent 9bf8a72 commit 0260132
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 30 deletions.
56 changes: 27 additions & 29 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,11 @@ def __init__(
]

# The benchmark that is currently being used, and the benchmark that
# the user requested. Those do not always correlate since env.benchmark
# can be set in an episode, but will not take effect until env.reset()
# is called.
# will be used on the next call to reset(). These are equal except in
# the gap between the user setting the env.benchmark property while in
# an episode and the next call to env.reset().
self._benchmark_in_use: Optional[Benchmark] = None
self._user_specified_benchmark: Optional[Benchmark] = None
self._next_benchmark: Optional[Benchmark] = None
# Normally when the benchmark is changed the updated value is not
# reflected until the next call to reset(). We make an exception for the
# constructor-time benchmark as otherwise the behavior of the benchmark
Expand All @@ -248,8 +248,15 @@ def __init__(
#
# By forcing the _benchmark_in_use URI at constructor time, the first
# env.benchmark above returns the benchmark as expected.
self.benchmark = benchmark or next(self.datasets.benchmarks())
self._benchmark_in_use = self._user_specified_benchmark
try:
self.benchmark = benchmark or next(self.datasets.benchmarks())
self._benchmark_in_use = self._next_benchmark
except StopIteration:
# StopIteration raised on next(self.datasets.benchmarks()) if there
# are no benchmarks available. This is to allow CompilerEnv to be
# used without any datasets by setting a benchmark before/during the
# first reset() call.
pass

# Process the available action, observation, and reward spaces.
self.action_spaces = [
Expand Down Expand Up @@ -387,29 +394,10 @@ def benchmark(self) -> Benchmark:
or the URI of a benchmark as in :meth:`env.datasets.benchmark_uris()
<compiler_gym.datasets.Datasets.benchmark_uris>`.
By default, a benchmark will be selected randomly by the service from
the available benchmarks on a call to :func:`reset`. To force a specific
benchmark to be chosen, set this property (or pass the benchmark as an
argument to :func:`reset`):
>>> env.benchmark = "benchmark://foo"
>>> env.reset()
>>> env.benchmark
benchmark://foo
Once set, all subsequent calls to :func:`env.reset()
<compiler_gym.envs.CompilerEnv.reset>` will select the same benchmark.
>>> env.benchmark = "*"
>>> env.reset() # random benchmark is chosen
.. note::
Setting a new benchmark has no effect until
:func:`env.reset() <compiler_gym.envs.CompilerEnv.reset>` is called.
To return to random benchmark selection, set this property to
:code:`None`:
"""
return self._benchmark_in_use

Expand All @@ -422,10 +410,10 @@ def benchmark(self, benchmark: Union[str, Benchmark]):
if isinstance(benchmark, str):
benchmark_object = self.datasets.benchmark(benchmark)
self.logger.debug("Setting benchmark by name: %s", benchmark_object)
self._user_specified_benchmark = benchmark_object
self._next_benchmark = benchmark_object
elif isinstance(benchmark, Benchmark):
self.logger.debug("Setting benchmark: %s", benchmark.uri)
self._user_specified_benchmark = benchmark
self._next_benchmark = benchmark
else:
raise TypeError(
f"Expected a Benchmark or str, received: '{type(benchmark).__name__}'"
Expand Down Expand Up @@ -555,7 +543,7 @@ def fork(self) -> "CompilerEnv":
# revert the environment to the initial benchmark state.
#
# pylint: disable=protected-access
new_env._user_specified_benchmark = self._benchmark_in_use
new_env._next_benchmark = self._benchmark_in_use

# Set the "visible" name of the current benchmark to hide the fact that
# we loaded from a custom bitcode file.
Expand Down Expand Up @@ -639,7 +627,17 @@ def reset( # pylint: disable=arguments-differ
:raises BenchmarkInitError: If the benchmark is invalid. In this case,
another benchmark must be used.
:raises TypeError: If no benchmark has been set, and the environment
does not have a default benchmark to select from.
"""
if not self._next_benchmark:
raise TypeError(
"No benchmark set. Set a benchmark using "
"`env.reset(benchmark=benchmark)`. Use `env.datasets` to "
"access the available benchmarks."
)

# Start a new service if required.
if self.service is None:
self.service = CompilerGymServiceConnection(
Expand All @@ -659,7 +657,7 @@ def reset( # pylint: disable=arguments-differ
# Update the user requested benchmark, if provided.
if benchmark:
self.benchmark = benchmark
self._benchmark_in_use = self._user_specified_benchmark
self._benchmark_in_use = self._next_benchmark

start_session_request = StartSessionRequest(
benchmark=self._benchmark_in_use.uri,
Expand Down
24 changes: 23 additions & 1 deletion tests/compiler_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import gym
import pytest

from compiler_gym.envs import CompilerEnv
from compiler_gym.envs import CompilerEnv, llvm
from compiler_gym.service.connection import CompilerGymServiceConnection
from tests.test_main import main

pytest_plugins = ["tests.pytest_plugins.llvm"]
Expand Down Expand Up @@ -141,5 +142,26 @@ def test_step_session_id_not_found(env: CompilerEnv):
assert not env.in_episode


@pytest.fixture(scope="function")
def remote_env() -> CompilerEnv:
"""A test fixture that yields a connection to a remote service."""
service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY)
env = CompilerEnv(service=service.connection.url)
try:
yield env
finally:
env.close()
service.close()


def test_base_class_has_no_benchmark(remote_env: CompilerEnv):
"""Test that when instantiating the base CompilerEnv class there are no
datasets available.
"""
assert remote_env.benchmark is None
with pytest.raises(TypeError, match="No benchmark set"):
remote_env.reset()


if __name__ == "__main__":
main()

0 comments on commit 0260132

Please sign in to comment.