|
3 | 3 | It illustrates how to load datasets, manipulate them, and apply transformations suitable for machine learning tasks in PyTorch.
|
4 | 4 |
|
5 | 5 | Features included in this script:
|
6 |
| -- Loading a dataset and accessing its properties. |
7 |
| -- Filtering data by episode number. |
8 |
| -- Converting tensor data for visualization. |
9 |
| -- Saving video files from dataset frames. |
| 6 | +- Viewing a dataset's metadata and exploring its properties. |
| 7 | +- Loading an existing dataset from the hub or a subset of it. |
| 8 | +- Accessing frames by episode number. |
10 | 9 | - Using advanced dataset features like timestamp-based frame selection.
|
11 | 10 | - Demonstrating compatibility with PyTorch DataLoader for batch processing.
|
12 | 11 |
|
13 | 12 | The script ends with examples of how to batch process data using PyTorch's DataLoader.
|
14 | 13 | """
|
15 | 14 |
|
16 |
| -from pathlib import Path |
17 | 15 | from pprint import pprint
|
18 | 16 |
|
19 |
| -import imageio |
20 | 17 | import torch
|
| 18 | +from huggingface_hub import HfApi |
21 | 19 |
|
22 | 20 | import lerobot
|
23 |
| -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset |
| 21 | +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata |
24 | 22 |
|
| 23 | +# We ported a number of existing datasets ourselves, use this to see the list: |
25 | 24 | print("List of available datasets:")
|
26 | 25 | pprint(lerobot.available_datasets)
|
27 | 26 |
|
28 |
| -# Let's take one for this example |
29 |
| -repo_id = "lerobot/pusht" |
30 |
| - |
31 |
| -# You can easily load a dataset from a Hugging Face repository |
| 27 | +# You can also browse through the datasets created/ported by the community on the hub using the hub api: |
| 28 | +hub_api = HfApi() |
| 29 | +repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])] |
| 30 | +pprint(repo_ids) |
| 31 | + |
| 32 | +# Or simply explore them in your web browser directly at: |
| 33 | +# https://huggingface.co/datasets?other=LeRobot |
| 34 | + |
| 35 | +# Let's take this one for this example |
| 36 | +repo_id = "lerobot/aloha_mobile_cabinet" |
| 37 | +# We can have a look and fetch its metadata to know more about it: |
| 38 | +ds_meta = LeRobotDatasetMetadata(repo_id) |
| 39 | + |
| 40 | +# By instantiating just this class, you can quickly access useful information about the content and the |
| 41 | +# structure of the dataset without downloading the actual data yet (only metadata files — which are |
| 42 | +# lightweight). |
| 43 | +print(f"Total number of episodes: {ds_meta.total_episodes}") |
| 44 | +print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}") |
| 45 | +print(f"Frames per second used during data collection: {ds_meta.fps}") |
| 46 | +print(f"Robot type: {ds_meta.robot_type}") |
| 47 | +print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n") |
| 48 | + |
| 49 | +print("Tasks:") |
| 50 | +print(ds_meta.tasks) |
| 51 | +print("Features:") |
| 52 | +pprint(ds_meta.features) |
| 53 | + |
| 54 | +# You can also get a short summary by simply printing the object: |
| 55 | +print(ds_meta) |
| 56 | + |
| 57 | +# You can then load the actual dataset from the hub. |
| 58 | +# Either load any subset of episodes: |
| 59 | +dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23]) |
| 60 | + |
| 61 | +# And see how many frames you have: |
| 62 | +print(f"Selected episodes: {dataset.episodes}") |
| 63 | +print(f"Number of episodes selected: {dataset.num_episodes}") |
| 64 | +print(f"Number of frames selected: {dataset.num_frames}") |
| 65 | + |
| 66 | +# Or simply load the entire dataset: |
32 | 67 | dataset = LeRobotDataset(repo_id)
|
| 68 | +print(f"Number of episodes selected: {dataset.num_episodes}") |
| 69 | +print(f"Number of frames selected: {dataset.num_frames}") |
33 | 70 |
|
34 |
| -# LeRobotDataset is actually a thin wrapper around an underlying Hugging Face dataset |
35 |
| -# (see https://huggingface.co/docs/datasets/index for more information). |
36 |
| -print(dataset) |
37 |
| -print(dataset.hf_dataset) |
| 71 | +# The previous metadata class is contained in the 'meta' attribute of the dataset: |
| 72 | +print(dataset.meta) |
38 | 73 |
|
39 |
| -# And provides additional utilities for robotics and compatibility with Pytorch |
40 |
| -print(f"\naverage number of frames per episode: {dataset.num_samples / dataset.num_episodes:.3f}") |
41 |
| -print(f"frames per second used during data collection: {dataset.fps=}") |
42 |
| -print(f"keys to access images from cameras: {dataset.camera_keys=}\n") |
| 74 | +# LeRobotDataset actually wraps an underlying Hugging Face dataset |
| 75 | +# (see https://huggingface.co/docs/datasets for more information). |
| 76 | +print(dataset.hf_dataset) |
43 | 77 |
|
44 |
| -# Access frame indexes associated to first episode |
| 78 | +# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working |
| 79 | +# with the latter, like iterating through the dataset. |
| 80 | +# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by |
| 81 | +# episodes, you can access the frame indices of any episode using the episode_data_index. Here, we access |
| 82 | +# frame indices associated to the first episode: |
45 | 83 | episode_index = 0
|
46 | 84 | from_idx = dataset.episode_data_index["from"][episode_index].item()
|
47 | 85 | to_idx = dataset.episode_data_index["to"][episode_index].item()
|
48 | 86 |
|
49 |
| -# LeRobot datasets actually subclass PyTorch datasets so you can do everything you know and love from working |
50 |
| -# with the latter, like iterating through the dataset. Here we grab all the image frames. |
51 |
| -frames = [dataset[idx]["observation.image"] for idx in range(from_idx, to_idx)] |
| 87 | +# Then we grab all the image frames from the first camera: |
| 88 | +camera_key = dataset.meta.camera_keys[0] |
| 89 | +frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)] |
52 | 90 |
|
53 |
| -# Video frames are now float32 in range [0,1] channel first (c,h,w) to follow pytorch convention. To visualize |
54 |
| -# them, we convert to uint8 in range [0,255] |
55 |
| -frames = [(frame * 255).type(torch.uint8) for frame in frames] |
56 |
| -# and to channel last (h,w,c). |
57 |
| -frames = [frame.permute((1, 2, 0)).numpy() for frame in frames] |
| 91 | +# The objects returned by the dataset are all torch.Tensors |
| 92 | +print(type(frames[0])) |
| 93 | +print(frames[0].shape) |
58 | 94 |
|
59 |
| -# Finally, we save the frames to a mp4 video for visualization. |
60 |
| -Path("outputs/examples/1_load_lerobot_dataset").mkdir(parents=True, exist_ok=True) |
61 |
| -imageio.mimsave("outputs/examples/1_load_lerobot_dataset/episode_0.mp4", frames, fps=dataset.fps) |
| 95 | +# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w). |
| 96 | +# We can compare this shape with the information available for that feature |
| 97 | +pprint(dataset.features[camera_key]) |
| 98 | +# In particular: |
| 99 | +print(dataset.features[camera_key]["shape"]) |
| 100 | +# The shape is in (h, w, c) which is a more universal format. |
62 | 101 |
|
63 | 102 | # For many machine learning applications we need to load the history of past observations or trajectories of
|
64 | 103 | # future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
65 | 104 | # differences with the current loaded frame. For instance:
|
66 | 105 | delta_timestamps = {
|
67 | 106 | # loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
68 |
| - "observation.image": [-1, -0.5, -0.20, 0], |
69 |
| - # loads 8 state vectors: 1.5 seconds before, 1 second before, ... 20 ms, 10 ms, and current frame |
70 |
| - "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, -0.02, -0.01, 0], |
| 107 | + camera_key: [-1, -0.5, -0.20, 0], |
| 108 | + # loads 8 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame |
| 109 | + "observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0], |
71 | 110 | # loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
72 | 111 | "action": [t / dataset.fps for t in range(64)],
|
73 | 112 | }
|
| 113 | +# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any |
| 114 | +# timestamp, you still get a valid timestamp. |
| 115 | + |
74 | 116 | dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
75 |
| -print(f"\n{dataset[0]['observation.image'].shape=}") # (4,c,h,w) |
76 |
| -print(f"{dataset[0]['observation.state'].shape=}") # (8,c) |
77 |
| -print(f"{dataset[0]['action'].shape=}\n") # (64,c) |
| 117 | +print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w) |
| 118 | +print(f"{dataset[0]['observation.state'].shape=}") # (6, c) |
| 119 | +print(f"{dataset[0]['action'].shape=}\n") # (64, c) |
78 | 120 |
|
79 | 121 | # Finally, our datasets are fully compatible with PyTorch dataloaders and samplers because they are just
|
80 | 122 | # PyTorch datasets.
|
|
84 | 126 | batch_size=32,
|
85 | 127 | shuffle=True,
|
86 | 128 | )
|
| 129 | + |
87 | 130 | for batch in dataloader:
|
88 |
| - print(f"{batch['observation.image'].shape=}") # (32,4,c,h,w) |
89 |
| - print(f"{batch['observation.state'].shape=}") # (32,8,c) |
90 |
| - print(f"{batch['action'].shape=}") # (32,64,c) |
| 131 | + print(f"{batch[camera_key].shape=}") # (32, 4, c, h, w) |
| 132 | + print(f"{batch['observation.state'].shape=}") # (32, 5, c) |
| 133 | + print(f"{batch['action'].shape=}") # (32, 64, c) |
91 | 134 | break
|
0 commit comments