|
42 | 42 | CatFrames,
|
43 | 43 | CatTensors,
|
44 | 44 | ChessEnv,
|
| 45 | + ConditionalSkip, |
45 | 46 | DoubleToFloat,
|
46 | 47 | EnvBase,
|
47 | 48 | EnvCreator,
|
|
72 | 73 | check_marl_grouping,
|
73 | 74 | make_composite_from_td,
|
74 | 75 | MarlGroupMapType,
|
| 76 | + RandomPolicy, |
75 | 77 | step_mdp,
|
76 | 78 | )
|
77 | 79 | from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
|
|
134 | 136 | EnvWithTensorClass,
|
135 | 137 | HeterogeneousCountingEnv,
|
136 | 138 | HeterogeneousCountingEnvPolicy,
|
| 139 | + HistoryTransform, |
137 | 140 | MockBatchedLockedEnv,
|
138 | 141 | MockBatchedUnLockedEnv,
|
139 | 142 | MockSerialEnv,
|
|
174 | 177 | EnvWithTensorClass,
|
175 | 178 | HeterogeneousCountingEnv,
|
176 | 179 | HeterogeneousCountingEnvPolicy,
|
| 180 | + HistoryTransform, |
177 | 181 | MockBatchedLockedEnv,
|
178 | 182 | MockBatchedUnLockedEnv,
|
179 | 183 | MockSerialEnv,
|
@@ -3634,8 +3638,11 @@ def test_serial(self, bwad, use_buffers):
|
3634 | 3638 | def test_parallel(self, bwad, use_buffers):
|
3635 | 3639 | N = 50
|
3636 | 3640 | env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
|
3637 |
| - r = env.rollout(N, break_when_any_done=bwad) |
3638 |
| - assert r.get("non_tensor").tolist() == [list(range(N))] * 2 |
| 3641 | + try: |
| 3642 | + r = env.rollout(N, break_when_any_done=bwad) |
| 3643 | + assert r.get("non_tensor").tolist() == [list(range(N))] * 2 |
| 3644 | + finally: |
| 3645 | + env.close(raise_if_closed=False) |
3639 | 3646 |
|
3640 | 3647 | class AddString(Transform):
|
3641 | 3648 | def __init__(self):
|
@@ -3667,19 +3674,22 @@ def test_partial_reset(self, batched):
|
3667 | 3674 | env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx)
|
3668 | 3675 | else:
|
3669 | 3676 | env = SerialEnv(2, [env0, env1])
|
3670 |
| - s = env.reset() |
3671 |
| - i = 0 |
3672 |
| - for i in range(10): # noqa: B007 |
3673 |
| - s, s_ = env.step_and_maybe_reset( |
3674 |
| - s.set("action", torch.ones(2, 1, dtype=torch.int)) |
3675 |
| - ) |
3676 |
| - if s.get(("next", "done")).any(): |
3677 |
| - break |
3678 |
| - s = s_ |
3679 |
| - assert i == 5 |
3680 |
| - assert (s["next", "done"] == torch.tensor([[True], [False]])).all() |
3681 |
| - assert s_["string"] == ["0", "6"] |
3682 |
| - assert s["next", "string"] == ["6", "6"] |
| 3677 | + try: |
| 3678 | + s = env.reset() |
| 3679 | + i = 0 |
| 3680 | + for i in range(10): # noqa: B007 |
| 3681 | + s, s_ = env.step_and_maybe_reset( |
| 3682 | + s.set("action", torch.ones(2, 1, dtype=torch.int)) |
| 3683 | + ) |
| 3684 | + if s.get(("next", "done")).any(): |
| 3685 | + break |
| 3686 | + s = s_ |
| 3687 | + assert i == 5 |
| 3688 | + assert (s["next", "done"] == torch.tensor([[True], [False]])).all() |
| 3689 | + assert s_["string"] == ["0", "6"] |
| 3690 | + assert s["next", "string"] == ["6", "6"] |
| 3691 | + finally: |
| 3692 | + env.close(raise_if_closed=False) |
3683 | 3693 |
|
3684 | 3694 | @pytest.mark.skipif(not _has_transformers, reason="transformers required")
|
3685 | 3695 | def test_str2str_env_tokenizer(self):
|
@@ -4398,6 +4408,124 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
|
4398 | 4408 | assert (td[3].get("next") != 0).any()
|
4399 | 4409 |
|
4400 | 4410 |
|
| 4411 | +class TestEnvWithHistory: |
| 4412 | + @pytest.fixture(autouse=True, scope="class") |
| 4413 | + def set_capture(self): |
| 4414 | + with set_capture_non_tensor_stack(False), set_auto_unwrap_transformed_env( |
| 4415 | + False |
| 4416 | + ): |
| 4417 | + yield |
| 4418 | + return |
| 4419 | + |
| 4420 | + def _make_env(self, device, max_steps=10): |
| 4421 | + return CountingEnv(device=device, max_steps=max_steps).append_transform( |
| 4422 | + HistoryTransform() |
| 4423 | + ) |
| 4424 | + |
| 4425 | + def _make_skipping_env(self, device, max_steps=10): |
| 4426 | + env = self._make_env(device=device, max_steps=max_steps) |
| 4427 | + # skip every 3 steps |
| 4428 | + env = env.append_transform( |
| 4429 | + ConditionalSkip(lambda td: ((td["step_count"] % 3) == 2)) |
| 4430 | + ) |
| 4431 | + env = TransformedEnv(env, StepCounter()) |
| 4432 | + return env |
| 4433 | + |
| 4434 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4435 | + def test_env_history_base(self, device): |
| 4436 | + env = self._make_env(device) |
| 4437 | + env.check_env_specs() |
| 4438 | + |
| 4439 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4440 | + def test_skipping_history_env(self, device): |
| 4441 | + env = self._make_skipping_env(device) |
| 4442 | + env.check_env_specs() |
| 4443 | + r = env.rollout(100) |
| 4444 | + |
| 4445 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4446 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4447 | + @pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"]) |
| 4448 | + @pytest.mark.parametrize("consolidate", [False, True]) |
| 4449 | + def test_env_history_base_batched( |
| 4450 | + self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate |
| 4451 | + ): |
| 4452 | + if batch_cls == "parallel": |
| 4453 | + batch_cls = maybe_fork_ParallelEnv |
| 4454 | + env = batch_cls( |
| 4455 | + 2, |
| 4456 | + lambda: self._make_env(device_env), |
| 4457 | + device=device, |
| 4458 | + consolidate=consolidate, |
| 4459 | + ) |
| 4460 | + try: |
| 4461 | + assert not env._use_buffers |
| 4462 | + env.check_env_specs(break_when_any_done="both") |
| 4463 | + finally: |
| 4464 | + env.close(raise_if_closed=False) |
| 4465 | + |
| 4466 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4467 | + @pytest.mark.parametrize("device", [None, "cpu"]) |
| 4468 | + @pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"]) |
| 4469 | + @pytest.mark.parametrize("consolidate", [False, True]) |
| 4470 | + def test_skipping_history_env_batched( |
| 4471 | + self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate |
| 4472 | + ): |
| 4473 | + if batch_cls == "parallel": |
| 4474 | + batch_cls = maybe_fork_ParallelEnv |
| 4475 | + env = batch_cls( |
| 4476 | + 2, |
| 4477 | + lambda: self._make_skipping_env(device_env), |
| 4478 | + device=device, |
| 4479 | + consolidate=consolidate, |
| 4480 | + ) |
| 4481 | + try: |
| 4482 | + env.check_env_specs() |
| 4483 | + finally: |
| 4484 | + env.close(raise_if_closed=False) |
| 4485 | + |
| 4486 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4487 | + @pytest.mark.parametrize("collector_cls", [SyncDataCollector]) |
| 4488 | + def test_env_history_base_collector(self, device_env, collector_cls): |
| 4489 | + env = self._make_env(device_env) |
| 4490 | + collector = collector_cls( |
| 4491 | + env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5 |
| 4492 | + ) |
| 4493 | + for d in collector: |
| 4494 | + for i in range(d.shape[0] - 1): |
| 4495 | + assert ( |
| 4496 | + d[i + 1]["history"].content[0] == d[i]["next", "history"].content[0] |
| 4497 | + ) |
| 4498 | + |
| 4499 | + @pytest.mark.parametrize("device_env", [None, "cpu"]) |
| 4500 | + @pytest.mark.parametrize("collector_cls", [SyncDataCollector]) |
| 4501 | + def test_skipping_history_env_collector(self, device_env, collector_cls): |
| 4502 | + env = self._make_skipping_env(device_env, max_steps=10) |
| 4503 | + collector = collector_cls( |
| 4504 | + env, |
| 4505 | + lambda td: td.update(env.full_action_spec.one()), |
| 4506 | + total_frames=35, |
| 4507 | + frames_per_batch=5, |
| 4508 | + ) |
| 4509 | + length = None |
| 4510 | + count = 1 |
| 4511 | + for d in collector: |
| 4512 | + for k in range(1, 5): |
| 4513 | + if len(d[k]["history"].content) == 2: |
| 4514 | + count = 1 |
| 4515 | + continue |
| 4516 | + if count % 3 == 2: |
| 4517 | + assert ( |
| 4518 | + d[k]["next", "history"].content |
| 4519 | + == d[k - 1]["next", "history"].content |
| 4520 | + ), (d["next", "history"].content, k, count) |
| 4521 | + else: |
| 4522 | + assert d[k]["next", "history"].content[-1] == str( |
| 4523 | + int(d[k - 1]["next", "history"].content[-1]) + 1 |
| 4524 | + ), (d["next", "history"].content, k, count) |
| 4525 | + count += 1 |
| 4526 | + count += 1 |
| 4527 | + |
| 4528 | + |
4401 | 4529 | if __name__ == "__main__":
|
4402 | 4530 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
4403 | 4531 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments