Skip to content

Commit b538c66

Browse files
committed
[BugFix] Test and fix life cycle of env with dynamic non-tensor spec
ghstack-source-id: 77da3a6baf0cb42525dd3a564b36ac03a531d17a Pull Request resolved: #2812
1 parent a3a1ebe commit b538c66

11 files changed

+330
-53
lines changed

test/mocking_classes.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import string
99
from typing import Dict, List, Optional
1010

11+
import numpy as np
12+
1113
import torch
1214
import torch.nn as nn
1315
from tensordict import tensorclass, TensorDict, TensorDictBase
@@ -26,6 +28,7 @@
2628
Unbounded,
2729
)
2830
from torchrl.data.utils import consolidate_spec
31+
from torchrl.envs import Transform
2932
from torchrl.envs.common import EnvBase
3033
from torchrl.envs.model_based.common import ModelBasedEnvBase
3134
from torchrl.envs.utils import (
@@ -34,7 +37,6 @@
3437
MarlGroupMapType,
3538
)
3639

37-
3840
spec_dict = {
3941
"bounded": Bounded,
4042
"one_hot": OneHot,
@@ -2395,3 +2397,69 @@ def _step(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
23952397
f1 + 1,
23962398
)
23972399
return td
2400+
2401+
2402+
@tensorclass
2403+
class History:
2404+
role: str
2405+
content: str
2406+
2407+
2408+
class HistoryTransform(Transform):
2409+
"""A mocking class to record history."""
2410+
2411+
def transform_observation_spec(self, observation_spec: Composite) -> Composite:
2412+
defaults = {
2413+
"role": NonTensor(
2414+
example_data="a role!",
2415+
shape=(-1,),
2416+
),
2417+
"content": NonTensor(
2418+
example_data="a content!",
2419+
shape=(-1,),
2420+
),
2421+
}
2422+
observation_spec["history"] = Composite(
2423+
defaults,
2424+
shape=(-1,),
2425+
data_cls=History,
2426+
)
2427+
assert observation_spec.device == self.parent.device
2428+
assert observation_spec["history"].device == self.parent.device
2429+
return observation_spec
2430+
2431+
def _reset(
2432+
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
2433+
) -> TensorDictBase:
2434+
assert tensordict_reset.device == self.parent.device
2435+
tensordict_reset["history"] = torch.stack(
2436+
[
2437+
History(role="system", content="0"),
2438+
History(role="user", content="1"),
2439+
]
2440+
)
2441+
assert tensordict_reset["history"].device == self.parent.device
2442+
return tensordict_reset
2443+
2444+
def _step(
2445+
self, tensordict: TensorDictBase, next_tensordict: TensorDictBase
2446+
) -> TensorDictBase:
2447+
assert next_tensordict.device == self.parent.device
2448+
history = tensordict["history"]
2449+
local_history = History(
2450+
role=np.random.choice(["user", "system", "assistant"]),
2451+
content=str(int(history.content[-1]) + 1),
2452+
device=history.device,
2453+
)
2454+
# history = tensordict["history"].append(local_history)
2455+
try:
2456+
history = torch.stack(list(history.unbind(0)) + [local_history])
2457+
except Exception:
2458+
raise
2459+
assert isinstance(history, History)
2460+
next_tensordict["history"] = history
2461+
assert next_tensordict["history"].device == self.parent.device, (
2462+
next_tensordict["history"],
2463+
self.parent.device,
2464+
)
2465+
return next_tensordict

test/test_env.py

+143-15
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
CatFrames,
4343
CatTensors,
4444
ChessEnv,
45+
ConditionalSkip,
4546
DoubleToFloat,
4647
EnvBase,
4748
EnvCreator,
@@ -72,6 +73,7 @@
7273
check_marl_grouping,
7374
make_composite_from_td,
7475
MarlGroupMapType,
76+
RandomPolicy,
7577
step_mdp,
7678
)
7779
from torchrl.modules import Actor, ActorCriticOperator, MLP, SafeModule, ValueOperator
@@ -134,6 +136,7 @@
134136
EnvWithTensorClass,
135137
HeterogeneousCountingEnv,
136138
HeterogeneousCountingEnvPolicy,
139+
HistoryTransform,
137140
MockBatchedLockedEnv,
138141
MockBatchedUnLockedEnv,
139142
MockSerialEnv,
@@ -174,6 +177,7 @@
174177
EnvWithTensorClass,
175178
HeterogeneousCountingEnv,
176179
HeterogeneousCountingEnvPolicy,
180+
HistoryTransform,
177181
MockBatchedLockedEnv,
178182
MockBatchedUnLockedEnv,
179183
MockSerialEnv,
@@ -3634,8 +3638,11 @@ def test_serial(self, bwad, use_buffers):
36343638
def test_parallel(self, bwad, use_buffers):
36353639
N = 50
36363640
env = ParallelEnv(2, EnvWithMetadata, use_buffers=use_buffers)
3637-
r = env.rollout(N, break_when_any_done=bwad)
3638-
assert r.get("non_tensor").tolist() == [list(range(N))] * 2
3641+
try:
3642+
r = env.rollout(N, break_when_any_done=bwad)
3643+
assert r.get("non_tensor").tolist() == [list(range(N))] * 2
3644+
finally:
3645+
env.close(raise_if_closed=False)
36393646

