Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use with statement in place of try/finally for envs. #384

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions benchmarks/parallelization_load_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,13 @@

def run_random_search(num_episodes, num_steps) -> None:
"""The inner loop of a load test benchmark."""
env = env_from_flags(benchmark=benchmark_from_flags())
try:
with env_from_flags(benchmark=benchmark_from_flags()) as env:
for _ in range(num_episodes):
env.reset()
for _ in range(num_steps):
_, _, done, _ = env.step(env.action_space.sample())
if done:
break
finally:
env.close()


def main(argv):
Expand Down
5 changes: 1 addition & 4 deletions compiler_gym/bin/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ def main(argv):
if len(argv) != 1:
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

env = env_from_flags()
try:
with env_from_flags() as env:
invalidated_manifest = False

for name_or_url in FLAGS.download:
Expand Down Expand Up @@ -182,8 +181,6 @@ def main(argv):
print(
summarize_datasets(env.datasets),
)
finally:
env.close()


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions compiler_gym/bin/random_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def main(argv):
output_dir / logs.METADATA_NAME
).is_file(), f"Invalid --output_dir: {output_dir}"

env = env_from_flags()
benchmark = benchmark_from_flags()
replay_actions_from_logs(env, output_dir, benchmark=benchmark)
with env_from_flags() as env:
benchmark = benchmark_from_flags()
replay_actions_from_logs(env, output_dir, benchmark=benchmark)


if __name__ == "__main__":
Expand Down
10 changes: 3 additions & 7 deletions compiler_gym/bin/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,17 @@ def main(argv):
raise app.UsageError(f"Unknown command line arguments: {argv[1:]}")

if FLAGS.ls_reward:
env = env_from_flags()
print("\n".join(sorted(env.reward.indices.keys())))
env.close()
with env_from_flags() as env:
print("\n".join(sorted(env.reward.indices.keys())))
return

assert FLAGS.patience >= 0, "--patience must be >= 0"

def make_env():
return env_from_flags(benchmark=benchmark_from_flags())

env = make_env()
try:
with make_env() as env:
env.reset()
finally:
env.close()

best_reward, _ = random_search(
make_env=make_env,
Expand Down
5 changes: 1 addition & 4 deletions compiler_gym/bin/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,8 @@ def main(argv):
"""Main entry point."""
assert len(argv) == 1, f"Unrecognized flags: {argv[1:]}"

env = env_from_flags()
try:
with env_from_flags() as env:
print_service_capabilities(env)
finally:
env.close()


if __name__ == "__main__":
Expand Down
5 changes: 1 addition & 4 deletions compiler_gym/bin/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,7 @@ def main(argv):
)

# Determine the name of the reward space.
env = env_from_flags()
try:
with env_from_flags() as env:
if FLAGS.reward_aggregation == "geomean":

def reward_aggregation(a):
Expand All @@ -173,8 +172,6 @@ def reward_aggregation(a):
reward_name = f"{reward_aggregation_name} {env.reward_space.id}"
else:
reward_name = ""
finally:
env.close()

