Skip to content

Commit

Permalink
Use with statement in place of try/finally for envs.
Browse files Browse the repository at this point in the history
This patch refactors the code pattern `try: ...; finally: env.close()`
to instead use the `with gym.make(...):` pattern. This is preferred
because it automatically handles calling `close()`.
  • Loading branch information
ChrisCummins committed Sep 9, 2021
1 parent 58b5831 commit 12b4414
Show file tree
Hide file tree
Showing 43 changed files with 169 additions and 300 deletions.
5 changes: 1 addition & 4 deletions benchmarks/parallelization_load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@

def run_random_search(num_episodes, num_steps) -> None:
"""The inner loop of a load test benchmark."""
env = env_from_flags(benchmark=benchmark_from_flags())
try:
with env_from_flags(benchmark=benchmark_from_flags()) as env:
for _ in range(num_episodes):
env.reset()
for _ in range(num_steps):
_, _, done, _ = env.step(env.action_space.sample())
if done:
break
finally:
env.close()


def main(argv):
Expand Down
5 changes: 1 addition & 4 deletions compiler_gym/bin/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def main(argv):
if len(argv) != 1:
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

env = env_from_flags()
try:
with env_from_flags() as env:
invalidated_manifest = False

for name_or_url in FLAGS.download:
Expand Down Expand Up @@ -182,8 +181,6 @@ def main(argv):
print(
summarize_datasets(env.datasets),
)
finally:
env.close()


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions compiler_gym/bin/random_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def main(argv):
output_dir / logs.METADATA_NAME
).is_file(), f"Invalid --output_dir: {output_dir}"

env = env_from_flags()
benchmark = benchmark_from_flags()
replay_actions_from_logs(env, output_dir, benchmark=benchmark)
with env_from_flags() as env:
benchmark = benchmark_from_flags()
replay_actions_from_logs(env, output_dir, benchmark=benchmark)


if __name__ == "__main__":
Expand Down
10 changes: 3 additions & 7 deletions compiler_gym/bin/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,17 @@ def main(argv):
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

if FLAGS.ls_reward:
env = env_from_flags()
print("\n".join(sorted(env.reward.indices.keys())))
env.close()
with env_from_flags() as env:
print("\n".join(sorted(env.reward.indices.keys())))
return

assert FLAGS.patience >= 0, "--patience must be >= 0"

def make_env():
return env_from_flags(benchmark=benchmark_from_flags())

env = make_env()
try:
with make_env() as env:
env.reset()
finally:
env.close()

best_reward, _ = random_search(
make_env=make_env,
Expand Down
5 changes: 1 addition & 4 deletions compiler_gym/bin/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,8 @@ def main(argv):
"""Main entry point."""
assert len(argv) == 1, f"Unrecognized flags: {argv[1:]}"

env = env_from_flags()
try:
with env_from_flags() as env:
print_service_capabilities(env)
finally:
env.close()


if __name__ == "__main__":
Expand Down
5 changes: 1 addition & 4 deletions compiler_gym/bin/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ def main(argv):
)

# Determine the name of the reward space.
env = env_from_flags()
try:
with env_from_flags() as env:
if FLAGS.reward_aggregation == "geomean":

def reward_aggregation(a):
Expand All @@ -173,8 +172,6 @@ def reward_aggregation(a):
reward_name = f"{reward_aggregation_name} {env.reward_space.id}"
else:
reward_name = ""
finally:
env.close()

