Skip to content

Commit 426fbc4

Browse files
committed
feat: add remove_episodes utility
This commit introduces a remove_episodes function/CLI tool to remove specific episodes from a dataset, and will automatically modify all required data, video, and metadata. The function will safely remove the episodes, meaning that if at any point during the process a failure occurs, the original dataset is preserved. Additionally, the original dataset is optionally backed up in case it is needed to revert to.
1 parent f994feb commit 426fbc4

File tree

3 files changed

+329
-1
lines changed

3 files changed

+329
-1
lines changed

lerobot/common/datasets/lerobot_dataset.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def download_episodes(self, download_videos: bool = True) -> None:
600600
self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns)
601601

602602
def get_episodes_file_paths(self) -> list[Path]:
603-
episodes = self.episodes if self.episodes is not None else list(range(self.meta.total_episodes))
603+
episodes = self.episodes if self.episodes is not None else list(self.meta.episodes)
604604
fpaths = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in episodes]
605605
if len(self.meta.video_keys) > 0:
606606
video_files = [

lerobot/scripts/remove_episodes.py

+292
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
import argparse
17+
import logging
18+
import shutil
19+
import sys
20+
import tempfile
21+
import time
22+
from copy import deepcopy
23+
from pathlib import Path
24+
25+
from lerobot.common.datasets.compute_stats import aggregate_stats
26+
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
27+
from lerobot.common.datasets.utils import (
28+
EPISODES_PATH,
29+
EPISODES_STATS_PATH,
30+
INFO_PATH,
31+
TASKS_PATH,
32+
append_jsonlines,
33+
write_episode,
34+
write_episode_stats,
35+
write_info,
36+
)
37+
from lerobot.common.utils.utils import init_logging
38+
39+
40+
def remove_episodes(
41+
dataset: LeRobotDataset,
42+
episodes_to_remove: list[int],
43+
backup: str | Path | bool = False,
44+
) -> LeRobotDataset:
45+
"""
46+
Removes specified episodes from a LeRobotDataset and updates all metadata and files accordingly.
47+
48+
Args:
49+
dataset: The LeRobotDataset to modify
50+
episodes_to_remove: List of episode indices to remove
51+
backup: Controls backup behavior:
52+
- False: No backup is created
53+
- True: Create backup at default location next to dataset
54+
- str/Path: Create backup at the specified location
55+
56+
Returns:
57+
Updated LeRobotDataset with specified episodes removed
58+
"""
59+
if not episodes_to_remove:
60+
return dataset
61+
62+
if not all(ep_idx in dataset.meta.episodes for ep_idx in episodes_to_remove):
63+
raise ValueError("Episodes to remove must be valid episode indices in the dataset")
64+
65+
# Calculate the new metadata
66+
new_meta = deepcopy(dataset.meta)
67+
new_meta.info["total_episodes"] -= len(episodes_to_remove)
68+
new_meta.info["total_frames"] -= sum(
69+
dataset.meta.episodes[ep_idx]["length"] for ep_idx in episodes_to_remove
70+
)
71+
72+
for ep_idx in episodes_to_remove:
73+
new_meta.episodes.pop(ep_idx)
74+
new_meta.episodes_stats.pop(ep_idx)
75+
new_meta.stats = aggregate_stats(list(new_meta.episodes_stats.values()))
76+
77+
tasks = {task for ep in new_meta.episodes.values() if "tasks" in ep for task in ep["tasks"]}
78+
new_meta.tasks = {new_meta.get_task_index(task): task for task in tasks}
79+
new_meta.task_to_task_index = {task: idx for idx, task in new_meta.tasks.items()}
80+
new_meta.info["total_tasks"] = len(new_meta.tasks)
81+
82+
new_meta.info["total_videos"] = (
83+
(new_meta.info["total_episodes"]) * len(dataset.meta.video_keys) if dataset.meta.video_keys else 0
84+
)
85+
86+
if "splits" in new_meta.info:
87+
new_meta.info["splits"] = {"train": f"0:{new_meta.info['total_episodes']}"}
88+
89+
# Now that the metadata is recalculated, we update the dataset files by
90+
# removing the files related to the specified episodes. We perform a safe
91+
# update such that if an error occurs, any changes are rolled back and the
92+
# dataset files are left in its original state. Optionally, a non-temporary
93+
# full backup can be made so that we also have the dataset in its original state.
94+
if backup:
95+
backup_path = (
96+
Path(backup)
97+
if isinstance(backup, (str, Path))
98+
else dataset.root.parent / f"{dataset.root.name}_backup_{int(time.time())}"
99+
)
100+
_backup_folder(dataset.root, backup_path)
101+
102+
_update_dataset_files(
103+
new_meta,
104+
episodes_to_remove,
105+
)
106+
107+
updated_dataset = LeRobotDataset(
108+
repo_id=dataset.repo_id,
109+
root=dataset.root,
110+
episodes=None, # Load all episodes
111+
image_transforms=dataset.image_transforms,
112+
delta_timestamps=dataset.delta_timestamps,
113+
tolerance_s=dataset.tolerance_s,
114+
revision=dataset.revision,
115+
download_videos=False, # No need to download, we just saved them
116+
video_backend=dataset.video_backend,
117+
)
118+
119+
return updated_dataset
120+
121+
122+
def _move_file(src: Path, dest: Path) -> None:
123+
dest.parent.mkdir(parents=True, exist_ok=True)
124+
shutil.move(src, dest)
125+
126+
127+
def _update_dataset_files(new_meta: LeRobotDatasetMetadata, episodes_to_remove: list[int]):
128+
"""Update dataset files.
129+
130+
This function performs a safe update for dataset files. It moves modified or removed
131+
episode files to a temporary directory. Once all changes are made, the temporary
132+
directory is deleted. If an error occurs during the update, all changes are rolled
133+
back and the original dataset files are restored.
134+
135+
Args:
136+
new_meta (LeRobotDatasetMetadata): Updated metadata object containing the new
137+
dataset state after removing episodes
138+
episodes_to_remove (list[int]): List of episode indices to remove from the dataset
139+
140+
Raises:
141+
Exception: If any operation fails, rolls back all changes and re-raises the original exception
142+
"""
143+
with tempfile.TemporaryDirectory(prefix="lerobot_backup_temp_") as backup_path:
144+
backup_dir = Path(backup_path)
145+
146+
# Init empty containers s.t. they are guaranteed to exist in the except block
147+
metadata_files = {}
148+
rel_data_paths = []
149+
rel_video_paths = []
150+
151+
try:
152+
# Step 1: Update metadata files
153+
metadata_files = {
154+
INFO_PATH: lambda: write_info(new_meta.info, new_meta.root),
155+
EPISODES_PATH: lambda: [
156+
write_episode(ep, new_meta.root) for ep in new_meta.episodes.values()
157+
],
158+
TASKS_PATH: lambda: [
159+
append_jsonlines({"task_index": idx, "task": task}, new_meta.root / TASKS_PATH)
160+
for idx, task in new_meta.tasks.items()
161+
],
162+
EPISODES_STATS_PATH: lambda: [
163+
write_episode_stats(idx, stats, new_meta.root)
164+
for idx, stats in new_meta.episodes_stats.items()
165+
],
166+
}
167+
for file_path, update_func in metadata_files.items():
168+
_move_file(new_meta.root / file_path, backup_dir / file_path)
169+
update_func()
170+
171+
# Step 2: Update data and video
172+
rel_data_paths = [new_meta.get_data_file_path(ep_idx) for ep_idx in episodes_to_remove]
173+
rel_video_paths = [
174+
new_meta.get_video_file_path(ep_idx, vid_key)
175+
for ep_idx in episodes_to_remove
176+
for vid_key in new_meta.video_keys
177+
]
178+
for rel_path in rel_data_paths + rel_video_paths:
179+
if (new_meta.root / rel_path).exists():
180+
_move_file(new_meta.root / rel_path, backup_dir / rel_path)
181+
182+
except Exception as e:
183+
logging.error(f"Error updating dataset files: {str(e)}. Rolling back changes.")
184+
185+
# Restore metadata files
186+
for file_path in metadata_files:
187+
if (backup_dir / file_path).exists():
188+
_move_file(backup_dir / file_path, new_meta.root / file_path)
189+
190+
# Restore data and video files
191+
for rel_file_path in rel_data_paths + rel_video_paths:
192+
if (backup_dir / rel_file_path).exists():
193+
_move_file(backup_dir / rel_file_path, new_meta.root / rel_file_path)
194+
195+
raise e
196+
197+
198+
def _backup_folder(target_dir: Path, backup_path: Path) -> None:
199+
if backup_path.resolve() == target_dir.resolve() or backup_path.resolve().is_relative_to(
200+
target_dir.resolve()
201+
):
202+
raise ValueError(
203+
f"Backup directory '{backup_path}' cannot be inside the dataset "
204+
f"directory '{target_dir}' as this would cause infinite recursion"
205+
)
206+
207+
backup_path.parent.mkdir(parents=True, exist_ok=True)
208+
logging.info(f"Creating backup at: {backup_path}")
209+
shutil.copytree(target_dir, backup_path)
210+
211+
212+
def _parse_episodes_list(episodes_str: str) -> list[int]:
213+
"""
214+
Parse a string of episode indices, ranges, and comma-separated lists into a list of integers.
215+
"""
216+
episodes = []
217+
for ep in episodes_str.split(","):
218+
if "-" in ep:
219+
start, end = ep.split("-")
220+
episodes.extend(range(int(start), int(end) + 1))
221+
else:
222+
episodes.append(int(ep))
223+
return episodes
224+
225+
226+
def main():
227+
parser = argparse.ArgumentParser(description="Remove episodes from a LeRobot dataset")
228+
parser.add_argument(
229+
"--repo-id",
230+
type=str,
231+
required=True,
232+
help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).",
233+
)
234+
parser.add_argument(
235+
"--root",
236+
type=Path,
237+
default=None,
238+
help="Root directory for the dataset stored locally. By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.",
239+
)
240+
parser.add_argument(
241+
"-e",
242+
"--episodes",
243+
type=str,
244+
required=True,
245+
help="Episodes to remove. Can be a single index, comma-separated indices, or ranges (e.g., '1-5,7,10-12')",
246+
)
247+
parser.add_argument(
248+
"-b",
249+
"--backup",
250+
nargs="?",
251+
const=True,
252+
default=False,
253+
help="Create a backup before modifying the dataset. Without a value, creates a backup in the default location. "
254+
"With a value, either 'true'/'false' or a path to store the backup.",
255+
)
256+
args = parser.parse_args()
257+
258+
# Parse the backup argument
259+
backup_value = args.backup
260+
if isinstance(backup_value, str):
261+
if backup_value.lower() == "true":
262+
backup_value = True
263+
elif backup_value.lower() == "false":
264+
backup_value = False
265+
# Otherwise, it's treated as a path
266+
267+
# Parse episodes to remove
268+
episodes_to_remove = _parse_episodes_list(args.episodes)
269+
if not episodes_to_remove:
270+
logging.warning("No episodes specified to remove")
271+
sys.exit(0)
272+
273+
# Load the dataset
274+
logging.info(f"Loading dataset '{args.repo_id}'...")
275+
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
276+
logging.info(f"Dataset has {dataset.meta.total_episodes} episodes")
277+
278+
# Modify the dataset
279+
logging.info(f"Removing {len(set(episodes_to_remove))} episodes: {sorted(set(episodes_to_remove))}")
280+
updated_dataset = remove_episodes(
281+
dataset=dataset,
282+
episodes_to_remove=episodes_to_remove,
283+
backup=backup_value,
284+
)
285+
logging.info(
286+
f"Successfully removed episodes. Dataset now has {updated_dataset.meta.total_episodes} episodes."
287+
)
288+
289+
290+
if __name__ == "__main__":
291+
init_logging()
292+
main()

