From ab2a2457ef05de533578229bffec31b1423e3a61 Mon Sep 17 00:00:00 2001 From: Chris Cummins Date: Thu, 8 Apr 2021 16:16:00 +0100 Subject: [PATCH] WIP test --- compiler_gym/datasets/tar_dataset.py | 7 +++-- compiler_gym/envs/compiler_env.py | 41 ++++++++++++++------------ compiler_gym/envs/llvm/service/Cost.cc | 8 ++--- tests/llvm/fork_env_test.py | 13 ++++---- 4 files changed, 35 insertions(+), 34 deletions(-) diff --git a/compiler_gym/datasets/tar_dataset.py b/compiler_gym/datasets/tar_dataset.py index ed94f9df06..80681862cb 100644 --- a/compiler_gym/datasets/tar_dataset.py +++ b/compiler_gym/datasets/tar_dataset.py @@ -127,6 +127,9 @@ def __init__(self, manifest_url: str, manifest_sha256: str, **dataset_args): self.manifest_sha256 = manifest_sha256 self._manifest_path = self.site_data_path / f"manifest-{manifest_sha256}.txt" + self._manifest_lock = Lock() + self._manifest_lockfile = self.site_data_path / "manifest.LOCK" + def _read_manifest_file(self) -> List[str]: with open(self._manifest_path) as f: uris = f.read().rstrip().split("\n") @@ -138,8 +141,8 @@ def _benchmark_uris(self) -> List[str]: if self._manifest_path.is_file(): return self._read_manifest_file() - with self._tar_lock: - with fasteners.InterProcessLock(self._tar_lockfile): + with self._manifest_lock: + with fasteners.InterProcessLock(self._manifest_lockfile): # Now that we have acquired the lock, repeat the check. if self._manifest_path.is_file(): return self._read_manifest_file() diff --git a/compiler_gym/envs/compiler_env.py b/compiler_gym/envs/compiler_env.py index be96700929..b5efff31e7 100644 --- a/compiler_gym/envs/compiler_env.py +++ b/compiler_gym/envs/compiler_env.py @@ -536,18 +536,11 @@ def fork(self) -> "CompilerEnv": :return: A new environment instance. """ - if not self.in_episode: - if self.actions: - state_to_replay = self.state - self.logger.warning( - "Parent service of fork() has died, replaying state" - ) - else: - state_to_replay = None - if state_to_replay: - self.apply(state_to_replay) - else: - self.reset() + if self._benchmark_in_use and not self.in_episode: + self.logger.warning("Parent service of fork() has died, replaying state") + self.apply(self.state) + elif not self.in_episode: + raise ValueError("Must call reset() before fork()") request = ForkSessionRequest(session_id=self._session_id) reply: ForkSessionReply = self.service(self.service.stub.ForkSession, request) @@ -988,8 +981,21 @@ def apply(self, state: CompilerEnvState) -> None: # noqa ) def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult: - in_place = state is not None - state = state or self.state + """Validate an environment's state. + + :param state: A state to environment. If not provided, the current state + is validated. + + :returns: A :class:`ValidationResult`. + """ + if state: + self.reset(benchmark=state.benchmark) + in_place = False + else: + state = self.state + in_place = True + + assert self.in_episode errors: ValidationError = [] validation = { @@ -1057,13 +1063,10 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult ) ) - # TODO(https://github.com/facebookresearch/CompilerGym/issues/45): - # Call the new self.benchmark.validation_callback() method - # once implemented. - benchmark = self.get_benchmark() + benchmark = replay_target.get_benchmark() if benchmark.is_validatable(): validation["benchmark_semantics_validated"] = True - semantics_errors = benchmark.validate(self) + semantics_errors = benchmark.validate(replay_target) if semantics_errors: validation["benchmark_semantics_validation_failed"] = True errors += semantics_errors diff --git a/compiler_gym/envs/llvm/service/Cost.cc b/compiler_gym/envs/llvm/service/Cost.cc index da2ceae679..fe92f9d920 100644 --- a/compiler_gym/envs/llvm/service/Cost.cc +++ b/compiler_gym/envs/llvm/service/Cost.cc @@ -72,12 +72,8 @@ Status getTextSizeInBytes(llvm::Module& module, int64_t* value, const fs::path& #endif const auto clangPath = util::getSiteDataPath("llvm/10.0.0/bin/clang"); const auto llvmSizePath = util::getSiteDataPath("llvm/10.0.0/bin/llvm-size"); - if (!fs::exists(clangPath)) { - return Status(StatusCode::INTERNAL, fmt::format("File not found: {}", clangPath.string())); - } - if (!fs::exists(llvmSizePath)) { - return Status(StatusCode::INTERNAL, fmt::format("File not found: {}", llvmSizePath.string())); - } + DCHECK(fs::exists(clangPath)) << fmt::format("File not found: {}", clangPath.string()); + DCHECK(fs::exists(llvmSizePath)) << fmt::format("File not found: {}", llvmSizePath.string()); // Lower the module to an object file using clang and extract the .text // section size using llvm-size. diff --git a/tests/llvm/fork_env_test.py b/tests/llvm/fork_env_test.py index b42b7a5d3b..863c93a94f 100644 --- a/tests/llvm/fork_env_test.py +++ b/tests/llvm/fork_env_test.py @@ -5,6 +5,8 @@ """Tests for LlvmEnv.fork().""" import subprocess +import pytest + from compiler_gym.envs import LlvmEnv from compiler_gym.util.runfiles_path import runfiles_path from tests.test_main import main @@ -89,13 +91,10 @@ def test_fork_chain_child_processes_are_not_orphaned(env: LlvmEnv): def test_fork_before_reset(env: LlvmEnv): """Test that fork() before reset() starts an episode.""" assert not env.in_episode - fkd = env.fork() - try: - assert env.in_episode - assert fkd.in_episode - assert env.benchmark == fkd.benchmark - finally: - fkd.close() + with pytest.raises(ValueError) as e_ctx: + env.fork() + + assert str(e_ctx.value) == "Must call reset() before fork()" def test_fork_closed_service(env: LlvmEnv):