From b236f21a2ac3e89fc6d3e75eecd99c2829d3c240 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 12 May 2021 03:15:36 +0900 Subject: [PATCH] Update tests for trainer.fit returning None in PL 1.3.0 (#641) * Update tests for trainer.fit returning None * Remove unused refs * Update tests for trainer.fit returning None --- .../rl/integration/test_policy_models.py | 8 ++----- .../rl/integration/test_value_models.py | 22 +++++-------------- tests/models/test_autoencoders.py | 6 ++--- 3 files changed, 9 insertions(+), 27 deletions(-) diff --git a/tests/models/rl/integration/test_policy_models.py b/tests/models/rl/integration/test_policy_models.py index 23c8b510d2..440c1465c4 100644 --- a/tests/models/rl/integration/test_policy_models.py +++ b/tests/models/rl/integration/test_policy_models.py @@ -30,13 +30,9 @@ def test_reinforce(self): """Smoke test that the reinforce model runs""" model = Reinforce(self.hparams.env) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model) def test_policy_gradient(self): """Smoke test that the policy gradient model runs""" model = VanillaPolicyGradient(self.hparams.env) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model) diff --git a/tests/models/rl/integration/test_value_models.py b/tests/models/rl/integration/test_value_models.py index a723a0a8f0..c127b81aa3 100644 --- a/tests/models/rl/integration/test_value_models.py +++ b/tests/models/rl/integration/test_value_models.py @@ -37,41 +37,29 @@ def setUp(self) -> None: def test_dqn(self): """Smoke test that the DQN model runs""" model = DQN(self.hparams.env, num_envs=5) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model) def test_double_dqn(self): """Smoke test that the Double DQN model runs""" model = DoubleDQN(self.hparams.env) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model) def test_dueling_dqn(self): """Smoke test that the Dueling DQN model runs""" model = DuelingDQN(self.hparams.env) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model) def test_noisy_dqn(self): """Smoke test that the Noisy DQN model runs""" model = NoisyDQN(self.hparams.env) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model) def test_per_dqn(self): """Smoke test that the PER DQN model runs""" model = PERDQN(self.hparams.env) - result = self.trainer.fit(model) - - self.assertEqual(result, 1) + self.trainer.fit(model) # def test_n_step_dqn(self): # """Smoke test that the N Step DQN model runs""" # model = DQN(self.hparams.env, n_steps=self.hparams.n_steps) # result = self.trainer.fit(model) - # - # self.assertEqual(result, 1) diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index 322cb28774..36bfb7b1cb 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -19,8 +19,7 @@ def test_vae(tmpdir, datadir, dm_cls): gpus=None, ) - result = trainer.fit(model, datamodule=dm) - assert result == 1 + trainer.fit(model, datamodule=dm) @pytest.mark.parametrize("dm_cls", [pytest.param(CIFAR10DataModule, id="cifar10")]) @@ -35,8 +34,7 @@ def test_ae(tmpdir, datadir, dm_cls): gpus=None, ) - result = trainer.fit(model, datamodule=dm) - assert result == 1 + trainer.fit(model, datamodule=dm) @torch.no_grad()