tests/test_datasets.py

+36
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from copy import deepcopy
2020
from itertools import chain
2121
from pathlib import Path
22+
from unittest.mock import patch
2223

2324
import numpy as np
2425
import pytest
@@ -28,6 +29,7 @@
2829
from safetensors.torch import load_file
2930

3031
import lerobot
32+
from lerobot.common.datasets.episode_utils import remove_episodes
3133
from lerobot.common.datasets.factory import make_dataset
3234
from lerobot.common.datasets.image_writer import image_array_to_pil_image
3335
from lerobot.common.datasets.lerobot_dataset import (
@@ -580,3 +582,37 @@ def test_dataset_feature_with_forward_slash_raises_error():
580582
fps=30,
581583
features={"a/b": {"dtype": "float32", "shape": 2, "names": None}},
582584
)
585+
586+
587+
@pytest.mark.parametrize(
588+
"total_episodes, total_frames, episodes_to_remove",
589+
[
590+
(3, 30, [1]),
591+
(3, 30, [0, 2]),
592+
(4, 50, [1, 2]),
593+
],
594+
)
595+
def test_remove_episodes(tmp_path, lerobot_dataset_factory, total_episodes, total_frames, episodes_to_remove):
596+
dataset = lerobot_dataset_factory(
597+
root=tmp_path / "test",
598+
total_episodes=total_episodes,
599+
total_frames=total_frames,
600+
)
601+
num_frames_to_remove = 0
602+
for ep in episodes_to_remove:
603+
num_frames_to_remove += (
604+
dataset.episode_data_index["to"][ep].item() - dataset.episode_data_index["from"][ep].item()
605+
)
606+
607+
with (
608+
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
609+
patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
610+
):
611+
mock_get_safe_version.side_effect = lambda repo_id, version: version
612+
mock_snapshot_download.side_effect = lambda repo_id, **kwargs: str(dataset.root)
613+
updated_dataset = remove_episodes(dataset, episodes_to_remove)
614+
615+
assert updated_dataset.meta.total_episodes == total_episodes - len(episodes_to_remove)
616+
assert updated_dataset.meta.total_frames == total_frames - num_frames_to_remove
617+
for i, ep_meta in enumerate(updated_dataset.meta.episodes.values()):
618+
assert ep_meta["episode_index"] == i

0 commit comments

Comments
 (0)