Skip to content

Commit

Permalink
WIP test
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Apr 8, 2021
1 parent 4b80031 commit ab2a245
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 34 deletions.
7 changes: 5 additions & 2 deletions compiler_gym/datasets/tar_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __init__(self, manifest_url: str, manifest_sha256: str, **dataset_args):
self.manifest_sha256 = manifest_sha256
self._manifest_path = self.site_data_path / f"manifest-{manifest_sha256}.txt"

self._manifest_lock = Lock()
self._manifest_lockfile = self.site_data_path / "manifest.LOCK"

def _read_manifest_file(self) -> List[str]:
with open(self._manifest_path) as f:
uris = f.read().rstrip().split("\n")
Expand All @@ -138,8 +141,8 @@ def _benchmark_uris(self) -> List[str]:
if self._manifest_path.is_file():
return self._read_manifest_file()

with self._tar_lock:
with fasteners.InterProcessLock(self._tar_lockfile):
with self._manifest_lock:
with fasteners.InterProcessLock(self._manifest_lockfile):
# Now that we have acquired the lock, repeat the check.
if self._manifest_path.is_file():
return self._read_manifest_file()
Expand Down
41 changes: 22 additions & 19 deletions compiler_gym/envs/compiler_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,18 +536,11 @@ def fork(self) -> "CompilerEnv":
:return: A new environment instance.
"""
if not self.in_episode:
if self.actions:
state_to_replay = self.state
self.logger.warning(
"Parent service of fork() has died, replaying state"
)
else:
state_to_replay = None
if state_to_replay:
self.apply(state_to_replay)
else:
self.reset()
if self._benchmark_in_use and not self.in_episode:
self.logger.warning("Parent service of fork() has died, replaying state")
self.apply(self.state)
elif not self.in_episode:
raise ValueError("Must call reset() before fork()")

request = ForkSessionRequest(session_id=self._session_id)
reply: ForkSessionReply = self.service(self.service.stub.ForkSession, request)
Expand Down Expand Up @@ -988,8 +981,21 @@ def apply(self, state: CompilerEnvState) -> None: # noqa
)

def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult:
in_place = state is not None
state = state or self.state
"""Validate an environment's state.
:param state: A state to environment. If not provided, the current state
is validated.
:returns: A :class:`ValidationResult`.
"""
if state:
self.reset(benchmark=state.benchmark)
in_place = False
else:
state = self.state
in_place = True

assert self.in_episode

errors: ValidationError = []
validation = {
Expand Down Expand Up @@ -1057,13 +1063,10 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult
)
)

# TODO(https://github.com/facebookresearch/CompilerGym/issues/45):
# Call the new self.benchmark.validation_callback() method
# once implemented.
benchmark = self.get_benchmark()
benchmark = replay_target.get_benchmark()
if benchmark.is_validatable():
validation["benchmark_semantics_validated"] = True
semantics_errors = benchmark.validate(self)
semantics_errors = benchmark.validate(replay_target)
if semantics_errors:
validation["benchmark_semantics_validation_failed"] = True
errors += semantics_errors
Expand Down
8 changes: 2 additions & 6 deletions compiler_gym/envs/llvm/service/Cost.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,8 @@ Status getTextSizeInBytes(llvm::Module& module, int64_t* value, const fs::path&
#endif
const auto clangPath = util::getSiteDataPath("llvm/10.0.0/bin/clang");
const auto llvmSizePath = util::getSiteDataPath("llvm/10.0.0/bin/llvm-size");
if (!fs::exists(clangPath)) {
return Status(StatusCode::INTERNAL, fmt::format("File not found: {}", clangPath.string()));
}
if (!fs::exists(llvmSizePath)) {
return Status(StatusCode::INTERNAL, fmt::format("File not found: {}", llvmSizePath.string()));
}
DCHECK(fs::exists(clangPath)) << fmt::format("File not found: {}", clangPath.string());
DCHECK(fs::exists(llvmSizePath)) << fmt::format("File not found: {}", llvmSizePath.string());

// Lower the module to an object file using clang and extract the .text
// section size using llvm-size.
Expand Down
13 changes: 6 additions & 7 deletions tests/llvm/fork_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
"""Tests for LlvmEnv.fork()."""
import subprocess

import pytest

from compiler_gym.envs import LlvmEnv
from compiler_gym.util.runfiles_path import runfiles_path
from tests.test_main import main
Expand Down Expand Up @@ -89,13 +91,10 @@ def test_fork_chain_child_processes_are_not_orphaned(env: LlvmEnv):
def test_fork_before_reset(env: LlvmEnv):
"""Test that fork() before reset() starts an episode."""
assert not env.in_episode
fkd = env.fork()
try:
assert env.in_episode
assert fkd.in_episode
assert env.benchmark == fkd.benchmark
finally:
fkd.close()
with pytest.raises(ValueError) as e_ctx:
env.fork()

assert str(e_ctx.value) == "Must call reset() before fork()"


def test_fork_closed_service(env: LlvmEnv):
Expand Down

0 comments on commit ab2a245

Please sign in to comment.