36403647
class AddString(Transform):
36413648
def __init__(self):
@@ -3667,19 +3674,22 @@ def test_partial_reset(self, batched):
36673674
env = ParallelEnv(2, [env0, env1], mp_start_method=mp_ctx)
36683675
else:
36693676
env = SerialEnv(2, [env0, env1])
3670-
s = env.reset()
3671-
i = 0
3672-
for i in range(10): # noqa: B007
3673-
s, s_ = env.step_and_maybe_reset(
3674-
s.set("action", torch.ones(2, 1, dtype=torch.int))
3675-
)
3676-
if s.get(("next", "done")).any():
3677-
break
3678-
s = s_
3679-
assert i == 5
3680-
assert (s["next", "done"] == torch.tensor([[True], [False]])).all()
3681-
assert s_["string"] == ["0", "6"]
3682-
assert s["next", "string"] == ["6", "6"]
3677+
try:
3678+
s = env.reset()
3679+
i = 0
3680+
for i in range(10): # noqa: B007
3681+
s, s_ = env.step_and_maybe_reset(
3682+
s.set("action", torch.ones(2, 1, dtype=torch.int))
3683+
)
3684+
if s.get(("next", "done")).any():
3685+
break
3686+
s = s_
3687+
assert i == 5
3688+
assert (s["next", "done"] == torch.tensor([[True], [False]])).all()
3689+
assert s_["string"] == ["0", "6"]
3690+
assert s["next", "string"] == ["6", "6"]
3691+
finally:
3692+
env.close(raise_if_closed=False)
36833693

