Skip to content

Commit

Permalink
[tests] Fix gym compatibility test.
Browse files Browse the repository at this point in the history
Regression introduced in facebookresearch#384.
  • Loading branch information
ChrisCummins committed Sep 10, 2021
1 parent 58f4587 commit 671dadc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tests/llvm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ py_test(
timeout = "short",
srcs = ["gym_interface_compatability.py"],
deps = [
"//compiler_gym",
"//compiler_gym/envs/llvm",
"//tests:test_main",
],
)
Expand Down
23 changes: 12 additions & 11 deletions tests/llvm/gym_interface_compatability.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,32 +6,33 @@
import gym
import pytest

import compiler_gym # noqa Register Environments
from compiler_gym.envs import CompilerEnv
from compiler_gym.envs.llvm import LlvmEnv
from tests.test_main import main


@pytest.fixture(scope="function")
def env() -> CompilerEnv:
return gym.make("llvm-autophase-ic-v0")
def env() -> LlvmEnv:
"""Create an LLVM environment."""
with gym.make("llvm-autophase-ic-v0") as env_:
yield env_


def test_type_classes(env: CompilerEnv):
def test_type_classes(env: LlvmEnv):
assert isinstance(env, gym.Env)
assert isinstance(env, CompilerEnv)
assert isinstance(env.unwrapped, CompilerEnv)
assert isinstance(env, LlvmEnv)
assert isinstance(env.unwrapped, LlvmEnv)
assert isinstance(env.action_space, gym.Space)
assert isinstance(env.observation_space, gym.Space)
assert isinstance(env.reward_range[0], float)
assert isinstance(env.reward_range[1], float)


def test_optional_properties(env: CompilerEnv):
def test_optional_properties(env: LlvmEnv):
assert "render.modes" in env.metadata
assert env.spec


def test_contextmanager(env: CompilerEnv, mocker):
def test_contextmanager(env: LlvmEnv, mocker):
mocker.spy(env, "close")
assert env.close.call_count == 0
with env:
Expand All @@ -48,7 +49,7 @@ def test_contextmanager_gym_make(mocker):
assert env.close.call_count == 1


def test_observation_wrapper(env: CompilerEnv):
def test_observation_wrapper(env: LlvmEnv):
class WrappedEnv(gym.ObservationWrapper):
def observation(self, observation):
return "Hello"
Expand All @@ -61,7 +62,7 @@ def observation(self, observation):
assert observation == "Hello"


def test_reward_wrapper(env: CompilerEnv):
def test_reward_wrapper(env: LlvmEnv):
class WrappedEnv(gym.RewardWrapper):
def reward(self, reward):
return 1
Expand Down

0 comments on commit 671dadc

Please sign in to comment.