Skip to content

Commit

Permalink
[tests] Add a gym interface compatability test.
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisCummins committed Apr 28, 2021
1 parent 8b9a620 commit be08674
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 0 deletions.
10 changes: 10 additions & 0 deletions tests/llvm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,16 @@ py_test(
],
)

py_test(
name = "gym_interface_compatability",
timeout = "short",
srcs = ["gym_interface_compatability.py"],
deps = [
"//compiler_gym",
"//tests:test_main",
],
)

py_test(
name = "llvm_benchmarks_test",
srcs = ["llvm_benchmarks_test.py"],
Expand Down
81 changes: 81 additions & 0 deletions tests/llvm/gym_interface_compatability.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Test that LlvmEnv is compatible with OpenAI gym interface."""
import gym
import pytest

import compiler_gym # noqa Register Environments
from compiler_gym.envs import CompilerEnv
from tests.test_main import main


@pytest.fixture(scope="function")
def env() -> CompilerEnv:
env = gym.make("llvm-autophase-ic-v0")
try:
yield env
finally:
env.close()


def test_type_classes(env: CompilerEnv):
assert isinstance(env, gym.Env)
assert isinstance(env, CompilerEnv)
assert isinstance(env.unwrapped, CompilerEnv)
assert isinstance(env.action_space, gym.Space)
assert isinstance(env.observation_space, gym.Space)
assert isinstance(env.reward_range[0], float)
assert isinstance(env.reward_range[1], float)


def test_optional_properties(env: CompilerEnv):
assert "render.modes" in env.metadata
assert env.spec


def test_contextmanager(env: CompilerEnv, mocker):
mocker.spy(env, "close")
assert env.close.call_count == 0
with env:
pass
assert env.close.call_count == 1


def test_contextmanager_gym_make(mocker):
with gym.make("llvm-v0") as env:
mocker.spy(env, "close")
assert env.close.call_count == 0
with env:
pass
assert env.close.call_count == 1


def test_observation_wrapper(env: CompilerEnv):
class WrappedEnv(gym.ObservationWrapper):
def observation(self, observation):
return "Hello"

wrapped = WrappedEnv(env)
observation = wrapped.reset()
assert observation == "Hello"

observation, _, _, _ = wrapped.step(0)
assert observation == "Hello"


def test_reward_wrapper(env: CompilerEnv):
class WrappedEnv(gym.RewardWrapper):
def reward(self, reward):
return 1

wrapped = WrappedEnv(env)
wrapped.reset()

_, reward, _, _ = wrapped.step(0)
assert reward == 1


if __name__ == "__main__":
main()

0 comments on commit be08674

Please sign in to comment.