From 12b4414beb2107e29269a22c460cc6dd3b6262d4 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 9 Sep 2021 12:42:04 +0100 Subject: [PATCH] Use `with` statement in place of try/finally for envs. This patch refactors the code pattern `try: ...; finally: env.close()` to instead use the `with gym.make(...):` pattern. This is preferred because it automatically handles calling `close()`. --- benchmarks/parallelization_load_test.py | 5 +- compiler_gym/bin/datasets.py | 5 +- compiler_gym/bin/random_replay.py | 6 +- compiler_gym/bin/random_search.py | 10 +-- compiler_gym/bin/service.py | 5 +- compiler_gym/bin/validate.py | 5 +- compiler_gym/leaderboard/llvm_instcount.py | 23 +++--- compiler_gym/random_replay.py | 1 - compiler_gym/random_search.py | 23 ++---- compiler_gym/util/flags/env_from_flags.py | 5 +- compiler_gym/validate.py | 5 +- docs/source/llvm/index.rst | 3 +- examples/actor_critic.py | 2 + examples/brute_force.py | 77 +++++++++---------- examples/explore.py | 2 +- examples/random_walk.py | 10 +-- examples/random_walk_test.py | 7 +- .../action_sensitivity_analysis.py | 8 +- .../benchmark_sensitivity_analysis.py | 6 +- examples/tabular_q.py | 9 +-- tests/bin/service_bin_test.py | 5 +- tests/compiler_env_test.py | 29 +++---- .../fuzzing/llvm_random_actions_fuzz_test.py | 9 +-- tests/llvm/all_benchmarks_init_close_test.py | 1 - tests/llvm/custom_benchmarks_test.py | 5 +- tests/llvm/datasets/anghabench_test.py | 5 +- tests/llvm/datasets/chstone_test.py | 5 +- tests/llvm/datasets/clgen_test.py | 5 +- tests/llvm/datasets/csmith_test.py | 5 +- tests/llvm/datasets/github_test.py | 5 +- tests/llvm/datasets/llvm_datasets_test.py | 5 +- tests/llvm/datasets/llvm_stress_test.py | 5 +- tests/llvm/datasets/poj104_test.py | 5 +- tests/llvm/gym_interface_compatability.py | 6 +- tests/llvm/llvm_env_test.py | 16 +--- tests/llvm/multiprocessing_test.py | 42 +++++----- tests/llvm/service_connection_test.py | 10 +-- tests/llvm/threading_test.py | 34 ++++---- tests/llvm/validate_test.py | 25 ++---- tests/pytest_plugins/llvm.py | 7 +- tests/random_search_test.py | 5 +- tests/service/connection_test.py | 10 +-- tests/wrappers/core_wrappers_test.py | 8 +- 43 files changed, 169 insertions(+), 300 deletions(-) diff --git a/benchmarks/parallelization_load_test.py b/benchmarks/parallelization_load_test.py index 6f84f7d49..635c4af2b 100644 --- a/benchmarks/parallelization_load_test.py +++ b/benchmarks/parallelization_load_test.py @@ -42,16 +42,13 @@ def run_random_search(num_episodes, num_steps) -> None: """The inner loop of a load test benchmark.""" - env = env_from_flags(benchmark=benchmark_from_flags()) - try: + with env_from_flags(benchmark=benchmark_from_flags()) as env: for _ in range(num_episodes): env.reset() for _ in range(num_steps): _, _, done, _ = env.step(env.action_space.sample()) if done: break - finally: - env.close() def main(argv): diff --git a/compiler_gym/bin/datasets.py b/compiler_gym/bin/datasets.py index cde0dffce..582f31cc0 100644 --- a/compiler_gym/bin/datasets.py +++ b/compiler_gym/bin/datasets.py @@ -148,8 +148,7 @@ def main(argv): if len(argv) != 1: raise app.UsageError(f"Unknown command line arguments: {argv[1:]}") - env = env_from_flags() - try: + with env_from_flags() as env: invalidated_manifest = False for name_or_url in FLAGS.download: @@ -182,8 +181,6 @@ def main(argv): print( summarize_datasets(env.datasets), ) - finally: - env.close() if __name__ == "__main__": diff --git a/compiler_gym/bin/random_replay.py b/compiler_gym/bin/random_replay.py index 394145843..c5079246f 100644 --- a/compiler_gym/bin/random_replay.py +++ b/compiler_gym/bin/random_replay.py @@ -36,9 +36,9 @@ def main(argv): output_dir / logs.METADATA_NAME ).is_file(), f"Invalid --output_dir: {output_dir}" - env = env_from_flags() - benchmark = benchmark_from_flags() - replay_actions_from_logs(env, output_dir, benchmark=benchmark) + with env_from_flags() as env: + benchmark = benchmark_from_flags() + replay_actions_from_logs(env, output_dir, benchmark=benchmark) if __name__ == "__main__": diff --git a/compiler_gym/bin/random_search.py b/compiler_gym/bin/random_search.py index c0e3f4927..baa57cc01 100644 --- a/compiler_gym/bin/random_search.py +++ b/compiler_gym/bin/random_search.py @@ -93,9 +93,8 @@ def main(argv): raise app.UsageError(f"Unknown command line arguments: {argv[1:]}") if FLAGS.ls_reward: - env = env_from_flags() - print("\n".join(sorted(env.reward.indices.keys()))) - env.close() + with env_from_flags() as env: + print("\n".join(sorted(env.reward.indices.keys()))) return assert FLAGS.patience >= 0, "--patience must be >= 0" @@ -103,11 +102,8 @@ def main(argv): def make_env(): return env_from_flags(benchmark=benchmark_from_flags()) - env = make_env() - try: + with make_env() as env: env.reset() - finally: - env.close() best_reward, _ = random_search( make_env=make_env, diff --git a/compiler_gym/bin/service.py b/compiler_gym/bin/service.py index c9c10ad41..44527bfee 100644 --- a/compiler_gym/bin/service.py +++ b/compiler_gym/bin/service.py @@ -229,11 +229,8 @@ def main(argv): """Main entry point.""" assert len(argv) == 1, f"Unrecognized flags: {argv[1:]}" - env = env_from_flags() - try: + with env_from_flags() as env: print_service_capabilities(env) - finally: - env.close() if __name__ == "__main__": diff --git a/compiler_gym/bin/validate.py b/compiler_gym/bin/validate.py index ba498b349..131158172 100644 --- a/compiler_gym/bin/validate.py +++ b/compiler_gym/bin/validate.py @@ -153,8 +153,7 @@ def main(argv): ) # Determine the name of the reward space. - env = env_from_flags() - try: + with env_from_flags() as env: if FLAGS.reward_aggregation == "geomean": def reward_aggregation(a): @@ -173,8 +172,6 @@ def reward_aggregation(a): reward_name = f"{reward_aggregation_name} {env.reward_space.id}" else: reward_name = "" - finally: - env.close() # Determine the maximum column width required for printing tabular output. max_state_name_length = max( diff --git a/compiler_gym/leaderboard/llvm_instcount.py b/compiler_gym/leaderboard/llvm_instcount.py index bbda9baa2..0816c9767 100644 --- a/compiler_gym/leaderboard/llvm_instcount.py +++ b/compiler_gym/leaderboard/llvm_instcount.py @@ -225,20 +225,19 @@ def main(argv): assert len(argv) == 1, f"Unknown args: {argv[:1]}" assert FLAGS.n > 0, "n must be > 0" - env = gym.make("llvm-ic-v0") + with gym.make("llvm-ic-v0") as env: - # Stream verbose CompilerGym logs to file. - logger = logging.getLogger("compiler_gym") - logger.setLevel(logging.DEBUG) - env.logger.setLevel(logging.DEBUG) - log_handler = logging.FileHandler(FLAGS.leaderboard_logfile) - logger.addHandler(log_handler) - logger.propagate = False + # Stream verbose CompilerGym logs to file. + logger = logging.getLogger("compiler_gym") + logger.setLevel(logging.DEBUG) + env.logger.setLevel(logging.DEBUG) + log_handler = logging.FileHandler(FLAGS.leaderboard_logfile) + logger.addHandler(log_handler) + logger.propagate = False - print(f"Writing results to {FLAGS.leaderboard_results}") - print(f"Writing logs to {FLAGS.leaderboard_logfile}") + print(f"Writing results to {FLAGS.leaderboard_results}") + print(f"Writing logs to {FLAGS.leaderboard_logfile}") - try: # Build the list of benchmarks to evaluate. benchmarks = env.datasets[FLAGS.test_dataset].benchmark_uris() if FLAGS.max_benchmarks: @@ -301,8 +300,6 @@ def main(argv): worker.alive = False # User interrupt, don't validate. FLAGS.validate = False - finally: - env.close() if FLAGS.validate: FLAGS.env = "llvm-ic-v0" diff --git a/compiler_gym/random_replay.py b/compiler_gym/random_replay.py index c12398470..f00cc63f6 100644 --- a/compiler_gym/random_replay.py +++ b/compiler_gym/random_replay.py @@ -79,4 +79,3 @@ def replay_actions_from_logs(env: CompilerEnv, logdir: Path, benchmark=None) -> env.reward_space = meta["reward"] env.reset(benchmark=benchmark) replay_actions(env, actions, logdir) - env.close() diff --git a/compiler_gym/random_search.py b/compiler_gym/random_search.py index 7cb025443..460e46f4e 100644 --- a/compiler_gym/random_search.py +++ b/compiler_gym/random_search.py @@ -54,22 +54,16 @@ def run(self) -> None: """Run episodes in an infinite loop.""" while self.should_run_one_episode: self.total_environment_count += 1 - env = self._make_env() - try: + with self._make_env() as env: self._patience = self._patience or env.action_space.n self.run_one_environment(env) - finally: - env.close() def run_one_environment(self, env: CompilerEnv) -> None: """Run random walks in an infinite loop. Returns if the environment ends.""" - try: - while self.should_run_one_episode: - self.total_episode_count += 1 - if not self.run_one_episode(env): - return - finally: - env.close() + while self.should_run_one_episode: + self.total_episode_count += 1 + if not self.run_one_episode(env): + return def run_one_episode(self, env: CompilerEnv) -> bool: """Run a single random episode. @@ -253,9 +247,8 @@ def random_search( print("done") print("Replaying actions from best solution found:") - env = make_env() - env.reset() - replay_actions(env, best_action_names, outdir) - env.close() + with make_env() as env: + env.reset() + replay_actions(env, best_action_names, outdir) return best_returns, best_actions diff --git a/compiler_gym/util/flags/env_from_flags.py b/compiler_gym/util/flags/env_from_flags.py index 67cf33d21..da984ee69 100644 --- a/compiler_gym/util/flags/env_from_flags.py +++ b/compiler_gym/util/flags/env_from_flags.py @@ -141,8 +141,5 @@ def env_from_flags(benchmark: Optional[Union[str, Benchmark]] = None) -> Compile def env_session_from_flags( benchmark: Optional[Union[str, Benchmark]] = None ) -> CompilerEnv: - env = env_from_flags(benchmark=benchmark) - try: + with env_from_flags(benchmark=benchmark) as env: yield env - finally: - env.close() diff --git a/compiler_gym/validate.py b/compiler_gym/validate.py index 9a8ed5b8d..b9ac7bb4b 100644 --- a/compiler_gym/validate.py +++ b/compiler_gym/validate.py @@ -16,11 +16,8 @@ def _validate_states_worker( make_env: Callable[[], CompilerEnv], state: CompilerEnvState ) -> ValidationResult: - env = make_env() - try: + with make_env() as env: result = env.validate(state) - finally: - env.close() return result diff --git a/docs/source/llvm/index.rst b/docs/source/llvm/index.rst index 3706fb46f..024483d1d 100644 --- a/docs/source/llvm/index.rst +++ b/docs/source/llvm/index.rst @@ -108,7 +108,8 @@ Alternatively the module can be serialized to a bitcode file on disk: .. note:: Files generated by the :code:`BitcodeFile` observation space are put in a - temporary directory that is removed when :meth:`env.close() ` is called. + temporary directory that is removed when :meth:`env.close() + ` is called. InstCount diff --git a/examples/actor_critic.py b/examples/actor_critic.py index d17c7cb3a..c77cb8a94 100644 --- a/examples/actor_critic.py +++ b/examples/actor_critic.py @@ -352,6 +352,8 @@ def make_env(): def main(argv): """Main entry point.""" + del argv # unused + torch.manual_seed(FLAGS.seed) random.seed(FLAGS.seed) diff --git a/examples/brute_force.py b/examples/brute_force.py index c967e81ca..0d256895f 100644 --- a/examples/brute_force.py +++ b/examples/brute_force.py @@ -172,33 +172,32 @@ def run_brute_force( meta_path = outdir / "meta.json" results_path = outdir / "results.csv" - env: CompilerEnv = make_env() - env.reset() - - action_names = action_names or env.action_space.names - - if not env.reward_space: - raise ValueError("A reward space must be specified for random search") - reward_space_name = env.reward_space.id - - actions = [env.action_space.names.index(a) for a in action_names] - benchmark_uri = env.benchmark.uri - - meta = { - "env": env.spec.id, - "action_names": action_names, - "benchmark": benchmark_uri, - "reward": reward_space_name, - "init_reward": env.reward[reward_space_name], - "episode_length": episode_length, - "nproc": nproc, - "chunksize": chunksize, - } - with open(str(meta_path), "w") as f: - json.dump(meta, f) - print(f"Wrote {meta_path}") - print(f"Writing results to {results_path}") - env.close() + with make_env() as env: + env.reset() + + action_names = action_names or env.action_space.names + + if not env.reward_space: + raise ValueError("A reward space must be specified for random search") + reward_space_name = env.reward_space.id + + actions = [env.action_space.names.index(a) for a in action_names] + benchmark_uri = env.benchmark.uri + + meta = { + "env": env.spec.id, + "action_names": action_names, + "benchmark": benchmark_uri, + "reward": reward_space_name, + "init_reward": env.reward[reward_space_name], + "episode_length": episode_length, + "nproc": nproc, + "chunksize": chunksize, + } + with open(str(meta_path), "w") as f: + json.dump(meta, f) + print(f"Wrote {meta_path}") + print(f"Writing results to {results_path}") # A queue for communicating action sequences to workers, and a queue for # workers to report results. @@ -287,14 +286,13 @@ def run_brute_force( worker.join() num_trials = sum(worker.num_trials for worker in workers) - env: CompilerEnv = make_env() - print( - f"completed {humanize.intcomma(num_trials)} of " - f"{humanize.intcomma(expected_trial_count)} trials " - f"({num_trials / expected_trial_count:.3%}), best sequence", - " ".join([env.action_space.flags[i] for i in best_action_sequence]), - ) - env.close() + with make_env() as env: + print( + f"completed {humanize.intcomma(num_trials)} of " + f"{humanize.intcomma(expected_trial_count)} trials " + f"({num_trials / expected_trial_count:.3%}), best sequence", + " ".join([env.action_space.flags[i] for i in best_action_sequence]), + ) def main(argv): @@ -309,11 +307,10 @@ def main(argv): if not benchmark: raise app.UsageError("No benchmark specified.") - env = env_from_flags(benchmark) - env.reset() - benchmark = env.benchmark - sanitized_benchmark_uri = "/".join(benchmark.split("/")[-2:]) - env.close() + with env_from_flags(benchmark) as env: + env.reset() + benchmark = env.benchmark + sanitized_benchmark_uri = "/".join(benchmark.split("/")[-2:]) logs_dir = Path( FLAGS.output_dir or create_logging_dir(f"brute_force/{sanitized_benchmark_uri}") ) diff --git a/examples/explore.py b/examples/explore.py index 77f75c3c9..bb5043119 100644 --- a/examples/explore.py +++ b/examples/explore.py @@ -447,8 +447,8 @@ def main(argv): print(f"Running with {FLAGS.nproc} threads.") assert FLAGS.nproc >= 1 + envs = [] try: - envs = [] for _ in range(FLAGS.nproc): envs.append(make_env()) compute_action_graph(envs, episode_length=FLAGS.episode_length) diff --git a/examples/random_walk.py b/examples/random_walk.py index 965d968e1..bba6bdc8a 100644 --- a/examples/random_walk.py +++ b/examples/random_walk.py @@ -62,7 +62,6 @@ def run_random_walk(env: CompilerEnv, step_count: int) -> None: if done: print("Episode ended by environment") break - env.close() def reward_percentage(reward, rewards): if sum(rewards) == 0: @@ -92,11 +91,10 @@ def main(argv): assert len(argv) == 1, f"Unrecognized flags: {argv[1:]}" benchmark = benchmark_from_flags() - env = env_from_flags(benchmark) - - step_min = min(FLAGS.step_min, FLAGS.step_max) - step_max = max(FLAGS.step_min, FLAGS.step_max) - run_random_walk(env=env, step_count=random.randint(step_min, step_max)) + with env_from_flags(benchmark) as env: + step_min = min(FLAGS.step_min, FLAGS.step_max) + step_max = max(FLAGS.step_min, FLAGS.step_max) + run_random_walk(env=env, step_count=random.randint(step_min, step_max)) if __name__ == "__main__": diff --git a/examples/random_walk_test.py b/examples/random_walk_test.py index d810ab12d..c4d169ded 100644 --- a/examples/random_walk_test.py +++ b/examples/random_walk_test.py @@ -12,12 +12,9 @@ def test_run_random_walk_smoke_test(): flags.FLAGS(["argv0"]) - env = gym.make("llvm-autophase-ic-v0") - env.benchmark = "cbench-v1/crc32" - try: + with gym.make("llvm-autophase-ic-v0") as env: + env.benchmark = "cbench-v1/crc32" run_random_walk(env=env, step_count=5) - finally: - env.close() if __name__ == "__main__": diff --git a/examples/sensitivity_analysis/action_sensitivity_analysis.py b/examples/sensitivity_analysis/action_sensitivity_analysis.py index 868e1f743..95f57b9d4 100644 --- a/examples/sensitivity_analysis/action_sensitivity_analysis.py +++ b/examples/sensitivity_analysis/action_sensitivity_analysis.py @@ -36,7 +36,7 @@ from compiler_gym.envs import CompilerEnv from compiler_gym.util.flags.benchmark_from_flags import benchmark_from_flags -from compiler_gym.util.flags.env_from_flags import env_session_from_flags +from compiler_gym.util.flags.env_from_flags import env_from_flags from compiler_gym.util.logs import create_logging_dir from compiler_gym.util.timer import Timer from examples.sensitivity_analysis.sensitivity_analysis_eval import ( @@ -93,7 +93,7 @@ def get_rewards( and len(rewards) < num_trials ): num_attempts += 1 - with env_session_from_flags(benchmark=benchmark) as env: + with env_from_flags(benchmark=benchmark) as env: env.observation_space = None env.reward_space = None env.reset(benchmark=benchmark) @@ -143,7 +143,7 @@ def run_action_sensitivity_analysis( max_attempts_multiplier: int = 5, ): """Estimate the reward delta of a given list of actions.""" - with env_session_from_flags() as env: + with env_from_flags() as env: action_names = env.action_space.names with ThreadPoolExecutor(max_workers=nproc) as executor: @@ -172,7 +172,7 @@ def main(argv): if len(argv) != 1: raise app.UsageError(f"Unknown command line arguments: {argv[1:]}") - with env_session_from_flags() as env: + with env_from_flags() as env: action_names = env.action_space.names if FLAGS.action: diff --git a/examples/sensitivity_analysis/benchmark_sensitivity_analysis.py b/examples/sensitivity_analysis/benchmark_sensitivity_analysis.py index 628d2286a..3d9482992 100644 --- a/examples/sensitivity_analysis/benchmark_sensitivity_analysis.py +++ b/examples/sensitivity_analysis/benchmark_sensitivity_analysis.py @@ -38,7 +38,7 @@ from compiler_gym.envs import CompilerEnv from compiler_gym.service.proto import Benchmark from compiler_gym.util.flags.benchmark_from_flags import benchmark_from_flags -from compiler_gym.util.flags.env_from_flags import env_session_from_flags +from compiler_gym.util.flags.env_from_flags import env_from_flags from compiler_gym.util.logs import create_logging_dir from compiler_gym.util.timer import Timer from examples.sensitivity_analysis.sensitivity_analysis_eval import ( @@ -94,7 +94,7 @@ def get_rewards( and len(rewards) < num_trials ): num_attempts += 1 - with env_session_from_flags(benchmark=benchmark) as env: + with env_from_flags(benchmark=benchmark) as env: env.observation_space = None env.reward_space = None env.reset(benchmark=benchmark) @@ -173,7 +173,7 @@ def main(argv): if benchmark: benchmarks = [benchmark] else: - with env_session_from_flags() as env: + with env_from_flags() as env: benchmarks = islice(env.benchmarks, 100) logs_dir = Path( diff --git a/examples/tabular_q.py b/examples/tabular_q.py index 7f11a03ec..f48ef50f3 100644 --- a/examples/tabular_q.py +++ b/examples/tabular_q.py @@ -189,10 +189,10 @@ def main(argv): q_table: Dict[StateActionTuple, float] = {} benchmark = benchmark_from_flags() assert benchmark, "You must specify a benchmark using the --benchmark flag" - env = gym.make("llvm-ic-v0", benchmark=benchmark) - env.observation_space = "Autophase" - try: + with gym.make("llvm-ic-v0", benchmark=benchmark) as env: + env.observation_space = "Autophase" + # Train a Q-table. with Timer("Constructing Q-table"): train(q_table, env) @@ -200,9 +200,6 @@ def main(argv): # Rollout resulting policy. rollout(q_table, env, printout=True) - finally: - env.close() - if __name__ == "__main__": app.run(main) diff --git a/tests/bin/service_bin_test.py b/tests/bin/service_bin_test.py index ffffb1c89..043b51031 100644 --- a/tests/bin/service_bin_test.py +++ b/tests/bin/service_bin_test.py @@ -15,9 +15,8 @@ @pytest.mark.parametrize("env_name", compiler_gym.COMPILER_GYM_ENVS) def test_print_service_capabilities_smoke_test(env_name: str): flags.FLAGS(["argv0"]) - env = gym.make(env_name) - print_service_capabilities(env) - env.close() + with gym.make(env_name) as env: + print_service_capabilities(env) if __name__ == "__main__": diff --git a/tests/compiler_env_test.py b/tests/compiler_env_test.py index d663c9971..babc28fea 100644 --- a/tests/compiler_env_test.py +++ b/tests/compiler_env_test.py @@ -18,11 +18,8 @@ def test_benchmark_constructor_arg(env: CompilerEnv): env.close() # Fixture only required to pull in dataset. - env = gym.make("llvm-v0", benchmark="cbench-v1/dijkstra") - try: + with gym.make("llvm-v0", benchmark="cbench-v1/dijkstra") as env: assert env.benchmark == "benchmark://cbench-v1/dijkstra" - finally: - env.close() def test_benchmark_setter(env: CompilerEnv): @@ -39,14 +36,10 @@ def test_benchmark_set_in_reset(env: CompilerEnv): def test_logger_forced(): logger = logging.getLogger("test_logger") - env_a = gym.make("llvm-v0") - env_b = gym.make("llvm-v0", logger=logger) - try: - assert env_a.logger != logger - assert env_b.logger == logger - finally: - env_a.close() - env_b.close() + with gym.make("llvm-v0") as env_a: + with gym.make("llvm-v0", logger=logger) as env_b: + assert env_a.logger != logger + assert env_b.logger == logger def test_uri_substring_no_match(env: CompilerEnv): @@ -120,14 +113,11 @@ def test_gym_make_kwargs(): """Test that passing kwargs to gym.make() are forwarded to environment constructor. """ - env = gym.make( + with gym.make( "llvm-v0", observation_space="Autophase", reward_space="IrInstructionCount" - ) - try: + ) as env: assert env.observation_space_spec.id == "Autophase" assert env.reward_space.id == "IrInstructionCount" - finally: - env.close() def test_step_session_id_not_found(env: CompilerEnv): @@ -146,11 +136,10 @@ def test_step_session_id_not_found(env: CompilerEnv): def remote_env() -> CompilerEnv: """A test fixture that yields a connection to a remote service.""" service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) - env = CompilerEnv(service=service.connection.url) try: - yield env + with CompilerEnv(service=service.connection.url) as env: + yield env finally: - env.close() service.close() diff --git a/tests/fuzzing/llvm_random_actions_fuzz_test.py b/tests/fuzzing/llvm_random_actions_fuzz_test.py index aabff3455..814527ccc 100644 --- a/tests/fuzzing/llvm_random_actions_fuzz_test.py +++ b/tests/fuzzing/llvm_random_actions_fuzz_test.py @@ -22,15 +22,14 @@ @pytest.mark.timeout(600) def test_fuzz(benchmark_name: str): """Run randomly selected actions on a benchmark until a minimum amount of time has elapsed.""" - env = gym.make( + with gym.make( "llvm-v0", reward_space="IrInstructionCount", observation_space="Autophase", benchmark=benchmark_name, - ) - env.reset() + ) as env: + env.reset() - try: # Take a random step until a predetermined amount of time has elapsed. end_time = time() + FUZZ_TIME_SECONDS while time() < end_time: @@ -52,8 +51,6 @@ def test_fuzz(benchmark_name: str): assert isinstance(observation, np.ndarray) assert observation.shape == (AUTOPHASE_FEATURE_DIM,) assert isinstance(reward, float) - finally: - env.close() if __name__ == "__main__": diff --git a/tests/llvm/all_benchmarks_init_close_test.py b/tests/llvm/all_benchmarks_init_close_test.py index 6c21f4ca6..aecbd4c49 100644 --- a/tests/llvm/all_benchmarks_init_close_test.py +++ b/tests/llvm/all_benchmarks_init_close_test.py @@ -14,7 +14,6 @@ def test_init_benchmark(env: CompilerEnv, benchmark_name: str): """Create an environment for each benchmark and close it.""" env.reset(benchmark=benchmark_name) assert env.benchmark == benchmark_name - env.close() if __name__ == "__main__": diff --git a/tests/llvm/custom_benchmarks_test.py b/tests/llvm/custom_benchmarks_test.py index 48b79e4a8..a0938bc0b 100644 --- a/tests/llvm/custom_benchmarks_test.py +++ b/tests/llvm/custom_benchmarks_test.py @@ -111,12 +111,9 @@ def test_custom_benchmark(env: LlvmEnv): def test_custom_benchmark_constructor(): benchmark = Benchmark.from_file("benchmark://new", EXAMPLE_BITCODE_FILE) - env = gym.make("llvm-v0", benchmark=benchmark) - try: + with gym.make("llvm-v0", benchmark=benchmark) as env: env.reset() assert env.benchmark == "benchmark://new" - finally: - env.close() def test_make_benchmark_single_bitcode(env: LlvmEnv): diff --git a/tests/llvm/datasets/anghabench_test.py b/tests/llvm/datasets/anghabench_test.py index bb4149dca..1b5a8c19d 100644 --- a/tests/llvm/datasets/anghabench_test.py +++ b/tests/llvm/datasets/anghabench_test.py @@ -21,11 +21,8 @@ @pytest.fixture(scope="module") def anghabench_dataset() -> AnghaBenchDataset: - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: ds = env.datasets["anghabench-v1"] - finally: - env.close() yield ds diff --git a/tests/llvm/datasets/chstone_test.py b/tests/llvm/datasets/chstone_test.py index 738fabdd7..23342c6c4 100644 --- a/tests/llvm/datasets/chstone_test.py +++ b/tests/llvm/datasets/chstone_test.py @@ -16,11 +16,8 @@ @pytest.fixture(scope="module") def chstone_dataset() -> CHStoneDataset: - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: ds = env.datasets["chstone-v0"] - finally: - env.close() yield ds diff --git a/tests/llvm/datasets/clgen_test.py b/tests/llvm/datasets/clgen_test.py index d725a8b4a..f9b125cfd 100644 --- a/tests/llvm/datasets/clgen_test.py +++ b/tests/llvm/datasets/clgen_test.py @@ -20,11 +20,8 @@ @pytest.fixture(scope="module") def clgen_dataset() -> CLgenDataset: - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: ds = env.datasets["benchmark://clgen-v0"] - finally: - env.close() yield ds diff --git a/tests/llvm/datasets/csmith_test.py b/tests/llvm/datasets/csmith_test.py index 7625cf614..d0e232ed2 100644 --- a/tests/llvm/datasets/csmith_test.py +++ b/tests/llvm/datasets/csmith_test.py @@ -21,11 +21,8 @@ @pytest.fixture(scope="module") def csmith_dataset() -> CsmithDataset: - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: ds = env.datasets["generator://csmith-v0"] - finally: - env.close() yield ds diff --git a/tests/llvm/datasets/github_test.py b/tests/llvm/datasets/github_test.py index ba7af2e60..1f556a62d 100644 --- a/tests/llvm/datasets/github_test.py +++ b/tests/llvm/datasets/github_test.py @@ -20,11 +20,8 @@ @pytest.fixture(scope="module") def github_dataset() -> GitHubDataset: - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: ds = env.datasets["github-v0"] - finally: - env.close() yield ds diff --git a/tests/llvm/datasets/llvm_datasets_test.py b/tests/llvm/datasets/llvm_datasets_test.py index 69b4305ec..697d26de5 100644 --- a/tests/llvm/datasets/llvm_datasets_test.py +++ b/tests/llvm/datasets/llvm_datasets_test.py @@ -10,8 +10,7 @@ def test_default_dataset_list(): - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: assert list(d.name for d in env.datasets) == [ "benchmark://cbench-v1", "benchmark://anghabench-v1", @@ -28,8 +27,6 @@ def test_default_dataset_list(): "generator://csmith-v0", "generator://llvm-stress-v0", ] - finally: - env.close() if __name__ == "__main__": diff --git a/tests/llvm/datasets/llvm_stress_test.py b/tests/llvm/datasets/llvm_stress_test.py index 303616be4..8afeab8c8 100644 --- a/tests/llvm/datasets/llvm_stress_test.py +++ b/tests/llvm/datasets/llvm_stress_test.py @@ -22,11 +22,8 @@ @pytest.fixture(scope="module") def llvm_stress_dataset() -> LlvmStressDataset: - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: ds = env.datasets["generator://llvm-stress-v0"] - finally: - env.close() yield ds diff --git a/tests/llvm/datasets/poj104_test.py b/tests/llvm/datasets/poj104_test.py index 084d89ad4..ce3ee3ed8 100644 --- a/tests/llvm/datasets/poj104_test.py +++ b/tests/llvm/datasets/poj104_test.py @@ -21,11 +21,8 @@ @pytest.fixture(scope="module") def poj104_dataset() -> POJ104Dataset: - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: ds = env.datasets["poj104-v1"] - finally: - env.close() yield ds diff --git a/tests/llvm/gym_interface_compatability.py b/tests/llvm/gym_interface_compatability.py index 95a02058f..f2c658acc 100644 --- a/tests/llvm/gym_interface_compatability.py +++ b/tests/llvm/gym_interface_compatability.py @@ -13,11 +13,7 @@ @pytest.fixture(scope="function") def env() -> CompilerEnv: - env = gym.make("llvm-autophase-ic-v0") - try: - yield env - finally: - env.close() + return gym.make("llvm-autophase-ic-v0") def test_type_classes(env: CompilerEnv): diff --git a/tests/llvm/llvm_env_test.py b/tests/llvm/llvm_env_test.py index 22a6dcaca..b98a1b016 100644 --- a/tests/llvm/llvm_env_test.py +++ b/tests/llvm/llvm_env_test.py @@ -31,18 +31,14 @@ def env(request) -> CompilerEnv: """Create an LLVM environment.""" if request.param == "local": - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: yield env - finally: - env.close() else: service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) - env = LlvmEnv(service=service.connection.url) try: - yield env + with LlvmEnv(service=service.connection.url) as env: + yield env finally: - env.close() service.close() @@ -165,13 +161,9 @@ def test_apply_state(env: LlvmEnv): env.reset(benchmark="cbench-v1/crc32") env.step(env.action_space.flags.index("-mem2reg")) - other = gym.make("llvm-v0", reward_space="IrInstructionCount") - try: + with gym.make("llvm-v0", reward_space="IrInstructionCount") as other: other.apply(env.state) - assert other.state == env.state - finally: - other.close() def test_set_observation_space_from_spec(env: LlvmEnv): diff --git a/tests/llvm/multiprocessing_test.py b/tests/llvm/multiprocessing_test.py index fd8b3027a..e241dcd82 100644 --- a/tests/llvm/multiprocessing_test.py +++ b/tests/llvm/multiprocessing_test.py @@ -17,15 +17,14 @@ def process_worker(env_name: str, benchmark: str, actions: List[int], queue: mp.Queue): assert actions - env = gym.make(env_name) - env.reset(benchmark=benchmark) + with gym.make(env_name) as env: + env.reset(benchmark=benchmark) - for action in actions: - observation, reward, done, info = env.step(action) - assert not done + for action in actions: + observation, reward, done, info = env.step(action) + assert not done - queue.put((observation, reward, done, info)) - env.close() + queue.put((observation, reward, done, info)) def process_worker_with_env(env: LlvmEnv, actions: List[int], queue: mp.Queue): @@ -70,28 +69,27 @@ def test_moving_environment_to_background_process_macos(): """Test moving an LLVM environment to a background process.""" queue = mp.Queue(maxsize=3) - env = gym.make("llvm-autophase-ic-v0") - env.reset(benchmark="cbench-v1/crc32") + with gym.make("llvm-autophase-ic-v0") as env: + env.reset(benchmark="cbench-v1/crc32") - process = mp.Process(target=process_worker_with_env, args=(env, [0, 0, 0], queue)) + process = mp.Process( + target=process_worker_with_env, args=(env, [0, 0, 0], queue) + ) - # Moving an environment to a background process is not supported because - # we are using a subprocess.Popen() to manage the service binary, which - # doesn't support pickling. - with pytest.raises(TypeError): - process.start() + # Moving an environment to a background process is not supported because + # we are using a subprocess.Popen() to manage the service binary, which + # doesn't support pickling. + with pytest.raises(TypeError): + process.start() def test_port_collision_test(): """Test that attempting to connect to a port that is already in use succeeds.""" - env_a = gym.make("llvm-autophase-ic-v0") - env_a.reset(benchmark="cbench-v1/crc32") - - env_b = LlvmEnv(service=env_a.service.connection.url) - env_b.reset(benchmark="cbench-v1/crc32") + with gym.make("llvm-autophase-ic-v0") as env_a: + env_a.reset(benchmark="cbench-v1/crc32") - env_b.close() - env_a.close() + with LlvmEnv(service=env_a.service.connection.url) as env_b: + env_b.reset(benchmark="cbench-v1/crc32") if __name__ == "__main__": diff --git a/tests/llvm/service_connection_test.py b/tests/llvm/service_connection_test.py index 3cc0eeab8..54b7d2d8d 100644 --- a/tests/llvm/service_connection_test.py +++ b/tests/llvm/service_connection_test.py @@ -23,18 +23,14 @@ def env(request) -> CompilerEnv: # Redefine fixture to test both gym.make(...) and unmanaged service # connections. if request.param == "local": - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: yield env - finally: - env.close() else: service = CompilerGymServiceConnection(llvm.LLVM_SERVICE_BINARY) - env = LlvmEnv(service=service.connection.url) try: - yield env + with LlvmEnv(service=service.connection.url) as env: + yield env finally: - env.close() service.close() diff --git a/tests/llvm/threading_test.py b/tests/llvm/threading_test.py index 20cc15d76..b40fe559a 100644 --- a/tests/llvm/threading_test.py +++ b/tests/llvm/threading_test.py @@ -24,15 +24,14 @@ def __init__(self, env_name: str, benchmark: str, actions: List[int]): assert actions def run(self) -> None: - env = gym.make(self.env_name, benchmark=self.benchmark) - env.reset() + with gym.make(self.env_name, benchmark=self.benchmark) as env: + env.reset() - for action in self.actions: - self.observation, self.reward, done, self.info = env.step(action) - assert not done, self.info["error_details"] + for action in self.actions: + self.observation, self.reward, done, self.info = env.step(action) + assert not done, self.info["error_details"] - self.done = True - env.close() + self.done = True class ThreadedWorkerWithEnv(Thread): @@ -73,20 +72,19 @@ def test_moving_environment_to_background_thread(): """Test running an LLVM environment from a background thread. The environment is made in the main thread and used in the background thread. """ - env = gym.make("llvm-autophase-ic-v0") - env.reset(benchmark="cbench-v1/crc32") + with gym.make("llvm-autophase-ic-v0") as env: + env.reset(benchmark="cbench-v1/crc32") - thread = ThreadedWorkerWithEnv(env=env, actions=[0, 0, 0]) - thread.start() - thread.join(timeout=10) + thread = ThreadedWorkerWithEnv(env=env, actions=[0, 0, 0]) + thread.start() + thread.join(timeout=10) - assert thread.done - assert thread.observation is not None - assert isinstance(thread.reward, float) - assert thread.info + assert thread.done + assert thread.observation is not None + assert isinstance(thread.reward, float) + assert thread.info - assert env.in_episode - env.close() + assert env.in_episode if __name__ == "__main__": diff --git a/tests/llvm/validate_test.py b/tests/llvm/validate_test.py index 1099ba8db..dc65b34e5 100644 --- a/tests/llvm/validate_test.py +++ b/tests/llvm/validate_test.py @@ -22,11 +22,8 @@ def test_validate_state_no_reward(): walltime=1, commandline="opt input.bc -o output.bc", ) - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: result = env.validate(state) - finally: - env.close() assert result.okay() assert not result.reward_validated @@ -40,11 +37,8 @@ def test_validate_state_with_reward(): reward=0, commandline="opt input.bc -o output.bc", ) - env = gym.make("llvm-v0", reward_space="IrInstructionCount") - try: + with gym.make("llvm-v0", reward_space="IrInstructionCount") as env: result = env.validate(state) - finally: - env.close() assert result.okay() assert result.reward_validated @@ -59,11 +53,8 @@ def test_validate_state_invalid_reward(): reward=1, commandline="opt input.bc -o output.bc", ) - env = gym.make("llvm-v0", reward_space="IrInstructionCount") - try: + with gym.make("llvm-v0", reward_space="IrInstructionCount") as env: result = env.validate(state) - finally: - env.close() assert not result.okay() assert result.reward_validated @@ -80,11 +71,8 @@ def test_validate_state_without_state_reward(): walltime=1, commandline="opt input.bc -o output.bc", ) - env = gym.make("llvm-v0", reward_space="IrInstructionCount") - try: + with gym.make("llvm-v0", reward_space="IrInstructionCount") as env: result = env.validate(state) - finally: - env.close() assert result.okay() assert not result.reward_validated @@ -99,8 +87,7 @@ def test_validate_state_without_env_reward(): reward=0, commandline="opt input.bc -o output.bc", ) - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: with pytest.warns( UserWarning, match=( @@ -109,8 +96,6 @@ def test_validate_state_without_env_reward(): ), ): result = env.validate(state) - finally: - env.close() assert result.okay() assert not result.reward_validated diff --git a/tests/pytest_plugins/llvm.py b/tests/pytest_plugins/llvm.py index a0768c67e..aeeb2de99 100644 --- a/tests/pytest_plugins/llvm.py +++ b/tests/pytest_plugins/llvm.py @@ -98,11 +98,8 @@ def non_validatable_cbench_uri(request) -> str: @pytest.fixture(scope="function") def env() -> LlvmEnv: """Create an LLVM environment.""" - env = gym.make("llvm-v0") - try: - yield env - finally: - env.close() + with gym.make("llvm-v0") as env_: + yield env_ @pytest.fixture(scope="module") diff --git a/tests/random_search_test.py b/tests/random_search_test.py index d29a78483..d730eb510 100644 --- a/tests/random_search_test.py +++ b/tests/random_search_test.py @@ -38,13 +38,10 @@ def test_random_search_smoke_test(): assert (outdir / "random_search_best_actions.txt").is_file() assert (outdir / "optimized.bc").is_file() - env = make_env() - try: + with make_env() as env: replay_actions_from_logs(env, Path(outdir)) assert (outdir / "random_search_best_actions_progress.csv").is_file() assert (outdir / "random_search_best_actions_commandline.txt").is_file() - finally: - env.close() if __name__ == "__main__": diff --git a/tests/service/connection_test.py b/tests/service/connection_test.py index 1203f3be4..3cd71fd08 100644 --- a/tests/service/connection_test.py +++ b/tests/service/connection_test.py @@ -19,25 +19,19 @@ @pytest.fixture(scope="function") def connection() -> CompilerGymServiceConnection: """Yields a connection to a local service.""" - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: yield env.service - finally: - env.close() @pytest.fixture(scope="function") def dead_connection() -> CompilerGymServiceConnection: """Yields a connection to a dead local service service.""" - env = gym.make("llvm-v0") - try: + with gym.make("llvm-v0") as env: # Kill the service. env.service.connection.process.terminate() env.service.connection.process.communicate() yield env.service - finally: - env.close() def test_create_invalid_options(): diff --git a/tests/wrappers/core_wrappers_test.py b/tests/wrappers/core_wrappers_test.py index b6b814824..6e47d9fa6 100644 --- a/tests/wrappers/core_wrappers_test.py +++ b/tests/wrappers/core_wrappers_test.py @@ -35,10 +35,10 @@ def test_wrapped_close(env: LlvmEnv, wrapper_type): def test_wrapped_properties(env: LlvmEnv, wrapper_type): """Test accessing the non-standard properties.""" - env = wrapper_type(env) - assert env.actions == [] - assert env.benchmark - assert isinstance(env.datasets, Datasets) + with wrapper_type(env) as env: + assert env.actions == [] + assert env.benchmark + assert isinstance(env.datasets, Datasets) def test_wrapped_fork_type(env: LlvmEnv, wrapper_type):