Skip to content

Commit feac375

Browse files
AdilZouitineCadene
andauthored
Refactor the download and publication of the datasets and convert it into CLI script (huggingface#95)
Co-authored-by: Remi <re.cadene@gmail.com>
1 parent 97f2df3 commit feac375

15 files changed

+1419
-836
lines changed

download_and_upload_dataset.py

-779
This file was deleted.

lerobot/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,13 @@
6161
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
6262
)
6363

64+
# TODO(rcadene, aliberts, alexander-soare): Add real-world env with a gym API
65+
available_datasets_without_env = ["lerobot/umi_cup_in_the_wild"]
66+
67+
available_datasets = list(
68+
itertools.chain(*available_datasets_per_env.values(), available_datasets_without_env)
69+
)
70+
6471
available_policies = [
6572
"act",
6673
"diffusion",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
"""
2+
This file contains all obsolete download scripts. They are centralized here to not have to load
3+
useless dependencies when using datasets.
4+
"""
5+
6+
import io
7+
from pathlib import Path
8+
9+
import tqdm
10+
11+
12+
def download_raw(root, dataset_id) -> Path:
13+
if "pusht" in dataset_id:
14+
return download_pusht(root=root, dataset_id=dataset_id)
15+
elif "xarm" in dataset_id:
16+
return download_xarm(root=root, dataset_id=dataset_id)
17+
elif "aloha" in dataset_id:
18+
return download_aloha(root=root, dataset_id=dataset_id)
19+
elif "umi" in dataset_id:
20+
return download_umi(root=root, dataset_id=dataset_id)
21+
else:
22+
raise ValueError(dataset_id)
23+
24+
25+
def download_and_extract_zip(url: str, destination_folder: Path) -> bool:
26+
import zipfile
27+
28+
import requests
29+
30+
print(f"downloading from {url}")
31+
response = requests.get(url, stream=True)
32+
if response.status_code == 200:
33+
total_size = int(response.headers.get("content-length", 0))
34+
progress_bar = tqdm.tqdm(total=total_size, unit="B", unit_scale=True)
35+
36+
zip_file = io.BytesIO()
37+
for chunk in response.iter_content(chunk_size=1024):
38+
if chunk:
39+
zip_file.write(chunk)
40+
progress_bar.update(len(chunk))
41+
42+
progress_bar.close()
43+
44+
zip_file.seek(0)
45+
46+
with zipfile.ZipFile(zip_file, "r") as zip_ref:
47+
zip_ref.extractall(destination_folder)
48+
return True
49+
else:
50+
return False
51+
52+
53+
def download_pusht(root: str, dataset_id: str = "pusht", fps: int = 10) -> Path:
54+
pusht_url = "https://diffusion-policy.cs.columbia.edu/data/training/pusht.zip"
55+
pusht_zarr = Path("pusht/pusht_cchi_v7_replay.zarr")
56+
57+
root = Path(root)
58+
raw_dir: Path = root / f"{dataset_id}_raw"
59+
zarr_path: Path = (raw_dir / pusht_zarr).resolve()
60+
if not zarr_path.is_dir():
61+
raw_dir.mkdir(parents=True, exist_ok=True)
62+
download_and_extract_zip(pusht_url, raw_dir)
63+
return zarr_path
64+
65+
66+
def download_xarm(root: str, dataset_id: str, fps: int = 15) -> Path:
67+
root = Path(root)
68+
raw_dir: Path = root / "xarm_datasets_raw"
69+
if not raw_dir.exists():
70+
import zipfile
71+
72+
import gdown
73+
74+
raw_dir.mkdir(parents=True, exist_ok=True)
75+
# from https://github.com/fyhMer/fowm/blob/main/scripts/download_datasets.py
76+
url = "https://drive.google.com/uc?id=1nhxpykGtPDhmQKm-_B8zBSywVRdgeVya"
77+
zip_path = raw_dir / "data.zip"
78+
gdown.download(url, str(zip_path), quiet=False)
79+
print("Extracting...")
80+
with zipfile.ZipFile(str(zip_path), "r") as zip_f:
81+
for member in zip_f.namelist():
82+
if member.startswith("data/xarm") and member.endswith(".pkl"):
83+
print(member)
84+
zip_f.extract(member=member)
85+
zip_path.unlink()
86+
87+
dataset_path: Path = root / f"{dataset_id}"
88+
return dataset_path
89+
90+
91+
def download_aloha(root: str, dataset_id: str) -> Path:
92+
folder_urls = {
93+
"aloha_sim_insertion_human": "https://drive.google.com/drive/folders/1RgyD0JgTX30H4IM5XZn8I3zSV_mr8pyF",
94+
"aloha_sim_insertion_scripted": "https://drive.google.com/drive/folders/1TsojQQSXtHEoGnqgJ3gmpPQR2DPLtS2N",
95+
"aloha_sim_transfer_cube_human": "https://drive.google.com/drive/folders/1sc-E4QYW7A0o23m1u2VWNGVq5smAsfCo",
96+
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/drive/folders/1aRyoOhQwxhyt1J8XgEig4s6kzaw__LXj",
97+
}
98+
99+
ep48_urls = {
100+
"aloha_sim_insertion_human": "https://drive.google.com/file/d/18Cudl6nikDtgRolea7je8iF_gGKzynOP/view?usp=drive_link",
101+
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/1wfMSZ24oOh5KR_0aaP3Cnu_c4ZCveduB/view?usp=drive_link",
102+
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/18smMymtr8tIxaNUQ61gW6dG50pt3MvGq/view?usp=drive_link",
103+
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1pnGIOd-E4-rhz2P3VxpknMKRZCoKt6eI/view?usp=drive_link",
104+
}
105+
106+
ep49_urls = {
107+
"aloha_sim_insertion_human": "https://drive.google.com/file/d/1C1kZYyROzs-PrLc0SkDgUgMi4-L3lauE/view?usp=drive_link",
108+
"aloha_sim_insertion_scripted": "https://drive.google.com/file/d/17EuCUWS6uCCr6yyNzpXdcdE-_TTNCKtf/view?usp=drive_link",
109+
"aloha_sim_transfer_cube_human": "https://drive.google.com/file/d/1Nk7l53d9sJoGDBKAOnNrExX5nLacATc6/view?usp=drive_link",
110+
"aloha_sim_transfer_cube_scripted": "https://drive.google.com/file/d/1GKReZHrXU73NMiC5zKCq_UtqPVtYq8eo/view?usp=drive_link",
111+
}
112+
num_episodes = { # noqa: F841 # we keep this for reference
113+
"aloha_sim_insertion_human": 50,
114+
"aloha_sim_insertion_scripted": 50,
115+
"aloha_sim_transfer_cube_human": 50,
116+
"aloha_sim_transfer_cube_scripted": 50,
117+
}
118+
119+
episode_len = { # noqa: F841 # we keep this for reference
120+
"aloha_sim_insertion_human": 500,
121+
"aloha_sim_insertion_scripted": 400,
122+
"aloha_sim_transfer_cube_human": 400,
123+
"aloha_sim_transfer_cube_scripted": 400,
124+
}
125+
126+
cameras = { # noqa: F841 # we keep this for reference
127+
"aloha_sim_insertion_human": ["top"],
128+
"aloha_sim_insertion_scripted": ["top"],
129+
"aloha_sim_transfer_cube_human": ["top"],
130+
"aloha_sim_transfer_cube_scripted": ["top"],
131+
}
132+
root = Path(root)
133+
raw_dir: Path = root / f"{dataset_id}_raw"
134+
if not raw_dir.is_dir():
135+
import gdown
136+
137+
assert dataset_id in folder_urls
138+
assert dataset_id in ep48_urls
139+
assert dataset_id in ep49_urls
140+
141+
raw_dir.mkdir(parents=True, exist_ok=True)
142+
143+
gdown.download_folder(folder_urls[dataset_id], output=str(raw_dir))
144+
145+
# because of the 50 files limit per directory, two files episode 48 and 49 were missing
146+
gdown.download(ep48_urls[dataset_id], output=str(raw_dir / "episode_48.hdf5"), fuzzy=True)
147+
gdown.download(ep49_urls[dataset_id], output=str(raw_dir / "episode_49.hdf5"), fuzzy=True)
148+
return raw_dir
149+
150+
151+
def download_umi(root: str, dataset_id: str) -> Path:
152+
url_cup_in_the_wild = "https://real.stanford.edu/umi/data/zarr_datasets/cup_in_the_wild.zarr.zip"
153+
cup_in_the_wild_zarr = Path("umi/cup_in_the_wild/cup_in_the_wild.zarr")
154+
155+
root = Path(root)
156+
raw_dir: Path = root / f"{dataset_id}_raw"
157+
zarr_path: Path = (raw_dir / cup_in_the_wild_zarr).resolve()
158+
if not zarr_path.is_dir():
159+
raw_dir.mkdir(parents=True, exist_ok=True)
160+
download_and_extract_zip(url_cup_in_the_wild, zarr_path)
161+
return zarr_path
162+
163+
164+
if __name__ == "__main__":
165+
root = "data"
166+
dataset_ids = [
167+
"pusht",
168+
"xarm_lift_medium",
169+
"xarm_lift_medium_replay",
170+
"xarm_push_medium",
171+
"xarm_push_medium_replay",
172+
"aloha_sim_insertion_human",
173+
"aloha_sim_insertion_scripted",
174+
"aloha_sim_transfer_cube_human",
175+
"aloha_sim_transfer_cube_scripted",
176+
"umi_cup_in_the_wild",
177+
]
178+
for dataset_id in dataset_ids:
179+
download_raw(root=root, dataset_id=dataset_id)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
import re
2+
from pathlib import Path
3+
4+
import h5py
5+
import torch
6+
import tqdm
7+
from datasets import Dataset, Features, Image, Sequence, Value
8+
from PIL import Image as PILImage
9+
10+
from lerobot.common.datasets.push_dataset_to_hub.utils import concatenate_episodes
11+
from lerobot.common.datasets.utils import (
12+
hf_transform_to_torch,
13+
)
14+
15+
16+
class AlohaProcessor:
17+
"""
18+
Process HDF5 files formatted like in: https://github.com/tonyzhaozh/act
19+
20+
Attributes:
21+
folder_path (Path): Path to the directory containing HDF5 files.
22+
cameras (list[str]): List of camera identifiers to check in the files.
23+
fps (int): Frames per second used in timestamp calculations.
24+
25+
Methods:
26+
is_valid() -> bool:
27+
Validates if each HDF5 file within the folder contains all required datasets.
28+
preprocess() -> dict:
29+
Processes the files and returns structured data suitable for further analysis.
30+
to_hf_dataset(data_dict: dict) -> Dataset:
31+
Converts processed data into a Hugging Face Dataset object.
32+
"""
33+
34+
def __init__(self, folder_path: Path, cameras: list[str] | None = None, fps: int | None = None):
35+
"""
36+
Initializes the AlohaProcessor with a specified directory path containing HDF5 files,
37+
an optional list of cameras, and a frame rate.
38+
39+
Args:
40+
folder_path (Path): The directory path where HDF5 files are stored.
41+
cameras (list[str] | None): Optional list of cameras to validate within the files. Defaults to ['top'] if None.
42+
fps (int): Frame rate for the datasets, used in time calculations. Default is 50.
43+
44+
Examples:
45+
>>> processor = AlohaProcessor(Path("path_to_hdf5_directory"), ["camera1", "camera2"])
46+
>>> processor.is_valid()
47+
True
48+
"""
49+
self.folder_path = folder_path
50+
if cameras is None:
51+
cameras = ["top"]
52+
self.cameras = cameras
53+
if fps is None:
54+
fps = 50
55+
self._fps = fps
56+
57+
@property
58+
def fps(self) -> int:
59+
return self._fps
60+
61+
def is_valid(self) -> bool:
62+
"""
63+
Validates the HDF5 files in the specified folder to ensure they contain the required datasets
64+
for actions, positions, and images for each specified camera.
65+
66+
Returns:
67+
bool: True if all files are valid HDF5 files with all required datasets, False otherwise.
68+
"""
69+
hdf5_files: list[Path] = list(self.folder_path.glob("episode_*.hdf5"))
70+
if len(hdf5_files) == 0:
71+
return False
72+
try:
73+
hdf5_files = sorted(
74+
hdf5_files, key=lambda x: int(re.search(r"episode_(\d+).hdf5", x.name).group(1))
75+
)
76+
except AttributeError:
77+
# All file names must contain a numerical identifier matching 'episode_(\\d+).hdf5
78+
return False
79+
80+
# Check if the sequence is consecutive eg episode_0, episode_1, episode_2, etc.
81+
# If not, return False
82+
previous_number = None
83+
for file in hdf5_files:
84+
current_number = int(re.search(r"episode_(\d+).hdf5", file.name).group(1))
85+
if previous_number is not None and current_number - previous_number != 1:
86+
return False
87+
previous_number = current_number
88+
89+
for file in hdf5_files:
90+
try:
91+
with h5py.File(file, "r") as file:
92+
# Check for the expected datasets within the HDF5 file
93+
required_datasets = ["/action", "/observations/qpos"]
94+
# Add camera-specific image datasets to the required datasets
95+
camera_datasets = [f"/observations/images/{cam}" for cam in self.cameras]
96+
required_datasets.extend(camera_datasets)
97+
98+
if not all(dataset in file for dataset in required_datasets):
99+
return False
100+
except OSError:
101+
return False
102+
return True
103+
104+
def preprocess(self):
105+
"""
106+
Collects episode data from the HDF5 file and returns it as an AlohaStep named tuple.
107+
108+
Returns:
109+
AlohaStep: Named tuple containing episode data.
110+
111+
Raises:
112+
ValueError: If the file is not valid.
113+
"""
114+
if not self.is_valid():
115+
raise ValueError("The HDF5 file is invalid or does not contain the required datasets.")
116+
117+
hdf5_files = list(self.folder_path.glob("*.hdf5"))
118+
hdf5_files = sorted(hdf5_files, key=lambda x: int(re.search(r"episode_(\d+)", x.name).group(1)))
119+
ep_dicts = []
120+
episode_data_index = {"from": [], "to": []}
121+
122+
id_from = 0
123+
124+
for ep_path in tqdm.tqdm(hdf5_files):
125+
with h5py.File(ep_path, "r") as ep:
126+
ep_id = int(re.search(r"episode_(\d+)", ep_path.name).group(1))
127+
num_frames = ep["/action"].shape[0]
128+
129+
# last step of demonstration is considered done
130+
done = torch.zeros(num_frames, dtype=torch.bool)
131+
done[-1] = True
132+
133+
state = torch.from_numpy(ep["/observations/qpos"][:])
134+
action = torch.from_numpy(ep["/action"][:])
135+
136+
ep_dict = {}
137+
138+
for cam in self.cameras:
139+
image = torch.from_numpy(ep[f"/observations/images/{cam}"][:]) # b h w c
140+
ep_dict[f"observation.images.{cam}"] = [PILImage.fromarray(x.numpy()) for x in image]
141+
142+
ep_dict.update(
143+
{
144+
"observation.state": state,
145+
"action": action,
146+
"episode_index": torch.tensor([ep_id] * num_frames),
147+
"frame_index": torch.arange(0, num_frames, 1),
148+
"timestamp": torch.arange(0, num_frames, 1) / self.fps,
149+
# TODO(rcadene): compute reward and success
150+
# "next.reward": reward,
151+
"next.done": done,
152+
# "next.success": success,
153+
}
154+
)
155+
156+
assert isinstance(ep_id, int)
157+
ep_dicts.append(ep_dict)
158+
159+
episode_data_index["from"].append(id_from)
160+
episode_data_index["to"].append(id_from + num_frames)
161+
162+
id_from += num_frames
163+
164+
data_dict = concatenate_episodes(ep_dicts)
165+
return data_dict, episode_data_index
166+
167+
def to_hf_dataset(self, data_dict) -> Dataset:
168+
"""
169+
Converts a dictionary of data into a Hugging Face Dataset object.
170+
171+
Args:
172+
data_dict (dict): A dictionary containing the data to be converted.
173+
174+
Returns:
175+
Dataset: The converted Hugging Face Dataset object.
176+
"""
177+
image_features = {f"observation.images.{cam}": Image() for cam in self.cameras}
178+
features = {
179+
"observation.state": Sequence(
180+
length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None)
181+
),
182+
"action": Sequence(length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None)),
183+
"episode_index": Value(dtype="int64", id=None),
184+
"frame_index": Value(dtype="int64", id=None),
185+
"timestamp": Value(dtype="float32", id=None),
186+
# "next.reward": Value(dtype="float32", id=None),
187+
"next.done": Value(dtype="bool", id=None),
188+
# "next.success": Value(dtype="bool", id=None),
189+
"index": Value(dtype="int64", id=None),
190+
}
191+
update_features = {**image_features, **features}
192+
features = Features(update_features)
193+
hf_dataset = Dataset.from_dict(data_dict, features=features)
194+
hf_dataset.set_transform(hf_transform_to_torch)
195+
196+
return hf_dataset
197+
198+
def cleanup(self):
199+
pass

0 commit comments

Comments
 (0)