Skip to content

Commit

Permalink
Use with statement in place of try/finally for envs.
Browse files Browse the repository at this point in the history
This patch refactors the code pattern `try: ...; finally: env.close()`
to instead use the `with gym.make(...):` pattern. This is preferred
because it automatically handles calling `close()`.
  • Loading branch information
ChrisCummins committed Sep 9, 2021
1 parent 58b5831 commit 50e6377
Show file tree
Hide file tree
Showing 24 changed files with 91 additions and 183 deletions.
23 changes: 10 additions & 13 deletions compiler_gym/leaderboard/llvm_instcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,19 @@ def main(argv):
assert len(argv) == 1, f"Unknown args: {argv[:1]}"
assert FLAGS.n > 0, "n must be > 0"

env = gym.make("llvm-ic-v0")
with gym.make("llvm-ic-v0") as env:

# Stream verbose CompilerGym logs to file.
logger = logging.getLogger("compiler_gym")
logger.setLevel(logging.DEBUG)
env.logger.setLevel(logging.DEBUG)
log_handler = logging.FileHandler(FLAGS.leaderboard_logfile)
logger.addHandler(log_handler)
logger.propagate = False
# Stream verbose CompilerGym logs to file.
logger = logging.getLogger("compiler_gym")
logger.setLevel(logging.DEBUG)
env.logger.setLevel(logging.DEBUG)
log_handler = logging.FileHandler(FLAGS.leaderboard_logfile)
logger.addHandler(log_handler)
logger.propagate = False

print(f"Writing results to {FLAGS.leaderboard_results}")
print(f"Writing logs to {FLAGS.leaderboard_logfile}")
print(f"Writing results to {FLAGS.leaderboard_results}")
print(f"Writing logs to {FLAGS.leaderboard_logfile}")

try:
# Build the list of benchmarks to evaluate.
benchmarks = env.datasets[FLAGS.test_dataset].benchmark_uris()
if FLAGS.max_benchmarks:
Expand Down Expand Up @@ -301,8 +300,6 @@ def main(argv):
worker.alive = False
# User interrupt, don't validate.
FLAGS.validate = False
finally:
env.close()

if FLAGS.validate:
FLAGS.env = "llvm-ic-v0"
Expand Down
2 changes: 2 additions & 0 deletions examples/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ def make_env():

def main(argv):
"""Main entry point."""
del argv # unused

torch.manual_seed(FLAGS.seed)
random.seed(FLAGS.seed)

Expand Down
7 changes: 2 additions & 5 deletions examples/random_walk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@

def test_run_random_walk_smoke_test():
flags.FLAGS(["argv0"])
env = gym.make("llvm-autophase-ic-v0")
env.benchmark = "cbench-v1/crc32"
try:
with gym.make("llvm-autophase-ic-v0") as env:
env.benchmark = "cbench-v1/crc32"
run_random_walk(env=env, step_count=5)
finally:
env.close()


if __name__ == "__main__":
Expand Down
9 changes: 3 additions & 6 deletions examples/tabular_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,17 @@ def main(argv):
q_table: Dict[StateActionTuple, float] = {}
benchmark = benchmark_from_flags()
assert benchmark, "You must specify a benchmark using the --benchmark flag"
env = gym.make("llvm-ic-v0", benchmark=benchmark)
env.observation_space = "Autophase"

try:
with gym.make("llvm-ic-v0", benchmark=benchmark) as env:
env.observation_space = "Autophase"

# Train a Q-table.
with Timer("Constructing Q-table"):
train(q_table, env)

# Rollout resulting policy.
rollout(q_table, env, printout=True)

finally:
env.close()


if __name__ == "__main__":
app.run(main)
5 changes: 2 additions & 3 deletions tests/bin/service_bin_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
@pytest.mark.parametrize("env_name", compiler_gym.COMPILER_GYM_ENVS)
def test_print_service_capabilities_smoke_test(env_name: str):
flags.FLAGS(["argv0"])
env = gym.make(env_name)
print_service_capabilities(env)
env.close()
with gym.make(env_name) as env:
print_service_capabilities(env)


if __name__ == "__main__":
Expand Down
24 changes: 7 additions & 17 deletions tests/compiler_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@
def test_benchmark_constructor_arg(env: CompilerEnv):
env.close() # Fixture only required to pull in dataset.

env = gym.make("llvm-v0", benchmark="cbench-v1/dijkstra")
try:
with gym.make("llvm-v0", benchmark="cbench-v1/dijkstra") as env:
assert env.benchmark == "benchmark://cbench-v1/dijkstra"
finally:
env.close()


def test_benchmark_setter(env: CompilerEnv):
Expand All @@ -39,14 +36,10 @@ def test_benchmark_set_in_reset(env: CompilerEnv):

def test_logger_forced():
logger = logging.getLogger("test_logger")
env_a = gym.make("llvm-v0")
env_b = gym.make("llvm-v0", logger=logger)
try:
assert env_a.logger != logger
assert env_b.logger == logger
finally:
env_a.close()
env_b.close()
with gym.make("llvm-v0") as env_a:
with gym.make("llvm-v0", logger=logger) as env_b:
assert env_a.logger != logger
assert env_b.logger == logger


