Skip to content

Commit 074f0ac

Browse files
authoredMar 7, 2025··
Fix gpu nightly (#829)
1 parent 25c63cc commit 074f0ac

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed
 

‎tests/test_policies.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -252,10 +252,11 @@ def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name:
252252
key: ft for key, ft in features.items() if key not in policy_cfg.output_features
253253
}
254254
policy = policy_cls(policy_cfg)
255+
policy.to(policy_cfg.device)
255256
save_dir = tmp_path / f"test_save_and_load_pretrained_{policy_cls.__name__}"
256257
policy.save_pretrained(save_dir)
257-
policy_ = policy_cls.from_pretrained(save_dir, config=policy_cfg)
258-
assert all(torch.equal(p, p_) for p, p_ in zip(policy.parameters(), policy_.parameters(), strict=True))
258+
loaded_policy = policy_cls.from_pretrained(save_dir, config=policy_cfg)
259+
torch.testing.assert_close(list(policy.parameters()), list(loaded_policy.parameters()), rtol=0, atol=0)
259260

260261

261262
@pytest.mark.parametrize("insert_temporal_dim", [False, True])

0 commit comments

Comments
 (0)