# Determine the maximum column width required for printing tabular output.
max_state_name_length = max(
Expand Down
23 changes: 10 additions & 13 deletions compiler_gym/leaderboard/llvm_instcount.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,20 +225,19 @@ def main(argv):
assert len(argv) == 1, f"Unknown args: {argv[:1]}"
assert FLAGS.n > 0, "n must be > 0"

env = gym.make("llvm-ic-v0")
with gym.make("llvm-ic-v0") as env:

# Stream verbose CompilerGym logs to file.
logger = logging.getLogger("compiler_gym")
logger.setLevel(logging.DEBUG)
env.logger.setLevel(logging.DEBUG)
log_handler = logging.FileHandler(FLAGS.leaderboard_logfile)
logger.addHandler(log_handler)
logger.propagate = False
# Stream verbose CompilerGym logs to file.
logger = logging.getLogger("compiler_gym")
logger.setLevel(logging.DEBUG)
env.logger.setLevel(logging.DEBUG)
log_handler = logging.FileHandler(FLAGS.leaderboard_logfile)
logger.addHandler(log_handler)
logger.propagate = False

print(f"Writing results to {FLAGS.leaderboard_results}")
print(f"Writing logs to {FLAGS.leaderboard_logfile}")
print(f"Writing results to {FLAGS.leaderboard_results}")
print(f"Writing logs to {FLAGS.leaderboard_logfile}")

try:
# Build the list of benchmarks to evaluate.
benchmarks = env.datasets[FLAGS.test_dataset].benchmark_uris()
if FLAGS.max_benchmarks:
Expand Down Expand Up @@ -301,8 +300,6 @@ def main(argv):
worker.alive = False
# User interrupt, don't validate.
FLAGS.validate = False
finally:
env.close()

if FLAGS.validate:
FLAGS.env = "llvm-ic-v0"
Expand Down
1 change: 0 additions & 1 deletion compiler_gym/random_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,3 @@ def replay_actions_from_logs(env: CompilerEnv, logdir: Path, benchmark=None) ->
env.reward_space = meta["reward"]
env.reset(benchmark=benchmark)
replay_actions(env, actions, logdir)
env.close()
23 changes: 8 additions & 15 deletions compiler_gym/random_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,22 +54,16 @@ def run(self) -> None:
"""Run episodes in an infinite loop."""
while self.should_run_one_episode:
self.total_environment_count += 1
env = self._make_env()
try:
with self._make_env() as env:
self._patience = self._patience or env.action_space.n
self.run_one_environment(env)
finally:
env.close()

def run_one_environment(self, env: CompilerEnv) -> None:
"""Run random walks in an infinite loop. Returns if the environment ends."""
try:
while self.should_run_one_episode:
self.total_episode_count += 1
if not self.run_one_episode(env):
return
finally:
env.close()
while self.should_run_one_episode:
self.total_episode_count += 1
if not self.run_one_episode(env):
return

def run_one_episode(self, env: CompilerEnv) -> bool:
"""Run a single random episode.
Expand Down Expand Up @@ -253,9 +247,8 @@ def random_search(
print("done")

print("Replaying actions from best solution found:")
env = make_env()
env.reset()
replay_actions(env, best_action_names, outdir)
env.close()
with make_env() as env:
env.reset()
replay_actions(env, best_action_names, outdir)

return best_returns, best_actions
5 changes: 1 addition & 4 deletions compiler_gym/util/flags/env_from_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,5 @@ def env_from_flags(benchmark: Optional[Union[str, Benchmark]] = None) -> Compile
def env_session_from_flags(
benchmark: Optional[Union[str, Benchmark]] = None
) -> CompilerEnv:
env = env_from_flags(benchmark=benchmark)
try:
with env_from_flags(benchmark=benchmark) as env:
yield env
finally:
env.close()
5 changes: 1 addition & 4 deletions compiler_gym/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
def _validate_states_worker(
make_env: Callable[[], CompilerEnv], state: CompilerEnvState
) -> ValidationResult:
env = make_env()
try:
with make_env() as env:
result = env.validate(state)
finally:
env.close()
return result


Expand Down
3 changes: 2 additions & 1 deletion docs/source/llvm/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ Alternatively the module can be serialized to a bitcode file on disk:

.. note::
Files generated by the :code:`BitcodeFile` observation space are put in a
temporary directory that is removed when :meth:`env.close() <compiler_gym.envs.CompilerEnv.close>` is called.
temporary directory that is removed when :meth:`env.close()
<compiler_gym.envs.CompilerEnv.close>` is called.


InstCount
Expand Down
2 changes: 2 additions & 0 deletions examples/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ def make_env():

def main(argv):
"""Main entry point."""
del argv # unused

torch.manual_seed(FLAGS.seed)
random.seed(FLAGS.seed)

Expand Down
77 changes: 37 additions & 40 deletions examples/brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,33 +172,32 @@ def run_brute_force(
meta_path = outdir / "meta.json"
results_path = outdir / "results.csv"

env: CompilerEnv = make_env()
env.reset()

action_names = action_names or env.action_space.names

if not env.reward_space:
raise ValueError("A reward space must be specified for random search")
reward_space_name = env.reward_space.id

actions = [env.action_space.names.index(a) for a in action_names]
benchmark_uri = env.benchmark.uri

meta = {
"env": env.spec.id,
"action_names": action_names,
"benchmark": benchmark_uri,
"reward": reward_space_name,
"init_reward": env.reward[reward_space_name],
"episode_length": episode_length,
"nproc": nproc,
"chunksize": chunksize,
}
with open(str(meta_path), "w") as f:
json.dump(meta, f)
print(f"Wrote {meta_path}")
print(f"Writing results to {results_path}")
env.close()
with make_env() as env:
env.reset()

action_names = action_names or env.action_space.names

if not env.reward_space:
raise ValueError("A reward space must be specified for random search")
reward_space_name = env.reward_space.id

actions = [env.action_space.names.index(a) for a in action_names]
benchmark_uri = env.benchmark.uri

meta = {
"env": env.spec.id,
"action_names": action_names,
"benchmark": benchmark_uri,
"reward": reward_space_name,
"init_reward": env.reward[reward_space_name],
"episode_length": episode_length,
"nproc": nproc,
"chunksize": chunksize,
}
with open(str(meta_path), "w") as f:
json.dump(meta, f)
print(f"Wrote {meta_path}")
print(f"Writing results to {results_path}")

# A queue for communicating action sequences to workers, and a queue for
# workers to report <action_sequence, reward_sequence> results.
Expand Down Expand Up @@ -287,14 +286,13 @@ def run_brute_force(
worker.join()

num_trials = sum(worker.num_trials for worker in workers)
env: CompilerEnv = make_env()
print(
f"completed {humanize.intcomma(num_trials)} of "
f"{humanize.intcomma(expected_trial_count)} trials "
f"({num_trials / expected_trial_count:.3%}), best sequence",
" ".join([env.action_space.flags[i] for i in best_action_sequence]),
)
env.close()
with make_env() as env:
print(
f"completed {humanize.intcomma(num_trials)} of "
f"{humanize.intcomma(expected_trial_count)} trials "
f"({num_trials / expected_trial_count:.3%}), best sequence",
" ".join([env.action_space.flags[i] for i in best_action_sequence]),
)


def main(argv):
Expand All @@ -309,11 +307,10 @@ def main(argv):
if not benchmark:
raise app.UsageError("No benchmark specified.")

env = env_from_flags(benchmark)
env.reset()
benchmark = env.benchmark
sanitized_benchmark_uri = "/".join(benchmark.split("/")[-2:])
env.close()
with env_from_flags(benchmark) as env:
env.reset()
benchmark = env.benchmark
sanitized_benchmark_uri = "/".join(benchmark.split("/")[-2:])
logs_dir = Path(
FLAGS.output_dir or create_logging_dir(f"brute_force/{sanitized_benchmark_uri}")
)
Expand Down
2 changes: 1 addition & 1 deletion examples/explore.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,8 +447,8 @@ def main(argv):

print(f"Running with {FLAGS.nproc} threads.")
assert FLAGS.nproc >= 1
envs = []
try:
envs = []
for _ in range(FLAGS.nproc):
envs.append(make_env())
compute_action_graph(envs, episode_length=FLAGS.episode_length)
Expand Down
10 changes: 4 additions & 6 deletions examples/random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def run_random_walk(env: CompilerEnv, step_count: int) -> None:
if done:
print("Episode ended by environment")
break
env.close()

def reward_percentage(reward, rewards):
if sum(rewards) == 0:
Expand Down Expand Up @@ -92,11 +91,10 @@ def main(argv):
assert len(argv) == 1, f"Unrecognized flags: {argv[1:]}"

benchmark = benchmark_from_flags()
env = env_from_flags(benchmark)

step_min = min(FLAGS.step_min, FLAGS.step_max)
step_max = max(FLAGS.step_min, FLAGS.step_max)
run_random_walk(env=env, step_count=random.randint(step_min, step_max))
with env_from_flags(benchmark) as env:
step_min = min(FLAGS.step_min, FLAGS.step_max)
step_max = max(FLAGS.step_min, FLAGS.step_max)
run_random_walk(env=env, step_count=random.randint(step_min, step_max))


if __name__ == "__main__":
Expand Down
7 changes: 2 additions & 5 deletions examples/random_walk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,9 @@

def test_run_random_walk_smoke_test():
flags.FLAGS(["argv0"])
env = gym.make("llvm-autophase-ic-v0")
env.benchmark = "cbench-v1/crc32"
try:
with gym.make("llvm-autophase-ic-v0") as env:
env.benchmark = "cbench-v1/crc32"
run_random_walk(env=env, step_count=5)
finally:
env.close()


if __name__ == "__main__":
Expand Down
Loading