Skip to content

Commit b104be0

Browse files
Merge branch 'main' into fix/lint_warnings
2 parents f9e4a1f + 074f0ac commit b104be0

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

examples/2_evaluate_pretrained_policy.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
# OR a path to a local outputs/train folder.
4545
# pretrained_policy_path = Path("outputs/train/example_pusht_diffusion")
4646

47-
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path, map_location=device)
47+
policy = DiffusionPolicy.from_pretrained(pretrained_policy_path)
4848

4949
# Initialize evaluation environment to render two observation types:
5050
# an image of the scene and state/position of the agent. The environment

tests/test_policies.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,12 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
252252
policy_cfg.input_features = {
253253
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
254254
}
255-
policy = policy_cls(policy_cfg) # config.device = gpu
255+
policy = policy_cls(policy_cfg)
256+
policy.to(policy_cfg.device)
256257
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
257258
policy.save_pretrained(save_dir)
258-
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
259-
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
259+
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
260+
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
260261

261262

262263
@pytest.mark.parametrize("insert_temporal_dim", [False, True])

0 commit comments

Comments
 (0)