Skip to content

Commit

Permalink
[env] Add a fallback fork() implementation.
Browse files Browse the repository at this point in the history
Supporting the Fork() operator is optional for CompilationSessions, so
provide a fallback implementation that creates a new environment and
replays the action sequence by hand.
  • Loading branch information
ChrisCummins committed May 17, 2021
1 parent 90b36dc commit fbd2af0
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 26 deletions.
70 changes: 44 additions & 26 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,39 +519,57 @@ def fork(self) -> "CompilerEnv":
:return: A new environment instance.
"""
if not self.in_episode:
if self.actions and not self.in_episode:
self.reset()
if self.actions:
self.logger.warning(
"Parent service of fork() has died, replaying state"
)
self.apply(self.state)
else:
self.reset()
actions = self.actions.copy()
_, _, done, _ = self.step(actions)
assert not done, "Failed to replay action sequence"

request = ForkSessionRequest(session_id=self._session_id)
reply: ForkSessionReply = self.service(self.service.stub.ForkSession, request)

# Create a new environment that shares the connection.
new_env = type(self)(
service=self._service_endpoint,
action_space=self.action_space,
connection_settings=self._connection_settings,
service_connection=self.service,
)

# Set the session ID.
new_env._session_id = reply.session_id # pylint: disable=protected-access
new_env.observation.session_id = reply.session_id
try:
reply: ForkSessionReply = self.service(
self.service.stub.ForkSession, request
)

# Now that we have initialized the environment with the current state,
# set the benchmark so that calls to new_env.reset() will correctly
# revert the environment to the initial benchmark state.
#
# pylint: disable=protected-access
new_env._next_benchmark = self._benchmark_in_use
# Create a new environment that shares the connection.
new_env = type(self)(
service=self._service_endpoint,
action_space=self.action_space,
connection_settings=self._connection_settings,
service_connection=self.service,
)

# Set the "visible" name of the current benchmark to hide the fact that
# we loaded from a custom bitcode file.
new_env._benchmark_in_use = self._benchmark_in_use
# Set the session ID.
new_env._session_id = reply.session_id # pylint: disable=protected-access
new_env.observation.session_id = reply.session_id

# Now that we have initialized the environment with the current state,
# set the benchmark so that calls to new_env.reset() will correctly
# revert the environment to the initial benchmark state.
#
# pylint: disable=protected-access
new_env._next_benchmark = self._benchmark_in_use

# Set the "visible" name of the current benchmark to hide the fact that
# we loaded from a custom bitcode file.
new_env._benchmark_in_use = self._benchmark_in_use
except NotImplementedError:
# Fallback implementation. If the compiler service does not support
# the Fork() operator then we create a new independent environment
# and apply the sequence of actions in the current environment to
# replay the state.
new_env = type(self)(
service=self._service_endpoint,
action_space=self.action_space,
benchmark=self.benchmark,
connection_settings=self._connection_settings,
)
new_env.reset()
_, _, done, _ = new_env.step(self.actions)
assert not done, "Failed to replay action sequence in forked environment"

# Create copies of the mutable reward and observation spaces. This
# is required to correctly calculate incremental updates.
Expand Down
12 changes: 12 additions & 0 deletions examples/example_compiler_gym_service/env_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,17 @@ def test_benchmarks(env: CompilerEnv):
]


def test_fork(env: CompilerEnv):
env.reset()
env.step(0)
env.step(1)
other_env = env.fork()
try:
assert env.benchmark == other_env.benchmark
assert other_env.actions == [0, 1]
finally:
other_env.close()


if __name__ == "__main__":
main()

0 comments on commit fbd2af0

Please sign in to comment.