|
18 | 18 | # TODO: (1) better device management
|
19 | 19 |
|
20 | 20 | from copy import deepcopy
|
21 |
| -from typing import Callable, Optional, Tuple |
| 21 | +from typing import Callable, Optional, Tuple, Union, Dict |
| 22 | +from pathlib import Path |
22 | 23 |
|
23 | 24 | import einops
|
24 | 25 | import numpy as np
|
@@ -142,6 +143,131 @@ def __init__(
|
142 | 143 | self.log_alpha = nn.Parameter(torch.tensor([0.0]))
|
143 | 144 | self.temperature = self.log_alpha.exp().item()
|
144 | 145 |
|
| 146 | + def _save_pretrained(self, save_directory): |
| 147 | + """Custom save method to handle TensorDict properly""" |
| 148 | + import os |
| 149 | + import json |
| 150 | + from dataclasses import asdict |
| 151 | + from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME |
| 152 | + from safetensors.torch import save_file |
| 153 | + |
| 154 | + # NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap |
| 155 | + # implies one side effect: the __batch_size parameters are saved in the state_dict |
| 156 | + # __batch_size is torch.Size or safetensor save only torch.Tensor |
| 157 | + # so we need to filter them out before saving |
| 158 | + simplified_state_dict = {} |
| 159 | + |
| 160 | + for name, param in self.named_parameters(): |
| 161 | + simplified_state_dict[name] = param |
| 162 | + save_file( |
| 163 | + simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE) |
| 164 | + ) |
| 165 | + |
| 166 | + # Save config |
| 167 | + config_dict = asdict(self.config) |
| 168 | + with open(os.path.join(save_directory, CONFIG_NAME), "w") as f: |
| 169 | + json.dump(config_dict, f, indent=2) |
| 170 | + print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}") |
| 171 | + |
| 172 | + @classmethod |
| 173 | + def _from_pretrained( |
| 174 | + cls, |
| 175 | + *, |
| 176 | + model_id: str, |
| 177 | + revision: Optional[str], |
| 178 | + cache_dir: Optional[Union[str, Path]], |
| 179 | + force_download: bool, |
| 180 | + proxies: Optional[Dict], |
| 181 | + resume_download: Optional[bool], |
| 182 | + local_files_only: bool, |
| 183 | + token: Optional[Union[str, bool]], |
| 184 | + map_location: str = "cpu", |
| 185 | + strict: bool = False, |
| 186 | + **model_kwargs, |
| 187 | + ) -> "SACPolicy": |
| 188 | + """Custom load method to handle loading SAC policy from saved files""" |
| 189 | + import os |
| 190 | + import json |
| 191 | + from pathlib import Path |
| 192 | + from huggingface_hub import hf_hub_download |
| 193 | + from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME |
| 194 | + from safetensors.torch import load_file |
| 195 | + from lerobot.common.policies.sac.configuration_sac import SACConfig |
| 196 | + |
| 197 | + # Check if model_id is a local path or a hub model ID |
| 198 | + if os.path.isdir(model_id): |
| 199 | + model_path = Path(model_id) |
| 200 | + safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE) |
| 201 | + config_file = os.path.join(model_path, CONFIG_NAME) |
| 202 | + else: |
| 203 | + # Download the safetensors file from the hub |
| 204 | + safetensors_file = hf_hub_download( |
| 205 | + repo_id=model_id, |
| 206 | + filename=SAFETENSORS_SINGLE_FILE, |
| 207 | + revision=revision, |
| 208 | + cache_dir=cache_dir, |
| 209 | + force_download=force_download, |
| 210 | + proxies=proxies, |
| 211 | + resume_download=resume_download, |
| 212 | + token=token, |
| 213 | + local_files_only=local_files_only, |
| 214 | + ) |
| 215 | + # Download the config file |
| 216 | + try: |
| 217 | + config_file = hf_hub_download( |
| 218 | + repo_id=model_id, |
| 219 | + filename=CONFIG_NAME, |
| 220 | + revision=revision, |
| 221 | + cache_dir=cache_dir, |
| 222 | + force_download=force_download, |
| 223 | + proxies=proxies, |
| 224 | + resume_download=resume_download, |
| 225 | + token=token, |
| 226 | + local_files_only=local_files_only, |
| 227 | + ) |
| 228 | + except Exception: |
| 229 | + config_file = None |
| 230 | + |
| 231 | + # Load or create config |
| 232 | + if config_file and os.path.exists(config_file): |
| 233 | + # Load config from file |
| 234 | + with open(config_file, "r") as f: |
| 235 | + config_dict = json.load(f) |
| 236 | + config = SACConfig(**config_dict) |
| 237 | + else: |
| 238 | + # Use the provided config or create a default one |
| 239 | + config = model_kwargs.get("config", SACConfig()) |
| 240 | + |
| 241 | + # Create a new instance with the loaded config |
| 242 | + model = cls(config=config) |
| 243 | + |
| 244 | + # Load state dict from safetensors file |
| 245 | + if os.path.exists(safetensors_file): |
| 246 | + # Note: The load_file function returns a dict with the parameters, but __batch_size |
| 247 | + # is not loaded so we need to copy it from the model state_dict |
| 248 | + # Load the parameters only |
| 249 | + loaded_state_dict = load_file(safetensors_file, device=map_location) |
| 250 | + |
| 251 | + # Copy batch size parameters |
| 252 | + find_and_copy_params( |
| 253 | + original_state_dict=model.state_dict(), |
| 254 | + loaded_state_dict=loaded_state_dict, |
| 255 | + pattern="__batch_size", |
| 256 | + match_type="endswith", |
| 257 | + ) |
| 258 | + |
| 259 | + # Copy normalization buffer parameters |
| 260 | + find_and_copy_params( |
| 261 | + original_state_dict=model.state_dict(), |
| 262 | + loaded_state_dict=loaded_state_dict, |
| 263 | + pattern="_orig_mod.output_normalization.buffer_action", |
| 264 | + match_type="contains", |
| 265 | + ) |
| 266 | + |
| 267 | + model.load_state_dict(loaded_state_dict, strict=False) |
| 268 | + |
| 269 | + return model |
| 270 | + |
145 | 271 | def reset(self):
|
146 | 272 | """Reset the policy"""
|
147 | 273 | pass
|
@@ -276,6 +402,9 @@ def compute_loss_actor(
|
276 | 402 |
|
277 | 403 | actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
278 | 404 |
|
| 405 | + # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way |
| 406 | + actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"] |
| 407 | + |
279 | 408 | q_preds = self.critic_forward(
|
280 | 409 | observations,
|
281 | 410 | actions_pi,
|
@@ -334,6 +463,50 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
334 | 463 | return self.net(x)
|
335 | 464 |
|
336 | 465 |
|
| 466 | +def find_and_copy_params( |
| 467 | + original_state_dict: dict[str, torch.Tensor], |
| 468 | + loaded_state_dict: dict[str, torch.Tensor], |
| 469 | + pattern: str, |
| 470 | + match_type: str = "contains", |
| 471 | +) -> list[str]: |
| 472 | + """Find and copy parameters from original state dict to loaded state dict based on a pattern. |
| 473 | +
|
| 474 | + This function can search for keys in different ways based on the match_type: |
| 475 | + - "exact": The key must exactly match the pattern |
| 476 | + - "contains": The key must contain the pattern anywhere |
| 477 | + - "startswith": The key must start with the pattern |
| 478 | + - "endswith": The key must end with the pattern |
| 479 | +
|
| 480 | + Args: |
| 481 | + original_state_dict: The source state dictionary |
| 482 | + loaded_state_dict: The target state dictionary |
| 483 | + pattern: The pattern to search for in keys |
| 484 | + match_type: How to match the pattern (exact, contains, startswith, endswith) |
| 485 | +
|
| 486 | + Returns: |
| 487 | + list[str]: List of keys that were copied |
| 488 | + """ |
| 489 | + copied_keys = [] |
| 490 | + |
| 491 | + for key in original_state_dict: |
| 492 | + should_copy = False |
| 493 | + |
| 494 | + if match_type == "exact": |
| 495 | + should_copy = key == pattern |
| 496 | + elif match_type == "contains": |
| 497 | + should_copy = pattern in key |
| 498 | + elif match_type == "startswith": |
| 499 | + should_copy = key.startswith(pattern) |
| 500 | + elif match_type == "endswith": |
| 501 | + should_copy = key.endswith(pattern) |
| 502 | + |
| 503 | + if should_copy: |
| 504 | + loaded_state_dict[key] = original_state_dict[key] |
| 505 | + copied_keys.append(key) |
| 506 | + |
| 507 | + return copied_keys |
| 508 | + |
| 509 | + |
337 | 510 | class CriticHead(nn.Module):
|
338 | 511 | def __init__(
|
339 | 512 | self,
|
|
0 commit comments