36843694
@pytest.mark.skipif(not _has_transformers, reason="transformers required")
36853695
def test_str2str_env_tokenizer(self):
@@ -4398,6 +4408,124 @@ def test_serial_partial_step_and_maybe_reset(self, use_buffers, device, env_devi
43984408
assert (td[3].get("next") != 0).any()
43994409

44004410

4411+
class TestEnvWithHistory:
4412+
@pytest.fixture(autouse=True, scope="class")
4413+
def set_capture(self):
4414+
with set_capture_non_tensor_stack(False), set_auto_unwrap_transformed_env(
4415+
False
4416+
):
4417+
yield
4418+
return
4419+
4420+
def _make_env(self, device, max_steps=10):
4421+
return CountingEnv(device=device, max_steps=max_steps).append_transform(
4422+
HistoryTransform()
4423+
)
4424+
4425+
def _make_skipping_env(self, device, max_steps=10):
4426+
env = self._make_env(device=device, max_steps=max_steps)
4427+
# skip every 3 steps
4428+
env = env.append_transform(
4429+
ConditionalSkip(lambda td: ((td["step_count"] % 3) == 2))
4430+
)
4431+
env = TransformedEnv(env, StepCounter())
4432+
return env
4433+
4434+
@pytest.mark.parametrize("device", [None, "cpu"])
4435+
def test_env_history_base(self, device):
4436+
env = self._make_env(device)
4437+
env.check_env_specs()
4438+
4439+
@pytest.mark.parametrize("device", [None, "cpu"])
4440+
def test_skipping_history_env(self, device):
4441+
env = self._make_skipping_env(device)
4442+
env.check_env_specs()
4443+
r = env.rollout(100)
4444+
4445+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4446+
@pytest.mark.parametrize("device", [None, "cpu"])
4447+
@pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"])
4448+
@pytest.mark.parametrize("consolidate", [False, True])
4449+
def test_env_history_base_batched(
4450+
self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate
4451+
):
4452+
if batch_cls == "parallel":
4453+
batch_cls = maybe_fork_ParallelEnv
4454+
env = batch_cls(
4455+
2,
4456+
lambda: self._make_env(device_env),
4457+
device=device,
4458+
consolidate=consolidate,
4459+
)
4460+
try:
4461+
assert not env._use_buffers
4462+
env.check_env_specs(break_when_any_done="both")
4463+
finally:
4464+
env.close(raise_if_closed=False)
4465+
4466+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4467+
@pytest.mark.parametrize("device", [None, "cpu"])
4468+
@pytest.mark.parametrize("batch_cls", [SerialEnv, "parallel"])
4469+
@pytest.mark.parametrize("consolidate", [False, True])
4470+
def test_skipping_history_env_batched(
4471+
self, device, device_env, batch_cls, maybe_fork_ParallelEnv, consolidate
4472+
):
4473+
if batch_cls == "parallel":
4474+
batch_cls = maybe_fork_ParallelEnv
4475+
env = batch_cls(
4476+
2,
4477+
lambda: self._make_skipping_env(device_env),
4478+
device=device,
4479+
consolidate=consolidate,
4480+
)
4481+
try:
4482+
env.check_env_specs()
4483+
finally:
4484+
env.close(raise_if_closed=False)
4485+
4486+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4487+
@pytest.mark.parametrize("collector_cls", [SyncDataCollector])
4488+
def test_env_history_base_collector(self, device_env, collector_cls):
4489+
env = self._make_env(device_env)
4490+
collector = collector_cls(
4491+
env, RandomPolicy(env.full_action_spec), total_frames=35, frames_per_batch=5
4492+
)
4493+
for d in collector:
4494+
for i in range(d.shape[0] - 1):
4495+
assert (
4496+
d[i + 1]["history"].content[0] == d[i]["next", "history"].content[0]
4497+
)
4498+
4499+
@pytest.mark.parametrize("device_env", [None, "cpu"])
4500+
@pytest.mark.parametrize("collector_cls", [SyncDataCollector])
4501+
def test_skipping_history_env_collector(self, device_env, collector_cls):
4502+
env = self._make_skipping_env(device_env, max_steps=10)
4503+
collector = collector_cls(
4504+
env,
4505+
lambda td: td.update(env.full_action_spec.one()),
4506+
total_frames=35,
4507+
frames_per_batch=5,
4508+
)
4509+
length = None
4510+
count = 1
4511+
for d in collector:
4512+
for k in range(1, 5):
4513+
if len(d[k]["history"].content) == 2:
4514+
count = 1
4515+
continue
4516+
if count % 3 == 2:
4517+
assert (
4518+
d[k]["next", "history"].content
4519+
== d[k - 1]["next", "history"].content
4520+
), (d["next", "history"].content, k, count)
4521+
else:
4522+
assert d[k]["next", "history"].content[-1] == str(
4523+
int(d[k - 1]["next", "history"].content[-1]) + 1
4524+
), (d["next", "history"].content, k, count)
4525+
count += 1
4526+
count += 1
4527+
4528+
44014529
if __name__ == "__main__":
44024530
args, unknown = argparse.ArgumentParser().parse_known_args()
44034531
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_specs.py

+7
Original file line numberDiff line numberDiff line change
@@ -3912,6 +3912,13 @@ def test_example_data_ineq(self):
39123912
nts1 = NonTensor(shape=(3, 4), example_data="example_data 2")
39133913
assert nts0 != nts1
39143914

3915+
def test_device_cast(self):
3916+
comp = Composite(device="cpu")
3917+
comp["nontensor"] = NonTensor(device=None)
3918+
assert comp["nontensor"].device == torch.device("cpu")
3919+
comp["nontensor"] = NonTensor(device="cpu")
3920+
assert comp["nontensor"].device == torch.device("cpu")
3921+
39153922

39163923
@pytest.mark.skipif(not torch.cuda.is_available(), reason="not cuda device")
39173924
def test_device_ordinal():

test/test_transforms.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -13498,20 +13498,21 @@ def check_non_tensor_match(self, td):
1349813498

1349913499
class ToString(Transform):
1350013500
def _apply_transform(self, obs: torch.Tensor) -> None:
13501-
return NonTensorData(str(obs), device=obs.device)
13501+
return NonTensorData(str(obs), device=self.parent.device)
1350213502

1350313503
def _reset(
1350413504
self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase
1350513505
) -> TensorDictBase:
13506-
return self._call(tensordict_reset)
13506+
reset_data = self._call(tensordict_reset)
13507+
return reset_data
1350713508

1350813509
def transform_observation_spec(
1350913510
self, observation_spec: TensorSpec
1351013511
) -> TensorSpec:
1351113512
observation_spec["obs_str"] = NonTensor(
1351213513
example_data="a string!",
1351313514
shape=observation_spec.shape,
13514-
device=observation_spec.device,
13515+
device=self.parent.device,
1351513516
)
1351613517
return observation_spec
1351713518

@@ -13545,7 +13546,8 @@ def test_single_trans_env_check(self, bwad):
1354513546
self.check_non_tensor_match(r)
1354613547

1354713548
@pytest.mark.parametrize("bwad", [False, True])
13548-
def test_serial_trans_env_check(self, bwad):
13549+
@pytest.mark.parametrize("device", [None])
13550+
def test_serial_trans_env_check(self, bwad, device):
1354913551
def make_env(i):
1355013552
env = TestConditionalSkip.CountinEnvWithString()
1355113553
base_env = TransformedEnv(
@@ -13561,7 +13563,9 @@ def make_env(i):
1356113563
auto_unwrap=False,
1356213564
)
1356313565

13564-
env = SerialEnv(2, [partial(make_env, i=0), partial(make_env, i=1)])
13566+
env = SerialEnv(
13567+
2, [partial(make_env, i=0), partial(make_env, i=1)], device=device
13568+
)
1356513569
env.check_env_specs()
1356613570
policy = lambda td: td.set("action", torch.ones((2, 1)))
1356713571
r = env.rollout(10, policy, break_when_any_done=bwad)
@@ -13571,7 +13575,8 @@ def make_env(i):
1357113575
self.check_non_tensor_match(r)
1357213576

1357313577
@pytest.mark.parametrize("bwad", [False, True])
13574-
def test_parallel_trans_env_check(self, bwad):
13578+
@pytest.mark.parametrize("device", [None])
13579+
def test_parallel_trans_env_check(self, bwad, device):
1357513580
def make_env(i):
1357613581
env = TestConditionalSkip.CountinEnvWithString()
1357713582
base_env = TransformedEnv(
@@ -13588,7 +13593,10 @@ def make_env(i):
1358813593
)
1358913594

1359013595
env = ParallelEnv(
13591-
2, [partial(make_env, i=0), partial(make_env, i=1)], mp_start_method=mp_ctx
13596+
2,
13597+
[partial(make_env, i=0), partial(make_env, i=1)],
13598+
mp_start_method=mp_ctx,
13599+
device=device,
1359213600
)
1359313601
try:
1359413602
env.check_env_specs()

0 commit comments

Comments
 (0)