Skip to content

Commit 5ec9bc5

Browse files
Fixed VideoRecorder crash when passing fps (#2827)
1 parent 6e40548 commit 5ec9bc5

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

test/test_transforms.py

+9
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@
131131
from torchrl.envs.utils import check_env_specs, MarlGroupMapType, step_mdp
132132
from torchrl.modules import GRUModule, LSTMModule, MLP, ProbabilisticActor, TanhNormal
133133
from torchrl.modules.utils import get_primers_from_module
134+
from torchrl.record.recorder import VideoRecorder
134135

135136
if os.getenv("PYTORCH_TEST_FBCODE"):
136137
from pytorch.rl.test._utils_internal import ( # noqa
@@ -13978,6 +13979,14 @@ def test_transform_inverse(self):
1397813979
raise pytest.skip("Tested elsewhere")
1397913980

1398013981

13982+
class TestVideoRecorder:
13983+
# TODO: add more tests
13984+
def test_can_init_with_fps(self):
13985+
recorder = VideoRecorder(None, None, fps=30)
13986+
13987+
assert recorder is not None
13988+
13989+
1398113990
if __name__ == "__main__":
1398213991
args, unknown = argparse.ArgumentParser().parse_known_args()
1398313992
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/record/recorder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def __init__(
121121
video_kwargs = {}
122122
video_kwargs.update(kwargs)
123123
if fps is not None:
124-
self.video_kwargs["fps"] = fps
124+
video_kwargs["fps"] = fps
125125
self.video_kwargs = video_kwargs
126126
self.iter = 0
127127
self.skip = skip

0 commit comments

Comments
 (0)