# Determine the maximum column width required for printing tabular output.
max_state_name_length = max(
Expand Down
23 changes: 10 additions & 13 deletions compiler_gym/leaderboard/llvm_instcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,19 @@ def main(argv):
assert len(argv) == 1, f"Unknown args: {argv[:1]}"
assert FLAGS.n > 0, "n must be > 0"

env = gym.make("llvm-ic-v0")
with gym.make("llvm-ic-v0") as env:

# Stream verbose CompilerGym logs to file.
logger = logging.getLogger("compiler_gym")
logger.setLevel(logging.DEBUG)
env.logger.setLevel(logging.DEBUG)
log_handler = logging.FileHandler(FLAGS.leaderboard_logfile)
logger.addHandler(log_handler)
logger.propagate = False
# Stream verbose CompilerGym logs to file.
logger = logging.getLogger("compiler_gym")
logger.setLevel(logging.DEBUG)
env.logger.setLevel(logging.DEBUG)
log_handler = logging.FileHandler(FLAGS.leaderboard_logfile)
logger.addHandler(log_handler)
logger.propagate = False

print(f"Writing results to {FLAGS.leaderboard_results}")
print(f"Writing logs to {FLAGS.leaderboard_logfile}")
print(f"Writing results to {FLAGS.leaderboard_results}")
print(f"Writing logs to {FLAGS.leaderboard_logfile}")

try:
# Build the list of benchmarks to evaluate.
benchmarks = env.datasets[FLAGS.test_dataset].benchmark_uris()
if FLAGS.max_benchmarks:
Expand Down Expand Up @@ -301,8 +300,6 @@ def main(argv):
worker.alive = False
# User interrupt, don't validate.
FLAGS.validate = False
finally:
env.close()

if FLAGS.validate:
FLAGS.env = "llvm-ic-v0"
Expand Down
1 change: 0 additions & 1 deletion compiler_gym/random_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,3 @@ def replay_actions_from_logs(env: CompilerEnv, logdir: Path, benchmark=None) ->
env.reward_space = meta["reward"]
env.reset(benchmark=benchmark)
replay_actions(env, actions, logdir)
env.close()
23 changes: 8 additions & 15 deletions compiler_gym/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,16 @@ def run(self) -> None:
"""Run episodes in an infinite loop."""
while self.should_run_one_episode:
self.total_environment_count += 1
env = self._make_env()
try:
with self._make_env() as env:
self._patience = self._patience or env.action_space.n
self.run_one_environment(env)
finally:
env.close()

def run_one_environment(self, env: CompilerEnv) -> None:
"""Run random walks in an infinite loop. Returns if the environment ends."""
try:
while self.should_run_one_episode:
self.total_episode_count += 1
if not self.run_one_episode(env):
return
finally:
env.close()
while self.should_run_one_episode:
self.total_episode_count += 1
if not self.run_one_episode(env):
return

def run_one_episode(self, env: CompilerEnv) -> bool:
"""Run a single random episode.
Expand Down Expand Up @@ -253,9 +247,8 @@ def random_search(
print("done")

print("Replaying actions from best solution found:")
env = make_env()
env.reset()
replay_actions(env, best_action_names, outdir)
env.close()
with make_env() as env:
env.reset()
replay_actions(env, best_action_names, outdir)

return best_returns, best_actions
5 changes: 1 addition & 4 deletions compiler_gym/util/flags/env_from_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,5 @@ def env_from_flags(benchmark: Optional[Union[str, Benchmark]] = None) -> Compile
def env_session_from_flags(
benchmark: Optional[Union[str, Benchmark]] = None
) -> CompilerEnv:
env = env_from_flags(benchmark=benchmark)
try:
with env_from_flags(benchmark=benchmark) as env:
yield env
finally:
env.close()
5 changes: 1 addition & 4 deletions compiler_gym/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
def _validate_states_worker(
make_env: Callable[[], CompilerEnv], state: CompilerEnvState
) -> ValidationResult:
env = make_env()
try:
with make_env() as env:
result = env.validate(state)
finally:
env.close()
return result


Expand Down
3 changes: 2 additions & 1 deletion docs/source/llvm/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ Alternatively the module can be serialized to a bitcode file on disk:

.. note::
Files generated by the :code:`BitcodeFile` observation space are put in a
temporary directory that is removed when :meth:`env.close() <compiler_gym.envs.CompilerEnv.close>` is called.
temporary directory that is removed when :meth:`env.close()
<compiler_gym.envs.CompilerEnv.close>` is called.


InstCount
Expand Down
2 changes: 2 additions & 0 deletions examples/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ def make_env():

def main(argv):
"""Main entry point."""
del argv # unused

torch.manual_seed(FLAGS.seed)
random.seed(FLAGS.seed)

Expand Down
77 changes: 37 additions & 40 deletions examples/brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,33 +172,32 @@ def run_brute_force(
meta_path = outdir / "meta.json"
results_path = outdir / "results.csv"

env: CompilerEnv = make_env()
env.reset()

action_names = action_names or env.action_space.names

if not env.reward_space:
raise ValueError("A reward space must be specified for random search")
reward_space_name = env.reward_space.id

actions = [env.action_space.names.index(a) for a in action_names]
benchmark_uri = env.benchmark.uri

meta = {
"env": env.spec.id,
"action_names": action_names,
"benchmark": benchmark_uri,
"reward": reward_space_name,
"init_reward": env.reward[reward_space_name],
"episode_length": episode_length,
"nproc": nproc,
"chunksize": chunksize,
}
with open(str(meta_path), "w") as f:
json.dump(meta, f)
print(f"Wrote {meta_path}")
print(f"Writing results to {results_path}")
env.close()
with make_env() as env:
env.reset()

action_names = action_names or env.action_space.names

if not env.reward_space:
raise ValueError("A reward space must be specified for random search")
reward_space_name = env.reward_space.id

actions = [env.action_space.names.index(a) for a in action_names]
benchmark_uri = env.benchmark.uri

meta = {
"env": env.spec.id,
"action_names": action_names,
"benchmark": benchmark_uri,
"reward": reward_space_name,
"init_reward": env.reward[reward_space_name],
"episode_length": episode_length,
"nproc": nproc,
"chunksize": chunksize,
}
with open(str(meta_path), "w") as f:
json.dump(meta, f)
print(f"Wrote {meta_path}")
print(f"Writing results to {results_path}")

# A queue for communicating action sequences to workers, and a queue for
# workers to report <action_sequence, reward_sequence> results.
Expand Down Expand Up @@ -287,14 +286,13 @@ def run_brute_force(
worker.join()

num_trials = sum(worker.num_trials for worker in workers)
env: CompilerEnv = make_env()
print(
f"completed {humanize.intcomma(num_trials)} of "
f"{humanize.intcomma(expected_trial_count)} trials "
f"({num_trials / expected_trial_count:.3%}), best sequence",
" ".join([env.action_space.flags[i] for i in best_action_sequence]),
)
env.close()
with make_env() as env:
print(
f"completed {humanize.intcomma(num_trials)} of "
f"{humanize.intcomma(expected_trial_count)} trials "
f"({num_trials / expected_trial_count:.3%}), best sequence",
" ".join([env.action_space.flags[i] for i in best_action_sequence]),
)


def main(argv):
Expand All @@ -309,11 +307,10 @@ def main(argv):
if not benchmark:
raise app.UsageError("No benchmark specified.")

env = env_from_flags(benchmark)
env.reset()
benchmark = env.benchmark
sanitized_benchmark_uri = "/".join(benchmark.split("/")[-2:])
env.close()
with env_from_flags(benchmark) as env:
env.reset()
benchmark = env.benchmark
sanitized_benchmark_uri = "/".join(benchmark.split("/")[-2:])
logs_dir = Path(
FLAGS.output_dir or create_logging_dir(f"brute_force/{sanitized_benchmark_uri}")
)
Expand Down
2 changes: 1 addition & 1 deletion examples/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ def main(argv):

print(f"Running with {FLAGS.nproc} threads.")
assert FLAGS.nproc >= 1
envs = []
try:
envs = []
for _ in range(FLAGS.nproc):
envs.append(make_env())
compute_action_graph(envs, episode_length=FLAGS.episode_length)
Expand Down
10 changes: 4 additions & 6 deletions examples/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def run_random_walk(env: CompilerEnv, step_count: int) -> None:
if done:
print("Episode ended by environment")
break
env.close()

def reward_percentage(reward, rewards):
if sum(rewards) == 0:
Expand Down Expand Up @@ -92,11 +91,10 @@ def main(argv):
assert len(argv) == 1, f"Unrecognized flags: {argv[1:]}"

benchmark = benchmark_from_flags()
env = env_from_flags(benchmark)

step_min = min(FLAGS.step_min, FLAGS.step_max)
step_max = max(FLAGS.step_min, FLAGS.step_max)
run_random_walk(env=env, step_count=random.randint(step_min, step_max))
with env_from_flags(benchmark) as env:
step_min = min(FLAGS.step_min, FLAGS.step_max)
step_max = max(FLAGS.step_min, FLAGS.step_max)
run_random_walk(env=env, step_count=random.randint(step_min, step_max))


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions examples/random_walk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@

def test_run_random_walk_smoke_test():
flags.FLAGS(["argv0"])
env = gym.make("llvm-autophase-ic-v0")
env.benchmark = "cbench-v1/crc32"
try:
with gym.make("llvm-autophase-ic-v0") as env:
env.benchmark = "cbench-v1/crc32"
run_random_walk(env=env, step_count=5)
finally:
env.close()


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 12b4414

Please sign in to comment.