|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +import argparse |
| 17 | +import logging |
| 18 | +import shutil |
| 19 | +import sys |
| 20 | +import tempfile |
| 21 | +import time |
| 22 | +from copy import deepcopy |
| 23 | +from pathlib import Path |
| 24 | + |
| 25 | +from lerobot.common.datasets.compute_stats import aggregate_stats |
| 26 | +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata |
| 27 | +from lerobot.common.datasets.utils import ( |
| 28 | + EPISODES_PATH, |
| 29 | + EPISODES_STATS_PATH, |
| 30 | + INFO_PATH, |
| 31 | + TASKS_PATH, |
| 32 | + append_jsonlines, |
| 33 | + write_episode, |
| 34 | + write_episode_stats, |
| 35 | + write_info, |
| 36 | +) |
| 37 | +from lerobot.common.utils.utils import init_logging |
| 38 | + |
| 39 | + |
| 40 | +def remove_episodes( |
| 41 | + dataset: LeRobotDataset, |
| 42 | + episodes_to_remove: list[int], |
| 43 | + backup: str | Path | bool = False, |
| 44 | +) -> LeRobotDataset: |
| 45 | + """ |
| 46 | + Removes specified episodes from a LeRobotDataset and updates all metadata and files accordingly. |
| 47 | +
|
| 48 | + Args: |
| 49 | + dataset: The LeRobotDataset to modify |
| 50 | + episodes_to_remove: List of episode indices to remove |
| 51 | + backup: Controls backup behavior: |
| 52 | + - False: No backup is created |
| 53 | + - True: Create backup at default location next to dataset |
| 54 | + - str/Path: Create backup at the specified location |
| 55 | +
|
| 56 | + Returns: |
| 57 | + Updated LeRobotDataset with specified episodes removed |
| 58 | + """ |
| 59 | + if not episodes_to_remove: |
| 60 | + return dataset |
| 61 | + |
| 62 | + if not all(ep_idx in dataset.meta.episodes for ep_idx in episodes_to_remove): |
| 63 | + raise ValueError("Episodes to remove must be valid episode indices in the dataset") |
| 64 | + |
| 65 | + # Calculate the new metadata |
| 66 | + new_meta = deepcopy(dataset.meta) |
| 67 | + new_meta.info["total_episodes"] -= len(episodes_to_remove) |
| 68 | + new_meta.info["total_frames"] -= sum( |
| 69 | + dataset.meta.episodes[ep_idx]["length"] for ep_idx in episodes_to_remove |
| 70 | + ) |
| 71 | + |
| 72 | + for ep_idx in episodes_to_remove: |
| 73 | + new_meta.episodes.pop(ep_idx) |
| 74 | + new_meta.episodes_stats.pop(ep_idx) |
| 75 | + new_meta.stats = aggregate_stats(list(new_meta.episodes_stats.values())) |
| 76 | + |
| 77 | + tasks = {task for ep in new_meta.episodes.values() if "tasks" in ep for task in ep["tasks"]} |
| 78 | + new_meta.tasks = {new_meta.get_task_index(task): task for task in tasks} |
| 79 | + new_meta.task_to_task_index = {task: idx for idx, task in new_meta.tasks.items()} |
| 80 | + new_meta.info["total_tasks"] = len(new_meta.tasks) |
| 81 | + |
| 82 | + new_meta.info["total_videos"] = ( |
| 83 | + (new_meta.info["total_episodes"]) * len(dataset.meta.video_keys) if dataset.meta.video_keys else 0 |
| 84 | + ) |
| 85 | + |
| 86 | + if "splits" in new_meta.info: |
| 87 | + new_meta.info["splits"] = {"train": f"0:{new_meta.info['total_episodes']}"} |
| 88 | + |
| 89 | + # Now that the metadata is recalculated, we update the dataset files by |
| 90 | + # removing the files related to the specified episodes. We perform a safe |
| 91 | + # update such that if an error occurs, any changes are rolled back and the |
| 92 | + # dataset files are left in its original state. Optionally, a non-temporary |
| 93 | + # full backup can be made so that we also have the dataset in its original state. |
| 94 | + if backup: |
| 95 | + backup_path = ( |
| 96 | + Path(backup) |
| 97 | + if isinstance(backup, (str, Path)) |
| 98 | + else dataset.root.parent / f"{dataset.root.name}_backup_{int(time.time())}" |
| 99 | + ) |
| 100 | + _backup_folder(dataset.root, backup_path) |
| 101 | + |
| 102 | + _update_dataset_files( |
| 103 | + new_meta, |
| 104 | + episodes_to_remove, |
| 105 | + ) |
| 106 | + |
| 107 | + updated_dataset = LeRobotDataset( |
| 108 | + repo_id=dataset.repo_id, |
| 109 | + root=dataset.root, |
| 110 | + episodes=None, # Load all episodes |
| 111 | + image_transforms=dataset.image_transforms, |
| 112 | + delta_timestamps=dataset.delta_timestamps, |
| 113 | + tolerance_s=dataset.tolerance_s, |
| 114 | + revision=dataset.revision, |
| 115 | + download_videos=False, # No need to download, we just saved them |
| 116 | + video_backend=dataset.video_backend, |
| 117 | + ) |
| 118 | + |
| 119 | + return updated_dataset |
| 120 | + |
| 121 | + |
| 122 | +def _move_file(src: Path, dest: Path) -> None: |
| 123 | + dest.parent.mkdir(parents=True, exist_ok=True) |
| 124 | + shutil.move(src, dest) |
| 125 | + |
| 126 | + |
| 127 | +def _update_dataset_files(new_meta: LeRobotDatasetMetadata, episodes_to_remove: list[int]): |
| 128 | + """Update dataset files. |
| 129 | +
|
| 130 | + This function performs a safe update for dataset files. It moves modified or removed |
| 131 | + episode files to a temporary directory. Once all changes are made, the temporary |
| 132 | + directory is deleted. If an error occurs during the update, all changes are rolled |
| 133 | + back and the original dataset files are restored. |
| 134 | +
|
| 135 | + Args: |
| 136 | + new_meta (LeRobotDatasetMetadata): Updated metadata object containing the new |
| 137 | + dataset state after removing episodes |
| 138 | + episodes_to_remove (list[int]): List of episode indices to remove from the dataset |
| 139 | +
|
| 140 | + Raises: |
| 141 | + Exception: If any operation fails, rolls back all changes and re-raises the original exception |
| 142 | + """ |
| 143 | + with tempfile.TemporaryDirectory(prefix="lerobot_backup_temp_") as backup_path: |
| 144 | + backup_dir = Path(backup_path) |
| 145 | + |
| 146 | + # Init empty containers s.t. they are guaranteed to exist in the except block |
| 147 | + metadata_files = {} |
| 148 | + rel_data_paths = [] |
| 149 | + rel_video_paths = [] |
| 150 | + |
| 151 | + try: |
| 152 | + # Step 1: Update metadata files |
| 153 | + metadata_files = { |
| 154 | + INFO_PATH: lambda: write_info(new_meta.info, new_meta.root), |
| 155 | + EPISODES_PATH: lambda: [ |
| 156 | + write_episode(ep, new_meta.root) for ep in new_meta.episodes.values() |
| 157 | + ], |
| 158 | + TASKS_PATH: lambda: [ |
| 159 | + append_jsonlines({"task_index": idx, "task": task}, new_meta.root / TASKS_PATH) |
| 160 | + for idx, task in new_meta.tasks.items() |
| 161 | + ], |
| 162 | + EPISODES_STATS_PATH: lambda: [ |
| 163 | + write_episode_stats(idx, stats, new_meta.root) |
| 164 | + for idx, stats in new_meta.episodes_stats.items() |
| 165 | + ], |
| 166 | + } |
| 167 | + for file_path, update_func in metadata_files.items(): |
| 168 | + _move_file(new_meta.root / file_path, backup_dir / file_path) |
| 169 | + update_func() |
| 170 | + |
| 171 | + # Step 2: Update data and video |
| 172 | + rel_data_paths = [new_meta.get_data_file_path(ep_idx) for ep_idx in episodes_to_remove] |
| 173 | + rel_video_paths = [ |
| 174 | + new_meta.get_video_file_path(ep_idx, vid_key) |
| 175 | + for ep_idx in episodes_to_remove |
| 176 | + for vid_key in new_meta.video_keys |
| 177 | + ] |
| 178 | + for rel_path in rel_data_paths + rel_video_paths: |
| 179 | + if (new_meta.root / rel_path).exists(): |
| 180 | + _move_file(new_meta.root / rel_path, backup_dir / rel_path) |
| 181 | + |
| 182 | + except Exception as e: |
| 183 | + logging.error(f"Error updating dataset files: {str(e)}. Rolling back changes.") |
| 184 | + |
| 185 | + # Restore metadata files |
| 186 | + for file_path in metadata_files: |
| 187 | + if (backup_dir / file_path).exists(): |
| 188 | + _move_file(backup_dir / file_path, new_meta.root / file_path) |
| 189 | + |
| 190 | + # Restore data and video files |
| 191 | + for rel_file_path in rel_data_paths + rel_video_paths: |
| 192 | + if (backup_dir / rel_file_path).exists(): |
| 193 | + _move_file(backup_dir / rel_file_path, new_meta.root / rel_file_path) |
| 194 | + |
| 195 | + raise e |
| 196 | + |
| 197 | + |
| 198 | +def _backup_folder(target_dir: Path, backup_path: Path) -> None: |
| 199 | + if backup_path.resolve() == target_dir.resolve() or backup_path.resolve().is_relative_to( |
| 200 | + target_dir.resolve() |
| 201 | + ): |
| 202 | + raise ValueError( |
| 203 | + f"Backup directory '{backup_path}' cannot be inside the dataset " |
| 204 | + f"directory '{target_dir}' as this would cause infinite recursion" |
| 205 | + ) |
| 206 | + |
| 207 | + backup_path.parent.mkdir(parents=True, exist_ok=True) |
| 208 | + logging.info(f"Creating backup at: {backup_path}") |
| 209 | + shutil.copytree(target_dir, backup_path) |
| 210 | + |
| 211 | + |
| 212 | +def _parse_episodes_list(episodes_str: str) -> list[int]: |
| 213 | + """ |
| 214 | + Parse a string of episode indices, ranges, and comma-separated lists into a list of integers. |
| 215 | + """ |
| 216 | + episodes = [] |
| 217 | + for ep in episodes_str.split(","): |
| 218 | + if "-" in ep: |
| 219 | + start, end = ep.split("-") |
| 220 | + episodes.extend(range(int(start), int(end) + 1)) |
| 221 | + else: |
| 222 | + episodes.append(int(ep)) |
| 223 | + return episodes |
| 224 | + |
| 225 | + |
| 226 | +def main(): |
| 227 | + parser = argparse.ArgumentParser(description="Remove episodes from a LeRobot dataset") |
| 228 | + parser.add_argument( |
| 229 | + "--repo-id", |
| 230 | + type=str, |
| 231 | + required=True, |
| 232 | + help="Name of hugging face repository containing a LeRobotDataset dataset (e.g. `lerobot/pusht`).", |
| 233 | + ) |
| 234 | + parser.add_argument( |
| 235 | + "--root", |
| 236 | + type=Path, |
| 237 | + default=None, |
| 238 | + help="Root directory for the dataset stored locally. By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", |
| 239 | + ) |
| 240 | + parser.add_argument( |
| 241 | + "-e", |
| 242 | + "--episodes", |
| 243 | + type=str, |
| 244 | + required=True, |
| 245 | + help="Episodes to remove. Can be a single index, comma-separated indices, or ranges (e.g., '1-5,7,10-12')", |
| 246 | + ) |
| 247 | + parser.add_argument( |
| 248 | + "-b", |
| 249 | + "--backup", |
| 250 | + nargs="?", |
| 251 | + const=True, |
| 252 | + default=False, |
| 253 | + help="Create a backup before modifying the dataset. Without a value, creates a backup in the default location. " |
| 254 | + "With a value, either 'true'/'false' or a path to store the backup.", |
| 255 | + ) |
| 256 | + args = parser.parse_args() |
| 257 | + |
| 258 | + # Parse the backup argument |
| 259 | + backup_value = args.backup |
| 260 | + if isinstance(backup_value, str): |
| 261 | + if backup_value.lower() == "true": |
| 262 | + backup_value = True |
| 263 | + elif backup_value.lower() == "false": |
| 264 | + backup_value = False |
| 265 | + # Otherwise, it's treated as a path |
| 266 | + |
| 267 | + # Parse episodes to remove |
| 268 | + episodes_to_remove = _parse_episodes_list(args.episodes) |
| 269 | + if not episodes_to_remove: |
| 270 | + logging.warning("No episodes specified to remove") |
| 271 | + sys.exit(0) |
| 272 | + |
| 273 | + # Load the dataset |
| 274 | + logging.info(f"Loading dataset '{args.repo_id}'...") |
| 275 | + dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root) |
| 276 | + logging.info(f"Dataset has {dataset.meta.total_episodes} episodes") |
| 277 | + |
| 278 | + # Modify the dataset |
| 279 | + logging.info(f"Removing {len(set(episodes_to_remove))} episodes: {sorted(set(episodes_to_remove))}") |
| 280 | + updated_dataset = remove_episodes( |
| 281 | + dataset=dataset, |
| 282 | + episodes_to_remove=episodes_to_remove, |
| 283 | + backup=backup_value, |
| 284 | + ) |
| 285 | + logging.info( |
| 286 | + f"Successfully removed episodes. Dataset now has {updated_dataset.meta.total_episodes} episodes." |
| 287 | + ) |
| 288 | + |
| 289 | + |
| 290 | +if __name__ == "__main__": |
| 291 | + init_logging() |
| 292 | + main() |
0 commit comments