Skip to content

Commit 6e40548

Browse files
committed
[BugFix] Fix PEnv device copies
ghstack-source-id: df39fd2e4cd72f24c645b0ac32b46ab3e8d847fc Pull Request resolved: #2840
1 parent ba8be9c commit 6e40548

File tree

2 files changed

+47
-35
lines changed

2 files changed

+47
-35
lines changed

test/test_env.py

+28
Original file line numberDiff line numberDiff line change
@@ -1692,6 +1692,34 @@ def test_parallel_env_device(
16921692
env_serial.close(raise_if_closed=False)
16931693
env0.close(raise_if_closed=False)
16941694

1695+
@pytest.mark.skipif(not _has_gym, reason="no gym")
1696+
@pytest.mark.parametrize("env_device", [None, "cpu"])
1697+
def test_parallel_env_device_vs_no_device(self, maybe_fork_ParallelEnv, env_device):
1698+
def make_env() -> GymEnv:
1699+
env = GymEnv(PENDULUM_VERSIONED(), device=env_device)
1700+
return env.append_transform(DoubleToFloat())
1701+
1702+
# Rollouts work with a regular env
1703+
parallel_env = maybe_fork_ParallelEnv(
1704+
num_workers=1, create_env_fn=make_env, device=None
1705+
)
1706+
parallel_env.reset()
1707+
parallel_env.set_seed(0)
1708+
torch.manual_seed(0)
1709+
1710+
parallel_rollout = parallel_env.rollout(max_steps=10)
1711+
1712+
# Rollout doesn't work with Parallelnv
1713+
parallel_env = maybe_fork_ParallelEnv(
1714+
num_workers=1, create_env_fn=make_env, device="cpu"
1715+
)
1716+
parallel_env.reset()
1717+
parallel_env.set_seed(0)
1718+
torch.manual_seed(0)
1719+
1720+
parallel_rollout_cpu = parallel_env.rollout(max_steps=10)
1721+
assert_allclose_td(parallel_rollout, parallel_rollout_cpu)
1722+
16951723
@pytest.mark.skipif(not _has_gym, reason="no gym")
16961724
@pytest.mark.flaky(reruns=3, reruns_delay=1)
16971725
@pytest.mark.parametrize(

torchrl/envs/batched_envs.py

+19-35
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,14 @@ def __init__(
379379

380380
is_spec_locked = EnvBase.is_spec_locked
381381

382+
def select_and_clone(self, name, tensor, selected_keys=None):
383+
if selected_keys is None:
384+
selected_keys = self._selected_step_keys
385+
if name in selected_keys:
386+
if self.device is not None and tensor.device != self.device:
387+
return tensor.to(self.device, non_blocking=self.non_blocking)
388+
return tensor.clone()
389+
382390
@property
383391
def non_blocking(self):
384392
nb = self._non_blocking
@@ -1072,12 +1080,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
10721080
selected_output_keys = self._selected_reset_keys_filt
10731081

10741082
# select + clone creates 2 tds, but we can create one only
1075-
def select_and_clone(name, tensor):
1076-
if name in selected_output_keys:
1077-
return tensor.clone()
1078-
10791083
out = self.shared_tensordict_parent.named_apply(
1080-
select_and_clone,
1084+
lambda *args: self.select_and_clone(
1085+
*args, selected_keys=selected_output_keys
1086+
),
10811087
nested_keys=True,
10821088
filter_empty=True,
10831089
)
@@ -1150,14 +1156,14 @@ def _step(
11501156
# will be modified in-place at further steps
11511157
device = self.device
11521158

1153-
def select_and_clone(name, tensor):
1154-
if name in self._selected_step_keys:
1155-
return tensor.clone()
1159+
selected_keys = self._selected_step_keys
11561160

11571161
if partial_steps is not None:
11581162
next_td = TensorDict.lazy_stack([next_td[i] for i in workers_range])
11591163
out = next_td.named_apply(
1160-
select_and_clone, nested_keys=True, filter_empty=True
1164+
lambda *args: self.select_and_clone(*args, selected_keys),
1165+
nested_keys=True,
1166+
filter_empty=True,
11611167
)
11621168
if out_tds is not None:
11631169
out.update(
@@ -2010,20 +2016,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
20102016
next_td = shared_tensordict_parent.get("next")
20112017
device = self.device
20122018

2013-
if next_td.device != device and device is not None:
2014-
2015-
def select_and_clone(name, tensor):
2016-
if name in self._selected_step_keys:
2017-
return tensor.to(device, non_blocking=self.non_blocking)
2018-
2019-
else:
2020-
2021-
def select_and_clone(name, tensor):
2022-
if name in self._selected_step_keys:
2023-
return tensor.clone()
2024-
20252019
out = next_td.named_apply(
2026-
select_and_clone,
2020+
self.select_and_clone,
20272021
nested_keys=True,
20282022
filter_empty=True,
20292023
device=device,
@@ -2203,20 +2197,10 @@ def tentative_update(val, other):
22032197
selected_output_keys = self._selected_reset_keys_filt
22042198
device = self.device
22052199

2206-
if self.shared_tensordict_parent.device != device and device is not None:
2207-
2208-
def select_and_clone(name, tensor):
2209-
if name in selected_output_keys:
2210-
return tensor.to(device, non_blocking=self.non_blocking)
2211-
2212-
else:
2213-
2214-
def select_and_clone(name, tensor):
2215-
if name in selected_output_keys:
2216-
return tensor.clone()
2217-
22182200
out = self.shared_tensordict_parent.named_apply(
2219-
select_and_clone,
2201+
lambda *args: self.select_and_clone(
2202+
*args, selected_keys=selected_output_keys
2203+
),
22202204
nested_keys=True,
22212205
filter_empty=True,
22222206
device=device,

0 commit comments

Comments
 (0)