def test_uri_substring_no_match(env: CompilerEnv):
Expand Down Expand Up @@ -120,14 +113,11 @@ def test_gym_make_kwargs():
"""Test that passing kwargs to gym.make() are forwarded to environment
constructor.
"""
env = gym.make(
with gym.make(
"llvm-v0", observation_space="Autophase", reward_space="IrInstructionCount"
)
try:
) as env:
assert env.observation_space_spec.id == "Autophase"
assert env.reward_space.id == "IrInstructionCount"
finally:
env.close()


def test_step_session_id_not_found(env: CompilerEnv):
Expand Down
9 changes: 3 additions & 6 deletions tests/fuzzing/llvm_random_actions_fuzz_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,14 @@
@pytest.mark.timeout(600)
def test_fuzz(benchmark_name: str):
"""Run randomly selected actions on a benchmark until a minimum amount of time has elapsed."""
env = gym.make(
with gym.make(
"llvm-v0",
reward_space="IrInstructionCount",
observation_space="Autophase",
benchmark=benchmark_name,
)
env.reset()
) as env:
env.reset()

try:
# Take a random step until a predetermined amount of time has elapsed.
end_time = time() + FUZZ_TIME_SECONDS
while time() < end_time:
Expand All @@ -52,8 +51,6 @@ def test_fuzz(benchmark_name: str):
assert isinstance(observation, np.ndarray)
assert observation.shape == (AUTOPHASE_FEATURE_DIM,)
assert isinstance(reward, float)
finally:
env.close()


if __name__ == "__main__":
Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/custom_benchmarks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,9 @@ def test_custom_benchmark(env: LlvmEnv):

def test_custom_benchmark_constructor():
benchmark = Benchmark.from_file("benchmark://new", EXAMPLE_BITCODE_FILE)
env = gym.make("llvm-v0", benchmark=benchmark)
try:
with gym.make("llvm-v0", benchmark=benchmark) as env:
env.reset()
assert env.benchmark == "benchmark://new"
finally:
env.close()


def test_make_benchmark_single_bitcode(env: LlvmEnv):
Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/datasets/anghabench_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@

@pytest.fixture(scope="module")
def anghabench_dataset() -> AnghaBenchDataset:
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
ds = env.datasets["anghabench-v1"]
finally:
env.close()
yield ds


Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/datasets/chstone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@

@pytest.fixture(scope="module")
def chstone_dataset() -> CHStoneDataset:
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
ds = env.datasets["chstone-v0"]
finally:
env.close()
yield ds


Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/datasets/clgen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@

@pytest.fixture(scope="module")
def clgen_dataset() -> CLgenDataset:
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
ds = env.datasets["benchmark://clgen-v0"]
finally:
env.close()
yield ds


Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/datasets/csmith_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@

@pytest.fixture(scope="module")
def csmith_dataset() -> CsmithDataset:
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
ds = env.datasets["generator://csmith-v0"]
finally:
env.close()
yield ds


Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/datasets/github_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,8 @@

@pytest.fixture(scope="module")
def github_dataset() -> GitHubDataset:
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
ds = env.datasets["github-v0"]
finally:
env.close()
yield ds


Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/datasets/llvm_datasets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@


def test_default_dataset_list():
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
assert list(d.name for d in env.datasets) == [
"benchmark://cbench-v1",
"benchmark://anghabench-v1",
Expand All @@ -28,8 +27,6 @@ def test_default_dataset_list():
"generator://csmith-v0",
"generator://llvm-stress-v0",
]
finally:
env.close()


if __name__ == "__main__":
Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/datasets/llvm_stress_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,8 @@

@pytest.fixture(scope="module")
def llvm_stress_dataset() -> LlvmStressDataset:
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
ds = env.datasets["generator://llvm-stress-v0"]
finally:
env.close()
yield ds


Expand Down
5 changes: 1 addition & 4 deletions tests/llvm/datasets/poj104_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@

@pytest.fixture(scope="module")
def poj104_dataset() -> POJ104Dataset:
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
ds = env.datasets["poj104-v1"]
finally:
env.close()
yield ds


Expand Down
6 changes: 1 addition & 5 deletions tests/llvm/gym_interface_compatability.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,7 @@

@pytest.fixture(scope="function")
def env() -> CompilerEnv:
env = gym.make("llvm-autophase-ic-v0")
try:
yield env
finally:
env.close()
return gym.make("llvm-autophase-ic-v0")


def test_type_classes(env: CompilerEnv):
Expand Down
16 changes: 4 additions & 12 deletions tests/llvm/llvm_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,14 @@
def env(request) -> CompilerEnv:
"""Create an LLVM environment."""
if request.param == "local":
env = gym.make("llvm-v0")
try:
with gym.make("llvm-v0") as env:
yield env
finally:
env.close()
else:
service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY)
env = LlvmEnv(service=service.connection.url)
try:
yield env
with LlvmEnv(service=service.connection.url) as env:
yield env
finally:
env.close()
service.close()


Expand Down Expand Up @@ -165,13 +161,9 @@ def test_apply_state(env: LlvmEnv):
env.reset(benchmark="cbench-v1/crc32")
env.step(env.action_space.flags.index("-mem2reg"))

other = gym.make("llvm-v0", reward_space="IrInstructionCount")
try:
with gym.make("llvm-v0", reward_space="IrInstructionCount") as other:
other.apply(env.state)

assert other.state == env.state
finally:
other.close()


def test_set_observation_space_from_spec(env: LlvmEnv):
Expand Down
Loading

0 comments on commit 50e6377

Please sign in to comment.