@@ -379,6 +379,14 @@ def __init__(
379
379
380
380
is_spec_locked = EnvBase .is_spec_locked
381
381
382
+ def select_and_clone (self , name , tensor , selected_keys = None ):
383
+ if selected_keys is None :
384
+ selected_keys = self ._selected_step_keys
385
+ if name in selected_keys :
386
+ if self .device is not None and tensor .device != self .device :
387
+ return tensor .to (self .device , non_blocking = self .non_blocking )
388
+ return tensor .clone ()
389
+
382
390
@property
383
391
def non_blocking (self ):
384
392
nb = self ._non_blocking
@@ -1072,12 +1080,10 @@ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
1072
1080
selected_output_keys = self ._selected_reset_keys_filt
1073
1081
1074
1082
# select + clone creates 2 tds, but we can create one only
1075
- def select_and_clone (name , tensor ):
1076
- if name in selected_output_keys :
1077
- return tensor .clone ()
1078
-
1079
1083
out = self .shared_tensordict_parent .named_apply (
1080
- select_and_clone ,
1084
+ lambda * args : self .select_and_clone (
1085
+ * args , selected_keys = selected_output_keys
1086
+ ),
1081
1087
nested_keys = True ,
1082
1088
filter_empty = True ,
1083
1089
)
@@ -1150,14 +1156,14 @@ def _step(
1150
1156
# will be modified in-place at further steps
1151
1157
device = self .device
1152
1158
1153
- def select_and_clone (name , tensor ):
1154
- if name in self ._selected_step_keys :
1155
- return tensor .clone ()
1159
+ selected_keys = self ._selected_step_keys
1156
1160
1157
1161
if partial_steps is not None :
1158
1162
next_td = TensorDict .lazy_stack ([next_td [i ] for i in workers_range ])
1159
1163
out = next_td .named_apply (
1160
- select_and_clone , nested_keys = True , filter_empty = True
1164
+ lambda * args : self .select_and_clone (* args , selected_keys ),
1165
+ nested_keys = True ,
1166
+ filter_empty = True ,
1161
1167
)
1162
1168
if out_tds is not None :
1163
1169
out .update (
@@ -2010,20 +2016,8 @@ def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
2010
2016
next_td = shared_tensordict_parent .get ("next" )
2011
2017
device = self .device
2012
2018
2013
- if next_td .device != device and device is not None :
2014
-
2015
- def select_and_clone (name , tensor ):
2016
- if name in self ._selected_step_keys :
2017
- return tensor .to (device , non_blocking = self .non_blocking )
2018
-
2019
- else :
2020
-
2021
- def select_and_clone (name , tensor ):
2022
- if name in self ._selected_step_keys :
2023
- return tensor .clone ()
2024
-
2025
2019
out = next_td .named_apply (
2026
- select_and_clone ,
2020
+ self . select_and_clone ,
2027
2021
nested_keys = True ,
2028
2022
filter_empty = True ,
2029
2023
device = device ,
@@ -2203,20 +2197,10 @@ def tentative_update(val, other):
2203
2197
selected_output_keys = self ._selected_reset_keys_filt
2204
2198
device = self .device
2205
2199
2206
- if self .shared_tensordict_parent .device != device and device is not None :
2207
-
2208
- def select_and_clone (name , tensor ):
2209
- if name in selected_output_keys :
2210
- return tensor .to (device , non_blocking = self .non_blocking )
2211
-
2212
- else :
2213
-
2214
- def select_and_clone (name , tensor ):
2215
- if name in selected_output_keys :
2216
- return tensor .clone ()
2217
-
2218
2200
out = self .shared_tensordict_parent .named_apply (
2219
- select_and_clone ,
2201
+ lambda * args : self .select_and_clone (
2202
+ * args , selected_keys = selected_output_keys
2203
+ ),
2220
2204
nested_keys = True ,
2221
2205
filter_empty = True ,
2222
2206
device = device ,
0 commit comments