Skip to content

Commit 5081c14

Browse files
committed
Add custom save and load methods for SAC policy
- Implement `_save_pretrained` method to handle TensorDict state saving - Add `_from_pretrained` class method for loading SAC policy from files - Create utility function `find_and_copy_params` to handle parameter copying
1 parent 25b88f3 commit 5081c14

File tree

1 file changed

+174
-1
lines changed

1 file changed

+174
-1
lines changed

lerobot/common/policies/sac/modeling_sac.py

+174-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
# TODO: (1) better device management
1919

2020
from copy import deepcopy
21-
from typing import Callable, Optional, Tuple
21+
from typing import Callable, Optional, Tuple, Union, Dict
22+
from pathlib import Path
2223

2324
import einops
2425
import numpy as np
@@ -142,6 +143,131 @@ def __init__(
142143
self.log_alpha = nn.Parameter(torch.tensor([0.0]))
143144
self.temperature = self.log_alpha.exp().item()
144145

146+
def _save_pretrained(self, save_directory):
147+
"""Custom save method to handle TensorDict properly"""
148+
import os
149+
import json
150+
from dataclasses import asdict
151+
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
152+
from safetensors.torch import save_file
153+
154+
# NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap
155+
# implies one side effect: the __batch_size parameters are saved in the state_dict
156+
# __batch_size is torch.Size or safetensor save only torch.Tensor
157+
# so we need to filter them out before saving
158+
simplified_state_dict = {}
159+
160+
for name, param in self.named_parameters():
161+
simplified_state_dict[name] = param
162+
save_file(
163+
simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE)
164+
)
165+
166+
# Save config
167+
config_dict = asdict(self.config)
168+
with open(os.path.join(save_directory, CONFIG_NAME), "w") as f:
169+
json.dump(config_dict, f, indent=2)
170+
print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}")
171+
172+
@classmethod
173+
def _from_pretrained(
174+
cls,
175+
*,
176+
model_id: str,
177+
revision: Optional[str],
178+
cache_dir: Optional[Union[str, Path]],
179+
force_download: bool,
180+
proxies: Optional[Dict],
181+
resume_download: Optional[bool],
182+
local_files_only: bool,
183+
token: Optional[Union[str, bool]],
184+
map_location: str = "cpu",
185+
strict: bool = False,
186+
**model_kwargs,
187+
) -> "SACPolicy":
188+
"""Custom load method to handle loading SAC policy from saved files"""
189+
import os
190+
import json
191+
from pathlib import Path
192+
from huggingface_hub import hf_hub_download
193+
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME
194+
from safetensors.torch import load_file
195+
from lerobot.common.policies.sac.configuration_sac import SACConfig
196+
197+
# Check if model_id is a local path or a hub model ID
198+
if os.path.isdir(model_id):
199+
model_path = Path(model_id)
200+
safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE)
201+
config_file = os.path.join(model_path, CONFIG_NAME)
202+
else:
203+
# Download the safetensors file from the hub
204+
safetensors_file = hf_hub_download(
205+
repo_id=model_id,
206+
filename=SAFETENSORS_SINGLE_FILE,
207+
revision=revision,
208+
cache_dir=cache_dir,
209+
force_download=force_download,
210+
proxies=proxies,
211+
resume_download=resume_download,
212+
token=token,
213+
local_files_only=local_files_only,
214+
)
215+
# Download the config file
216+
try:
217+
config_file = hf_hub_download(
218+
repo_id=model_id,
219+
filename=CONFIG_NAME,
220+
revision=revision,
221+
cache_dir=cache_dir,
222+
force_download=force_download,
223+
proxies=proxies,
224+
resume_download=resume_download,
225+
token=token,
226+
local_files_only=local_files_only,
227+
)
228+
except Exception:
229+
config_file = None
230+
231+
# Load or create config
232+
if config_file and os.path.exists(config_file):
233+
# Load config from file
234+
with open(config_file, "r") as f:
235+
config_dict = json.load(f)
236+
config = SACConfig(**config_dict)
237+
else:
238+
# Use the provided config or create a default one
239+
config = model_kwargs.get("config", SACConfig())
240+
241+
# Create a new instance with the loaded config
242+
model = cls(config=config)
243+
244+
# Load state dict from safetensors file
245+
if os.path.exists(safetensors_file):
246+
# Note: The load_file function returns a dict with the parameters, but __batch_size
247+
# is not loaded so we need to copy it from the model state_dict
248+
# Load the parameters only
249+
loaded_state_dict = load_file(safetensors_file, device=map_location)
250+
251+
# Copy batch size parameters
252+
find_and_copy_params(
253+
original_state_dict=model.state_dict(),
254+
loaded_state_dict=loaded_state_dict,
255+
pattern="__batch_size",
256+
match_type="endswith",
257+
)
258+
259+
# Copy normalization buffer parameters
260+
find_and_copy_params(
261+
original_state_dict=model.state_dict(),
262+
loaded_state_dict=loaded_state_dict,
263+
pattern="_orig_mod.output_normalization.buffer_action",
264+
match_type="contains",
265+
)
266+
267+
model.load_state_dict(loaded_state_dict, strict=False)
268+
269+
return model
270+
145271
def reset(self):
146272
"""Reset the policy"""
147273
pass
@@ -276,6 +402,9 @@ def compute_loss_actor(
276402

277403
actions_pi, log_probs, _ = self.actor(observations, observation_features)
278404

405+
# TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way
406+
actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"]
407+
279408
q_preds = self.critic_forward(
280409
observations,
281410
actions_pi,
@@ -334,6 +463,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
334463
return self.net(x)
335464

336465

466+
def find_and_copy_params(
467+
original_state_dict: dict[str, torch.Tensor],
468+
loaded_state_dict: dict[str, torch.Tensor],
469+
pattern: str,
470+
match_type: str = "contains",
471+
) -> list[str]:
472+
"""Find and copy parameters from original state dict to loaded state dict based on a pattern.
473+
474+
This function can search for keys in different ways based on the match_type:
475+
- "exact": The key must exactly match the pattern
476+
- "contains": The key must contain the pattern anywhere
477+
- "startswith": The key must start with the pattern
478+
- "endswith": The key must end with the pattern
479+
480+
Args:
481+
original_state_dict: The source state dictionary
482+
loaded_state_dict: The target state dictionary
483+
pattern: The pattern to search for in keys
484+
match_type: How to match the pattern (exact, contains, startswith, endswith)
485+
486+
Returns:
487+
list[str]: List of keys that were copied
488+
"""
489+
copied_keys = []
490+
491+
for key in original_state_dict:
492+
should_copy = False
493+
494+
if match_type == "exact":
495+
should_copy = key == pattern
496+
elif match_type == "contains":
497+
should_copy = pattern in key
498+
elif match_type == "startswith":
499+
should_copy = key.startswith(pattern)
500+
elif match_type == "endswith":
501+
should_copy = key.endswith(pattern)
502+
503+
if should_copy:
504+
loaded_state_dict[key] = original_state_dict[key]
505+
copied_keys.append(key)
506+
507+
return copied_keys
508+
509+
337510
class CriticHead(nn.Module):
338511
def __init__(
339512
self,

0 commit comments

Comments
 (0)