diff --git a/tests/llvm/BUILD b/tests/llvm/BUILD index 25b538c5ea..b71767354e 100644 --- a/tests/llvm/BUILD +++ b/tests/llvm/BUILD @@ -142,7 +142,6 @@ py_test( deps = [ "//compiler_gym", "//tests:test_main", - "//tests/pytest_plugins:llvm", ], ) diff --git a/tests/llvm/gym_interface_compatability.py b/tests/llvm/gym_interface_compatability.py index 5f8c8fae85..111b68f476 100644 --- a/tests/llvm/gym_interface_compatability.py +++ b/tests/llvm/gym_interface_compatability.py @@ -4,30 +4,36 @@ # LICENSE file in the root directory of this source tree. """Test that LlvmEnv is compatible with OpenAI gym interface.""" import gym +import pytest import compiler_gym # noqa Register Environments -from compiler_gym.envs import CompilerEnv +from compiler_gym.envs import LlvmEnv from tests.test_main import main -pytest_plugins = ["tests.pytest_plugins.llvm"] +@pytest.fixture(scope="function") +def env() -> LlvmEnv: + """Create an LLVM environment.""" + with gym.make("llvm-autophase-ic-v0") as env_: + yield env_ -def test_type_classes(env: CompilerEnv): + +def test_type_classes(env: LlvmEnv): assert isinstance(env, gym.Env) - assert isinstance(env, CompilerEnv) - assert isinstance(env.unwrapped, CompilerEnv) + assert isinstance(env, LlvmEnv) + assert isinstance(env.unwrapped, LlvmEnv) assert isinstance(env.action_space, gym.Space) assert isinstance(env.observation_space, gym.Space) assert isinstance(env.reward_range[0], float) assert isinstance(env.reward_range[1], float) -def test_optional_properties(env: CompilerEnv): +def test_optional_properties(env: LlvmEnv): assert "render.modes" in env.metadata assert env.spec -def test_contextmanager(env: CompilerEnv, mocker): +def test_contextmanager(env: LlvmEnv, mocker): mocker.spy(env, "close") assert env.close.call_count == 0 with env: @@ -44,7 +50,7 @@ def test_contextmanager_gym_make(mocker): assert env.close.call_count == 1 -def test_observation_wrapper(env: CompilerEnv): +def test_observation_wrapper(env: LlvmEnv): class WrappedEnv(gym.ObservationWrapper): def observation(self, observation): return "Hello" @@ -57,7 +63,7 @@ def observation(self, observation): assert observation == "Hello" -def test_reward_wrapper(env: CompilerEnv): +def test_reward_wrapper(env: LlvmEnv): class WrappedEnv(gym.RewardWrapper): def reward(self, reward): return 1