diff --git a/src/pymgrid/envs/discrete/discrete.py b/src/pymgrid/envs/discrete/discrete.py index b2593a8d..e4f09972 100644 --- a/src/pymgrid/envs/discrete/discrete.py +++ b/src/pymgrid/envs/discrete/discrete.py @@ -49,11 +49,11 @@ def _get_action_space(self): """ # n_actions = 2**(self.modules.fixed.) fixed_sources = [(module.name, module.action_spaces['unnormalized'].shape[0], n_actions) - for module in self.fixed_modules.sources.iterlist() + for module in self.modules.fixed.sources.iterlist() for n_actions in range(module.action_spaces['unnormalized'].shape[0])] fixed_sources.extend([ (module.name, module.action_spaces['unnormalized'].shape[0], n_actions) - for module in self.fixed_modules.source_and_sinks.iterlist() + for module in self.modules.fixed.source_and_sinks.iterlist() for n_actions in range(module.action_spaces['unnormalized'].shape[0])]) priority_lists = list(permutations(fixed_sources)) diff --git a/src/pymgrid/microgrid/microgrid.py b/src/pymgrid/microgrid/microgrid.py index 1b46d430..c1cef1b7 100644 --- a/src/pymgrid/microgrid/microgrid.py +++ b/src/pymgrid/microgrid/microgrid.py @@ -477,10 +477,6 @@ def flex(self): """ return self._modules.flex - @property - def flat_modules(self): - raise AttributeError('Getting attribute flat_modules has been deprecated. Call .modules_dict() instead.') - @property def module_list(self): """ diff --git a/tests/envs/discrete.py b/tests/envs/discrete.py index c1a2e6fb..17db6ea4 100644 --- a/tests/envs/discrete.py +++ b/tests/envs/discrete.py @@ -5,8 +5,25 @@ class TestDiscreteEnv(TestCase): - def test_init(self): + def test_init_from_microgrid(self): microgrid = get_modular_microgrid() - env_2 = DiscreteMicrogridEnv(microgrid) + env = DiscreteMicrogridEnv(microgrid) + + self.assertEqual(env.modules, microgrid.modules) + self.assertIsNot(env.modules.module_tuples(), microgrid.modules.module_tuples()) + + n_obs = sum([x.observation_spaces['normalized'].shape[0] for x in microgrid.module_list]) + + self.assertEqual(env.observation_space.shape, (n_obs,)) + + def test_init_from_modules(self): + microgrid = get_modular_microgrid() + env = DiscreteMicrogridEnv(microgrid.modules.module_tuples(), add_unbalanced_module=False) + + self.assertEqual(env.modules, microgrid.modules) + self.assertIsNot(env.modules.module_tuples(), microgrid.modules.module_tuples()) + + n_obs = sum([x.observation_spaces['normalized'].shape[0] for x in microgrid.module_list]) + + self.assertEqual(env.observation_space.shape, (n_obs,)) - print("Here") \ No newline at end of file