diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index 851869a0f..c245345f4 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -50,7 +50,7 @@ jobs: uses: actions/checkout@v3 - name: Install poetry - run: pipx install poetry + run: pipx install "poetry<2.0.0" - name: Poetry check run: poetry check @@ -64,7 +64,7 @@ jobs: uses: actions/checkout@v3 - name: Install poetry - run: pipx install poetry + run: pipx install "poetry<2.0.0" - name: Install poetry-relax run: poetry self add poetry-relax diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 58eca3206..bec3b1d82 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,6 +17,7 @@ repos: rev: v3.19.0 hooks: - id: pyupgrade + exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)' - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.8.2 hooks: diff --git a/README.md b/README.md index 9331bdeca..849a14de5 100644 --- a/README.md +++ b/README.md @@ -68,7 +68,7 @@ ### Acknowledgment -- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). +- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io). - Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io). - Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM). - Thanks to Antonio Loquercio and Ashish Kumar for their early support. diff --git a/benchmarks/video/README.md b/benchmarks/video/README.md index 890c1142c..56cd1d1e2 100644 --- a/benchmarks/video/README.md +++ b/benchmarks/video/README.md @@ -21,7 +21,7 @@ How to decode videos? ## Variables **Image content & size** -We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution). +We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution). For these reasons, we run this benchmark on four representative datasets: - `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera. - `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera. @@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`. -Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario: +Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario: - `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`), However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded. @@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc **Average Structural Similarity Index Measure (higher is better)** `avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity. -One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes. -h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility: +One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes. +h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility: - `yuv420p` is more widely supported across various platforms, including web browsers. - `yuv444p` offers higher color fidelity but might not be supported as broadly. @@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In - `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1. - `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.). -See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters. +See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters. Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few: - `torchaudio` diff --git a/benchmarks/video/run_video_benchmark.py b/benchmarks/video/run_video_benchmark.py index e90664872..21a143c22 100644 --- a/benchmarks/video/run_video_benchmark.py +++ b/benchmarks/video/run_video_benchmark.py @@ -32,7 +32,11 @@ import pandas as pd import PIL import torch -from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity +from skimage.metrics import ( + mean_squared_error, + peak_signal_noise_ratio, + structural_similarity, +) from tqdm import tqdm from lerobot.common.datasets.lerobot_dataset import LeRobotDataset @@ -81,7 +85,9 @@ def get_directory_size(directory: Path) -> int: return total_size -def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor: +def load_original_frames( + imgs_dir: Path, timestamps: list[float], fps: int +) -> torch.Tensor: frames = [] for ts in timestamps: idx = int(ts * fps) @@ -94,7 +100,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t def save_decoded_frames( - imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int + imgs_dir: Path, + save_dir: Path, + frames: torch.Tensor, + timestamps: list[float], + fps: int, ) -> None: if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps): return @@ -104,7 +114,10 @@ def save_decoded_frames( idx = int(ts * fps) frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy() PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png") - shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png") + shutil.copyfile( + imgs_dir / f"frame_{idx:06d}.png", + save_dir / f"frame_{idx:06d}_original.png", + ) def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: @@ -116,11 +129,17 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: hf_dataset = dataset.hf_dataset.with_format(None) # We only save images from the first camera - img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")] + img_keys = [ + key for key in hf_dataset.features if key.startswith("observation.image") + ] imgs_dataset = hf_dataset.select_columns(img_keys[0]) for i, item in enumerate( - tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False) + tqdm( + imgs_dataset, + desc=f"saving {dataset.repo_id} first episode images", + leave=False, + ) ): img = item[img_keys[0]] img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100) @@ -129,7 +148,9 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None: break -def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]: +def sample_timestamps( + timestamps_mode: str, ep_num_images: int, fps: int +) -> list[float]: # Start at 5 to allow for 2_frames_4_space and 6_frames idx = random.randint(5, ep_num_images - 1) match timestamps_mode: @@ -154,7 +175,9 @@ def decode_video_frames( backend: str, ) -> torch.Tensor: if backend in ["pyav", "video_reader"]: - return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend) + return decode_video_frames_torchvision( + video_path, timestamps, tolerance_s, backend + ) else: raise NotImplementedError(backend) @@ -181,7 +204,9 @@ def process_sample(sample: int): } with time_benchmark: - frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend) + frames = decode_video_frames( + video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend + ) result["load_time_video_ms"] = time_benchmark.result_ms / num_frames with time_benchmark: @@ -190,12 +215,18 @@ def process_sample(sample: int): frames_np, original_frames_np = frames.numpy(), original_frames.numpy() for i in range(num_frames): - result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i])) + result["mse_values"].append( + mean_squared_error(original_frames_np[i], frames_np[i]) + ) result["psnr_values"].append( - peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0) + peak_signal_noise_ratio( + original_frames_np[i], frames_np[i], data_range=1.0 + ) ) result["ssim_values"].append( - structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0) + structural_similarity( + original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0 + ) ) if save_frames and sample == 0: @@ -215,7 +246,9 @@ def process_sample(sample: int): # As these samples are independent, we run them in parallel threads to speed up the benchmark. with ThreadPoolExecutor(max_workers=num_workers) as executor: futures = [executor.submit(process_sample, i) for i in range(num_samples)] - for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False): + for future in tqdm( + as_completed(futures), total=num_samples, desc="samples", leave=False + ): result = future.result() load_times_video_ms.append(result["load_time_video_ms"]) load_times_images_ms.append(result["load_time_images_ms"]) @@ -275,9 +308,13 @@ def benchmark_encoding_decoding( random.seed(seed) benchmark_table = [] for timestamps_mode in tqdm( - decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False + decoding_cfg["timestamps_modes"], + desc="decodings (timestamps_modes)", + leave=False, ): - for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False): + for backend in tqdm( + decoding_cfg["backends"], desc="decodings (backends)", leave=False + ): benchmark_row = benchmark_decoding( imgs_dir, video_path, @@ -355,14 +392,23 @@ def main( imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_") # We only use the first episode save_first_episode(imgs_dir, dataset) - for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False): + for key, values in tqdm( + encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False + ): for value in tqdm(values, desc=f"encodings ({key})", leave=False): encoding_cfg = BASE_ENCODING.copy() encoding_cfg["vcodec"] = video_codec encoding_cfg["pix_fmt"] = pixel_format encoding_cfg[key] = value - args_path = Path("_".join(str(value) for value in encoding_cfg.values())) - video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4" + args_path = Path( + "_".join(str(value) for value in encoding_cfg.values()) + ) + video_path = ( + output_dir + / "videos" + / args_path + / f"{repo_id.replace('/', '_')}.mp4" + ) benchmark_table += benchmark_encoding_decoding( dataset, video_path, @@ -388,7 +434,9 @@ def main( # Concatenate all results df_list = [pd.read_csv(csv_path) for csv_path in file_paths] concatenated_df = pd.concat(df_list, ignore_index=True) - concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv" + concatenated_path = ( + output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv" + ) concatenated_df.to_csv(concatenated_path, header=True, index=False) diff --git a/checkport.py b/checkport.py new file mode 100644 index 000000000..7f79af6ff --- /dev/null +++ b/checkport.py @@ -0,0 +1,18 @@ +import socket + + +def check_port(host, port): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + try: + s.connect((host, port)) + print(f"Connection successful to {host}:{port}!") + except Exception as e: + print(f"Connection failed to {host}:{port}: {e}") + finally: + s.close() + + +if __name__ == "__main__": + host = "127.0.0.1" # or "localhost" + port = 51350 + check_port(host, port) diff --git a/docker/lerobot-gpu-mani-skill/Dockerfile b/docker/lerobot-gpu-mani-skill/Dockerfile new file mode 100644 index 000000000..e45d84e82 --- /dev/null +++ b/docker/lerobot-gpu-mani-skill/Dockerfile @@ -0,0 +1,11 @@ +FROM huggingface/lerobot-gpu:latest + +RUN apt-get update && apt-get install -y --no-install-recommends \ + libvulkan1 vulkan-tools \ + && apt-get clean && rm -rf /var/lib/apt/lists/* + +RUN pip install --upgrade --no-cache-dir pip +RUN pip install --no-cache-dir ".[mani-skill]" + +# Set EGL as the rendering backend for MuJoCo +ENV MUJOCO_GL="egl" diff --git a/examples/10_use_so100.md b/examples/10_use_so100.md index 70e4ed8ba..b247f9804 100644 --- a/examples/10_use_so100.md +++ b/examples/10_use_so100.md @@ -1,25 +1,31 @@ -This tutorial explains how to use [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot. +# Using the [SO-100](https://github.com/TheRobotStudio/SO-ARM100) with LeRobot -## Source the parts + +## A. Source the parts Follow this [README](https://github.com/TheRobotStudio/SO-ARM100). It contains the bill of materials, with link to source the parts, as well as the instructions to 3D print the parts, and advices if it's your first time printing or if you don't own a 3D printer already. **Important**: Before assembling, you will first need to configure your motors. To this end, we provide a nice script, so let's first install LeRobot. After configuration, we will also guide you through assembly. -## Install LeRobot +## B. Install LeRobot On your computer: 1. [Install Miniconda](https://docs.anaconda.com/miniconda/#quick-command-line-install): ```bash mkdir -p ~/miniconda3 +# Linux: wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda3/miniconda.sh +# Mac M-series: +# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh +# Mac Intel: +# curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-x86_64.sh -o ~/miniconda3/miniconda.sh bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3 rm ~/miniconda3/miniconda.sh ~/miniconda3/bin/conda init bash ``` -2. Restart shell or `source ~/.bashrc` +2. Restart shell or `source ~/.bashrc` (*Mac*: `source ~/.bash_profile`) or `source ~/.zshrc` if you're using zshell 3. Create and activate a fresh conda environment for lerobot ```bash @@ -36,23 +42,30 @@ git clone https://github.com/huggingface/lerobot.git ~/lerobot cd ~/lerobot && pip install -e ".[feetech]" ``` -For Linux only (not Mac), install extra dependencies for recording datasets: +*For Linux only (not Mac)*: install extra dependencies for recording datasets: ```bash conda install -y -c conda-forge ffmpeg pip uninstall -y opencv-python conda install -y -c conda-forge "opencv>=4.10.0" ``` -## Configure the motors +## C. Configure the motors + +### 1. Find the USB ports associated to each arm + +Designate one bus servo adapter and 6 motors for your leader arm, and similarly the other bus servo adapter and 6 motors for the follower arm. -Follow steps 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the use of our scripts below. +#### a. Run the script to find ports -**Find USB ports associated to your arms** -To find the correct ports for each arm, run the utility script twice: +Follow Step 1 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I), which illustrates the use of our scripts below. + +To find the port for each bus servo adapter, run the utility script: ```bash python lerobot/scripts/find_motors_bus_port.py ``` +#### b. Example outputs + Example output when identifying the leader arm's port (e.g., `/dev/tty.usbmodem575E0031751` on Mac, or possibly `/dev/ttyACM0` on Linux): ``` Finding all available ports for the MotorBus. @@ -64,7 +77,6 @@ Remove the usb cable from your DynamixelMotorsBus and press Enter when done. The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0031751 Reconnect the usb cable. ``` - Example output when identifying the follower arm's port (e.g., `/dev/tty.usbmodem575E0032081`, or possibly `/dev/ttyACM1` on Linux): ``` Finding all available ports for the MotorBus. @@ -77,13 +89,20 @@ The port of this DynamixelMotorsBus is /dev/tty.usbmodem575E0032081 Reconnect the usb cable. ``` -Troubleshooting: On Linux, you might need to give access to the USB ports by running: +#### c. Troubleshooting +On Linux, you might need to give access to the USB ports by running: ```bash sudo chmod 666 /dev/ttyACM0 sudo chmod 666 /dev/ttyACM1 ``` -**Configure your motors** +#### d. Update YAML file + +Now that you have the ports, modify the *port* sections in `so100.yaml` + +### 2. Configure the motors + +#### a. Set IDs for all 12 motors Plug your first motor and run this script to set its ID to 1. It will also set its present position to 2048, so expect your motor to rotate: ```bash python lerobot/scripts/configure_motor.py \ @@ -94,7 +113,7 @@ python lerobot/scripts/configure_motor.py \ --ID 1 ``` -Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees). +*Note: These motors are currently limitated. They can take values between 0 and 4096 only, which corresponds to a full turn. They can't turn more than that. 2048 is at the middle of this range, so we can take -2048 steps (180 degrees anticlockwise) and reach the maximum range, or take +2048 steps (180 degrees clockwise) and reach the maximum range. The configuration step also sets the homing offset to 0, so that if you misassembled the arm, you can always update the homing offset to account for a shift up to ± 2048 steps (± 180 degrees).* Then unplug your motor and plug the second motor and set its ID to 2. ```bash @@ -108,23 +127,25 @@ python lerobot/scripts/configure_motor.py \ Redo the process for all your motors until ID 6. Do the same for the 6 motors of the leader arm. -**Remove the gears of the 6 leader motors** -Follow step 2 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm. -**Add motor horn to the motors** -Follow step 3 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30. +#### b. Remove the gears of the 6 leader motors + +Follow step 2 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=248). You need to remove the gear for the motors of the leader arm. As a result, you will only use the position encoding of the motor and reduce friction to more easily operate the leader arm. + +#### c. Add motor horn to all 12 motors +Follow step 3 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=569). For SO-100, you need to align the holes on the motor horn to the motor spline to be approximately 1:30, 4:30, 7:30 and 10:30. Try to avoid rotating the motor while doing so to keep position 2048 set during configuration. It is especially tricky for the leader motors as it is more sensible without the gears, but it's ok if it's a bit rotated. -## Assemble the arms +## D. Assemble the arms -Follow step 4 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm. +Follow step 4 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=610). The first arm should take a bit more than 1 hour to assemble, but once you get use to it, you can do it under 1 hour for the second arm. -## Calibrate +## E. Calibrate Next, you'll need to calibrate your SO-100 robot to ensure that the leader and follower arms have the same position values when they are in the same physical position. This calibration is essential because it allows a neural network trained on one SO-100 robot to work on another. -**Manual calibration of follower arm** -/!\ Contrarily to step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the auto calibration, we will actually do manual calibration of follower for now. +#### a. Manual calibration of follower arm +/!\ Contrarily to step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the auto calibration, we will actually do manual calibration of follower for now. You will need to move the follower arm to these positions sequentially: @@ -139,8 +160,8 @@ python lerobot/scripts/control_robot.py calibrate \ --robot-overrides '~cameras' --arms main_follower ``` -**Manual calibration of leader arm** -Follow step 6 of the [assembly video](https://www.youtube.com/watch?v=FioA2oeFZ5I) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially: +#### b. Manual calibration of leader arm +Follow step 6 of the [assembly video](https://youtu.be/FioA2oeFZ5I?t=724) which illustrates the manual calibration. You will need to move the leader arm to these positions sequentially: | 1. Zero position | 2. Rotated position | 3. Rest position | |---|---|---| @@ -153,7 +174,7 @@ python lerobot/scripts/control_robot.py calibrate \ --robot-overrides '~cameras' --arms main_leader ``` -## Teleoperate +## F. Teleoperate **Simple teleop** Then you are ready to teleoperate your robot! Run this simple script (it won't connect and display the cameras): @@ -165,14 +186,14 @@ python lerobot/scripts/control_robot.py teleoperate \ ``` -**Teleop with displaying cameras** +#### a. Teleop with displaying cameras Follow [this guide to setup your cameras](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#c-add-your-cameras-with-opencvcamera). Then you will be able to display the cameras on your computer while you are teleoperating by running the following code. This is useful to prepare your setup before recording your first dataset. ```bash python lerobot/scripts/control_robot.py teleoperate \ --robot-path lerobot/configs/robot/so100.yaml ``` -## Record a dataset +## G. Record a dataset Once you're familiar with teleoperation, you can record your first dataset with SO-100. @@ -201,7 +222,7 @@ python lerobot/scripts/control_robot.py record \ --push-to-hub 1 ``` -## Visualize a dataset +## H. Visualize a dataset If you uploaded your dataset to the hub with `--push-to-hub 1`, you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id given by: ```bash @@ -214,7 +235,7 @@ python lerobot/scripts/visualize_dataset_html.py \ --repo-id ${HF_USER}/so100_test ``` -## Replay an episode +## I. Replay an episode Now try to replay the first episode on your robot: ```bash @@ -225,7 +246,7 @@ python lerobot/scripts/control_robot.py replay \ --episode 0 ``` -## Train a policy +## J. Train a policy To train a policy to control your robot, use the [`python lerobot/scripts/train.py`](../lerobot/scripts/train.py) script. A few arguments are required. Here is an example command: ```bash @@ -248,7 +269,7 @@ Let's explain it: Training should take several hours. You will find checkpoints in `outputs/train/act_so100_test/checkpoints`. -## Evaluate your policy +## K. Evaluate your policy You can use the `record` function from [`lerobot/scripts/control_robot.py`](../lerobot/scripts/control_robot.py) but with a policy checkpoint as input. For instance, run this command to record 10 evaluation episodes: ```bash @@ -268,7 +289,7 @@ As you can see, it's almost the same command as previously used to record your t 1. There is an additional `-p` argument which indicates the path to your policy checkpoint with (e.g. `-p outputs/train/eval_so100_test/checkpoints/last/pretrained_model`). You can also use the model repository if you uploaded a model checkpoint to the hub (e.g. `-p ${HF_USER}/act_so100_test`). 2. The name of dataset begins by `eval` to reflect that you are running inference (e.g. `--repo-id ${HF_USER}/eval_act_so100_test`). -## More +## L. More Information Follow this [previous tutorial](https://github.com/huggingface/lerobot/blob/main/examples/7_get_started_with_real_robot.md#4-train-a-policy-on-your-data) for a more in-depth tutorial on controlling real robots with LeRobot. diff --git a/examples/12_train_hilserl_classifier.md b/examples/12_train_hilserl_classifier.md index eeaf0f2bc..9f7ccf814 100644 --- a/examples/12_train_hilserl_classifier.md +++ b/examples/12_train_hilserl_classifier.md @@ -81,3 +81,14 @@ You can also log sample predictions during evaluation. Each logged sample will i - The **classifier's "confidence" (logits/probability)**. These logs can be useful for diagnosing and debugging performance issues. + + +#### Generate protobuf files + +```bash +python -m grpc_tools.protoc \ + -I lerobot/scripts/server \ + --python_out=lerobot/scripts/server \ + --grpc_python_out=lerobot/scripts/server \ + lerobot/scripts/server/hilserl.proto +``` diff --git a/examples/1_load_lerobot_dataset.py b/examples/1_load_lerobot_dataset.py index 96c104b68..1eddbf4b3 100644 --- a/examples/1_load_lerobot_dataset.py +++ b/examples/1_load_lerobot_dataset.py @@ -18,7 +18,10 @@ from huggingface_hub import HfApi import lerobot -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, +) # We ported a number of existing datasets ourselves, use this to see the list: print("List of available datasets:") @@ -26,7 +29,10 @@ # You can also browse through the datasets created/ported by the community on the hub using the hub api: hub_api = HfApi() -repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])] +repo_ids = [ + info.id + for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"]) +] pprint(repo_ids) # Or simply explore them in your web browser directly at: @@ -41,7 +47,9 @@ # structure of the dataset without downloading the actual data yet (only metadata files — which are # lightweight). print(f"Total number of episodes: {ds_meta.total_episodes}") -print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}") +print( + f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}" +) print(f"Frames per second used during data collection: {ds_meta.fps}") print(f"Robot type: {ds_meta.robot_type}") print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n") diff --git a/examples/2_evaluate_pretrained_policy.py b/examples/2_evaluate_pretrained_policy.py index b2fe1dba1..85a501295 100644 --- a/examples/2_evaluate_pretrained_policy.py +++ b/examples/2_evaluate_pretrained_policy.py @@ -32,7 +32,9 @@ print("GPU is available. Device set to:", device) else: device = torch.device("cpu") - print(f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU.") + print( + f"GPU is not available. Device set to: {device}. Inference will be slower than on GPU." + ) # Decrease the number of reverse-diffusion steps (trades off a bit of quality for 10x speed) policy.diffusion.num_inference_steps = 10 diff --git a/examples/3_train_policy.py b/examples/3_train_policy.py index 935ab2dbf..821e2bfe7 100644 --- a/examples/3_train_policy.py +++ b/examples/3_train_policy.py @@ -31,7 +31,24 @@ # Load the previous action (-0.1), the next action to be executed (0.0), # and 14 future actions with a 0.1 seconds spacing. All these actions will be # used to supervise the policy. - "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], + "action": [ + -0.1, + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + 1.4, + ], } dataset = LeRobotDataset("lerobot/pusht", delta_timestamps=delta_timestamps) diff --git a/examples/6_add_image_transforms.py b/examples/6_add_image_transforms.py index 882710e3d..43024ac23 100644 --- a/examples/6_add_image_transforms.py +++ b/examples/6_add_image_transforms.py @@ -34,10 +34,14 @@ ) # Create another LeRobotDataset with the defined transformations -transformed_dataset = LeRobotDataset(dataset_repo_id, episodes=[0], image_transforms=transforms) +transformed_dataset = LeRobotDataset( + dataset_repo_id, episodes=[0], image_transforms=transforms +) # Get a frame from the transformed dataset -transformed_frame = transformed_dataset[first_idx][transformed_dataset.meta.camera_keys[0]] +transformed_frame = transformed_dataset[first_idx][ + transformed_dataset.meta.camera_keys[0] +] # Create a directory to store output images output_dir = Path("outputs/image_transforms") diff --git a/examples/advanced/2_calculate_validation_loss.py b/examples/advanced/2_calculate_validation_loss.py index 00ba9930f..c61aafa32 100644 --- a/examples/advanced/2_calculate_validation_loss.py +++ b/examples/advanced/2_calculate_validation_loss.py @@ -14,7 +14,10 @@ import torch from huggingface_hub import snapshot_download -from lerobot.common.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata +from lerobot.common.datasets.lerobot_dataset import ( + LeRobotDataset, + LeRobotDatasetMetadata, +) from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy device = torch.device("cuda") @@ -37,7 +40,24 @@ # Load the previous action (-0.1), the next action to be executed (0.0), # and 14 future actions with a 0.1 seconds spacing. All these actions will be # used to calculate the loss. - "action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4], + "action": [ + -0.1, + 0.0, + 0.1, + 0.2, + 0.3, + 0.4, + 0.5, + 0.6, + 0.7, + 0.8, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + 1.4, + ], } # Load the last 10% of episodes of the dataset as a validation set. @@ -53,8 +73,12 @@ print(f"Number of episodes in training dataset (90% subset): {len(train_episodes)}") print(f"Number of episodes in validation dataset (10% subset): {len(val_episodes)}") # - Load train an val datasets -train_dataset = LeRobotDataset("lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps) -val_dataset = LeRobotDataset("lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps) +train_dataset = LeRobotDataset( + "lerobot/pusht", episodes=train_episodes, delta_timestamps=delta_timestamps +) +val_dataset = LeRobotDataset( + "lerobot/pusht", episodes=val_episodes, delta_timestamps=delta_timestamps +) print(f"Number of frames in training dataset (90% subset): {len(train_dataset)}") print(f"Number of frames in validation dataset (10% subset): {len(val_dataset)}") diff --git a/examples/port_datasets/pusht_zarr.py b/examples/port_datasets/pusht_zarr.py index 60df98405..6766ac831 100644 --- a/examples/port_datasets/pusht_zarr.py +++ b/examples/port_datasets/pusht_zarr.py @@ -69,7 +69,9 @@ def load_raw_dataset(zarr_path: Path): ReplayBuffer as DiffusionPolicyReplayBuffer, ) except ModuleNotFoundError as e: - print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + print( + "`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`" + ) raise e zarr_data = DiffusionPolicyReplayBuffer.copy_from_path(zarr_path) @@ -81,7 +83,9 @@ def calculate_coverage(zarr_data): import pymunk from gym_pusht.envs.pusht import PushTEnv, pymunk_to_shapely except ModuleNotFoundError as e: - print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + print( + "`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`" + ) raise e block_pos = zarr_data["state"][:, 2:4] @@ -111,7 +115,9 @@ def calculate_coverage(zarr_data): ] space.add(*walls) - block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) + block_body, block_shapes = PushTEnv.add_tee( + space, block_pos[i].tolist(), block_angle[i].item() + ) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes) intersection_area = goal_geom.intersection(block_geom).area diff --git a/lerobot/__init__.py b/lerobot/__init__.py index 3d5bb6aaa..4540b93e5 100644 --- a/lerobot/__init__.py +++ b/lerobot/__init__.py @@ -182,7 +182,11 @@ ] available_datasets = sorted( - set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets)) + set( + itertools.chain( + *available_datasets_per_env.values(), available_real_world_datasets + ) + ) ) # lists all available policies from `lerobot/common/policies` @@ -224,9 +228,13 @@ "dora_aloha_real": ["act_aloha_real"], } -env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks] +env_task_pairs = [ + (env, task) for env, tasks in available_tasks_per_env.items() for task in tasks +] env_dataset_pairs = [ - (env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets + (env, dataset) + for env, datasets in available_datasets_per_env.items() + for dataset in datasets ] env_dataset_policy_triplets = [ (env, dataset, policy) diff --git a/lerobot/common/datasets/compute_stats.py b/lerobot/common/datasets/compute_stats.py index c62116994..4dbd1a572 100644 --- a/lerobot/common/datasets/compute_stats.py +++ b/lerobot/common/datasets/compute_stats.py @@ -45,12 +45,20 @@ def get_stats_einops_patterns(dataset, num_workers=0): if key in dataset.meta.camera_keys: # sanity check that images are channel first _, c, h, w = batch[key].shape - assert c < h and c < w, f"expect channel first images, but instead {batch[key].shape}" + assert ( + c < h and c < w + ), f"expect channel first images, but instead {batch[key].shape}" # sanity check that images are float32 in range [0,1] - assert batch[key].dtype == torch.float32, f"expect torch.float32, but instead {batch[key].dtype=}" - assert batch[key].max() <= 1, f"expect pixels lower than 1, but instead {batch[key].max()=}" - assert batch[key].min() >= 0, f"expect pixels greater than 1, but instead {batch[key].min()=}" + assert ( + batch[key].dtype == torch.float32 + ), f"expect torch.float32, but instead {batch[key].dtype=}" + assert ( + batch[key].max() <= 1 + ), f"expect pixels lower than 1, but instead {batch[key].max()=}" + assert ( + batch[key].min() >= 0 + ), f"expect pixels greater than 1, but instead {batch[key].min()=}" stats_patterns[key] = "b c h w -> c 1 1" elif batch[key].ndim == 2: @@ -98,7 +106,11 @@ def create_seeded_dataloader(dataset, batch_size, seed): running_item_count = 0 # for online mean computation dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) for i, batch in enumerate( - tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute mean, min, max") + tqdm.tqdm( + dataloader, + total=ceil(max_num_samples / batch_size), + desc="Compute mean, min, max", + ) ): this_batch_size = len(batch["index"]) running_item_count += this_batch_size @@ -113,9 +125,16 @@ def create_seeded_dataloader(dataset, batch_size, seed): # and x is the current batch mean. Some rearrangement is then required to avoid risking # numerical overflow. Another hint: Nₙ₋₁ = Nₙ - Bₙ. Rearrangement yields # x̄ₙ = x̄ₙ₋₁ + Bₙ * (xₙ - x̄ₙ₋₁) / Nₙ - mean[key] = mean[key] + this_batch_size * (batch_mean - mean[key]) / running_item_count - max[key] = torch.maximum(max[key], einops.reduce(batch[key], pattern, "max")) - min[key] = torch.minimum(min[key], einops.reduce(batch[key], pattern, "min")) + mean[key] = ( + mean[key] + + this_batch_size * (batch_mean - mean[key]) / running_item_count + ) + max[key] = torch.maximum( + max[key], einops.reduce(batch[key], pattern, "max") + ) + min[key] = torch.minimum( + min[key], einops.reduce(batch[key], pattern, "min") + ) if i == ceil(max_num_samples / batch_size) - 1: break @@ -124,7 +143,9 @@ def create_seeded_dataloader(dataset, batch_size, seed): running_item_count = 0 # for online std computation dataloader = create_seeded_dataloader(dataset, batch_size, seed=1337) for i, batch in enumerate( - tqdm.tqdm(dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std") + tqdm.tqdm( + dataloader, total=ceil(max_num_samples / batch_size), desc="Compute std" + ) ): this_batch_size = len(batch["index"]) running_item_count += this_batch_size @@ -138,7 +159,9 @@ def create_seeded_dataloader(dataset, batch_size, seed): # Numerically stable update step for mean computation (where the mean is over squared # residuals).See notes in the mean computation loop above. batch_std = einops.reduce((batch[key] - mean[key]) ** 2, pattern, "mean") - std[key] = std[key] + this_batch_size * (batch_std - std[key]) / running_item_count + std[key] = ( + std[key] + this_batch_size * (batch_std - std[key]) / running_item_count + ) if i == ceil(max_num_samples / batch_size) - 1: break @@ -177,13 +200,19 @@ def aggregate_stats(ls_datasets) -> dict[str, torch.Tensor]: # compute `max(dataset_0["max"], dataset_1["max"], ...)` stats[data_key][stat_key] = einops.reduce( torch.stack( - [ds.meta.stats[data_key][stat_key] for ds in ls_datasets if data_key in ds.meta.stats], + [ + ds.meta.stats[data_key][stat_key] + for ds in ls_datasets + if data_key in ds.meta.stats + ], dim=0, ), "n ... -> ...", stat_key, ) - total_samples = sum(d.num_frames for d in ls_datasets if data_key in d.meta.stats) + total_samples = sum( + d.num_frames for d in ls_datasets if data_key in d.meta.stats + ) # Compute the "sum" statistic by multiplying each mean by the number of samples in the respective # dataset, then divide by total_samples to get the overall "mean". # NOTE: the brackets around (d.num_frames / total_samples) are needed tor minimize the risk of diff --git a/lerobot/common/datasets/factory.py b/lerobot/common/datasets/factory.py index f6164ed1d..02ec04230 100644 --- a/lerobot/common/datasets/factory.py +++ b/lerobot/common/datasets/factory.py @@ -74,7 +74,25 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData image_transforms = None if cfg.training.image_transforms.enable: - cfg_tf = cfg.training.image_transforms + default_tf = OmegaConf.create( + { + "brightness": {"weight": 0.0, "min_max": None}, + "contrast": {"weight": 0.0, "min_max": None}, + "saturation": {"weight": 0.0, "min_max": None}, + "hue": {"weight": 0.0, "min_max": None}, + "sharpness": {"weight": 0.0, "min_max": None}, + "max_num_transforms": None, + "random_order": False, + "image_size": None, + "interpolation": None, + "image_mean": None, + "image_std": None, + } + ) + cfg_tf = OmegaConf.merge( + OmegaConf.create(default_tf), cfg.training.image_transforms + ) + image_transforms = get_image_transforms( brightness_weight=cfg_tf.brightness.weight, brightness_min_max=cfg_tf.brightness.min_max, @@ -88,6 +106,12 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData sharpness_min_max=cfg_tf.sharpness.min_max, max_num_transforms=cfg_tf.max_num_transforms, random_order=cfg_tf.random_order, + image_size=(cfg_tf.image_size.height, cfg_tf.image_size.width) + if cfg_tf.image_size + else None, + interpolation=cfg_tf.interpolation, + image_mean=cfg_tf.image_mean, + image_std=cfg_tf.image_std, ) if isinstance(cfg.dataset_repo_id, str): @@ -111,6 +135,8 @@ def make_dataset(cfg, split: str = "train") -> LeRobotDataset | MultiLeRobotData for stats_type, listconfig in stats_dict.items(): # example of stats_type: min, max, mean, std stats = OmegaConf.to_container(listconfig, resolve=True) - dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + dataset.meta.stats[key][stats_type] = torch.tensor( + stats, dtype=torch.float32 + ) return dataset diff --git a/lerobot/common/datasets/image_writer.py b/lerobot/common/datasets/image_writer.py index 85dd6830b..ba53d6fff 100644 --- a/lerobot/common/datasets/image_writer.py +++ b/lerobot/common/datasets/image_writer.py @@ -109,7 +109,9 @@ def __init__(self, num_processes: int = 0, num_threads: int = 1): self._stopped = False if num_threads <= 0 and num_processes <= 0: - raise ValueError("Number of threads and processes must be greater than zero.") + raise ValueError( + "Number of threads and processes must be greater than zero." + ) if self.num_processes == 0: # Use threading @@ -123,12 +125,16 @@ def __init__(self, num_processes: int = 0, num_threads: int = 1): # Use multiprocessing self.queue = multiprocessing.JoinableQueue() for _ in range(self.num_processes): - p = multiprocessing.Process(target=worker_process, args=(self.queue, self.num_threads)) + p = multiprocessing.Process( + target=worker_process, args=(self.queue, self.num_threads) + ) p.daemon = True p.start() self.processes.append(p) - def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path): + def save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path + ): if isinstance(image, torch.Tensor): # Convert tensor to numpy array to minimize main process time image = image.cpu().numpy() diff --git a/lerobot/common/datasets/lerobot_dataset.py b/lerobot/common/datasets/lerobot_dataset.py index 232558056..1c7ae5b55 100644 --- a/lerobot/common/datasets/lerobot_dataset.py +++ b/lerobot/common/datasets/lerobot_dataset.py @@ -68,7 +68,9 @@ # For maintainers, see lerobot/common/datasets/push_dataset_to_hub/CODEBASE_VERSION.md CODEBASE_VERSION = "v2.0" -LEROBOT_HOME = Path(os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot")).expanduser() +LEROBOT_HOME = Path( + os.getenv("LEROBOT_HOME", "~/.cache/huggingface/lerobot") +).expanduser() class LeRobotDatasetMetadata: @@ -84,7 +86,8 @@ def __init__( # Load metadata (self.root / "meta").mkdir(exist_ok=True, parents=True) - self.pull_from_repo(allow_patterns="meta/") + if not self.local_files_only: + self.pull_from_repo(allow_patterns="meta/") self.info = load_info(self.root) self.stats = load_stats(self.root) self.tasks = load_tasks(self.root) @@ -107,7 +110,11 @@ def pull_from_repo( @cached_property def _hub_version(self) -> str | None: - return None if self.local_files_only else get_hub_safe_version(self.repo_id, CODEBASE_VERSION) + return ( + None + if self.local_files_only + else get_hub_safe_version(self.repo_id, CODEBASE_VERSION) + ) @property def _version(self) -> str: @@ -121,7 +128,9 @@ def get_data_file_path(self, ep_index: int) -> Path: def get_video_file_path(self, ep_index: int, vid_key: str) -> Path: ep_chunk = self.get_episode_chunk(ep_index) - fpath = self.video_path.format(episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index) + fpath = self.video_path.format( + episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_index + ) return Path(fpath) def get_episode_chunk(self, ep_index: int) -> int: @@ -165,7 +174,11 @@ def video_keys(self) -> list[str]: @property def camera_keys(self) -> list[str]: """Keys to access visual modalities (regardless of their storage method).""" - return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] + return [ + key + for key, ft in self.features.items() + if ft["dtype"] in ["video", "image"] + ] @property def names(self) -> dict[str, list | dict]: @@ -214,7 +227,9 @@ def get_task_index(self, task: str) -> int: task_index = self.task_to_task_index.get(task, None) return task_index if task_index is not None else self.total_tasks - def save_episode(self, episode_index: int, episode_length: int, task: str, task_index: int) -> None: + def save_episode( + self, episode_index: int, episode_length: int, task: str, task_index: int + ) -> None: self.info["total_episodes"] += 1 self.info["total_frames"] += episode_length @@ -256,7 +271,9 @@ def write_video_info(self) -> None: """ for key in self.video_keys: if not self.features[key].get("info", None): - video_path = self.root / self.get_video_file_path(ep_index=0, vid_key=key) + video_path = self.root / self.get_video_file_path( + ep_index=0, vid_key=key + ) self.info["features"][key]["info"] = get_video_info(video_path) write_json(self.info, self.root / INFO_PATH) @@ -307,7 +324,9 @@ def create( features = {**features, **DEFAULT_FEATURES} obj.tasks, obj.stats, obj.episodes = {}, {}, [] - obj.info = create_empty_dataset_info(CODEBASE_VERSION, fps, robot_type, features, use_videos) + obj.info = create_empty_dataset_info( + CODEBASE_VERSION, fps, robot_type, features, use_videos + ) if len(obj.video_keys) > 0 and not use_videos: raise ValueError() write_json(obj.info, obj.root / INFO_PATH) @@ -443,7 +462,9 @@ def __init__( self.root.mkdir(exist_ok=True, parents=True) # Load metadata - self.meta = LeRobotDatasetMetadata(self.repo_id, self.root, self.local_files_only) + self.meta = LeRobotDatasetMetadata( + self.repo_id, self.root, self.local_files_only + ) # Check version check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION) @@ -451,10 +472,14 @@ def __init__( # Load actual data self.download_episodes(download_videos) self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) + self.episode_data_index = get_episode_data_index( + self.meta.episodes, self.episodes + ) # Check timestamps - check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) + check_timestamps_sync( + self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s + ) # Setup delta_indices if self.delta_timestamps is not None: @@ -500,7 +525,9 @@ def push_to_hub( tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs ) card.push_to_hub(repo_id=self.repo_id, repo_type="dataset") - create_branch(repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset") + create_branch( + repo_id=self.repo_id, branch=CODEBASE_VERSION, repo_type="dataset" + ) def pull_from_repo( self, @@ -528,7 +555,9 @@ def download_episodes(self, download_videos: bool = True) -> None: files = None ignore_patterns = None if download_videos else "videos/" if self.episodes is not None: - files = [str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] + files = [ + str(self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes + ] if len(self.meta.video_keys) > 0 and download_videos: video_files = [ str(self.meta.get_video_file_path(ep_idx, vid_key)) @@ -537,7 +566,8 @@ def download_episodes(self, download_videos: bool = True) -> None: ] files += video_files - self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) + if not self.local_files_only: + self.pull_from_repo(allow_patterns=files, ignore_patterns=ignore_patterns) def load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" @@ -545,7 +575,10 @@ def load_hf_dataset(self) -> datasets.Dataset: path = str(self.root / "data") hf_dataset = load_dataset("parquet", data_dir=path, split="train") else: - files = [str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes] + files = [ + str(self.root / self.meta.get_data_file_path(ep_idx)) + for ep_idx in self.episodes + ] hf_dataset = load_dataset("parquet", data_files=files, split="train") # TODO(aliberts): hf_dataset.set_format("torch") @@ -561,12 +594,20 @@ def fps(self) -> int: @property def num_frames(self) -> int: """Number of frames in selected episodes.""" - return len(self.hf_dataset) if self.hf_dataset is not None else self.meta.total_frames + return ( + len(self.hf_dataset) + if self.hf_dataset is not None + else self.meta.total_frames + ) @property def num_episodes(self) -> int: """Number of episodes selected.""" - return len(self.episodes) if self.episodes is not None else self.meta.total_episodes + return ( + len(self.episodes) + if self.episodes is not None + else self.meta.total_episodes + ) @property def features(self) -> dict[str, dict]: @@ -580,16 +621,24 @@ def hf_features(self) -> datasets.Features: else: return get_hf_features_from_features(self.features) - def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]: + def _get_query_indices( + self, idx: int, ep_idx: int + ) -> tuple[dict[str, list[int | bool]]]: ep_start = self.episode_data_index["from"][ep_idx] ep_end = self.episode_data_index["to"][ep_idx] query_indices = { - key: [max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) for delta in delta_idx] + key: [ + max(ep_start.item(), min(ep_end.item() - 1, idx + delta)) + for delta in delta_idx + ] for key, delta_idx in self.delta_indices.items() } padding = { # Pad values outside of current episode range f"{key}_is_pad": torch.BoolTensor( - [(idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) for delta in delta_idx] + [ + (idx + delta < ep_start.item()) | (idx + delta >= ep_end.item()) + for delta in delta_idx + ] ) for key, delta_idx in self.delta_indices.items() } @@ -617,7 +666,9 @@ def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict: if key not in self.meta.video_keys } - def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict: + def _query_videos( + self, query_timestamps: dict[str, list[float]], ep_idx: int + ) -> dict: """Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a Segmentation Fault. This probably happens because a memory reference to the video loader is created in @@ -647,7 +698,9 @@ def __getitem__(self, idx) -> dict: query_indices = None if self.delta_indices is not None: - current_ep_idx = self.episodes.index(ep_idx) if self.episodes is not None else ep_idx + current_ep_idx = ( + self.episodes.index(ep_idx) if self.episodes is not None else ep_idx + ) query_indices, padding = self._get_query_indices(idx, current_ep_idx) query_result = self._query_hf_dataset(query_indices) item = {**item, **padding} @@ -679,19 +732,28 @@ def __repr__(self): ) def create_episode_buffer(self, episode_index: int | None = None) -> dict: - current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index + current_ep_idx = ( + self.meta.total_episodes if episode_index is None else episode_index + ) return { "size": 0, - **{key: current_ep_idx if key == "episode_index" else [] for key in self.features}, + **{ + key: current_ep_idx if key == "episode_index" else [] + for key in self.features + }, } - def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path: + def _get_image_file_path( + self, episode_index: int, image_key: str, frame_index: int + ) -> Path: fpath = DEFAULT_IMAGE_PATH.format( image_key=image_key, episode_index=episode_index, frame_index=frame_index ) return self.root / fpath - def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None: + def _save_image( + self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path + ) -> None: if self.image_writer is None: if isinstance(image, torch.Tensor): image = image.cpu().numpy() @@ -712,7 +774,9 @@ def add_frame(self, frame: dict) -> None: self.episode_buffer = self.create_episode_buffer() frame_index = self.episode_buffer["size"] - timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps + timestamp = ( + frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps + ) self.episode_buffer["frame_index"].append(frame_index) self.episode_buffer["timestamp"].append(timestamp) @@ -721,11 +785,17 @@ def add_frame(self, frame: dict) -> None: raise ValueError(key) if self.features[key]["dtype"] not in ["image", "video"]: - item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key] + item = ( + frame[key].numpy() + if isinstance(frame[key], torch.Tensor) + else frame[key] + ) self.episode_buffer[key].append(item) elif self.features[key]["dtype"] in ["image", "video"]: img_path = self._get_image_file_path( - episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index + episode_index=self.episode_buffer["episode_index"], + image_key=key, + frame_index=frame_index, ) if frame_index == 0: img_path.parent.mkdir(parents=True, exist_ok=True) @@ -734,7 +804,9 @@ def add_frame(self, frame: dict) -> None: self.episode_buffer["size"] += 1 - def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict | None = None) -> None: + def save_episode( + self, task: str, encode_videos: bool = True, episode_data: dict | None = None + ) -> None: """ This will save to disk the current episode in self.episode_buffer. Note that since it affects files on disk, it sets self.consolidated to False to ensure proper consolidation later on before uploading to @@ -801,7 +873,9 @@ def save_episode(self, task: str, encode_videos: bool = True, episode_data: dict def _save_episode_table(self, episode_buffer: dict, episode_index: int) -> None: episode_dict = {key: episode_buffer[key] for key in self.hf_features} - ep_dataset = datasets.Dataset.from_dict(episode_dict, features=self.hf_features, split="train") + ep_dataset = datasets.Dataset.from_dict( + episode_dict, features=self.hf_features, split="train" + ) ep_data_path = self.root / self.meta.get_data_file_path(ep_index=episode_index) ep_data_path.parent.mkdir(parents=True, exist_ok=True) write_parquet(ep_dataset, ep_data_path) @@ -873,10 +947,16 @@ def encode_episode_videos(self, episode_index: int) -> dict: return video_paths - def consolidate(self, run_compute_stats: bool = True, keep_image_files: bool = False) -> None: + def consolidate( + self, run_compute_stats: bool = True, keep_image_files: bool = False + ) -> None: self.hf_dataset = self.load_hf_dataset() - self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes) - check_timestamps_sync(self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s) + self.episode_data_index = get_episode_data_index( + self.meta.episodes, self.episodes + ) + check_timestamps_sync( + self.hf_dataset, self.episode_data_index, self.fps, self.tolerance_s + ) if len(self.meta.video_keys) > 0: self.encode_videos() @@ -981,7 +1061,9 @@ def __init__( super().__init__() self.repo_ids = repo_ids self.root = Path(root) if root else LEROBOT_HOME - self.tolerances_s = tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids} + self.tolerances_s = ( + tolerances_s if tolerances_s else {repo_id: 1e-4 for repo_id in repo_ids} + ) # Construct the underlying datasets passing everything but `transform` and `delta_timestamps` which # are handled by this class. self._datasets = [ @@ -1058,7 +1140,13 @@ def video(self) -> bool: def features(self) -> datasets.Features: features = {} for dataset in self._datasets: - features.update({k: v for k, v in dataset.hf_features.items() if k not in self.disabled_features}) + features.update( + { + k: v + for k, v in dataset.hf_features.items() + if k not in self.disabled_features + } + ) return features @property @@ -1119,7 +1207,9 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: continue break else: - raise AssertionError("We expect the loop to break out as long as the index is within bounds.") + raise AssertionError( + "We expect the loop to break out as long as the index is within bounds." + ) item = self._datasets[dataset_idx][idx - start_idx] item["dataset_index"] = torch.tensor(dataset_idx) for data_key in self.disabled_features: diff --git a/lerobot/common/datasets/online_buffer.py b/lerobot/common/datasets/online_buffer.py index d907e4687..e31206faa 100644 --- a/lerobot/common/datasets/online_buffer.py +++ b/lerobot/common/datasets/online_buffer.py @@ -131,7 +131,9 @@ def set_delta_timestamps(self, value: dict[str, list[float]] | None): else: self._delta_timestamps = None - def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> dict[str, dict[str, Any]]: + def _make_data_spec( + self, data_spec: dict[str, Any], buffer_capacity: int + ) -> dict[str, dict[str, Any]]: """Makes the data spec for np.memmap.""" if any(k.startswith("_") for k in data_spec): raise ValueError( @@ -154,14 +156,32 @@ def _make_data_spec(self, data_spec: dict[str, Any], buffer_capacity: int) -> di OnlineBuffer.NEXT_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": ()}, # Since the memmap is initialized with all-zeros, this keeps track of which indices are occupied # with real data rather than the dummy initialization. - OnlineBuffer.OCCUPANCY_MASK_KEY: {"dtype": np.dtype("?"), "shape": (buffer_capacity,)}, - OnlineBuffer.INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.FRAME_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.EPISODE_INDEX_KEY: {"dtype": np.dtype("int64"), "shape": (buffer_capacity,)}, - OnlineBuffer.TIMESTAMP_KEY: {"dtype": np.dtype("float64"), "shape": (buffer_capacity,)}, + OnlineBuffer.OCCUPANCY_MASK_KEY: { + "dtype": np.dtype("?"), + "shape": (buffer_capacity,), + }, + OnlineBuffer.INDEX_KEY: { + "dtype": np.dtype("int64"), + "shape": (buffer_capacity,), + }, + OnlineBuffer.FRAME_INDEX_KEY: { + "dtype": np.dtype("int64"), + "shape": (buffer_capacity,), + }, + OnlineBuffer.EPISODE_INDEX_KEY: { + "dtype": np.dtype("int64"), + "shape": (buffer_capacity,), + }, + OnlineBuffer.TIMESTAMP_KEY: { + "dtype": np.dtype("float64"), + "shape": (buffer_capacity,), + }, } for k, v in data_spec.items(): - complete_data_spec[k] = {"dtype": v["dtype"], "shape": (buffer_capacity, *v["shape"])} + complete_data_spec[k] = { + "dtype": v["dtype"], + "shape": (buffer_capacity, *v["shape"]), + } return complete_data_spec def add_data(self, data: dict[str, np.ndarray]): @@ -188,7 +208,9 @@ def add_data(self, data: dict[str, np.ndarray]): # Shift the incoming indices if necessary. if self.num_frames > 0: - last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][next_index - 1] + last_episode_index = self._data[OnlineBuffer.EPISODE_INDEX_KEY][ + next_index - 1 + ] last_data_index = self._data[OnlineBuffer.INDEX_KEY][next_index - 1] data[OnlineBuffer.EPISODE_INDEX_KEY] += last_episode_index + 1 data[OnlineBuffer.INDEX_KEY] += last_data_index + 1 @@ -223,7 +245,11 @@ def fps(self) -> float | None: @property def num_episodes(self) -> int: return len( - np.unique(self._data[OnlineBuffer.EPISODE_INDEX_KEY][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) + np.unique( + self._data[OnlineBuffer.EPISODE_INDEX_KEY][ + self._data[OnlineBuffer.OCCUPANCY_MASK_KEY] + ] + ) ) @property @@ -261,7 +287,9 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: self._data[OnlineBuffer.OCCUPANCY_MASK_KEY], ) )[0] - episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][episode_data_indices] + episode_timestamps = self._data[OnlineBuffer.TIMESTAMP_KEY][ + episode_data_indices + ] for data_key in self.delta_timestamps: # Note: The logic in this loop is copied from `load_previous_and_future_frames`. @@ -278,7 +306,8 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: # Check violated query timestamps are all outside the episode range. assert ( - (query_ts[is_pad] < episode_timestamps[0]) | (episode_timestamps[-1] < query_ts[is_pad]) + (query_ts[is_pad] < episode_timestamps[0]) + | (episode_timestamps[-1] < query_ts[is_pad]) ).all(), ( f"One or several timestamps unexpectedly violate the tolerance ({min_} > {self.tolerance_s=}" ") inside the episode range." @@ -293,7 +322,9 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]: def get_data_by_key(self, key: str) -> torch.Tensor: """Returns all data for a given data key as a Tensor.""" - return torch.from_numpy(self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]]) + return torch.from_numpy( + self._data[key][self._data[OnlineBuffer.OCCUPANCY_MASK_KEY]] + ) def compute_sampler_weights( @@ -324,13 +355,19 @@ def compute_sampler_weights( - Options `drop_first_n_frames` and `episode_indices_to_use` can be added easily. They were not included here to avoid adding complexity. """ - if len(offline_dataset) == 0 and (online_dataset is None or len(online_dataset) == 0): - raise ValueError("At least one of `offline_dataset` or `online_dataset` should be contain data.") + if len(offline_dataset) == 0 and ( + online_dataset is None or len(online_dataset) == 0 + ): + raise ValueError( + "At least one of `offline_dataset` or `online_dataset` should be contain data." + ) if (online_dataset is None) ^ (online_sampling_ratio is None): raise ValueError( "`online_dataset` and `online_sampling_ratio` must be provided together or not at all." ) - offline_sampling_ratio = 0 if online_sampling_ratio is None else 1 - online_sampling_ratio + offline_sampling_ratio = ( + 0 if online_sampling_ratio is None else 1 - online_sampling_ratio + ) weights = [] diff --git a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py index 33b4c9745..8952b585f 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_diffusion_policy_replay_buffer.py @@ -37,10 +37,16 @@ def check_chunks_compatible(chunks: tuple, shape: tuple): assert c > 0 -def rechunk_recompress_array(group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp"): +def rechunk_recompress_array( + group, name, chunks=None, chunk_length=None, compressor=None, tmp_key="_temp" +): old_arr = group[name] if chunks is None: - chunks = (chunk_length,) + old_arr.chunks[1:] if chunk_length is not None else old_arr.chunks + chunks = ( + (chunk_length,) + old_arr.chunks[1:] + if chunk_length is not None + else old_arr.chunks + ) check_chunks_compatible(chunks, old_arr.shape) if compressor is None: @@ -82,13 +88,18 @@ def get_optimal_chunks(shape, dtype, target_chunk_bytes=2e6, max_chunk_length=No for i in range(len(shape) - 1): this_chunk_bytes = itemsize * np.prod(rshape[:i]) next_chunk_bytes = itemsize * np.prod(rshape[: i + 1]) - if this_chunk_bytes <= target_chunk_bytes and next_chunk_bytes > target_chunk_bytes: + if ( + this_chunk_bytes <= target_chunk_bytes + and next_chunk_bytes > target_chunk_bytes + ): split_idx = i rchunks = rshape[:split_idx] item_chunk_bytes = itemsize * np.prod(rshape[:split_idx]) this_max_chunk_length = rshape[split_idx] - next_chunk_length = min(this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes)) + next_chunk_length = min( + this_max_chunk_length, math.ceil(target_chunk_bytes / item_chunk_bytes) + ) rchunks.append(next_chunk_length) len_diff = len(shape) - len(rchunks) rchunks.extend([1] * len_diff) @@ -124,7 +135,13 @@ def create_empty_zarr(cls, storage=None, root=None): root.require_group("data", overwrite=False) meta = root.require_group("meta", overwrite=False) if "episode_ends" not in meta: - meta.zeros("episode_ends", shape=(0,), dtype=np.int64, compressor=None, overwrite=False) + meta.zeros( + "episode_ends", + shape=(0,), + dtype=np.int64, + compressor=None, + overwrite=False, + ) return cls(root=root) @classmethod @@ -193,7 +210,11 @@ def copy_from_store( root = zarr.group(store=store) # copy without recompression n_copied, n_skipped, n_bytes_copied = zarr.copy_store( - source=src_store, dest=store, source_path="/meta", dest_path="/meta", if_exists=if_exists + source=src_store, + dest=store, + source_path="/meta", + dest_path="/meta", + if_exists=if_exists, ) data_group = root.create_group("data", overwrite=True) if keys is None: @@ -201,7 +222,9 @@ def copy_from_store( for key in keys: value = src_root["data"][key] cks = cls._resolve_array_chunks(chunks=chunks, key=key, array=value) - cpr = cls._resolve_array_compressor(compressors=compressors, key=key, array=value) + cpr = cls._resolve_array_compressor( + compressors=compressors, key=key, array=value + ) if cks == value.chunks and cpr == value.compressor: # copy without recompression this_path = "/data/" + key @@ -286,13 +309,17 @@ def save_to_store( meta_group = root.create_group("meta", overwrite=True) # save meta, no chunking for key, value in self.root["meta"].items(): - _ = meta_group.array(name=key, data=value, shape=value.shape, chunks=value.shape) + _ = meta_group.array( + name=key, data=value, shape=value.shape, chunks=value.shape + ) # save data, chunk data_group = root.create_group("data", overwrite=True) for key, value in self.root["data"].items(): cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value) - cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value) + cpr = self._resolve_array_compressor( + compressors=compressors, key=key, array=value + ) if isinstance(value, zarr.Array): if cks == value.chunks and cpr == value.compressor: # copy without recompression @@ -339,13 +366,19 @@ def save_to_path( @staticmethod def resolve_compressor(compressor="default"): if compressor == "default": - compressor = numcodecs.Blosc(cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE) + compressor = numcodecs.Blosc( + cname="lz4", clevel=5, shuffle=numcodecs.Blosc.NOSHUFFLE + ) elif compressor == "disk": - compressor = numcodecs.Blosc("zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE) + compressor = numcodecs.Blosc( + "zstd", clevel=5, shuffle=numcodecs.Blosc.BITSHUFFLE + ) return compressor @classmethod - def _resolve_array_compressor(cls, compressors: dict | str | numcodecs.abc.Codec, key, array): + def _resolve_array_compressor( + cls, compressors: dict | str | numcodecs.abc.Codec, key, array + ): # allows compressor to be explicitly set to None cpr = "nil" if isinstance(compressors, dict): @@ -404,7 +437,11 @@ def update_meta(self, data): if self.backend == "zarr": for key, value in np_data.items(): _ = meta_group.array( - name=key, data=value, shape=value.shape, chunks=value.shape, overwrite=True + name=key, + data=value, + shape=value.shape, + chunks=value.shape, + overwrite=True, ) else: meta_group.update(np_data) @@ -514,10 +551,18 @@ def add_episode( # create array if key not in self.data: if is_zarr: - cks = self._resolve_array_chunks(chunks=chunks, key=key, array=value) - cpr = self._resolve_array_compressor(compressors=compressors, key=key, array=value) + cks = self._resolve_array_chunks( + chunks=chunks, key=key, array=value + ) + cpr = self._resolve_array_compressor( + compressors=compressors, key=key, array=value + ) arr = self.data.zeros( - name=key, shape=new_shape, chunks=cks, dtype=value.dtype, compressor=cpr + name=key, + shape=new_shape, + chunks=cks, + dtype=value.dtype, + compressor=cpr, ) else: # copy data to prevent modify @@ -544,7 +589,9 @@ def add_episode( # rechunk if is_zarr and episode_ends.chunks[0] < episode_ends.shape[0]: - rechunk_recompress_array(self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5)) + rechunk_recompress_array( + self.meta, "episode_ends", chunk_length=int(episode_ends.shape[0] * 1.5) + ) def drop_episode(self): is_zarr = self.backend == "zarr" diff --git a/lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py b/lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py index 184d79fb2..bb2bd4a8e 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_encode_datasets.py @@ -38,7 +38,9 @@ from pathlib import Path from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION -from lerobot.common.datasets.push_dataset_to_hub._download_raw import AVAILABLE_RAW_REPO_IDS +from lerobot.common.datasets.push_dataset_to_hub._download_raw import ( + AVAILABLE_RAW_REPO_IDS, +) from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id from lerobot.scripts.push_dataset_to_hub import push_dataset_to_hub @@ -73,7 +75,9 @@ def encode_datasets( check_repo_id(raw_repo_id) dataset_repo_id_push = get_push_repo_id_from_raw(raw_repo_id, push_repo) dataset_raw_dir = raw_dir / raw_repo_id - dataset_dir = local_dir / dataset_repo_id_push if local_dir is not None else None + dataset_dir = ( + local_dir / dataset_repo_id_push if local_dir is not None else None + ) encoding = { "vcodec": vcodec, "pix_fmt": pix_fmt, diff --git a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py index a118b7e78..a8898933a 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py +++ b/lerobot/common/datasets/push_dataset_to_hub/_umi_imagecodecs_numcodecs.py @@ -133,7 +133,9 @@ def encode(self, buf): ) def decode(self, buf, out=None): - return imagecodecs.jpeg2k_decode(buf, verbose=self.verbose, numthreads=self.numthreads, out=out) + return imagecodecs.jpeg2k_decode( + buf, verbose=self.verbose, numthreads=self.numthreads, out=out + ) class JpegXl(Codec): diff --git a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py index e2973ef81..527b31b2d 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/aloha_hdf5_format.py @@ -44,7 +44,9 @@ def get_cameras(hdf5_data): # ignore depth channel, not currently handled # TODO(rcadene): add depth - rgb_cameras = [key for key in hdf5_data["/observations/images"].keys() if "depth" not in key] # noqa: SIM118 + rgb_cameras = [ + key for key in hdf5_data["/observations/images"].keys() if "depth" not in key + ] # noqa: SIM118 return rgb_cameras @@ -73,7 +75,9 @@ def check_format(raw_dir) -> bool: else: assert data[f"/observations/images/{camera}"].ndim == 4 b, h, w, c = data[f"/observations/images/{camera}"].shape - assert c < h and c < w, f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided." + assert ( + c < h and c < w + ), f"Expect (h,w,c) image format but ({h=},{w=},{c=}) provided." def load_from_raw( @@ -134,14 +138,17 @@ def load_from_raw( # encode images to a mp4 video fname = f"{img_key}_episode_{ep_idx:06d}.mp4" video_path = videos_dir / fname - encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {})) + encode_video_frames( + tmp_imgs_dir, video_path, fps, **(encoding or {}) + ) # clean temporary images directory shutil.rmtree(tmp_imgs_dir) # store the reference to the video frame ep_dict[img_key] = [ - {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames) + {"path": f"videos/{fname}", "timestamp": i / fps} + for i in range(num_frames) ] else: ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] @@ -181,15 +188,18 @@ def to_hf_dataset(data_dict, video) -> Dataset: features[key] = Image() features["observation.state"] = Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.state"].shape[1], + feature=Value(dtype="float32", id=None), ) if "observation.velocity" in data_dict: features["observation.velocity"] = Sequence( - length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.velocity"].shape[1], + feature=Value(dtype="float32", id=None), ) if "observation.effort" in data_dict: features["observation.effort"] = Sequence( - length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.effort"].shape[1], + feature=Value(dtype="float32", id=None), ) features["action"] = Sequence( length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) diff --git a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py index 95f9c0071..c90bd9297 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/dora_parquet_format.py @@ -26,7 +26,9 @@ from datasets import Dataset, Features, Image, Sequence, Value from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION -from lerobot.common.datasets.push_dataset_to_hub.utils import calculate_episode_data_index +from lerobot.common.datasets.push_dataset_to_hub.utils import ( + calculate_episode_data_index, +) from lerobot.common.datasets.utils import ( hf_transform_to_torch, ) @@ -42,11 +44,19 @@ def check_format(raw_dir) -> bool: return True -def load_from_raw(raw_dir: Path, videos_dir: Path, fps: int, video: bool, episodes: list[int] | None = None): +def load_from_raw( + raw_dir: Path, + videos_dir: Path, + fps: int, + video: bool, + episodes: list[int] | None = None, +): # Load data stream that will be used as reference for the timestamps synchronization reference_files = list(raw_dir.glob("observation.images.cam_*.parquet")) if len(reference_files) == 0: - raise ValueError(f"Missing reference files for camera, starting with in '{raw_dir}'") + raise ValueError( + f"Missing reference files for camera, starting with in '{raw_dir}'" + ) # select first camera in alphanumeric order reference_key = sorted(reference_files)[0].stem reference_df = pd.read_parquet(raw_dir / f"{reference_key}.parquet") @@ -107,7 +117,9 @@ def get_episode_index(row): df["timestamp"] = df["timestamp_utc"].map(lambda x: x.timestamp()) # each episode starts with timestamp 0 to match the ones from the video - df["timestamp"] = df.groupby("episode_index")["timestamp"].transform(lambda x: x - x.iloc[0]) + df["timestamp"] = df.groupby("episode_index")["timestamp"].transform( + lambda x: x - x.iloc[0] + ) del df["timestamp_utc"] @@ -120,7 +132,9 @@ def get_episode_index(row): ep_ids = [ep_idx for ep_idx, _ in df.groupby("episode_index")] expected_ep_ids = list(range(df["episode_index"].max() + 1)) if ep_ids != expected_ep_ids: - raise ValueError(f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}") + raise ValueError( + f"Episodes indices go from {ep_ids} instead of {expected_ep_ids}" + ) # Create symlink to raw videos directory (that needs to be absolute not relative) videos_dir.parent.mkdir(parents=True, exist_ok=True) @@ -152,7 +166,9 @@ def get_episode_index(row): data_dict[key] = torch.from_numpy(df[key].values) # is vector elif df[key].iloc[0].shape[0] > 1: - data_dict[key] = torch.stack([torch.from_numpy(x.copy()) for x in df[key].values]) + data_dict[key] = torch.stack( + [torch.from_numpy(x.copy()) for x in df[key].values] + ) else: raise ValueError(key) @@ -170,15 +186,18 @@ def to_hf_dataset(data_dict, video) -> Dataset: features[key] = Image() features["observation.state"] = Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.state"].shape[1], + feature=Value(dtype="float32", id=None), ) if "observation.velocity" in data_dict: features["observation.velocity"] = Sequence( - length=data_dict["observation.velocity"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.velocity"].shape[1], + feature=Value(dtype="float32", id=None), ) if "observation.effort" in data_dict: features["observation.effort"] = Sequence( - length=data_dict["observation.effort"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.effort"].shape[1], + feature=Value(dtype="float32", id=None), ) features["action"] = Sequence( length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) diff --git a/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py b/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py index 1f8a5d144..1c8359735 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/openx_rlds_format.py @@ -143,7 +143,11 @@ def load_from_raw( else: state_keys.append(key) - lang_key = "language_instruction" if "language_instruction" in dataset.element_spec else None + lang_key = ( + "language_instruction" + if "language_instruction" in dataset.element_spec + else None + ) print(" - image_keys: ", image_keys) print(" - lang_key: ", lang_key) @@ -202,7 +206,9 @@ def load_from_raw( # If lang_key is present, convert the entire tensor at once if lang_key is not None: - ep_dict["language_instruction"] = [x.numpy().decode("utf-8") for x in episode[lang_key]] + ep_dict["language_instruction"] = [ + x.numpy().decode("utf-8") for x in episode[lang_key] + ] ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames) @@ -234,7 +240,8 @@ def load_from_raw( # store the reference to the video frame ep_dict[img_key] = [ - {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames) + {"path": f"videos/{fname}", "timestamp": i / fps} + for i in range(num_frames) ] else: ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] @@ -259,7 +266,9 @@ def to_hf_dataset(data_dict, video) -> Dataset: for key in data_dict: # check if vector state obs if key.startswith("observation.") and "observation.images." not in key: - features[key] = Sequence(length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None)) + features[key] = Sequence( + length=data_dict[key].shape[1], feature=Value(dtype="float32", id=None) + ) # check if image obs elif "observation.images." in key: if video: diff --git a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py index 27b31ba24..22b5ea786 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/pusht_zarr_format.py @@ -56,7 +56,9 @@ def check_format(raw_dir): required_datasets.remove("meta/episode_ends") - assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) + assert all( + nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets + ) def load_from_raw( @@ -76,7 +78,9 @@ def load_from_raw( ReplayBuffer as DiffusionPolicyReplayBuffer, ) except ModuleNotFoundError as e: - print("`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`") + print( + "`gym_pusht` is not installed. Please install it with `pip install 'lerobot[gym_pusht]'`" + ) raise e # as define in gmy-pusht env: https://github.com/huggingface/gym-pusht/blob/e0684ff988d223808c0a9dcfaba9dc4991791370/gym_pusht/envs/pusht.py#L174 success_threshold = 0.95 # 95% coverage, @@ -150,7 +154,9 @@ def load_from_raw( ] space.add(*walls) - block_body, block_shapes = PushTEnv.add_tee(space, block_pos[i].tolist(), block_angle[i].item()) + block_body, block_shapes = PushTEnv.add_tee( + space, block_pos[i].tolist(), block_angle[i].item() + ) goal_geom = pymunk_to_shapely(goal_body, block_body.shapes) block_geom = pymunk_to_shapely(block_body, block_body.shapes) intersection_area = goal_geom.intersection(block_geom).area @@ -159,7 +165,9 @@ def load_from_raw( reward[i] = np.clip(coverage / success_threshold, 0, 1) success[i] = coverage > success_threshold if keypoints_instead_of_image: - keypoints[i] = torch.from_numpy(PushTEnv.get_keypoints(block_shapes).flatten()) + keypoints[i] = torch.from_numpy( + PushTEnv.get_keypoints(block_shapes).flatten() + ) # last step of demonstration is considered done done[-1] = True @@ -184,7 +192,8 @@ def load_from_raw( # store the reference to the video frame ep_dict[img_key] = [ - {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames) + {"path": f"videos/{fname}", "timestamp": i / fps} + for i in range(num_frames) ] else: ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] @@ -193,7 +202,9 @@ def load_from_raw( if keypoints_instead_of_image: ep_dict["observation.environment_state"] = keypoints ep_dict["action"] = actions[from_idx:to_idx] - ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) + ep_dict["episode_index"] = torch.tensor( + [ep_idx] * num_frames, dtype=torch.int64 + ) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps # ep_dict["next.observation.image"] = image[1:], @@ -220,7 +231,8 @@ def to_hf_dataset(data_dict, video, keypoints_instead_of_image: bool = False): features["observation.image"] = Image() features["observation.state"] = Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.state"].shape[1], + feature=Value(dtype="float32", id=None), ) if keypoints_instead_of_image: features["observation.environment_state"] = Sequence( @@ -261,7 +273,9 @@ def from_raw_to_lerobot_format( if fps is None: fps = 10 - data_dict = load_from_raw(raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding) + data_dict = load_from_raw( + raw_dir, videos_dir, fps, video, episodes, keypoints_instead_of_image, encoding + ) hf_dataset = to_hf_dataset(data_dict, video, keypoints_instead_of_image) episode_data_index = calculate_episode_data_index(hf_dataset) info = { diff --git a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py index fec893a7f..a03cb0588 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/umi_zarr_format.py @@ -26,7 +26,9 @@ from PIL import Image as PILImage from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION -from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import register_codecs +from lerobot.common.datasets.push_dataset_to_hub._umi_imagecodecs_numcodecs import ( + register_codecs, +) from lerobot.common.datasets.push_dataset_to_hub.utils import ( calculate_episode_data_index, concatenate_episodes, @@ -61,7 +63,9 @@ def check_format(raw_dir) -> bool: nb_frames = zarr_data["data/camera0_rgb"].shape[0] required_datasets.remove("meta/episode_ends") - assert all(nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets) + assert all( + nb_frames == zarr_data[dataset].shape[0] for dataset in required_datasets + ) def load_from_raw( @@ -79,7 +83,9 @@ def load_from_raw( end_pose = torch.from_numpy(zarr_data["data/robot0_demo_end_pose"][:]) start_pos = torch.from_numpy(zarr_data["data/robot0_demo_start_pose"][:]) eff_pos = torch.from_numpy(zarr_data["data/robot0_eef_pos"][:]) - eff_rot_axis_angle = torch.from_numpy(zarr_data["data/robot0_eef_rot_axis_angle"][:]) + eff_rot_axis_angle = torch.from_numpy( + zarr_data["data/robot0_eef_rot_axis_angle"][:] + ) gripper_width = torch.from_numpy(zarr_data["data/robot0_gripper_width"][:]) states_pos = torch.cat([eff_pos, eff_rot_axis_angle], dim=1) @@ -129,24 +135,31 @@ def load_from_raw( save_images_concurrently(imgs_array, tmp_imgs_dir) # encode images to a mp4 video - encode_video_frames(tmp_imgs_dir, video_path, fps, **(encoding or {})) + encode_video_frames( + tmp_imgs_dir, video_path, fps, **(encoding or {}) + ) # clean temporary images directory shutil.rmtree(tmp_imgs_dir) # store the reference to the video frame ep_dict[img_key] = [ - {"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames) + {"path": f"videos/{fname}", "timestamp": i / fps} + for i in range(num_frames) ] else: ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict["observation.state"] = state - ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) + ep_dict["episode_index"] = torch.tensor( + [ep_idx] * num_frames, dtype=torch.int64 + ) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps ep_dict["episode_data_index_from"] = torch.tensor([from_idx] * num_frames) - ep_dict["episode_data_index_to"] = torch.tensor([from_idx + num_frames] * num_frames) + ep_dict["episode_data_index_to"] = torch.tensor( + [from_idx + num_frames] * num_frames + ) ep_dict["end_pose"] = end_pose[from_idx:to_idx] ep_dict["start_pos"] = start_pos[from_idx:to_idx] ep_dict["gripper_width"] = gripper_width[from_idx:to_idx] @@ -172,7 +185,8 @@ def to_hf_dataset(data_dict, video): features["observation.image"] = Image() features["observation.state"] = Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.state"].shape[1], + feature=Value(dtype="float32", id=None), ) features["episode_index"] = Value(dtype="int64", id=None) features["frame_index"] = Value(dtype="int64", id=None) @@ -192,7 +206,8 @@ def to_hf_dataset(data_dict, video): length=data_dict["start_pos"].shape[1], feature=Value(dtype="float32", id=None) ) features["gripper_width"] = Sequence( - length=data_dict["gripper_width"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["gripper_width"].shape[1], + feature=Value(dtype="float32", id=None), ) hf_dataset = Dataset.from_dict(data_dict, features=Features(features)) diff --git a/lerobot/common/datasets/push_dataset_to_hub/utils.py b/lerobot/common/datasets/push_dataset_to_hub/utils.py index ebcf87f77..13997c81e 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/utils.py +++ b/lerobot/common/datasets/push_dataset_to_hub/utils.py @@ -45,7 +45,9 @@ def concatenate_episodes(ep_dicts): return data_dict -def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4): +def save_images_concurrently( + imgs_array: numpy.array, out_dir: Path, max_workers: int = 4 +): out_dir = Path(out_dir) out_dir.mkdir(parents=True, exist_ok=True) @@ -55,7 +57,10 @@ def save_image(img_array, i, out_dir): num_images = len(imgs_array) with ThreadPoolExecutor(max_workers=max_workers) as executor: - [executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)] + [ + executor.submit(save_image, imgs_array[i], i, out_dir) + for i in range(num_images) + ] def get_default_encoding() -> dict: @@ -64,7 +69,8 @@ def get_default_encoding() -> dict: return { k: v.default for k, v in signature.parameters.items() - if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"] + if v.default is not inspect.Parameter.empty + and k in ["vcodec", "pix_fmt", "g", "crf"] } @@ -77,7 +83,9 @@ def check_repo_id(repo_id: str) -> None: # TODO(aliberts): remove -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> Dict[str, torch.Tensor]: +def calculate_episode_data_index( + hf_dataset: datasets.Dataset, +) -> Dict[str, torch.Tensor]: """ Calculate episode data index for the provided HuggingFace Dataset. Relies on episode_index column of hf_dataset. diff --git a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py index 0047e48c3..f628a5f14 100644 --- a/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py +++ b/lerobot/common/datasets/push_dataset_to_hub/xarm_pkl_format.py @@ -40,7 +40,10 @@ def check_format(raw_dir): keys = {"actions", "rewards", "dones"} - nested_keys = {"observations": {"rgb", "state"}, "next_observations": {"rgb", "state"}} + nested_keys = { + "observations": {"rgb", "state"}, + "next_observations": {"rgb", "state"}, + } xarm_files = list(raw_dir.glob("*.pkl")) assert len(xarm_files) > 0 @@ -53,11 +56,17 @@ def check_format(raw_dir): # Check for consistent lengths in nested keys expected_len = len(dataset_dict["actions"]) - assert all(len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict) + assert all( + len(dataset_dict[key]) == expected_len for key in keys if key in dataset_dict + ) for key, subkeys in nested_keys.items(): nested_dict = dataset_dict.get(key, {}) - assert all(len(nested_dict[subkey]) == expected_len for subkey in subkeys if subkey in nested_dict) + assert all( + len(nested_dict[subkey]) == expected_len + for subkey in subkeys + if subkey in nested_dict + ) def load_from_raw( @@ -122,13 +131,18 @@ def load_from_raw( shutil.rmtree(tmp_imgs_dir) # store the reference to the video frame - ep_dict[img_key] = [{"path": f"videos/{fname}", "timestamp": i / fps} for i in range(num_frames)] + ep_dict[img_key] = [ + {"path": f"videos/{fname}", "timestamp": i / fps} + for i in range(num_frames) + ] else: ep_dict[img_key] = [PILImage.fromarray(x) for x in imgs_array] ep_dict["observation.state"] = state ep_dict["action"] = action - ep_dict["episode_index"] = torch.tensor([ep_idx] * num_frames, dtype=torch.int64) + ep_dict["episode_index"] = torch.tensor( + [ep_idx] * num_frames, dtype=torch.int64 + ) ep_dict["frame_index"] = torch.arange(0, num_frames, 1) ep_dict["timestamp"] = torch.arange(0, num_frames, 1) / fps # ep_dict["next.observation.image"] = next_image @@ -153,7 +167,8 @@ def to_hf_dataset(data_dict, video): features["observation.image"] = Image() features["observation.state"] = Sequence( - length=data_dict["observation.state"].shape[1], feature=Value(dtype="float32", id=None) + length=data_dict["observation.state"].shape[1], + feature=Value(dtype="float32", id=None), ) features["action"] = Sequence( length=data_dict["action"].shape[1], feature=Value(dtype="float32", id=None) diff --git a/lerobot/common/datasets/sampler.py b/lerobot/common/datasets/sampler.py index 2f6c15c15..53d0e2e4f 100644 --- a/lerobot/common/datasets/sampler.py +++ b/lerobot/common/datasets/sampler.py @@ -43,7 +43,10 @@ def __init__( ): if episode_indices_to_use is None or episode_idx in episode_indices_to_use: indices.extend( - range(start_index.item() + drop_n_first_frames, end_index.item() - drop_n_last_frames) + range( + start_index.item() + drop_n_first_frames, + end_index.item() - drop_n_last_frames, + ) ) self.indices = indices diff --git a/lerobot/common/datasets/transforms.py b/lerobot/common/datasets/transforms.py index 899f0d66c..3c7922c28 100644 --- a/lerobot/common/datasets/transforms.py +++ b/lerobot/common/datasets/transforms.py @@ -57,7 +57,9 @@ def __init__( elif not isinstance(n_subset, int): raise TypeError("n_subset should be an int or None") elif not (1 <= n_subset <= len(transforms)): - raise ValueError(f"n_subset should be in the interval [1, {len(transforms)}]") + raise ValueError( + f"n_subset should be in the interval [1, {len(transforms)}]" + ) self.transforms = transforms total = sum(p) @@ -116,16 +118,22 @@ def __init__(self, sharpness: float | Sequence[float]) -> None: def _check_input(self, sharpness): if isinstance(sharpness, (int, float)): if sharpness < 0: - raise ValueError("If sharpness is a single number, it must be non negative.") + raise ValueError( + "If sharpness is a single number, it must be non negative." + ) sharpness = [1.0 - sharpness, 1.0 + sharpness] sharpness[0] = max(sharpness[0], 0.0) elif isinstance(sharpness, collections.abc.Sequence) and len(sharpness) == 2: sharpness = [float(v) for v in sharpness] else: - raise TypeError(f"{sharpness=} should be a single number or a sequence with length 2.") + raise TypeError( + f"{sharpness=} should be a single number or a sequence with length 2." + ) if not 0.0 <= sharpness[0] <= sharpness[1]: - raise ValueError(f"sharpnesss values should be between (0., inf), but got {sharpness}.") + raise ValueError( + f"sharpnesss values should be between (0., inf), but got {sharpness}." + ) return float(sharpness[0]), float(sharpness[1]) @@ -134,7 +142,9 @@ def _generate_value(self, left: float, right: float) -> float: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: sharpness_factor = self._generate_value(self.sharpness[0], self.sharpness[1]) - return self._call_kernel(F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor) + return self._call_kernel( + F.adjust_sharpness, inpt, sharpness_factor=sharpness_factor + ) def get_image_transforms( @@ -150,6 +160,10 @@ def get_image_transforms( sharpness_min_max: tuple[float, float] | None = None, max_num_transforms: int | None = None, random_order: bool = False, + interpolation: str | None = None, + image_size: tuple[int, int] | None = None, + image_mean: list[float] | None = None, + image_std: list[float] | None = None, ): def check_value(name, weight, min_max): if min_max is not None: @@ -170,6 +184,22 @@ def check_value(name, weight, min_max): weights = [] transforms = [] + if image_size is not None: + interpolations = [interpolation.value for interpolation in v2.InterpolationMode] + if interpolation is None: + # Use BICUBIC as default interpolation + interpolation_mode = v2.InterpolationMode.BICUBIC + elif interpolation in interpolations: + interpolation_mode = v2.InterpolationMode(interpolation) + else: + raise ValueError("The interpolation passed is not supported") + # Weight for resizing is always 1 + weights.append(1.0) + transforms.append( + v2.Resize( + size=(image_size[0], image_size[1]), interpolation=interpolation_mode + ) + ) if brightness_min_max is not None and brightness_weight > 0.0: weights.append(brightness_weight) transforms.append(v2.ColorJitter(brightness=brightness_min_max)) @@ -185,6 +215,15 @@ def check_value(name, weight, min_max): if sharpness_min_max is not None and sharpness_weight > 0.0: weights.append(sharpness_weight) transforms.append(SharpnessJitter(sharpness=sharpness_min_max)) + if image_mean is not None and image_std is not None: + # Weight for normalization is always 1 + weights.append(1.0) + transforms.append( + v2.Normalize( + mean=image_mean, + std=image_std, + ) + ) n_subset = len(transforms) if max_num_transforms is not None: @@ -194,4 +233,6 @@ def check_value(name, weight, min_max): return v2.Identity() else: # TODO(rcadene, aliberts): add v2.ToDtype float16? - return RandomSubsetApply(transforms, p=weights, n_subset=n_subset, random_order=random_order) + return RandomSubsetApply( + transforms, p=weights, n_subset=n_subset, random_order=random_order + ) diff --git a/lerobot/common/datasets/utils.py b/lerobot/common/datasets/utils.py index af5b03cc0..1162c31e1 100644 --- a/lerobot/common/datasets/utils.py +++ b/lerobot/common/datasets/utils.py @@ -17,9 +17,11 @@ import json import logging import textwrap +from collections.abc import Iterator from itertools import accumulate from pathlib import Path from pprint import pformat +from types import SimpleNamespace from typing import Any import datasets @@ -41,9 +43,15 @@ STATS_PATH = "meta/stats.json" TASKS_PATH = "meta/tasks.jsonl" -DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" -DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" -DEFAULT_IMAGE_PATH = "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +DEFAULT_VIDEO_PATH = ( + "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4" +) +DEFAULT_PARQUET_PATH = ( + "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet" +) +DEFAULT_IMAGE_PATH = ( + "images/{image_key}/episode_{episode_index:06d}/frame_{frame_index:06d}.png" +) DATASET_CARD_TEMPLATE = """ --- @@ -97,7 +105,9 @@ def unflatten_dict(d: dict, sep: str = "/") -> dict: def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict: - serialized_dict = {key: value.tolist() for key, value in flatten_dict(stats).items()} + serialized_dict = { + key: value.tolist() for key, value in flatten_dict(stats).items() + } return unflatten_dict(serialized_dict) @@ -155,14 +165,19 @@ def load_stats(local_dir: Path) -> dict: def load_tasks(local_dir: Path) -> dict: tasks = load_jsonlines(local_dir / TASKS_PATH) - return {item["task_index"]: item["task"] for item in sorted(tasks, key=lambda x: x["task_index"])} + return { + item["task_index"]: item["task"] + for item in sorted(tasks, key=lambda x: x["task_index"]) + } def load_episodes(local_dir: Path) -> dict: return load_jsonlines(local_dir / EPISODES_PATH) -def load_image_as_numpy(fpath: str | Path, dtype="float32", channel_first: bool = True) -> np.ndarray: +def load_image_as_numpy( + fpath: str | Path, dtype="float32", channel_first: bool = True +) -> np.ndarray: img = PILImage.open(fpath).convert("RGB") img_array = np.array(img, dtype=dtype) if channel_first: # (H, W, C) -> (C, H, W) @@ -220,7 +235,10 @@ def __init__(self, repo_id, version): def check_version_compatibility( - repo_id: str, version_to_check: str, current_version: str, enforce_breaking_major: bool = True + repo_id: str, + version_to_check: str, + current_version: str, + enforce_breaking_major: bool = True, ) -> None: current_major, _ = _get_major_minor(current_version) major_to_check, _ = _get_major_minor(version_to_check) @@ -273,6 +291,7 @@ def get_hf_features_from_features(features: dict) -> datasets.Features: hf_features[key] = datasets.Sequence( length=ft["shape"][0], feature=datasets.Value(dtype=ft["dtype"]) ) + # TODO: (alibers, azouitine) Add support for ft["shap"] == 0 as Value return datasets.Features(hf_features) @@ -314,7 +333,9 @@ def create_empty_dataset_info( def get_episode_data_index( episode_dicts: list[dict], episodes: list[int] | None = None ) -> dict[str, torch.Tensor]: - episode_lengths = {ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts)} + episode_lengths = { + ep_idx: ep_dict["length"] for ep_idx, ep_dict in enumerate(episode_dicts) + } if episodes is not None: episode_lengths = {ep_idx: episode_lengths[ep_idx] for ep_idx in episodes} @@ -335,7 +356,9 @@ def calculate_total_episode( return total_episodes -def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]: +def calculate_episode_data_index( + hf_dataset: datasets.Dataset, +) -> dict[str, torch.Tensor]: episode_lengths = [] table = hf_dataset.data.table total_episodes = calculate_total_episode(hf_dataset) @@ -377,7 +400,9 @@ def check_timestamps_sync( # Track original indices before masking original_indices = torch.arange(len(diffs)) filtered_indices = original_indices[mask] - outside_tolerance_filtered_indices = torch.nonzero(~filtered_within_tolerance) # .squeeze() + outside_tolerance_filtered_indices = torch.nonzero( + ~filtered_within_tolerance + ) # .squeeze() outside_tolerance_indices = filtered_indices[outside_tolerance_filtered_indices] episode_indices = torch.stack(hf_dataset["episode_index"]) @@ -402,7 +427,10 @@ def check_timestamps_sync( def check_delta_timestamps( - delta_timestamps: dict[str, list[float]], fps: int, tolerance_s: float, raise_value_error: bool = True + delta_timestamps: dict[str, list[float]], + fps: int, + tolerance_s: float, + raise_value_error: bool = True, ) -> bool: """This will check if all the values in delta_timestamps are multiples of 1/fps +/- tolerance. This is to ensure that these delta_timestamps added to any timestamp from a dataset will themselves be @@ -410,10 +438,14 @@ def check_delta_timestamps( """ outside_tolerance = {} for key, delta_ts in delta_timestamps.items(): - within_tolerance = [abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts] + within_tolerance = [ + abs(ts * fps - round(ts * fps)) / fps <= tolerance_s for ts in delta_ts + ] if not all(within_tolerance): outside_tolerance[key] = [ - ts for ts, is_within in zip(delta_ts, within_tolerance, strict=True) if not is_within + ts + for ts, is_within in zip(delta_ts, within_tolerance, strict=True) + if not is_within ] if len(outside_tolerance) > 0: @@ -431,7 +463,9 @@ def check_delta_timestamps( return True -def get_delta_indices(delta_timestamps: dict[str, list[float]], fps: int) -> dict[str, list[int]]: +def get_delta_indices( + delta_timestamps: dict[str, list[float]], fps: int +) -> dict[str, list[int]]: delta_indices = {} for key, delta_ts in delta_timestamps.items(): delta_indices[key] = (torch.tensor(delta_ts) * fps).long().tolist() @@ -477,7 +511,6 @@ def create_lerobot_dataset_card( Note: If specified, license must be one of https://huggingface.co/docs/hub/repositories-licenses. """ card_tags = ["LeRobot"] - card_template_path = importlib.resources.path("lerobot.common.datasets", "card_template.md") if tags: card_tags += tags @@ -497,8 +530,67 @@ def create_lerobot_dataset_card( ], ) + card_template = ( + importlib.resources.files("lerobot.common.datasets") / "card_template.md" + ).read_text() + return DatasetCard.from_template( card_data=card_data, - template_path=str(card_template_path), + template_str=card_template, **kwargs, ) + + +class IterableNamespace(SimpleNamespace): + """ + A namespace object that supports both dictionary-like iteration and dot notation access. + Automatically converts nested dictionaries into IterableNamespaces. + + This class extends SimpleNamespace to provide: + - Dictionary-style iteration over keys + - Access to items via both dot notation (obj.key) and brackets (obj["key"]) + - Dictionary-like methods: items(), keys(), values() + - Recursive conversion of nested dictionaries + + Args: + dictionary: Optional dictionary to initialize the namespace + **kwargs: Additional keyword arguments passed to SimpleNamespace + + Examples: + >>> data = {"name": "Alice", "details": {"age": 25}} + >>> ns = IterableNamespace(data) + >>> ns.name + 'Alice' + >>> ns.details.age + 25 + >>> list(ns.keys()) + ['name', 'details'] + >>> for key, value in ns.items(): + ... print(f"{key}: {value}") + name: Alice + details: IterableNamespace(age=25) + """ + + def __init__(self, dictionary: dict[str, Any] = None, **kwargs): + super().__init__(**kwargs) + if dictionary is not None: + for key, value in dictionary.items(): + if isinstance(value, dict): + setattr(self, key, IterableNamespace(value)) + else: + setattr(self, key, value) + + def __iter__(self) -> Iterator[str]: + return iter(vars(self)) + + def __getitem__(self, key: str) -> Any: + return vars(self)[key] + + def items(self): + return vars(self).items() + + def values(self): + return vars(self).values() + + def keys(self): + return vars(self).keys() diff --git a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py index c8da2fe14..9f0fda41d 100644 --- a/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/batch_convert_dataset_v1_to_v2.py @@ -26,7 +26,10 @@ from textwrap import dedent from lerobot import available_datasets -from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import convert_dataset, parse_robot_config +from lerobot.common.datasets.v2.convert_dataset_v1_to_v2 import ( + convert_dataset, + parse_robot_config, +) LOCAL_DIR = Path("data/") @@ -117,7 +120,10 @@ "single_task": "Place the battery into the slot of the remote controller.", **ALOHA_STATIC_INFO, }, - "aloha_static_candy": {"single_task": "Pick up the candy and unwrap it.", **ALOHA_STATIC_INFO}, + "aloha_static_candy": { + "single_task": "Pick up the candy and unwrap it.", + **ALOHA_STATIC_INFO, + }, "aloha_static_coffee": { "single_task": "Place the coffee capsule inside the capsule container, then place the cup onto the center of the cup tray, then push the 'Hot Water' and 'Travel Mug' buttons.", **ALOHA_STATIC_INFO, @@ -159,20 +165,29 @@ **ALOHA_STATIC_INFO, }, "aloha_static_vinh_cup": { - "single_task": "Pick up the platic cup with the right arm, then pop its lid open with the left arm.", + "single_task": "Pick up the plastic cup with the right arm, then pop its lid open with the left arm.", **ALOHA_STATIC_INFO, }, "aloha_static_vinh_cup_left": { - "single_task": "Pick up the platic cup with the left arm, then pop its lid open with the right arm.", + "single_task": "Pick up the plastic cup with the left arm, then pop its lid open with the right arm.", + **ALOHA_STATIC_INFO, + }, + "aloha_static_ziploc_slide": { + "single_task": "Slide open the ziploc bag.", + **ALOHA_STATIC_INFO, + }, + "aloha_sim_insertion_scripted": { + "single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO, }, - "aloha_static_ziploc_slide": {"single_task": "Slide open the ziploc bag.", **ALOHA_STATIC_INFO}, - "aloha_sim_insertion_scripted": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, "aloha_sim_insertion_scripted_image": { "single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO, }, - "aloha_sim_insertion_human": {"single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO}, + "aloha_sim_insertion_human": { + "single_task": "Insert the peg into the socket.", + **ALOHA_STATIC_INFO, + }, "aloha_sim_insertion_human_image": { "single_task": "Insert the peg into the socket.", **ALOHA_STATIC_INFO, @@ -193,10 +208,19 @@ "single_task": "Pick up the cube with the right arm and transfer it to the left arm.", **ALOHA_STATIC_INFO, }, - "pusht": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, - "pusht_image": {"single_task": "Push the T-shaped block onto the T-shaped target.", **PUSHT_INFO}, + "pusht": { + "single_task": "Push the T-shaped block onto the T-shaped target.", + **PUSHT_INFO, + }, + "pusht_image": { + "single_task": "Push the T-shaped block onto the T-shaped target.", + **PUSHT_INFO, + }, "unitreeh1_fold_clothes": {"single_task": "Fold the sweatshirt.", **UNITREEH_INFO}, - "unitreeh1_rearrange_objects": {"single_task": "Put the object into the bin.", **UNITREEH_INFO}, + "unitreeh1_rearrange_objects": { + "single_task": "Put the object into the bin.", + **UNITREEH_INFO, + }, "unitreeh1_two_robot_greeting": { "single_task": "Greet the other robot with a high five.", **UNITREEH_INFO, @@ -206,13 +230,31 @@ **UNITREEH_INFO, }, "xarm_lift_medium": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_replay": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, - "xarm_lift_medium_replay_image": {"single_task": "Pick up the cube and lift it.", **XARM_INFO}, + "xarm_lift_medium_image": { + "single_task": "Pick up the cube and lift it.", + **XARM_INFO, + }, + "xarm_lift_medium_replay": { + "single_task": "Pick up the cube and lift it.", + **XARM_INFO, + }, + "xarm_lift_medium_replay_image": { + "single_task": "Pick up the cube and lift it.", + **XARM_INFO, + }, "xarm_push_medium": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_replay": {"single_task": "Push the cube onto the target.", **XARM_INFO}, - "xarm_push_medium_replay_image": {"single_task": "Push the cube onto the target.", **XARM_INFO}, + "xarm_push_medium_image": { + "single_task": "Push the cube onto the target.", + **XARM_INFO, + }, + "xarm_push_medium_replay": { + "single_task": "Push the cube onto the target.", + **XARM_INFO, + }, + "xarm_push_medium_replay_image": { + "single_task": "Push the cube onto the target.", + **XARM_INFO, + }, "umi_cup_in_the_wild": { "single_task": "Put the cup on the plate.", "license": "apache-2.0", diff --git a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py index bf135043b..74fe931f0 100644 --- a/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py +++ b/lerobot/common/datasets/v2/convert_dataset_v1_to_v2.py @@ -152,7 +152,9 @@ V1_STATS_PATH = "meta_data/stats.safetensors" -def parse_robot_config(config_path: Path, config_overrides: list[str] | None = None) -> tuple[str, dict]: +def parse_robot_config( + config_path: Path, config_overrides: list[str] | None = None +) -> tuple[str, dict]: robot_cfg = init_hydra_config(config_path, config_overrides) if robot_cfg["robot_type"] in ["aloha", "koch"]: state_names = [ @@ -203,7 +205,9 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None: torch.testing.assert_close(stats_json[key], stats[key]) -def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = None) -> dict[str, list]: +def get_features_from_hf_dataset( + dataset: Dataset, robot_config: dict | None = None +) -> dict[str, list]: features = {} for key, ft in dataset.features.items(): if isinstance(ft, datasets.Value): @@ -215,7 +219,9 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N dtype = ft.feature.dtype shape = (ft.length,) motor_names = ( - robot_config["names"][key] if robot_config else [f"motor_{i}" for i in range(ft.length)] + robot_config["names"][key] + if robot_config + else [f"motor_{i}" for i in range(ft.length)] ) assert len(motor_names) == shape[0] names = {"motors": motor_names} @@ -239,11 +245,15 @@ def get_features_from_hf_dataset(dataset: Dataset, robot_config: dict | None = N return features -def add_task_index_by_episodes(dataset: Dataset, tasks_by_episodes: dict) -> tuple[Dataset, list[str]]: +def add_task_index_by_episodes( + dataset: Dataset, tasks_by_episodes: dict +) -> tuple[Dataset, list[str]]: df = dataset.to_pandas() tasks = list(set(tasks_by_episodes.values())) tasks_to_task_index = {task: task_idx for task_idx, task in enumerate(tasks)} - episodes_to_task_index = {ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items()} + episodes_to_task_index = { + ep_idx: tasks_to_task_index[task] for ep_idx, task in tasks_by_episodes.items() + } df["task_index"] = df["episode_index"].map(episodes_to_task_index).astype(int) features = dataset.features @@ -260,10 +270,19 @@ def add_task_index_from_tasks_col( # HACK: This is to clean some of the instructions in our version of Open X datasets prefix_to_clean = "tf.Tensor(b'" suffix_to_clean = "', shape=(), dtype=string)" - df[tasks_col] = df[tasks_col].str.removeprefix(prefix_to_clean).str.removesuffix(suffix_to_clean) + df[tasks_col] = ( + df[tasks_col] + .str.removeprefix(prefix_to_clean) + .str.removesuffix(suffix_to_clean) + ) # Create task_index col - tasks_by_episode = df.groupby("episode_index")[tasks_col].unique().apply(lambda x: x.tolist()).to_dict() + tasks_by_episode = ( + df.groupby("episode_index")[tasks_col] + .unique() + .apply(lambda x: x.tolist()) + .to_dict() + ) tasks = df[tasks_col].unique().tolist() tasks_to_task_index = {task: idx for idx, task in enumerate(tasks)} df["task_index"] = df[tasks_col].map(tasks_to_task_index).astype(int) @@ -288,7 +307,9 @@ def split_parquet_by_episodes( for ep_chunk in range(total_chunks): ep_chunk_start = DEFAULT_CHUNK_SIZE * ep_chunk ep_chunk_end = min(DEFAULT_CHUNK_SIZE * (ep_chunk + 1), total_episodes) - chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format(episode_chunk=ep_chunk) + chunk_dir = "/".join(DEFAULT_PARQUET_PATH.split("/")[:-1]).format( + episode_chunk=ep_chunk + ) (output_dir / chunk_dir).mkdir(parents=True, exist_ok=True) for ep_idx in range(ep_chunk_start, ep_chunk_end): ep_table = table.filter(pc.equal(table["episode_index"], ep_idx)) @@ -320,7 +341,9 @@ def move_videos( videos_moved = False video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*.mp4")] if len(video_files) == 0: - video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")] + video_files = [ + str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4") + ] videos_moved = True # Videos have already been moved assert len(video_files) == total_episodes * len(video_keys) @@ -351,7 +374,9 @@ def move_videos( target_path = DEFAULT_VIDEO_PATH.format( episode_chunk=ep_chunk, video_key=vid_key, episode_index=ep_idx ) - video_file = V1_VIDEO_FILE.format(video_key=vid_key, episode_index=ep_idx) + video_file = V1_VIDEO_FILE.format( + video_key=vid_key, episode_index=ep_idx + ) if len(video_dirs) == 1: video_path = video_dirs[0] / video_file else: @@ -368,7 +393,9 @@ def move_videos( subprocess.run(["git", "push"], cwd=work_dir, check=True) -def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str]) -> None: +def fix_lfs_video_files_tracking( + work_dir: Path, lfs_untracked_videos: list[str] +) -> None: """ HACK: This function fixes the tracking by git lfs which was not properly set on some repos. In that case, there's no other option than to download the actual files and reupload them with lfs tracking. @@ -376,7 +403,12 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str] for i in range(0, len(lfs_untracked_videos), 100): files = lfs_untracked_videos[i : i + 100] try: - subprocess.run(["git", "rm", "--cached", *files], cwd=work_dir, capture_output=True, check=True) + subprocess.run( + ["git", "rm", "--cached", *files], + cwd=work_dir, + capture_output=True, + check=True, + ) except subprocess.CalledProcessError as e: print("git rm --cached ERROR:") print(e.stderr) @@ -387,10 +419,14 @@ def fix_lfs_video_files_tracking(work_dir: Path, lfs_untracked_videos: list[str] subprocess.run(["git", "push"], cwd=work_dir, check=True) -def fix_gitattributes(work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path) -> None: +def fix_gitattributes( + work_dir: Path, current_gittatributes: Path, clean_gittatributes: Path +) -> None: shutil.copyfile(clean_gittatributes, current_gittatributes) subprocess.run(["git", "add", ".gitattributes"], cwd=work_dir, check=True) - subprocess.run(["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True) + subprocess.run( + ["git", "commit", "-m", "Fix .gitattributes"], cwd=work_dir, check=True + ) subprocess.run(["git", "push"], cwd=work_dir, check=True) @@ -399,7 +435,17 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None: repo_url = f"https://huggingface.co/datasets/{repo_id}" env = {"GIT_LFS_SKIP_SMUDGE": "1"} # Prevent downloading LFS files subprocess.run( - ["git", "clone", "--branch", branch, "--single-branch", "--depth", "1", repo_url, str(work_dir)], + [ + "git", + "clone", + "--branch", + branch, + "--single-branch", + "--depth", + "1", + repo_url, + str(work_dir), + ], check=True, env=env, ) @@ -407,13 +453,19 @@ def _lfs_clone(repo_id: str, work_dir: Path, branch: str) -> None: def _get_lfs_untracked_videos(work_dir: Path, video_files: list[str]) -> list[str]: lfs_tracked_files = subprocess.run( - ["git", "lfs", "ls-files", "-n"], cwd=work_dir, capture_output=True, text=True, check=True + ["git", "lfs", "ls-files", "-n"], + cwd=work_dir, + capture_output=True, + text=True, + check=True, ) lfs_tracked_files = set(lfs_tracked_files.stdout.splitlines()) return [f for f in video_files if f not in lfs_tracked_files] -def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch: str) -> dict: +def get_videos_info( + repo_id: str, local_dir: Path, video_keys: list[str], branch: str +) -> dict: # Assumes first episode video_files = [ DEFAULT_VIDEO_PATH.format(episode_chunk=0, video_key=vid_key, episode_index=0) @@ -421,7 +473,11 @@ def get_videos_info(repo_id: str, local_dir: Path, video_keys: list[str], branch ] hub_api = HfApi() hub_api.snapshot_download( - repo_id=repo_id, repo_type="dataset", local_dir=local_dir, revision=branch, allow_patterns=video_files + repo_id=repo_id, + repo_type="dataset", + local_dir=local_dir, + revision=branch, + allow_patterns=video_files, ) videos_info_dict = {} for vid_key, vid_path in zip(video_keys, video_files, strict=True): @@ -448,7 +504,11 @@ def convert_dataset( hub_api = HfApi() hub_api.snapshot_download( - repo_id=repo_id, repo_type="dataset", revision=v1, local_dir=v1x_dir, ignore_patterns="videos*/" + repo_id=repo_id, + repo_type="dataset", + revision=v1, + local_dir=v1x_dir, + ignore_patterns="videos*/", ) branch = "main" if test_branch: @@ -480,19 +540,31 @@ def convert_dataset( if single_task: tasks_by_episodes = {ep_idx: single_task for ep_idx in episode_indices} dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) - tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} + tasks_by_episodes = { + ep_idx: [task] for ep_idx, task in tasks_by_episodes.items() + } elif tasks_path: tasks_by_episodes = load_json(tasks_path) - tasks_by_episodes = {int(ep_idx): task for ep_idx, task in tasks_by_episodes.items()} + tasks_by_episodes = { + int(ep_idx): task for ep_idx, task in tasks_by_episodes.items() + } dataset, tasks = add_task_index_by_episodes(dataset, tasks_by_episodes) - tasks_by_episodes = {ep_idx: [task] for ep_idx, task in tasks_by_episodes.items()} + tasks_by_episodes = { + ep_idx: [task] for ep_idx, task in tasks_by_episodes.items() + } elif tasks_col: - dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col(dataset, tasks_col) + dataset, tasks, tasks_by_episodes = add_task_index_from_tasks_col( + dataset, tasks_col + ) else: raise ValueError - assert set(tasks) == {task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks} - tasks = [{"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks)] + assert set(tasks) == { + task for ep_tasks in tasks_by_episodes.values() for task in ep_tasks + } + tasks = [ + {"task_index": task_idx, "task": task} for task_idx, task in enumerate(tasks) + ] write_jsonlines(tasks, v20_dir / TASKS_PATH) features["task_index"] = { "dtype": "int64", @@ -506,14 +578,25 @@ def convert_dataset( dataset = dataset.remove_columns(video_keys) clean_gitattr = Path( hub_api.hf_hub_download( - repo_id=GITATTRIBUTES_REF, repo_type="dataset", local_dir=local_dir, filename=".gitattributes" + repo_id=GITATTRIBUTES_REF, + repo_type="dataset", + local_dir=local_dir, + filename=".gitattributes", ) ).absolute() with tempfile.TemporaryDirectory() as tmp_video_dir: move_videos( - repo_id, video_keys, total_episodes, total_chunks, Path(tmp_video_dir), clean_gitattr, branch + repo_id, + video_keys, + total_episodes, + total_chunks, + Path(tmp_video_dir), + clean_gitattr, + branch, ) - videos_info = get_videos_info(repo_id, v1x_dir, video_keys=video_keys, branch=branch) + videos_info = get_videos_info( + repo_id, v1x_dir, video_keys=video_keys, branch=branch + ) for key in video_keys: features[key]["shape"] = ( videos_info[key].pop("video.height"), @@ -521,15 +604,22 @@ def convert_dataset( videos_info[key].pop("video.channels"), ) features[key]["video_info"] = videos_info[key] - assert math.isclose(videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3) + assert math.isclose( + videos_info[key]["video.fps"], metadata_v1["fps"], rel_tol=1e-3 + ) if "encoding" in metadata_v1: - assert videos_info[key]["video.pix_fmt"] == metadata_v1["encoding"]["pix_fmt"] + assert ( + videos_info[key]["video.pix_fmt"] + == metadata_v1["encoding"]["pix_fmt"] + ) else: assert metadata_v1.get("video", 0) == 0 videos_info = None # Split data into 1 parquet file by episode - episode_lengths = split_parquet_by_episodes(dataset, total_episodes, total_chunks, v20_dir) + episode_lengths = split_parquet_by_episodes( + dataset, total_episodes, total_chunks, v20_dir + ) if robot_config is not None: robot_type = robot_config["robot_type"] @@ -540,7 +630,11 @@ def convert_dataset( # Episodes episodes = [ - {"episode_index": ep_idx, "tasks": tasks_by_episodes[ep_idx], "length": episode_lengths[ep_idx]} + { + "episode_index": ep_idx, + "tasks": tasks_by_episodes[ep_idx], + "length": episode_lengths[ep_idx], + } for ep_idx in episode_indices ] write_jsonlines(episodes, v20_dir / EPISODES_PATH) @@ -563,16 +657,27 @@ def convert_dataset( } write_json(metadata_v2_0, v20_dir / INFO_PATH) convert_stats_to_json(v1x_dir, v20_dir) - card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs) + card = create_lerobot_dataset_card( + tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs + ) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch) + hub_api.delete_folder( + repo_id=repo_id, path_in_repo="data", repo_type="dataset", revision=branch + ) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta_data", repo_type="dataset", revision=branch) + hub_api.delete_folder( + repo_id=repo_id, + path_in_repo="meta_data", + repo_type="dataset", + revision=branch, + ) with contextlib.suppress(EntryNotFoundError, HfHubHTTPError): - hub_api.delete_folder(repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch) + hub_api.delete_folder( + repo_id=repo_id, path_in_repo="meta", repo_type="dataset", revision=branch + ) hub_api.upload_folder( repo_id=repo_id, @@ -655,7 +760,11 @@ def main(): if not args.local_dir: args.local_dir = Path("/tmp/lerobot_dataset_v2") - robot_config = parse_robot_config(args.robot_config, args.robot_overrides) if args.robot_config else None + robot_config = ( + parse_robot_config(args.robot_config, args.robot_overrides) + if args.robot_config + else None + ) del args.robot_config, args.robot_overrides convert_dataset(**vars(args), robot_config=robot_config) diff --git a/lerobot/common/datasets/video_utils.py b/lerobot/common/datasets/video_utils.py index 8ed3318dd..d63bbf8c6 100644 --- a/lerobot/common/datasets/video_utils.py +++ b/lerobot/common/datasets/video_utils.py @@ -227,7 +227,9 @@ def get_audio_info(video_path: Path | str) -> dict: "json", str(video_path), ] - result = subprocess.run(ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + result = subprocess.run( + ffprobe_audio_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) if result.returncode != 0: raise RuntimeError(f"Error running ffprobe: {result.stderr}") @@ -241,7 +243,9 @@ def get_audio_info(video_path: Path | str) -> dict: "has_audio": True, "audio.channels": audio_stream_info.get("channels", None), "audio.codec": audio_stream_info.get("codec_name", None), - "audio.bit_rate": int(audio_stream_info["bit_rate"]) if audio_stream_info.get("bit_rate") else None, + "audio.bit_rate": int(audio_stream_info["bit_rate"]) + if audio_stream_info.get("bit_rate") + else None, "audio.sample_rate": int(audio_stream_info["sample_rate"]) if audio_stream_info.get("sample_rate") else None, @@ -263,7 +267,9 @@ def get_video_info(video_path: Path | str) -> dict: "json", str(video_path), ] - result = subprocess.run(ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + result = subprocess.run( + ffprobe_video_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True + ) if result.returncode != 0: raise RuntimeError(f"Error running ffprobe: {result.stderr}") diff --git a/lerobot/common/envs/factory.py b/lerobot/common/envs/factory.py index 54f24ea84..457b7af67 100644 --- a/lerobot/common/envs/factory.py +++ b/lerobot/common/envs/factory.py @@ -14,9 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import importlib +from collections import deque import gymnasium as gym +import numpy as np +import torch from omegaconf import DictConfig +# from mani_skill.utils import common def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv | None: @@ -30,6 +34,12 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv if cfg.env.name == "real_world": return + if "maniskill" in cfg.env.name: + env = make_maniskill_env( + cfg, n_envs if n_envs is not None else cfg.eval.batch_size + ) + return env + package_name = f"gym_{cfg.env.name}" try: @@ -47,7 +57,11 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv gym_kwgs["max_episode_steps"] = cfg.env.episode_length # batched version of the env that returns an observation of shape (b, c) - env_cls = gym.vector.AsyncVectorEnv if cfg.eval.use_async_envs else gym.vector.SyncVectorEnv + env_cls = ( + gym.vector.AsyncVectorEnv + if cfg.eval.use_async_envs + else gym.vector.SyncVectorEnv + ) env = env_cls( [ lambda: gym.make(gym_handle, disable_env_checker=True, **gym_kwgs) @@ -56,3 +70,98 @@ def make_env(cfg: DictConfig, n_envs: int | None = None) -> gym.vector.VectorEnv ) return env + + +def make_maniskill_env( + cfg: DictConfig, n_envs: int | None = None +) -> gym.vector.VectorEnv | None: + """Make ManiSkill3 gym environment""" + from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv + + env = gym.make( + cfg.env.task, + obs_mode=cfg.env.obs, + control_mode=cfg.env.control_mode, + render_mode=cfg.env.render_mode, + sensor_configs=dict(width=cfg.env.image_size, height=cfg.env.image_size), + num_envs=n_envs, + ) + # cfg.env_cfg.control_mode = cfg.eval_env_cfg.control_mode = env.control_mode + env = ManiSkillVectorEnv(env, ignore_terminations=True) + # state should have the size of 25 + # env = ConvertToLeRobotEnv(env, n_envs) + # env = PixelWrapper(cfg, env, n_envs) + env._max_episode_steps = env.max_episode_steps = ( + 50 # gym_utils.find_max_episode_steps_value(env) + ) + env.unwrapped.metadata["render_fps"] = 20 + + return env + + +class PixelWrapper(gym.Wrapper): + """ + Wrapper for pixel observations. Works with Maniskill vectorized environments + """ + + def __init__(self, cfg, env, num_envs, num_frames=3): + super().__init__(env) + self.cfg = cfg + self.env = env + self.observation_space = gym.spaces.Box( + low=0, + high=255, + shape=(num_envs, num_frames * 3, cfg.env.render_size, cfg.env.render_size), + dtype=np.uint8, + ) + self._frames = deque([], maxlen=num_frames) + self._render_size = cfg.env.render_size + + def _get_obs(self, obs): + frame = obs["sensor_data"]["base_camera"]["rgb"].cpu().permute(0, 3, 1, 2) + self._frames.append(frame) + return { + "pixels": torch.from_numpy(np.concatenate(self._frames, axis=1)).to( + self.env.device + ) + } + + def reset(self, seed): + obs, info = self.env.reset() # (seed=seed) + for _ in range(self._frames.maxlen): + obs_frames = self._get_obs(obs) + return obs_frames, info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return self._get_obs(obs), reward, terminated, truncated, info + + +class ConvertToLeRobotEnv(gym.Wrapper): + def __init__(self, env, num_envs): + super().__init__(env) + + def reset(self, seed=None, options=None): + obs, info = self.env.reset(seed=seed, options={}) + return self._get_obs(obs), info + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + return self._get_obs(obs), reward, terminated, truncated, info + + def _get_obs(self, observation): + sensor_data = observation.pop("sensor_data") + del observation["sensor_param"] + images = [] + for cam_data in sensor_data.values(): + images.append(cam_data["rgb"]) + + images = torch.concat(images, axis=-1) + # flatten the rest of the data which should just be state data + observation = common.flatten_state_dict( + observation, use_torch=True, device=self.base_env.device + ) + ret = dict() + ret["state"] = observation + ret["pixels"] = images + return ret diff --git a/lerobot/common/envs/utils.py b/lerobot/common/envs/utils.py index 001973bc1..a163e6f8b 100644 --- a/lerobot/common/envs/utils.py +++ b/lerobot/common/envs/utils.py @@ -28,28 +28,32 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten """ # map to expected inputs for the policy return_observations = {} - if "pixels" in observations: - if isinstance(observations["pixels"], dict): - imgs = {f"observation.images.{key}": img for key, img in observations["pixels"].items()} - else: - imgs = {"observation.image": observations["pixels"]} + # TODO: You have to merge all tensors from agent key and extra key + # You don't keep sensor param key in the observation + # And you keep sensor data rgb + for key, img in observations.items(): + if "images" not in key: + continue - for imgkey, img in imgs.items(): - img = torch.from_numpy(img) + if img.ndim == 3: + img = img.unsqueeze(0) + # sanity check that images are channel last + _, h, w, c = img.shape + assert ( + c < h and c < w + ), f"expect channel last images, but instead got {img.shape=}" - # sanity check that images are channel last - _, h, w, c = img.shape - assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" - # sanity check that images are uint8 - assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + # convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + img /= 255 - # convert to channel first of type float32 in range [0,1] - img = einops.rearrange(img, "b h w c -> b c h w").contiguous() - img = img.type(torch.float32) - img /= 255 - - return_observations[imgkey] = img + return_observations[key] = img + # obs state agent qpos and qvel + # image if "environment_state" in observations: return_observations["observation.environment_state"] = torch.from_numpy( @@ -58,5 +62,43 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten # TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing # requirement for "agent_pos" - return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + # return_observations["observation.state"] = torch.from_numpy(observations["agent_pos"]).float() + return_observations["observation.state"] = observations["observation.state"].float() + return return_observations + + +def preprocess_maniskill_observation( + observations: dict[str, np.ndarray], +) -> dict[str, Tensor]: + """Convert environment observation to LeRobot format observation. + Args: + observation: Dictionary of observation batches from a Gym vector environment. + Returns: + Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. + """ + # map to expected inputs for the policy + return_observations = {} + # TODO: You have to merge all tensors from agent key and extra key + # You don't keep sensor param key in the observation + # And you keep sensor data rgb + q_pos = observations["agent"]["qpos"] + q_vel = observations["agent"]["qvel"] + tcp_pos = observations["extra"]["tcp_pose"] + img = observations["sensor_data"]["base_camera"]["rgb"] + + _, h, w, c = img.shape + assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + + # convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + img /= 255 + + state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) + + return_observations["observation.image"] = img + return_observations["observation.state"] = state return return_observations diff --git a/lerobot/common/logger.py b/lerobot/common/logger.py index 4015492de..b140270bc 100644 --- a/lerobot/common/logger.py +++ b/lerobot/common/logger.py @@ -84,7 +84,9 @@ class Logger: pretrained_model_dir_name = "pretrained_model" training_state_file_name = "training_state.pth" - def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None): + def __init__( + self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = None + ): """ Args: log_dir: The directory to save all logs and training outputs to. @@ -104,7 +106,9 @@ def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = N enable_wandb = cfg.get("wandb", {}).get("enable", False) run_offline = not enable_wandb or not project if run_offline: - logging.info(colored("Logs will be saved locally.", "yellow", attrs=["bold"])) + logging.info( + colored("Logs will be saved locally.", "yellow", attrs=["bold"]) + ) self._wandb = None else: os.environ["WANDB_SILENT"] = "true" @@ -127,8 +131,12 @@ def __init__(self, cfg: DictConfig, log_dir: str, wandb_job_name: str | None = N job_type="train_eval", resume="must" if cfg.resume else None, ) + # Handle custom step key for rl asynchronous training. + self._wandb_custom_step_key: set[str] | None = None print(colored("Logs will be synced with wandb.", "blue", attrs=["bold"])) - logging.info(f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}") + logging.info( + f"Track this run --> {colored(wandb.run.get_url(), 'yellow', attrs=['bold'])}" + ) self._wandb = wandb @classmethod @@ -149,7 +157,9 @@ def get_last_pretrained_model_dir(cls, log_dir: str | Path) -> Path: """ return cls.get_last_checkpoint_dir(log_dir) / cls.pretrained_model_dir_name - def save_model(self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None): + def save_model( + self, save_dir: Path, policy: Policy, wandb_artifact_name: str | None = None + ): """Save the weights of the Policy model using PyTorchModelHubMixin. The weights are saved in a folder called "pretrained_model" under the checkpoint directory. @@ -172,18 +182,32 @@ def save_training_state( self, save_dir: Path, train_step: int, - optimizer: Optimizer, + optimizer: Optimizer | dict, scheduler: LRScheduler | None, + interaction_step: int | None = None, ): """Checkpoint the global training_step, optimizer state, scheduler state, and random state. All of these are saved as "training_state.pth" under the checkpoint directory. """ + # In Sac, for example, we have a dictionary of torch.optim.Optimizer + if type(optimizer) is dict: + optimizer_state_dict = {} + for k in optimizer: + optimizer_state_dict[k] = optimizer[k].state_dict() + else: + optimizer_state_dict = optimizer.state_dict() + training_state = { "step": train_step, - "optimizer": optimizer.state_dict(), + "optimizer": optimizer_state_dict, **get_global_random_state(), } + # Interaction step is related to the distributed training code + # In that setup, we have two kinds of steps, the online step of the env and the optimization step + # We need to save both in order to resume the optimization properly and not break the logs dependant on the interaction step + if interaction_step is not None: + training_state["interaction_step"] = interaction_step if scheduler is not None: training_state["scheduler"] = scheduler.state_dict() torch.save(training_state, save_dir / self.training_state_file_name) @@ -195,6 +219,7 @@ def save_checkpoint( optimizer: Optimizer, scheduler: LRScheduler | None, identifier: str, + interaction_step: int | None = None, ): """Checkpoint the model weights and the training state.""" checkpoint_dir = self.checkpoints_dir / str(identifier) @@ -204,18 +229,34 @@ def save_checkpoint( else f"{self._group.replace(':', '_').replace('/', '_')}-{self._cfg.seed}-{identifier}" ) self.save_model( - checkpoint_dir / self.pretrained_model_dir_name, policy, wandb_artifact_name=wandb_artifact_name + checkpoint_dir / self.pretrained_model_dir_name, + policy, + wandb_artifact_name=wandb_artifact_name, + ) + self.save_training_state( + checkpoint_dir, train_step, optimizer, scheduler, interaction_step ) - self.save_training_state(checkpoint_dir, train_step, optimizer, scheduler) os.symlink(checkpoint_dir.absolute(), self.last_checkpoint_dir) - def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler | None) -> int: + def load_last_training_state( + self, optimizer: Optimizer | dict, scheduler: LRScheduler | None + ) -> int: """ Given the last checkpoint in the logging directory, load the optimizer state, scheduler state, and random state, and return the global training step. """ - training_state = torch.load(self.last_checkpoint_dir / self.training_state_file_name) - optimizer.load_state_dict(training_state["optimizer"]) + training_state = torch.load( + self.last_checkpoint_dir / self.training_state_file_name + ) + # For the case where the optimizer is a dictionary of optimizers (e.g., sac) + if type(training_state["optimizer"]) is dict: + assert set(training_state["optimizer"].keys()) == set( + optimizer.keys() + ), "Optimizer dictionaries do not have the same keys during resume!" + for k, v in training_state["optimizer"].items(): + optimizer[k].load_state_dict(v) + else: + optimizer.load_state_dict(training_state["optimizer"]) if scheduler is not None: scheduler.load_state_dict(training_state["scheduler"]) elif "scheduler" in training_state: @@ -223,20 +264,63 @@ def load_last_training_state(self, optimizer: Optimizer, scheduler: LRScheduler "The checkpoint contains a scheduler state_dict, but no LRScheduler was provided." ) # Small hack to get the expected keys: use `get_global_random_state`. - set_global_random_state({k: training_state[k] for k in get_global_random_state()}) + set_global_random_state( + {k: training_state[k] for k in get_global_random_state()} + ) return training_state["step"] - def log_dict(self, d, step, mode="train"): + def log_dict( + self, + d, + step: int | None = None, + mode="train", + custom_step_key: str | None = None, + ): + """Log a dictionary of metrics to WandB.""" assert mode in {"train", "eval"} # TODO(alexander-soare): Add local text log. + if step is None and custom_step_key is None: + raise ValueError("Either step or custom_step_key must be provided.") + if self._wandb is not None: + # NOTE: This is not simple. Wandb step is it must always monotonically increase and it + # increases with each wandb.log call, but in the case of asynchronous RL for example, + # multiple time steps is possible for example, the interaction step with the environment, + # the training step, the evaluation step, etc. So we need to define a custom step key + # to log the correct step for each metric. + if custom_step_key is not None: + if self._wandb_custom_step_key is None: + self._wandb_custom_step_key = set() + new_custom_key = f"{mode}/{custom_step_key}" + if new_custom_key not in self._wandb_custom_step_key: + self._wandb_custom_step_key.add(new_custom_key) + self._wandb.define_metric(new_custom_key, hidden=True) + for k, v in d.items(): if not isinstance(v, (int, float, str, wandb.Table)): logging.warning( f'WandB logging of key "{k}" was ignored as its type is not handled by this wrapper.' ) continue - self._wandb.log({f"{mode}/{k}": v}, step=step) + + # Do not log the custom step key itself. + if ( + self._wandb_custom_step_key is not None + and k in self._wandb_custom_step_key + ): + continue + + if custom_step_key is not None: + value_custom_step = d[custom_step_key] + self._wandb.log( + { + f"{mode}/{k}": v, + f"{mode}/{custom_step_key}": value_custom_step, + } + ) + continue + + self._wandb.log(data={f"{mode}/{k}": v}, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): assert mode in {"train", "eval"} diff --git a/lerobot/common/policies/act/configuration_act.py b/lerobot/common/policies/act/configuration_act.py index a86c359c9..51c95097d 100644 --- a/lerobot/common/policies/act/configuration_act.py +++ b/lerobot/common/policies/act/configuration_act.py @@ -168,4 +168,6 @@ def __post_init__(self): not any(k.startswith("observation.image") for k in self.input_shapes) and "observation.environment_state" not in self.input_shapes ): - raise ValueError("You must provide at least one image or the environment state among the inputs.") + raise ValueError( + "You must provide at least one image or the environment state among the inputs." + ) diff --git a/lerobot/common/policies/act/modeling_act.py b/lerobot/common/policies/act/modeling_act.py index 418863a14..5eee2201d 100644 --- a/lerobot/common/policies/act/modeling_act.py +++ b/lerobot/common/policies/act/modeling_act.py @@ -81,10 +81,14 @@ def __init__( self.model = ACT(config) - self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + self.expected_image_keys = [ + k for k in config.input_shapes if k.startswith("observation.image") + ] if config.temporal_ensemble_coeff is not None: - self.temporal_ensembler = ACTTemporalEnsembler(config.temporal_ensemble_coeff, config.chunk_size) + self.temporal_ensembler = ACTTemporalEnsembler( + config.temporal_ensemble_coeff, config.chunk_size + ) self.reset() @@ -107,8 +111,12 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[k] for k in self.expected_image_keys], dim=-4 + ) # If we are doing temporal ensembling, do online updates where we keep track of the number of actions # we are ensembling over. @@ -135,13 +143,18 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[k] for k in self.expected_image_keys], dim=-4 + ) batch = self.normalize_targets(batch) actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch) l1_loss = ( - F.l1_loss(batch["action"], actions_hat, reduction="none") * ~batch["action_is_pad"].unsqueeze(-1) + F.l1_loss(batch["action"], actions_hat, reduction="none") + * ~batch["action_is_pad"].unsqueeze(-1) ).mean() loss_dict = {"l1_loss": l1_loss.item()} @@ -151,7 +164,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # KL-divergence per batch element, then take the mean over the batch. # (See App. B of https://arxiv.org/abs/1312.6114 for more details). mean_kld = ( - (-0.5 * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp())).sum(-1).mean() + ( + -0.5 + * (1 + log_sigma_x2_hat - mu_hat.pow(2) - (log_sigma_x2_hat).exp()) + ) + .sum(-1) + .mean() ) loss_dict["kld_loss"] = mean_kld.item() loss_dict["loss"] = l1_loss + mean_kld * self.config.kl_weight @@ -205,7 +223,9 @@ def __init__(self, temporal_ensemble_coeff: float, chunk_size: int) -> None: ``` """ self.chunk_size = chunk_size - self.ensemble_weights = torch.exp(-temporal_ensemble_coeff * torch.arange(chunk_size)) + self.ensemble_weights = torch.exp( + -temporal_ensemble_coeff * torch.arange(chunk_size) + ) self.ensemble_weights_cumsum = torch.cumsum(self.ensemble_weights, dim=0) self.reset() @@ -221,7 +241,9 @@ def update(self, actions: Tensor) -> Tensor: time steps, and pop/return the next batch of actions in the sequence. """ self.ensemble_weights = self.ensemble_weights.to(device=actions.device) - self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to(device=actions.device) + self.ensemble_weights_cumsum = self.ensemble_weights_cumsum.to( + device=actions.device + ) if self.ensembled_actions is None: # Initializes `self._ensembled_action` to the sequence of actions predicted during the first # time step of the episode. @@ -229,19 +251,34 @@ def update(self, actions: Tensor) -> Tensor: # Note: The last dimension is unsqueeze to make sure we can broadcast properly for tensor # operations later. self.ensembled_actions_count = torch.ones( - (self.chunk_size, 1), dtype=torch.long, device=self.ensembled_actions.device + (self.chunk_size, 1), + dtype=torch.long, + device=self.ensembled_actions.device, ) else: # self.ensembled_actions will have shape (batch_size, chunk_size - 1, action_dim). Compute # the online update for those entries. - self.ensembled_actions *= self.ensemble_weights_cumsum[self.ensembled_actions_count - 1] - self.ensembled_actions += actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] - self.ensembled_actions /= self.ensemble_weights_cumsum[self.ensembled_actions_count] - self.ensembled_actions_count = torch.clamp(self.ensembled_actions_count + 1, max=self.chunk_size) + self.ensembled_actions *= self.ensemble_weights_cumsum[ + self.ensembled_actions_count - 1 + ] + self.ensembled_actions += ( + actions[:, :-1] * self.ensemble_weights[self.ensembled_actions_count] + ) + self.ensembled_actions /= self.ensemble_weights_cumsum[ + self.ensembled_actions_count + ] + self.ensembled_actions_count = torch.clamp( + self.ensembled_actions_count + 1, max=self.chunk_size + ) # The last action, which has no prior online average, needs to get concatenated onto the end. - self.ensembled_actions = torch.cat([self.ensembled_actions, actions[:, -1:]], dim=1) + self.ensembled_actions = torch.cat( + [self.ensembled_actions, actions[:, -1:]], dim=1 + ) self.ensembled_actions_count = torch.cat( - [self.ensembled_actions_count, torch.ones_like(self.ensembled_actions_count[-1:])] + [ + self.ensembled_actions_count, + torch.ones_like(self.ensembled_actions_count[-1:]), + ] ) # "Consume" the first action. action, self.ensembled_actions, self.ensembled_actions_count = ( @@ -293,7 +330,9 @@ def __init__(self, config: ACTConfig): # BERT style VAE encoder with input tokens [cls, robot_state, *action_sequence]. # The cls token forms parameters of the latent's distribution (like this [*means, *log_variances]). self.use_robot_state = "observation.state" in config.input_shapes - self.use_images = any(k.startswith("observation.image") for k in config.input_shapes) + self.use_images = any( + k.startswith("observation.image") for k in config.input_shapes + ) self.use_env_state = "observation.environment_state" in config.input_shapes if self.config.use_vae: self.vae_encoder = ACTEncoder(config, is_vae_encoder=True) @@ -308,7 +347,9 @@ def __init__(self, config: ACTConfig): config.output_shapes["action"][0], config.dim_model ) # Projection layer from the VAE encoder's output to the latent distribution's parameter space. - self.vae_encoder_latent_output_proj = nn.Linear(config.dim_model, config.latent_dim * 2) + self.vae_encoder_latent_output_proj = nn.Linear( + config.dim_model, config.latent_dim * 2 + ) # Fixed sinusoidal positional embedding for the input to the VAE encoder. Unsqueeze for batch # dimension. num_input_token_encoder = 1 + config.chunk_size @@ -316,20 +357,28 @@ def __init__(self, config: ACTConfig): num_input_token_encoder += 1 self.register_buffer( "vae_encoder_pos_enc", - create_sinusoidal_pos_embedding(num_input_token_encoder, config.dim_model).unsqueeze(0), + create_sinusoidal_pos_embedding( + num_input_token_encoder, config.dim_model + ).unsqueeze(0), ) # Backbone for image feature extraction. if self.use_images: backbone_model = getattr(torchvision.models, config.vision_backbone)( - replace_stride_with_dilation=[False, False, config.replace_final_stride_with_dilation], + replace_stride_with_dilation=[ + False, + False, + config.replace_final_stride_with_dilation, + ], weights=config.pretrained_backbone_weights, norm_layer=FrozenBatchNorm2d, ) # Note: The assumption here is that we are using a ResNet model (and hence layer4 is the final # feature map). # Note: The forward method of this returns a dict: {"feature_map": output}. - self.backbone = IntermediateLayerGetter(backbone_model, return_layers={"layer4": "feature_map"}) + self.backbone = IntermediateLayerGetter( + backbone_model, return_layers={"layer4": "feature_map"} + ) # Transformer (acts as VAE decoder when training with the variational objective). self.encoder = ACTEncoder(config) @@ -343,7 +392,8 @@ def __init__(self, config: ACTConfig): ) if self.use_env_state: self.encoder_env_state_input_proj = nn.Linear( - config.input_shapes["observation.environment_state"][0], config.dim_model + config.input_shapes["observation.environment_state"][0], + config.dim_model, ) self.encoder_latent_input_proj = nn.Linear(config.latent_dim, config.dim_model) if self.use_images: @@ -358,14 +408,18 @@ def __init__(self, config: ACTConfig): n_1d_tokens += 1 self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model) if self.use_images: - self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2) + self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d( + config.dim_model // 2 + ) # Transformer decoder. # Learnable positional embedding for the transformer's decoder (in the style of DETR object queries). self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model) # Final action regression head on the output of the transformer's decoder. - self.action_head = nn.Linear(config.dim_model, config.output_shapes["action"][0]) + self.action_head = nn.Linear( + config.dim_model, config.output_shapes["action"][0] + ) self._reset_parameters() @@ -375,7 +429,9 @@ def _reset_parameters(self): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: + def forward( + self, batch: dict[str, Tensor] + ) -> tuple[Tensor, tuple[Tensor, Tensor] | tuple[None, None]]: """A forward pass through the Action Chunking Transformer (with optional VAE encoder). `batch` should have the following structure: @@ -412,12 +468,20 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size ) # (B, 1, D) if self.use_robot_state: - robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"]) + robot_state_embed = self.vae_encoder_robot_state_input_proj( + batch["observation.state"] + ) robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D) - action_embed = self.vae_encoder_action_input_proj(batch["action"]) # (B, S, D) + action_embed = self.vae_encoder_action_input_proj( + batch["action"] + ) # (B, S, D) if self.use_robot_state: - vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D) + vae_encoder_input = [ + cls_embed, + robot_state_embed, + action_embed, + ] # (B, S+2, D) else: vae_encoder_input = [cls_embed, action_embed] vae_encoder_input = torch.cat(vae_encoder_input, axis=1) @@ -455,20 +519,26 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso # When not using the VAE encoder, we set the latent to be all zeros. mu = log_sigma_x2 = None # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use buffer - latent_sample = torch.zeros([batch_size, self.config.latent_dim], dtype=torch.float32).to( - batch["observation.state"].device - ) + latent_sample = torch.zeros( + [batch_size, self.config.latent_dim], dtype=torch.float32 + ).to(batch["observation.state"].device) # Prepare transformer encoder inputs. encoder_in_tokens = [self.encoder_latent_input_proj(latent_sample)] - encoder_in_pos_embed = list(self.encoder_1d_feature_pos_embed.weight.unsqueeze(1)) + encoder_in_pos_embed = list( + self.encoder_1d_feature_pos_embed.weight.unsqueeze(1) + ) # Robot state token. if self.use_robot_state: - encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"])) + encoder_in_tokens.append( + self.encoder_robot_state_input_proj(batch["observation.state"]) + ) # Environment state token. if self.use_env_state: encoder_in_tokens.append( - self.encoder_env_state_input_proj(batch["observation.environment_state"]) + self.encoder_env_state_input_proj( + batch["observation.environment_state"] + ) ) # Camera observation features and positional embeddings. @@ -477,19 +547,29 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso all_cam_pos_embeds = [] for cam_index in range(batch["observation.images"].shape[-4]): - cam_features = self.backbone(batch["observation.images"][:, cam_index])["feature_map"] + cam_features = self.backbone(batch["observation.images"][:, cam_index])[ + "feature_map" + ] # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use # buffer - cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype) - cam_features = self.encoder_img_feat_input_proj(cam_features) # (B, C, h, w) + cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to( + dtype=cam_features.dtype + ) + cam_features = self.encoder_img_feat_input_proj( + cam_features + ) # (B, C, h, w) all_cam_features.append(cam_features) all_cam_pos_embeds.append(cam_pos_embed) # Concatenate camera observation feature maps and positional embeddings along the width dimension, # and move to (sequence, batch, dim). all_cam_features = torch.cat(all_cam_features, axis=-1) - encoder_in_tokens.extend(einops.rearrange(all_cam_features, "b c h w -> (h w) b c")) + encoder_in_tokens.extend( + einops.rearrange(all_cam_features, "b c h w -> (h w) b c") + ) all_cam_pos_embeds = torch.cat(all_cam_pos_embeds, axis=-1) - encoder_in_pos_embed.extend(einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c")) + encoder_in_pos_embed.extend( + einops.rearrange(all_cam_pos_embeds, "b c h w -> (h w) b c") + ) # Stack all tokens along the sequence dimension. encoder_in_tokens = torch.stack(encoder_in_tokens, axis=0) @@ -524,12 +604,21 @@ class ACTEncoder(nn.Module): def __init__(self, config: ACTConfig, is_vae_encoder: bool = False): super().__init__() self.is_vae_encoder = is_vae_encoder - num_layers = config.n_vae_encoder_layers if self.is_vae_encoder else config.n_encoder_layers - self.layers = nn.ModuleList([ACTEncoderLayer(config) for _ in range(num_layers)]) + num_layers = ( + config.n_vae_encoder_layers + if self.is_vae_encoder + else config.n_encoder_layers + ) + self.layers = nn.ModuleList( + [ACTEncoderLayer(config) for _ in range(num_layers)] + ) self.norm = nn.LayerNorm(config.dim_model) if config.pre_norm else nn.Identity() def forward( - self, x: Tensor, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + self, + x: Tensor, + pos_embed: Tensor | None = None, + key_padding_mask: Tensor | None = None, ) -> Tensor: for layer in self.layers: x = layer(x, pos_embed=pos_embed, key_padding_mask=key_padding_mask) @@ -540,7 +629,9 @@ def forward( class ACTEncoderLayer(nn.Module): def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.self_attn = nn.MultiheadAttention( + config.dim_model, config.n_heads, dropout=config.dropout + ) # Feed forward layers. self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) @@ -555,7 +646,9 @@ def __init__(self, config: ACTConfig): self.activation = get_activation_fn(config.feedforward_activation) self.pre_norm = config.pre_norm - def forward(self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None) -> Tensor: + def forward( + self, x, pos_embed: Tensor | None = None, key_padding_mask: Tensor | None = None + ) -> Tensor: skip = x if self.pre_norm: x = self.norm1(x) @@ -580,7 +673,9 @@ class ACTDecoder(nn.Module): def __init__(self, config: ACTConfig): """Convenience module for running multiple decoder layers followed by normalization.""" super().__init__() - self.layers = nn.ModuleList([ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)]) + self.layers = nn.ModuleList( + [ACTDecoderLayer(config) for _ in range(config.n_decoder_layers)] + ) self.norm = nn.LayerNorm(config.dim_model) def forward( @@ -592,7 +687,10 @@ def forward( ) -> Tensor: for layer in self.layers: x = layer( - x, encoder_out, decoder_pos_embed=decoder_pos_embed, encoder_pos_embed=encoder_pos_embed + x, + encoder_out, + decoder_pos_embed=decoder_pos_embed, + encoder_pos_embed=encoder_pos_embed, ) if self.norm is not None: x = self.norm(x) @@ -602,8 +700,12 @@ def forward( class ACTDecoderLayer(nn.Module): def __init__(self, config: ACTConfig): super().__init__() - self.self_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) - self.multihead_attn = nn.MultiheadAttention(config.dim_model, config.n_heads, dropout=config.dropout) + self.self_attn = nn.MultiheadAttention( + config.dim_model, config.n_heads, dropout=config.dropout + ) + self.multihead_attn = nn.MultiheadAttention( + config.dim_model, config.n_heads, dropout=config.dropout + ) # Feed forward layers. self.linear1 = nn.Linear(config.dim_model, config.dim_feedforward) @@ -644,7 +746,9 @@ def forward( if self.pre_norm: x = self.norm1(x) q = k = self.maybe_add_pos_embed(x, decoder_pos_embed) - x = self.self_attn(q, k, value=x)[0] # select just the output, not the attention weights + x = self.self_attn(q, k, value=x)[ + 0 + ] # select just the output, not the attention weights x = skip + self.dropout1(x) if self.pre_norm: skip = x @@ -681,9 +785,14 @@ def create_sinusoidal_pos_embedding(num_positions: int, dimension: int) -> Tenso """ def get_position_angle_vec(position): - return [position / np.power(10000, 2 * (hid_j // 2) / dimension) for hid_j in range(dimension)] - - sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(num_positions)]) + return [ + position / np.power(10000, 2 * (hid_j // 2) / dimension) + for hid_j in range(dimension) + ] + + sinusoid_table = np.array( + [get_position_angle_vec(pos_i) for pos_i in range(num_positions)] + ) sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 return torch.from_numpy(sinusoid_table).float() @@ -728,7 +837,9 @@ def forward(self, x: Tensor) -> Tensor: x_range = x_range / (x_range[:, :, -1:] + self._eps) * self._two_pi inverse_frequency = self._temperature ** ( - 2 * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) / self.dimension + 2 + * (torch.arange(self.dimension, dtype=torch.float32, device=x.device) // 2) + / self.dimension ) x_range = x_range.unsqueeze(-1) / inverse_frequency # (1, H, W, 1) @@ -736,9 +847,15 @@ def forward(self, x: Tensor) -> Tensor: # Note: this stack then flatten operation results in interleaved sine and cosine terms. # pos_embed_x and pos_embed_y are (1, H, W, C // 2). - pos_embed_x = torch.stack((x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1).flatten(3) - pos_embed_y = torch.stack((y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1).flatten(3) - pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute(0, 3, 1, 2) # (1, C, H, W) + pos_embed_x = torch.stack( + (x_range[..., 0::2].sin(), x_range[..., 1::2].cos()), dim=-1 + ).flatten(3) + pos_embed_y = torch.stack( + (y_range[..., 0::2].sin(), y_range[..., 1::2].cos()), dim=-1 + ).flatten(3) + pos_embed = torch.cat((pos_embed_y, pos_embed_x), dim=3).permute( + 0, 3, 1, 2 + ) # (1, C, H, W) return pos_embed diff --git a/lerobot/common/policies/diffusion/configuration_diffusion.py b/lerobot/common/policies/diffusion/configuration_diffusion.py index 531f49e4d..4ee53c866 100644 --- a/lerobot/common/policies/diffusion/configuration_diffusion.py +++ b/lerobot/common/policies/diffusion/configuration_diffusion.py @@ -121,7 +121,9 @@ class DiffusionConfig: "observation.state": "min_max", } ) - output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"} + ) # Architecture / modeling. # Vision backbone. @@ -163,8 +165,13 @@ def __post_init__(self): image_keys = {k for k in self.input_shapes if k.startswith("observation.image")} - if len(image_keys) == 0 and "observation.environment_state" not in self.input_shapes: - raise ValueError("You must provide at least one image or the environment state among the inputs.") + if ( + len(image_keys) == 0 + and "observation.environment_state" not in self.input_shapes + ): + raise ValueError( + "You must provide at least one image or the environment state among the inputs." + ) if len(image_keys) > 0: if self.crop_shape is not None: diff --git a/lerobot/common/policies/diffusion/modeling_diffusion.py b/lerobot/common/policies/diffusion/modeling_diffusion.py index 9ba562600..7f6858bed 100644 --- a/lerobot/common/policies/diffusion/modeling_diffusion.py +++ b/lerobot/common/policies/diffusion/modeling_diffusion.py @@ -88,7 +88,9 @@ def __init__( self.diffusion = DiffusionModel(config) - self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + self.expected_image_keys = [ + k for k in config.input_shapes if k.startswith("observation.image") + ] self.use_env_state = "observation.environment_state" in config.input_shapes self.reset() @@ -102,7 +104,9 @@ def reset(self): if len(self.expected_image_keys) > 0: self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps) if self.use_env_state: - self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps) + self._queues["observation.environment_state"] = deque( + maxlen=self.config.n_obs_steps + ) @torch.no_grad def select_action(self, batch: dict[str, Tensor]) -> Tensor: @@ -128,14 +132,22 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: """ batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[k] for k in self.expected_image_keys], dim=-4 + ) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) if len(self._queues["action"]) == 0: # stack n latest observations from the queue - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + batch = { + k: torch.stack(list(self._queues[k]), dim=1) + for k in batch + if k in self._queues + } actions = self.diffusion.generate_actions(batch) # TODO(rcadene): make above methods return output dictionary? @@ -150,8 +162,12 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) if len(self.expected_image_keys) > 0: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[k] for k in self.expected_image_keys], dim=-4 + ) batch = self.normalize_targets(batch) loss = self.diffusion.compute_loss(batch) return {"loss": loss} @@ -177,7 +193,9 @@ def __init__(self, config: DiffusionConfig): # Build observation encoders (depending on which observations are provided). global_cond_dim = config.input_shapes["observation.state"][0] - num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) + num_images = len( + [k for k in config.input_shapes if k.startswith("observation.image")] + ) self._use_images = False self._use_env_state = False if num_images > 0: @@ -193,7 +211,9 @@ def __init__(self, config: DiffusionConfig): self._use_env_state = True global_cond_dim += config.input_shapes["observation.environment_state"][0] - self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps) + self.unet = DiffusionConditionalUnet1d( + config, global_cond_dim=global_cond_dim * config.n_obs_steps + ) self.noise_scheduler = _make_noise_scheduler( config.noise_scheduler_type, @@ -213,14 +233,21 @@ def __init__(self, config: DiffusionConfig): # ========= inference ============ def conditional_sample( - self, batch_size: int, global_cond: Tensor | None = None, generator: torch.Generator | None = None + self, + batch_size: int, + global_cond: Tensor | None = None, + generator: torch.Generator | None = None, ) -> Tensor: device = get_device_from_parameters(self) dtype = get_dtype_from_parameters(self) # Sample prior. sample = torch.randn( - size=(batch_size, self.config.horizon, self.config.output_shapes["action"][0]), + size=( + batch_size, + self.config.horizon, + self.config.output_shapes["action"][0], + ), dtype=dtype, device=device, generator=generator, @@ -236,7 +263,9 @@ def conditional_sample( global_cond=global_cond, ) # Compute previous image: x_t -> x_t-1 - sample = self.noise_scheduler.step(model_output, t, sample, generator=generator).prev_sample + sample = self.noise_scheduler.step( + model_output, t, sample, generator=generator + ).prev_sample return sample @@ -248,27 +277,39 @@ def _prepare_global_conditioning(self, batch: dict[str, Tensor]) -> Tensor: if self._use_images: if self.config.use_separate_rgb_encoder_per_camera: # Combine batch and sequence dims while rearranging to make the camera index dimension first. - images_per_camera = einops.rearrange(batch["observation.images"], "b s n ... -> n (b s) ...") + images_per_camera = einops.rearrange( + batch["observation.images"], "b s n ... -> n (b s) ..." + ) img_features_list = torch.cat( [ encoder(images) - for encoder, images in zip(self.rgb_encoder, images_per_camera, strict=True) + for encoder, images in zip( + self.rgb_encoder, images_per_camera, strict=True + ) ] ) # Separate batch and sequence dims back out. The camera index dim gets absorbed into the # feature dim (effectively concatenating the camera features). img_features = einops.rearrange( - img_features_list, "(n b s) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + img_features_list, + "(n b s) ... -> b s (n ...)", + b=batch_size, + s=n_obs_steps, ) else: # Combine batch, sequence, and "which camera" dims before passing to shared encoder. img_features = self.rgb_encoder( - einops.rearrange(batch["observation.images"], "b s n ... -> (b s n) ...") + einops.rearrange( + batch["observation.images"], "b s n ... -> (b s n) ..." + ) ) # Separate batch dim and sequence dim back out. The camera index dim gets absorbed into the # feature dim (effectively concatenating the camera features). img_features = einops.rearrange( - img_features, "(b s n) ... -> b s (n ...)", b=batch_size, s=n_obs_steps + img_features, + "(b s n) ... -> b s (n ...)", + b=batch_size, + s=n_obs_steps, ) global_cond_feats.append(img_features) @@ -354,7 +395,9 @@ def compute_loss(self, batch: dict[str, Tensor]) -> Tensor: elif self.config.prediction_type == "sample": target = batch["action"] else: - raise ValueError(f"Unsupported prediction type {self.config.prediction_type}") + raise ValueError( + f"Unsupported prediction type {self.config.prediction_type}" + ) loss = F.mse_loss(pred, target, reduction="none") @@ -414,7 +457,9 @@ def __init__(self, input_shape, num_kp=None): # we could use torch.linspace directly but that seems to behave slightly differently than numpy # and causes a small degradation in pc_success of pre-trained models. - pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x, pos_y = np.meshgrid( + np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) + ) pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() # register as buffer so it's moved to the correct device. @@ -456,7 +501,9 @@ def __init__(self, config: DiffusionConfig): # Always use center crop for eval self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) if config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + self.maybe_random_crop = torchvision.transforms.RandomCrop( + config.crop_shape + ) else: self.maybe_random_crop = self.center_crop else: @@ -477,7 +524,9 @@ def __init__(self, config: DiffusionConfig): self.backbone = _replace_submodules( root_module=self.backbone, predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + func=lambda x: nn.GroupNorm( + num_groups=x.num_features // 16, num_channels=x.num_features + ), ) # Set up pooling and final layers. @@ -485,17 +534,25 @@ def __init__(self, config: DiffusionConfig): # The dummy input should take the number of image channels from `config.input_shapes` and it should # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # height and width from `config.input_shapes`. - image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + image_keys = [ + k for k in config.input_shapes if k.startswith("observation.image") + ] # Note: we have a check in the config class to make sure all images have the same shape. image_key = image_keys[0] dummy_input_h_w = ( - config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] + config.crop_shape + if config.crop_shape is not None + else config.input_shapes[image_key][1:] + ) + dummy_input = torch.zeros( + size=(1, config.input_shapes[image_key][0], *dummy_input_h_w) ) - dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:]) - self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.pool = SpatialSoftmax( + feature_map_shape, num_kp=config.spatial_softmax_num_keypoints + ) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() @@ -522,7 +579,9 @@ def forward(self, x: Tensor) -> Tensor: def _replace_submodules( - root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] + root_module: nn.Module, + predicate: Callable[[nn.Module], bool], + func: Callable[[nn.Module], nn.Module], ) -> nn.Module: """ Args: @@ -535,7 +594,11 @@ def _replace_submodules( if predicate(root_module): return func(root_module) - replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + replace_list = [ + k.split(".") + for k, m in root_module.named_modules(remove_duplicate=True) + if predicate(m) + ] for *parents, k in replace_list: parent_module = root_module if len(parents) > 0: @@ -550,7 +613,9 @@ def _replace_submodules( else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced - assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + assert not any( + predicate(m) for _, m in root_module.named_modules(remove_duplicate=True) + ) return root_module @@ -578,7 +643,9 @@ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): super().__init__() self.block = nn.Sequential( - nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2), + nn.Conv1d( + inp_channels, out_channels, kernel_size, padding=kernel_size // 2 + ), nn.GroupNorm(n_groups, out_channels), nn.Mish(), ) @@ -601,9 +668,13 @@ def __init__(self, config: DiffusionConfig, global_cond_dim: int): # Encoder for the diffusion timestep. self.diffusion_step_encoder = nn.Sequential( DiffusionSinusoidalPosEmb(config.diffusion_step_embed_dim), - nn.Linear(config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4), + nn.Linear( + config.diffusion_step_embed_dim, config.diffusion_step_embed_dim * 4 + ), nn.Mish(), - nn.Linear(config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim), + nn.Linear( + config.diffusion_step_embed_dim * 4, config.diffusion_step_embed_dim + ), ) # The FiLM conditioning dimension. @@ -628,10 +699,16 @@ def __init__(self, config: DiffusionConfig, global_cond_dim: int): self.down_modules.append( nn.ModuleList( [ - DiffusionConditionalResidualBlock1d(dim_in, dim_out, **common_res_block_kwargs), - DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d( + dim_in, dim_out, **common_res_block_kwargs + ), + DiffusionConditionalResidualBlock1d( + dim_out, dim_out, **common_res_block_kwargs + ), # Downsample as long as it is not the last block. - nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity(), + nn.Conv1d(dim_out, dim_out, 3, 2, 1) + if not is_last + else nn.Identity(), ] ) ) @@ -640,10 +717,14 @@ def __init__(self, config: DiffusionConfig, global_cond_dim: int): self.mid_modules = nn.ModuleList( [ DiffusionConditionalResidualBlock1d( - config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs + config.down_dims[-1], + config.down_dims[-1], + **common_res_block_kwargs, ), DiffusionConditionalResidualBlock1d( - config.down_dims[-1], config.down_dims[-1], **common_res_block_kwargs + config.down_dims[-1], + config.down_dims[-1], + **common_res_block_kwargs, ), ] ) @@ -656,16 +737,24 @@ def __init__(self, config: DiffusionConfig, global_cond_dim: int): nn.ModuleList( [ # dim_in * 2, because it takes the encoder's skip connection as well - DiffusionConditionalResidualBlock1d(dim_in * 2, dim_out, **common_res_block_kwargs), - DiffusionConditionalResidualBlock1d(dim_out, dim_out, **common_res_block_kwargs), + DiffusionConditionalResidualBlock1d( + dim_in * 2, dim_out, **common_res_block_kwargs + ), + DiffusionConditionalResidualBlock1d( + dim_out, dim_out, **common_res_block_kwargs + ), # Upsample as long as it is not the last block. - nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) if not is_last else nn.Identity(), + nn.ConvTranspose1d(dim_out, dim_out, 4, 2, 1) + if not is_last + else nn.Identity(), ] ) ) self.final_conv = nn.Sequential( - DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size), + DiffusionConv1dBlock( + config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size + ), nn.Conv1d(config.down_dims[0], config.output_shapes["action"][0], 1), ) @@ -733,17 +822,23 @@ def __init__( self.use_film_scale_modulation = use_film_scale_modulation self.out_channels = out_channels - self.conv1 = DiffusionConv1dBlock(in_channels, out_channels, kernel_size, n_groups=n_groups) + self.conv1 = DiffusionConv1dBlock( + in_channels, out_channels, kernel_size, n_groups=n_groups + ) # FiLM modulation (https://arxiv.org/abs/1709.07871) outputs per-channel bias and (maybe) scale. cond_channels = out_channels * 2 if use_film_scale_modulation else out_channels self.cond_encoder = nn.Sequential(nn.Mish(), nn.Linear(cond_dim, cond_channels)) - self.conv2 = DiffusionConv1dBlock(out_channels, out_channels, kernel_size, n_groups=n_groups) + self.conv2 = DiffusionConv1dBlock( + out_channels, out_channels, kernel_size, n_groups=n_groups + ) # A final convolution for dimension matching the residual (if needed). self.residual_conv = ( - nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity() + nn.Conv1d(in_channels, out_channels, 1) + if in_channels != out_channels + else nn.Identity() ) def forward(self, x: Tensor, cond: Tensor) -> Tensor: diff --git a/lerobot/common/policies/factory.py b/lerobot/common/policies/factory.py index 7f550d909..814a4d0ab 100644 --- a/lerobot/common/policies/factory.py +++ b/lerobot/common/policies/factory.py @@ -52,7 +52,9 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: return TDMPCPolicy, TDMPCConfig elif name == "diffusion": - from lerobot.common.policies.diffusion.configuration_diffusion import DiffusionConfig + from lerobot.common.policies.diffusion.configuration_diffusion import ( + DiffusionConfig, + ) from lerobot.common.policies.diffusion.modeling_diffusion import DiffusionPolicy return DiffusionPolicy, DiffusionConfig @@ -71,13 +73,16 @@ def get_policy_and_config_classes(name: str) -> tuple[Policy, object]: from lerobot.common.policies.sac.modeling_sac import SACPolicy return SACPolicy, SACConfig - else: raise NotImplementedError(f"Policy with name {name} is not implemented.") def make_policy( - hydra_cfg: DictConfig, pretrained_policy_name_or_path: str | None = None, dataset_stats=None + hydra_cfg: DictConfig, + pretrained_policy_name_or_path: str | None = None, + dataset_stats=None, + *args, + **kwargs, ) -> Policy: """Make an instance of a policy class. @@ -91,17 +96,19 @@ def make_policy( be provided when initializing a new policy, and must not be provided when loading a pretrained policy. Therefore, this argument is mutually exclusive with `pretrained_policy_name_or_path`. """ - if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): - raise ValueError( - "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided." - ) + # if not (pretrained_policy_name_or_path is None) ^ (dataset_stats is None): + # raise ValueError( + # "Exactly one of `pretrained_policy_name_or_path` and `dataset_stats` must be provided." + # ) policy_cls, policy_cfg_class = get_policy_and_config_classes(hydra_cfg.policy.name) policy_cfg = _policy_cfg_from_hydra_cfg(policy_cfg_class, hydra_cfg) if pretrained_policy_name_or_path is None: # Make a fresh policy. - policy = policy_cls(policy_cfg, dataset_stats) + # HACK: We pass *args and **kwargs to the policy constructor to allow for additional arguments + # for example device for the sac policy. + policy = policy_cls(config=policy_cfg, dataset_stats=dataset_stats) else: # Load a pretrained policy and override the config if needed (for example, if there are inference-time # hyperparameters that we want to vary). @@ -110,7 +117,9 @@ def make_policy( # huggingface_hub should make it possible to avoid the hack: # https://github.com/huggingface/huggingface_hub/pull/2274. policy = policy_cls(policy_cfg) - policy.load_state_dict(policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict()) + policy.load_state_dict( + policy_cls.from_pretrained(pretrained_policy_name_or_path).state_dict() + ) policy.to(get_safe_torch_device(hydra_cfg.device)) diff --git a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py index de3742ecf..fe7eb1425 100644 --- a/lerobot/common/policies/hilserl/classifier/configuration_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/configuration_classifier.py @@ -10,7 +10,7 @@ class ClassifierConfig: num_classes: int = 2 hidden_dim: int = 256 dropout_rate: float = 0.1 - model_name: str = "microsoft/resnet-50" + model_name: str = "helper2424/resnet10" device: str = "cpu" model_type: str = "cnn" # "transformer" or "cnn" num_cameras: int = 2 diff --git a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py index 4a0223357..eb023f9fc 100644 --- a/lerobot/common/policies/hilserl/classifier/modeling_classifier.py +++ b/lerobot/common/policies/hilserl/classifier/modeling_classifier.py @@ -7,7 +7,9 @@ from .configuration_classifier import ClassifierConfig -logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) logger = logging.getLogger(__name__) @@ -15,7 +17,10 @@ class ClassifierOutput: """Wrapper for classifier outputs with additional metadata.""" def __init__( - self, logits: Tensor, probabilities: Optional[Tensor] = None, hidden_states: Optional[Tensor] = None + self, + logits: Tensor, + probabilities: Optional[Tensor] = None, + hidden_states: Optional[Tensor] = None, ): self.logits = logits self.probabilities = probabilities @@ -43,12 +48,14 @@ class Classifier( name = "classifier" def __init__(self, config: ClassifierConfig): - from transformers import AutoImageProcessor, AutoModel + from transformers import AutoModel super().__init__() self.config = config - self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) - encoder = AutoModel.from_pretrained(self.config.model_name, trust_remote_code=True) + # self.processor = AutoImageProcessor.from_pretrained(self.config.model_name, trust_remote_code=True) + encoder = AutoModel.from_pretrained( + self.config.model_name, trust_remote_code=True + ) # Extract vision model if we're given a multimodal model if hasattr(encoder, "vision_model"): logging.info("Multimodal model detected - using vision encoder only") @@ -74,7 +81,9 @@ def _setup_cnn_backbone(self): self.feature_dim = self.encoder.fc.in_features self.encoder = nn.Sequential(*list(self.encoder.children())[:-1]) elif hasattr(self.encoder.config, "hidden_sizes"): - self.feature_dim = self.encoder.config.hidden_sizes[-1] # Last channel dimension + self.feature_dim = self.encoder.config.hidden_sizes[ + -1 + ] # Last channel dimension else: raise ValueError("Unsupported CNN architecture") @@ -94,25 +103,31 @@ def _build_classifier_head(self) -> None: if hasattr(self.encoder.config, "hidden_size"): input_dim = self.encoder.config.hidden_size else: - raise ValueError("Unsupported transformer architecture since hidden_size is not found") + raise ValueError( + "Unsupported transformer architecture since hidden_size is not found" + ) self.classifier_head = nn.Sequential( nn.Linear(input_dim * self.config.num_cameras, self.config.hidden_dim), nn.Dropout(self.config.dropout_rate), nn.LayerNorm(self.config.hidden_dim), nn.ReLU(), - nn.Linear(self.config.hidden_dim, 1 if self.config.num_classes == 2 else self.config.num_classes), + nn.Linear( + self.config.hidden_dim, + 1 if self.config.num_classes == 2 else self.config.num_classes, + ), ) self.classifier_head = self.classifier_head.to(self.config.device) def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: """Extract the appropriate output from the encoder.""" # Process images with the processor (handles resizing and normalization) - processed = self.processor( - images=x, # LeRobotDataset already provides proper tensor format - return_tensors="pt", - ) - processed = processed["pixel_values"].to(x.device) + # processed = self.processor( + # images=x, # LeRobotDataset already provides proper tensor format + # return_tensors="pt", + # ) + # processed = processed["pixel_values"].to(x.device) + processed = x with torch.no_grad(): if self.is_cnn: @@ -126,7 +141,10 @@ def _get_encoder_output(self, x: torch.Tensor) -> torch.Tensor: return features else: # Transformer models outputs = self.encoder(processed) - if hasattr(outputs, "pooler_output") and outputs.pooler_output is not None: + if ( + hasattr(outputs, "pooler_output") + and outputs.pooler_output is not None + ): return outputs.pooler_output return outputs.last_hidden_state[:, 0, :] @@ -142,10 +160,14 @@ def forward(self, xs: torch.Tensor) -> ClassifierOutput: else: probabilities = torch.softmax(logits, dim=-1) - return ClassifierOutput(logits=logits, probabilities=probabilities, hidden_states=encoder_outputs) + return ClassifierOutput( + logits=logits, probabilities=probabilities, hidden_states=encoder_outputs + ) - def predict_reward(self, x): + def predict_reward(self, x, threshold=0.6): if self.config.num_classes == 2: - return (self.forward(x).probabilities > 0.5).float() + probs = self.forward(x).probabilities + logging.debug(f"Predicted reward images: {probs}") + return (probs > threshold).float() else: return torch.argmax(self.forward(x).probabilities, dim=1) diff --git a/lerobot/common/policies/normalize.py b/lerobot/common/policies/normalize.py index f2e1179c0..8dbe048d6 100644 --- a/lerobot/common/policies/normalize.py +++ b/lerobot/common/policies/normalize.py @@ -130,7 +130,7 @@ def __init__( setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad + # @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, mode in self.modes.items(): @@ -196,7 +196,7 @@ def __init__( setattr(self, "buffer_" + key.replace(".", "_"), buffer) # TODO(rcadene): should we remove torch.no_grad? - @torch.no_grad + # @torch.no_grad def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) # shallow copy avoids mutating the input batch for key, mode in self.modes.items(): diff --git a/lerobot/common/policies/sac/configuration_sac.py b/lerobot/common/policies/sac/configuration_sac.py index 4ae6e5d42..b834896e7 100644 --- a/lerobot/common/policies/sac/configuration_sac.py +++ b/lerobot/common/policies/sac/configuration_sac.py @@ -16,6 +16,7 @@ # limitations under the License. from dataclasses import dataclass, field +from typing import Any @dataclass @@ -28,39 +29,78 @@ class SACConfig: ) output_shapes: dict[str, list[int]] = field( default_factory=lambda: { - "action": [4], + "action": [2], + } + ) + input_normalization_modes: dict[str, str] = field( + default_factory=lambda: { + "observation.image": "mean_std", + "observation.state": "min_max", + "observation.environment_state": "min_max", + } + ) + input_normalization_params: dict[str, dict[str, list[float]]] = field( + default_factory=lambda: { + "observation.image": { + "mean": [[0.485, 0.456, 0.406]], + "std": [[0.229, 0.224, 0.225]], + }, + "observation.state": {"min": [-1, -1, -1, -1], "max": [1, 1, 1, 1]}, } ) - - # Normalization / Unnormalization - input_normalization_modes: dict[str, str] | None = None output_normalization_modes: dict[str, str] = field( - default_factory=lambda: {"action": "min_max"}, + default_factory=lambda: {"action": "min_max"} + ) + output_normalization_params: dict[str, dict[str, list[float]]] = field( + default_factory=lambda: { + "action": {"min": [-1, -1], "max": [1, 1]}, + } + ) + # TODO: Move it outside of the config + actor_learner_config: dict[str, str | int] = field( + default_factory=lambda: { + "learner_host": "127.0.0.1", + "learner_port": 50051, + } ) + camera_number: int = 1 - discount = 0.99 - temperature_init = 1.0 - num_critics = 2 - num_subsample_critics = None - critic_lr = 3e-4 - actor_lr = 3e-4 - temperature_lr = 3e-4 - critic_target_update_weight = 0.005 - utd_ratio = 2 - state_encoder_hidden_dim = 256 - latent_dim = 128 - target_entropy = None - backup_entropy = True - critic_network_kwargs = { - "hidden_dims": [256, 256], - "activate_final": True, - } - actor_network_kwargs = { - "hidden_dims": [256, 256], - "activate_final": True, - } - policy_kwargs = { - "use_tanh_squash": True, - "log_std_min": -5, - "log_std_max": 2, - } + storage_device: str = "cpu" + # Add type annotations for these fields: + vision_encoder_name: str | None = field(default="helper2424/resnet10") + freeze_vision_encoder: bool = True + image_encoder_hidden_dim: int = 32 + shared_encoder: bool = True + discount: float = 0.99 + temperature_init: float = 1.0 + num_critics: int = 2 + num_subsample_critics: int | None = None + critic_lr: float = 3e-4 + actor_lr: float = 3e-4 + temperature_lr: float = 3e-4 + critic_target_update_weight: float = 0.005 + utd_ratio: int = 1 # If you want enable utd_ratio, you need to set it to >1 + state_encoder_hidden_dim: int = 256 + latent_dim: int = 256 + target_entropy: float | None = None + use_backup_entropy: bool = True + critic_network_kwargs: dict[str, Any] = field( + default_factory=lambda: { + "hidden_dims": [256, 256], + "activate_final": True, + } + ) + actor_network_kwargs: dict[str, Any] = field( + default_factory=lambda: { + "hidden_dims": [256, 256], + "activate_final": True, + } + ) + policy_kwargs: dict[str, Any] = field( + default_factory=lambda: { + "use_tanh_squash": True, + "log_std_min": -5, + "log_std_max": 2, + "init_final": 0.005, + } + ) diff --git a/lerobot/common/policies/sac/modeling_sac.py b/lerobot/common/policies/sac/modeling_sac.py index 62725ce1d..afbbc9451 100644 --- a/lerobot/common/policies/sac/modeling_sac.py +++ b/lerobot/common/policies/sac/modeling_sac.py @@ -17,12 +17,13 @@ # TODO: (1) better device management -from collections import deque from copy import deepcopy -from typing import Callable, Optional, Sequence, Tuple +from typing import Callable, Optional, Tuple, Union, Dict +from pathlib import Path import einops import numpy as np +from tensordict import from_modules import torch import torch.nn as nn import torch.nn.functional as F # noqa: N812 @@ -31,6 +32,7 @@ from lerobot.common.policies.normalize import Normalize, Unnormalize from lerobot.common.policies.sac.configuration_sac import SACConfig +from lerobot.common.policies.utils import get_device_from_parameters class SACPolicy( @@ -43,7 +45,9 @@ class SACPolicy( name = "sac" def __init__( - self, config: SACConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None + self, + config: SACConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): super().__init__() @@ -52,200 +56,367 @@ def __init__( self.config = config if config.input_normalization_modes is not None: + input_normalization_params = _convert_normalization_params_to_tensor( + config.input_normalization_params + ) self.normalize_inputs = Normalize( - config.input_shapes, config.input_normalization_modes, dataset_stats + config.input_shapes, + config.input_normalization_modes, + input_normalization_params, ) else: self.normalize_inputs = nn.Identity() + + output_normalization_params = _convert_normalization_params_to_tensor( + config.output_normalization_params + ) + + # HACK: This is hacky and should be removed + dataset_stats = dataset_stats or output_normalization_params self.normalize_targets = Normalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) self.unnormalize_outputs = Unnormalize( config.output_shapes, config.output_normalization_modes, dataset_stats ) - encoder_critic = SACObservationEncoder(config) - encoder_actor = SACObservationEncoder(config) - # Define networks - critic_nets = [] - for _ in range(config.num_critics): - critic_net = Critic( - encoder=encoder_critic, - network=MLP( - input_dim=encoder_critic.output_dim + config.output_shapes["action"][0], - **config.critic_network_kwargs - ) - ) - critic_nets.append(critic_net) - self.critic_ensemble = create_critic_ensemble(critic_nets, config.num_critics) - self.critic_target = deepcopy(self.critic_ensemble) + # NOTE: For images the encoder should be shared between the actor and critic + if config.shared_encoder: + encoder_critic = SACObservationEncoder(config, self.normalize_inputs) + encoder_actor: SACObservationEncoder = encoder_critic + else: + encoder_critic = SACObservationEncoder(config, self.normalize_inputs) + encoder_actor = SACObservationEncoder(config, self.normalize_inputs) + + self.critic_ensemble = CriticEnsemble( + encoder=encoder_critic, + ensemble=Ensemble( + [ + CriticHead( + input_dim=encoder_critic.output_dim + + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + ), + output_normalization=self.normalize_targets, + ) + + self.critic_target = CriticEnsemble( + encoder=encoder_critic, + ensemble=Ensemble( + [ + CriticHead( + input_dim=encoder_critic.output_dim + + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + ), + output_normalization=self.normalize_targets, + ) + + self.critic_target.load_state_dict(self.critic_ensemble.state_dict()) + + self.critic_ensemble = torch.compile(self.critic_ensemble) + self.critic_target = torch.compile(self.critic_target) self.actor = Policy( encoder=encoder_actor, network=MLP( - input_dim=encoder_actor.output_dim, - **config.actor_network_kwargs + input_dim=encoder_actor.output_dim, **config.actor_network_kwargs ), action_dim=config.output_shapes["action"][0], - **config.policy_kwargs + encoder_is_shared=config.shared_encoder, + **config.policy_kwargs, ) if config.target_entropy is None: - config.target_entropy = -np.prod(config.output_shapes["action"][0]) # (-dim(A)) - self.temperature = LagrangeMultiplier(init_value=config.temperature_init) + config.target_entropy = ( + -np.prod(config.output_shapes["action"][0]) / 2 + ) # (-dim(A)/2) + + # TODO (azouitine): Handle the case where the temparameter is a fixed + # TODO (michel-aractingi): Put the log_alpha in cuda by default because otherwise + # it triggers "can't optimize a non-leaf Tensor" + self.log_alpha = nn.Parameter(torch.tensor([0.0])) + self.temperature = self.log_alpha.exp().item() + + def _save_pretrained(self, save_directory): + """Custom save method to handle TensorDict properly""" + import os + import json + from dataclasses import asdict + from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME + from safetensors.torch import save_file + + # NOTE: Using tensordict.from_modules in the model to batch the inference using torch.vmap + # implies one side effect: the __batch_size parameters are saved in the state_dict + # __batch_size is torch.Size or safetensor save only torch.Tensor + # so we need to filter them out before saving + simplified_state_dict = {} + + for name, param in self.named_parameters(): + simplified_state_dict[name] = param + save_file( + simplified_state_dict, os.path.join(save_directory, SAFETENSORS_SINGLE_FILE) + ) + + # Save config + config_dict = asdict(self.config) + with open(os.path.join(save_directory, CONFIG_NAME), "w") as f: + json.dump(config_dict, f, indent=2) + print(f"Saved config to {os.path.join(save_directory, CONFIG_NAME)}") + + @classmethod + def _from_pretrained( + cls, + *, + model_id: str, + revision: Optional[str], + cache_dir: Optional[Union[str, Path]], + force_download: bool, + proxies: Optional[Dict], + resume_download: Optional[bool], + local_files_only: bool, + token: Optional[Union[str, bool]], + map_location: str = "cpu", + strict: bool = False, + **model_kwargs, + ) -> "SACPolicy": + """Custom load method to handle loading SAC policy from saved files""" + import os + import json + from pathlib import Path + from huggingface_hub import hf_hub_download + from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE, CONFIG_NAME + from safetensors.torch import load_file + from lerobot.common.policies.sac.configuration_sac import SACConfig + + # Check if model_id is a local path or a hub model ID + if os.path.isdir(model_id): + model_path = Path(model_id) + safetensors_file = os.path.join(model_path, SAFETENSORS_SINGLE_FILE) + config_file = os.path.join(model_path, CONFIG_NAME) + else: + # Download the safetensors file from the hub + safetensors_file = hf_hub_download( + repo_id=model_id, + filename=SAFETENSORS_SINGLE_FILE, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + # Download the config file + try: + config_file = hf_hub_download( + repo_id=model_id, + filename=CONFIG_NAME, + revision=revision, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + token=token, + local_files_only=local_files_only, + ) + except Exception: + config_file = None + + # Load or create config + if config_file and os.path.exists(config_file): + # Load config from file + with open(config_file) as f: + config_dict = json.load(f) + config = SACConfig(**config_dict) + else: + # Use the provided config or create a default one + config = model_kwargs.get("config", SACConfig()) + + # Create a new instance with the loaded config + model = cls(config=config) + + # Load state dict from safetensors file + if os.path.exists(safetensors_file): + # Note: The load_file function returns a dict with the parameters, but __batch_size + # is not loaded so we need to copy it from the model state_dict + # Load the parameters only + loaded_state_dict = load_file(safetensors_file, device=map_location) + + # Copy batch size parameters + find_and_copy_params( + original_state_dict=model.state_dict(), + loaded_state_dict=loaded_state_dict, + pattern="__batch_size", + match_type="endswith", + ) + + # Copy normalization buffer parameters + find_and_copy_params( + original_state_dict=model.state_dict(), + loaded_state_dict=loaded_state_dict, + pattern="_orig_mod.output_normalization.buffer_action", + match_type="contains", + ) + + model.load_state_dict(loaded_state_dict, strict=False) + + return model def reset(self): - """ - Clear observation and action queues. Should be called on `env.reset()` - queues are populated during rollout of the policy, they contain the n latest observations and actions - """ + """Reset the policy""" + pass - self._queues = { - "observation.state": deque(maxlen=1), - "action": deque(maxlen=1), - } - if "observation.image" in self.config.input_shapes: - self._queues["observation.image"] = deque(maxlen=1) - if "observation.environment_state" in self.config.input_shapes: - self._queues["observation.environment_state"] = deque(maxlen=1) + def to(self, *args, **kwargs): + """Override .to(device) method to involve moving the log_alpha fixed_std""" + if self.actor.fixed_std is not None: + self.actor.fixed_std = self.actor.fixed_std.to(*args, **kwargs) + # self.log_alpha = self.log_alpha.to(*args, **kwargs) + super().to(*args, **kwargs) @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select action for inference/evaluation""" - actions, _ = self.actor(batch) + actions, _, _ = self.actor(batch) actions = self.unnormalize_outputs({"action": actions})["action"] return actions - - def critic_forward(self, observations: dict[str, Tensor], actions: Tensor, use_target: bool = False) -> Tensor: + + def critic_forward( + self, + observations: dict[str, Tensor], + actions: Tensor, + use_target: bool = False, + observation_features: Tensor | None = None, + ) -> Tensor: """Forward pass through a critic network ensemble - + Args: observations: Dictionary of observations actions: Action tensor use_target: If True, use target critics, otherwise use ensemble critics - + Returns: Tensor of Q-values from all critics """ critics = self.critic_target if use_target else self.critic_ensemble - q_values = torch.stack([critic(observations, actions) for critic in critics]) + q_values = critics(observations, actions, observation_features) return q_values + def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: ... + def update_target_networks(self): + """Update target networks with exponential moving average""" + for target_param, param in zip( + self.critic_target.parameters(), + self.critic_ensemble.parameters(), + strict=False, + ): + target_param.data.copy_( + param.data * self.config.critic_target_update_weight + + target_param.data * (1.0 - self.config.critic_target_update_weight) + ) - def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: - """Run the batch through the model and compute the loss. - - Returns a dictionary with loss as a tensor, and other information as native floats. - """ - batch = self.normalize_inputs(batch) - # batch shape is (b, 2, ...) where index 1 returns the current observation and - # the next observation for calculating the right td index. - actions = batch["action"][:, 0] - rewards = batch["next.reward"][:, 0] - observations = {} - next_observations = {} - for k in batch: - if k.startswith("observation."): - observations[k] = batch[k][:, 0] - next_observations[k] = batch[k][:, 1] - - # perform image augmentation - - # reward bias from HIL-SERL code base - # add_or_replace={"rewards": batch["rewards"] + self.config["reward_bias"]} in reward_batch - - # calculate critics loss - # 1- compute actions from policy - action_preds, log_probs = self.actor(next_observations) - - # 2- compute q targets - q_targets = self.critic_forward(next_observations, action_preds, use_target=True) - - # subsample critics to prevent overfitting if use high UTD (update to date) - if self.config.num_subsample_critics is not None: - indices = torch.randperm(self.config.num_critics) - indices = indices[:self.config.num_subsample_critics] - q_targets = q_targets[indices] - - # critics subsample size - min_q, _ = q_targets.min(dim=0) # Get values from min operation - - # compute td target - td_target = rewards + self.config.discount * min_q #+ self.config.discount * self.temperature() * log_probs # add entropy term + def compute_loss_critic( + self, + observations, + actions, + rewards, + next_observations, + done, + observation_features: Tensor | None = None, + next_observation_features: Tensor | None = None, + ) -> Tensor: + temperature = self.log_alpha.exp().item() + with torch.no_grad(): + next_action_preds, next_log_probs, _ = self.actor( + next_observations, next_observation_features + ) + + # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way + next_action_preds = self.unnormalize_outputs({"action": next_action_preds})[ + "action" + ] + + # 2- compute q targets + q_targets = self.critic_forward( + observations=next_observations, + actions=next_action_preds, + use_target=True, + observation_features=next_observation_features, + ) + + # subsample critics to prevent overfitting if use high UTD (update to date) + if self.config.num_subsample_critics is not None: + indices = torch.randperm(self.config.num_critics) + indices = indices[: self.config.num_subsample_critics] + q_targets = q_targets[indices] + + # critics subsample size + min_q, _ = q_targets.min(dim=0) # Get values from min operation + if self.config.use_backup_entropy: + min_q = min_q - (temperature * next_log_probs) + + td_target = rewards + (1 - done) * self.config.discount * min_q # 3- compute predicted qs - q_preds = self.critic_forward(observations, actions, use_target=False) + q_preds = self.critic_forward( + observations, + actions, + use_target=False, + observation_features=observation_features, + ) # 4- Calculate loss # Compute state-action value loss (TD loss) for all of the Q functions in the ensemble. - critics_loss = F.mse_loss( - q_preds, # shape: [num_critics, batch_size] - einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), # expand td_target to match q_preds shape - reduction="none" - ).sum(0).mean() - - # critics_loss = ( - # F.mse_loss( - # q_preds, - # einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]), - # reduction="none", - # ).sum(0) # sum over ensemble - # # `q_preds_ensemble` depends on the first observation and the actions. - # * ~batch["observation.state_is_pad"][0] - # * ~batch["action_is_pad"] - # # q_targets depends on the reward and the next observations. - # * ~batch["next.reward_is_pad"] - # * ~batch["observation.state_is_pad"][1:] - # ).sum(0).mean() - - # calculate actors loss - # 1- temperature - temperature = self.temperature() - # 2- get actions (batch_size, action_dim) and log probs (batch_size,) - actions, log_probs = self.actor(observations) - # 3- get q-value predictions - with torch.inference_mode(): - q_preds = self.critic_forward(observations, actions, use_target=False) - actor_loss = ( - -(q_preds - temperature * log_probs).mean() - # * ~batch["observation.state_is_pad"][0] - # * ~batch["action_is_pad"] + td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0]) + # You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up + critics_loss = ( + F.mse_loss( + input=q_preds, + target=td_target_duplicate, + reduction="none", + ).mean(1) + ).sum() + return critics_loss + + def compute_loss_temperature( + self, observations, observation_features: Tensor | None = None + ) -> Tensor: + """Compute the temperature loss""" + # calculate temperature loss + with torch.no_grad(): + _, log_probs, _ = self.actor(observations, observation_features) + temperature_loss = ( + -self.log_alpha.exp() * (log_probs + self.config.target_entropy) ).mean() + return temperature_loss + def compute_loss_actor( + self, observations, observation_features: Tensor | None = None + ) -> Tensor: + temperature = self.log_alpha.exp().item() - # calculate temperature loss - # 1- calculate entropy - entropy = -log_probs.mean() - temperature_loss = self.temperature( - lhs=entropy, - rhs=self.config.target_entropy + actions_pi, log_probs, _ = self.actor(observations, observation_features) + + # TODO: (maractingi, azouitine) This is to slow, we should find a way to do this in a more efficient way + actions_pi = self.unnormalize_outputs({"action": actions_pi})["action"] + + q_preds = self.critic_forward( + observations, + actions_pi, + use_target=False, + observation_features=observation_features, ) + min_q_preds = q_preds.min(dim=0)[0] + + actor_loss = ((temperature * log_probs) - min_q_preds).mean() + return actor_loss + - loss = critics_loss + actor_loss + temperature_loss - - return { - "critics_loss": critics_loss.item(), - "actor_loss": actor_loss.item(), - "temperature_loss": temperature_loss.item(), - "temperature": temperature.item(), - "entropy": entropy.item(), - "loss": loss, - } - - def update(self): - # TODO: implement UTD update - # First update only critics for utd_ratio-1 times - #for critic_step in range(self.config.utd_ratio - 1): - # only update critic and critic target - # Then update critic, critic target, actor and temperature - """Update target networks with exponential moving average""" - with torch.no_grad(): - for target_critic, critic in zip(self.critic_target, self.critic_ensemble, strict=False): - for target_param, param in zip(target_critic.parameters(), critic.parameters(), strict=False): - target_param.data.copy_( - target_param.data * self.config.critic_target_update_weight + - param.data * (1.0 - self.config.critic_target_update_weight) - ) - class MLP(nn.Module): def __init__( self, @@ -258,80 +429,196 @@ def __init__( super().__init__() self.activate_final = activate_final layers = [] - + # First layer uses input_dim layers.append(nn.Linear(input_dim, hidden_dims[0])) - + # Add activation after first layer if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[0])) - layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) - + layers.append( + activations + if isinstance(activations, nn.Module) + else getattr(nn, activations)() + ) + # Rest of the layers for i in range(1, len(hidden_dims)): - layers.append(nn.Linear(hidden_dims[i-1], hidden_dims[i])) - + layers.append(nn.Linear(hidden_dims[i - 1], hidden_dims[i])) + if i + 1 < len(hidden_dims) or activate_final: if dropout_rate is not None and dropout_rate > 0: layers.append(nn.Dropout(p=dropout_rate)) layers.append(nn.LayerNorm(hidden_dims[i])) - layers.append(activations if isinstance(activations, nn.Module) else getattr(nn, activations)()) - + layers.append( + activations + if isinstance(activations, nn.Module) + else getattr(nn, activations)() + ) + self.net = nn.Sequential(*layers) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) - - -class Critic(nn.Module): + + +def find_and_copy_params( + original_state_dict: dict[str, torch.Tensor], + loaded_state_dict: dict[str, torch.Tensor], + pattern: str, + match_type: str = "contains", +) -> list[str]: + """Find and copy parameters from original state dict to loaded state dict based on a pattern. + + This function can search for keys in different ways based on the match_type: + - "exact": The key must exactly match the pattern + - "contains": The key must contain the pattern anywhere + - "startswith": The key must start with the pattern + - "endswith": The key must end with the pattern + + Args: + original_state_dict: The source state dictionary + loaded_state_dict: The target state dictionary + pattern: The pattern to search for in keys + match_type: How to match the pattern (exact, contains, startswith, endswith) + + Returns: + list[str]: List of keys that were copied + """ + copied_keys = [] + + for key in original_state_dict: + should_copy = False + + if match_type == "exact": + should_copy = key == pattern + elif match_type == "contains": + should_copy = pattern in key + elif match_type == "startswith": + should_copy = key.startswith(pattern) + elif match_type == "endswith": + should_copy = key.endswith(pattern) + + if should_copy: + loaded_state_dict[key] = original_state_dict[key] + copied_keys.append(key) + + return copied_keys + + +class CriticHead(nn.Module): def __init__( self, - encoder: Optional[nn.Module], - network: nn.Module, + input_dim: int, + hidden_dims: list[int], + activations: Callable[[torch.Tensor], torch.Tensor] | str = nn.SiLU(), + activate_final: bool = False, + dropout_rate: Optional[float] = None, init_final: Optional[float] = None, - device: str = "cuda" ): super().__init__() - self.device = torch.device(device) - self.encoder = encoder - self.network = network - self.init_final = init_final - - # Find the last Linear layer's output dimension - for layer in reversed(network.net): - if isinstance(layer, nn.Linear): - out_features = layer.out_features - break - - # Output layer + self.net = MLP( + input_dim=input_dim, + hidden_dims=hidden_dims, + activations=activations, + activate_final=activate_final, + dropout_rate=dropout_rate, + ) + self.output_layer = nn.Linear(in_features=hidden_dims[-1], out_features=1) if init_final is not None: - self.output_layer = nn.Linear(out_features, 1) nn.init.uniform_(self.output_layer.weight, -init_final, init_final) nn.init.uniform_(self.output_layer.bias, -init_final, init_final) else: - self.output_layer = nn.Linear(out_features, 1) orthogonal_init()(self.output_layer.weight) - - self.to(self.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.output_layer(self.net(x)) + + +class CriticEnsemble(nn.Module): + """ + ┌──────────────────┬─────────────────────────────────────────────────────────┐ + │ Critic Ensemble │ │ + ├──────────────────┘ │ + │ │ + │ ┌────┐ ┌────┐ ┌────┐ │ + │ │ Q1 │ │ Q2 │ │ Qn │ │ + │ └────┘ └────┘ └────┘ │ + │ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ + │ │ │ │ │ │ │ │ + │ │ MLP 1 │ │ MLP 2 │ │ MLP │ │ + │ │ │ │ │ ... │ num_critics │ │ + │ │ │ │ │ │ │ │ + │ └──────────────┘ └──────────────┘ └──────────────┘ │ + │ ▲ ▲ ▲ │ + │ └───────────────────┴───────┬────────────────────────────┘ │ + │ │ │ + │ │ │ + │ ┌───────────────────┐ │ + │ │ Embedding │ │ + │ │ │ │ + │ └───────────────────┘ │ + │ ▲ │ + │ │ │ + │ ┌─────────────┴────────────┐ │ + │ │ │ │ + │ │ SACObservationEncoder │ │ + │ │ │ │ + │ └──────────────────────────┘ │ + │ ▲ │ + │ │ │ + │ │ │ + │ │ │ + └───────────────────────────┬────────────────────┬───────────────────────────┘ + │ Observation │ + └────────────────────┘ + """ + + def __init__( + self, + encoder: Optional[nn.Module], + ensemble: "Ensemble[CriticHead]", + output_normalization: nn.Module, + init_final: Optional[float] = None, + ): + super().__init__() + self.encoder = encoder + self.ensemble = ensemble + self.init_final = init_final + self.output_normalization = output_normalization + + self.parameters_to_optimize = [] + # Handle the case where a part of the encoder if frozen + if self.encoder is not None: + self.parameters_to_optimize += list(self.encoder.parameters_to_optimize) + self.parameters_to_optimize += list(self.ensemble.parameters()) def forward( - self, - observations: dict[str, torch.Tensor], + self, + observations: dict[str, torch.Tensor], actions: torch.Tensor, + observation_features: torch.Tensor | None = None, ) -> torch.Tensor: + device = get_device_from_parameters(self) # Move each tensor in observations to device - observations = { - k: v.to(self.device) for k, v in observations.items() - } - actions = actions.to(self.device) - - obs_enc = observations if self.encoder is None else self.encoder(observations) - + observations = {k: v.to(device) for k, v in observations.items()} + # NOTE: We normalize actions it helps for sample efficiency + actions: dict[str, torch.tensor] = {"action": actions} + # NOTE: Normalization layer took dict in input and outputs a dict that why + actions = self.output_normalization(actions)["action"] + actions = actions.to(device) + + obs_enc = ( + observation_features + if observation_features is not None + else (observations if self.encoder is None else self.encoder(observations)) + ) + inputs = torch.cat([obs_enc, actions], dim=-1) - x = self.network(inputs) - value = self.output_layer(x) - return value.squeeze(-1) + q_values = self.ensemble(inputs) # [num_critics, B, 1] + return q_values.squeeze(-1) # [num_critics, B] + class Policy(nn.Module): def __init__( @@ -344,24 +631,27 @@ def __init__( fixed_std: Optional[torch.Tensor] = None, init_final: Optional[float] = None, use_tanh_squash: bool = False, - device: str = "cuda" + encoder_is_shared: bool = False, ): super().__init__() - self.device = torch.device(device) self.encoder = encoder self.network = network self.action_dim = action_dim self.log_std_min = log_std_min self.log_std_max = log_std_max - self.fixed_std = fixed_std.to(self.device) if fixed_std is not None else None + self.fixed_std = fixed_std self.use_tanh_squash = use_tanh_squash - + self.parameters_to_optimize = [] + + self.parameters_to_optimize += list(self.network.parameters()) + + if self.encoder is not None and not encoder_is_shared: + self.parameters_to_optimize += list(self.encoder.parameters()) # Find the last Linear layer's output dimension for layer in reversed(network.net): if isinstance(layer, nn.Linear): out_features = layer.out_features break - # Mean layer self.mean_layer = nn.Linear(out_features, action_dim) if init_final is not None: @@ -369,7 +659,8 @@ def __init__( nn.init.uniform_(self.mean_layer.bias, -init_final, init_final) else: orthogonal_init()(self.mean_layer.weight) - + + self.parameters_to_optimize += list(self.mean_layer.parameters()) # Standard deviation layer or parameter if fixed_std is None: self.std_layer = nn.Linear(out_features, action_dim) @@ -378,44 +669,62 @@ def __init__( nn.init.uniform_(self.std_layer.bias, -init_final, init_final) else: orthogonal_init()(self.std_layer.weight) - - self.to(self.device) + self.parameters_to_optimize += list(self.std_layer.parameters()) def forward( - self, + self, observations: torch.Tensor, + observation_features: torch.Tensor | None = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - # Encode observations if encoder exists - obs_enc = observations if self.encoder is None else self.encoder(observations) + obs_enc = ( + observation_features + if observation_features is not None + else (observations if self.encoder is None else self.encoder(observations)) + ) # Get network outputs outputs = self.network(obs_enc) means = self.mean_layer(outputs) - + # Compute standard deviations if self.fixed_std is None: log_std = self.std_layer(outputs) + assert not torch.isnan( + log_std + ).any(), "[ERROR] log_std became NaN after std_layer!" + if self.use_tanh_squash: log_std = torch.tanh(log_std) - log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) + log_std = self.log_std_min + 0.5 * ( + self.log_std_max - self.log_std_min + ) * (log_std + 1.0) + else: + log_std = torch.clamp(log_std, self.log_std_min, self.log_std_max) else: log_std = self.fixed_std.expand_as(means) - - # uses tahn activation function to squash the action to be in the range of [-1, 1] + + # uses tanh activation function to squash the action to be in the range of [-1, 1] normal = torch.distributions.Normal(means, torch.exp(log_std)) - x_t = normal.rsample() # for reparameterization trick (mean + std * N(0,1)) - log_probs = normal.log_prob(x_t) + x_t = normal.rsample() # Reparameterization trick (mean + std * N(0,1)) + log_probs = normal.log_prob(x_t) # Base log probability before Tanh + if self.use_tanh_squash: actions = torch.tanh(x_t) - log_probs -= torch.log((1 - actions.pow(2)) + 1e-6) - log_probs = log_probs.sum(-1) # sum over action dim + log_probs -= torch.log( + (1 - actions.pow(2)) + 1e-6 + ) # Adjust log-probs for Tanh + else: + actions = x_t # No Tanh; raw Gaussian sample + + log_probs = log_probs.sum(-1) # Sum over action dimensions + means = torch.tanh(means) if self.use_tanh_squash else means + return actions, log_probs, means - return actions, log_probs - def get_features(self, observations: torch.Tensor) -> torch.Tensor: """Get encoded features from observations""" - observations = observations.to(self.device) + device = get_device_from_parameters(self) + observations = observations.to(device) if self.encoder is not None: with torch.inference_mode(): return self.encoder(observations) @@ -423,59 +732,67 @@ def get_features(self, observations: torch.Tensor) -> torch.Tensor: class SACObservationEncoder(nn.Module): - """Encode image and/or state vector observations. - TODO(ke-wang): The original work allows for (1) stacking multiple history frames and (2) using pretrained resnet encoders. - """ + """Encode image and/or state vector observations.""" - def __init__(self, config: SACConfig): + def __init__(self, config: SACConfig, input_normalizer: nn.Module): """ Creates encoders for pixel and/or state modalities. """ super().__init__() self.config = config + self.input_normalization = input_normalizer + self.has_pretrained_vision_encoder = False + self.parameters_to_optimize = [] + + self.aggregation_size: int = 0 + if any("observation.image" in key for key in config.input_shapes): + self.camera_number = config.camera_number + + if self.config.vision_encoder_name is not None: + self.image_enc_layers = PretrainedImageEncoder(config) + self.has_pretrained_vision_encoder = True + else: + self.image_enc_layers = DefaultImageEncoder(config) + + self.aggregation_size += config.latent_dim * self.camera_number + + if config.freeze_vision_encoder: + freeze_image_encoder(self.image_enc_layers) + else: + self.parameters_to_optimize += list(self.image_enc_layers.parameters()) + self.all_image_keys = [ + k for k in config.input_shapes if k.startswith("observation.image") + ] - if "observation.image" in config.input_shapes: - self.image_enc_layers = nn.Sequential( - nn.Conv2d( - config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 - ), - nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), - nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), - nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), - nn.ReLU(), - ) - dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) - with torch.inference_mode(): - out_shape = self.image_enc_layers(dummy_batch).shape[1:] - self.image_enc_layers.extend( - nn.Sequential( - nn.Flatten(), - nn.Linear(np.prod(out_shape), config.latent_dim), - nn.LayerNorm(config.latent_dim), - nn.Tanh(), - ) - ) if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( - nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), - nn.ELU(), - nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), - nn.LayerNorm(config.latent_dim), + nn.Linear( + in_features=config.input_shapes["observation.state"][0], + out_features=config.latent_dim, + ), + nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) + self.aggregation_size += config.latent_dim + + self.parameters_to_optimize += list(self.state_enc_layers.parameters()) + if "observation.environment_state" in config.input_shapes: self.env_state_enc_layers = nn.Sequential( nn.Linear( - config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim + in_features=config.input_shapes["observation.environment_state"][0], + out_features=config.latent_dim, ), - nn.ELU(), - nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), - nn.LayerNorm(config.latent_dim), + nn.LayerNorm(normalized_shape=config.latent_dim), nn.Tanh(), ) + self.aggregation_size += config.latent_dim + self.parameters_to_optimize += list(self.env_state_enc_layers.parameters()) + + self.aggregation_layer = nn.Linear( + in_features=self.aggregation_size, out_features=config.latent_dim + ) + self.parameters_to_optimize += list(self.aggregation_layer.parameters()) def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: """Encode the image and/or state vector. @@ -484,84 +801,187 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: over all features. """ feat = [] - # Concatenate all images along the channel dimension. - image_keys = [k for k in self.config.input_shapes if k.startswith("observation.image")] - for image_key in image_keys: - feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict[image_key])) + obs_dict = self.input_normalization(obs_dict) + # Batch all images along the batch dimension, then encode them. + if len(self.all_image_keys) > 0: + images_batched = torch.cat( + [obs_dict[key] for key in self.all_image_keys], dim=0 + ) + images_batched = self.image_enc_layers(images_batched) + embeddings_chunks = torch.chunk( + images_batched, dim=0, chunks=len(self.all_image_keys) + ) + feat.extend(embeddings_chunks) + if "observation.environment_state" in self.config.input_shapes: - feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + feat.append( + self.env_state_enc_layers(obs_dict["observation.environment_state"]) + ) if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) - # TODO(ke-wang): currently average over all features, concatenate all features maybe a better way - return torch.stack(feat, dim=0).mean(0) - + + features = torch.cat(tensors=feat, dim=-1) + features = self.aggregation_layer(features) + + return features + @property def output_dim(self) -> int: """Returns the dimension of the encoder output""" return self.config.latent_dim -class LagrangeMultiplier(nn.Module): - def __init__( - self, - init_value: float = 1.0, - constraint_shape: Sequence[int] = (), - device: str = "cuda" - ): +class DefaultImageEncoder(nn.Module): + def __init__(self, config): super().__init__() - self.device = torch.device(device) - init_value = torch.log(torch.exp(torch.tensor(init_value, device=self.device)) - 1) - - # Initialize the Lagrange multiplier as a parameter - self.lagrange = nn.Parameter( - torch.full(constraint_shape, init_value, dtype=torch.float32, device=self.device) + self.image_enc_layers = nn.Sequential( + nn.Conv2d( + in_channels=config.input_shapes["observation.image"][0], + out_channels=config.image_encoder_hidden_dim, + kernel_size=7, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=5, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + nn.Conv2d( + in_channels=config.image_encoder_hidden_dim, + out_channels=config.image_encoder_hidden_dim, + kernel_size=3, + stride=2, + ), + nn.ReLU(), + ) + dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) + with torch.inference_mode(): + self.image_enc_out_shape = self.image_enc_layers(dummy_batch).shape[1:] + self.image_enc_layers.extend( + nn.Sequential( + nn.Flatten(), + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) ) - - self.to(self.device) - def forward( - self, - lhs: Optional[torch.Tensor | float | int] = None, - rhs: Optional[torch.Tensor | float | int] = None - ) -> torch.Tensor: - # Get the multiplier value based on parameterization - multiplier = torch.nn.functional.softplus(self.lagrange) - - # Return the raw multiplier if no constraint values provided - if lhs is None: - return multiplier - - # Convert inputs to tensors and move to device - lhs = torch.tensor(lhs, device=self.device) if not isinstance(lhs, torch.Tensor) else lhs.to(self.device) - if rhs is not None: - rhs = torch.tensor(rhs, device=self.device) if not isinstance(rhs, torch.Tensor) else rhs.to(self.device) + def forward(self, x): + return self.image_enc_layers(x) + + +class PretrainedImageEncoder(nn.Module): + def __init__(self, config): + super().__init__() + + self.image_enc_layers, self.image_enc_out_shape = ( + self._load_pretrained_vision_encoder(config) + ) + self.image_enc_proj = nn.Sequential( + nn.Linear(np.prod(self.image_enc_out_shape), config.latent_dim), + nn.LayerNorm(config.latent_dim), + nn.Tanh(), + ) + + def _load_pretrained_vision_encoder(self, config): + """Set up CNN encoder""" + from transformers import AutoModel + + self.image_enc_layers = AutoModel.from_pretrained( + config.vision_encoder_name, trust_remote_code=True + ) + # self.image_enc_layers.pooler = Identity() + + if hasattr(self.image_enc_layers.config, "hidden_sizes"): + self.image_enc_out_shape = self.image_enc_layers.config.hidden_sizes[ + -1 + ] # Last channel dimension + elif hasattr(self.image_enc_layers, "fc"): + self.image_enc_out_shape = self.image_enc_layers.fc.in_features else: - rhs = torch.zeros_like(lhs, device=self.device) - - diff = lhs - rhs - - assert diff.shape == multiplier.shape, f"Shape mismatch: {diff.shape} vs {multiplier.shape}" - - return multiplier * diff + raise ValueError( + "Unsupported vision encoder architecture, make sure you are using a CNN" + ) + return self.image_enc_layers, self.image_enc_out_shape + + def forward(self, x): + # TODO: (maractingi, azouitine) check the forward pass of the pretrained model + # doesn't reach the classifier layer because we don't need it + enc_feat = self.image_enc_layers(x).pooler_output + enc_feat = self.image_enc_proj(enc_feat.view(enc_feat.shape[0], -1)) + return enc_feat + + +def freeze_image_encoder(image_encoder: nn.Module): + """Freeze all parameters in the encoder""" + for param in image_encoder.parameters(): + param.requires_grad = False def orthogonal_init(): return lambda x: torch.nn.init.orthogonal_(x, gain=1.0) -def create_critic_ensemble(critics: list[nn.Module], num_critics: int, device: str = "cuda") -> nn.ModuleList: - """Creates an ensemble of critic networks""" - assert len(critics) == num_critics, f"Expected {num_critics} critics, got {len(critics)}" - return nn.ModuleList(critics).to(device) +class Identity(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return x + + +class Ensemble(nn.Module): + """ + Vectorized ensemble of modules. + """ + + def __init__(self, modules, **kwargs): + super().__init__() + # combine_state_for_ensemble causes graph breaks + self.params = from_modules(*modules, as_module=True) + with self.params[0].data.to("meta").to_module(modules[0]): + self.module = deepcopy(modules[0]) + self._repr = str(modules[0]) + self._n = len(modules) + def __len__(self): + return self._n + + def _call(self, params, *args, **kwargs): + with params.to_module(self.module): + return self.module(*args, **kwargs) + + def forward(self, *args, **kwargs): + return torch.vmap(self._call, (0, None), randomness="different")( + self.params, *args, **kwargs + ) + + def __repr__(self): + return f"Vectorized {len(self)}x " + self._repr + + +# TODO (azouitine): I think in our case this function is not usefull we should remove it +# after some investigation # borrowed from tdmpc -def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: +def flatten_forward_unflatten( + fn: Callable[[Tensor], Tensor], image_tensor: Tensor +) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. Args: fn: Callable that the image tensor will be passed to. It should accept (B, C, H, W) and return (B, *), where * is any number of dimensions. - image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and + image_tensor: An image tensor of shape (**, C, H, W), where ** is any number of dimensions and can be more than 1 dimensions, generally different from *. Returns: A return value from the callable reshaped to (**, *). @@ -571,4 +991,91 @@ def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tens start_dims = image_tensor.shape[:-3] inp = torch.flatten(image_tensor, end_dim=-4) flat_out = fn(inp) - return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) \ No newline at end of file + return torch.reshape(flat_out, (*start_dims, *flat_out.shape[1:])) + + +def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict: + converted_params = {} + for outer_key, inner_dict in normalization_params.items(): + converted_params[outer_key] = {} + for key, value in inner_dict.items(): + converted_params[outer_key][key] = torch.tensor(value) + if "image" in outer_key: + converted_params[outer_key][key] = converted_params[outer_key][ + key + ].view(3, 1, 1) + + return converted_params + + +if __name__ == "__main__": + # Test the SACObservationEncoder + import time + + config = SACConfig() + config.num_critics = 10 + config.vision_encoder_name = None + encoder = SACObservationEncoder(config, nn.Identity()) + # actor_encoder = SACObservationEncoder(config) + # encoder = torch.compile(encoder) + critic_ensemble = CriticEnsemble( + encoder=encoder, + ensemble=Ensemble( + [ + CriticHead( + input_dim=encoder.output_dim + config.output_shapes["action"][0], + **config.critic_network_kwargs, + ) + for _ in range(config.num_critics) + ] + ), + output_normalization=nn.Identity(), + ) + # actor = Policy( + # encoder=actor_encoder, + # network=MLP(input_dim=actor_encoder.output_dim, **config.actor_network_kwargs), + # action_dim=config.output_shapes["action"][0], + # encoder_is_shared=config.shared_encoder, + # **config.policy_kwargs, + # ) + # encoder = encoder.to("cuda:0") + # critic_ensemble = torch.compile(critic_ensemble) + critic_ensemble = critic_ensemble.to("cuda:0") + # actor = torch.compile(actor) + # actor = actor.to("cuda:0") + obs_dict = { + "observation.image": torch.randn(8, 3, 84, 84), + "observation.state": torch.randn(8, 4), + } + actions = torch.randn(8, 2).to("cuda:0") + # obs_dict = {k: v.to("cuda:0") for k, v in obs_dict.items()} + # print("compiling...") + q_value = critic_ensemble(obs_dict, actions) + print(q_value.size()) + # action = actor(obs_dict) + # print("compiled") + # start = time.perf_counter() + # for _ in range(1000): + # # features = encoder(obs_dict) + # action = actor(obs_dict) + # # q_value = critic_ensemble(obs_dict, actions) + # print("Time taken:", time.perf_counter() - start) + # Compare the performance of the ensemble vs a for loop of 16 MLPs + ensemble = Ensemble([CriticHead(256, [256, 256]) for _ in range(2)]) + ensemble = ensemble.to("cuda:0") + critic = CriticHead(256, [256, 256]) + critic = critic.to("cuda:0") + data_ensemble = torch.randn(8, 256).to("cuda:0") + ensemble = torch.compile(ensemble) + # critic = torch.compile(critic) + print(ensemble(data_ensemble).size()) + print(critic(data_ensemble).size()) + start = time.perf_counter() + for _ in range(1000): + ensemble(data_ensemble) + print("Time taken:", time.perf_counter() - start) + start = time.perf_counter() + for _ in range(1000): + for i in range(2): + critic(data_ensemble) + print("Time taken:", time.perf_counter() - start) diff --git a/lerobot/common/policies/tdmpc/configuration_tdmpc.py b/lerobot/common/policies/tdmpc/configuration_tdmpc.py index 4a5415a15..8f4683a1b 100644 --- a/lerobot/common/policies/tdmpc/configuration_tdmpc.py +++ b/lerobot/common/policies/tdmpc/configuration_tdmpc.py @@ -191,6 +191,10 @@ def __post_init__(self): "If `n_action_steps > 1`, `n_action_repeats` must be left to its default value of 1." ) if not self.use_mpc: - raise ValueError("If `n_action_steps > 1`, `use_mpc` must be set to `True`.") + raise ValueError( + "If `n_action_steps > 1`, `use_mpc` must be set to `True`." + ) if self.n_action_steps > self.horizon: - raise ValueError("`n_action_steps` must be less than or equal to `horizon`.") + raise ValueError( + "`n_action_steps` must be less than or equal to `horizon`." + ) diff --git a/lerobot/common/policies/tdmpc/modeling_tdmpc.py b/lerobot/common/policies/tdmpc/modeling_tdmpc.py index d97c4824c..fdccbe23f 100644 --- a/lerobot/common/policies/tdmpc/modeling_tdmpc.py +++ b/lerobot/common/policies/tdmpc/modeling_tdmpc.py @@ -68,7 +68,9 @@ class TDMPCPolicy( name = "tdmpc" def __init__( - self, config: TDMPCConfig | None = None, dataset_stats: dict[str, dict[str, Tensor]] | None = None + self, + config: TDMPCConfig | None = None, + dataset_stats: dict[str, dict[str, Tensor]] | None = None, ): """ Args: @@ -100,7 +102,9 @@ def __init__( config.output_shapes, config.output_normalization_modes, dataset_stats ) - image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + image_keys = [ + k for k in config.input_shapes if k.startswith("observation.image") + ] # Note: This check is covered in the post-init of the config but have a sanity check just in case. self._use_image = False self._use_env_state = False @@ -120,7 +124,9 @@ def reset(self): """ self._queues = { "observation.state": deque(maxlen=1), - "action": deque(maxlen=max(self.config.n_action_steps, self.config.n_action_repeats)), + "action": deque( + maxlen=max(self.config.n_action_steps, self.config.n_action_repeats) + ), } if self._use_image: self._queues["observation.image"] = deque(maxlen=1) @@ -135,7 +141,9 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" batch = self.normalize_inputs(batch) if self._use_image: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original batch["observation.image"] = batch[self.input_image_key] self._queues = populate_queues(self._queues, batch) @@ -209,13 +217,20 @@ def plan(self, z: Tensor) -> Tensor: # In the CEM loop we will need this for a call to estimate_value with the gaussian sampled # trajectories. - z = einops.repeat(z, "b d -> n b d", n=self.config.n_gaussian_samples + self.config.n_pi_samples) + z = einops.repeat( + z, + "b d -> n b d", + n=self.config.n_gaussian_samples + self.config.n_pi_samples, + ) # Model Predictive Path Integral (MPPI) with the cross-entropy method (CEM) as the optimization # algorithm. # The initial mean and standard deviation for the cross-entropy method (CEM). mean = torch.zeros( - self.config.horizon, batch_size, self.config.output_shapes["action"][0], device=device + self.config.horizon, + batch_size, + self.config.output_shapes["action"][0], + device=device, ) # Maybe warm start CEM with the mean from the previous step. if self._prev_mean is not None: @@ -231,35 +246,47 @@ def plan(self, z: Tensor) -> Tensor: self.config.output_shapes["action"][0], device=std.device, ) - gaussian_actions = torch.clamp(mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1) + gaussian_actions = torch.clamp( + mean.unsqueeze(1) + std.unsqueeze(1) * std_normal_noise, -1, 1 + ) # Compute elite actions. actions = torch.cat([gaussian_actions, pi_actions], dim=1) value = self.estimate_value(z, actions).nan_to_num_(0) - elite_idxs = torch.topk(value, self.config.n_elites, dim=0).indices # (n_elites, batch) + elite_idxs = torch.topk( + value, self.config.n_elites, dim=0 + ).indices # (n_elites, batch) elite_value = value.take_along_dim(elite_idxs, dim=0) # (n_elites, batch) # (horizon, n_elites, batch, action_dim) - elite_actions = actions.take_along_dim(einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1) + elite_actions = actions.take_along_dim( + einops.rearrange(elite_idxs, "n b -> 1 n b 1"), dim=1 + ) # Update gaussian PDF parameters to be the (weighted) mean and standard deviation of the elites. max_value = elite_value.max(0, keepdim=True)[0] # (1, batch) # The weighting is a softmax over trajectory values. Note that this is not the same as the usage # of Ω in eqn 4 of the TD-MPC paper. Instead it is the normalized version of it: s = Ω/ΣΩ. This # makes the equations: μ = Σ(s⋅Γ), σ = Σ(s⋅(Γ-μ)²). - score = torch.exp(self.config.elite_weighting_temperature * (elite_value - max_value)) + score = torch.exp( + self.config.elite_weighting_temperature * (elite_value - max_value) + ) score /= score.sum(axis=0, keepdim=True) # (horizon, batch, action_dim) - _mean = torch.sum(einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1) + _mean = torch.sum( + einops.rearrange(score, "n b -> n b 1") * elite_actions, dim=1 + ) _std = torch.sqrt( torch.sum( einops.rearrange(score, "n b -> n b 1") - * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) ** 2, + * (elite_actions - einops.rearrange(_mean, "h b d -> h 1 b d")) + ** 2, dim=1, ) ) # Update mean with an exponential moving average, and std with a direct replacement. mean = ( - self.config.gaussian_mean_momentum * mean + (1 - self.config.gaussian_mean_momentum) * _mean + self.config.gaussian_mean_momentum * mean + + (1 - self.config.gaussian_mean_momentum) * _mean ) std = _std.clamp_(self.config.min_std, self.config.max_std) @@ -268,7 +295,9 @@ def plan(self, z: Tensor) -> Tensor: # Randomly select one of the elite actions from the last iteration of MPPI/CEM using the softmax # scores from the last iteration. - actions = elite_actions[:, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size)] + actions = elite_actions[ + :, torch.multinomial(score.T, 1).squeeze(), torch.arange(batch_size) + ] return actions @@ -291,7 +320,8 @@ def estimate_value(self, z: Tensor, actions: Tensor): # of the FOWM paper. if self.config.uncertainty_regularizer_coeff > 0: regularization = -( - self.config.uncertainty_regularizer_coeff * self.model.Qs(z, actions[t]).std(0) + self.config.uncertainty_regularizer_coeff + * self.model.Qs(z, actions[t]).std(0) ) else: regularization = 0 @@ -311,15 +341,22 @@ def estimate_value(self, z: Tensor, actions: Tensor): if self.config.q_ensemble_size > 2: G += ( running_discount - * torch.min(terminal_values[torch.randint(0, self.config.q_ensemble_size, size=(2,))], dim=0)[ - 0 - ] + * torch.min( + terminal_values[ + torch.randint(0, self.config.q_ensemble_size, size=(2,)) + ], + dim=0, + )[0] ) else: G += running_discount * torch.min(terminal_values, dim=0)[0] # Finally, also regularize the terminal value. if self.config.uncertainty_regularizer_coeff > 0: - G -= running_discount * self.config.uncertainty_regularizer_coeff * terminal_values.std(0) + G -= ( + running_discount + * self.config.uncertainty_regularizer_coeff + * terminal_values.std(0) + ) return G def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: @@ -331,7 +368,9 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: batch = self.normalize_inputs(batch) if self._use_image: - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original batch["observation.image"] = batch[self.input_image_key] batch = self.normalize_targets(batch) @@ -349,7 +388,10 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: # Apply random image augmentations. if self._use_image and self.config.max_random_shift_ratio > 0: observations["observation.image"] = flatten_forward_unflatten( - partial(random_shifts_aug, max_random_shift_ratio=self.config.max_random_shift_ratio), + partial( + random_shifts_aug, + max_random_shift_ratio=self.config.max_random_shift_ratio, + ), observations["observation.image"], ) @@ -367,14 +409,20 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: # Note this has shape `horizon+1` because there are `horizon` actions and a current `z`. Each action # gives us a next `z`. batch_size = batch["index"].shape[0] - z_preds = torch.empty(horizon + 1, batch_size, self.config.latent_dim, device=device) + z_preds = torch.empty( + horizon + 1, batch_size, self.config.latent_dim, device=device + ) z_preds[0] = self.model.encode(current_observation) reward_preds = torch.empty_like(reward, device=device) for t in range(horizon): - z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward(z_preds[t], action[t]) + z_preds[t + 1], reward_preds[t] = self.model.latent_dynamics_and_reward( + z_preds[t], action[t] + ) # Compute Q and V value predictions based on the latent rollout. - q_preds_ensemble = self.model.Qs(z_preds[:-1], action) # (ensemble, horizon, batch) + q_preds_ensemble = self.model.Qs( + z_preds[:-1], action + ) # (ensemble, horizon, batch) v_preds = self.model.V(z_preds[:-1]) info.update({"Q": q_preds_ensemble.mean().item(), "V": v_preds.mean().item()}) @@ -388,10 +436,14 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: # actions (not actions estimated by π). # Note: Here we do not use self.model_target, but self.model. This is to follow the original code # and the FOWM paper. - q_targets = reward + self.config.discount * self.model.V(self.model.encode(next_observations)) + q_targets = reward + self.config.discount * self.model.V( + self.model.encode(next_observations) + ) # From eqn 3 of FOWM. These appear as Q(z, a). Here we call them v_targets to emphasize that we # are using them to compute loss for V. - v_targets = self.model_target.Qs(z_preds[:-1].detach(), action, return_min=True) + v_targets = self.model_target.Qs( + z_preds[:-1].detach(), action, return_min=True + ) # Compute losses. # Exponentially decay the loss weight with respect to the timestep. Steps that are more distant in the @@ -434,7 +486,9 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: temporal_loss_coeffs * F.mse_loss( q_preds_ensemble, - einops.repeat(q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0]), + einops.repeat( + q_targets, "t b -> e t b", e=q_preds_ensemble.shape[0] + ), reduction="none", ).sum(0) # sum over ensemble # `q_preds_ensemble` depends on the first observation and the actions. @@ -472,12 +526,14 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor | float]: z_preds = z_preds.detach() # Use stopgrad for the advantage calculation. with torch.no_grad(): - advantage = self.model_target.Qs(z_preds[:-1], action, return_min=True) - self.model.V( - z_preds[:-1] - ) + advantage = self.model_target.Qs( + z_preds[:-1], action, return_min=True + ) - self.model.V(z_preds[:-1]) info["advantage"] = advantage[0] # (t, b) - exp_advantage = torch.clamp(torch.exp(advantage * self.config.advantage_scaling), max=100.0) + exp_advantage = torch.clamp( + torch.exp(advantage * self.config.advantage_scaling), max=100.0 + ) action_preds = self.model.pi(z_preds[:-1]) # (t, b, a) # Calculate the MSE between the actions and the action predictions. # Note: FOWM's original code calculates the log probability (wrt to a unit standard deviation @@ -532,7 +588,9 @@ def update(self): # Note a minor variation with respect to the original FOWM code. Here they do this based on an EMA # update frequency parameter which is set to 2 (every 2 steps an update is done). To simplify the code # we update every step and adjust the decay parameter `alpha` accordingly (0.99 -> 0.995) - update_ema_parameters(self.model_target, self.model, self.config.target_model_momentum) + update_ema_parameters( + self.model_target, self.model, self.config.target_model_momentum + ) class TDMPCTOLD(nn.Module): @@ -543,7 +601,9 @@ def __init__(self, config: TDMPCConfig): self.config = config self._encoder = TDMPCObservationEncoder(config) self._dynamics = nn.Sequential( - nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.Linear( + config.latent_dim + config.output_shapes["action"][0], config.mlp_dim + ), nn.LayerNorm(config.mlp_dim), nn.Mish(), nn.Linear(config.mlp_dim, config.mlp_dim), @@ -554,7 +614,9 @@ def __init__(self, config: TDMPCConfig): nn.Sigmoid(), ) self._reward = nn.Sequential( - nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.Linear( + config.latent_dim + config.output_shapes["action"][0], config.mlp_dim + ), nn.LayerNorm(config.mlp_dim), nn.Mish(), nn.Linear(config.mlp_dim, config.mlp_dim), @@ -574,7 +636,10 @@ def __init__(self, config: TDMPCConfig): self._Qs = nn.ModuleList( [ nn.Sequential( - nn.Linear(config.latent_dim + config.output_shapes["action"][0], config.mlp_dim), + nn.Linear( + config.latent_dim + config.output_shapes["action"][0], + config.mlp_dim, + ), nn.LayerNorm(config.mlp_dim), nn.Tanh(), nn.Linear(config.mlp_dim, config.mlp_dim), @@ -619,7 +684,9 @@ def _apply_fn(m): m[-1], nn.Linear ), "Sanity check. The last linear layer needs 0 initialization on weights." nn.init.zeros_(m[-1].weight) - nn.init.zeros_(m[-1].bias) # this has already been done, but keep this line here for good measure + nn.init.zeros_( + m[-1].bias + ) # this has already been done, but keep this line here for good measure def encode(self, obs: dict[str, Tensor]) -> Tensor: """Encodes an observation into its latent representation.""" @@ -717,14 +784,32 @@ def __init__(self, config: TDMPCConfig): if "observation.image" in config.input_shapes: self.image_enc_layers = nn.Sequential( nn.Conv2d( - config.input_shapes["observation.image"][0], config.image_encoder_hidden_dim, 7, stride=2 + config.input_shapes["observation.image"][0], + config.image_encoder_hidden_dim, + 7, + stride=2, ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 5, stride=2), + nn.Conv2d( + config.image_encoder_hidden_dim, + config.image_encoder_hidden_dim, + 5, + stride=2, + ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.Conv2d( + config.image_encoder_hidden_dim, + config.image_encoder_hidden_dim, + 3, + stride=2, + ), nn.ReLU(), - nn.Conv2d(config.image_encoder_hidden_dim, config.image_encoder_hidden_dim, 3, stride=2), + nn.Conv2d( + config.image_encoder_hidden_dim, + config.image_encoder_hidden_dim, + 3, + stride=2, + ), nn.ReLU(), ) dummy_batch = torch.zeros(1, *config.input_shapes["observation.image"]) @@ -740,7 +825,10 @@ def __init__(self, config: TDMPCConfig): ) if "observation.state" in config.input_shapes: self.state_enc_layers = nn.Sequential( - nn.Linear(config.input_shapes["observation.state"][0], config.state_encoder_hidden_dim), + nn.Linear( + config.input_shapes["observation.state"][0], + config.state_encoder_hidden_dim, + ), nn.ELU(), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), nn.LayerNorm(config.latent_dim), @@ -749,7 +837,8 @@ def __init__(self, config: TDMPCConfig): if "observation.environment_state" in config.input_shapes: self.env_state_enc_layers = nn.Sequential( nn.Linear( - config.input_shapes["observation.environment_state"][0], config.state_encoder_hidden_dim + config.input_shapes["observation.environment_state"][0], + config.state_encoder_hidden_dim, ), nn.ELU(), nn.Linear(config.state_encoder_hidden_dim, config.latent_dim), @@ -766,9 +855,15 @@ def forward(self, obs_dict: dict[str, Tensor]) -> Tensor: feat = [] # NOTE: Order of observations matters here. if "observation.image" in self.config.input_shapes: - feat.append(flatten_forward_unflatten(self.image_enc_layers, obs_dict["observation.image"])) + feat.append( + flatten_forward_unflatten( + self.image_enc_layers, obs_dict["observation.image"] + ) + ) if "observation.environment_state" in self.config.input_shapes: - feat.append(self.env_state_enc_layers(obs_dict["observation.environment_state"])) + feat.append( + self.env_state_enc_layers(obs_dict["observation.environment_state"]) + ) if "observation.state" in self.config.input_shapes: feat.append(self.state_enc_layers(obs_dict["observation.state"])) return torch.stack(feat, dim=0).mean(0) @@ -811,12 +906,17 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): """Update EMA parameters in place with ema_param <- alpha * ema_param + (1 - alpha) * param.""" for ema_module, module in zip(ema_net.modules(), net.modules(), strict=True): for (n_p_ema, p_ema), (n_p, p) in zip( - ema_module.named_parameters(recurse=False), module.named_parameters(recurse=False), strict=True + ema_module.named_parameters(recurse=False), + module.named_parameters(recurse=False), + strict=True, ): assert n_p_ema == n_p, "Parameter names don't match for EMA model update" if isinstance(p, dict): raise RuntimeError("Dict parameter not supported") - if isinstance(module, nn.modules.batchnorm._BatchNorm) or not p.requires_grad: + if ( + isinstance(module, nn.modules.batchnorm._BatchNorm) + or not p.requires_grad + ): # Copy BatchNorm parameters, and non-trainable parameters directly. p_ema.copy_(p.to(dtype=p_ema.dtype).data) with torch.no_grad(): @@ -824,7 +924,9 @@ def update_ema_parameters(ema_net: nn.Module, net: nn.Module, alpha: float): p_ema.add_(p.to(dtype=p_ema.dtype).data, alpha=1 - alpha) -def flatten_forward_unflatten(fn: Callable[[Tensor], Tensor], image_tensor: Tensor) -> Tensor: +def flatten_forward_unflatten( + fn: Callable[[Tensor], Tensor], image_tensor: Tensor +) -> Tensor: """Helper to temporarily flatten extra dims at the start of the image tensor. Args: diff --git a/lerobot/common/policies/vqbet/configuration_vqbet.py b/lerobot/common/policies/vqbet/configuration_vqbet.py index dfe4684d2..e92c269e5 100644 --- a/lerobot/common/policies/vqbet/configuration_vqbet.py +++ b/lerobot/common/policies/vqbet/configuration_vqbet.py @@ -109,7 +109,9 @@ class VQBeTConfig: "observation.state": "min_max", } ) - output_normalization_modes: dict[str, str] = field(default_factory=lambda: {"action": "min_max"}) + output_normalization_modes: dict[str, str] = field( + default_factory=lambda: {"action": "min_max"} + ) # Architecture / modeling. # Vision backbone. diff --git a/lerobot/common/policies/vqbet/modeling_vqbet.py b/lerobot/common/policies/vqbet/modeling_vqbet.py index 98adce00b..25af6a7d6 100644 --- a/lerobot/common/policies/vqbet/modeling_vqbet.py +++ b/lerobot/common/policies/vqbet/modeling_vqbet.py @@ -79,7 +79,9 @@ def __init__( self.vqbet = VQBeTModel(config) - self.expected_image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + self.expected_image_keys = [ + k for k in config.input_shapes if k.startswith("observation.image") + ] self.reset() @@ -104,8 +106,12 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: """ batch = self.normalize_inputs(batch) - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[k] for k in self.expected_image_keys], dim=-4 + ) # Note: It's important that this happens after stacking the images into a single key. self._queues = populate_queues(self._queues, batch) @@ -116,8 +122,14 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: ) if len(self._queues["action"]) == 0: - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.vqbet(batch, rollout=True)[:, : self.config.action_chunk_size] + batch = { + k: torch.stack(list(self._queues[k]), dim=1) + for k in batch + if k in self._queues + } + actions = self.vqbet(batch, rollout=True)[ + :, : self.config.action_chunk_size + ] # the dimension of returned action is (batch_size, action_chunk_size, action_dim) actions = self.unnormalize_outputs({"action": actions})["action"] @@ -130,8 +142,12 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: """Run the batch through the model and compute the loss for training or validation.""" batch = self.normalize_inputs(batch) - batch = dict(batch) # shallow copy so that adding a key doesn't modify the original - batch["observation.images"] = torch.stack([batch[k] for k in self.expected_image_keys], dim=-4) + batch = dict( + batch + ) # shallow copy so that adding a key doesn't modify the original + batch["observation.images"] = torch.stack( + [batch[k] for k in self.expected_image_keys], dim=-4 + ) batch = self.normalize_targets(batch) # VQ-BeT discretizes action using VQ-VAE before training BeT (please refer to section 3.2 in the VQ-BeT paper https://arxiv.org/pdf/2403.03181) if not self.vqbet.action_head.vqvae_model.discretized.item(): @@ -139,7 +155,9 @@ def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # n_different_codes: how many of the total possible VQ codes are being used in single batch (how many of them have at least one encoder embedding as a nearest neighbor). This can be at most `vqvae_n_embed * number of layers of RVQ (=2)`. # n_different_combinations: how many different code combinations are being used out of all possible combinations in single batch. This can be at most `vqvae_n_embed ^ number of layers of RVQ (=2)` (hint consider the RVQ as a decision tree). loss, n_different_codes, n_different_combinations, recon_l1_error = ( - self.vqbet.action_head.discretize(self.config.n_vqvae_training_steps, batch["action"]) + self.vqbet.action_head.discretize( + self.config.n_vqvae_training_steps, batch["action"] + ) ) return { "loss": loss, @@ -196,7 +214,9 @@ def __init__(self, input_shape, num_kp=None): # we could use torch.linspace directly but that seems to behave slightly differently than numpy # and causes a small degradation in pc_success of pre-trained models. - pos_x, pos_y = np.meshgrid(np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h)) + pos_x, pos_y = np.meshgrid( + np.linspace(-1.0, 1.0, self._in_w), np.linspace(-1.0, 1.0, self._in_h) + ) pos_x = torch.from_numpy(pos_x.reshape(self._in_h * self._in_w, 1)).float() pos_y = torch.from_numpy(pos_y.reshape(self._in_h * self._in_w, 1)).float() # register as buffer so it's moved to the correct device. @@ -288,14 +308,17 @@ def __init__(self, config: VQBeTConfig): self.config = config self.rgb_encoder = VQBeTRgbEncoder(config) - self.num_images = len([k for k in config.input_shapes if k.startswith("observation.image")]) + self.num_images = len( + [k for k in config.input_shapes if k.startswith("observation.image")] + ) # This action query token is used as a prompt for querying action chunks. Please refer to "A_Q" in the image above. # Note: During the forward pass, this token is repeated as many times as needed. The authors also experimented with initializing the necessary number of tokens independently and observed inferior results. self.action_token = nn.Parameter(torch.randn(1, 1, self.config.gpt_input_dim)) # To input state and observation features into GPT layers, we first project the features to fit the shape of input size of GPT. self.state_projector = MLP( - config.input_shapes["observation.state"][0], hidden_channels=[self.config.gpt_input_dim] + config.input_shapes["observation.state"][0], + hidden_channels=[self.config.gpt_input_dim], ) self.rgb_feature_projector = MLP( self.rgb_encoder.feature_dim, hidden_channels=[self.config.gpt_input_dim] @@ -310,7 +333,12 @@ def __init__(self, config: VQBeTConfig): num_tokens = self.config.n_action_pred_token + self.config.n_obs_steps - 1 self.register_buffer( "select_target_actions_indices", - torch.row_stack([torch.arange(i, i + self.config.action_chunk_size) for i in range(num_tokens)]), + torch.row_stack( + [ + torch.arange(i, i + self.config.action_chunk_size) + for i in range(num_tokens) + ] + ), ) def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: @@ -325,7 +353,11 @@ def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: ) # Separate batch and sequence dims. img_features = einops.rearrange( - img_features, "(b s n) ... -> b s n ...", b=batch_size, s=n_obs_steps, n=self.num_images + img_features, + "(b s n) ... -> b s n ...", + b=batch_size, + s=n_obs_steps, + n=self.num_images, ) # Arrange prior and current observation step tokens as shown in the class docstring. @@ -337,13 +369,19 @@ def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: input_tokens.append( self.state_projector(batch["observation.state"]) ) # (batch, obs_step, projection dims) - input_tokens.append(einops.repeat(self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps)) + input_tokens.append( + einops.repeat( + self.action_token, "1 1 d -> b n d", b=batch_size, n=n_obs_steps + ) + ) # Interleave tokens by stacking and rearranging. input_tokens = torch.stack(input_tokens, dim=2) input_tokens = einops.rearrange(input_tokens, "b n t d -> b (n t) d") len_additional_action_token = self.config.n_action_pred_token - 1 - future_action_tokens = self.action_token.repeat(batch_size, len_additional_action_token, 1) + future_action_tokens = self.action_token.repeat( + batch_size, len_additional_action_token, 1 + ) # add additional action query tokens for predicting future action chunks input_tokens = torch.cat([input_tokens, future_action_tokens], dim=1) @@ -352,9 +390,9 @@ def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: features = self.policy(input_tokens) # len(self.config.input_shapes) is the number of different observation modes. # this line gets the index of action prompt tokens. - historical_act_pred_index = np.arange(0, n_obs_steps) * (len(self.config.input_shapes) + 1) + len( - self.config.input_shapes - ) + historical_act_pred_index = np.arange(0, n_obs_steps) * ( + len(self.config.input_shapes) + 1 + ) + len(self.config.input_shapes) # only extract the output tokens at the position of action query: # Behavior Transformer (BeT), and VQ-BeT are both sequence-to-sequence prediction models, @@ -362,7 +400,11 @@ def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: # Thus, it predicts a historical action sequence, in addition to current and future actions (predicting future actions : optional). if len_additional_action_token > 0: features = torch.cat( - [features[:, historical_act_pred_index], features[:, -len_additional_action_token:]], dim=1 + [ + features[:, historical_act_pred_index], + features[:, -len_additional_action_token:], + ], + dim=1, ) else: features = features[:, historical_act_pred_index] @@ -370,13 +412,15 @@ def forward(self, batch: dict[str, Tensor], rollout: bool) -> Tensor: action_head_output = self.action_head(features) # if rollout, VQ-BeT don't calculate loss if rollout: - return action_head_output["predicted_action"][:, n_obs_steps - 1, :].reshape( - batch_size, self.config.action_chunk_size, -1 - ) + return action_head_output["predicted_action"][ + :, n_obs_steps - 1, : + ].reshape(batch_size, self.config.action_chunk_size, -1) # else, it calculate overall loss (bin prediction loss, and offset loss) else: output = batch["action"][:, self.select_target_actions_indices] - loss = self.action_head.loss_fn(action_head_output, output, reduction="mean") + loss = self.action_head.loss_fn( + action_head_output, output, reduction="mean" + ) return action_head_output, loss @@ -411,7 +455,9 @@ def __init__(self, config: VQBeTConfig): else: self.map_to_cbet_preds_bin = MLP( in_channels=config.gpt_output_dim, - hidden_channels=[self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed], + hidden_channels=[ + self.vqvae_model.vqvae_num_layers * self.config.vqvae_n_embed + ], ) self.map_to_cbet_preds_offset = MLP( in_channels=config.gpt_output_dim, @@ -438,7 +484,10 @@ def discretize(self, n_vqvae_training_steps, actions): loss, metric = self.vqvae_model.vqvae_forward(actions) n_different_codes = sum( - [len(torch.unique(metric[2][:, i])) for i in range(self.vqvae_model.vqvae_num_layers)] + [ + len(torch.unique(metric[2][:, i])) + for i in range(self.vqvae_model.vqvae_num_layers) + ] ) n_different_combinations = len(torch.unique(metric[2], dim=0)) recon_l1_error = metric[0].detach().cpu().item() @@ -485,7 +534,13 @@ def forward(self, x, **kwargs): cbet_secondary_logits = self.map_to_cbet_preds_secondary_bin( torch.cat( - (x, F.one_hot(sampled_primary_centers, num_classes=self.config.vqvae_n_embed)), + ( + x, + F.one_hot( + sampled_primary_centers, + num_classes=self.config.vqvae_n_embed, + ), + ), axis=1, ) ) @@ -493,19 +548,29 @@ def forward(self, x, **kwargs): cbet_secondary_logits / self.config.bet_softmax_temperature, dim=-1 ) sampled_secondary_centers = einops.rearrange( - torch.multinomial(cbet_secondary_probs.view(-1, choices), num_samples=1), + torch.multinomial( + cbet_secondary_probs.view(-1, choices), num_samples=1 + ), "(NT) 1 -> NT", NT=NT, ) - sampled_centers = torch.stack((sampled_primary_centers, sampled_secondary_centers), axis=1) - cbet_logits = torch.stack([cbet_primary_logits, cbet_secondary_logits], dim=1) + sampled_centers = torch.stack( + (sampled_primary_centers, sampled_secondary_centers), axis=1 + ) + cbet_logits = torch.stack( + [cbet_primary_logits, cbet_secondary_logits], dim=1 + ) # if self.config.sequentially_select is False, bin prediction head samples primary and secondary code at once. else: cbet_logits = self.map_to_cbet_preds_bin(x) cbet_logits = einops.rearrange( - cbet_logits, "(NT) (G C) -> (NT) G C", G=self.vqvae_model.vqvae_num_layers + cbet_logits, + "(NT) (G C) -> (NT) G C", + G=self.vqvae_model.vqvae_num_layers, + ) + cbet_probs = torch.softmax( + cbet_logits / self.config.bet_softmax_temperature, dim=-1 ) - cbet_probs = torch.softmax(cbet_logits / self.config.bet_softmax_temperature, dim=-1) NT, G, choices = cbet_probs.shape sampled_centers = einops.rearrange( torch.multinomial(cbet_probs.view(-1, choices), num_samples=1), @@ -525,9 +590,17 @@ def forward(self, x, **kwargs): sampled_offsets = sampled_offsets.sum(dim=1) with torch.no_grad(): # Get the centroids (= vectors corresponding to the codes) of each layer to pass it through RVQ decoder - return_decoder_input = self.vqvae_model.get_embeddings_from_code(sampled_centers).clone().detach() + return_decoder_input = ( + self.vqvae_model.get_embeddings_from_code(sampled_centers) + .clone() + .detach() + ) # pass the centroids through decoder to get actions. - decoded_action = self.vqvae_model.get_action_from_latent(return_decoder_input).clone().detach() + decoded_action = ( + self.vqvae_model.get_action_from_latent(return_decoder_input) + .clone() + .detach() + ) # reshaped extracted offset to match with decoded centroids sampled_offsets = einops.rearrange( sampled_offsets, "NT (W A) -> NT W A", W=self.config.action_chunk_size @@ -576,7 +649,9 @@ def loss_fn(self, pred, target, **kwargs): # Figure out the loss for the actions. # First, we need to find the closest cluster center for each ground truth action. with torch.no_grad(): - state_vq, action_bins = self.vqvae_model.get_code(action_seq) # action_bins: NT, G + state_vq, action_bins = self.vqvae_model.get_code( + action_seq + ) # action_bins: NT, G # Now we can compute the loss. @@ -599,8 +674,12 @@ def loss_fn(self, pred, target, **kwargs): + cbet_loss2 * self.config.secondary_code_loss_weight ) - equal_primary_code_rate = torch.sum((action_bins[:, 0] == sampled_centers[:, 0]).int()) / (NT) - equal_secondary_code_rate = torch.sum((action_bins[:, 1] == sampled_centers[:, 1]).int()) / (NT) + equal_primary_code_rate = torch.sum( + (action_bins[:, 0] == sampled_centers[:, 0]).int() + ) / (NT) + equal_secondary_code_rate = torch.sum( + (action_bins[:, 1] == sampled_centers[:, 1]).int() + ) / (NT) action_mse_error = torch.mean((action_seq - predicted_action) ** 2) vq_action_error = torch.mean(torch.abs(action_seq - decoded_action)) @@ -614,7 +693,9 @@ def loss_fn(self, pred, target, **kwargs): "classification_loss": cbet_loss.detach().cpu().item(), "offset_loss": offset_loss.detach().cpu().item(), "equal_primary_code_rate": equal_primary_code_rate.detach().cpu().item(), - "equal_secondary_code_rate": equal_secondary_code_rate.detach().cpu().item(), + "equal_secondary_code_rate": equal_secondary_code_rate.detach() + .cpu() + .item(), "vq_action_error": vq_action_error.detach().cpu().item(), "offset_action_error": offset_action_error.detach().cpu().item(), "action_error_max": action_error_max.detach().cpu().item(), @@ -643,11 +724,17 @@ def __init__(self, policy, cfg): if cfg.policy.sequentially_select: decay_params = ( decay_params - + list(policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters()) - + list(policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters()) + + list( + policy.vqbet.action_head.map_to_cbet_preds_primary_bin.parameters() + ) + + list( + policy.vqbet.action_head.map_to_cbet_preds_secondary_bin.parameters() + ) ) else: - decay_params = decay_params + list(policy.vqbet.action_head.map_to_cbet_preds_bin.parameters()) + decay_params = decay_params + list( + policy.vqbet.action_head.map_to_cbet_preds_bin.parameters() + ) optim_groups = [ { @@ -693,7 +780,11 @@ def lr_lambda(current_step): progress = float(current_step - num_warmup_steps) / float( max(1, num_training_steps - num_warmup_steps) ) - return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))) + return max( + 0.0, + 0.5 + * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) self.lr_scheduler = LambdaLR(optimizer, lr_lambda, -1) @@ -717,7 +808,9 @@ def __init__(self, config: VQBeTConfig): # Always use center crop for eval self.center_crop = torchvision.transforms.CenterCrop(config.crop_shape) if config.crop_is_random: - self.maybe_random_crop = torchvision.transforms.RandomCrop(config.crop_shape) + self.maybe_random_crop = torchvision.transforms.RandomCrop( + config.crop_shape + ) else: self.maybe_random_crop = self.center_crop else: @@ -738,7 +831,9 @@ def __init__(self, config: VQBeTConfig): self.backbone = _replace_submodules( root_module=self.backbone, predicate=lambda x: isinstance(x, nn.BatchNorm2d), - func=lambda x: nn.GroupNorm(num_groups=x.num_features // 16, num_channels=x.num_features), + func=lambda x: nn.GroupNorm( + num_groups=x.num_features // 16, num_channels=x.num_features + ), ) # Set up pooling and final layers. @@ -746,17 +841,25 @@ def __init__(self, config: VQBeTConfig): # The dummy input should take the number of image channels from `config.input_shapes` and it should # use the height and width from `config.crop_shape` if it is provided, otherwise it should use the # height and width from `config.input_shapes`. - image_keys = [k for k in config.input_shapes if k.startswith("observation.image")] + image_keys = [ + k for k in config.input_shapes if k.startswith("observation.image") + ] assert len(image_keys) == 1 image_key = image_keys[0] dummy_input_h_w = ( - config.crop_shape if config.crop_shape is not None else config.input_shapes[image_key][1:] + config.crop_shape + if config.crop_shape is not None + else config.input_shapes[image_key][1:] + ) + dummy_input = torch.zeros( + size=(1, config.input_shapes[image_key][0], *dummy_input_h_w) ) - dummy_input = torch.zeros(size=(1, config.input_shapes[image_key][0], *dummy_input_h_w)) with torch.inference_mode(): dummy_feature_map = self.backbone(dummy_input) feature_map_shape = tuple(dummy_feature_map.shape[1:]) - self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints) + self.pool = SpatialSoftmax( + feature_map_shape, num_kp=config.spatial_softmax_num_keypoints + ) self.feature_dim = config.spatial_softmax_num_keypoints * 2 self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim) self.relu = nn.ReLU() @@ -783,7 +886,9 @@ def forward(self, x: Tensor) -> Tensor: def _replace_submodules( - root_module: nn.Module, predicate: Callable[[nn.Module], bool], func: Callable[[nn.Module], nn.Module] + root_module: nn.Module, + predicate: Callable[[nn.Module], bool], + func: Callable[[nn.Module], nn.Module], ) -> nn.Module: """ Args: @@ -796,7 +901,11 @@ def _replace_submodules( if predicate(root_module): return func(root_module) - replace_list = [k.split(".") for k, m in root_module.named_modules(remove_duplicate=True) if predicate(m)] + replace_list = [ + k.split(".") + for k, m in root_module.named_modules(remove_duplicate=True) + if predicate(m) + ] for *parents, k in replace_list: parent_module = root_module if len(parents) > 0: @@ -811,7 +920,9 @@ def _replace_submodules( else: setattr(parent_module, k, tgt_module) # verify that all BN are replaced - assert not any(predicate(m) for _, m in root_module.named_modules(remove_duplicate=True)) + assert not any( + predicate(m) for _, m in root_module.named_modules(remove_duplicate=True) + ) return root_module @@ -844,7 +955,8 @@ def __init__( ) self.encoder = MLP( - in_channels=self.config.output_shapes["action"][0] * self.config.action_chunk_size, + in_channels=self.config.output_shapes["action"][0] + * self.config.action_chunk_size, hidden_channels=[ config.vqvae_enc_hidden_dim, config.vqvae_enc_hidden_dim, @@ -872,9 +984,13 @@ def get_action_from_latent(self, latent): # given latent vector, this function outputs the decoded action. output = self.decoder(latent) if self.config.action_chunk_size == 1: - return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]) + return einops.rearrange( + output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0] + ) else: - return einops.rearrange(output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0]) + return einops.rearrange( + output, "N (T A) -> N T A", A=self.config.output_shapes["action"][0] + ) def get_code(self, state): # in phase 2 of VQ-BeT training, we need a `ground truth labels of action data` to calculate the Focal loss for code prediction head. (please refer to section 3.3 in the paper https://arxiv.org/pdf/2403.03181) diff --git a/lerobot/common/policies/vqbet/vqbet_utils.py b/lerobot/common/policies/vqbet/vqbet_utils.py index 90a2cfda3..acbe9ade8 100644 --- a/lerobot/common/policies/vqbet/vqbet_utils.py +++ b/lerobot/common/policies/vqbet/vqbet_utils.py @@ -123,9 +123,15 @@ def forward(self, x): # calculate query, key, values for all heads in batch and move head forward to be the batch dim q, k, v = self.c_attn(x).split(self.gpt_hidden_dim, dim=2) - k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) - q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) - v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose(1, 2) # (B, nh, T, hs) + k = k.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + q = q.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) + v = v.view(B, T, self.gpt_n_head, C // self.gpt_n_head).transpose( + 1, 2 + ) # (B, nh, T, hs) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) @@ -133,7 +139,9 @@ def forward(self, x): att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side + y = ( + y.transpose(1, 2).contiguous().view(B, T, C) + ) # re-assemble all head outputs side by side # output projection y = self.resid_dropout(self.c_proj(y)) @@ -189,12 +197,16 @@ def __init__(self, config: VQBeTConfig): "ln_f": nn.LayerNorm(config.gpt_hidden_dim), } ) - self.lm_head = nn.Linear(config.gpt_hidden_dim, config.gpt_output_dim, bias=False) + self.lm_head = nn.Linear( + config.gpt_hidden_dim, config.gpt_output_dim, bias=False + ) # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper self.apply(self._init_weights) for pn, p in self.named_parameters(): if pn.endswith("c_proj.weight"): - torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer)) + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * config.gpt_n_layer) + ) # report number of parameters n_params = sum(p.numel() for p in self.parameters()) @@ -208,11 +220,17 @@ def forward(self, input, targets=None): ), f"Cannot forward sequence of length {t}, block size is only {self.config.gpt_block_size}" # positional encodings that are added to the input embeddings - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) + pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze( + 0 + ) # shape (1, t) # forward the GPT model itself - tok_emb = self.transformer.wte(input) # token embeddings of shape (b, t, gpt_hidden_dim) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, gpt_hidden_dim) + tok_emb = self.transformer.wte( + input + ) # token embeddings of shape (b, t, gpt_hidden_dim) + pos_emb = self.transformer.wpe( + pos + ) # position embeddings of shape (1, t, gpt_hidden_dim) x = self.transformer.drop(tok_emb + pos_emb) for block in self.transformer.h: x = block(x) @@ -237,7 +255,9 @@ def crop_block_size(self, gpt_block_size): # but want to use a smaller block size for some smaller, simpler model assert gpt_block_size <= self.config.gpt_block_size self.config.gpt_block_size = gpt_block_size - self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size]) + self.transformer.wpe.weight = nn.Parameter( + self.transformer.wpe.weight[:gpt_block_size] + ) for block in self.transformer.h: block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size] @@ -270,7 +290,9 @@ def configure_parameters(self): param_dict = dict(self.named_parameters()) inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format( + assert ( + len(inter_params) == 0 + ), "parameters {} made it into both decay/no_decay sets!".format( str(inter_params) ) assert ( @@ -368,8 +390,12 @@ def __init__( codebook_input_dim = codebook_dim * heads requires_projection = codebook_input_dim != dim - self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() - self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + self.project_in = ( + nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + ) self.num_quantizers = num_quantizers @@ -377,7 +403,10 @@ def __init__( self.layers = nn.ModuleList( [ VectorQuantize( - dim=codebook_dim, codebook_dim=codebook_dim, accept_image_fmap=accept_image_fmap, **kwargs + dim=codebook_dim, + codebook_dim=codebook_dim, + accept_image_fmap=accept_image_fmap, + **kwargs, ) for _ in range(num_quantizers) ] @@ -448,7 +477,9 @@ def get_codebook_vector_from_indices(self, indices): return all_codes - def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp=None): + def forward( + self, x, indices=None, return_all_codes=False, sample_codebook_temp=None + ): """ For given input tensor x, this function will return the quantized output, the indices of the quantized output, and the loss. First, the input tensor x is projected to the codebook dimension. Then, the input tensor x is passed through Nq layers of VectorQuantize. @@ -477,13 +508,17 @@ def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp= ), "some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss" ce_losses = [] - should_quantize_dropout = self.training and self.quantize_dropout and not return_loss + should_quantize_dropout = ( + self.training and self.quantize_dropout and not return_loss + ) # sample a layer index at which to dropout further residual quantization # also prepare null indices and loss if should_quantize_dropout: - rand_quantize_dropout_index = randrange(self.quantize_dropout_cutoff_index, num_quant) + rand_quantize_dropout_index = randrange( + self.quantize_dropout_cutoff_index, num_quant + ) if quant_dropout_multiple_of != 1: rand_quantize_dropout_index = ( @@ -492,14 +527,23 @@ def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp= - 1 ) - null_indices_shape = (x.shape[0], *x.shape[-2:]) if self.accept_image_fmap else tuple(x.shape[:2]) - null_indices = torch.full(null_indices_shape, -1.0, device=device, dtype=torch.long) + null_indices_shape = ( + (x.shape[0], *x.shape[-2:]) + if self.accept_image_fmap + else tuple(x.shape[:2]) + ) + null_indices = torch.full( + null_indices_shape, -1.0, device=device, dtype=torch.long + ) null_loss = torch.full((1,), 0.0, device=device, dtype=x.dtype) # go through the layers for quantizer_index, layer in enumerate(self.layers): - if should_quantize_dropout and quantizer_index > rand_quantize_dropout_index: + if ( + should_quantize_dropout + and quantizer_index > rand_quantize_dropout_index + ): all_indices.append(null_indices) all_losses.append(null_loss) continue @@ -539,7 +583,9 @@ def forward(self, x, indices=None, return_all_codes=False, sample_codebook_temp= # stack all losses and indices - all_losses, all_indices = map(partial(torch.stack, dim=-1), (all_losses, all_indices)) + all_losses, all_indices = map( + partial(torch.stack, dim=-1), (all_losses, all_indices) + ) ret = (quantized_out, all_indices, all_losses) @@ -599,8 +645,12 @@ def __init__( codebook_input_dim = codebook_dim * heads requires_projection = codebook_input_dim != dim - self.project_in = nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() - self.project_out = nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + self.project_in = ( + nn.Linear(dim, codebook_input_dim) if requires_projection else nn.Identity() + ) + self.project_out = ( + nn.Linear(codebook_input_dim, dim) if requires_projection else nn.Identity() + ) self.eps = eps self.commitment_weight = commitment_weight @@ -614,10 +664,14 @@ def __init__( self.orthogonal_reg_active_codes_only = orthogonal_reg_active_codes_only self.orthogonal_reg_max_codes = orthogonal_reg_max_codes - assert not (ema_update and learnable_codebook), "learnable codebook not compatible with EMA update" + assert not ( + ema_update and learnable_codebook + ), "learnable codebook not compatible with EMA update" assert 0 <= sync_update_v <= 1.0 - assert not (sync_update_v > 0.0 and not learnable_codebook), "learnable codebook must be turned on" + assert not ( + sync_update_v > 0.0 and not learnable_codebook + ), "learnable codebook must be turned on" self.sync_update_v = sync_update_v @@ -629,7 +683,9 @@ def __init__( ) if sync_codebook is None: - sync_codebook = distributed.is_initialized() and distributed.get_world_size() > 1 + sync_codebook = ( + distributed.is_initialized() and distributed.get_world_size() > 1 + ) codebook_kwargs = { "dim": codebook_dim, @@ -794,11 +850,17 @@ def forward( # quantize again - quantize, embed_ind, distances = self._codebook(x, **codebook_forward_kwargs) + quantize, embed_ind, distances = self._codebook( + x, **codebook_forward_kwargs + ) if self.training: # determine code to use for commitment loss - maybe_detach = torch.detach if not self.learnable_codebook or freeze_codebook else identity + maybe_detach = ( + torch.detach + if not self.learnable_codebook or freeze_codebook + else identity + ) commit_quantize = maybe_detach(quantize) @@ -808,7 +870,9 @@ def forward( if self.sync_update_v > 0.0: # (21) in https://minyoungg.github.io/vqtorch/assets/draft_050523.pdf - quantize = quantize + self.sync_update_v * (quantize - quantize.detach()) + quantize = quantize + self.sync_update_v * ( + quantize - quantize.detach() + ) # function for calculating cross entropy loss to distance matrix # used for (1) naturalspeech2 training residual vq latents to be close to the correct codes and (2) cross-entropy based commitment loss @@ -841,7 +905,9 @@ def calculate_ce_loss(codes): embed_ind = rearrange(embed_ind, "1 (b h) n -> b n h", h=heads) if self.accept_image_fmap: - embed_ind = rearrange(embed_ind, "b (h w) ... -> b h w ...", h=height, w=width) + embed_ind = rearrange( + embed_ind, "b (h w) ... -> b h w ...", h=height, w=width + ) if only_one: embed_ind = rearrange(embed_ind, "b 1 -> b") @@ -895,8 +961,12 @@ def calculate_ce_loss(codes): num_codes = codebook.shape[-2] - if (self.orthogonal_reg_max_codes is not None) and num_codes > self.orthogonal_reg_max_codes: - rand_ids = torch.randperm(num_codes, device=device)[: self.orthogonal_reg_max_codes] + if ( + self.orthogonal_reg_max_codes is not None + ) and num_codes > self.orthogonal_reg_max_codes: + rand_ids = torch.randperm(num_codes, device=device)[ + : self.orthogonal_reg_max_codes + ] codebook = codebook[:, rand_ids] orthogonal_reg_loss = orthogonal_loss_fn(codebook) @@ -928,7 +998,9 @@ def calculate_ce_loss(codes): # if masking, only return quantized for where mask has True if mask is not None: - quantize = torch.where(rearrange(mask, "... -> ... 1"), quantize, orig_input) + quantize = torch.where( + rearrange(mask, "... -> ... 1"), quantize, orig_input + ) return quantize, embed_ind, loss @@ -1038,7 +1110,9 @@ def sample_vectors(samples, num): def batched_sample_vectors(samples, num): - return torch.stack([sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0) + return torch.stack( + [sample_vectors(sample, num) for sample in samples.unbind(dim=0)], dim=0 + ) def pad_shape(shape, size, dim=0): @@ -1089,7 +1163,9 @@ def sample_vectors_distributed(local_samples, num): all_num_samples = all_gather_sizes(local_samples, dim=0) if rank == 0: - samples_per_rank = sample_multinomial(num, all_num_samples / all_num_samples.sum()) + samples_per_rank = sample_multinomial( + num, all_num_samples / all_num_samples.sum() + ) else: samples_per_rank = torch.empty_like(all_num_samples) @@ -1202,7 +1278,9 @@ def __init__( self.eps = eps self.threshold_ema_dead_code = threshold_ema_dead_code self.reset_cluster_size = ( - reset_cluster_size if (reset_cluster_size is not None) else threshold_ema_dead_code + reset_cluster_size + if (reset_cluster_size is not None) + else threshold_ema_dead_code ) assert callable(gumbel_sample) @@ -1213,8 +1291,14 @@ def __init__( use_ddp and num_codebooks > 1 and kmeans_init ), "kmeans init is not compatible with multiple codebooks in distributed environment for now" - self.sample_fn = sample_vectors_distributed if use_ddp and sync_kmeans else batched_sample_vectors - self.kmeans_all_reduce_fn = distributed.all_reduce if use_ddp and sync_kmeans else noop + self.sample_fn = ( + sample_vectors_distributed + if use_ddp and sync_kmeans + else batched_sample_vectors + ) + self.kmeans_all_reduce_fn = ( + distributed.all_reduce if use_ddp and sync_kmeans else noop + ) self.all_reduce_fn = distributed.all_reduce if use_ddp else noop self.register_buffer("initted", torch.Tensor([not kmeans_init])) @@ -1353,7 +1437,9 @@ def update_affine(self, data, embed, mask=None): distributed.all_reduce(variance_numer) batch_variance = variance_numer / num_vectors - self.update_with_decay("batch_variance", batch_variance, self.affine_param_batch_decay) + self.update_with_decay( + "batch_variance", batch_variance, self.affine_param_batch_decay + ) def replace(self, batch_samples, batch_mask): for ind, (samples, mask) in enumerate( @@ -1362,7 +1448,9 @@ def replace(self, batch_samples, batch_mask): if not torch.any(mask): continue - sampled = self.sample_fn(rearrange(samples, "... -> 1 ..."), mask.sum().item()) + sampled = self.sample_fn( + rearrange(samples, "... -> 1 ..."), mask.sum().item() + ) sampled = rearrange(sampled, "1 ... -> ...") self.embed.data[ind][mask] = sampled @@ -1386,7 +1474,9 @@ def expire_codes_(self, batch_samples): def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False): needs_codebook_dim = x.ndim < 4 sample_codebook_temp = ( - sample_codebook_temp if (sample_codebook_temp is not None) else self.sample_codebook_temp + sample_codebook_temp + if (sample_codebook_temp is not None) + else self.sample_codebook_temp ) x = x.float() @@ -1414,7 +1504,9 @@ def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False if self.affine_param: codebook_std = self.codebook_variance.clamp(min=1e-5).sqrt() batch_std = self.batch_variance.clamp(min=1e-5).sqrt() - embed = (embed - self.codebook_mean) * (batch_std / codebook_std) + self.batch_mean + embed = (embed - self.codebook_mean) * ( + batch_std / codebook_std + ) + self.batch_mean dist = -cdist(flatten, embed) @@ -1432,7 +1524,9 @@ def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False if self.training and self.ema_update and not freeze_codebook: if self.affine_param: - flatten = (flatten - self.batch_mean) * (codebook_std / batch_std) + self.codebook_mean + flatten = (flatten - self.batch_mean) * ( + codebook_std / batch_std + ) + self.codebook_mean if mask is not None: embed_onehot[~mask] = 0.0 @@ -1455,7 +1549,9 @@ def forward(self, x, sample_codebook_temp=None, mask=None, freeze_codebook=False self.expire_codes_(x) if needs_codebook_dim: - quantize, embed_ind = tuple(rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind)) + quantize, embed_ind = tuple( + rearrange(t, "1 ... -> ...") for t in (quantize, embed_ind) + ) dist = unpack_one(dist, ps, "h * d") diff --git a/lerobot/common/robot_devices/cameras/intelrealsense.py b/lerobot/common/robot_devices/cameras/intelrealsense.py index 84ac540f2..cda24169e 100644 --- a/lerobot/common/robot_devices/cameras/intelrealsense.py +++ b/lerobot/common/robot_devices/cameras/intelrealsense.py @@ -65,7 +65,9 @@ def save_image(img_array, serial_number, frame_index, images_dir): img.save(str(path), quality=100) logging.info(f"Saved image: {path}") except Exception as e: - logging.error(f"Failed to save image for camera {serial_number} frame {frame_index}: {e}") + logging.error( + f"Failed to save image for camera {serial_number} frame {frame_index}: {e}" + ) def save_images_from_cameras( @@ -94,7 +96,9 @@ def save_images_from_cameras( cameras = [] for cam_sn in serial_numbers: print(f"{cam_sn=}") - camera = IntelRealSenseCamera(cam_sn, fps=fps, width=width, height=height, mock=mock) + camera = IntelRealSenseCamera( + cam_sn, fps=fps, width=width, height=height, mock=mock + ) camera.connect() print( f"IntelRealSenseCamera({camera.serial_number}, fps={camera.fps}, width={camera.width}, height={camera.height}, color_mode={camera.color_mode})" @@ -140,7 +144,9 @@ def save_images_from_cameras( if time.perf_counter() - start_time > record_time_s: break - print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") + print( + f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}" + ) frame_index += 1 finally: @@ -182,8 +188,12 @@ def __post_init__(self): self.channels = 3 - at_least_one_is_not_none = self.fps is not None or self.width is not None or self.height is not None - at_least_one_is_none = self.fps is None or self.width is None or self.height is None + at_least_one_is_not_none = ( + self.fps is not None or self.width is not None or self.height is not None + ) + at_least_one_is_none = ( + self.fps is None or self.width is None or self.height is None + ) if at_least_one_is_not_none and at_least_one_is_none: raise ValueError( "For `fps`, `width` and `height`, either all of them need to be set, or none of them, " @@ -191,7 +201,9 @@ def __post_init__(self): ) if self.rotation not in [-90, None, 90, 180]: - raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + raise ValueError( + f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})" + ) class IntelRealSenseCamera: @@ -286,7 +298,9 @@ def __init__( self.rotation = cv2.ROTATE_180 @classmethod - def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs): + def init_from_name( + cls, name: str, config: IntelRealSenseCameraConfig | None = None, **kwargs + ): camera_infos = find_cameras() camera_names = [cam["name"] for cam in camera_infos] this_name_count = Counter(camera_names)[name] @@ -296,7 +310,9 @@ def init_from_name(cls, name: str, config: IntelRealSenseCameraConfig | None = N f"Multiple {name} cameras have been detected. Please use their serial number to instantiate them." ) - name_to_serial_dict = {cam["name"]: cam["serial_number"] for cam in camera_infos} + name_to_serial_dict = { + cam["name"]: cam["serial_number"] for cam in camera_infos + } cam_sn = name_to_serial_dict[name] if config is None: @@ -323,13 +339,17 @@ def connect(self): if self.fps and self.width and self.height: # TODO(rcadene): can we set rgb8 directly? - config.enable_stream(rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps) + config.enable_stream( + rs.stream.color, self.width, self.height, rs.format.rgb8, self.fps + ) else: config.enable_stream(rs.stream.color) if self.use_depth: if self.fps and self.width and self.height: - config.enable_stream(rs.stream.depth, self.width, self.height, rs.format.z16, self.fps) + config.enable_stream( + rs.stream.depth, self.width, self.height, rs.format.z16, self.fps + ) else: config.enable_stream(rs.stream.depth) @@ -362,7 +382,9 @@ def connect(self): actual_height = color_profile.height() # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + if self.fps is not None and not math.isclose( + self.fps, actual_fps, rel_tol=1e-3 + ): # Using `OSError` since it's a broad that encompasses issues related to device communication raise OSError( f"Can't set {self.fps=} for IntelRealSenseCamera({self.serial_number}). Actual value is {actual_fps}." @@ -382,7 +404,9 @@ def connect(self): self.is_connected = True - def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + def read( + self, temporary_color: str | None = None + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """Read a frame from the camera returned in the format height x width x channels (e.g. 480 x 640 x 3) of type `np.uint8`, contrarily to the pytorch format which is float channel first. @@ -409,11 +433,15 @@ def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndar color_frame = frame.get_color_frame() if not color_frame: - raise OSError(f"Can't capture color image from IntelRealSenseCamera({self.serial_number}).") + raise OSError( + f"Can't capture color image from IntelRealSenseCamera({self.serial_number})." + ) color_image = np.asanyarray(color_frame.get_data()) - requested_color_mode = self.color_mode if temporary_color is None else temporary_color + requested_color_mode = ( + self.color_mode if temporary_color is None else temporary_color + ) if requested_color_mode not in ["rgb", "bgr"]: raise ValueError( f"Expected color values are 'rgb' or 'bgr', but {requested_color_mode} is provided." @@ -441,7 +469,9 @@ def read(self, temporary_color: str | None = None) -> np.ndarray | tuple[np.ndar if self.use_depth: depth_frame = frame.get_depth_frame() if not depth_frame: - raise OSError(f"Can't capture depth image from IntelRealSenseCamera({self.serial_number}).") + raise OSError( + f"Can't capture depth image from IntelRealSenseCamera({self.serial_number})." + ) depth_map = np.asanyarray(depth_frame.get_data()) @@ -483,7 +513,9 @@ def async_read(self): # TODO(rcadene, aliberts): intelrealsense has diverged compared to opencv over here num_tries += 1 time.sleep(1 / self.fps) - if num_tries > self.fps and (self.thread.ident is None or not self.thread.is_alive()): + if num_tries > self.fps and ( + self.thread.ident is None or not self.thread.is_alive() + ): raise Exception( "The thread responsible for `self.async_read()` took too much time to start. There might be an issue. Verify that `self.thread.start()` has been called." ) diff --git a/lerobot/common/robot_devices/cameras/opencv.py b/lerobot/common/robot_devices/cameras/opencv.py index d284cf55a..3b46c8f55 100644 --- a/lerobot/common/robot_devices/cameras/opencv.py +++ b/lerobot/common/robot_devices/cameras/opencv.py @@ -31,10 +31,14 @@ MAX_OPENCV_INDEX = 60 -def find_cameras(raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False) -> list[dict]: +def find_cameras( + raise_when_empty=False, max_index_search_range=MAX_OPENCV_INDEX, mock=False +) -> list[dict]: cameras = [] if platform.system() == "Linux": - print("Linux detected. Finding available camera indices through scanning '/dev/video*' ports") + print( + "Linux detected. Finding available camera indices through scanning '/dev/video*' ports" + ) possible_ports = [str(port) for port in Path("/dev").glob("video*")] ports = _find_cameras(possible_ports, mock=mock) for port in ports: @@ -165,7 +169,9 @@ def save_images_from_cameras( dt_s = time.perf_counter() - now busy_wait(1 / fps - dt_s) - print(f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}") + print( + f"Frame: {frame_index:04d}\tLatency (ms): {(time.perf_counter() - now) * 1000:.2f}" + ) if time.perf_counter() - start_time > record_time_s: break @@ -205,7 +211,9 @@ def __post_init__(self): self.channels = 3 if self.rotation not in [-90, None, 90, 180]: - raise ValueError(f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})") + raise ValueError( + f"`rotation` must be in [-90, None, 90, 180] (got {self.rotation})" + ) class OpenCVCamera: @@ -247,7 +255,12 @@ class OpenCVCamera: ``` """ - def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = None, **kwargs): + def __init__( + self, + camera_index: int | str, + config: OpenCVCameraConfig | None = None, + **kwargs, + ): if config is None: config = OpenCVCameraConfig() @@ -261,12 +274,16 @@ def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = if platform.system() == "Linux": if isinstance(self.camera_index, int): self.port = Path(f"/dev/video{self.camera_index}") - elif isinstance(self.camera_index, str) and is_valid_unix_path(self.camera_index): + elif isinstance(self.camera_index, str) and is_valid_unix_path( + self.camera_index + ): self.port = Path(self.camera_index) # Retrieve the camera index from a potentially symlinked path self.camera_index = get_camera_index_from_unix_port(self.port) else: - raise ValueError(f"Please check the provided camera_index: {camera_index}") + raise ValueError( + f"Please check the provided camera_index: {camera_index}" + ) self.fps = config.fps self.width = config.width @@ -298,7 +315,9 @@ def __init__(self, camera_index: int | str, config: OpenCVCameraConfig | None = def connect(self): if self.is_connected: - raise RobotDeviceAlreadyConnectedError(f"OpenCVCamera({self.camera_index}) is already connected.") + raise RobotDeviceAlreadyConnectedError( + f"OpenCVCamera({self.camera_index}) is already connected." + ) if self.mock: import tests.mock_cv2 as cv2 @@ -309,7 +328,11 @@ def connect(self): # when other threads are used to save the images. cv2.setNumThreads(1) - camera_idx = f"/dev/video{self.camera_index}" if platform.system() == "Linux" else self.camera_index + camera_idx = ( + f"/dev/video{self.camera_index}" + if platform.system() == "Linux" + else self.camera_index + ) # First create a temporary camera trying to access `camera_index`, # and verify it is a valid camera by calling `isOpened`. tmp_camera = cv2.VideoCapture(camera_idx) @@ -349,16 +372,22 @@ def connect(self): actual_height = self.camera.get(cv2.CAP_PROP_FRAME_HEIGHT) # Using `math.isclose` since actual fps can be a float (e.g. 29.9 instead of 30) - if self.fps is not None and not math.isclose(self.fps, actual_fps, rel_tol=1e-3): + if self.fps is not None and not math.isclose( + self.fps, actual_fps, rel_tol=1e-3 + ): # Using `OSError` since it's a broad that encompasses issues related to device communication raise OSError( f"Can't set {self.fps=} for OpenCVCamera({self.camera_index}). Actual value is {actual_fps}." ) - if self.width is not None and not math.isclose(self.width, actual_width, rel_tol=1e-3): + if self.width is not None and not math.isclose( + self.width, actual_width, rel_tol=1e-3 + ): raise OSError( f"Can't set {self.width=} for OpenCVCamera({self.camera_index}). Actual value is {actual_width}." ) - if self.height is not None and not math.isclose(self.height, actual_height, rel_tol=1e-3): + if self.height is not None and not math.isclose( + self.height, actual_height, rel_tol=1e-3 + ): raise OSError( f"Can't set {self.height=} for OpenCVCamera({self.camera_index}). Actual value is {actual_height}." ) @@ -388,7 +417,9 @@ def read(self, temporary_color_mode: str | None = None) -> np.ndarray: if not ret: raise OSError(f"Can't capture color image from camera {self.camera_index}.") - requested_color_mode = self.color_mode if temporary_color_mode is None else temporary_color_mode + requested_color_mode = ( + self.color_mode if temporary_color_mode is None else temporary_color_mode + ) if requested_color_mode not in ["rgb", "bgr"]: raise ValueError( diff --git a/lerobot/common/robot_devices/control_utils.py b/lerobot/common/robot_devices/control_utils.py index 10cb9f5c0..ae25f7ae0 100644 --- a/lerobot/common/robot_devices/control_utils.py +++ b/lerobot/common/robot_devices/control_utils.py @@ -23,11 +23,17 @@ from lerobot.common.policies.factory import make_policy from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import busy_wait -from lerobot.common.utils.utils import get_safe_torch_device, init_hydra_config, set_global_seed +from lerobot.common.utils.utils import ( + get_safe_torch_device, + init_hydra_config, + set_global_seed, +) from lerobot.scripts.eval import get_pretrained_policy_path -def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None): +def log_control_info( + robot: Robot, dt_s, episode_index=None, frame_index=None, fps=None +): log_items = [] if episode_index is not None: log_items.append(f"ep:{episode_index}") @@ -36,7 +42,7 @@ def log_control_info(robot: Robot, dt_s, episode_index=None, frame_index=None, f def log_dt(shortname, dt_val_s): nonlocal log_items, fps - info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1/ dt_val_s:3.1f}hz)" + info_str = f"{shortname}:{dt_val_s * 1000:5.2f} ({1 / dt_val_s:3.1f}hz)" if fps is not None: actual_fps = 1 / dt_val_s if actual_fps < fps - 1: @@ -98,7 +104,9 @@ def predict_action(observation, policy, device, use_amp): observation = copy(observation) with ( torch.inference_mode(), - torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(), + torch.autocast(device_type=device.type) + if device.type == "cuda" and use_amp + else nullcontext(), ): # Convert to pytorch format: channel first and float32 in [0,1] with batch dimension for name in observation: @@ -154,7 +162,9 @@ def on_press(key): print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + print( + "Left arrow key pressed. Exiting loop and rerecord the last episode..." + ) events["rerecord_episode"] = True events["exit_early"] = True elif key == keyboard.Key.esc: @@ -180,8 +190,12 @@ def on_press(key): def init_policy(pretrained_policy_name_or_path, policy_overrides): """Instantiate the policy and load fps, device and use_amp from config yaml""" pretrained_policy_path = get_pretrained_policy_path(pretrained_policy_name_or_path) - hydra_cfg = init_hydra_config(pretrained_policy_path / "config.yaml", policy_overrides) - policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path) + hydra_cfg = init_hydra_config( + pretrained_policy_path / "config.yaml", policy_overrides + ) + policy = make_policy( + hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=pretrained_policy_path + ) # Check device is available device = get_safe_torch_device(hydra_cfg.device, log=True) @@ -225,6 +239,7 @@ def record_episode( device, use_amp, fps, + record_delta_actions, ): control_loop( robot=robot, @@ -236,6 +251,7 @@ def record_episode( device=device, use_amp=use_amp, fps=fps, + record_delta_actions=record_delta_actions, teleoperate=policy is None, ) @@ -252,6 +268,7 @@ def control_loop( device=None, use_amp=None, fps=None, + record_delta_actions=False, ): # TODO(rcadene): Add option to record logs if not robot.is_connected: @@ -267,15 +284,21 @@ def control_loop( raise ValueError("When `teleoperate` is True, `policy` should be None.") if dataset is not None and fps is not None and dataset.fps != fps: - raise ValueError(f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps}).") + raise ValueError( + f"The dataset fps should be equal to requested fps ({dataset['fps']} != {fps})." + ) timestamp = 0 start_episode_t = time.perf_counter() while timestamp < control_time_s: start_loop_t = time.perf_counter() + current_joint_positions = robot.follower_arms["main"].read("Present_Position") + if teleoperate: observation, action = robot.teleop_step(record_data=True) + if record_delta_actions: + action["action"] = action["action"] - current_joint_positions else: observation = robot.capture_observation() @@ -290,12 +313,20 @@ def control_loop( frame = {**observation, **action} if "next.reward" in events: frame["next.reward"] = events["next.reward"] + frame["next.done"] = (events["next.reward"] == 1) or ( + events["exit_early"] + ) dataset.add_frame(frame) + # if frame["next.done"]: + # break + if display_cameras and not is_headless(): image_keys = [key for key in observation if "image" in key] for key in image_keys: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.imshow( + key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + ) cv2.waitKey(1) if fps is not None: @@ -335,7 +366,9 @@ def reset_environment(robot, events, reset_time_s): def reset_follower_position(robot: Robot, target_position): current_position = robot.follower_arms["main"].read("Present_Position") - trajectory = torch.from_numpy(np.linspace(current_position, target_position, 30)) # NOTE: 30 is just an aribtrary number + trajectory = torch.from_numpy( + np.linspace(current_position, target_position, 30) + ) # NOTE: 30 is just an aribtrary number for pose in trajectory: robot.send_action(pose) busy_wait(0.015) @@ -371,7 +404,11 @@ def sanity_check_dataset_name(repo_id, policy): def sanity_check_dataset_robot_compatibility( - dataset: LeRobotDataset, robot: Robot, fps: int, use_videos: bool, extra_features: dict = None + dataset: LeRobotDataset, + robot: Robot, + fps: int, + use_videos: bool, + extra_features: dict = None, ) -> None: features_from_robot = get_features_from_robot(robot, use_videos) if extra_features is not None: @@ -385,11 +422,14 @@ def sanity_check_dataset_robot_compatibility( mismatches = [] for field, dataset_value, present_value in fields: - diff = DeepDiff(dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"]) + diff = DeepDiff( + dataset_value, present_value, exclude_regex_paths=[r".*\['info'\]$"] + ) if diff: mismatches.append(f"{field}: expected {present_value}, got {dataset_value}") if mismatches: raise ValueError( - "Dataset metadata compatibility check failed with mismatches:\n" + "\n".join(mismatches) + "Dataset metadata compatibility check failed with mismatches:\n" + + "\n".join(mismatches) ) diff --git a/lerobot/common/robot_devices/motors/dynamixel.py b/lerobot/common/robot_devices/motors/dynamixel.py index 1e1396f76..8a5ad19bd 100644 --- a/lerobot/common/robot_devices/motors/dynamixel.py +++ b/lerobot/common/robot_devices/motors/dynamixel.py @@ -8,7 +8,10 @@ import numpy as np import tqdm -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) from lerobot.common.utils.utils import capture_timestamp_utc PROTOCOL_VERSION = 2.0 @@ -143,7 +146,9 @@ NUM_WRITE_RETRY = 10 -def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: +def convert_degrees_to_steps( + degrees: float | np.ndarray, models: str | list[str] +) -> np.ndarray: """This function converts the degree range to the step range for indicating motors rotation. It assumes a motor achieves a full rotation by going from -180 degree position to +180. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. @@ -378,7 +383,9 @@ def find_motor_indices(self, possible_ids=None, num_retry=2): indices = [] for idx in tqdm.tqdm(possible_ids): try: - present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] + present_idx = self.read_with_motor_ids( + self.motor_models, [idx], "ID", num_retry=num_retry + )[0] except ConnectionError: continue @@ -394,7 +401,9 @@ def find_motor_indices(self, possible_ids=None, num_retry=2): def set_bus_baudrate(self, baudrate): present_bus_baudrate = self.port_handler.getBaudRate() if present_bus_baudrate != baudrate: - print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") + print( + f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}." + ) self.port_handler.setBaudRate(baudrate) if self.port_handler.getBaudRate() != baudrate: @@ -415,7 +424,9 @@ def motor_indices(self) -> list[int]: def set_calibration(self, calibration: dict[str, list]): self.calibration = calibration - def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): + def apply_calibration_autocorrect( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """This function applies the calibration, automatically detects out of range errors for motors values and attempts to correct. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. @@ -428,7 +439,9 @@ def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: values = self.apply_calibration(values, motor_names) return values - def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def apply_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with a "zero position" at 0 degree. @@ -503,7 +516,9 @@ def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | return values - def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def autocorrect_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """This function automatically detects issues with values of motors after calibration, and correct for these issues. Some motors might have values outside of expected maximum bounds after calibration. @@ -545,15 +560,23 @@ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[s values[i] *= -1 # Convert from initial range to range [-180, 180] degrees - calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE - in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) + calib_val = ( + (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + ) + in_range = (calib_val > LOWER_BOUND_DEGREE) and ( + calib_val < UPPER_BOUND_DEGREE + ) # Solve this inequality to find the factor to shift the range into [-180, 180] degrees # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # (- (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= ((resolution // 2) - values[i] - homing_offset) / resolution - low_factor = (-(resolution // 2) - values[i] - homing_offset) / resolution - upp_factor = ((resolution // 2) - values[i] - homing_offset) / resolution + low_factor = ( + -(resolution // 2) - values[i] - homing_offset + ) / resolution + upp_factor = ( + (resolution // 2) - values[i] - homing_offset + ) / resolution elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: start_pos = self.calibration["start_pos"][calib_idx] @@ -561,7 +584,9 @@ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[s # Convert from initial range to range [0, 100] in % calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 - in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) + in_range = (calib_val > LOWER_BOUND_LINEAR) and ( + calib_val < UPPER_BOUND_LINEAR + ) # Solve this inequality to find the factor to shift the range into [0, 100] % # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 @@ -577,19 +602,27 @@ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[s factor = math.ceil(low_factor) if factor > upp_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + raise ValueError( + f"No integer found between bounds [{low_factor=}, {upp_factor=}]" + ) else: factor = math.ceil(upp_factor) if factor > low_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + raise ValueError( + f"No integer found between bounds [{low_factor=}, {upp_factor=}]" + ) if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + out_of_range_str = ( + f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + ) + in_range_str = ( + f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + ) logging.warning( f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " @@ -599,7 +632,9 @@ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[s # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. self.calibration["homing_offset"][calib_idx] += resolution * factor - def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def revert_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """Inverse of `apply_calibration`.""" if motor_names is None: motor_names = self.motor_names @@ -638,7 +673,9 @@ def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | values = np.round(values).astype(np.int32) return values - def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): + def read_with_motor_ids( + self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY + ): if self.mock: import tests.mock_dynamixel_sdk as dxl else: @@ -740,7 +777,9 @@ def read(self, data_name, motor_names: str | list[str] | None = None): values = self.apply_calibration_autocorrect(values, motor_names) # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) + delta_ts_name = get_log_name( + "delta_timestamp_s", "read", data_name, motor_names + ) self.logs[delta_ts_name] = time.perf_counter() - start_time # log the utc time at which the data was received @@ -749,7 +788,9 @@ def read(self, data_name, motor_names: str | list[str] | None = None): return values - def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): + def write_with_motor_ids( + self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY + ): if self.mock: import tests.mock_dynamixel_sdk as dxl else: @@ -778,7 +819,12 @@ def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_r f"{self.packet_handler.getTxRxResult(comm)}" ) - def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): + def write( + self, + data_name, + values: int | float | np.ndarray, + motor_names: str | list[str] | None = None, + ): if not self.is_connected: raise RobotDeviceNotConnectedError( f"DynamixelMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." @@ -839,7 +885,9 @@ def write(self, data_name, values: int | float | np.ndarray, motor_names: str | ) # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) + delta_ts_name = get_log_name( + "delta_timestamp_s", "write", data_name, motor_names + ) self.logs[delta_ts_name] = time.perf_counter() - start_time # TODO(rcadene): should we log the time before sending the write command? diff --git a/lerobot/common/robot_devices/motors/feetech.py b/lerobot/common/robot_devices/motors/feetech.py index 0d5480f7a..51a770f64 100644 --- a/lerobot/common/robot_devices/motors/feetech.py +++ b/lerobot/common/robot_devices/motors/feetech.py @@ -8,7 +8,10 @@ import numpy as np import tqdm -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) from lerobot.common.utils.utils import capture_timestamp_utc PROTOCOL_VERSION = 0 @@ -122,7 +125,9 @@ NUM_WRITE_RETRY = 20 -def convert_degrees_to_steps(degrees: float | np.ndarray, models: str | list[str]) -> np.ndarray: +def convert_degrees_to_steps( + degrees: float | np.ndarray, models: str | list[str] +) -> np.ndarray: """This function converts the degree range to the step range for indicating motors rotation. It assumes a motor achieves a full rotation by going from -180 degree position to +180. The motor resolution (e.g. 4096) corresponds to the number of steps needed to achieve a full rotation. @@ -358,7 +363,9 @@ def find_motor_indices(self, possible_ids=None, num_retry=2): indices = [] for idx in tqdm.tqdm(possible_ids): try: - present_idx = self.read_with_motor_ids(self.motor_models, [idx], "ID", num_retry=num_retry)[0] + present_idx = self.read_with_motor_ids( + self.motor_models, [idx], "ID", num_retry=num_retry + )[0] except ConnectionError: continue @@ -374,7 +381,9 @@ def find_motor_indices(self, possible_ids=None, num_retry=2): def set_bus_baudrate(self, baudrate): present_bus_baudrate = self.port_handler.getBaudRate() if present_bus_baudrate != baudrate: - print(f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}.") + print( + f"Setting bus baud rate to {baudrate}. Previously {present_bus_baudrate}." + ) self.port_handler.setBaudRate(baudrate) if self.port_handler.getBaudRate() != baudrate: @@ -395,7 +404,9 @@ def motor_indices(self) -> list[int]: def set_calibration(self, calibration: dict[str, list]): self.calibration = calibration - def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: list[str] | None): + def apply_calibration_autocorrect( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """This function apply the calibration, automatically detects out of range errors for motors values and attempt to correct. For more info, see docstring of `apply_calibration` and `autocorrect_calibration`. @@ -408,7 +419,9 @@ def apply_calibration_autocorrect(self, values: np.ndarray | list, motor_names: values = self.apply_calibration(values, motor_names) return values - def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def apply_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """Convert from unsigned int32 joint position range [0, 2**32[ to the universal float32 nominal degree range ]-180.0, 180.0[ with a "zero position" at 0 degree. @@ -482,7 +495,9 @@ def apply_calibration(self, values: np.ndarray | list, motor_names: list[str] | return values - def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def autocorrect_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """This function automatically detects issues with values of motors after calibration, and correct for these issues. Some motors might have values outside of expected maximum bounds after calibration. @@ -521,18 +536,26 @@ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[s values[i] *= -1 # Convert from initial range to range [-180, 180] degrees - calib_val = (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE - in_range = (calib_val > LOWER_BOUND_DEGREE) and (calib_val < UPPER_BOUND_DEGREE) + calib_val = ( + (values[i] + homing_offset) / (resolution // 2) * HALF_TURN_DEGREE + ) + in_range = (calib_val > LOWER_BOUND_DEGREE) and ( + calib_val < UPPER_BOUND_DEGREE + ) # Solve this inequality to find the factor to shift the range into [-180, 180] degrees # values[i] = (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE # - HALF_TURN_DEGREE <= (values[i] + homing_offset + resolution * factor) / (resolution // 2) * HALF_TURN_DEGREE <= HALF_TURN_DEGREE # (- HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset) / resolution <= factor <= (HALF_TURN_DEGREE / 180 * (resolution // 2) - values[i] - homing_offset) / resolution low_factor = ( - -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset + -HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) + - values[i] + - homing_offset ) / resolution upp_factor = ( - HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) - values[i] - homing_offset + HALF_TURN_DEGREE / HALF_TURN_DEGREE * (resolution // 2) + - values[i] + - homing_offset ) / resolution elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: @@ -541,7 +564,9 @@ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[s # Convert from initial range to range [0, 100] in % calib_val = (values[i] - start_pos) / (end_pos - start_pos) * 100 - in_range = (calib_val > LOWER_BOUND_LINEAR) and (calib_val < UPPER_BOUND_LINEAR) + in_range = (calib_val > LOWER_BOUND_LINEAR) and ( + calib_val < UPPER_BOUND_LINEAR + ) # Solve this inequality to find the factor to shift the range into [0, 100] % # values[i] = (values[i] - start_pos + resolution * factor) / (end_pos + resolution * factor - start_pos - resolution * factor) * 100 @@ -557,19 +582,27 @@ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[s factor = math.ceil(low_factor) if factor > upp_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + raise ValueError( + f"No integer found between bounds [{low_factor=}, {upp_factor=}]" + ) else: factor = math.ceil(upp_factor) if factor > low_factor: - raise ValueError(f"No integer found between bounds [{low_factor=}, {upp_factor=}]") + raise ValueError( + f"No integer found between bounds [{low_factor=}, {upp_factor=}]" + ) if CalibrationMode[calib_mode] == CalibrationMode.DEGREE: out_of_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" in_range_str = f"{LOWER_BOUND_DEGREE} < {calib_val} < {UPPER_BOUND_DEGREE} degrees" elif CalibrationMode[calib_mode] == CalibrationMode.LINEAR: - out_of_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" - in_range_str = f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + out_of_range_str = ( + f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + ) + in_range_str = ( + f"{LOWER_BOUND_LINEAR} < {calib_val} < {UPPER_BOUND_LINEAR} %" + ) logging.warning( f"Auto-correct calibration of motor '{name}' by shifting value by {abs(factor)} full turns, " @@ -579,7 +612,9 @@ def autocorrect_calibration(self, values: np.ndarray | list, motor_names: list[s # A full turn corresponds to 360 degrees but also to 4096 steps for a motor resolution of 4096. self.calibration["homing_offset"][calib_idx] += resolution * factor - def revert_calibration(self, values: np.ndarray | list, motor_names: list[str] | None): + def revert_calibration( + self, values: np.ndarray | list, motor_names: list[str] | None + ): """Inverse of `apply_calibration`.""" if motor_names is None: motor_names = self.motor_names @@ -655,7 +690,9 @@ def avoid_rotation_reset(self, values, motor_names, data_name): return values - def read_with_motor_ids(self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY): + def read_with_motor_ids( + self, motor_models, motor_ids, data_name, num_retry=NUM_READ_RETRY + ): if self.mock: import tests.mock_scservo_sdk as scs else: @@ -760,7 +797,9 @@ def read(self, data_name, motor_names: str | list[str] | None = None): values = self.apply_calibration_autocorrect(values, motor_names) # log the number of seconds it took to read the data from the motors - delta_ts_name = get_log_name("delta_timestamp_s", "read", data_name, motor_names) + delta_ts_name = get_log_name( + "delta_timestamp_s", "read", data_name, motor_names + ) self.logs[delta_ts_name] = time.perf_counter() - start_time # log the utc time at which the data was received @@ -769,7 +808,9 @@ def read(self, data_name, motor_names: str | list[str] | None = None): return values - def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY): + def write_with_motor_ids( + self, motor_models, motor_ids, data_name, values, num_retry=NUM_WRITE_RETRY + ): if self.mock: import tests.mock_scservo_sdk as scs else: @@ -798,7 +839,12 @@ def write_with_motor_ids(self, motor_models, motor_ids, data_name, values, num_r f"{self.packet_handler.getTxRxResult(comm)}" ) - def write(self, data_name, values: int | float | np.ndarray, motor_names: str | list[str] | None = None): + def write( + self, + data_name, + values: int | float | np.ndarray, + motor_names: str | list[str] | None = None, + ): if not self.is_connected: raise RobotDeviceNotConnectedError( f"FeetechMotorsBus({self.port}) is not connected. You need to run `motors_bus.connect()`." @@ -859,7 +905,9 @@ def write(self, data_name, values: int | float | np.ndarray, motor_names: str | ) # log the number of seconds it took to write the data to the motors - delta_ts_name = get_log_name("delta_timestamp_s", "write", data_name, motor_names) + delta_ts_name = get_log_name( + "delta_timestamp_s", "write", data_name, motor_names + ) self.logs[delta_ts_name] = time.perf_counter() - start_time # TODO(rcadene): should we log the time before sending the write command? diff --git a/lerobot/common/robot_devices/robots/dynamixel_calibration.py b/lerobot/common/robot_devices/robots/dynamixel_calibration.py index 5c4932d2e..b6aa976ae 100644 --- a/lerobot/common/robot_devices/robots/dynamixel_calibration.py +++ b/lerobot/common/robot_devices/robots/dynamixel_calibration.py @@ -10,9 +10,7 @@ ) from lerobot.common.robot_devices.motors.utils import MotorsBus -URL_TEMPLATE = ( - "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" -) +URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" # The following positions are provided in nominal degree range ]-180, +180[ # For more info on these constants, see comments in the code where they get used. @@ -23,7 +21,9 @@ def assert_drive_mode(drive_mode): # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. if not np.all(np.isin(drive_mode, [0, 1])): - raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") + raise ValueError( + f"`drive_mode` contains values other than 0 or 1: ({drive_mode})" + ) def apply_drive_mode(position, drive_mode): @@ -64,12 +64,16 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type ``` """ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") + raise ValueError( + "To run calibration, the torque must be disabled on all motors." + ) print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to zero position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) + print( + "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero") + ) input("Press Enter to continue...") # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. @@ -90,10 +94,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view # of the previous motor in the kinetic chain. print("\nMove arm to rotated target position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) + print( + "See: " + + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated") + ) input("Press Enter to continue...") - rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) + rotated_target_pos = convert_degrees_to_steps( + ROTATED_POSITION_DEGREE, arm.motor_models + ) # Find drive mode by rotating each motor by a quarter of a turn. # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). @@ -102,11 +111,15 @@ def run_arm_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type # Re-compute homing offset to take into account drive mode rotated_drived_pos = apply_drive_mode(rotated_pos, drive_mode) - rotated_nearest_pos = compute_nearest_rounded_position(rotated_drived_pos, arm.motor_models) + rotated_nearest_pos = compute_nearest_rounded_position( + rotated_drived_pos, arm.motor_models + ) homing_offset = rotated_target_pos - rotated_nearest_pos print("\nMove arm to rest position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) + print( + "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest") + ) input("Press Enter to continue...") print() diff --git a/lerobot/common/robot_devices/robots/feetech_calibration.py b/lerobot/common/robot_devices/robots/feetech_calibration.py index b015951a0..f702e6d82 100644 --- a/lerobot/common/robot_devices/robots/feetech_calibration.py +++ b/lerobot/common/robot_devices/robots/feetech_calibration.py @@ -12,9 +12,7 @@ ) from lerobot.common.robot_devices.motors.utils import MotorsBus -URL_TEMPLATE = ( - "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" -) +URL_TEMPLATE = "https://raw.githubusercontent.com/huggingface/lerobot/main/media/{robot}/{arm}_{position}.webp" # The following positions are provided in nominal degree range ]-180, +180[ # For more info on these constants, see comments in the code where they get used. @@ -25,7 +23,9 @@ def assert_drive_mode(drive_mode): # `drive_mode` is in [0,1] with 0 means original rotation direction for the motor, and 1 means inverted. if not np.all(np.isin(drive_mode, [0, 1])): - raise ValueError(f"`drive_mode` contains values other than 0 or 1: ({drive_mode})") + raise ValueError( + f"`drive_mode` contains values other than 0 or 1: ({drive_mode})" + ) def apply_drive_mode(position, drive_mode): @@ -126,7 +126,9 @@ def apply_offset(calib, offset): return calib -def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): +def run_arm_auto_calibration( + arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str +): if robot_type == "so100": return run_arm_auto_calibration_so100(arm, robot_type, arm_name, arm_type) elif robot_type == "moss": @@ -135,18 +137,27 @@ def run_arm_auto_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm raise ValueError(robot_type) -def run_arm_auto_calibration_so100(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): +def run_arm_auto_calibration_so100( + arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str +): """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") + raise ValueError( + "To run calibration, the torque must be disabled on all motors." + ) if not (robot_type == "so100" and arm_type == "follower"): - raise NotImplementedError("Auto calibration only supports the follower of so100 arms for now.") + raise NotImplementedError( + "Auto calibration only supports the follower of so100 arms for now." + ) print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to initial position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) + print( + "See: " + + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial") + ) input("Press Enter to continue...") # Lower the acceleration of the motors (in [0,254]) @@ -193,11 +204,16 @@ def in_between_move_hook(): print("Calibrate elbow_flex") calib["elbow_flex"] = move_to_calibrate( - arm, "elbow_flex", positive_first=False, in_between_move_hook=in_between_move_hook + arm, + "elbow_flex", + positive_first=False, + in_between_move_hook=in_between_move_hook, ) calib["elbow_flex"] = apply_offset(calib["elbow_flex"], offset=80 - 1024) - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex") + arm.write( + "Goal_Position", calib["elbow_flex"]["zero_pos"] + 1024 + 512, "elbow_flex" + ) time.sleep(1) def in_between_move_hook(): @@ -225,18 +241,30 @@ def while_move_hook(): } arm.write("Goal_Position", list(positions.values()), list(positions.keys())) - arm.write("Goal_Position", round(calib["shoulder_lift"]["zero_pos"] - 1600), "shoulder_lift") + arm.write( + "Goal_Position", + round(calib["shoulder_lift"]["zero_pos"] - 1600), + "shoulder_lift", + ) time.sleep(2) - arm.write("Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex") + arm.write( + "Goal_Position", round(calib["elbow_flex"]["zero_pos"] + 1700), "elbow_flex" + ) time.sleep(2) - arm.write("Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex") + arm.write( + "Goal_Position", round(calib["wrist_flex"]["zero_pos"] + 800), "wrist_flex" + ) time.sleep(2) arm.write("Goal_Position", round(calib["gripper"]["end_pos"]), "gripper") time.sleep(2) print("Calibrate wrist_roll") calib["wrist_roll"] = move_to_calibrate( - arm, "wrist_roll", invert_drive_mode=True, positive_first=False, while_move_hook=while_move_hook + arm, + "wrist_roll", + invert_drive_mode=True, + positive_first=False, + while_move_hook=while_move_hook, ) arm.write("Goal_Position", calib["wrist_roll"]["zero_pos"], "wrist_roll") @@ -246,7 +274,9 @@ def while_move_hook(): arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"], "wrist_flex") time.sleep(1) arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] + 2048, "elbow_flex") - arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift") + arm.write( + "Goal_Position", calib["shoulder_lift"]["zero_pos"] - 2048, "shoulder_lift" + ) time.sleep(1) arm.write("Goal_Position", calib["shoulder_pan"]["zero_pos"], "shoulder_pan") time.sleep(1) @@ -275,18 +305,27 @@ def while_move_hook(): return calib_dict -def run_arm_auto_calibration_moss(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): +def run_arm_auto_calibration_moss( + arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str +): """All the offsets and magic numbers are hand tuned, and are unique to SO-100 follower arms""" if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") + raise ValueError( + "To run calibration, the torque must be disabled on all motors." + ) if not (robot_type == "moss" and arm_type == "follower"): - raise NotImplementedError("Auto calibration only supports the follower of moss arms for now.") + raise NotImplementedError( + "Auto calibration only supports the follower of moss arms for now." + ) print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to initial position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial")) + print( + "See: " + + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="initial") + ) input("Press Enter to continue...") # Lower the acceleration of the motors (in [0,254]) @@ -370,8 +409,12 @@ def in_between_move_shoulder_lift_hook(): arm.write("Goal_Position", calib["wrist_flex"]["zero_pos"] - 1024, "wrist_flex") time.sleep(1) - arm.write("Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift") - arm.write("Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex") + arm.write( + "Goal_Position", calib["shoulder_lift"]["zero_pos"] + 2048, "shoulder_lift" + ) + arm.write( + "Goal_Position", calib["elbow_flex"]["zero_pos"] - 1024 - 400, "elbow_flex" + ) time.sleep(2) calib_modes = [] @@ -398,7 +441,9 @@ def in_between_move_shoulder_lift_hook(): return calib_dict -def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str): +def run_arm_manual_calibration( + arm: MotorsBus, robot_type: str, arm_name: str, arm_type: str +): """This function ensures that a neural network trained on data collected on a given robot can work on another robot. For instance before calibration, setting a same goal position for each motor of two different robots will get two very different positions. But after calibration, @@ -421,12 +466,16 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a ``` """ if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run calibration, the torque must be disabled on all motors.") + raise ValueError( + "To run calibration, the torque must be disabled on all motors." + ) print(f"\nRunning calibration of {robot_type} {arm_name} {arm_type}...") print("\nMove arm to zero position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero")) + print( + "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="zero") + ) input("Press Enter to continue...") # We arbitrarily chose our zero target position to be a straight horizontal position with gripper upwards and closed. @@ -446,10 +495,15 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a # corresponds to opening the gripper. When the rotation direction is ambiguous, we arbitrarely rotate clockwise from the point of view # of the previous motor in the kinetic chain. print("\nMove arm to rotated target position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated")) + print( + "See: " + + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rotated") + ) input("Press Enter to continue...") - rotated_target_pos = convert_degrees_to_steps(ROTATED_POSITION_DEGREE, arm.motor_models) + rotated_target_pos = convert_degrees_to_steps( + ROTATED_POSITION_DEGREE, arm.motor_models + ) # Find drive mode by rotating each motor by a quarter of a turn. # Drive mode indicates if the motor rotation direction should be inverted (=1) or not (=0). @@ -461,7 +515,9 @@ def run_arm_manual_calibration(arm: MotorsBus, robot_type: str, arm_name: str, a homing_offset = rotated_target_pos - rotated_drived_pos print("\nMove arm to rest position") - print("See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest")) + print( + "See: " + URL_TEMPLATE.format(robot=robot_type, arm=arm_type, position="rest") + ) input("Press Enter to continue...") print() diff --git a/lerobot/common/robot_devices/robots/manipulator.py b/lerobot/common/robot_devices/robots/manipulator.py index 618105064..8671a02b9 100644 --- a/lerobot/common/robot_devices/robots/manipulator.py +++ b/lerobot/common/robot_devices/robots/manipulator.py @@ -18,11 +18,16 @@ from lerobot.common.robot_devices.cameras.utils import Camera from lerobot.common.robot_devices.motors.utils import MotorsBus from lerobot.common.robot_devices.robots.utils import get_arm_id -from lerobot.common.robot_devices.utils import RobotDeviceAlreadyConnectedError, RobotDeviceNotConnectedError +from lerobot.common.robot_devices.utils import ( + RobotDeviceAlreadyConnectedError, + RobotDeviceNotConnectedError, +) def ensure_safe_goal_position( - goal_pos: torch.Tensor, present_pos: torch.Tensor, max_relative_target: float | list[float] + goal_pos: torch.Tensor, + present_pos: torch.Tensor, + max_relative_target: float | list[float], ): # Cap relative action target magnitude for safety. diff = goal_pos - present_pos @@ -32,7 +37,7 @@ def ensure_safe_goal_position( safe_goal_pos = present_pos + safe_diff if not torch.allclose(goal_pos, safe_goal_pos): - logging.warning( + logging.debug( "Relative goal position magnitude had to be clamped to be safe.\n" f" requested relative goal position target: {diff}\n" f" clamped relative goal position target: {safe_diff}" @@ -67,8 +72,14 @@ class ManipulatorRobotConfig: # gripper is not put in torque mode. gripper_open_degree: float | None = None + joint_position_relative_bounds: dict[np.ndarray] | None = None + def __setattr__(self, prop: str, val): - if prop == "max_relative_target" and val is not None and isinstance(val, Sequence): + if ( + prop == "max_relative_target" + and val is not None + and isinstance(val, Sequence) + ): for name in self.follower_arms: if len(self.follower_arms[name].motors) != len(val): raise ValueError( @@ -78,11 +89,16 @@ def __setattr__(self, prop: str, val): "Note: This feature does not yet work with robots where different follower arms have " "different numbers of motors." ) + if prop == "joint_position_relative_bounds" and val is not None: + for key in val: + val[key] = torch.tensor(val[key]) super().__setattr__(prop, val) def __post_init__(self): if self.robot_type not in ["koch", "koch_bimanual", "aloha", "so100", "moss"]: - raise ValueError(f"Provided robot type ({self.robot_type}) is not supported.") + raise ValueError( + f"Provided robot type ({self.robot_type}) is not supported." + ) class ManipulatorRobot: @@ -336,7 +352,9 @@ def connect(self): # to squeeze the gripper and have it spring back to an open position on its own. for name in self.leader_arms: self.leader_arms[name].write("Torque_Enable", 1, "gripper") - self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") + self.leader_arms[name].write( + "Goal_Position", self.config.gripper_open_degree, "gripper" + ) # Check both arms can be read for name in self.follower_arms: @@ -368,18 +386,26 @@ def load_or_run_calibration_(name, arm, arm_type): print(f"Missing calibration file '{arm_calib_path}'") if self.robot_type in ["koch", "koch_bimanual", "aloha"]: - from lerobot.common.robot_devices.robots.dynamixel_calibration import run_arm_calibration + from lerobot.common.robot_devices.robots.dynamixel_calibration import ( + run_arm_calibration, + ) - calibration = run_arm_calibration(arm, self.robot_type, name, arm_type) + calibration = run_arm_calibration( + arm, self.robot_type, name, arm_type + ) elif self.robot_type in ["so100", "moss"]: from lerobot.common.robot_devices.robots.feetech_calibration import ( run_arm_manual_calibration, ) - calibration = run_arm_manual_calibration(arm, self.robot_type, name, arm_type) + calibration = run_arm_manual_calibration( + arm, self.robot_type, name, arm_type + ) - print(f"Calibration is done! Saving calibration file '{arm_calib_path}'") + print( + f"Calibration is done! Saving calibration file '{arm_calib_path}'" + ) arm_calib_path.parent.mkdir(parents=True, exist_ok=True) with open(arm_calib_path, "w") as f: json.dump(calibration, f) @@ -398,13 +424,17 @@ def set_operating_mode_(arm): from lerobot.common.robot_devices.motors.dynamixel import TorqueMode if (arm.read("Torque_Enable") != TorqueMode.DISABLED.value).any(): - raise ValueError("To run set robot preset, the torque must be disabled on all motors.") + raise ValueError( + "To run set robot preset, the torque must be disabled on all motors." + ) # Use 'extended position mode' for all motors except gripper, because in joint mode the servos can't # rotate more than 360 degrees (from 0 to 4095) And some mistake can happen while assembling the arm, # you could end up with a servo with a position 0 or 4095 at a crucial point See [ # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] - all_motors_except_gripper = [name for name in arm.motor_names if name != "gripper"] + all_motors_except_gripper = [ + name for name in arm.motor_names if name != "gripper" + ] if len(all_motors_except_gripper) > 0: # 4 corresponds to Extended Position on Koch motors arm.write("Operating_Mode", 4, all_motors_except_gripper) @@ -433,7 +463,9 @@ def set_operating_mode_(arm): # Enable torque on the gripper of the leader arms, and move it to 45 degrees, # so that we can use it as a trigger to close the gripper of the follower arms. self.leader_arms[name].write("Torque_Enable", 1, "gripper") - self.leader_arms[name].write("Goal_Position", self.config.gripper_open_degree, "gripper") + self.leader_arms[name].write( + "Goal_Position", self.config.gripper_open_degree, "gripper" + ) def set_aloha_robot_preset(self): def set_shadow_(arm): @@ -463,11 +495,15 @@ def set_shadow_(arm): # you could end up with a servo with a position 0 or 4095 at a crucial point See [ # https://emanual.robotis.com/docs/en/dxl/x/x_series/#operating-mode11] all_motors_except_gripper = [ - name for name in self.follower_arms[name].motor_names if name != "gripper" + name + for name in self.follower_arms[name].motor_names + if name != "gripper" ] if len(all_motors_except_gripper) > 0: # 4 corresponds to Extended Position on Aloha motors - self.follower_arms[name].write("Operating_Mode", 4, all_motors_except_gripper) + self.follower_arms[name].write( + "Operating_Mode", 4, all_motors_except_gripper + ) # Use 'position control current based' for follower gripper to be limited by the limit of the current. # It can grasp an object without forcing too much even tho, @@ -515,7 +551,9 @@ def teleop_step( before_lread_t = time.perf_counter() leader_pos[name] = self.leader_arms[name].read("Present_Position") leader_pos[name] = torch.from_numpy(leader_pos[name]) - self.logs[f"read_leader_{name}_pos_dt_s"] = time.perf_counter() - before_lread_t + self.logs[f"read_leader_{name}_pos_dt_s"] = ( + time.perf_counter() - before_lread_t + ) # Send goal position to the follower follower_goal_pos = {} @@ -523,19 +561,31 @@ def teleop_step( before_fwrite_t = time.perf_counter() goal_pos = leader_pos[name] + # If specified, clip the goal positions within predefined bounds specified in the config of the robot + if self.config.joint_position_relative_bounds is not None: + goal_pos = torch.clamp( + goal_pos, + self.config.joint_position_relative_bounds["min"], + self.config.joint_position_relative_bounds["max"], + ) + # Cap goal position when too far away from present position. # Slower fps expected due to reading from the follower. if self.config.max_relative_target is not None: present_pos = self.follower_arms[name].read("Present_Position") present_pos = torch.from_numpy(present_pos) - goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + goal_pos = ensure_safe_goal_position( + goal_pos, present_pos, self.config.max_relative_target + ) # Used when record_data=True follower_goal_pos[name] = goal_pos goal_pos = goal_pos.numpy().astype(np.int32) self.follower_arms[name].write("Goal_Position", goal_pos) - self.logs[f"write_follower_{name}_goal_pos_dt_s"] = time.perf_counter() - before_fwrite_t + self.logs[f"write_follower_{name}_goal_pos_dt_s"] = ( + time.perf_counter() - before_fwrite_t + ) # Early exit when recording data is not requested if not record_data: @@ -548,7 +598,9 @@ def teleop_step( before_fread_t = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = torch.from_numpy(follower_pos[name]) - self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + self.logs[f"read_follower_{name}_pos_dt_s"] = ( + time.perf_counter() - before_fread_t + ) # Create state by concatenating follower current position state = [] @@ -570,8 +622,12 @@ def teleop_step( before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ + "delta_timestamp_s" + ] + self.logs[f"async_read_camera_{name}_dt_s"] = ( + time.perf_counter() - before_camread_t + ) # Populate output dictionnaries obs_dict, action_dict = {}, {} @@ -595,7 +651,9 @@ def capture_observation(self): before_fread_t = time.perf_counter() follower_pos[name] = self.follower_arms[name].read("Present_Position") follower_pos[name] = torch.from_numpy(follower_pos[name]) - self.logs[f"read_follower_{name}_pos_dt_s"] = time.perf_counter() - before_fread_t + self.logs[f"read_follower_{name}_pos_dt_s"] = ( + time.perf_counter() - before_fread_t + ) # Create state by concatenating follower current position state = [] @@ -610,8 +668,12 @@ def capture_observation(self): before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ + "delta_timestamp_s" + ] + self.logs[f"async_read_camera_{name}_dt_s"] = ( + time.perf_counter() - before_camread_t + ) # Populate output dictionnaries and format to pytorch obs_dict = {} @@ -644,18 +706,29 @@ def send_action(self, action: torch.Tensor) -> torch.Tensor: goal_pos = action[from_idx:to_idx] from_idx = to_idx + # If specified, clip the goal positions within predefined bounds specified in the config of the robot + if self.config.joint_position_relative_bounds is not None: + goal_pos = torch.clamp( + goal_pos, + self.config.joint_position_relative_bounds["min"], + self.config.joint_position_relative_bounds["max"], + ) + # Cap goal position when too far away from present position. # Slower fps expected due to reading from the follower. if self.config.max_relative_target is not None: present_pos = self.follower_arms[name].read("Present_Position") present_pos = torch.from_numpy(present_pos) - goal_pos = ensure_safe_goal_position(goal_pos, present_pos, self.config.max_relative_target) + goal_pos = ensure_safe_goal_position( + goal_pos, present_pos, self.config.max_relative_target + ) # Save tensor to concat and return action_sent.append(goal_pos) # Send goal position to each follower goal_pos = goal_pos.numpy().astype(np.int32) + self.follower_arms[name].write("Goal_Position", goal_pos) return torch.cat(action_sent) diff --git a/lerobot/common/robot_devices/robots/stretch.py b/lerobot/common/robot_devices/robots/stretch.py index ff86b6d80..13209715b 100644 --- a/lerobot/common/robot_devices/robots/stretch.py +++ b/lerobot/common/robot_devices/robots/stretch.py @@ -60,7 +60,9 @@ def __init__(self, config: StretchRobotConfig | None = None, **kwargs): def connect(self) -> None: self.is_connected = self.startup() if not self.is_connected: - print("Another process is already using Stretch. Try running 'stretch_free_robot_process.py'") + print( + "Another process is already using Stretch. Try running 'stretch_free_robot_process.py'" + ) raise ConnectionError() for name in self.cameras: @@ -68,7 +70,9 @@ def connect(self) -> None: self.is_connected = self.is_connected and self.cameras[name].is_connected if not self.is_connected: - print("Could not connect to the cameras, check that all cameras are plugged-in.") + print( + "Could not connect to the cameras, check that all cameras are plugged-in." + ) raise ConnectionError() self.run_calibration() @@ -113,8 +117,12 @@ def teleop_step( before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ + "delta_timestamp_s" + ] + self.logs[f"async_read_camera_{name}_dt_s"] = ( + time.perf_counter() - before_camread_t + ) # Populate output dictionnaries obs_dict, action_dict = {}, {} @@ -158,8 +166,12 @@ def capture_observation(self) -> dict: before_camread_t = time.perf_counter() images[name] = self.cameras[name].async_read() images[name] = torch.from_numpy(images[name]) - self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs["delta_timestamp_s"] - self.logs[f"async_read_camera_{name}_dt_s"] = time.perf_counter() - before_camread_t + self.logs[f"read_camera_{name}_dt_s"] = self.cameras[name].logs[ + "delta_timestamp_s" + ] + self.logs[f"async_read_camera_{name}_dt_s"] = ( + time.perf_counter() - before_camread_t + ) # Populate output dictionnaries obs_dict = {} diff --git a/lerobot/common/robot_devices/utils.py b/lerobot/common/robot_devices/utils.py index 19bb637e5..fe9c4f42b 100644 --- a/lerobot/common/robot_devices/utils.py +++ b/lerobot/common/robot_devices/utils.py @@ -34,7 +34,8 @@ class RobotDeviceNotConnectedError(Exception): """Exception raised when the robot device is not connected.""" def __init__( - self, message="This robot device is not connected. Try calling `robot_device.connect()` first." + self, + message="This robot device is not connected. Try calling `robot_device.connect()` first.", ): self.message = message super().__init__(self.message) diff --git a/lerobot/common/utils/import_utils.py b/lerobot/common/utils/import_utils.py index cd5f82450..e2ce5a87d 100644 --- a/lerobot/common/utils/import_utils.py +++ b/lerobot/common/utils/import_utils.py @@ -17,7 +17,9 @@ import logging -def is_package_available(pkg_name: str, return_version: bool = False) -> tuple[bool, str] | bool: +def is_package_available( + pkg_name: str, return_version: bool = False +) -> tuple[bool, str] | bool: """Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/utils/import_utils.py Check if the package spec exists and grab its version to avoid importing a local directory. **Note:** this doesn't work for all packages. diff --git a/lerobot/common/utils/io_utils.py b/lerobot/common/utils/io_utils.py index b85f17c7a..664b8a0d4 100644 --- a/lerobot/common/utils/io_utils.py +++ b/lerobot/common/utils/io_utils.py @@ -22,6 +22,8 @@ def write_video(video_path, stacked_frames, fps): # Filter out DeprecationWarnings raised from pkg_resources with warnings.catch_warnings(): warnings.filterwarnings( - "ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning + "ignore", + "pkg_resources is deprecated as an API", + category=DeprecationWarning, ) imageio.mimsave(video_path, stacked_frames, fps=fps) diff --git a/lerobot/common/utils/utils.py b/lerobot/common/utils/utils.py index 4e276e169..fecf88f98 100644 --- a/lerobot/common/utils/utils.py +++ b/lerobot/common/utils/utils.py @@ -18,6 +18,7 @@ import os.path as osp import platform import random +import time from contextlib import contextmanager from datetime import datetime, timezone from pathlib import Path @@ -115,11 +116,11 @@ def seeded_context(seed: int) -> Generator[None, None, None]: set_global_random_state(random_state_dict) -def init_logging(): +def init_logging(log_file=None): def custom_format(record): dt = datetime.now().strftime("%Y-%m-%d %H:%M:%S") fnameline = f"{record.pathname}:{record.lineno}" - message = f"{record.levelname} {dt} {fnameline[-15:]:>15} {record.msg}" + message = f"{record.levelname} [PID: {os.getpid()}] {dt} {fnameline[-15:]:>15} {record.msg}" return message logging.basicConfig(level=logging.INFO) @@ -133,6 +134,12 @@ def custom_format(record): console_handler.setFormatter(formatter) logging.getLogger().addHandler(console_handler) + if log_file is not None: + # File handler + file_handler = logging.FileHandler(log_file) + file_handler.setFormatter(formatter) + logging.getLogger().addHandler(file_handler) + def format_big_number(num, precision=0): suffixes = ["", "K", "M", "B", "T", "Q"] @@ -155,11 +162,16 @@ def _relative_path_between(path1: Path, path2: Path) -> Path: except ValueError: # most likely because path1 is not a subpath of path2 common_parts = Path(osp.commonpath([path1, path2])).parts return Path( - "/".join([".."] * (len(path2.parts) - len(common_parts)) + list(path1.parts[len(common_parts) :])) + "/".join( + [".."] * (len(path2.parts) - len(common_parts)) + + list(path1.parts[len(common_parts) :]) + ) ) -def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> DictConfig: +def init_hydra_config( + config_path: str, overrides: list[str] | None = None +) -> DictConfig: """Initialize a Hydra config given only the path to the relevant config file. For config resolution, it is assumed that the config file's parent is the Hydra config dir. @@ -168,7 +180,11 @@ def init_hydra_config(config_path: str, overrides: list[str] | None = None) -> D hydra.core.global_hydra.GlobalHydra.instance().clear() # Hydra needs a path relative to this file. hydra.initialize( - str(_relative_path_between(Path(config_path).absolute().parent, Path(__file__).absolute().parent)), + str( + _relative_path_between( + Path(config_path).absolute().parent, Path(__file__).absolute().parent + ) + ), version_base="1.2", ) cfg = hydra.compose(Path(config_path).stem, overrides) @@ -182,10 +198,26 @@ def print_cuda_memory_usage(): gc.collect() # Also clear the cache if you want to fully release the memory torch.cuda.empty_cache() - print("Current GPU Memory Allocated: {:.2f} MB".format(torch.cuda.memory_allocated(0) / 1024**2)) - print("Maximum GPU Memory Allocated: {:.2f} MB".format(torch.cuda.max_memory_allocated(0) / 1024**2)) - print("Current GPU Memory Reserved: {:.2f} MB".format(torch.cuda.memory_reserved(0) / 1024**2)) - print("Maximum GPU Memory Reserved: {:.2f} MB".format(torch.cuda.max_memory_reserved(0) / 1024**2)) + print( + "Current GPU Memory Allocated: {:.2f} MB".format( + torch.cuda.memory_allocated(0) / 1024**2 + ) + ) + print( + "Maximum GPU Memory Allocated: {:.2f} MB".format( + torch.cuda.max_memory_allocated(0) / 1024**2 + ) + ) + print( + "Current GPU Memory Reserved: {:.2f} MB".format( + torch.cuda.memory_reserved(0) / 1024**2 + ) + ) + print( + "Maximum GPU Memory Reserved: {:.2f} MB".format( + torch.cuda.max_memory_reserved(0) / 1024**2 + ) + ) def capture_timestamp_utc(): @@ -217,3 +249,33 @@ def log_say(text, play_sounds, blocking=False): if play_sounds: say(text, blocking) + + +class TimerManager: + def __init__( + self, + elapsed_time_list: list[float] | None = None, + label="Elapsed time", + log=True, + ): + self.label = label + self.elapsed_time_list = elapsed_time_list + self.log = log + self.elapsed = 0.0 + + def __enter__(self): + self.start = time.perf_counter() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.elapsed: float = time.perf_counter() - self.start + + if self.elapsed_time_list is not None: + self.elapsed_time_list.append(self.elapsed) + + if self.log: + print(f"{self.label}: {self.elapsed:.6f} seconds") + + @property + def elapsed_seconds(self): + return self.elapsed diff --git a/lerobot/configs/default.yaml b/lerobot/configs/default.yaml index a3ff1d41b..7750ba3a3 100644 --- a/lerobot/configs/default.yaml +++ b/lerobot/configs/default.yaml @@ -2,6 +2,7 @@ defaults: - _self_ - env: pusht - policy: diffusion + - robot: so100 hydra: run: diff --git a/lerobot/configs/env/maniskill_example.yaml b/lerobot/configs/env/maniskill_example.yaml new file mode 100644 index 000000000..3df23b2ec --- /dev/null +++ b/lerobot/configs/env/maniskill_example.yaml @@ -0,0 +1,30 @@ +# @package _global_ + +fps: 20 + +env: + name: maniskill/pushcube + task: PushCube-v1 + image_size: 64 + control_mode: pd_ee_delta_pose + state_dim: 25 + action_dim: 7 + fps: ${fps} + obs: rgb + render_mode: rgb_array + render_size: 64 + device: cuda + + reward_classifier: + pretrained_path: null + config_path: null + + wrapper: + joint_masking_action_space: null + delta_action: null + + video_record: + enabled: false + record_dir: maniskill_videos + trajectory_name: trajectory + fps: ${fps} diff --git a/lerobot/configs/env/so100_real.yaml b/lerobot/configs/env/so100_real.yaml index 8e65d72f4..dc30224c1 100644 --- a/lerobot/configs/env/so100_real.yaml +++ b/lerobot/configs/env/so100_real.yaml @@ -1,6 +1,6 @@ # @package _global_ -fps: 30 +fps: 10 env: name: real_world @@ -8,3 +8,23 @@ env: state_dim: 6 action_dim: 6 fps: ${fps} + device: mps + + wrapper: + crop_params_dict: + observation.images.front: [102, 43, 358, 523] + observation.images.side: [92, 123, 379, 349] + # observation.images.front: [109, 37, 361, 557] + # observation.images.side: [94, 161, 372, 315] + resize_size: [128, 128] + control_time_s: 20 + reset_follower_pos: true + use_relative_joint_positions: true + reset_time_s: 5 + display_cameras: false + delta_action: 0.1 + joint_masking_action_space: [1, 1, 1, 1, 0, 0] # disable wrist and gripper + + reward_classifier: + pretrained_path: outputs/classifier/13-02-random-sample-resnet10-frozen/checkpoints/best/pretrained_model + config_path: lerobot/configs/policy/hilserl_classifier.yaml diff --git a/lerobot/configs/policy/hilserl_classifier.yaml b/lerobot/configs/policy/hilserl_classifier.yaml index f8137b696..9ab181d53 100644 --- a/lerobot/configs/policy/hilserl_classifier.yaml +++ b/lerobot/configs/policy/hilserl_classifier.yaml @@ -3,8 +3,19 @@ defaults: - _self_ +hydra: + run: + # Set `dir` to where you would like to save all of the run outputs. If you run another training session + # with the same value for `dir` its contents will be overwritten unless you set `resume` to true. + dir: outputs/train_hilserl_classifier/${now:%Y-%m-%d}/${now:%H-%M-%S}_${env.name}_${hydra.job.name} + job: + name: default + seed: 13 -dataset_repo_id: aractingi/pick_place_lego_cube_1 +dataset_repo_id: aractingi/push_cube_square_light_reward_cropped_resized +# aractingi/push_cube_square_reward_1_cropped_resized +dataset_root: data/aractingi/push_cube_square_light_reward_cropped_resized +local_files_only: true train_split_proportion: 0.8 # Required by logger @@ -14,7 +25,7 @@ env: training: - num_epochs: 5 + num_epochs: 6 batch_size: 16 learning_rate: 1e-4 num_workers: 4 @@ -24,16 +35,18 @@ training: eval_freq: 1 # How often to run validation (in epochs) save_freq: 1 # How often to save checkpoints (in epochs) save_checkpoint: true - image_keys: ["observation.images.top", "observation.images.wrist"] + image_keys: ["observation.images.front", "observation.images.side"] label_key: "next.reward" + profile_inference_time: false + profile_inference_time_iters: 20 eval: batch_size: 16 num_samples_to_log: 30 # Number of validation samples to log in the table policy: - name: "hilserl/classifier/pick_place_lego_cube_1" - model_name: "facebook/convnext-base-224" + name: "hilserl/classifier" + model_name: "helper2424/resnet10" # "facebook/convnext-base-224 model_type: "cnn" num_cameras: 2 # Has to be len(training.image_keys) @@ -45,4 +58,4 @@ wandb: device: "mps" resume: false -output_dir: "outputs/classifier" +output_dir: "outputs/classifier/old_trainer_resnet10_frozen" diff --git a/lerobot/configs/policy/sac_maniskill.yaml b/lerobot/configs/policy/sac_maniskill.yaml new file mode 100644 index 000000000..c9bbca44f --- /dev/null +++ b/lerobot/configs/policy/sac_maniskill.yaml @@ -0,0 +1,117 @@ +# @package _global_ + +# Train with: +# +# python lerobot/scripts/train.py \ +# +dataset=lerobot/pusht_keypoints +# env=pusht \ +# env.gym.obs_type=environment_state_agent_pos \ + +seed: 1 +dataset_repo_id: "AdilZtn/Maniskill-Pushcube-demonstration-medium" + +training: + # Offline training dataloader + num_workers: 4 + + batch_size: 512 + grad_clip_norm: 10.0 + lr: 3e-4 + + + storage_device: "cpu" + + eval_freq: 2500 + log_freq: 10 + save_freq: 2000000 + + online_steps: 1000000 + online_rollout_n_episodes: 10 + online_rollout_batch_size: 10 + online_steps_between_rollouts: 1000 + online_sampling_ratio: 1.0 + online_env_seed: 10000 + online_buffer_capacity: 200000 + online_buffer_seed_size: 0 + online_step_before_learning: 500 + do_online_rollout_async: false + policy_update_freq: 1 + + # delta_timestamps: + # observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + # observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + # action: "[i / ${fps} for i in range(${policy.horizon})]" + # next.reward: "[i / ${fps} for i in range(${policy.horizon})]" + +policy: + name: sac + + pretrained_model_path: + + # Input / output structure. + n_action_repeats: 1 + horizon: 1 + n_action_steps: 1 + + shared_encoder: true + # vision_encoder_name: "helper2424/resnet10" + vision_encoder_name: null + # freeze_vision_encoder: true + freeze_vision_encoder: false + input_shapes: + # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.state: ["${env.state_dim}"] + observation.image: [3, 64, 64] + output_shapes: + action: [7] + + camera_number: 1 + + # Normalization / Unnormalization + input_normalization_modes: null + # input_normalization_modes: + # observation.state: min_max + input_normalization_params: null + # observation.state: + # min: [-1.9361e+00, -7.7640e-01, -7.7094e-01, -2.9709e+00, -8.5656e-01, + # 1.0764e+00, -1.2680e+00, 0.0000e+00, 0.0000e+00, -9.3448e+00, + # -3.3828e+00, -3.8420e+00, -5.2553e+00, -3.4154e+00, -6.5082e+00, + # -6.0500e+00, -8.7193e+00, -8.2337e+00, -3.4650e-01, -4.9441e-01, + # 8.3516e-03, -3.1114e-01, -9.9700e-01, -2.3471e-01, -2.7137e-01] + + # max: [ 0.8644, 1.4306, 1.8520, -0.7578, 0.9508, 3.4901, 1.9381, 0.0400, + # 0.0400, 5.0885, 4.7156, 7.9393, 7.9100, 2.9796, 5.7720, 4.7163, + # 7.8145, 9.7415, 0.2422, 0.4505, 0.6306, 0.2622, 1.0000, 0.5135, + # 0.4001] + + output_normalization_modes: + action: min_max + output_normalization_params: + action: + min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0] + max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + output_normalization_shapes: + action: [7] + + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: 32 + # discount: 0.99 + discount: 0.80 + temperature_init: 1.0 + num_critics: 10 #10 + num_subsample_critics: 2 + critic_lr: 3e-4 + actor_lr: 3e-4 + temperature_lr: 3e-4 + # critic_target_update_weight: 0.005 + critic_target_update_weight: 0.01 + utd_ratio: 2 # 10 + +actor_learner_config: + learner_host: "127.0.0.1" + learner_port: 50051 + policy_parameters_push_frequency: 1 + concurrency: + actor: 'processes' + learner: 'processes' diff --git a/lerobot/configs/policy/sac_real.yaml b/lerobot/configs/policy/sac_real.yaml new file mode 100644 index 000000000..139463f98 --- /dev/null +++ b/lerobot/configs/policy/sac_real.yaml @@ -0,0 +1,128 @@ +# @package _global_ + +# Train with: +# +# python lerobot/scripts/train.py \ +# +dataset=lerobot/pusht_keypoints +# env=pusht \ +# env.gym.obs_type=environment_state_agent_pos \ + +seed: 1 +dataset_repo_id: aractingi/push_cube_overfit_cropped_resized +#aractingi/push_cube_square_offline_demo_cropped_resized + +training: + # Offline training dataloader + num_workers: 4 + + # batch_size: 256 + batch_size: 512 + grad_clip_norm: 10.0 + lr: 3e-4 + + eval_freq: 2500 + log_freq: 1 + save_freq: 2000000 + + online_steps: 1000000 + online_rollout_n_episodes: 10 + online_rollout_batch_size: 10 + online_steps_between_rollouts: 1000 + online_sampling_ratio: 1.0 + online_env_seed: 10000 + online_buffer_capacity: 1000000 + online_buffer_seed_size: 0 + online_step_before_learning: 100 #5000 + do_online_rollout_async: false + policy_update_freq: 1 + + # delta_timestamps: + # observation.environment_state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + # observation.state: "[i / ${fps} for i in range(${policy.horizon} + 1)]" + # action: "[i / ${fps} for i in range(${policy.horizon})]" + # next.reward: "[i / ${fps} for i in range(${policy.horizon})]" + +policy: + name: sac + + pretrained_model_path: + + # Input / output structure. + n_action_repeats: 1 + horizon: 1 + n_action_steps: 1 + + shared_encoder: true + vision_encoder_name: "helper2424/resnet10" + freeze_vision_encoder: true + input_shapes: + # # TODO(rcadene, alexander-soare): add variables for height and width from the dataset/env? + observation.state: ["${env.state_dim}"] + observation.images.front: [3, 128, 128] + observation.images.side: [3, 128, 128] + # observation.image: [3, 128, 128] + output_shapes: + action: [4] # ["${env.action_dim}"] + + # Normalization / Unnormalization + input_normalization_modes: + observation.images.front: mean_std + observation.images.side: mean_std + observation.state: min_max + input_normalization_params: + observation.images.front: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + observation.images.side: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + observation.state: + min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786] + max: [ 7.215820e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01] + + # min: [-87.09961, 62.402344, 67.23633, 36.035156, 77.34375,0.53691274] + # max: [58.183594, 131.83594, 145.98633, 82.08984, 78.22266, 0.60402685] + # min: [-88.50586, 23.81836, 0.87890625, -32.16797, 78.66211, 0.53691274] + # max: [84.55078, 187.11914, 145.98633, 101.60156, 146.60156, 88.18792] + + output_normalization_modes: + action: min_max + output_normalization_params: + # action: + # min: [-1.0, -1.0, -1.0, -1.0, -1.0, -1.0] + # max: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] + action: + min: [-149.23828125, -97.734375, -100.1953125, -73.740234375] + max: [149.23828125, 97.734375, 100.1953125, 73.740234375] + + # Architecture / modeling. + # Neural networks. + image_encoder_hidden_dim: 32 + # discount: 0.99 + discount: 0.97 + temperature_init: 1.0 + num_critics: 2 #10 + camera_number: 2 + num_subsample_critics: null + critic_lr: 3e-4 + actor_lr: 3e-4 + temperature_lr: 3e-4 + # critic_target_update_weight: 0.005 + critic_target_update_weight: 0.01 + utd_ratio: 2 # 10 + +actor_learner_config: + learner_host: "127.0.0.1" + learner_port: 50051 + policy_parameters_push_frequency: 15 + + # # Loss coefficients. + # reward_coeff: 0.5 + # expectile_weight: 0.9 + # value_coeff: 0.1 + # consistency_coeff: 20.0 + # advantage_scaling: 3.0 + # pi_coeff: 0.5 + # temporal_decay_coeff: 0.5 + # # Target model. + # target_model_momentum: 0.995 diff --git a/lerobot/configs/robot/so100.yaml b/lerobot/configs/robot/so100.yaml index 0978de64e..459308aea 100644 --- a/lerobot/configs/robot/so100.yaml +++ b/lerobot/configs/robot/so100.yaml @@ -14,6 +14,9 @@ calibration_dir: .cache/calibration/so100 # Set this to a positive scalar to have the same value for all motors, or a list that is the same length as # the number of motors in your follower arms. max_relative_target: null +joint_position_relative_bounds: + max: [ 7.2158203e+01, 1.5398438e+02, 1.6075195e+02, 9.3251953e+01, 0., -1.4184397e-01] + min: [-77.08008, 56.25, 60.55664, 19.511719, 0., -0.63829786] leader_arms: main: @@ -31,7 +34,7 @@ leader_arms: follower_arms: main: _target_: lerobot.common.robot_devices.motors.feetech.FeetechMotorsBus - port: /dev/tty.usbmodem585A0080971 + port: /dev/tty.usbmodem58760431631 motors: # name: (index, model) shoulder_pan: [1, "sts3215"] @@ -42,13 +45,13 @@ follower_arms: gripper: [6, "sts3215"] cameras: - laptop: + front: _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera camera_index: 0 fps: 30 width: 640 height: 480 - phone: + side: _target_: lerobot.common.robot_devices.cameras.opencv.OpenCVCamera camera_index: 1 fps: 30 diff --git a/lerobot/scripts/configure_motor.py b/lerobot/scripts/configure_motor.py index 18707397f..1a53ab6c7 100644 --- a/lerobot/scripts/configure_motor.py +++ b/lerobot/scripts/configure_motor.py @@ -22,13 +22,17 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): from lerobot.common.robot_devices.motors.feetech import ( SCS_SERIES_BAUDRATE_TABLE as SERIES_BAUDRATE_TABLE, ) - from lerobot.common.robot_devices.motors.feetech import FeetechMotorsBus as MotorsBusClass + from lerobot.common.robot_devices.motors.feetech import ( + FeetechMotorsBus as MotorsBusClass, + ) elif brand == "dynamixel": from lerobot.common.robot_devices.motors.dynamixel import MODEL_BAUDRATE_TABLE from lerobot.common.robot_devices.motors.dynamixel import ( X_SERIES_BAUDRATE_TABLE as SERIES_BAUDRATE_TABLE, ) - from lerobot.common.robot_devices.motors.dynamixel import DynamixelMotorsBus as MotorsBusClass + from lerobot.common.robot_devices.motors.dynamixel import ( + DynamixelMotorsBus as MotorsBusClass, + ) else: raise ValueError( f"Currently we do not support this motor brand: {brand}. We currently support feetech and dynamixel motors." @@ -46,7 +50,9 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): motor_model = model # Use the motor model passed via argument # Initialize the MotorBus with the correct port and motor configurations - motor_bus = MotorsBusClass(port=port, motors={motor_name: (motor_index_arbitrary, motor_model)}) + motor_bus = MotorsBusClass( + port=port, motors={motor_name: (motor_index_arbitrary, motor_model)} + ) # Try to connect to the motor bus and handle any connection-specific errors try: @@ -78,20 +84,26 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): motor_index = present_ids[0] if motor_index == -1: - raise ValueError("No motors detected. Please ensure you have one motor connected.") + raise ValueError( + "No motors detected. Please ensure you have one motor connected." + ) print(f"Motor index found at: {motor_index}") if brand == "feetech": # Allows ID and BAUDRATE to be written in memory - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) + motor_bus.write_with_motor_ids( + motor_bus.motor_models, motor_index, "Lock", 0 + ) if baudrate != baudrate_des: print(f"Setting its baudrate to {baudrate_des}") baudrate_idx = list(SERIES_BAUDRATE_TABLE.values()).index(baudrate_des) # The write can fail, so we allow retries - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx) + motor_bus.write_with_motor_ids( + motor_bus.motor_models, motor_index, "Baud_Rate", baudrate_idx + ) time.sleep(0.5) motor_bus.set_bus_baudrate(baudrate_des) present_baudrate_idx = motor_bus.read_with_motor_ids( @@ -103,9 +115,13 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): print(f"Setting its index to desired index {motor_idx_des}") motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "Lock", 0) - motor_bus.write_with_motor_ids(motor_bus.motor_models, motor_index, "ID", motor_idx_des) + motor_bus.write_with_motor_ids( + motor_bus.motor_models, motor_index, "ID", motor_idx_des + ) - present_idx = motor_bus.read_with_motor_ids(motor_bus.motor_models, motor_idx_des, "ID", num_retry=2) + present_idx = motor_bus.read_with_motor_ids( + motor_bus.motor_models, motor_idx_des, "ID", num_retry=2 + ) if present_idx != motor_idx_des: raise OSError("Failed to write index.") @@ -133,12 +149,29 @@ def configure_motor(port, brand, model, motor_idx_des, baudrate_des): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--port", type=str, required=True, help="Motors bus port (e.g. dynamixel,feetech)") - parser.add_argument("--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)") - parser.add_argument("--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)") - parser.add_argument("--ID", type=int, required=True, help="Desired ID of the current motor (e.g. 1,2,3)") parser.add_argument( - "--baudrate", type=int, default=1000000, help="Desired baudrate for the motor (default: 1000000)" + "--port", + type=str, + required=True, + help="Motors bus port (e.g. dynamixel,feetech)", + ) + parser.add_argument( + "--brand", type=str, required=True, help="Motor brand (e.g. dynamixel,feetech)" + ) + parser.add_argument( + "--model", type=str, required=True, help="Motor model (e.g. xl330-m077,sts3215)" + ) + parser.add_argument( + "--ID", + type=int, + required=True, + help="Desired ID of the current motor (e.g. 1,2,3)", + ) + parser.add_argument( + "--baudrate", + type=int, + default=1000000, + help="Desired baudrate for the motor (default: 1000000)", ) args = parser.parse_args() diff --git a/lerobot/scripts/control_robot.py b/lerobot/scripts/control_robot.py index 8187e8a34..599df1f63 100644 --- a/lerobot/scripts/control_robot.py +++ b/lerobot/scripts/control_robot.py @@ -118,7 +118,12 @@ from lerobot.common.robot_devices.robots.factory import make_robot from lerobot.common.robot_devices.robots.utils import Robot from lerobot.common.robot_devices.utils import busy_wait, safe_disconnect -from lerobot.common.utils.utils import init_hydra_config, init_logging, log_say, none_or_int +from lerobot.common.utils.utils import ( + init_hydra_config, + init_logging, + log_say, + none_or_int, +) ######################################################################################## # Control modes @@ -173,7 +178,10 @@ def calibrate(robot: Robot, arms: list[str] | None): @safe_disconnect def teleoperate( - robot: Robot, fps: int | None = None, teleop_time_s: float | None = None, display_cameras: bool = False + robot: Robot, + fps: int | None = None, + teleop_time_s: float | None = None, + display_cameras: bool = False, ): control_loop( robot, @@ -206,7 +214,8 @@ def record( num_image_writer_threads_per_camera: int = 4, display_cameras: bool = True, play_sounds: bool = True, - reset_follower: bool = False, + reset_follower: bool = False, + record_delta_actions: bool = False, resume: bool = False, # TODO(rcadene, aliberts): remove local_files_only when refactor with dataset as argument local_files_only: bool = False, @@ -218,7 +227,12 @@ def record( device = None use_amp = None extra_features = ( - {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None + { + "next.reward": {"dtype": "int64", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + } + if assign_rewards + else None ) if single_task: @@ -228,11 +242,15 @@ def record( # Load pretrained policy if pretrained_policy_name_or_path is not None: - policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) + policy, policy_fps, device, use_amp = init_policy( + pretrained_policy_name_or_path, policy_overrides + ) if fps is None: fps = policy_fps - logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") + logging.warning( + f"No fps provided, so using the fps from policy config ({policy_fps})." + ) elif fps != policy_fps: logging.warning( f"There is a mismatch between the provided fps ({fps}) and the one from policy config ({policy_fps})." @@ -248,7 +266,9 @@ def record( num_processes=num_image_writer_processes, num_threads=num_image_writer_threads_per_camera * len(robot.cameras), ) - sanity_check_dataset_robot_compatibility(dataset, robot, fps, video, extra_features) + sanity_check_dataset_robot_compatibility( + dataset, robot, fps, video, extra_features + ) else: # Create empty dataset or load existing saved episodes sanity_check_dataset_name(repo_id, policy) @@ -259,7 +279,8 @@ def record( robot=robot, use_videos=video, image_writer_processes=num_image_writer_processes, - image_writer_threads=num_image_writer_threads_per_camera * len(robot.cameras), + image_writer_threads=num_image_writer_threads_per_camera + * len(robot.cameras), features=extra_features, ) @@ -269,14 +290,16 @@ def record( if reset_follower: initial_position = robot.follower_arms["main"].read("Present_Position") - + # Execute a few seconds without recording to: # 1. teleoperate the robot to move it in starting position if no policy provided, # 2. give times to the robot devices to connect and start synchronizing, # 3. place the cameras windows on screen enable_teleoperation = policy is None log_say("Warmup record", play_sounds) - warmup_record(robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps) + warmup_record( + robot, events, enable_teleoperation, warmup_time_s, display_cameras, fps + ) if has_method(robot, "teleop_safety_stop"): robot.teleop_safety_stop() @@ -302,6 +325,7 @@ def record( device=device, use_amp=use_amp, fps=fps, + record_delta_actions=record_delta_actions, ) # Execute a few seconds without recording to give time to manually reset the environment @@ -309,7 +333,7 @@ def record( # TODO(rcadene): add an option to enable teleoperation during reset # Skip reset for the last episode to be recorded if not events["stop_recording"] and ( - (dataset.num_episodes < num_episodes - 1) or events["rerecord_episode"] + (recorded_episodes < num_episodes - 1) or events["rerecord_episode"] ): log_say("Reset the environment", play_sounds) if reset_follower: @@ -353,21 +377,26 @@ def replay( fps: int | None = None, play_sounds: bool = True, local_files_only: bool = False, + replay_delta_actions: bool = False, ): # TODO(rcadene, aliberts): refactor with control_loop, once `dataset` is an instance of LeRobotDataset # TODO(rcadene): Add option to record logs - dataset = LeRobotDataset(repo_id, root=root, episodes=[episode], local_files_only=local_files_only) + dataset = LeRobotDataset( + repo_id, root=root, episodes=[episode], local_files_only=local_files_only + ) actions = dataset.hf_dataset.select_columns("action") - if not robot.is_connected: robot.connect() log_say("Replaying episode", play_sounds, blocking=True) for idx in range(dataset.num_frames): + current_joint_positions = robot.follower_arms["main"].read("Present_Position") start_episode_t = time.perf_counter() action = actions[idx]["action"] + if replay_delta_actions: + action = action + current_joint_positions robot.send_action(action) dt_s = time.perf_counter() - start_episode_t @@ -406,7 +435,10 @@ def replay( parser_teleop = subparsers.add_parser("teleoperate", parents=[base_parser]) parser_teleop.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + "--fps", + type=none_or_int, + default=None, + help="Frames per second (set to None to disable)", ) parser_teleop.add_argument( "--display-cameras", @@ -418,7 +450,10 @@ def replay( parser_record = subparsers.add_parser("record", parents=[base_parser]) task_args = parser_record.add_mutually_exclusive_group(required=True) parser_record.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + "--fps", + type=none_or_int, + default=None, + help="Frames per second (set to None to disable)", ) task_args.add_argument( "--single-task", @@ -467,7 +502,9 @@ def replay( default=60, help="Number of seconds for resetting the environment after each episode.", ) - parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.") + parser_record.add_argument( + "--num-episodes", type=int, default=50, help="Number of episodes to record." + ) parser_record.add_argument( "--run-compute-stats", type=int, @@ -534,6 +571,12 @@ def replay( default=0, help="Enables the assignation of rewards to frames (by default no assignation). When enabled, assign a 0 reward to frames until the space bar is pressed which assign a 1 reward. Press the space bar a second time to assign a 0 reward. The reward assigned is reset to 0 when the episode ends.", ) + parser_record.add_argument( + "--record-delta-actions", + type=int, + default=0, + help="Enables the recording of delta actions instead of absolute actions.", + ) parser_record.add_argument( "--reset-follower", type=int, @@ -543,7 +586,10 @@ def replay( parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + "--fps", + type=none_or_int, + default=None, + help="Frames per second (set to None to disable)", ) parser_replay.add_argument( "--root", @@ -563,7 +609,15 @@ def replay( default=0, help="Use local files only. By default, this script will try to fetch the dataset from the hub if it exists.", ) - parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episode to replay.") + parser_replay.add_argument( + "--replay-delta-actions", + type=int, + default=0, + help="Enables the replay of delta actions instead of absolute actions.", + ) + parser_replay.add_argument( + "--episode", type=int, default=0, help="Index of the episode to replay." + ) args = parser.parse_args() diff --git a/lerobot/scripts/control_sim_robot.py b/lerobot/scripts/control_sim_robot.py index 67bdfb856..36bd16706 100644 --- a/lerobot/scripts/control_sim_robot.py +++ b/lerobot/scripts/control_sim_robot.py @@ -135,7 +135,11 @@ def init_sim_calibration(robot, cfg): axis_directions = np.array(cfg.get("axis_directions", [1])) offsets = np.array(cfg.get("offsets", [0])) * np.pi - return {"start_pos": start_pos, "axis_directions": axis_directions, "offsets": offsets} + return { + "start_pos": start_pos, + "axis_directions": axis_directions, + "offsets": offsets, + } def real_positions_to_sim(real_positions, axis_directions, start_pos, offsets): @@ -156,7 +160,10 @@ def teleoperate(env, robot: Robot, process_action_fn, teleop_time_s=None): leader_pos = robot.leader_arms.main.read("Present_Position") action = process_action_fn(leader_pos) env.step(np.expand_dims(action, 0)) - if teleop_time_s is not None and time.perf_counter() - start_teleop_t > teleop_time_s: + if ( + teleop_time_s is not None + and time.perf_counter() - start_teleop_t > teleop_time_s + ): print("Teleoperation processes finished.") break @@ -188,19 +195,27 @@ def record( # Load pretrained policy extra_features = ( - {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} if assign_rewards else None + {"next.reward": {"dtype": "int64", "shape": (1,), "names": None}} + if assign_rewards + else None ) policy = None if pretrained_policy_name_or_path is not None: - policy, policy_fps, device, use_amp = init_policy(pretrained_policy_name_or_path, policy_overrides) + policy, policy_fps, device, use_amp = init_policy( + pretrained_policy_name_or_path, policy_overrides + ) if fps is None: fps = policy_fps - logging.warning(f"No fps provided, so using the fps from policy config ({policy_fps}).") + logging.warning( + f"No fps provided, so using the fps from policy config ({policy_fps})." + ) if policy is None and process_action_from_leader is None: - raise ValueError("Either policy or process_action_fn has to be set to enable control in sim.") + raise ValueError( + "Either policy or process_action_fn has to be set to enable control in sim." + ) # initialize listener before sim env listener, events = init_keyboard_listener(assign_rewards=assign_rewards) @@ -233,7 +248,11 @@ def record( shape = env.observation_space[key].shape if not key.startswith("observation.image."): key = "observation.image." + key - features[key] = {"dtype": "video", "names": ["channel", "height", "width"], "shape": shape} + features[key] = { + "dtype": "video", + "names": ["channel", "height", "width"], + "shape": shape, + } for key, obs_key in state_keys_dict.items(): features[key] = { @@ -242,7 +261,11 @@ def record( "shape": env.observation_space[obs_key].shape, } - features["action"] = {"dtype": "float32", "shape": env.action_space.shape, "names": None} + features["action"] = { + "dtype": "float32", + "shape": env.action_space.shape, + "names": None, + } features = {**features, **extra_features} # Create empty dataset or load existing saved episodes @@ -343,7 +366,9 @@ def record( if events["stop_recording"] or recorded_episodes >= num_episodes: break else: - logging.info("Waiting for a few seconds before starting next episode recording...") + logging.info( + "Waiting for a few seconds before starting next episode recording..." + ) busy_wait(3) log_say("Stop recording", play_sounds, blocking=True) @@ -361,7 +386,12 @@ def record( def replay( - env, root: Path, repo_id: str, episode: int, fps: int | None = None, local_files_only: bool = True + env, + root: Path, + repo_id: str, + episode: int, + fps: int | None = None, + local_files_only: bool = True, ): env = env() @@ -408,7 +438,10 @@ def replay( parser_record = subparsers.add_parser("record", parents=[base_parser]) parser_record.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + "--fps", + type=none_or_int, + default=None, + help="Frames per second (set to None to disable)", ) parser_record.add_argument( "--root", @@ -434,7 +467,9 @@ def replay( required=True, help="A description of the task preformed during recording that can be used as a language instruction.", ) - parser_record.add_argument("--num-episodes", type=int, default=50, help="Number of episodes to record.") + parser_record.add_argument( + "--num-episodes", type=int, default=50, help="Number of episodes to record." + ) parser_record.add_argument( "--run-compute-stats", type=int, @@ -495,7 +530,10 @@ def replay( parser_replay = subparsers.add_parser("replay", parents=[base_parser]) parser_replay.add_argument( - "--fps", type=none_or_int, default=None, help="Frames per second (set to None to disable)" + "--fps", + type=none_or_int, + default=None, + help="Frames per second (set to None to disable)", ) parser_replay.add_argument( "--root", @@ -509,7 +547,9 @@ def replay( default="lerobot/test", help="Dataset identifier. By convention it should match '{hf_username}/{dataset_name}' (e.g. `lerobot/test`).", ) - parser_replay.add_argument("--episode", type=int, default=0, help="Index of the episodes to replay.") + parser_replay.add_argument( + "--episode", type=int, default=0, help="Index of the episodes to replay." + ) args = parser.parse_args() diff --git a/lerobot/scripts/display_sys_info.py b/lerobot/scripts/display_sys_info.py index 4d3cc291f..2d844990f 100644 --- a/lerobot/scripts/display_sys_info.py +++ b/lerobot/scripts/display_sys_info.py @@ -59,7 +59,11 @@ torch_version = torch.__version__ if HAS_TORCH else "N/A" torch_cuda_available = torch.cuda.is_available() if HAS_TORCH else "N/A" -cuda_version = torch._C._cuda_getCompiledVersion() if HAS_TORCH and torch.version.cuda is not None else "N/A" +cuda_version = ( + torch._C._cuda_getCompiledVersion() + if HAS_TORCH and torch.version.cuda is not None + else "N/A" +) # TODO(aliberts): refactor into an actual command `lerobot env` @@ -77,7 +81,9 @@ def display_sys_info() -> dict: "Using GPU in script?": "", # "Using distributed or parallel set-up in script?": "", } - print("\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n") + print( + "\nCopy-and-paste the text below in your GitHub issue and FILL OUT the last point.\n" + ) print(format_dict(info)) return info diff --git a/lerobot/scripts/eval.py b/lerobot/scripts/eval.py index 040f92d96..8b7b9e80e 100644 --- a/lerobot/scripts/eval.py +++ b/lerobot/scripts/eval.py @@ -149,7 +149,9 @@ def rollout( if return_observations: all_observations.append(deepcopy(observation)) - observation = {key: observation[key].to(device, non_blocking=True) for key in observation} + observation = { + key: observation[key].to(device, non_blocking=True) for key in observation + } with torch.inference_mode(): action = policy.select_action(observation) @@ -166,7 +168,10 @@ def rollout( # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't # available of none of the envs finished. if "final_info" in info: - successes = [info["is_success"] if info is not None else False for info in info["final_info"]] + successes = [ + info["is_success"] if info is not None else False + for info in info["final_info"] + ] else: successes = [False] * env.num_envs @@ -180,9 +185,13 @@ def rollout( step += 1 running_success_rate = ( - einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() + einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any") + .numpy() + .mean() + ) + progbar.set_postfix( + {"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"} ) - progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}) progbar.update() # Track the final observation. @@ -200,7 +209,9 @@ def rollout( if return_observations: stacked_observations = {} for key in all_observations[0]: - stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) + stacked_observations[key] = torch.stack( + [obs[key] for obs in all_observations], dim=1 + ) ret["observation"] = stacked_observations return ret @@ -255,7 +266,9 @@ def render_frame(env: gym.vector.VectorEnv): return n_to_render_now = min(max_episodes_rendered - n_episodes_rendered, env.num_envs) if isinstance(env, gym.vector.SyncVectorEnv): - ep_frames.append(np.stack([env.envs[i].render() for i in range(n_to_render_now)])) # noqa: B023 + ep_frames.append( + np.stack([env.envs[i].render() for i in range(n_to_render_now)]) + ) # noqa: B023 elif isinstance(env, gym.vector.AsyncVectorEnv): # Here we must render all frames and discard any we don't need. ep_frames.append(np.stack(env.call("render")[:n_to_render_now])) @@ -267,7 +280,9 @@ def render_frame(env: gym.vector.VectorEnv): episode_data: dict | None = None # we dont want progress bar when we use slurm, since it clutters the logs - progbar = trange(n_batches, desc="Stepping through eval batches", disable=inside_slurm()) + progbar = trange( + n_batches, desc="Stepping through eval batches", disable=inside_slurm() + ) for batch_ix in progbar: # Cache frames for rendering videos. Each item will be (b, h, w, c), and the list indexes the rollout # step. @@ -278,7 +293,8 @@ def render_frame(env: gym.vector.VectorEnv): seeds = None else: seeds = range( - start_seed + (batch_ix * env.num_envs), start_seed + ((batch_ix + 1) * env.num_envs) + start_seed + (batch_ix * env.num_envs), + start_seed + ((batch_ix + 1) * env.num_envs), ) rollout_data = rollout( env, @@ -296,13 +312,22 @@ def render_frame(env: gym.vector.VectorEnv): # Make a mask with shape (batch, n_steps) to mask out rollout data after the first done # (batch-element-wise). Note the `done_indices + 1` to make sure to keep the data from the done step. - mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int() + mask = ( + torch.arange(n_steps) + <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps) + ).int() # Extend metrics. - batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum") + batch_sum_rewards = einops.reduce( + (rollout_data["reward"] * mask), "b n -> b", "sum" + ) sum_rewards.extend(batch_sum_rewards.tolist()) - batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max") + batch_max_rewards = einops.reduce( + (rollout_data["reward"] * mask), "b n -> b", "max" + ) max_rewards.extend(batch_max_rewards.tolist()) - batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any") + batch_successes = einops.reduce( + (rollout_data["success"] * mask), "b n -> b", "any" + ) all_successes.extend(batch_successes.tolist()) if seeds: all_seeds.extend(seeds) @@ -315,17 +340,27 @@ def render_frame(env: gym.vector.VectorEnv): rollout_data, done_indices, start_episode_index=batch_ix * env.num_envs, - start_data_index=(0 if episode_data is None else (episode_data["index"][-1].item() + 1)), + start_data_index=( + 0 + if episode_data is None + else (episode_data["index"][-1].item() + 1) + ), fps=env.unwrapped.metadata["render_fps"], ) if episode_data is None: episode_data = this_episode_data else: # Some sanity checks to make sure we are correctly compiling the data. - assert episode_data["episode_index"][-1] + 1 == this_episode_data["episode_index"][0] + assert ( + episode_data["episode_index"][-1] + 1 + == this_episode_data["episode_index"][0] + ) assert episode_data["index"][-1] + 1 == this_episode_data["index"][0] # Concatenate the episode data. - episode_data = {k: torch.cat([episode_data[k], this_episode_data[k]]) for k in episode_data} + episode_data = { + k: torch.cat([episode_data[k], this_episode_data[k]]) + for k in episode_data + } # Maybe render video for visualization. if max_episodes_rendered > 0 and len(ep_frames) > 0: @@ -343,7 +378,9 @@ def render_frame(env: gym.vector.VectorEnv): target=write_video, args=( str(video_path), - stacked_frames[: done_index + 1], # + 1 to capture the last observation + stacked_frames[ + : done_index + 1 + ], # + 1 to capture the last observation env.unwrapped.metadata["render_fps"], ), ) @@ -352,7 +389,9 @@ def render_frame(env: gym.vector.VectorEnv): n_episodes_rendered += 1 progbar.set_postfix( - {"running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%"} + { + "running_success_rate": f"{np.mean(all_successes[:n_episodes]).item() * 100:.1f}%" + } ) # Wait till all video rendering threads are done. @@ -398,7 +437,11 @@ def render_frame(env: gym.vector.VectorEnv): def _compile_episode_data( - rollout_data: dict, done_indices: Tensor, start_episode_index: int, start_data_index: int, fps: float + rollout_data: dict, + done_indices: Tensor, + start_episode_index: int, + start_data_index: int, + fps: float, ) -> dict: """Convenience function for `eval_policy(return_episode_data=True)` @@ -416,12 +459,16 @@ def _compile_episode_data( # Here we do `num_frames - 1` as we don't want to include the last observation frame just yet. ep_dict = { "action": rollout_data["action"][ep_ix, : num_frames - 1], - "episode_index": torch.tensor([start_episode_index + ep_ix] * (num_frames - 1)), + "episode_index": torch.tensor( + [start_episode_index + ep_ix] * (num_frames - 1) + ), "frame_index": torch.arange(0, num_frames - 1, 1), "timestamp": torch.arange(0, num_frames - 1, 1) / fps, "next.done": rollout_data["done"][ep_ix, : num_frames - 1], "next.success": rollout_data["success"][ep_ix, : num_frames - 1], - "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type(torch.float32), + "next.reward": rollout_data["reward"][ep_ix, : num_frames - 1].type( + torch.float32 + ), } # For the last observation frame, all other keys will just be copy padded. @@ -437,7 +484,9 @@ def _compile_episode_data( for key in ep_dicts[0]: data_dict[key] = torch.cat([x[key] for x in ep_dicts]) - data_dict["index"] = torch.arange(start_data_index, start_data_index + total_frames, 1) + data_dict["index"] = torch.arange( + start_data_index, start_data_index + total_frames, 1 + ) return data_dict @@ -450,7 +499,9 @@ def main( ): assert (pretrained_policy_path is None) ^ (hydra_cfg_path is None) if pretrained_policy_path is not None: - hydra_cfg = init_hydra_config(str(pretrained_policy_path / "config.yaml"), config_overrides) + hydra_cfg = init_hydra_config( + str(pretrained_policy_path / "config.yaml"), config_overrides + ) else: hydra_cfg = init_hydra_config(hydra_cfg_path, config_overrides) @@ -481,15 +532,23 @@ def main( logging.info("Making policy.") if hydra_cfg_path is None: - policy = make_policy(hydra_cfg=hydra_cfg, pretrained_policy_name_or_path=str(pretrained_policy_path)) + policy = make_policy( + hydra_cfg=hydra_cfg, + pretrained_policy_name_or_path=str(pretrained_policy_path), + ) else: # Note: We need the dataset stats to pass to the policy's normalization modules. - policy = make_policy(hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats) + policy = make_policy( + hydra_cfg=hydra_cfg, dataset_stats=make_dataset(hydra_cfg).meta.stats + ) assert isinstance(policy, nn.Module) policy.eval() - with torch.no_grad(), torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(): + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) if hydra_cfg.use_amp else nullcontext(), + ): info = eval_policy( env, policy, @@ -511,16 +570,14 @@ def main( def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None): try: - pretrained_policy_path = Path(snapshot_download(pretrained_policy_name_or_path, revision=revision)) + pretrained_policy_path = Path( + snapshot_download(pretrained_policy_name_or_path, revision=revision) + ) except (HFValidationError, RepositoryNotFoundError) as e: if isinstance(e, HFValidationError): - error_message = ( - "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID." - ) + error_message = "The provided pretrained_policy_name_or_path is not a valid Hugging Face Hub repo ID." else: - error_message = ( - "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub." - ) + error_message = "The provided pretrained_policy_name_or_path was not found on the Hugging Face Hub." logging.warning(f"{error_message} Treating it as a local directory.") pretrained_policy_path = Path(pretrained_policy_name_or_path) @@ -555,7 +612,9 @@ def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None): "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." ), ) - parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") + parser.add_argument( + "--revision", help="Optionally provide the Hugging Face Hub revision ID." + ) parser.add_argument( "--out-dir", help=( @@ -571,7 +630,11 @@ def get_pretrained_policy_path(pretrained_policy_name_or_path, revision=None): args = parser.parse_args() if args.pretrained_policy_name_or_path is None: - main(hydra_cfg_path=args.config, out_dir=args.out_dir, config_overrides=args.overrides) + main( + hydra_cfg_path=args.config, + out_dir=args.out_dir, + config_overrides=args.overrides, + ) else: pretrained_policy_path = get_pretrained_policy_path( args.pretrained_policy_name_or_path, revision=args.revision diff --git a/lerobot/scripts/eval_on_robot.py b/lerobot/scripts/eval_on_robot.py index 842c1a281..8a7062e79 100644 --- a/lerobot/scripts/eval_on_robot.py +++ b/lerobot/scripts/eval_on_robot.py @@ -46,7 +46,11 @@ from tqdm import trange from lerobot.common.policies.policy_protocol import Policy -from lerobot.common.robot_devices.control_utils import busy_wait, is_headless, reset_follower_position +from lerobot.common.robot_devices.control_utils import ( + busy_wait, + is_headless, + reset_follower_position, +) from lerobot.common.robot_devices.robots.factory import Robot, make_robot from lerobot.common.utils.utils import ( init_hydra_config, @@ -60,13 +64,19 @@ def get_classifier(pretrained_path, config_path): return from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg - from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig - from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier + from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( + ClassifierConfig, + ) + from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ) cfg = init_hydra_config(config_path) classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) - classifier_config.num_cameras = len(cfg.training.image_keys) # TODO automate these paths + classifier_config.num_cameras = len( + cfg.training.image_keys + ) # TODO automate these paths model = Classifier(classifier_config) model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) model = model.to("mps") @@ -151,11 +161,17 @@ def rollout( images = [] for key in image_keys: if display_cameras: - cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.imshow( + key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + ) cv2.waitKey(1) images.append(observation[key].to("mps")) - reward = reward_classifier.predict_reward(images) if reward_classifier is not None else 0.0 + reward = ( + reward_classifier.predict_reward(images) + if reward_classifier is not None + else 0.0 + ) all_rewards.append(reward) # print("REWARD : ", reward) @@ -219,11 +235,19 @@ def eval_policy( start_eval = time.perf_counter() progbar = trange(n_episodes, desc="Evaluating policy on real robot") - reward_classifier = get_classifier(reward_classifier_pretrained_path, reward_classifier_config_file) + reward_classifier = get_classifier( + reward_classifier_pretrained_path, reward_classifier_config_file + ) for _ in progbar: rollout_data = rollout( - robot, policy, reward_classifier, fps, control_time_s, use_amp, display_cameras + robot, + policy, + reward_classifier, + fps, + control_time_s, + use_amp, + display_cameras, ) rollouts.append(rollout_data) @@ -289,7 +313,9 @@ def on_press(key): print("Right arrow key pressed. Exiting loop...") events["exit_early"] = True elif key == keyboard.Key.left: - print("Left arrow key pressed. Exiting loop and rerecord the last episode...") + print( + "Left arrow key pressed. Exiting loop and rerecord the last episode..." + ) events["rerecord_episode"] = True events["exit_early"] = True elif key == keyboard.Key.space: @@ -301,7 +327,10 @@ def on_press(key): "Place the leader in similar pose to the follower and press space again." ) events["pause_policy"] = True - log_say("Human intervention stage. Get ready to take over.", play_sounds=True) + log_say( + "Human intervention stage. Get ready to take over.", + play_sounds=True, + ) else: events["human_intervention_step"] = True print("Space key pressed. Human intervention starting.") @@ -351,7 +380,9 @@ def on_press(key): "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." ), ) - parser.add_argument("--revision", help="Optionally provide the Hugging Face Hub revision ID.") + parser.add_argument( + "--revision", help="Optionally provide the Hugging Face Hub revision ID." + ) parser.add_argument( "--out-dir", help=( @@ -360,7 +391,8 @@ def on_press(key): ), ) parser.add_argument( - "--display-cameras", help=("Whether to display the camera feed while the rollout is happening") + "--display-cameras", + help=("Whether to display the camera feed while the rollout is happening"), ) parser.add_argument( "--reward-classifier-pretrained-path", diff --git a/lerobot/scripts/find_motors_bus_port.py b/lerobot/scripts/find_motors_bus_port.py index 67b92ad7d..b4dcbe4ed 100644 --- a/lerobot/scripts/find_motors_bus_port.py +++ b/lerobot/scripts/find_motors_bus_port.py @@ -32,9 +32,13 @@ def find_port(): print(f"The port of this MotorsBus is '{port}'") print("Reconnect the USB cable.") elif len(ports_diff) == 0: - raise OSError(f"Could not detect the port. No difference was found ({ports_diff}).") + raise OSError( + f"Could not detect the port. No difference was found ({ports_diff})." + ) else: - raise OSError(f"Could not detect the port. More than one port was found ({ports_diff}).") + raise OSError( + f"Could not detect the port. More than one port was found ({ports_diff})." + ) if __name__ == "__main__": diff --git a/lerobot/scripts/push_dataset_to_hub.py b/lerobot/scripts/push_dataset_to_hub.py index 0233ede69..85e1be403 100644 --- a/lerobot/scripts/push_dataset_to_hub.py +++ b/lerobot/scripts/push_dataset_to_hub.py @@ -56,24 +56,42 @@ from lerobot.common.datasets.compute_stats import compute_stats from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset from lerobot.common.datasets.push_dataset_to_hub.utils import check_repo_id -from lerobot.common.datasets.utils import create_branch, create_lerobot_dataset_card, flatten_dict +from lerobot.common.datasets.utils import ( + create_branch, + create_lerobot_dataset_card, + flatten_dict, +) def get_from_raw_to_lerobot_format_fn(raw_format: str): if raw_format == "pusht_zarr": - from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format + from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import ( + from_raw_to_lerobot_format, + ) elif raw_format == "umi_zarr": - from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format + from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import ( + from_raw_to_lerobot_format, + ) elif raw_format == "aloha_hdf5": - from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format + from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import ( + from_raw_to_lerobot_format, + ) elif raw_format in ["rlds", "openx"]: - from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import from_raw_to_lerobot_format + from lerobot.common.datasets.push_dataset_to_hub.openx_rlds_format import ( + from_raw_to_lerobot_format, + ) elif raw_format == "dora_parquet": - from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format + from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import ( + from_raw_to_lerobot_format, + ) elif raw_format == "xarm_pkl": - from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format + from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import ( + from_raw_to_lerobot_format, + ) elif raw_format == "cam_png": - from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format + from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import ( + from_raw_to_lerobot_format, + ) else: raise ValueError( f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?" @@ -83,7 +101,10 @@ def get_from_raw_to_lerobot_format_fn(raw_format: str): def save_meta_data( - info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path + info: dict[str, Any], + stats: dict, + episode_data_index: dict[str, list], + meta_data_dir: Path, ): meta_data_dir.mkdir(parents=True, exist_ok=True) @@ -97,12 +118,16 @@ def save_meta_data( save_file(flatten_dict(stats), stats_path) # save episode_data_index - episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index} + episode_data_index = { + key: torch.tensor(episode_data_index[key]) for key in episode_data_index + } ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors" save_file(episode_data_index, ep_data_idx_path) -def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None): +def push_meta_data_to_hub( + repo_id: str, meta_data_dir: str | Path, revision: str | None +): """Expect all meta data files to be all stored in a single "meta_data" directory. On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root. """ @@ -187,7 +212,9 @@ def push_dataset_to_hub( if force_override: shutil.rmtree(local_dir) elif not resume: - raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.") + raise ValueError( + f"`local_dir` already exists ({local_dir}). Use `--force-override 1`." + ) meta_data_dir = local_dir / "meta_data" videos_dir = local_dir / "videos" @@ -223,7 +250,9 @@ def push_dataset_to_hub( stats = compute_stats(lerobot_dataset, batch_size, num_workers) if local_dir: - hf_dataset = hf_dataset.with_format(None) # to remove transforms that cant be saved + hf_dataset = hf_dataset.with_format( + None + ) # to remove transforms that cant be saved hf_dataset.save_to_disk(str(local_dir / "train")) if push_to_hub or local_dir: diff --git a/lerobot/scripts/server/actor_server.py b/lerobot/scripts/server/actor_server.py new file mode 100644 index 000000000..24d8356d6 --- /dev/null +++ b/lerobot/scripts/server/actor_server.py @@ -0,0 +1,631 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from statistics import mean, quantiles +from functools import lru_cache +from lerobot.scripts.server.utils import setup_process_handlers + +# from lerobot.scripts.eval import eval_policy + +import grpc +import hydra +import torch +from omegaconf import DictConfig +from torch import nn +import time + +# TODO: Remove the import of maniskill +# from lerobot.common.envs.factory import make_maniskill_env +# from lerobot.common.envs.utils import preprocess_maniskill_observation +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.robot_devices.robots.utils import Robot +from lerobot.common.utils.utils import ( + TimerManager, + get_safe_torch_device, + set_global_seed, +) +from lerobot.scripts.server import hilserl_pb2, hilserl_pb2_grpc +from lerobot.scripts.server.buffer import ( + Transition, + move_state_dict_to_device, + move_transition_to_device, + python_object_to_bytes, + transitions_to_bytes, + bytes_to_state_dict, +) +from lerobot.scripts.server.network_utils import ( + receive_bytes_in_chunks, + send_bytes_in_chunks, +) +from lerobot.scripts.server.gym_manipulator import get_classifier, make_robot_env +from lerobot.scripts.server import learner_service + +from torch.multiprocessing import Queue, Event +from queue import Empty + +from lerobot.common.utils.utils import init_logging + +from lerobot.scripts.server.utils import get_last_item_from_queue + +ACTOR_SHUTDOWN_TIMEOUT = 30 + + +def receive_policy( + cfg: DictConfig, + parameters_queue: Queue, + shutdown_event: any, # Event, + learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +): + logging.info("[ACTOR] Start receiving parameters from the Learner") + + if not use_threads(cfg): + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + setup_process_handlers(False) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.actor_learner_config.learner_host, + port=cfg.actor_learner_config.learner_port, + ) + + try: + iterator = learner_client.StreamParameters(hilserl_pb2.Empty()) + receive_bytes_in_chunks( + iterator, + parameters_queue, + shutdown_event, + log_prefix="[ACTOR] parameters", + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Received policy loop stopped") + + +def transitions_stream( + shutdown_event: Event, transitions_queue: Queue +) -> hilserl_pb2.Empty: + while not shutdown_event.is_set(): + try: + message = transitions_queue.get(block=True, timeout=5) + except Empty: + logging.debug("[ACTOR] Transition queue is empty") + continue + + yield from send_bytes_in_chunks( + message, hilserl_pb2.Transition, log_prefix="[ACTOR] Send transitions" + ) + + return hilserl_pb2.Empty() + + +def interactions_stream( + shutdown_event: any, # Event, + interactions_queue: Queue, +) -> hilserl_pb2.Empty: + while not shutdown_event.is_set(): + try: + message = interactions_queue.get(block=True, timeout=5) + except Empty: + logging.debug("[ACTOR] Interaction queue is empty") + continue + + yield from send_bytes_in_chunks( + message, + hilserl_pb2.InteractionMessage, + log_prefix="[ACTOR] Send interactions", + ) + + return hilserl_pb2.Empty() + + +def send_transitions( + cfg: DictConfig, + transitions_queue: Queue, + shutdown_event: any, # Event, + learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> hilserl_pb2.Empty: + """ + Sends transitions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - **Transition Data:** + - A batch of transitions (observation, action, reward, next observation) is collected. + - Transitions are moved to the CPU and serialized using PyTorch. + - The serialized data is wrapped in a `hilserl_pb2.Transition` message and sent to the learner. + """ + + if not use_threads(cfg): + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + setup_process_handlers(False) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.actor_learner_config.learner_host, + port=cfg.actor_learner_config.learner_port, + ) + + try: + learner_client.SendTransitions( + transitions_stream(shutdown_event, transitions_queue) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming transitions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Transitions process stopped") + + +def send_interactions( + cfg: DictConfig, + interactions_queue: Queue, + shutdown_event: any, # Event, + learner_client: hilserl_pb2_grpc.LearnerServiceStub | None = None, + grpc_channel: grpc.Channel | None = None, +) -> hilserl_pb2.Empty: + """ + Sends interactions to the learner. + + This function continuously retrieves messages from the queue and processes: + + - **Interaction Messages:** + - Contains useful statistics about episodic rewards and policy timings. + - The message is serialized using `pickle` and sent to the learner. + """ + + if not use_threads(cfg): + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + setup_process_handlers(False) + + if grpc_channel is None or learner_client is None: + learner_client, grpc_channel = learner_service_client( + host=cfg.actor_learner_config.learner_host, + port=cfg.actor_learner_config.learner_port, + ) + + try: + learner_client.SendInteractions( + interactions_stream(shutdown_event, interactions_queue) + ) + except grpc.RpcError as e: + logging.error(f"[ACTOR] gRPC error: {e}") + + logging.info("[ACTOR] Finished streaming interactions") + + if not use_threads(cfg): + grpc_channel.close() + logging.info("[ACTOR] Interactions process stopped") + + +@lru_cache(maxsize=1) +def learner_service_client( + host="127.0.0.1", port=50051 +) -> tuple[hilserl_pb2_grpc.LearnerServiceStub, grpc.Channel]: + import json + + """ + Returns a client for the learner service. + + GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection. + So we need to create only one client and reuse it. + """ + + service_config = { + "methodConfig": [ + { + "name": [{}], # Applies to ALL methods in ALL services + "retryPolicy": { + "maxAttempts": 5, # Max retries (total attempts = 5) + "initialBackoff": "0.1s", # First retry after 0.1s + "maxBackoff": "2s", # Max wait time between retries + "backoffMultiplier": 2, # Exponential backoff factor + "retryableStatusCodes": [ + "UNAVAILABLE", + "DEADLINE_EXCEEDED", + ], # Retries on network failures + }, + } + ] + } + + service_config_json = json.dumps(service_config) + + channel = grpc.insecure_channel( + f"{host}:{port}", + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.enable_retries", 1), + ("grpc.service_config", service_config_json), + ], + ) + stub = hilserl_pb2_grpc.LearnerServiceStub(channel) + logging.info("[ACTOR] Learner service client created") + return stub, channel + + +def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device): + if not parameters_queue.empty(): + logging.info("[ACTOR] Load new parameters from Learner.") + bytes_state_dict = get_last_item_from_queue(parameters_queue) + state_dict = bytes_to_state_dict(bytes_state_dict) + state_dict = move_state_dict_to_device(state_dict, device=device) + policy.load_state_dict(state_dict) + + +def act_with_policy( + cfg: DictConfig, + robot: Robot, + reward_classifier: nn.Module, + shutdown_event: any, # Event, + parameters_queue: Queue, + transitions_queue: Queue, + interactions_queue: Queue, +): + """ + Executes policy interaction within the environment. + + This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner. + Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network. + + Args: + cfg (DictConfig): Configuration settings for the interaction process. + """ + + logging.info("make_env online") + + online_env = make_robot_env( + robot=robot, reward_classifier=reward_classifier, cfg=cfg + ) + + set_global_seed(cfg.seed) + device = get_safe_torch_device(cfg.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("make_policy") + + # HACK: This is an ugly hack to pass the normalization parameters to the policy + # Because the action space is dynamic so we override the output normalization parameters + # it's ugly, we know ... and we will fix it + min_action_space: list = online_env.action_space.spaces[0].low.tolist() + max_action_space: list = online_env.action_space.spaces[0].high.tolist() + output_normalization_params: dict[dict[str, list]] = { + "action": {"min": min_action_space, "max": max_action_space} + } + cfg.policy.output_normalization_params = output_normalization_params + cfg.policy.output_shapes["action"] = online_env.action_space.spaces[0].shape + + ### Instantiate the policy in both the actor and learner processes + ### To avoid sending a SACPolicy object through the port, we create a policy intance + ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters + # TODO: At some point we should just need make sac policy + policy: SACPolicy = make_policy( + hydra_cfg=cfg, + # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, + # Hack: But if we do online training, we do not need dataset_stats + dataset_stats=None, + # TODO: Handle resume training + device=device, + ) + policy = torch.compile(policy) + assert isinstance(policy, nn.Module) + + obs, info = online_env.reset() + + # NOTE: For the moment we will solely handle the case of a single environment + sum_reward_episode = 0 + list_transition_to_send_to_learner = [] + list_policy_time = [] + episode_intervention = False + + for interaction_step in range(cfg.training.online_steps): + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down act_with_policy") + return + + if interaction_step >= cfg.training.online_step_before_learning: + # Time policy inference and check if it meets FPS requirement + with TimerManager( + elapsed_time_list=list_policy_time, + label="Policy inference time", + log=False, + ) as timer: # noqa: F841 + action = policy.select_action(batch=obs) + policy_fps = 1.0 / (list_policy_time[-1] + 1e-9) + + log_policy_frequency_issue( + policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step + ) + + next_obs, reward, done, truncated, info = online_env.step( + action.squeeze(dim=0).cpu().numpy() + ) + else: + # TODO (azouitine): Make a custom space for torch tensor + action = online_env.action_space.sample() + next_obs, reward, done, truncated, info = online_env.step(action) + + # HACK: We have only one env but we want to batch it, it will be resolved with the torch box + action = ( + torch.from_numpy(action[0]) + .to(device, non_blocking=device.type == "cuda") + .unsqueeze(dim=0) + ) + + sum_reward_episode += float(reward) + + # NOTE: We overide the action if the intervention is True, because the action applied is the intervention action + if "is_intervention" in info and info["is_intervention"]: + # TODO: Check the shape + # NOTE: The action space for demonstration before hand is with the full action space + # but sometimes for example we want to deactivate the gripper + action = info["action_intervention"] + episode_intervention = True + + # Check for NaN values in observations + for key, tensor in obs.items(): + if torch.isnan(tensor).any(): + logging.error( + f"[ACTOR] NaN values found in obs[{key}] at step {interaction_step}" + ) + + list_transition_to_send_to_learner.append( + Transition( + state=obs, + action=action, + reward=reward, + next_state=next_obs, + done=done, + truncated=truncated, # TODO: (azouitine) Handle truncation properly + complementary_info=info, # TODO Handle information for the transition, is_demonstraction: bool + ) + ) + + # assign obs to the next obs and continue the rollout + obs = next_obs + + # HACK: We have only one env but we want to batch it, it will be resolved with the torch box + # Because we are using a single environment we can index at zero + if done or truncated: + # TODO: Handle logging for episode information + logging.info( + f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}" + ) + + update_policy_parameters( + policy=policy.actor, parameters_queue=parameters_queue, device=device + ) + + if len(list_transition_to_send_to_learner) > 0: + push_transitions_to_transport_queue( + transitions=list_transition_to_send_to_learner, + transitions_queue=transitions_queue, + ) + list_transition_to_send_to_learner = [] + + stats = get_frequency_stats(list_policy_time) + list_policy_time.clear() + + # Send episodic reward to the learner + interactions_queue.put( + python_object_to_bytes( + { + "Episodic reward": sum_reward_episode, + "Interaction step": interaction_step, + "Episode intervention": int(episode_intervention), + **stats, + } + ) + ) + sum_reward_episode = 0.0 + episode_intervention = False + obs, info = online_env.reset() + + +def push_transitions_to_transport_queue(transitions: list, transitions_queue): + """Send transitions to learner in smaller chunks to avoid network issues. + + Args: + transitions: List of transitions to send + message_queue: Queue to send messages to learner + chunk_size: Size of each chunk to send + """ + transition_to_send_to_learner = [] + for transition in transitions: + tr = move_transition_to_device(transition=transition, device="cpu") + for key, value in tr["state"].items(): + if torch.isnan(value).any(): + logging.warning(f"Found NaN values in transition {key}") + + transition_to_send_to_learner.append(tr) + + transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner)) + + +def get_frequency_stats(list_policy_time: list[float]) -> dict[str, float]: + stats = {} + list_policy_fps = [1.0 / t for t in list_policy_time] + if len(list_policy_fps) > 1: + policy_fps = mean(list_policy_fps) + quantiles_90 = quantiles(list_policy_fps, n=10)[-1] + logging.debug(f"[ACTOR] Average policy frame rate: {policy_fps}") + logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {quantiles_90}") + stats = { + "Policy frequency [Hz]": policy_fps, + "Policy frequency 90th-p [Hz]": quantiles_90, + } + return stats + + +def log_policy_frequency_issue( + policy_fps: float, cfg: DictConfig, interaction_step: int +): + if policy_fps < cfg.fps: + logging.warning( + f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.fps} at step {interaction_step}" + ) + + +def establish_learner_connection( + stub, + shutdown_event: any, # Event, + attempts=30, +): + for _ in range(attempts): + if shutdown_event.is_set(): + logging.info("[ACTOR] Shutting down establish_learner_connection") + return False + + # Force a connection attempt and check state + try: + logging.info("[ACTOR] Send ready message to Learner") + if stub.Ready(hilserl_pb2.Empty()) == hilserl_pb2.Empty(): + return True + except grpc.RpcError as e: + logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}") + time.sleep(2) + return False + + +def use_threads(cfg: DictConfig) -> bool: + return cfg.actor_learner_config.concurrency.actor == "threads" + + +@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") +def actor_cli(cfg: dict): + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + + init_logging(log_file="actor.log") + robot = make_robot(cfg=cfg.robot) + + shutdown_event = setup_process_handlers(use_threads(cfg)) + + learner_client, grpc_channel = learner_service_client( + host=cfg.actor_learner_config.learner_host, + port=cfg.actor_learner_config.learner_port, + ) + + logging.info("[ACTOR] Establishing connection with Learner") + if not establish_learner_connection(learner_client, shutdown_event): + logging.error("[ACTOR] Failed to establish connection with Learner") + return + + if not use_threads(cfg): + # If we use multithreading, we can reuse the channel + grpc_channel.close() + grpc_channel = None + + logging.info("[ACTOR] Connection with Learner established") + + parameters_queue = Queue() + transitions_queue = Queue() + interactions_queue = Queue() + + concurrency_entity = None + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from multiprocessing import Process + + concurrency_entity = Process + + receive_policy_process = concurrency_entity( + target=receive_policy, + args=(cfg, parameters_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process = concurrency_entity( + target=send_transitions, + args=(cfg, transitions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + interactions_process = concurrency_entity( + target=send_interactions, + args=(cfg, interactions_queue, shutdown_event, grpc_channel), + daemon=True, + ) + + transitions_process.start() + interactions_process.start() + receive_policy_process.start() + + # HACK: FOR MANISKILL we do not have a reward classifier + # TODO: Remove this once we merge into main + reward_classifier = None + if ( + cfg.env.reward_classifier.pretrained_path is not None + and cfg.env.reward_classifier.config_path is not None + ): + reward_classifier = get_classifier( + pretrained_path=cfg.env.reward_classifier.pretrained_path, + config_path=cfg.env.reward_classifier.config_path, + ) + + act_with_policy( + cfg, + robot, + reward_classifier, + shutdown_event, + parameters_queue, + transitions_queue, + interactions_queue, + ) + logging.info("[ACTOR] Policy process joined") + + logging.info("[ACTOR] Closing queues") + transitions_queue.close() + interactions_queue.close() + parameters_queue.close() + + transitions_process.join() + logging.info("[ACTOR] Transitions process joined") + interactions_process.join() + logging.info("[ACTOR] Interactions process joined") + receive_policy_process.join() + logging.info("[ACTOR] Receive policy process joined") + + logging.info("[ACTOR] join queues") + transitions_queue.cancel_join_thread() + interactions_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[ACTOR] queues closed") + + +if __name__ == "__main__": + actor_cli() diff --git a/lerobot/scripts/server/buffer.py b/lerobot/scripts/server/buffer.py new file mode 100644 index 000000000..80834eac0 --- /dev/null +++ b/lerobot/scripts/server/buffer.py @@ -0,0 +1,1253 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +from typing import Any, Callable, Optional, Sequence, TypedDict + +import io +import torch +import torch.nn.functional as F # noqa: N812 +from tqdm import tqdm + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +import os +import pickle + + +class Transition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: float + next_state: dict[str, torch.Tensor] + done: bool + truncated: bool + complementary_info: dict[str, Any] = None + + +class BatchTransition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: torch.Tensor + next_state: dict[str, torch.Tensor] + done: torch.Tensor + truncated: torch.Tensor + + +def move_transition_to_device( + transition: Transition, device: str = "cpu" +) -> Transition: + # Move state tensors to CPU + device = torch.device(device) + transition["state"] = { + key: val.to(device, non_blocking=device.type == "cuda") + for key, val in transition["state"].items() + } + + # Move action to CPU + transition["action"] = transition["action"].to( + device, non_blocking=device.type == "cuda" + ) + + # No need to move reward or done, as they are float and bool + + # No need to move reward or done, as they are float and bool + if isinstance(transition["reward"], torch.Tensor): + transition["reward"] = transition["reward"].to( + device=device, non_blocking=device.type == "cuda" + ) + + if isinstance(transition["done"], torch.Tensor): + transition["done"] = transition["done"].to( + device, non_blocking=device.type == "cuda" + ) + + if isinstance(transition["truncated"], torch.Tensor): + transition["truncated"] = transition["truncated"].to( + device, non_blocking=device.type == "cuda" + ) + + # Move next_state tensors to CPU + transition["next_state"] = { + key: val.to(device, non_blocking=device.type == "cuda") + for key, val in transition["next_state"].items() + } + + # If complementary_info is present, move its tensors to CPU + # if transition["complementary_info"] is not None: + # transition["complementary_info"] = { + # key: val.to(device, non_blocking=True) for key, val in transition["complementary_info"].items() + # } + return transition + + +def move_state_dict_to_device(state_dict, device="cpu"): + """ + Recursively move all tensors in a (potentially) nested + dict/list/tuple structure to the CPU. + """ + if isinstance(state_dict, torch.Tensor): + return state_dict.to(device) + elif isinstance(state_dict, dict): + return { + k: move_state_dict_to_device(v, device=device) + for k, v in state_dict.items() + } + elif isinstance(state_dict, list): + return [move_state_dict_to_device(v, device=device) for v in state_dict] + elif isinstance(state_dict, tuple): + return tuple(move_state_dict_to_device(v, device=device) for v in state_dict) + else: + return state_dict + + +def state_to_bytes(state_dict: dict[str, torch.Tensor]) -> bytes: + """Convert model state dict to flat array for transmission""" + buffer = io.BytesIO() + + torch.save(state_dict, buffer) + + return buffer.getvalue() + + +def bytes_to_state_dict(buffer: bytes) -> dict[str, torch.Tensor]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return torch.load(buffer) + + +def python_object_to_bytes(python_object: Any) -> bytes: + return pickle.dumps(python_object) + + +def bytes_to_python_object(buffer: bytes) -> Any: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return pickle.load(buffer) + + +def bytes_to_transitions(buffer: bytes) -> list[Transition]: + buffer = io.BytesIO(buffer) + buffer.seek(0) + return torch.load(buffer) + + +def transitions_to_bytes(transitions: list[Transition]) -> bytes: + buffer = io.BytesIO() + torch.save(transitions, buffer) + return buffer.getvalue() + + +def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: + """ + Perform a per-image random crop over a batch of images in a vectorized way. + (Same as shown previously.) + """ + B, C, H, W = images.shape # noqa: N806 + crop_h, crop_w = output_size + + if crop_h > H or crop_w > W: + raise ValueError( + f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})." + ) + + tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device) + lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device) + + rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1) + cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1) + + rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w) + cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w) + + images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) + + # Gather pixels + cropped_hwcn = images_hwcn[ + torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, : + ] + # cropped_hwcn => (B, crop_h, crop_w, C) + + cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) + return cropped + + +def random_shift(images: torch.Tensor, pad: int = 4): + """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels""" + _, _, h, w = images.shape + images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate") + return random_crop_vectorized(images=images, output_size=(h, w)) + + +class ReplayBuffer: + def __init__( + self, + capacity: int, + device: str = "cuda:0", + state_keys: Optional[Sequence[str]] = None, + image_augmentation_function: Optional[Callable] = None, + use_drq: bool = True, + storage_device: str = "cpu", + optimize_memory: bool = False, + ): + """ + Args: + capacity (int): Maximum number of transitions to store in the buffer. + device (str): The device where the tensors will be moved when sampling ("cuda:0" or "cpu"). + state_keys (List[str]): The list of keys that appear in `state` and `next_state`. + image_augmentation_function (Optional[Callable]): A function that takes a batch of images + and returns a batch of augmented images. If None, a default augmentation function is used. + use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. + storage_device: The device (e.g. "cpu" or "cuda:0") where the data will be stored. + Using "cpu" can help save GPU memory. + optimize_memory (bool): If True, optimizes memory by not storing duplicate next_states when + they can be derived from states. This is useful for large datasets where next_state[i] = state[i+1]. + """ + self.capacity = capacity + self.device = device + self.storage_device = storage_device + self.position = 0 + self.size = 0 + self.initialized = False + self.optimize_memory = optimize_memory + + # Track episode boundaries for memory optimization + self.episode_ends = torch.zeros( + capacity, dtype=torch.bool, device=storage_device + ) + + # If no state_keys provided, default to an empty list + self.state_keys = state_keys if state_keys is not None else [] + + if image_augmentation_function is None: + base_function = functools.partial(random_shift, pad=4) + self.image_augmentation_function = torch.compile(base_function) + self.use_drq = use_drq + + def _initialize_storage(self, state: dict[str, torch.Tensor], action: torch.Tensor): + """Initialize the storage tensors based on the first transition.""" + # Determine shapes from the first transition + state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} + action_shape = action.squeeze(0).shape + + # Pre-allocate tensors for storage + self.states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + self.actions = torch.empty( + (self.capacity, *action_shape), device=self.storage_device + ) + self.rewards = torch.empty((self.capacity,), device=self.storage_device) + + if not self.optimize_memory: + # Standard approach: store states and next_states separately + self.next_states = { + key: torch.empty((self.capacity, *shape), device=self.storage_device) + for key, shape in state_shapes.items() + } + else: + # Memory-optimized approach: don't allocate next_states buffer + # Just create a reference to states for consistent API + self.next_states = self.states # Just a reference for API consistency + + self.dones = torch.empty( + (self.capacity,), dtype=torch.bool, device=self.storage_device + ) + self.truncateds = torch.empty( + (self.capacity,), dtype=torch.bool, device=self.storage_device + ) + self.initialized = True + + def __len__(self): + return self.size + + def add( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + reward: float, + next_state: dict[str, torch.Tensor], + done: bool, + truncated: bool, + complementary_info: Optional[dict[str, torch.Tensor]] = None, + ): + """Saves a transition, ensuring tensors are stored on the designated storage device.""" + # Initialize storage if this is the first transition + if not self.initialized: + self._initialize_storage(state=state, action=action) + + # Store the transition in pre-allocated tensors + for key in self.states: + self.states[key][self.position].copy_(state[key].squeeze(dim=0)) + + if not self.optimize_memory: + # Only store next_states if not optimizing memory + self.next_states[key][self.position].copy_( + next_state[key].squeeze(dim=0) + ) + + self.actions[self.position].copy_(action.squeeze(dim=0)) + self.rewards[self.position] = reward + self.dones[self.position] = done + self.truncateds[self.position] = truncated + + self.position = (self.position + 1) % self.capacity + self.size = min(self.size + 1, self.capacity) + + def sample(self, batch_size: int) -> BatchTransition: + """Sample a random batch of transitions and collate them into batched tensors.""" + if not self.initialized: + raise RuntimeError( + "Cannot sample from an empty buffer. Add transitions first." + ) + + batch_size = min(batch_size, self.size) + + # Random indices for sampling - create on the same device as storage + idx = torch.randint( + low=0, high=self.size, size=(batch_size,), device=self.storage_device + ) + + # Identify image keys that need augmentation + image_keys = ( + [k for k in self.states if k.startswith("observation.image")] + if self.use_drq + else [] + ) + + # Create batched state and next_state + batch_state = {} + batch_next_state = {} + + # First pass: load all state tensors to target device + for key in self.states: + batch_state[key] = self.states[key][idx].to(self.device) + + if not self.optimize_memory: + # Standard approach - load next_states directly + batch_next_state[key] = self.next_states[key][idx].to(self.device) + else: + # Memory-optimized approach - get next_state from the next index + next_idx = (idx + 1) % self.capacity + batch_next_state[key] = self.states[key][next_idx].to(self.device) + + # Apply image augmentation in a batched way if needed + if self.use_drq and image_keys: + # Concatenate all images from state and next_state + all_images = [] + for key in image_keys: + all_images.append(batch_state[key]) + all_images.append(batch_next_state[key]) + + # Batch all images and apply augmentation once + all_images_tensor = torch.cat(all_images, dim=0) + augmented_images = self.image_augmentation_function(all_images_tensor) + + # Split the augmented images back to their sources + for i, key in enumerate(image_keys): + # State images are at even indices (0, 2, 4...) + batch_state[key] = augmented_images[ + i * 2 * batch_size : (i * 2 + 1) * batch_size + ] + # Next state images are at odd indices (1, 3, 5...) + batch_next_state[key] = augmented_images[ + (i * 2 + 1) * batch_size : (i + 1) * 2 * batch_size + ] + + # Sample other tensors + batch_actions = self.actions[idx].to(self.device) + batch_rewards = self.rewards[idx].to(self.device) + batch_dones = self.dones[idx].to(self.device).float() + batch_truncateds = self.truncateds[idx].to(self.device).float() + + return BatchTransition( + state=batch_state, + action=batch_actions, + reward=batch_rewards, + next_state=batch_next_state, + done=batch_dones, + truncated=batch_truncateds, + ) + + @classmethod + def from_lerobot_dataset( + cls, + lerobot_dataset: LeRobotDataset, + device: str = "cuda:0", + state_keys: Optional[Sequence[str]] = None, + capacity: Optional[int] = None, + action_mask: Optional[Sequence[int]] = None, + action_delta: Optional[float] = None, + image_augmentation_function: Optional[Callable] = None, + use_drq: bool = True, + storage_device: str = "cpu", + optimize_memory: bool = False, + ) -> "ReplayBuffer": + """ + Convert a LeRobotDataset into a ReplayBuffer. + + Args: + lerobot_dataset (LeRobotDataset): The dataset to convert. + device (str): The device for sampling tensors. Defaults to "cuda:0". + state_keys (Optional[Sequence[str]]): The list of keys that appear in `state` and `next_state`. + capacity (Optional[int]): Buffer capacity. If None, uses dataset length. + action_mask (Optional[Sequence[int]]): Indices of action dimensions to keep. + action_delta (Optional[float]): Factor to divide actions by. + image_augmentation_function (Optional[Callable]): Function for image augmentation. + If None, uses default random shift with pad=4. + use_drq (bool): Whether to use DrQ image augmentation when sampling. + storage_device (str): Device for storing tensor data. Using "cpu" saves GPU memory. + optimize_memory (bool): If True, reduces memory usage by not duplicating state data. + + Returns: + ReplayBuffer: The replay buffer with dataset transitions. + """ + if capacity is None: + capacity = len(lerobot_dataset) + + if capacity < len(lerobot_dataset): + raise ValueError( + "The capacity of the ReplayBuffer must be greater than or equal to the length of the LeRobotDataset." + ) + + # Create replay buffer with image augmentation and DrQ settings + replay_buffer = cls( + capacity=capacity, + device=device, + state_keys=state_keys, + image_augmentation_function=image_augmentation_function, + use_drq=use_drq, + storage_device=storage_device, + optimize_memory=optimize_memory, + ) + + # Convert dataset to transitions + list_transition = cls._lerobotdataset_to_transitions( + dataset=lerobot_dataset, state_keys=state_keys + ) + + # Initialize the buffer with the first transition to set up storage tensors + if list_transition: + first_transition = list_transition[0] + first_state = { + k: v.to(device) for k, v in first_transition["state"].items() + } + first_action = first_transition["action"].to(device) + + # Apply action mask/delta if needed + if action_mask is not None: + if first_action.dim() == 1: + first_action = first_action[action_mask] + else: + first_action = first_action[:, action_mask] + + if action_delta is not None: + first_action = first_action / action_delta + + replay_buffer._initialize_storage(state=first_state, action=first_action) + + # Fill the buffer with all transitions + for data in list_transition: + for k, v in data.items(): + if isinstance(v, dict): + for key, tensor in v.items(): + v[key] = tensor.to(device) + elif isinstance(v, torch.Tensor): + data[k] = v.to(device) + + action = data["action"] + if action_mask is not None: + if action.dim() == 1: + action = action[action_mask] + else: + action = action[:, action_mask] + + if action_delta is not None: + action = action / action_delta + + replay_buffer.add( + state=data["state"], + action=action, + reward=data["reward"], + next_state=data["next_state"], + done=data["done"], + truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset + ) + + return replay_buffer + + def to_lerobot_dataset( + self, + repo_id: str, + fps=1, + root=None, + task_name="from_replay_buffer", + ) -> LeRobotDataset: + """ + Converts all transitions in this ReplayBuffer into a single LeRobotDataset object. + """ + if self.size == 0: + raise ValueError("The replay buffer is empty. Cannot convert to a dataset.") + + # Create features dictionary for the dataset + features = { + "index": {"dtype": "int64", "shape": [1]}, # global index across episodes + "episode_index": {"dtype": "int64", "shape": [1]}, # which episode + "frame_index": {"dtype": "int64", "shape": [1]}, # index inside an episode + "timestamp": {"dtype": "float32", "shape": [1]}, # for now we store dummy + "task_index": {"dtype": "int64", "shape": [1]}, + } + + # Add "action" + sample_action = self.actions[0] + act_info = guess_feature_info(t=sample_action, name="action") + features["action"] = act_info + + # Add "reward" and "done" + features["next.reward"] = {"dtype": "float32", "shape": (1,)} + features["next.done"] = {"dtype": "bool", "shape": (1,)} + + # Add state keys + for key in self.states: + sample_val = self.states[key][0] + f_info = guess_feature_info(t=sample_val, name=key) + features[key] = f_info + + # Create an empty LeRobotDataset + lerobot_dataset = LeRobotDataset.create( + repo_id=repo_id, + fps=fps, + root=root, + robot=None, # TODO: (azouitine) Handle robot + robot_type=None, + features=features, + use_videos=True, + ) + + # Start writing images if needed + lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) + + # Convert transitions into episodes and frames + episode_index = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( + episode_index=episode_index + ) + + frame_idx_in_episode = 0 + for idx in range(self.size): + actual_idx = (self.position - self.size + idx) % self.capacity + + frame_dict = {} + + # Fill the data for state keys + for key in self.states: + frame_dict[key] = self.states[key][actual_idx].cpu() + + # Fill action, reward, done + frame_dict["action"] = self.actions[actual_idx].cpu() + frame_dict["next.reward"] = torch.tensor( + [self.rewards[actual_idx]], dtype=torch.float32 + ).cpu() + frame_dict["next.done"] = torch.tensor( + [self.dones[actual_idx]], dtype=torch.bool + ).cpu() + + # Add to the dataset's buffer + lerobot_dataset.add_frame(frame_dict) + + # Move to next frame + frame_idx_in_episode += 1 + + # If we reached an episode boundary, call save_episode, reset counters + if self.dones[actual_idx] or self.truncateds[actual_idx]: + lerobot_dataset.save_episode(task=task_name) + episode_index += 1 + frame_idx_in_episode = 0 + lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( + episode_index=episode_index + ) + + # Save any remaining frames in the buffer + if lerobot_dataset.episode_buffer["size"] > 0: + lerobot_dataset.save_episode(task=task_name) + + lerobot_dataset.stop_image_writer() + lerobot_dataset.consolidate(run_compute_stats=False, keep_image_files=False) + + return lerobot_dataset + + @staticmethod + def _lerobotdataset_to_transitions( + dataset: LeRobotDataset, + state_keys: Optional[Sequence[str]] = None, + ) -> list[Transition]: + """ + Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. + + Args: + dataset (LeRobotDataset): + The dataset to convert. Each item in the dataset is expected to have + at least the following keys: + { + "action": ... + "next.reward": ... + "next.done": ... + "episode_index": ... + } + plus whatever your 'state_keys' specify. + + state_keys (Optional[Sequence[str]]): + The dataset keys to include in 'state' and 'next_state'. Their names + will be kept as-is in the output transitions. E.g. + ["observation.state", "observation.environment_state"]. + If None, you must handle or define default keys. + + Returns: + transitions (List[Transition]): + A list of Transition dictionaries with the same length as `dataset`. + """ + if state_keys is None: + raise ValueError( + "State keys must be provided when converting LeRobotDataset to Transitions." + ) + + transitions = [] + num_frames = len(dataset) + + # Check if the dataset has "next.done" key + sample = dataset[0] + has_done_key = "next.done" in sample + + # If not, we need to infer it from episode boundaries + if not has_done_key: + print( + "'next.done' key not found in dataset. Inferring from episode boundaries..." + ) + + for i in tqdm(range(num_frames)): + current_sample = dataset[i] + + # ----- 1) Current state ----- + current_state: dict[str, torch.Tensor] = {} + for key in state_keys: + val = current_sample[key] + current_state[key] = val.unsqueeze(0) # Add batch dimension + + # ----- 2) Action ----- + action = current_sample["action"].unsqueeze(0) # Add batch dimension + + # ----- 3) Reward and done ----- + reward = float(current_sample["next.reward"].item()) # ensure float + + # Determine done flag - use next.done if available, otherwise infer from episode boundaries + if has_done_key: + done = bool(current_sample["next.done"].item()) # ensure bool + else: + # If this is the last frame or if next frame is in a different episode, mark as done + done = False + if i == num_frames - 1: + done = True + elif i < num_frames - 1: + next_sample = dataset[i + 1] + if next_sample["episode_index"] != current_sample["episode_index"]: + done = True + + # TODO: (azouitine) Handle truncation (using the same value as done for now) + truncated = done + + # ----- 4) Next state ----- + # If not done and the next sample is in the same episode, we pull the next sample's state. + # Otherwise (done=True or next sample crosses to a new episode), next_state = current_state. + next_state = current_state # default + if not done and (i < num_frames - 1): + next_sample = dataset[i + 1] + if next_sample["episode_index"] == current_sample["episode_index"]: + # Build next_state from the same keys + next_state_data: dict[str, torch.Tensor] = {} + for key in state_keys: + val = next_sample[key] + next_state_data[key] = val.unsqueeze(0) # Add batch dimension + next_state = next_state_data + + # ----- Construct the Transition ----- + transition = Transition( + state=current_state, + action=action, + reward=reward, + next_state=next_state, + done=done, + truncated=truncated, + ) + transitions.append(transition) + + return transitions + + +# Utility function to guess shapes/dtypes from a tensor +def guess_feature_info(t: torch.Tensor, name: str): + """ + Return a dictionary with the 'dtype' and 'shape' for a given tensor or array. + If it looks like a 3D (C,H,W) shape, we might consider it an 'image'. + Otherwise default to 'float32' for numeric. You can customize as needed. + """ + shape = tuple(t.shape) + # Basic guess: if we have exactly 3 dims and shape[0] in {1, 3}, guess 'image' + if len(shape) == 3 and shape[0] in [1, 3]: + return { + "dtype": "image", + "shape": shape, + } + else: + # Otherwise treat as numeric + return { + "dtype": "float32", + "shape": shape, + } + + +def concatenate_batch_transitions( + left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition +) -> BatchTransition: + """NOTE: Be careful it change the left_batch_transitions in place""" + left_batch_transitions["state"] = { + key: torch.cat( + [ + left_batch_transitions["state"][key], + right_batch_transition["state"][key], + ], + dim=0, + ) + for key in left_batch_transitions["state"] + } + left_batch_transitions["action"] = torch.cat( + [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 + ) + left_batch_transitions["reward"] = torch.cat( + [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 + ) + left_batch_transitions["next_state"] = { + key: torch.cat( + [ + left_batch_transitions["next_state"][key], + right_batch_transition["next_state"][key], + ], + dim=0, + ) + for key in left_batch_transitions["next_state"] + } + left_batch_transitions["done"] = torch.cat( + [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 + ) + left_batch_transitions["truncated"] = torch.cat( + [left_batch_transitions["truncated"], right_batch_transition["truncated"]], + dim=0, + ) + return left_batch_transitions + + +if __name__ == "__main__": + from tempfile import TemporaryDirectory + + # ===== Test 1: Create and use a synthetic ReplayBuffer ===== + print("Testing synthetic ReplayBuffer...") + + # Create sample data dimensions + batch_size = 32 + state_dims = {"observation.image": (3, 84, 84), "observation.state": (10,)} + action_dim = (6,) + + # Create a buffer + buffer = ReplayBuffer( + capacity=1000, + device="cpu", + state_keys=list(state_dims.keys()), + use_drq=True, + storage_device="cpu", + ) + + # Add some random transitions + for i in range(100): + # Create dummy transition data + state = { + "observation.image": torch.rand(1, 3, 84, 84), + "observation.state": torch.rand(1, 10), + } + action = torch.rand(1, 6) + reward = 0.5 + next_state = { + "observation.image": torch.rand(1, 3, 84, 84), + "observation.state": torch.rand(1, 10), + } + done = False if i < 99 else True + truncated = False + + buffer.add( + state=state, + action=action, + reward=reward, + next_state=next_state, + done=done, + truncated=truncated, + ) + + # Test sampling + batch = buffer.sample(batch_size) + print(f"Buffer size: {len(buffer)}") + print( + f"Sampled batch state shapes: {batch['state']['observation.image'].shape}, {batch['state']['observation.state'].shape}" + ) + print(f"Sampled batch action shape: {batch['action'].shape}") + print(f"Sampled batch reward shape: {batch['reward'].shape}") + print(f"Sampled batch done shape: {batch['done'].shape}") + print(f"Sampled batch truncated shape: {batch['truncated'].shape}") + + # ===== Test for state-action-reward alignment ===== + print("\nTesting state-action-reward alignment...") + + # Create a buffer with controlled transitions where we know the relationships + aligned_buffer = ReplayBuffer( + capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu" + ) + + # Create transitions with known relationships + # - Each state has a unique signature value + # - Action is 2x the state signature + # - Reward is 3x the state signature + # - Next state is signature + 0.01 (unless at episode end) + for i in range(100): + # Create a state with a signature value that encodes the transition number + signature = float(i) / 100.0 + state = {"state_value": torch.tensor([[signature]]).float()} + + # Action is 2x the signature + action = torch.tensor([[2.0 * signature]]).float() + + # Reward is 3x the signature + reward = 3.0 * signature + + # Next state is signature + 0.01, unless end of episode + # End episode every 10 steps + is_end = (i + 1) % 10 == 0 + + if is_end: + # At episode boundaries, next_state repeats current state (as per your implementation) + next_state = {"state_value": torch.tensor([[signature]]).float()} + done = True + else: + # Within episodes, next_state has signature + 0.01 + next_signature = float(i + 1) / 100.0 + next_state = {"state_value": torch.tensor([[next_signature]]).float()} + done = False + + aligned_buffer.add(state, action, reward, next_state, done, False) + + # Sample from this buffer + aligned_batch = aligned_buffer.sample(50) + + # Verify alignments in sampled batch + correct_relationships = 0 + total_checks = 0 + + # For each transition in the batch + for i in range(50): + # Extract signature from state + state_sig = aligned_batch["state"]["state_value"][i].item() + + # Check action is 2x signature (within reasonable precision) + action_val = aligned_batch["action"][i].item() + action_check = abs(action_val - 2.0 * state_sig) < 1e-4 + + # Check reward is 3x signature (within reasonable precision) + reward_val = aligned_batch["reward"][i].item() + reward_check = abs(reward_val - 3.0 * state_sig) < 1e-4 + + # Check next_state relationship matches our pattern + next_state_sig = aligned_batch["next_state"]["state_value"][i].item() + is_done = aligned_batch["done"][i].item() > 0.5 + + # Calculate expected next_state value based on done flag + if is_done: + # For episodes that end, next_state should equal state + next_state_check = abs(next_state_sig - state_sig) < 1e-4 + else: + # For continuing episodes, check if next_state is approximately state + 0.01 + # We need to be careful because we don't know the original index + # So we check if the increment is roughly 0.01 + next_state_check = ( + abs(next_state_sig - state_sig - 0.01) < 1e-4 + or abs(next_state_sig - state_sig) < 1e-4 + ) + + # Count correct relationships + if action_check: + correct_relationships += 1 + if reward_check: + correct_relationships += 1 + if next_state_check: + correct_relationships += 1 + + total_checks += 3 + + alignment_accuracy = 100.0 * correct_relationships / total_checks + print( + f"State-action-reward-next_state alignment accuracy: {alignment_accuracy:.2f}%" + ) + if alignment_accuracy > 99.0: + print( + "✅ All relationships verified! Buffer maintains correct temporal relationships." + ) + else: + print( + "⚠️ Some relationships don't match expected patterns. Buffer may have alignment issues." + ) + + # Print some debug information about failures + print("\nDebug information for failed checks:") + for i in range(5): # Print first 5 transitions for debugging + state_sig = aligned_batch["state"]["state_value"][i].item() + action_val = aligned_batch["action"][i].item() + reward_val = aligned_batch["reward"][i].item() + next_state_sig = aligned_batch["next_state"]["state_value"][i].item() + is_done = aligned_batch["done"][i].item() > 0.5 + + print(f"Transition {i}:") + print(f" State: {state_sig:.6f}") + print(f" Action: {action_val:.6f} (expected: {2.0 * state_sig:.6f})") + print(f" Reward: {reward_val:.6f} (expected: {3.0 * state_sig:.6f})") + print(f" Done: {is_done}") + print(f" Next state: {next_state_sig:.6f}") + + # Calculate expected next state + if is_done: + expected_next = state_sig + else: + # This approximation might not be perfect + state_idx = round(state_sig * 100) + expected_next = (state_idx + 1) / 100.0 + + print(f" Expected next state: {expected_next:.6f}") + print() + + # ===== Test 2: Convert to LeRobotDataset and back ===== + with TemporaryDirectory() as temp_dir: + print("\nTesting conversion to LeRobotDataset and back...") + # Convert buffer to dataset + repo_id = "test/replay_buffer_conversion" + # Create a subdirectory to avoid the "directory exists" error + dataset_dir = os.path.join(temp_dir, "dataset1") + dataset = buffer.to_lerobot_dataset(repo_id=repo_id, root=dataset_dir) + + print(f"Dataset created with {len(dataset)} frames") + print(f"Dataset features: {list(dataset.features.keys())}") + + # Check a random sample from the dataset + sample = dataset[0] + print( + f"Dataset sample types: {[(k, type(v)) for k, v in sample.items() if k.startswith('observation')]}" + ) + + # Convert dataset back to buffer + reconverted_buffer = ReplayBuffer.from_lerobot_dataset( + dataset, state_keys=list(state_dims.keys()), device="cpu" + ) + + print(f"Reconverted buffer size: {len(reconverted_buffer)}") + + # Sample from the reconverted buffer + reconverted_batch = reconverted_buffer.sample(batch_size) + print( + f"Reconverted batch state shapes: {reconverted_batch['state']['observation.image'].shape}, {reconverted_batch['state']['observation.state'].shape}" + ) + + # Verify consistency before and after conversion + original_states = batch["state"]["observation.image"].mean().item() + reconverted_states = ( + reconverted_batch["state"]["observation.image"].mean().item() + ) + print(f"Original buffer state mean: {original_states:.4f}") + print(f"Reconverted buffer state mean: {reconverted_states:.4f}") + + if abs(original_states - reconverted_states) < 1.0: + print("Values are reasonably similar - conversion works as expected") + else: + print( + "WARNING: Significant difference between original and reconverted values" + ) + + print("\nAll previous tests completed!") + + # ===== Test for memory optimization ===== + print("\n===== Testing Memory Optimization =====") + + # Create two buffers, one with memory optimization and one without + standard_buffer = ReplayBuffer( + capacity=1000, + device="cpu", + state_keys=["observation.image", "observation.state"], + storage_device="cpu", + optimize_memory=False, + use_drq=True, + ) + + optimized_buffer = ReplayBuffer( + capacity=1000, + device="cpu", + state_keys=["observation.image", "observation.state"], + storage_device="cpu", + optimize_memory=True, + use_drq=True, + ) + + # Generate sample data with larger state dimensions for better memory impact + print("Generating test data...") + num_episodes = 10 + steps_per_episode = 50 + total_steps = num_episodes * steps_per_episode + + for episode in range(num_episodes): + for step in range(steps_per_episode): + # Index in the overall sequence + i = episode * steps_per_episode + step + + # Create state with identifiable values + img = torch.ones((3, 84, 84)) * (i / total_steps) + state_vec = torch.ones((10,)) * (i / total_steps) + + state = { + "observation.image": img.unsqueeze(0), + "observation.state": state_vec.unsqueeze(0), + } + + # Create next state (i+1 or same as current if last in episode) + is_last_step = step == steps_per_episode - 1 + + if is_last_step: + # At episode end, next state = current state + next_img = img.clone() + next_state_vec = state_vec.clone() + done = True + truncated = False + else: + # Within episode, next state has incremented value + next_val = (i + 1) / total_steps + next_img = torch.ones((3, 84, 84)) * next_val + next_state_vec = torch.ones((10,)) * next_val + done = False + truncated = False + + next_state = { + "observation.image": next_img.unsqueeze(0), + "observation.state": next_state_vec.unsqueeze(0), + } + + # Action and reward + action = torch.tensor([[i / total_steps]]) + reward = float(i / total_steps) + + # Add to both buffers + standard_buffer.add(state, action, reward, next_state, done, truncated) + optimized_buffer.add(state, action, reward, next_state, done, truncated) + + # Verify episode boundaries with our simplified approach + print("\nVerifying simplified memory optimization...") + + # Test with a new buffer with a small sequence + test_buffer = ReplayBuffer( + capacity=20, + device="cpu", + state_keys=["value"], + storage_device="cpu", + optimize_memory=True, + use_drq=False, + ) + + # Add a simple sequence with known episode boundaries + for i in range(20): + val = float(i) + state = {"value": torch.tensor([[val]]).float()} + next_val = float(i + 1) if i % 5 != 4 else val # Episode ends every 5 steps + next_state = {"value": torch.tensor([[next_val]]).float()} + + # Set done=True at every 5th step + done = (i % 5) == 4 + action = torch.tensor([[0.0]]) + reward = 1.0 + truncated = False + + test_buffer.add(state, action, reward, next_state, done, truncated) + + # Get sequential batch for verification + sequential_batch_size = test_buffer.size + all_indices = torch.arange(sequential_batch_size, device=test_buffer.storage_device) + + # Get state tensors + batch_state = { + "value": test_buffer.states["value"][all_indices].to(test_buffer.device) + } + + # Get next_state using memory-optimized approach (simply index+1) + next_indices = (all_indices + 1) % test_buffer.capacity + batch_next_state = { + "value": test_buffer.states["value"][next_indices].to(test_buffer.device) + } + + # Get other tensors + batch_dones = test_buffer.dones[all_indices].to(test_buffer.device) + + # Print sequential values + print("State, Next State, Done (Sequential values with simplified optimization):") + state_values = batch_state["value"].squeeze().tolist() + next_values = batch_next_state["value"].squeeze().tolist() + done_flags = batch_dones.tolist() + + # Print all values + for i in range(len(state_values)): + print(f" {state_values[i]:.1f} → {next_values[i]:.1f}, Done: {done_flags[i]}") + + # Explain the memory optimization tradeoff + print("\nWith simplified memory optimization:") + print("- We always use the next state in the buffer (index+1) as next_state") + print("- For terminal states, this means using the first state of the next episode") + print("- This is a common tradeoff in RL implementations for memory efficiency") + print( + "- Since we track done flags, the algorithm can handle these transitions correctly" + ) + + # Test random sampling + print("\nVerifying random sampling with simplified memory optimization...") + random_samples = test_buffer.sample(20) # Sample all transitions + + # Extract values + random_state_values = random_samples["state"]["value"].squeeze().tolist() + random_next_values = random_samples["next_state"]["value"].squeeze().tolist() + random_done_flags = random_samples["done"].bool().tolist() + + # Print a few samples + print("Random samples - State, Next State, Done (First 10):") + for i in range(10): + print( + f" {random_state_values[i]:.1f} → {random_next_values[i]:.1f}, Done: {random_done_flags[i]}" + ) + + # Calculate memory savings + # Assume optimized_buffer and standard_buffer have already been initialized and filled + std_mem = ( + sum( + standard_buffer.states[key].nelement() + * standard_buffer.states[key].element_size() + for key in standard_buffer.states + ) + * 2 + ) + opt_mem = sum( + optimized_buffer.states[key].nelement() + * optimized_buffer.states[key].element_size() + for key in optimized_buffer.states + ) + + savings_percent = (std_mem - opt_mem) / std_mem * 100 + + print("\nMemory optimization result:") + print(f"- Standard buffer state memory: {std_mem / (1024 * 1024):.2f} MB") + print(f"- Optimized buffer state memory: {opt_mem / (1024 * 1024):.2f} MB") + print(f"- Memory savings for state tensors: {savings_percent:.1f}%") + + print("\nAll memory optimization tests completed!") + + # # ===== Test real dataset conversion ===== + # print("\n===== Testing Real LeRobotDataset Conversion =====") + # try: + # # Try to use a real dataset if available + # dataset_name = "AdilZtn/Maniskill-Pushcube-demonstration-small" + # dataset = LeRobotDataset(repo_id=dataset_name) + + # # Print available keys to debug + # sample = dataset[0] + # print("Available keys in dataset:", list(sample.keys())) + + # # Check for required keys + # if "action" not in sample or "next.reward" not in sample: + # print("Dataset missing essential keys. Cannot convert.") + # raise ValueError("Missing required keys in dataset") + + # # Auto-detect appropriate state keys + # image_keys = [] + # state_keys = [] + # for k, v in sample.items(): + # # Skip metadata keys and action/reward keys + # if k in { + # "index", + # "episode_index", + # "frame_index", + # "timestamp", + # "task_index", + # "action", + # "next.reward", + # "next.done", + # }: + # continue + + # # Infer key type from tensor shape + # if isinstance(v, torch.Tensor): + # if len(v.shape) == 3 and (v.shape[0] == 3 or v.shape[0] == 1): + # # Likely an image (channels, height, width) + # image_keys.append(k) + # else: + # # Likely state or other vector + # state_keys.append(k) + + # print(f"Detected image keys: {image_keys}") + # print(f"Detected state keys: {state_keys}") + + # if not image_keys and not state_keys: + # print("No usable keys found in dataset, skipping further tests") + # raise ValueError("No usable keys found in dataset") + + # # Test with standard and memory-optimized buffers + # for optimize_memory in [False, True]: + # buffer_type = "Standard" if not optimize_memory else "Memory-optimized" + # print(f"\nTesting {buffer_type} buffer with real dataset...") + + # # Convert to ReplayBuffer with detected keys + # replay_buffer = ReplayBuffer.from_lerobot_dataset( + # lerobot_dataset=dataset, + # state_keys=image_keys + state_keys, + # device="cpu", + # optimize_memory=optimize_memory, + # ) + # print(f"Loaded {len(replay_buffer)} transitions from {dataset_name}") + + # # Test sampling + # real_batch = replay_buffer.sample(32) + # print(f"Sampled batch from real dataset ({buffer_type}), state shapes:") + # for key in real_batch["state"]: + # print(f" {key}: {real_batch['state'][key].shape}") + + # # Convert back to LeRobotDataset + # with TemporaryDirectory() as temp_dir: + # dataset_name = f"test/real_dataset_converted_{buffer_type}" + # replay_buffer_converted = replay_buffer.to_lerobot_dataset( + # repo_id=dataset_name, + # root=os.path.join(temp_dir, f"dataset_{buffer_type}"), + # ) + # print( + # f"Successfully converted back to LeRobotDataset with {len(replay_buffer_converted)} frames" + # ) + + # except Exception as e: + # print(f"Real dataset test failed: {e}") + # print("This is expected if running offline or if the dataset is not available.") + + # print("\nAll tests completed!") diff --git a/lerobot/scripts/server/crop_dataset_roi.py b/lerobot/scripts/server/crop_dataset_roi.py new file mode 100644 index 000000000..8bb414feb --- /dev/null +++ b/lerobot/scripts/server/crop_dataset_roi.py @@ -0,0 +1,291 @@ +import argparse # noqa: I001 +import json +from copy import deepcopy +from typing import Dict, Tuple +from pathlib import Path +import cv2 + +# import torch.nn.functional as F # noqa: N812 +import torchvision.transforms.functional as F # type: ignore # noqa: N812 +from tqdm import tqdm # type: ignore + +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + +def select_rect_roi(img): + """ + Allows the user to draw a rectangular ROI on the image. + + The user must click and drag to draw the rectangle. + - While dragging, the rectangle is dynamically drawn. + - On mouse button release, the rectangle is fixed. + - Press 'c' to confirm the selection. + - Press 'r' to reset the selection. + - Press ESC to cancel. + + Returns: + A tuple (top, left, height, width) representing the rectangular ROI, + or None if no valid ROI is selected. + """ + # Create a working copy of the image + clone = img.copy() + working_img = clone.copy() + + roi = None # Will store the final ROI as (top, left, height, width) + drawing = False + ix, iy = -1, -1 # Initial click coordinates + + def mouse_callback(event, x, y, flags, param): + nonlocal ix, iy, drawing, roi, working_img + + if event == cv2.EVENT_LBUTTONDOWN: + # Start drawing: record starting coordinates + drawing = True + ix, iy = x, y + + elif event == cv2.EVENT_MOUSEMOVE: + if drawing: + # Compute the top-left and bottom-right corners regardless of drag direction + top = min(iy, y) + left = min(ix, x) + bottom = max(iy, y) + right = max(ix, x) + # Show a temporary image with the current rectangle drawn + temp = working_img.copy() + cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", temp) + + elif event == cv2.EVENT_LBUTTONUP: + # Finish drawing + drawing = False + top = min(iy, y) + left = min(ix, x) + bottom = max(iy, y) + right = max(ix, x) + height = bottom - top + width = right - left + roi = (top, left, height, width) # (top, left, height, width) + # Draw the final rectangle on the working image and display it + working_img = clone.copy() + cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2) + cv2.imshow("Select ROI", working_img) + + # Create the window and set the callback + cv2.namedWindow("Select ROI") + cv2.setMouseCallback("Select ROI", mouse_callback) + cv2.imshow("Select ROI", working_img) + + print("Instructions for ROI selection:") + print(" - Click and drag to draw a rectangular ROI.") + print(" - Press 'c' to confirm the selection.") + print(" - Press 'r' to reset and draw again.") + print(" - Press ESC to cancel the selection.") + + # Wait until the user confirms with 'c', resets with 'r', or cancels with ESC + while True: + key = cv2.waitKey(1) & 0xFF + # Confirm ROI if one has been drawn + if key == ord("c") and roi is not None: + break + # Reset: clear the ROI and restore the original image + elif key == ord("r"): + working_img = clone.copy() + roi = None + cv2.imshow("Select ROI", working_img) + # Cancel selection for this image + elif key == 27: # ESC key + roi = None + break + + cv2.destroyWindow("Select ROI") + return roi + + +def select_square_roi_for_images(images: dict) -> dict: + """ + For each image in the provided dictionary, open a window to allow the user + to select a rectangular ROI. Returns a dictionary mapping each key to a tuple + (top, left, height, width) representing the ROI. + + Parameters: + images (dict): Dictionary where keys are identifiers and values are OpenCV images. + + Returns: + dict: Mapping of image keys to the selected rectangular ROI. + """ + selected_rois = {} + + for key, img in images.items(): + if img is None: + print(f"Image for key '{key}' is None, skipping.") + continue + + print(f"\nSelect rectangular ROI for image with key: '{key}'") + roi = select_rect_roi(img) + + if roi is None: + print(f"No valid ROI selected for '{key}'.") + else: + selected_rois[key] = roi + print(f"ROI for '{key}': {roi}") + + return selected_rois + + +def get_image_from_lerobot_dataset(dataset: LeRobotDataset): + """ + Find the first row in the dataset and extract the image in order to be used for the crop. + """ + row = dataset[0] + image_dict = {} + for k in row: + if "image" in k: + image_dict[k] = deepcopy(row[k]) + return image_dict + + +def convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset: LeRobotDataset, + crop_params_dict: Dict[str, Tuple[int, int, int, int]], + new_repo_id: str, + new_dataset_root: str, + resize_size: Tuple[int, int] = (128, 128), +) -> LeRobotDataset: + """ + Converts an existing LeRobotDataset by iterating over its episodes and frames, + applying cropping and resizing to image observations, and saving a new dataset + with the transformed data. + + Args: + original_dataset (LeRobotDataset): The source dataset. + crop_params_dict (Dict[str, Tuple[int, int, int, int]]): + A dictionary mapping observation keys to crop parameters (top, left, height, width). + new_repo_id (str): Repository id for the new dataset. + new_dataset_root (str): The root directory where the new dataset will be written. + resize_size (Tuple[int, int], optional): The target size (height, width) after cropping. + Defaults to (128, 128). + + Returns: + LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped + and resized. + """ + # 1. Create a new (empty) LeRobotDataset for writing. + new_dataset = LeRobotDataset.create( + repo_id=new_repo_id, + fps=original_dataset.fps, + root=new_dataset_root, + robot_type=original_dataset.meta.robot_type, + features=original_dataset.meta.info["features"], + use_videos=len(original_dataset.meta.video_keys) > 0, + ) + + # Update the metadata for every image key that will be cropped: + # (Here we simply set the shape to be the final resize_size.) + for key in crop_params_dict: + if key in new_dataset.meta.info["features"]: + new_dataset.meta.info["features"][key]["shape"] = list(resize_size) + + # 2. Process each episode in the original dataset. + episodes_info = original_dataset.meta.episodes + # (Sort episodes by episode_index for consistency.) + + episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"]) + # Use the first task from the episode metadata (or "unknown" if not provided) + task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown" + + last_episode_index = 0 + for sample in tqdm(original_dataset): + episode_index = sample.pop("episode_index") + if episode_index != last_episode_index: + new_dataset.save_episode(task, encode_videos=True) + last_episode_index = episode_index + sample.pop("frame_index") + # Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable) + new_sample = sample.copy() + # Loop over each observation key that should be cropped/resized. + for key, params in crop_params_dict.items(): + if key in new_sample: + top, left, height, width = params + # Apply crop then resize. + cropped = F.crop(new_sample[key], top, left, height, width) + resized = F.resize(cropped, resize_size) + new_sample[key] = resized + # Add the transformed frame to the new dataset. + new_dataset.add_frame(new_sample) + + # save last episode + new_dataset.save_episode(task, encode_videos=True) + + # Optionally, consolidate the new dataset to compute statistics and update video info. + new_dataset.consolidate(run_compute_stats=True, keep_image_files=True) + + new_dataset.push_to_hub(tags=None) + + return new_dataset + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Crop rectangular ROIs from a LeRobot dataset." + ) + parser.add_argument( + "--repo-id", + type=str, + default="lerobot", + help="The repository id of the LeRobot dataset to process.", + ) + parser.add_argument( + "--root", + type=str, + default=None, + help="The root directory of the LeRobot dataset.", + ) + parser.add_argument( + "--crop-params-path", + type=str, + default=None, + help="The path to the JSON file containing the ROIs.", + ) + args = parser.parse_args() + + local_files_only = args.root is not None + dataset = LeRobotDataset( + repo_id=args.repo_id, root=args.root, local_files_only=local_files_only + ) + + images = get_image_from_lerobot_dataset(dataset) + images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()} + images = {k: (v * 255).astype("uint8") for k, v in images.items()} + + if args.crop_params_path is None: + rois = select_square_roi_for_images(images) + else: + with open(args.crop_params_path) as f: + rois = json.load(f) + + # rois = { + # "observation.images.front": [102, 43, 358, 523], + # "observation.images.side": [92, 123, 379, 349], + # } + + # Print the selected rectangular ROIs + print("\nSelected Rectangular Regions of Interest (top, left, height, width):") + for key, roi in rois.items(): + print(f"{key}: {roi}") + + new_repo_id = args.repo_id + "_cropped_resized" + new_dataset_root = Path(str(dataset.root) + "_cropped_resized") + + croped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset( + original_dataset=dataset, + crop_params_dict=rois, + new_repo_id=new_repo_id, + new_dataset_root=new_dataset_root, + resize_size=(128, 128), + ) + + meta_dir = new_dataset_root / "meta" + meta_dir.mkdir(exist_ok=True) + + with open(meta_dir / "crop_params.json", "w") as f: + json.dump(rois, f, indent=4) diff --git a/lerobot/scripts/server/find_joint_limits.py b/lerobot/scripts/server/find_joint_limits.py new file mode 100644 index 000000000..d5870027e --- /dev/null +++ b/lerobot/scripts/server/find_joint_limits.py @@ -0,0 +1,72 @@ +import argparse +import time + +import cv2 +import numpy as np + +from lerobot.common.robot_devices.control_utils import is_headless +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.utils.utils import init_hydra_config + + +def find_joint_bounds( + robot, + control_time_s=20, + display_cameras=False, +): + # TODO(rcadene): Add option to record logs + if not robot.is_connected: + robot.connect() + + control_time_s = float("inf") + + timestamp = 0 + start_episode_t = time.perf_counter() + pos_list = [] + while timestamp < control_time_s: + observation, action = robot.teleop_step(record_data=True) + + pos_list.append(robot.follower_arms["main"].read("Present_Position")) + + if display_cameras and not is_headless(): + image_keys = [key for key in observation if "image" in key] + for key in image_keys: + cv2.imshow( + key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR) + ) + cv2.waitKey(1) + + timestamp = time.perf_counter() - start_episode_t + if timestamp > 60: + max = np.max(np.stack(pos_list), 0) + min = np.min(np.stack(pos_list), 0) + print(f"Max angle position per joint {max}") + print(f"Min angle position per joint {min}") + break + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--robot-path", + type=str, + default="lerobot/configs/robot/koch.yaml", + help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", + ) + parser.add_argument( + "--robot-overrides", + type=str, + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) + parser.add_argument( + "--control-time-s", + type=float, + default=20, + help="Maximum episode length in seconds", + ) + args = parser.parse_args() + robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) + + robot = make_robot(robot_cfg) + find_joint_bounds(robot, control_time_s=args.control_time_s) diff --git a/lerobot/scripts/server/gym_manipulator.py b/lerobot/scripts/server/gym_manipulator.py new file mode 100644 index 000000000..c1a7c88c9 --- /dev/null +++ b/lerobot/scripts/server/gym_manipulator.py @@ -0,0 +1,987 @@ +import argparse +import logging +import time +from threading import Lock +from typing import Annotated, Any, Callable, Dict, Optional, Tuple + +import gymnasium as gym +import numpy as np +import torch +import torchvision.transforms.functional as F # noqa: N812 + +from lerobot.common.envs.utils import preprocess_observation +from lerobot.common.robot_devices.control_utils import busy_wait, is_headless +from lerobot.common.robot_devices.robots.factory import make_robot +from lerobot.common.utils.utils import init_hydra_config, log_say + +logging.basicConfig(level=logging.INFO) + + +class HILSerlRobotEnv(gym.Env): + """ + Gym-compatible environment for evaluating robotic control policies with integrated human intervention. + + This environment wraps a robot interface to provide a consistent API for policy evaluation. It supports both relative (delta) + and absolute joint position commands and automatically configures its observation and action spaces based on the robot's + sensors and configuration. + + The environment can switch between executing actions from a policy or using teleoperated actions (human intervention) during + each step. When teleoperation is used, the override action is captured and returned in the `info` dict along with a flag + `is_intervention`. + """ + + def __init__( + self, + robot, + use_delta_action_space: bool = True, + delta: float | None = None, + display_cameras: bool = False, + ): + """ + Initialize the HILSerlRobotEnv environment. + + The environment is set up with a robot interface, which is used to capture observations and send joint commands. The setup + supports both relative (delta) adjustments and absolute joint positions for controlling the robot. + + Args: + robot: The robot interface object used to connect and interact with the physical robot. + use_delta_action_space (bool): If True, uses a delta (relative) action space for joint control. Otherwise, absolute + joint positions are used. + delta (float or None): A scaling factor for the relative adjustments applied to joint positions. Should be a value between + 0 and 1 when using a delta action space. + display_cameras (bool): If True, the robot's camera feeds will be displayed during execution. + """ + super().__init__() + + self.robot = robot + self.display_cameras = display_cameras + + # Connect to the robot if not already connected. + if not self.robot.is_connected: + self.robot.connect() + + self.initial_follower_position = robot.follower_arms["main"].read( + "Present_Position" + ) + + # Episode tracking. + self.current_step = 0 + self.episode_data = None + + self.delta = delta + self.use_delta_action_space = use_delta_action_space + self.current_joint_positions = self.robot.follower_arms["main"].read( + "Present_Position" + ) + + # Retrieve the size of the joint position interval bound. + self.relative_bounds_size = ( + self.robot.config.joint_position_relative_bounds["max"] + - self.robot.config.joint_position_relative_bounds["min"] + ) + + self.delta_relative_bounds_size = self.relative_bounds_size * self.delta + + self.robot.config.max_relative_target = self.delta_relative_bounds_size.float() + + # Dynamically configure the observation and action spaces. + self._setup_spaces() + + def _setup_spaces(self): + """ + Dynamically configure the observation and action spaces based on the robot's capabilities. + + Observation Space: + - For keys with "image": A Box space with pixel values ranging from 0 to 255. + - For non-image keys: A nested Dict space is created under 'observation.state' with a suitable range. + + Action Space: + - The action space is defined as a Tuple where: + • The first element is a Box space representing joint position commands. It is defined as relative (delta) + or absolute, based on the configuration. + • The second element is a Discrete space (with 2 values) serving as a flag for intervention (teleoperation). + """ + example_obs = self.robot.capture_observation() + + # Define observation spaces for images and other states. + image_keys = [key for key in example_obs if "image" in key] + state_keys = [key for key in example_obs if "image" not in key] + observation_spaces = { + key: gym.spaces.Box( + low=0, high=255, shape=example_obs[key].shape, dtype=np.uint8 + ) + for key in image_keys + } + observation_spaces["observation.state"] = gym.spaces.Dict( + { + key: gym.spaces.Box( + low=0, high=10, shape=example_obs[key].shape, dtype=np.float32 + ) + for key in state_keys + } + ) + + self.observation_space = gym.spaces.Dict(observation_spaces) + + # Define the action space for joint positions along with setting an intervention flag. + action_dim = len(self.robot.follower_arms["main"].read("Present_Position")) + if self.use_delta_action_space: + action_space_robot = gym.spaces.Box( + low=-self.relative_bounds_size.cpu().numpy(), + high=self.relative_bounds_size.cpu().numpy(), + shape=(action_dim,), + dtype=np.float32, + ) + else: + action_space_robot = gym.spaces.Box( + low=self.robot.config.joint_position_relative_bounds["min"] + .cpu() + .numpy(), + high=self.robot.config.joint_position_relative_bounds["max"] + .cpu() + .numpy(), + shape=(action_dim,), + dtype=np.float32, + ) + + self.action_space = gym.spaces.Tuple( + ( + action_space_robot, + gym.spaces.Discrete(2), + ), + ) + + def reset( + self, seed=None, options=None + ) -> Tuple[Dict[str, np.ndarray], Dict[str, Any]]: + """ + Reset the environment to its initial state. + This method resets the step counter and clears any episodic data. + + Args: + seed (Optional[int]): A seed for random number generation to ensure reproducibility. + options (Optional[dict]): Additional options to influence the reset behavior. + + Returns: + A tuple containing: + - observation (dict): The initial sensor observation. + - info (dict): A dictionary with supplementary information, including the key "initial_position". + """ + super().reset(seed=seed, options=options) + + # Capture the initial observation. + observation = self.robot.capture_observation() + + # Reset episode tracking variables. + self.current_step = 0 + self.episode_data = None + + return observation, {"initial_position": self.initial_follower_position} + + def step( + self, action: Tuple[np.ndarray, bool] + ) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict[str, Any]]: + """ + Execute a single step within the environment using the specified action. + + The provided action is a tuple comprised of: + • A policy action (joint position commands) that may be either in absolute values or as a delta. + • A boolean flag indicating whether teleoperation (human intervention) should be used for this step. + + Behavior: + - When the intervention flag is False, the environment processes and sends the policy action to the robot. + - When True, a teleoperation step is executed. If using a delta action space, an absolute teleop action is converted + to relative change based on the current joint positions. + + Args: + action (tuple): A tuple with two elements: + - policy_action (np.ndarray or torch.Tensor): The commanded joint positions. + - intervention_bool (bool): True if the human operator intervenes by providing a teleoperation input. + + Returns: + tuple: A tuple containing: + - observation (dict): The new sensor observation after taking the step. + - reward (float): The step reward (default is 0.0 within this wrapper). + - terminated (bool): True if the episode has reached a terminal state. + - truncated (bool): True if the episode was truncated (e.g., time constraints). + - info (dict): Additional debugging information including: + ◦ "action_intervention": The teleop action if intervention was used. + ◦ "is_intervention": Flag indicating whether teleoperation was employed. + """ + policy_action, intervention_bool = action + teleop_action = None + self.current_joint_positions = self.robot.follower_arms["main"].read( + "Present_Position" + ) + if isinstance(policy_action, torch.Tensor): + policy_action = policy_action.cpu().numpy() + policy_action = np.clip( + policy_action, self.action_space[0].low, self.action_space[0].high + ) + if not intervention_bool: + if self.use_delta_action_space: + target_joint_positions = ( + self.current_joint_positions + self.delta * policy_action + ) + else: + target_joint_positions = policy_action + self.robot.send_action(torch.from_numpy(target_joint_positions)) + observation = self.robot.capture_observation() + else: + observation, teleop_action = self.robot.teleop_step(record_data=True) + teleop_action = teleop_action[ + "action" + ] # Convert tensor to appropriate format + + # When applying the delta action space, convert teleop absolute values to relative differences. + if self.use_delta_action_space: + teleop_action = ( + teleop_action - self.current_joint_positions + ) / self.delta + if torch.any(teleop_action < -self.relative_bounds_size) and torch.any( + teleop_action > self.relative_bounds_size + ): + logging.debug( + f"Relative teleop delta exceeded bounds {self.relative_bounds_size}, teleop_action {teleop_action}\n" + f"lower bounds condition {teleop_action < -self.relative_bounds_size}\n" + f"upper bounds condition {teleop_action > self.relative_bounds_size}" + ) + + teleop_action = torch.clamp( + teleop_action, + -self.relative_bounds_size, + self.relative_bounds_size, + ) + # NOTE: To mimic the shape of a neural network output, we add a batch dimension to the teleop action. + if teleop_action.dim() == 1: + teleop_action = teleop_action.unsqueeze(0) + + # self.render() + + self.current_step += 1 + + reward = 0.0 + terminated = False + truncated = False + + return ( + observation, + reward, + terminated, + truncated, + { + "action_intervention": teleop_action, + "is_intervention": teleop_action is not None, + }, + ) + + def render(self): + """ + Render the current state of the environment by displaying the robot's camera feeds. + """ + import cv2 + + observation = self.robot.capture_observation() + image_keys = [key for key in observation if "image" in key] + + for key in image_keys: + cv2.imshow(key, cv2.cvtColor(observation[key].numpy(), cv2.COLOR_RGB2BGR)) + cv2.waitKey(1) + + def close(self): + """ + Close the environment and clean up resources by disconnecting the robot. + + If the robot is currently connected, this method properly terminates the connection to ensure that all + associated resources are released. + """ + if self.robot.is_connected: + self.robot.disconnect() + + +class ActionRepeatWrapper(gym.Wrapper): + def __init__(self, env, nb_repeat: int = 1): + super().__init__(env) + self.nb_repeat = nb_repeat + + def step(self, action): + for _ in range(self.nb_repeat): + obs, reward, done, truncated, info = self.env.step(action) + if done or truncated: + break + return obs, reward, done, truncated, info + + +class RewardWrapper(gym.Wrapper): + def __init__(self, env, reward_classifier, device: torch.device = "cuda"): + """ + Wrapper to add reward prediction to the environment, it use a trained classifer. + + Args: + env: The environment to wrap + reward_classifier: The reward classifier model + device: The device to run the model on + """ + self.env = env + + # NOTE: We got 15% speedup by compiling the model + self.reward_classifier = torch.compile(reward_classifier) + + if isinstance(device, str): + device = torch.device(device) + self.device = device + + def step(self, action): + observation, _, terminated, truncated, info = self.env.step(action) + images = [ + observation[key].to(self.device, non_blocking=self.device.type == "cuda") + for key in observation + if "image" in key + ] + start_time = time.perf_counter() + with torch.inference_mode(): + reward = ( + self.reward_classifier.predict_reward(images, threshold=0.8) + if self.reward_classifier is not None + else 0.0 + ) + info["Reward classifer frequency"] = 1 / (time.perf_counter() - start_time) + + # logging.info(f"Reward: {reward}") + + if reward == 1.0: + terminated = True + return observation, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + return self.env.reset(seed=seed, options=options) + + +class JointMaskingActionSpace(gym.Wrapper): + def __init__(self, env, mask): + """ + Wrapper to mask out dimensions of the action space. + + Args: + env: The environment to wrap + mask: Binary mask array where 0 indicates dimensions to remove + """ + super().__init__(env) + + # Validate mask matches action space + + # Keep only dimensions where mask is 1 + self.active_dims = np.where(mask)[0] + + if isinstance(env.action_space, gym.spaces.Box): + if len(mask) != env.action_space.shape[0]: + raise ValueError("Mask length must match action space dimensions") + low = env.action_space.low[self.active_dims] + high = env.action_space.high[self.active_dims] + self.action_space = gym.spaces.Box( + low=low, high=high, dtype=env.action_space.dtype + ) + + if isinstance(env.action_space, gym.spaces.Tuple): + if len(mask) != env.action_space[0].shape[0]: + raise ValueError("Mask length must match action space 0 dimensions") + + low = env.action_space[0].low[self.active_dims] + high = env.action_space[0].high[self.active_dims] + action_space_masked = gym.spaces.Box( + low=low, high=high, dtype=env.action_space[0].dtype + ) + self.action_space = gym.spaces.Tuple( + (action_space_masked, env.action_space[1]) + ) + # Create new action space with masked dimensions + + def action(self, action): + """ + Convert masked action back to full action space. + + Args: + action: Action in masked space. For Tuple spaces, the first element is masked. + + Returns: + Action in original space with masked dims set to 0. + """ + + # Determine whether we are handling a Tuple space or a Box. + if isinstance(self.env.action_space, gym.spaces.Tuple): + # Extract the masked component from the tuple. + masked_action = action[0] if isinstance(action, tuple) else action + # Create a full action for the Box element. + full_box_action = np.zeros( + self.env.action_space[0].shape, dtype=self.env.action_space[0].dtype + ) + full_box_action[self.active_dims] = masked_action + # Return a tuple with the reconstructed Box action and the unchanged remainder. + return (full_box_action, action[1]) + else: + # For Box action spaces. + masked_action = action if not isinstance(action, tuple) else action[0] + full_action = np.zeros( + self.env.action_space.shape, dtype=self.env.action_space.dtype + ) + full_action[self.active_dims] = masked_action + return full_action + + def step(self, action): + action = self.action(action) + obs, reward, terminated, truncated, info = self.env.step(action) + if "action_intervention" in info and info["action_intervention"] is not None: + if info["action_intervention"].dim() == 1: + info["action_intervention"] = info["action_intervention"][ + self.active_dims + ] + else: + info["action_intervention"] = info["action_intervention"][ + :, self.active_dims + ] + return obs, reward, terminated, truncated, info + + +class TimeLimitWrapper(gym.Wrapper): + def __init__(self, env, control_time_s, fps): + self.env = env + self.control_time_s = control_time_s + self.fps = fps + + self.last_timestamp = 0.0 + self.episode_time_in_s = 0.0 + + self.max_episode_steps = int(self.control_time_s * self.fps) + + self.current_step = 0 + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + time_since_last_step = time.perf_counter() - self.last_timestamp + self.episode_time_in_s += time_since_last_step + self.last_timestamp = time.perf_counter() + self.current_step += 1 + # check if last timestep took more time than the expected fps + if 1.0 / time_since_last_step < self.fps: + logging.debug(f"Current timestep exceeded expected fps {self.fps}") + + if self.episode_time_in_s > self.control_time_s: + # if self.current_step >= self.max_episode_steps: + # Terminated = True + terminated = True + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + self.episode_time_in_s = 0.0 + self.last_timestamp = time.perf_counter() + self.current_step = 0 + return self.env.reset(seed=seed, options=options) + + +class ImageCropResizeWrapper(gym.Wrapper): + def __init__( + self, + env, + crop_params_dict: Dict[str, Annotated[Tuple[int], 4]], + resize_size=None, + ): + super().__init__(env) + self.env = env + self.crop_params_dict = crop_params_dict + print(f"obs_keys , {self.env.observation_space}") + print(f"crop params dict {crop_params_dict.keys()}") + for key_crop in crop_params_dict: + if key_crop not in self.env.observation_space.keys(): # noqa: SIM118 + raise ValueError(f"Key {key_crop} not in observation space") + for key in crop_params_dict: + top, left, height, width = crop_params_dict[key] + new_shape = (top + height, left + width) + self.observation_space[key] = gym.spaces.Box( + low=0, high=255, shape=new_shape + ) + + self.resize_size = resize_size + if self.resize_size is None: + self.resize_size = (128, 128) + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + for k in self.crop_params_dict: + device = obs[k].device + + # Check for NaNs before processing + if torch.isnan(obs[k]).any(): + logging.error( + f"NaN values detected in observation {k} before crop and resize" + ) + + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + + # Check for NaNs after processing + if torch.isnan(obs[k]).any(): + logging.error( + f"NaN values detected in observation {k} after crop and resize" + ) + + obs[k] = obs[k].to(device) + + return obs, reward, terminated, truncated, info + + def reset(self, seed=None, options=None): + obs, info = self.env.reset(seed=seed, options=options) + for k in self.crop_params_dict: + device = obs[k].device + if device == torch.device("mps:0"): + obs[k] = obs[k].cpu() + obs[k] = F.crop(obs[k], *self.crop_params_dict[k]) + obs[k] = F.resize(obs[k], self.resize_size) + obs[k] = obs[k].to(device) + return obs, info + + +class ConvertToLeRobotObservation(gym.ObservationWrapper): + def __init__(self, env, device): + super().__init__(env) + + if isinstance(device, str): + device = torch.device(device) + self.device = device + + def observation(self, observation): + observation = preprocess_observation(observation) + + observation = { + key: observation[key].to( + self.device, non_blocking=self.device.type == "cuda" + ) + for key in observation + } + observation = { + k: torch.tensor(v, device=self.device) for k, v in observation.items() + } + return observation + + +class KeyboardInterfaceWrapper(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + self.listener = None + self.events = { + "exit_early": False, + "pause_policy": False, + "reset_env": False, + "human_intervention_step": False, + "episode_success": False, + } + self.event_lock = Lock() # Thread-safe access to events + self._init_keyboard_listener() + + def _init_keyboard_listener(self): + """Initialize keyboard listener if not in headless mode""" + + if is_headless(): + logging.warning( + "Headless environment detected. On-screen cameras display and keyboard inputs will not be available." + ) + return + try: + from pynput import keyboard + + def on_press(key): + with self.event_lock: + try: + if key == keyboard.Key.right or key == keyboard.Key.esc: + print("Right arrow key pressed. Exiting loop...") + self.events["exit_early"] = True + return + if hasattr(key, "char") and key.char == "s": + print("Key 's' pressed. Episode success triggered.") + self.events["episode_success"] = True + return + if key == keyboard.Key.space and not self.events["exit_early"]: + if not self.events["pause_policy"]: + print( + "Space key pressed. Human intervention required.\n" + "Place the leader in similar pose to the follower and press space again." + ) + self.events["pause_policy"] = True + log_say( + "Human intervention stage. Get ready to take over.", + play_sounds=True, + ) + return + if ( + self.events["pause_policy"] + and not self.events["human_intervention_step"] + ): + self.events["human_intervention_step"] = True + print("Space key pressed. Human intervention starting.") + log_say( + "Starting human intervention.", play_sounds=True + ) + return + if ( + self.events["pause_policy"] + and self.events["human_intervention_step"] + ): + self.events["pause_policy"] = False + self.events["human_intervention_step"] = False + print("Space key pressed for a third time.") + log_say( + "Continuing with policy actions.", play_sounds=True + ) + return + except Exception as e: + print(f"Error handling key press: {e}") + + self.listener = keyboard.Listener(on_press=on_press) + self.listener.start() + except ImportError: + logging.warning( + "Could not import pynput. Keyboard interface will not be available." + ) + self.listener = None + + def step(self, action: Any) -> Tuple[Any, float, bool, bool, Dict]: + is_intervention = False + terminated_by_keyboard = False + + # Extract policy_action if needed + if isinstance(self.env.action_space, gym.spaces.Tuple): + policy_action = action[0] + + # Check the event flags without holding the lock for too long. + with self.event_lock: + if self.events["exit_early"]: + terminated_by_keyboard = True + pause_policy = self.events["pause_policy"] + + if pause_policy: + # Now, wait for human_intervention_step without holding the lock + while True: + with self.event_lock: + if self.events["human_intervention_step"]: + is_intervention = True + break + time.sleep(0.1) # Check more frequently if desired + + # Execute the step in the underlying environment + obs, reward, terminated, truncated, info = self.env.step( + (policy_action, is_intervention) + ) + + # Override reward and termination if episode success event triggered + with self.event_lock: + if self.events["episode_success"]: + reward = 1 + terminated_by_keyboard = True + + return obs, reward, terminated or terminated_by_keyboard, truncated, info + + def reset(self, **kwargs) -> Tuple[Any, Dict]: + """ + Reset the environment and clear any pending events + """ + with self.event_lock: + self.events = {k: False for k in self.events} + return self.env.reset(**kwargs) + + def close(self): + """ + Properly clean up the keyboard listener when the environment is closed + """ + if self.listener is not None: + self.listener.stop() + super().close() + + +class ResetWrapper(gym.Wrapper): + def __init__( + self, + env: HILSerlRobotEnv, + reset_fn: Optional[Callable[[], None]] = None, + reset_time_s: float = 5, + ): + super().__init__(env) + self.reset_fn = reset_fn + self.reset_time_s = reset_time_s + + self.robot = self.unwrapped.robot + self.init_pos = self.unwrapped.initial_follower_position + + def reset(self, *, seed=None, options=None): + if self.reset_fn is not None: + self.reset_fn(self.env) + else: + log_say( + f"Manually reset the environment for {self.reset_time_s} seconds.", + play_sounds=True, + ) + start_time = time.perf_counter() + while time.perf_counter() - start_time < self.reset_time_s: + self.robot.teleop_step() + + log_say("Manual reseting of the environment done.", play_sounds=True) + return super().reset(seed=seed, options=options) + + +class BatchCompitableWrapper(gym.ObservationWrapper): + def __init__(self, env): + super().__init__(env) + + def observation( + self, observation: dict[str, torch.Tensor] + ) -> dict[str, torch.Tensor]: + for key in observation: + if "image" in key and observation[key].dim() == 3: + observation[key] = observation[key].unsqueeze(0) + if "state" in key and observation[key].dim() == 1: + observation[key] = observation[key].unsqueeze(0) + return observation + + +# TODO: REMOVE TH + + +def make_robot_env( + robot, + reward_classifier, + cfg, + n_envs: int = 1, +) -> gym.vector.VectorEnv: + """ + Factory function to create a vectorized robot environment. + + Args: + robot: Robot instance to control + reward_classifier: Classifier model for computing rewards + cfg: Configuration object containing environment parameters + n_envs: Number of environments to create in parallel. Defaults to 1. + + Returns: + A vectorized gym environment with all the necessary wrappers applied. + """ + if "maniskill" in cfg.env.name: + from lerobot.scripts.server.maniskill_manipulator import make_maniskill + + logging.warning("WE SHOULD REMOVE THE MANISKILL BEFORE THE MERGE INTO MAIN") + env = make_maniskill( + cfg=cfg, + n_envs=1, + ) + return env + # Create base environment + env = HILSerlRobotEnv( + robot=robot, + display_cameras=cfg.env.wrapper.display_cameras, + delta=cfg.env.wrapper.delta_action, + use_delta_action_space=cfg.env.wrapper.use_relative_joint_positions, + ) + + # Add observation and image processing + env = ConvertToLeRobotObservation(env=env, device=cfg.device) + if cfg.env.wrapper.crop_params_dict is not None: + env = ImageCropResizeWrapper( + env=env, + crop_params_dict=cfg.env.wrapper.crop_params_dict, + resize_size=cfg.env.wrapper.resize_size, + ) + + # Add reward computation and control wrappers + env = RewardWrapper(env=env, reward_classifier=reward_classifier, device=cfg.device) + env = TimeLimitWrapper( + env=env, control_time_s=cfg.env.wrapper.control_time_s, fps=cfg.fps + ) + env = KeyboardInterfaceWrapper(env=env) + env = ResetWrapper( + env=env, reset_fn=None, reset_time_s=cfg.env.wrapper.reset_time_s + ) + env = JointMaskingActionSpace( + env=env, mask=cfg.env.wrapper.joint_masking_action_space + ) + env = BatchCompitableWrapper(env=env) + + return env + + # batched version of the env that returns an observation of shape (b, c) + + +def get_classifier(pretrained_path, config_path, device="mps"): + if pretrained_path is None or config_path is None: + return None + + from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg + from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( + ClassifierConfig, + ) + from lerobot.common.policies.hilserl.classifier.modeling_classifier import ( + Classifier, + ) + + cfg = init_hydra_config(config_path) + + classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) + classifier_config.num_cameras = len( + cfg.training.image_keys + ) # TODO automate these paths + model = Classifier(classifier_config) + model.load_state_dict(Classifier.from_pretrained(pretrained_path).state_dict()) + model = model.to(device) + return model + + +def replay_episode(env, repo_id, root=None, episode=0): + from lerobot.common.datasets.lerobot_dataset import LeRobotDataset + + local_files_only = root is not None + dataset = LeRobotDataset( + repo_id, root=root, episodes=[episode], local_files_only=local_files_only + ) + actions = dataset.hf_dataset.select_columns("action") + + for idx in range(dataset.num_frames): + start_episode_t = time.perf_counter() + + action = actions[idx]["action"][:4] + print(action) + env.step((action / env.unwrapped.delta, False)) + + dt_s = time.perf_counter() - start_episode_t + busy_wait(1 / 10 - dt_s) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fps", type=int, default=30, help="control frequency") + parser.add_argument( + "--robot-path", + type=str, + default="lerobot/configs/robot/koch.yaml", + help="Path to robot yaml file used to instantiate the robot using `make_robot` factory function.", + ) + parser.add_argument( + "--robot-overrides", + type=str, + nargs="*", + help="Any key=value arguments to override config values (use dots for.nested=overrides)", + ) + parser.add_argument( + "-p", + "--pretrained-policy-name-or-path", + help=( + "Either the repo ID of a model hosted on the Hub or a path to a directory containing weights " + "saved using `Policy.save_pretrained`. If not provided, the policy is initialized from scratch " + "(useful for debugging). This argument is mutually exclusive with `--config`." + ), + ) + parser.add_argument( + "--config", + help=( + "Path to a yaml config you want to use for initializing a policy from scratch (useful for " + "debugging). This argument is mutually exclusive with `--pretrained-policy-name-or-path` (`-p`)." + ), + ) + parser.add_argument( + "--display-cameras", + help=("Whether to display the camera feed while the rollout is happening"), + ) + parser.add_argument( + "--reward-classifier-pretrained-path", + type=str, + default=None, + help="Path to the pretrained classifier weights.", + ) + parser.add_argument( + "--reward-classifier-config-file", + type=str, + default=None, + help="Path to a yaml config file that is necessary to build the reward classifier model.", + ) + parser.add_argument( + "--env-path", type=str, default=None, help="Path to the env yaml file" + ) + parser.add_argument( + "--env-overrides", + type=str, + default=None, + help="Overrides for the env yaml file", + ) + parser.add_argument( + "--control-time-s", + type=float, + default=20, + help="Maximum episode length in seconds", + ) + parser.add_argument( + "--reset-follower-pos", + type=int, + default=1, + help="Reset follower between episodes", + ) + parser.add_argument( + "--replay-repo-id", + type=str, + default=None, + help="Repo ID of the episode to replay", + ) + parser.add_argument( + "--replay-root", type=str, default=None, help="Root of the dataset to replay" + ) + parser.add_argument( + "--replay-episode", type=int, default=0, help="Episode to replay" + ) + args = parser.parse_args() + + robot_cfg = init_hydra_config(args.robot_path, args.robot_overrides) + robot = make_robot(robot_cfg) + + reward_classifier = get_classifier( + args.reward_classifier_pretrained_path, args.reward_classifier_config_file + ) + user_relative_joint_positions = True + + cfg = init_hydra_config(args.env_path, args.env_overrides) + env = make_robot_env( + robot, + reward_classifier, + cfg.env, # .wrapper, + ) + + env.reset() + + if args.replay_repo_id is not None: + replay_episode( + env, args.replay_repo_id, root=args.replay_root, episode=args.replay_episode + ) + exit() + + # Retrieve the robot's action space for joint commands. + action_space_robot = env.action_space.spaces[0] + + # Initialize the smoothed action as a random sample. + smoothed_action = action_space_robot.sample() + + # Smoothing coefficient (alpha) defines how much of the new random sample to mix in. + # A value close to 0 makes the trajectory very smooth (slow to change), while a value close to 1 is less smooth. + alpha = 0.4 + + while True: + start_loop_s = time.perf_counter() + # Sample a new random action from the robot's action space. + new_random_action = action_space_robot.sample() + # Update the smoothed action using an exponential moving average. + smoothed_action = alpha * new_random_action + (1 - alpha) * smoothed_action + + # Execute the step: wrap the NumPy action in a torch tensor. + obs, reward, terminated, truncated, info = env.step( + (torch.from_numpy(smoothed_action), False) + ) + if terminated or truncated: + env.reset() + + dt_s = time.perf_counter() - start_loop_s + busy_wait(1 / args.fps - dt_s) diff --git a/lerobot/scripts/server/hilserl.proto b/lerobot/scripts/server/hilserl.proto new file mode 100644 index 000000000..dec2117b2 --- /dev/null +++ b/lerobot/scripts/server/hilserl.proto @@ -0,0 +1,55 @@ +// !/usr/bin/env python + +// Copyright 2024 The HuggingFace Inc. team. +// All rights reserved. + +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +syntax = "proto3"; + +package hil_serl; + +// LearnerService: the Actor calls this to push transitions. +// The Learner implements this service. +service LearnerService { + // Actor -> Learner to store transitions + rpc SendInteractionMessage(InteractionMessage) returns (Empty); + rpc StreamParameters(Empty) returns (stream Parameters); + rpc SendTransitions(stream Transition) returns (Empty); + rpc SendInteractions(stream InteractionMessage) returns (Empty); + rpc Ready(Empty) returns (Empty); +} + +enum TransferState { + TRANSFER_UNKNOWN = 0; + TRANSFER_BEGIN = 1; + TRANSFER_MIDDLE = 2; + TRANSFER_END = 3; +} + +// Messages +message Transition { + TransferState transfer_state = 1; + bytes data = 2; +} + +message Parameters { + TransferState transfer_state = 1; + bytes data = 2; +} + +message InteractionMessage { + TransferState transfer_state = 1; + bytes data = 2; +} + +message Empty {} diff --git a/lerobot/scripts/server/hilserl_pb2.py b/lerobot/scripts/server/hilserl_pb2.py new file mode 100644 index 000000000..4a4cbea76 --- /dev/null +++ b/lerobot/scripts/server/hilserl_pb2.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# NO CHECKED-IN PROTOBUF GENCODE +# source: hilserl.proto +# Protobuf Python Version: 5.29.0 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import runtime_version as _runtime_version +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +_runtime_version.ValidateProtobufRuntimeVersion( + _runtime_version.Domain.PUBLIC, + 5, + 29, + 0, + '', + 'hilserl.proto' +) +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\rhilserl.proto\x12\x08hil_serl\"K\n\nTransition\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"K\n\nParameters\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"S\n\x12InteractionMessage\x12/\n\x0etransfer_state\x18\x01 \x01(\x0e\x32\x17.hil_serl.TransferState\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\x07\n\x05\x45mpty*`\n\rTransferState\x12\x14\n\x10TRANSFER_UNKNOWN\x10\x00\x12\x12\n\x0eTRANSFER_BEGIN\x10\x01\x12\x13\n\x0fTRANSFER_MIDDLE\x10\x02\x12\x10\n\x0cTRANSFER_END\x10\x03\x32\xc2\x02\n\x0eLearnerService\x12G\n\x16SendInteractionMessage\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty\x12;\n\x10StreamParameters\x12\x0f.hil_serl.Empty\x1a\x14.hil_serl.Parameters0\x01\x12:\n\x0fSendTransitions\x12\x14.hil_serl.Transition\x1a\x0f.hil_serl.Empty(\x01\x12\x43\n\x10SendInteractions\x12\x1c.hil_serl.InteractionMessage\x1a\x0f.hil_serl.Empty(\x01\x12)\n\x05Ready\x12\x0f.hil_serl.Empty\x1a\x0f.hil_serl.Emptyb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'hilserl_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_TRANSFERSTATE']._serialized_start=275 + _globals['_TRANSFERSTATE']._serialized_end=371 + _globals['_TRANSITION']._serialized_start=27 + _globals['_TRANSITION']._serialized_end=102 + _globals['_PARAMETERS']._serialized_start=104 + _globals['_PARAMETERS']._serialized_end=179 + _globals['_INTERACTIONMESSAGE']._serialized_start=181 + _globals['_INTERACTIONMESSAGE']._serialized_end=264 + _globals['_EMPTY']._serialized_start=266 + _globals['_EMPTY']._serialized_end=273 + _globals['_LEARNERSERVICE']._serialized_start=374 + _globals['_LEARNERSERVICE']._serialized_end=696 +# @@protoc_insertion_point(module_scope) diff --git a/lerobot/scripts/server/hilserl_pb2_grpc.py b/lerobot/scripts/server/hilserl_pb2_grpc.py new file mode 100644 index 000000000..1fa96e81a --- /dev/null +++ b/lerobot/scripts/server/hilserl_pb2_grpc.py @@ -0,0 +1,276 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +import hilserl_pb2 as hilserl__pb2 + +GRPC_GENERATED_VERSION = '1.70.0' +GRPC_VERSION = grpc.__version__ +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + raise RuntimeError( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in hilserl_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + ) + + +class LearnerServiceStub(object): + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.SendInteractionMessage = channel.unary_unary( + '/hil_serl.LearnerService/SendInteractionMessage', + request_serializer=hilserl__pb2.InteractionMessage.SerializeToString, + response_deserializer=hilserl__pb2.Empty.FromString, + _registered_method=True) + self.StreamParameters = channel.unary_stream( + '/hil_serl.LearnerService/StreamParameters', + request_serializer=hilserl__pb2.Empty.SerializeToString, + response_deserializer=hilserl__pb2.Parameters.FromString, + _registered_method=True) + self.SendTransitions = channel.stream_unary( + '/hil_serl.LearnerService/SendTransitions', + request_serializer=hilserl__pb2.Transition.SerializeToString, + response_deserializer=hilserl__pb2.Empty.FromString, + _registered_method=True) + self.SendInteractions = channel.stream_unary( + '/hil_serl.LearnerService/SendInteractions', + request_serializer=hilserl__pb2.InteractionMessage.SerializeToString, + response_deserializer=hilserl__pb2.Empty.FromString, + _registered_method=True) + self.Ready = channel.unary_unary( + '/hil_serl.LearnerService/Ready', + request_serializer=hilserl__pb2.Empty.SerializeToString, + response_deserializer=hilserl__pb2.Empty.FromString, + _registered_method=True) + + +class LearnerServiceServicer(object): + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + def SendInteractionMessage(self, request, context): + """Actor -> Learner to store transitions + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def StreamParameters(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendTransitions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def SendInteractions(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Ready(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_LearnerServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'SendInteractionMessage': grpc.unary_unary_rpc_method_handler( + servicer.SendInteractionMessage, + request_deserializer=hilserl__pb2.InteractionMessage.FromString, + response_serializer=hilserl__pb2.Empty.SerializeToString, + ), + 'StreamParameters': grpc.unary_stream_rpc_method_handler( + servicer.StreamParameters, + request_deserializer=hilserl__pb2.Empty.FromString, + response_serializer=hilserl__pb2.Parameters.SerializeToString, + ), + 'SendTransitions': grpc.stream_unary_rpc_method_handler( + servicer.SendTransitions, + request_deserializer=hilserl__pb2.Transition.FromString, + response_serializer=hilserl__pb2.Empty.SerializeToString, + ), + 'SendInteractions': grpc.stream_unary_rpc_method_handler( + servicer.SendInteractions, + request_deserializer=hilserl__pb2.InteractionMessage.FromString, + response_serializer=hilserl__pb2.Empty.SerializeToString, + ), + 'Ready': grpc.unary_unary_rpc_method_handler( + servicer.Ready, + request_deserializer=hilserl__pb2.Empty.FromString, + response_serializer=hilserl__pb2.Empty.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'hil_serl.LearnerService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('hil_serl.LearnerService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class LearnerService(object): + """LearnerService: the Actor calls this to push transitions. + The Learner implements this service. + """ + + @staticmethod + def SendInteractionMessage(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/hil_serl.LearnerService/SendInteractionMessage', + hilserl__pb2.InteractionMessage.SerializeToString, + hilserl__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def StreamParameters(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/hil_serl.LearnerService/StreamParameters', + hilserl__pb2.Empty.SerializeToString, + hilserl__pb2.Parameters.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendTransitions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/hil_serl.LearnerService/SendTransitions', + hilserl__pb2.Transition.SerializeToString, + hilserl__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def SendInteractions(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/hil_serl.LearnerService/SendInteractions', + hilserl__pb2.InteractionMessage.SerializeToString, + hilserl__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Ready(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/hil_serl.LearnerService/Ready', + hilserl__pb2.Empty.SerializeToString, + hilserl__pb2.Empty.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/lerobot/scripts/server/learner_server.py b/lerobot/scripts/server/learner_server.py new file mode 100644 index 000000000..7bd4aee05 --- /dev/null +++ b/lerobot/scripts/server/learner_server.py @@ -0,0 +1,738 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import shutil +import time +from pprint import pformat +from concurrent.futures import ThreadPoolExecutor + +# from torch.multiprocessing import Event, Queue, Process +# from threading import Event, Thread +# from torch.multiprocessing import Queue, Event +from torch.multiprocessing import Queue + +from lerobot.scripts.server.utils import setup_process_handlers + +import grpc + +# Import generated stubs +import hilserl_pb2_grpc # type: ignore +import hydra +import torch +from deepdiff import DeepDiff +from omegaconf import DictConfig, OmegaConf +from termcolor import colored +from torch import nn +from torch.optim.optimizer import Optimizer + +from lerobot.common.datasets.factory import make_dataset + +# TODO: Remove the import of maniskill +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.logger import Logger, log_output_dir +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.utils.utils import ( + format_big_number, + get_global_random_state, + get_safe_torch_device, + init_hydra_config, + init_logging, + set_global_random_state, + set_global_seed, +) + +from lerobot.scripts.server.buffer import ( + ReplayBuffer, + concatenate_batch_transitions, + move_transition_to_device, + move_state_dict_to_device, + bytes_to_transitions, + state_to_bytes, + bytes_to_python_object, +) + +from lerobot.scripts.server import learner_service + + +def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig: + if not cfg.resume: + if Logger.get_last_checkpoint_dir(out_dir).exists(): + raise RuntimeError( + f"Output directory {Logger.get_last_checkpoint_dir(out_dir)} already exists. " + "Use `resume=true` to resume training." + ) + return cfg + + # if resume == True + checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir) + if not checkpoint_dir.exists(): + raise RuntimeError( + f"No model checkpoint found in {checkpoint_dir} for resume=True" + ) + + checkpoint_cfg_path = str( + Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml" + ) + logging.info( + colored( + "Resume=True detected, resuming previous run", + color="yellow", + attrs=["bold"], + ) + ) + + checkpoint_cfg = init_hydra_config(checkpoint_cfg_path) + diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + + if "values_changed" in diff and "root['resume']" in diff["values_changed"]: + del diff["values_changed"]["root['resume']"] + + if len(diff) > 0: + logging.warning( + f"Differences between the checkpoint config and the provided config detected: \n{pformat(diff)}\n" + "Checkpoint configuration takes precedence." + ) + + checkpoint_cfg.resume = True + return checkpoint_cfg + + +def load_training_state( + cfg: DictConfig, + logger: Logger, + optimizers: Optimizer | dict, +): + if not cfg.resume: + return None, None + + training_state = torch.load( + logger.last_checkpoint_dir / logger.training_state_file_name + ) + + if isinstance(training_state["optimizer"], dict): + assert set(training_state["optimizer"].keys()) == set(optimizers.keys()) + for k, v in training_state["optimizer"].items(): + optimizers[k].load_state_dict(v) + else: + optimizers.load_state_dict(training_state["optimizer"]) + + set_global_random_state({k: training_state[k] for k in get_global_random_state()}) + return training_state["step"], training_state["interaction_step"] + + +def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None: + num_learnable_params = sum( + p.numel() for p in policy.parameters() if p.requires_grad + ) + num_total_params = sum(p.numel() for p in policy.parameters()) + + log_output_dir(out_dir) + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.training.online_steps=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + +def initialize_replay_buffer( + cfg: DictConfig, logger: Logger, device: str, storage_device: str +) -> ReplayBuffer: + if not cfg.resume: + return ReplayBuffer( + capacity=cfg.training.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + storage_device=storage_device, + optimize_memory=True, + ) + + dataset = LeRobotDataset( + repo_id=cfg.dataset_repo_id, + local_files_only=True, + root=logger.log_dir / "dataset", + ) + return ReplayBuffer.from_lerobot_dataset( + lerobot_dataset=dataset, + capacity=cfg.training.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + optimize_memory=True, + ) + + +def get_observation_features( + policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if ( + policy.config.vision_encoder_name is None + or not policy.config.freeze_vision_encoder + ): + return None, None + + with torch.no_grad(): + observation_features = ( + policy.actor.encoder(observations) + if policy.actor.encoder is not None + else None + ) + next_observation_features = ( + policy.actor.encoder(next_observations) + if policy.actor.encoder is not None + else None + ) + + return observation_features, next_observation_features + + +def use_threads(cfg: DictConfig) -> bool: + return cfg.actor_learner_config.concurrency.learner == "threads" + + +def start_learner_threads( + cfg: DictConfig, + logger: Logger, + out_dir: str, + shutdown_event: any, # Event, +) -> None: + # Create multiprocessing queues + transition_queue = Queue() + interaction_message_queue = Queue() + parameters_queue = Queue() + + concurrency_entity = None + + if use_threads(cfg): + from threading import Thread + + concurrency_entity = Thread + else: + from torch.multiprocessing import Process + + concurrency_entity = Process + + communication_process = concurrency_entity( + target=start_learner_server, + args=( + parameters_queue, + transition_queue, + interaction_message_queue, + shutdown_event, + cfg, + ), + daemon=True, + ) + communication_process.start() + + add_actor_information_and_train( + cfg, + logger, + out_dir, + shutdown_event, + transition_queue, + interaction_message_queue, + parameters_queue, + ) + logging.info("[LEARNER] Training process stopped") + + logging.info("[LEARNER] Closing queues") + transition_queue.close() + interaction_message_queue.close() + parameters_queue.close() + + communication_process.join() + logging.info("[LEARNER] Communication process joined") + + logging.info("[LEARNER] join queues") + transition_queue.cancel_join_thread() + interaction_message_queue.cancel_join_thread() + parameters_queue.cancel_join_thread() + + logging.info("[LEARNER] queues closed") + + +def start_learner_server( + parameters_queue: Queue, + transition_queue: Queue, + interaction_message_queue: Queue, + shutdown_event: any, # Event, + cfg: DictConfig, +): + if not use_threads(cfg): + # We need init logging for MP separataly + init_logging() + + # Setup process handlers to handle shutdown signal + # But use shutdown event from the main process + # Return back for MP + setup_process_handlers(False) + + service = learner_service.LearnerService( + shutdown_event, + parameters_queue, + cfg.actor_learner_config.policy_parameters_push_frequency, + transition_queue, + interaction_message_queue, + ) + + server = grpc.server( + ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS), + options=[ + ("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE), + ("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE), + ], + ) + + hilserl_pb2_grpc.add_LearnerServiceServicer_to_server( + service, + server, + ) + + host = cfg.actor_learner_config.learner_host + port = cfg.actor_learner_config.learner_port + + server.add_insecure_port(f"{host}:{port}") + server.start() + logging.info("[LEARNER] gRPC server started") + + shutdown_event.wait() + logging.info("[LEARNER] Stopping gRPC server...") + server.stop(learner_service.STUTDOWN_TIMEOUT) + logging.info("[LEARNER] gRPC server stopped") + + +def check_nan_in_transition( + observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor +): + for k in observations: + if torch.isnan(observations[k]).any(): + logging.error(f"observations[{k}] contains NaN values") + for k in next_state: + if torch.isnan(next_state[k]).any(): + logging.error(f"next_state[{k}] contains NaN values") + if torch.isnan(actions).any(): + logging.error("actions contains NaN values") + + +def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module): + logging.debug("[LEARNER] Pushing actor policy to the queue") + state_dict = move_state_dict_to_device(policy.actor.state_dict(), device="cpu") + state_bytes = state_to_bytes(state_dict) + parameters_queue.put(state_bytes) + + +def add_actor_information_and_train( + cfg, + logger: Logger, + out_dir: str, + shutdown_event: any, # Event, + transition_queue: Queue, + interaction_message_queue: Queue, + parameters_queue: Queue, +): + """ + Handles data transfer from the actor to the learner, manages training updates, + and logs training progress in an online reinforcement learning setup. + + This function continuously: + - Transfers transitions from the actor to the replay buffer. + - Logs received interaction messages. + - Ensures training begins only when the replay buffer has a sufficient number of transitions. + - Samples batches from the replay buffer and performs multiple critic updates. + - Periodically updates the actor, critic, and temperature optimizers. + - Logs training statistics, including loss values and optimization frequency. + + **NOTE:** + - This function performs multiple responsibilities (data transfer, training, and logging). + It should ideally be split into smaller functions in the future. + - Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks + significantly reduces performance. Instead, this function executes all operations in a single thread. + + Args: + cfg: Configuration object containing hyperparameters. + device (str): The computing device (`"cpu"` or `"cuda"`). + logger (Logger): Logger instance for tracking training progress. + out_dir (str): The output directory for storing training checkpoints and logs. + shutdown_event (Event): Event to signal shutdown. + transition_queue (Queue): Queue for receiving transitions from the actor. + interaction_message_queue (Queue): Queue for receiving interaction messages from the actor. + parameters_queue (Queue): Queue for sending policy parameters to the actor. + """ + + device = get_safe_torch_device(cfg.device, log=True) + storage_device = get_safe_torch_device(cfg_device=cfg.training.storage_device) + + logging.info("Initializing policy") + ### Instantiate the policy in both the actor and learner processes + ### To avoid sending a SACPolicy object through the port, we create a policy intance + ### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters + # TODO: At some point we should just need make sac policy + + policy: SACPolicy = make_policy( + hydra_cfg=cfg, + # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, + # Hack: But if we do online traning, we do not need dataset_stats + dataset_stats=None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) + if cfg.resume + else None, + ) + # compile policy + policy = torch.compile(policy) + assert isinstance(policy, nn.Module) + + push_actor_policy_to_queue(parameters_queue, policy) + + last_time_policy_pushed = time.time() + + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) + resume_optimization_step, resume_interaction_step = load_training_state( + cfg, logger, optimizers + ) + + log_training_info(cfg, out_dir, policy) + + replay_buffer = initialize_replay_buffer(cfg, logger, device, storage_device) + batch_size = cfg.training.batch_size + offline_replay_buffer = None + + if cfg.dataset_repo_id is not None: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + logging.info("Convertion to a offline replay buffer") + active_action_dims = None + if cfg.env.wrapper.joint_masking_action_space is not None: + active_action_dims = [ + i + for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) + if mask + ] + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + action_mask=active_action_dims, + action_delta=cfg.env.wrapper.delta_action, + storage_device=storage_device, + optimize_memory=True, + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer + + # NOTE: This function doesn't have a single responsibility, it should be split into multiple functions + # in the future. The reason why we did that is the GIL in Python. It's super slow the performance + # are divided by 200. So we need to have a single thread that does all the work. + time.time() + logging.info("Starting learner thread") + interaction_message, transition = None, None + optimization_step = ( + resume_optimization_step if resume_optimization_step is not None else 0 + ) + interaction_step_shift = ( + resume_interaction_step if resume_interaction_step is not None else 0 + ) + + while True: + if shutdown_event is not None and shutdown_event.is_set(): + logging.info("[LEARNER] Shutdown signal received. Exiting...") + break + + logging.debug("[LEARNER] Waiting for transitions") + while not transition_queue.empty() and not shutdown_event.is_set(): + transition_list = transition_queue.get() + transition_list = bytes_to_transitions(transition_list) + + for transition in transition_list: + transition = move_transition_to_device(transition, device=device) + replay_buffer.add(**transition) + if transition.get("complementary_info", {}).get("is_intervention"): + offline_replay_buffer.add(**transition) + logging.debug("[LEARNER] Received transitions") + logging.debug("[LEARNER] Waiting for interactions") + while not interaction_message_queue.empty() and not shutdown_event.is_set(): + interaction_message = interaction_message_queue.get() + interaction_message = bytes_to_python_object(interaction_message) + # If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging + interaction_message["Interaction step"] += interaction_step_shift + logger.log_dict( + interaction_message, mode="train", custom_step_key="Interaction step" + ) + + logging.debug("[LEARNER] Received interactions") + + if len(replay_buffer) < cfg.training.online_step_before_learning: + continue + + logging.debug("[LEARNER] Starting optimization loop") + time_for_one_optimization_step = time.time() + for _ in range(cfg.policy.utd_ratio - 1): + batch = replay_buffer.sample(batch_size) + + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions(batch, batch_offline) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + check_nan_in_transition( + observations=observations, actions=actions, next_state=next_observations + ) + + observation_features, next_observation_features = get_observation_features( + policy, observations, next_observations + ) + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() + + batch = replay_buffer.sample(batch_size) + + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + check_nan_in_transition( + observations=observations, actions=actions, next_state=next_observations + ) + + observation_features, next_observation_features = get_observation_features( + policy, observations, next_observations + ) + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + observation_features=observation_features, + next_observation_features=next_observation_features, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() + + training_infos = {} + training_infos["loss_critic"] = loss_critic.item() + + if optimization_step % cfg.training.policy_update_freq == 0: + for _ in range(cfg.training.policy_update_freq): + loss_actor = policy.compute_loss_actor( + observations=observations, + observation_features=observation_features, + ) + + optimizers["actor"].zero_grad() + loss_actor.backward() + optimizers["actor"].step() + + training_infos["loss_actor"] = loss_actor.item() + + loss_temperature = policy.compute_loss_temperature( + observations=observations, + observation_features=observation_features, + ) + optimizers["temperature"].zero_grad() + loss_temperature.backward() + optimizers["temperature"].step() + + training_infos["loss_temperature"] = loss_temperature.item() + + if ( + time.time() - last_time_policy_pushed + > cfg.actor_learner_config.policy_parameters_push_frequency + ): + push_actor_policy_to_queue(parameters_queue, policy) + last_time_policy_pushed = time.time() + + policy.update_target_networks() + if optimization_step % cfg.training.log_freq == 0: + training_infos["Optimization step"] = optimization_step + logger.log_dict( + d=training_infos, mode="train", custom_step_key="Optimization step" + ) + # logging.info(f"Training infos: {training_infos}") + + time_for_one_optimization_step = time.time() - time_for_one_optimization_step + frequency_for_one_optimization_step = 1 / ( + time_for_one_optimization_step + 1e-9 + ) + + logging.info( + f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}" + ) + + logger.log_dict( + { + "Optimization frequency loop [Hz]": frequency_for_one_optimization_step, + "Optimization step": optimization_step, + }, + mode="train", + custom_step_key="Optimization step", + ) + + optimization_step += 1 + if optimization_step % cfg.training.log_freq == 0: + logging.info(f"[LEARNER] Number of optimization step: {optimization_step}") + + if cfg.training.save_checkpoint and ( + optimization_step % cfg.training.save_freq == 0 + or optimization_step == cfg.training.online_steps + ): + logging.info(f"Checkpoint policy after step {optimization_step}") + # Note: Save with step as the identifier, and format it to have at least 6 digits but more if + # needed (choose 6 as a minimum for consistency without being overkill). + _num_digits = max(6, len(str(cfg.training.online_steps))) + step_identifier = f"{optimization_step:0{_num_digits}d}" + interaction_step = ( + interaction_message["Interaction step"] + if interaction_message is not None + else 0 + ) + logger.save_checkpoint( + optimization_step, + policy, + optimizers, + scheduler=None, + identifier=step_identifier, + interaction_step=interaction_step, + ) + + # TODO : temporarly save replay buffer here, remove later when on the robot + # We want to control this with the keyboard inputs + dataset_dir = logger.log_dir / "dataset" + if dataset_dir.exists() and dataset_dir.is_dir(): + shutil.rmtree( + dataset_dir, + ) + replay_buffer.to_lerobot_dataset( + cfg.dataset_repo_id, fps=cfg.fps, root=logger.log_dir / "dataset" + ) + + logging.info("Resume training") + + +def make_optimizers_and_scheduler(cfg, policy: nn.Module): + """ + Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy. + + This function sets up Adam optimizers for: + - The **actor network**, ensuring that only relevant parameters are optimized. + - The **critic ensemble**, which evaluates the value function. + - The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods. + + It also initializes a learning rate scheduler, though currently, it is set to `None`. + + **NOTE:** + - If the encoder is shared, its parameters are excluded from the actor's optimization process. + - The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor. + + Args: + cfg: Configuration object containing hyperparameters. + policy (nn.Module): The policy model containing the actor, critic, and temperature components. + + Returns: + Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]: + A tuple containing: + - `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers. + - `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling. + + """ + optimizer_actor = torch.optim.Adam( + # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor + params=policy.actor.parameters_to_optimize, + lr=policy.config.actor_lr, + ) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr + ) + optimizer_temperature = torch.optim.Adam( + params=[policy.log_alpha], lr=policy.config.critic_lr + ) + lr_scheduler = None + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + return optimizers, lr_scheduler + + +def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): + if out_dir is None: + raise NotImplementedError() + if job_name is None: + raise NotImplementedError() + + init_logging() + logging.info(pformat(OmegaConf.to_container(cfg))) + + logger = Logger(cfg, out_dir, wandb_job_name=job_name) + cfg = handle_resume_logic(cfg, out_dir) + + set_global_seed(cfg.seed) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + shutdown_event = setup_process_handlers(use_threads(cfg)) + + start_learner_threads( + cfg, + logger, + out_dir, + shutdown_event, + ) + + +@hydra.main(version_base="1.2", config_name="default", config_path="../../configs") +def train_cli(cfg: dict): + if not use_threads(cfg): + import torch.multiprocessing as mp + + mp.set_start_method("spawn") + + train( + cfg, + out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, + job_name=hydra.core.hydra_config.HydraConfig.get().job.name, + ) + + logging.info("[LEARNER] train_cli finished") + + +if __name__ == "__main__": + train_cli() + + logging.info("[LEARNER] main finished") diff --git a/lerobot/scripts/server/learner_service.py b/lerobot/scripts/server/learner_service.py new file mode 100644 index 000000000..b1f91cdc7 --- /dev/null +++ b/lerobot/scripts/server/learner_service.py @@ -0,0 +1,82 @@ +import hilserl_pb2 # type: ignore +import hilserl_pb2_grpc # type: ignore +import logging +from multiprocessing import Event, Queue + +from lerobot.scripts.server.network_utils import receive_bytes_in_chunks +from lerobot.scripts.server.network_utils import send_bytes_in_chunks + +MAX_MESSAGE_SIZE = 4 * 1024 * 1024 # 4 MB +MAX_WORKERS = 3 # Stream parameters, send transitions and interactions +STUTDOWN_TIMEOUT = 10 + + +class LearnerService(hilserl_pb2_grpc.LearnerServiceServicer): + def __init__( + self, + shutdown_event: Event, + parameters_queue: Queue, + seconds_between_pushes: float, + transition_queue: Queue, + interaction_message_queue: Queue, + ): + self.shutdown_event = shutdown_event + self.parameters_queue = parameters_queue + self.seconds_between_pushes = seconds_between_pushes + self.transition_queue = transition_queue + self.interaction_message_queue = interaction_message_queue + + def StreamParameters(self, request, context): + # TODO: authorize the request + logging.info("[LEARNER] Received request to stream parameters from the Actor") + + while not self.shutdown_event.is_set(): + logging.info("[LEARNER] Push parameters to the Actor") + buffer = self.parameters_queue.get() + + yield from send_bytes_in_chunks( + buffer, + hilserl_pb2.Parameters, + log_prefix="[LEARNER] Sending parameters", + silent=True, + ) + + logging.info("[LEARNER] Parameters sent") + + self.shutdown_event.wait(self.seconds_between_pushes) + + logging.info("[LEARNER] Stream parameters finished") + return hilserl_pb2.Empty() + + def SendTransitions(self, request_iterator, _context): + # TODO: authorize the request + logging.info("[LEARNER] Received request to receive transitions from the Actor") + + receive_bytes_in_chunks( + request_iterator, + self.transition_queue, + self.shutdown_event, + log_prefix="[LEARNER] transitions", + ) + + logging.debug("[LEARNER] Finished receiving transitions") + return hilserl_pb2.Empty() + + def SendInteractions(self, request_iterator, _context): + # TODO: authorize the request + logging.info( + "[LEARNER] Received request to receive interactions from the Actor" + ) + + receive_bytes_in_chunks( + request_iterator, + self.interaction_message_queue, + self.shutdown_event, + log_prefix="[LEARNER] interactions", + ) + + logging.debug("[LEARNER] Finished receiving interactions") + return hilserl_pb2.Empty() + + def Ready(self, request, context): + return hilserl_pb2.Empty() diff --git a/lerobot/scripts/server/maniskill_manipulator.py b/lerobot/scripts/server/maniskill_manipulator.py new file mode 100644 index 000000000..e4d55955c --- /dev/null +++ b/lerobot/scripts/server/maniskill_manipulator.py @@ -0,0 +1,192 @@ +import einops +import numpy as np +import gymnasium as gym +import torch + +from omegaconf import DictConfig +from typing import Any +from mani_skill.vector.wrappers.gymnasium import ManiSkillVectorEnv +from mani_skill.utils.wrappers.record import RecordEpisode + + +def preprocess_maniskill_observation( + observations: dict[str, np.ndarray], +) -> dict[str, torch.Tensor]: + """Convert environment observation to LeRobot format observation. + Args: + observation: Dictionary of observation batches from a Gym vector environment. + Returns: + Dictionary of observation batches with keys renamed to LeRobot format and values as tensors. + """ + # map to expected inputs for the policy + return_observations = {} + # TODO: You have to merge all tensors from agent key and extra key + # You don't keep sensor param key in the observation + # And you keep sensor data rgb + q_pos = observations["agent"]["qpos"] + q_vel = observations["agent"]["qvel"] + tcp_pos = observations["extra"]["tcp_pose"] + img = observations["sensor_data"]["base_camera"]["rgb"] + + _, h, w, c = img.shape + assert c < h and c < w, f"expect channel last images, but instead got {img.shape=}" + + # sanity check that images are uint8 + assert img.dtype == torch.uint8, f"expect torch.uint8, but instead {img.dtype=}" + + # convert to channel first of type float32 in range [0,1] + img = einops.rearrange(img, "b h w c -> b c h w").contiguous() + img = img.type(torch.float32) + img /= 255 + + state = torch.cat([q_pos, q_vel, tcp_pos], dim=-1) + + return_observations["observation.image"] = img + return_observations["observation.state"] = state + return return_observations + + +class ManiSkillObservationWrapper(gym.ObservationWrapper): + def __init__(self, env, device: torch.device = "cuda"): + super().__init__(env) + self.device = device + + def observation(self, observation): + observation = preprocess_maniskill_observation(observation) + observation = {k: v.to(self.device) for k, v in observation.items()} + return observation + + +class ManiSkillCompat(gym.Wrapper): + def __init__(self, env): + super().__init__(env) + new_action_space_shape = env.action_space.shape[-1] + new_low = np.squeeze(env.action_space.low, axis=0) + new_high = np.squeeze(env.action_space.high, axis=0) + self.action_space = gym.spaces.Box( + low=new_low, high=new_high, shape=(new_action_space_shape,) + ) + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Any, dict[str, Any]]: + options = {} + return super().reset(seed=seed, options=options) + + def step(self, action): + obs, reward, terminated, truncated, info = self.env.step(action) + reward = reward.item() + terminated = terminated.item() + truncated = truncated.item() + return obs, reward, terminated, truncated, info + + +class ManiSkillActionWrapper(gym.ActionWrapper): + def __init__(self, env): + super().__init__(env) + self.action_space = gym.spaces.Tuple( + spaces=(env.action_space, gym.spaces.Discrete(2)) + ) + + def action(self, action): + action, telop = action + return action + + +class ManiSkillMultiplyActionWrapper(gym.Wrapper): + def __init__(self, env, multiply_factor: float = 1): + super().__init__(env) + self.multiply_factor = multiply_factor + action_space_agent: gym.spaces.Box = env.action_space[0] + action_space_agent.low = action_space_agent.low * multiply_factor + action_space_agent.high = action_space_agent.high * multiply_factor + self.action_space = gym.spaces.Tuple( + spaces=(action_space_agent, gym.spaces.Discrete(2)) + ) + + def step(self, action): + if isinstance(action, tuple): + action, telop = action + else: + telop = 0 + action = action / self.multiply_factor + obs, reward, terminated, truncated, info = self.env.step((action, telop)) + return obs, reward, terminated, truncated, info + + +def make_maniskill( + cfg: DictConfig, + n_envs: int | None = None, +) -> gym.Env: + """ + Factory function to create a ManiSkill environment with standard wrappers. + + Args: + task: Name of the ManiSkill task + obs_mode: Observation mode (rgb, rgbd, etc) + control_mode: Control mode for the robot + render_mode: Rendering mode + sensor_configs: Camera sensor configurations + n_envs: Number of parallel environments + + Returns: + A wrapped ManiSkill environment + """ + + env = gym.make( + cfg.env.task, + obs_mode=cfg.env.obs, + control_mode=cfg.env.control_mode, + render_mode=cfg.env.render_mode, + sensor_configs={"width": cfg.env.image_size, "height": cfg.env.image_size}, + num_envs=n_envs, + ) + + if cfg.env.video_record.enabled: + env = RecordEpisode( + env, + output_dir=cfg.env.video_record.record_dir, + save_trajectory=True, + trajectory_name=cfg.env.video_record.trajectory_name, + save_video=True, + video_fps=30, + ) + env = ManiSkillObservationWrapper(env, device=cfg.env.device) + env = ManiSkillVectorEnv(env, ignore_terminations=True, auto_reset=False) + env._max_episode_steps = env.max_episode_steps = ( + 50 # gym_utils.find_max_episode_steps_value(env) + ) + env.unwrapped.metadata["render_fps"] = 20 + env = ManiSkillCompat(env) + env = ManiSkillActionWrapper(env) + env = ManiSkillMultiplyActionWrapper(env, multiply_factor=1) + + return env + + +if __name__ == "__main__": + import argparse + import hydra + + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", type=str, default="lerobot/configs/env/maniskill_example.yaml" + ) + args = parser.parse_args() + + # Initialize config + with hydra.initialize(version_base=None, config_path="../../configs"): + cfg = hydra.compose(config_name="env/maniskill_example.yaml") + + env = make_maniskill( + task=cfg.env.task, + obs_mode=cfg.env.obs, + control_mode=cfg.env.control_mode, + render_mode=cfg.env.render_mode, + sensor_configs={"width": cfg.env.render_size, "height": cfg.env.render_size}, + ) + + print("env done") + obs, info = env.reset() + random_action = env.action_space.sample() + obs, reward, terminated, truncated, info = env.step(random_action) diff --git a/lerobot/scripts/server/network_utils.py b/lerobot/scripts/server/network_utils.py new file mode 100644 index 000000000..f5e8973b1 --- /dev/null +++ b/lerobot/scripts/server/network_utils.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.scripts.server import hilserl_pb2 +import logging +import io +from multiprocessing import Queue, Event +from typing import Any + +CHUNK_SIZE = 2 * 1024 * 1024 # 2 MB + + +def bytes_buffer_size(buffer: io.BytesIO) -> int: + buffer.seek(0, io.SEEK_END) + result = buffer.tell() + buffer.seek(0) + return result + + +def send_bytes_in_chunks( + buffer: bytes, message_class: Any, log_prefix: str = "", silent: bool = True +): + buffer = io.BytesIO(buffer) + size_in_bytes = bytes_buffer_size(buffer) + + sent_bytes = 0 + + logging_method = logging.info if not silent else logging.debug + + logging_method(f"{log_prefix} Buffer size {size_in_bytes/1024/1024} MB with") + + while sent_bytes < size_in_bytes: + transfer_state = hilserl_pb2.TransferState.TRANSFER_MIDDLE + + if sent_bytes + CHUNK_SIZE >= size_in_bytes: + transfer_state = hilserl_pb2.TransferState.TRANSFER_END + elif sent_bytes == 0: + transfer_state = hilserl_pb2.TransferState.TRANSFER_BEGIN + + size_to_read = min(CHUNK_SIZE, size_in_bytes - sent_bytes) + chunk = buffer.read(size_to_read) + + yield message_class(transfer_state=transfer_state, data=chunk) + sent_bytes += size_to_read + logging_method( + f"{log_prefix} Sent {sent_bytes}/{size_in_bytes} bytes with state {transfer_state}" + ) + + logging_method(f"{log_prefix} Published {sent_bytes/1024/1024} MB") + + +def receive_bytes_in_chunks( + iterator, queue: Queue, shutdown_event: Event, log_prefix: str = "" +): + bytes_buffer = io.BytesIO() + step = 0 + + logging.info(f"{log_prefix} Starting receiver") + for item in iterator: + logging.debug(f"{log_prefix} Received item") + if shutdown_event.is_set(): + logging.info(f"{log_prefix} Shutting down receiver") + return + + if item.transfer_state == hilserl_pb2.TransferState.TRANSFER_BEGIN: + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + bytes_buffer.write(item.data) + logging.debug(f"{log_prefix} Received data at step 0") + step = 0 + continue + elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_MIDDLE: + bytes_buffer.write(item.data) + step += 1 + logging.debug(f"{log_prefix} Received data at step {step}") + elif item.transfer_state == hilserl_pb2.TransferState.TRANSFER_END: + bytes_buffer.write(item.data) + logging.debug( + f"{log_prefix} Received data at step end size {bytes_buffer_size(bytes_buffer)}" + ) + + queue.put(bytes_buffer.getvalue()) + + bytes_buffer.seek(0) + bytes_buffer.truncate(0) + step = 0 + + logging.debug(f"{log_prefix} Queue updated") diff --git a/lerobot/scripts/server/utils.py b/lerobot/scripts/server/utils.py new file mode 100644 index 000000000..699717e4a --- /dev/null +++ b/lerobot/scripts/server/utils.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import signal +import sys +from torch.multiprocessing import Queue +from queue import Empty + +shutdown_event_counter = 0 + + +def setup_process_handlers(use_threads: bool) -> any: + if use_threads: + from threading import Event + else: + from multiprocessing import Event + + shutdown_event = Event() + + # Define signal handler + def signal_handler(signum, frame): + logging.info("Shutdown signal received. Cleaning up...") + shutdown_event.set() + global shutdown_event_counter + shutdown_event_counter += 1 + + if shutdown_event_counter > 1: + logging.info("Force shutdown") + sys.exit(1) + + signal.signal(signal.SIGINT, signal_handler) # Ctrl+C + signal.signal(signal.SIGTERM, signal_handler) # Termination request (kill) + signal.signal(signal.SIGHUP, signal_handler) # Terminal closed/Hangup + signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\ + + def signal_handler(signum, frame): + logging.info("Shutdown signal received. Cleaning up...") + shutdown_event.set() + + return shutdown_event + + +def get_last_item_from_queue(queue: Queue): + item = queue.get() + counter = 1 + + # Drain queue and keep only the most recent parameters + try: + while True: + item = queue.get_nowait() + counter += 1 + except Empty: + pass + + logging.debug(f"Drained {counter} items from queue") + + return item diff --git a/lerobot/scripts/train.py b/lerobot/scripts/train.py index a4eb35286..120895c48 100644 --- a/lerobot/scripts/train.py +++ b/lerobot/scripts/train.py @@ -71,7 +71,9 @@ def make_optimizer_and_scheduler(cfg, policy): }, ] optimizer = torch.optim.AdamW( - optimizer_params_dicts, lr=cfg.training.lr, weight_decay=cfg.training.weight_decay + optimizer_params_dicts, + lr=cfg.training.lr, + weight_decay=cfg.training.weight_decay, ) lr_scheduler = None elif cfg.policy.name == "diffusion": @@ -98,14 +100,23 @@ def make_optimizer_and_scheduler(cfg, policy): optimizer = torch.optim.Adam( [ {"params": policy.actor.parameters(), "lr": policy.config.actor_lr}, - {"params": policy.critic_ensemble.parameters(), "lr": policy.config.critic_lr}, - {"params": policy.temperature.parameters(), "lr": policy.config.temperature_lr}, + { + "params": policy.critic_ensemble.parameters(), + "lr": policy.config.critic_lr, + }, + { + "params": policy.temperature.parameters(), + "lr": policy.config.temperature_lr, + }, ] ) lr_scheduler = None elif cfg.policy.name == "vqbet": - from lerobot.common.policies.vqbet.modeling_vqbet import VQBeTOptimizer, VQBeTScheduler + from lerobot.common.policies.vqbet.modeling_vqbet import ( + VQBeTOptimizer, + VQBeTScheduler, + ) optimizer = VQBeTOptimizer(policy, cfg) lr_scheduler = VQBeTScheduler(optimizer, cfg) @@ -255,7 +266,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No logging.info(pformat(OmegaConf.to_container(cfg))) if cfg.training.online_steps > 0 and isinstance(cfg.dataset_repo_id, ListConfig): - raise NotImplementedError("Online training with LeRobotMultiDataset is not implemented.") + raise NotImplementedError( + "Online training with LeRobotMultiDataset is not implemented." + ) # If we are resuming a run, we need to check that a checkpoint exists in the log directory, and we need # to check for any differences between the provided config and the checkpoint's config. @@ -265,7 +278,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No "You have set resume=True, but there is no model checkpoint in " f"{Logger.get_last_checkpoint_dir(out_dir)}" ) - checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + checkpoint_cfg_path = str( + Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml" + ) logging.info( colored( "You have set resume=True, indicating that you wish to resume a run", @@ -278,7 +293,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # Check for differences between the checkpoint configuration and provided configuration. # Hack to resolve the delta_timestamps ahead of time in order to properly diff. resolve_delta_timestamps(cfg) - diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + diff = DeepDiff( + OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg) + ) # Ignore the `resume` and parameters. if "values_changed" in diff and "root['resume']" in diff["values_changed"]: del diff["values_changed"]["root['resume']"] @@ -325,7 +342,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No # TODO (michel-aractingi): temporary fix to avoid datasets with task_index key that doesn't exist in online environment # i.e., pusht if "task_index" in offline_dataset.hf_dataset[0]: - offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns(["task_index"]) + offline_dataset.hf_dataset = offline_dataset.hf_dataset.remove_columns( + ["task_index"] + ) if isinstance(offline_dataset, MultiLeRobotDataset): logging.info( @@ -345,7 +364,9 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No policy = make_policy( hydra_cfg=cfg, dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, - pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) + if cfg.resume + else None, ) assert isinstance(policy, nn.Module) # Create optimizer and scheduler @@ -358,36 +379,58 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No if cfg.resume: step = logger.load_last_training_state(optimizer, lr_scheduler) - num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) + num_learnable_params = sum( + p.numel() for p in policy.parameters() if p.requires_grad + ) num_total_params = sum(p.numel() for p in policy.parameters()) log_output_dir(out_dir) logging.info(f"{cfg.env.task=}") - logging.info(f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})") + logging.info( + f"{cfg.training.offline_steps=} ({format_big_number(cfg.training.offline_steps)})" + ) logging.info(f"{cfg.training.online_steps=}") - logging.info(f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})") + logging.info( + f"{offline_dataset.num_frames=} ({format_big_number(offline_dataset.num_frames)})" + ) logging.info(f"{offline_dataset.num_episodes=}") logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") # Note: this helper will be used in offline and online training loops. def evaluate_and_checkpoint_if_needed(step, is_online): - _num_digits = max(6, len(str(cfg.training.offline_steps + cfg.training.online_steps))) + _num_digits = max( + 6, len(str(cfg.training.offline_steps + cfg.training.online_steps)) + ) step_identifier = f"{step:0{_num_digits}d}" if cfg.training.eval_freq > 0 and step % cfg.training.eval_freq == 0: logging.info(f"Eval policy at step {step}") - with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.use_amp else nullcontext(): + with ( + torch.no_grad(), + torch.autocast(device_type=device.type) + if cfg.use_amp + else nullcontext(), + ): assert eval_env is not None eval_info = eval_policy( eval_env, policy, cfg.eval.n_episodes, - videos_dir=Path(out_dir) / "eval" / f"videos_step_{step_identifier}", + videos_dir=Path(out_dir) + / "eval" + / f"videos_step_{step_identifier}", max_episodes_rendered=4, start_seed=cfg.seed, ) - log_eval_info(logger, eval_info["aggregated"], step, cfg, offline_dataset, is_online=is_online) + log_eval_info( + logger, + eval_info["aggregated"], + step, + cfg, + offline_dataset, + is_online=is_online, + ) if cfg.wandb.enable: logger.log_video(eval_info["video_paths"][0], step, mode="eval") logging.info("Resume training") @@ -456,7 +499,9 @@ def evaluate_and_checkpoint_if_needed(step, is_online): train_info["dataloading_s"] = dataloading_s if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, offline_dataset, is_online=False) + log_train_info( + logger, train_info, step, cfg, offline_dataset, is_online=False + ) # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, # so we pass in step + 1. @@ -489,8 +534,14 @@ def evaluate_and_checkpoint_if_needed(step, is_online): online_dataset = OnlineBuffer( online_buffer_path, data_spec={ - **{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.input_shapes.items()}, - **{k: {"shape": v, "dtype": np.dtype("float32")} for k, v in policy.config.output_shapes.items()}, + **{ + k: {"shape": v, "dtype": np.dtype("float32")} + for k, v in policy.config.input_shapes.items() + }, + **{ + k: {"shape": v, "dtype": np.dtype("float32")} + for k, v in policy.config.output_shapes.items() + }, "next.reward": {"shape": (), "dtype": np.dtype("float32")}, "next.done": {"shape": (), "dtype": np.dtype("?")}, "next.success": {"shape": (), "dtype": np.dtype("?")}, @@ -502,7 +553,9 @@ def evaluate_and_checkpoint_if_needed(step, is_online): # If we are doing online rollouts asynchronously, deepcopy the policy to use for online rollouts (this # makes it possible to do online rollouts in parallel with training updates). - online_rollout_policy = deepcopy(policy) if cfg.training.do_online_rollout_async else policy + online_rollout_policy = ( + deepcopy(policy) if cfg.training.do_online_rollout_async else policy + ) # Create dataloader for online training. concat_dataset = torch.utils.data.ConcatDataset([offline_dataset, online_dataset]) @@ -539,7 +592,9 @@ def evaluate_and_checkpoint_if_needed(step, is_online): online_step = 0 online_rollout_s = 0 # time take to do online rollout - update_online_buffer_s = 0 # time taken to update the online buffer with the online rollout data + update_online_buffer_s = ( + 0 # time taken to update the online buffer with the online rollout data + ) # Time taken waiting for the online buffer to finish being updated. This is relevant when using the async # online rollout option. await_update_online_buffer_s = 0 @@ -563,11 +618,16 @@ def sample_trajectory_and_update_buffer(): online_env, online_rollout_policy, n_episodes=cfg.training.online_rollout_n_episodes, - max_episodes_rendered=min(10, cfg.training.online_rollout_n_episodes), + max_episodes_rendered=min( + 10, cfg.training.online_rollout_n_episodes + ), videos_dir=logger.log_dir / "online_rollout_videos", return_episode_data=True, start_seed=( - rollout_start_seed := (rollout_start_seed + cfg.training.batch_size) % 1000000 + rollout_start_seed := ( + rollout_start_seed + cfg.training.batch_size + ) + % 1000000 ), ) online_rollout_s = time.perf_counter() - start_rollout_time @@ -577,16 +637,21 @@ def sample_trajectory_and_update_buffer(): online_dataset.add_data(eval_info["episodes"]) # Update the concatenated dataset length used during sampling. - concat_dataset.cumulative_sizes = concat_dataset.cumsum(concat_dataset.datasets) + concat_dataset.cumulative_sizes = concat_dataset.cumsum( + concat_dataset.datasets + ) # Update the sampling weights. sampler.weights = compute_sampler_weights( offline_dataset, - offline_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0), + offline_drop_n_last_frames=cfg.training.get( + "drop_n_last_frames", 0 + ), online_dataset=online_dataset, # +1 because online rollouts return an extra frame for the "final observation". Note: we don't have # this final observation in the offline datasets, but we might add them in future. - online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + 1, + online_drop_n_last_frames=cfg.training.get("drop_n_last_frames", 0) + + 1, online_sampling_ratio=cfg.training.online_sampling_ratio, ) sampler.num_frames = len(concat_dataset) @@ -639,7 +704,9 @@ def sample_trajectory_and_update_buffer(): train_info["online_buffer_size"] = len(online_dataset) if step % cfg.training.log_freq == 0: - log_train_info(logger, train_info, step, cfg, online_dataset, is_online=True) + log_train_info( + logger, train_info, step, cfg, online_dataset, is_online=True + ) # Note: evaluate_and_checkpoint_if_needed happens **after** the `step`th training update has completed, # so we pass in step + 1. @@ -672,7 +739,9 @@ def train_cli(cfg: dict): ) -def train_notebook(out_dir=None, job_name=None, config_name="default", config_path="../configs"): +def train_notebook( + out_dir=None, job_name=None, config_name="default", config_path="../configs" +): from hydra import compose, initialize hydra.core.global_hydra.GlobalHydra.instance().clear() diff --git a/lerobot/scripts/train_hilserl_classifier.py b/lerobot/scripts/train_hilserl_classifier.py index 458e3ff14..1cae01836 100644 --- a/lerobot/scripts/train_hilserl_classifier.py +++ b/lerobot/scripts/train_hilserl_classifier.py @@ -1,5 +1,3 @@ -#!/usr/bin/env python - # Copyright 2024 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -16,10 +14,10 @@ import logging import time from contextlib import nullcontext -from pathlib import Path from pprint import pformat import hydra +import numpy as np import torch import torch.nn as nn import wandb @@ -27,41 +25,67 @@ from omegaconf import DictConfig, OmegaConf from termcolor import colored from torch import optim +from torch.autograd import profiler from torch.cuda.amp import GradScaler -from torch.utils.data import DataLoader, WeightedRandomSampler, random_split +from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler from tqdm import tqdm from lerobot.common.datasets.factory import resolve_delta_timestamps from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.logger import Logger from lerobot.common.policies.factory import _policy_cfg_from_hydra_cfg -from lerobot.common.policies.hilserl.classifier.configuration_classifier import ClassifierConfig +from lerobot.common.policies.hilserl.classifier.configuration_classifier import ( + ClassifierConfig, +) from lerobot.common.policies.hilserl.classifier.modeling_classifier import Classifier from lerobot.common.utils.utils import ( format_big_number, get_safe_torch_device, init_hydra_config, + init_logging, set_global_seed, ) +from lerobot.scripts.server.buffer import random_shift def get_model(cfg, logger): # noqa I001 classifier_config = _policy_cfg_from_hydra_cfg(ClassifierConfig, cfg) model = Classifier(classifier_config) if cfg.resume: - model.load_state_dict(Classifier.from_pretrained(str(logger.last_pretrained_model_dir)).state_dict()) + model.load_state_dict( + Classifier.from_pretrained( + str(logger.last_pretrained_model_dir) + ).state_dict() + ) return model def create_balanced_sampler(dataset, cfg): - # Creates a weighted sampler to handle class imbalance + # Get underlying dataset if using Subset + original_dataset = ( + dataset.dataset if isinstance(dataset, torch.utils.data.Subset) else dataset + ) + + # Get indices if using Subset (for slicing) + indices = dataset.indices if isinstance(dataset, torch.utils.data.Subset) else None + + # Get labels from Hugging Face dataset + if indices is not None: + # Get subset of labels using Hugging Face's select() + hf_subset = original_dataset.hf_dataset.select(indices) + labels = hf_subset[cfg.training.label_key] + else: + # Get all labels directly + labels = original_dataset.hf_dataset[cfg.training.label_key] - labels = torch.tensor([item[cfg.training.label_key] for item in dataset]) + labels = torch.stack(labels) _, counts = torch.unique(labels, return_counts=True) class_weights = 1.0 / counts.float() sample_weights = class_weights[labels] - return WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True) + return WeightedRandomSampler( + weights=sample_weights, num_samples=len(sample_weights), replacement=True + ) def support_amp(device: torch.device, cfg: DictConfig) -> bool: @@ -70,7 +94,9 @@ def support_amp(device: torch.device, cfg: DictConfig) -> bool: return cfg.training.use_amp and device.type in ("cuda", "cpu") -def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg): +def train_epoch( + model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg +): # Single epoch training loop with AMP support and progress tracking model.train() correct = 0 @@ -80,10 +106,15 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, for batch_idx, batch in enumerate(pbar): start_time = time.perf_counter() images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] + images = [random_shift(img, 4) for img in images] labels = batch[cfg.training.label_key].float().to(device) # Forward pass with optional AMP - with torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(): + with ( + torch.autocast(device_type=device.type) + if support_amp(device, cfg) + else nullcontext() + ): outputs = model(images) loss = criterion(outputs.logits, labels) @@ -116,7 +147,7 @@ def train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, pbar.set_postfix({"loss": f"{loss.item():.4f}", "acc": f"{current_acc:.2f}%"}) -def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_log=8): +def validate(model, val_loader, criterion, device, logger, cfg): # Validation loop with metric tracking and sample logging model.eval() correct = 0 @@ -124,16 +155,32 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l batch_start_time = time.perf_counter() samples = [] running_loss = 0 + inference_times = [] with ( torch.no_grad(), - torch.autocast(device_type=device.type) if support_amp(device, cfg) else nullcontext(), + torch.autocast(device_type=device.type) + if support_amp(device, cfg) + else nullcontext(), ): for batch in tqdm(val_loader, desc="Validation"): images = [batch[img_key].to(device) for img_key in cfg.training.image_keys] labels = batch[cfg.training.label_key].float().to(device) - outputs = model(images) + if cfg.training.profile_inference_time and logger._cfg.wandb.enable: + with ( + profiler.profile(record_shapes=True) as prof, + profiler.record_function("model_inference"), + ): + outputs = model(images) + inference_times.append( + next( + x for x in prof.key_averages() if x.key == "model_inference" + ).cpu_time + ) + else: + outputs = model(images) + loss = criterion(outputs.logits, labels) # Track metrics @@ -146,15 +193,26 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l running_loss += loss.item() # Log sample predictions for visualization - if len(samples) < num_samples_to_log: - for i in range(min(num_samples_to_log - len(samples), len(images))): + if len(samples) < cfg.eval.num_samples_to_log: + for i in range( + min(cfg.eval.num_samples_to_log - len(samples), len(images)) + ): if model.config.num_classes == 2: confidence = round(outputs.probabilities[i].item(), 3) else: - confidence = [round(prob, 3) for prob in outputs.probabilities[i].tolist()] + confidence = [ + round(prob, 3) for prob in outputs.probabilities[i].tolist() + ] samples.append( { - "image": wandb.Image(images[i].cpu()), + **{ + f"image_{img_key}": wandb.Image( + images[img_idx][i].cpu() + ) + for img_idx, img_key in enumerate( + cfg.training.image_keys + ) + }, "true_label": labels[i].item(), "predicted": predictions[i].item(), "confidence": confidence, @@ -170,36 +228,122 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l "accuracy": accuracy, "eval_s": time.perf_counter() - batch_start_time, "eval/prediction_samples": wandb.Table( - data=[[s["image"], s["true_label"], s["predicted"], f"{s['confidence']}"] for s in samples], - columns=["Image", "True Label", "Predicted", "Confidence"], + data=[list(s.values()) for s in samples], + columns=list(samples[0].keys()), ) if logger._cfg.wandb.enable else None, } + if len(inference_times) > 0: + eval_info["inference_time_avg"] = np.mean(inference_times) + eval_info["inference_time_median"] = np.median(inference_times) + eval_info["inference_time_std"] = np.std(inference_times) + eval_info["inference_time_batch_size"] = val_loader.batch_size + + print( + f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}" + ) + return accuracy, eval_info -@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier") -def train(cfg: DictConfig) -> None: +def benchmark_inference_time(model, dataset, logger, cfg, device, step): + if not cfg.training.profile_inference_time: + return + + iters = cfg.training.profile_inference_time_iters + inference_times = [] + + loader = DataLoader( + dataset, + batch_size=1, + num_workers=cfg.training.num_workers, + sampler=RandomSampler(dataset), + pin_memory=True, + ) + + model.eval() + with torch.no_grad(): + for _ in tqdm(range(iters), desc="Benchmarking inference time"): + x = next(iter(loader)) + x = [x[img_key].to(device) for img_key in cfg.training.image_keys] + + # Warm up + for _ in range(10): + _ = model(x) + + # sync the device + if device.type == "cuda": + torch.cuda.synchronize() + elif device.type == "mps": + torch.mps.synchronize() + + with ( + profiler.profile(record_shapes=True) as prof, + profiler.record_function("model_inference"), + ): + _ = model(x) + + inference_times.append( + next( + x for x in prof.key_averages() if x.key == "model_inference" + ).cpu_time + ) + + inference_times = np.array(inference_times) + avg, median, std = ( + inference_times.mean(), + np.median(inference_times), + inference_times.std(), + ) + print( + f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device" + ) + if logger._cfg.wandb.enable: + logger.log_dict( + { + "inference_time_benchmark_avg": avg, + "inference_time_benchmark_median": median, + "inference_time_benchmark_std": std, + }, + step + 1, + mode="eval", + ) + + return avg, median, std + + +def train( + cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None +) -> None: + if out_dir is None: + raise NotImplementedError() + if job_name is None: + raise NotImplementedError() + # Main training pipeline with support for resuming training + init_logging() logging.info(OmegaConf.to_yaml(cfg)) + logger = Logger(cfg, out_dir, wandb_job_name=job_name) + # Initialize training environment device = get_safe_torch_device(cfg.device, log=True) set_global_seed(cfg.seed) - out_dir = Path(cfg.output_dir) - out_dir.mkdir(parents=True, exist_ok=True) - logger = Logger(cfg, out_dir, cfg.wandb.job_name if cfg.wandb.enable else None) - # Setup dataset and dataloaders - dataset = LeRobotDataset(cfg.dataset_repo_id) + dataset = LeRobotDataset( + cfg.dataset_repo_id, + root=cfg.dataset_root, + local_files_only=cfg.local_files_only, + ) logging.info(f"Dataset size: {len(dataset)}") - train_size = int(cfg.train_split_proportion * len(dataset)) - val_size = len(dataset) - train_size - train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) + n_total = len(dataset) + n_train = int(cfg.train_split_proportion * len(dataset)) + train_dataset = torch.utils.data.Subset(dataset, range(0, n_train)) + val_dataset = torch.utils.data.Subset(dataset, range(n_train, n_total)) sampler = create_balanced_sampler(train_dataset, cfg) train_loader = DataLoader( @@ -207,7 +351,7 @@ def train(cfg: DictConfig) -> None: batch_size=cfg.training.batch_size, num_workers=cfg.training.num_workers, sampler=sampler, - pin_memory=True, + pin_memory=device.type == "cuda", ) val_loader = DataLoader( @@ -215,7 +359,7 @@ def train(cfg: DictConfig) -> None: batch_size=cfg.eval.batch_size, shuffle=False, num_workers=cfg.training.num_workers, - pin_memory=True, + pin_memory=device.type == "cuda", ) # Resume training if requested @@ -228,7 +372,9 @@ def train(cfg: DictConfig) -> None: "You have set resume=True, but there is no model checkpoint in " f"{Logger.get_last_checkpoint_dir(out_dir)}" ) - checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml") + checkpoint_cfg_path = str( + Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml" + ) logging.info( colored( "You have set resume=True, indicating that you wish to resume a run", @@ -241,7 +387,9 @@ def train(cfg: DictConfig) -> None: # Check for differences between the checkpoint configuration and provided configuration. # Hack to resolve the delta_timestamps ahead of time in order to properly diff. resolve_delta_timestamps(cfg) - diff = DeepDiff(OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg)) + diff = DeepDiff( + OmegaConf.to_container(checkpoint_cfg), OmegaConf.to_container(cfg) + ) # Ignore the `resume` and parameters. if "values_changed" in diff and "root['resume']" in diff["values_changed"]: del diff["values_changed"]["root['resume']"] @@ -260,7 +408,11 @@ def train(cfg: DictConfig) -> None: optimizer = optim.AdamW(model.parameters(), lr=cfg.training.learning_rate) # Use BCEWithLogitsLoss for binary classification and CrossEntropyLoss for multi-class - criterion = nn.BCEWithLogitsLoss() if model.config.num_classes == 2 else nn.CrossEntropyLoss() + criterion = ( + nn.BCEWithLogitsLoss() + if model.config.num_classes == 2 + else nn.CrossEntropyLoss() + ) grad_scaler = GradScaler(enabled=cfg.training.use_amp) # Log model parameters @@ -276,7 +428,17 @@ def train(cfg: DictConfig) -> None: for epoch in range(cfg.training.num_epochs): logging.info(f"\nEpoch {epoch+1}/{cfg.training.num_epochs}") - train_epoch(model, train_loader, criterion, optimizer, grad_scaler, device, logger, step, cfg) + train_epoch( + model, + train_loader, + criterion, + optimizer, + grad_scaler, + device, + logger, + step, + cfg, + ) # Periodic validation if cfg.training.eval_freq > 0 and (epoch + 1) % cfg.training.eval_freq == 0: @@ -313,8 +475,37 @@ def train(cfg: DictConfig) -> None: step += len(train_loader) + benchmark_inference_time(model, dataset, logger, cfg, device, step) + logging.info("Training completed") +@hydra.main( + version_base="1.2", + config_name="hilserl_classifier", + config_path="../configs/policy", +) +def train_cli(cfg: dict): + train( + cfg, + out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, + job_name=hydra.core.hydra_config.HydraConfig.get().job.name, + ) + + +def train_notebook( + out_dir=None, + job_name=None, + config_name="hilserl_classifier", + config_path="../configs/policy", +): + from hydra import compose, initialize + + hydra.core.global_hydra.GlobalHydra.instance().clear() + initialize(config_path=config_path) + cfg = compose(config_name=config_name) + train(cfg, out_dir=out_dir, job_name=job_name) + + if __name__ == "__main__": - train() + train_cli() diff --git a/lerobot/scripts/train_sac.py b/lerobot/scripts/train_sac.py new file mode 100644 index 000000000..cfd05f62f --- /dev/null +++ b/lerobot/scripts/train_sac.py @@ -0,0 +1,626 @@ +#!/usr/bin/env python + +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools +import logging +import random +from pprint import pformat +from typing import Callable, Optional, Sequence, TypedDict + +import hydra +import torch +import torch.nn.functional as F +from omegaconf import DictConfig, OmegaConf +from torch import nn +from tqdm import tqdm + +# TODO: Remove the import of maniskill +from lerobot.common.datasets.factory import make_dataset +from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.envs.factory import make_maniskill_env +from lerobot.common.envs.utils import preprocess_maniskill_observation +from lerobot.common.logger import Logger, log_output_dir +from lerobot.common.policies.factory import make_policy +from lerobot.common.policies.sac.modeling_sac import SACPolicy +from lerobot.common.utils.utils import ( + format_big_number, + get_safe_torch_device, + init_logging, + set_global_seed, +) + + +def make_optimizers_and_scheduler(cfg, policy): + optimizer_actor = torch.optim.Adam( + # NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor + params=policy.actor.parameters_to_optimize, + lr=policy.config.actor_lr, + ) + optimizer_critic = torch.optim.Adam( + params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr + ) + # We wrap policy log temperature in list because this is a torch tensor and not a nn.Module + optimizer_temperature = torch.optim.Adam( + params=[policy.log_alpha], lr=policy.config.critic_lr + ) + lr_scheduler = None + optimizers = { + "actor": optimizer_actor, + "critic": optimizer_critic, + "temperature": optimizer_temperature, + } + return optimizers, lr_scheduler + + +class Transition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: float + next_state: dict[str, torch.Tensor] + done: bool + complementary_info: dict[str, torch.Tensor] = None + + +class BatchTransition(TypedDict): + state: dict[str, torch.Tensor] + action: torch.Tensor + reward: torch.Tensor + next_state: dict[str, torch.Tensor] + done: torch.Tensor + + +def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: + """ + Perform a per-image random crop over a batch of images in a vectorized way. + (Same as shown previously.) + """ + B, C, H, W = images.shape + crop_h, crop_w = output_size + + if crop_h > H or crop_w > W: + raise ValueError( + f"Requested crop size ({crop_h}, {crop_w}) is bigger than the image size ({H}, {W})." + ) + + tops = torch.randint(0, H - crop_h + 1, (B,), device=images.device) + lefts = torch.randint(0, W - crop_w + 1, (B,), device=images.device) + + rows = torch.arange(crop_h, device=images.device).unsqueeze(0) + tops.unsqueeze(1) + cols = torch.arange(crop_w, device=images.device).unsqueeze(0) + lefts.unsqueeze(1) + + rows = rows.unsqueeze(2).expand(-1, -1, crop_w) # (B, crop_h, crop_w) + cols = cols.unsqueeze(1).expand(-1, crop_h, -1) # (B, crop_h, crop_w) + + images_hwcn = images.permute(0, 2, 3, 1) # (B, H, W, C) + + # Gather pixels + cropped_hwcn = images_hwcn[ + torch.arange(B, device=images.device).view(B, 1, 1), rows, cols, : + ] + # cropped_hwcn => (B, crop_h, crop_w, C) + + cropped = cropped_hwcn.permute(0, 3, 1, 2) # (B, C, crop_h, crop_w) + return cropped + + +def random_shift(images: torch.Tensor, pad: int = 4): + """Vectorized random shift, imgs: (B,C,H,W), pad: #pixels""" + _, _, h, w = images.shape + images = F.pad(input=images, pad=(pad, pad, pad, pad), mode="replicate") + return random_crop_vectorized(images=images, output_size=(h, w)) + + +class ReplayBuffer: + def __init__( + self, + capacity: int, + device: str = "cuda:0", + state_keys: Optional[Sequence[str]] = None, + image_augmentation_function: Optional[Callable] = None, + use_drq: bool = True, + ): + """ + Args: + capacity (int): Maximum number of transitions to store in the buffer. + device (str): The device where the tensors will be moved ("cuda:0" or "cpu"). + state_keys (List[str]): The list of keys that appear in `state` and `next_state`. + image_augmentation_function (Optional[Callable]): A function that takes a batch of images + and returns a batch of augmented images. If None, a default augmentation function is used. + use_drq (bool): Whether to use the default DRQ image augmentation style, when sampling in the buffer. + """ + self.capacity = capacity + self.device = device + self.memory: list[Transition] = [] + self.position = 0 + + # If no state_keys provided, default to an empty list + # (you can handle this differently if needed) + self.state_keys = state_keys if state_keys is not None else [] + if image_augmentation_function is None: + self.image_augmentation_function = functools.partial(random_shift, pad=4) + self.use_drq = use_drq + + def add( + self, + state: dict[str, torch.Tensor], + action: torch.Tensor, + reward: float, + next_state: dict[str, torch.Tensor], + done: bool, + complementary_info: Optional[dict[str, torch.Tensor]] = None, + ): + """Saves a transition.""" + if len(self.memory) < self.capacity: + self.memory.append(None) + + # Create and store the Transition + self.memory[self.position] = Transition( + state=state, + action=action, + reward=reward, + next_state=next_state, + done=done, + complementary_info=complementary_info, + ) + self.position: int = (self.position + 1) % self.capacity + + # TODO: ADD image_augmentation and use_drq arguments in this function in order to instantiate the class with them + @classmethod + def from_lerobot_dataset( + cls, + lerobot_dataset: LeRobotDataset, + device: str = "cuda:0", + state_keys: Optional[Sequence[str]] = None, + ) -> "ReplayBuffer": + """ + Convert a LeRobotDataset into a ReplayBuffer. + + Args: + lerobot_dataset (LeRobotDataset): The dataset to convert. + device (str): The device . Defaults to "cuda:0". + state_keys (Optional[Sequence[str]], optional): The list of keys that appear in `state` and `next_state`. + Defaults to None. + + Returns: + ReplayBuffer: The replay buffer with offline dataset transitions. + """ + # We convert the LeRobotDataset into a replay buffer, because it is more efficient to sample from + # a replay buffer than from a lerobot dataset. + replay_buffer = cls( + capacity=len(lerobot_dataset), device=device, state_keys=state_keys + ) + list_transition = cls._lerobotdataset_to_transitions( + dataset=lerobot_dataset, state_keys=state_keys + ) + # Fill the replay buffer with the lerobot dataset transitions + for data in list_transition: + replay_buffer.add( + state=data["state"], + action=data["action"], + reward=data["reward"], + next_state=data["next_state"], + done=data["done"], + ) + return replay_buffer + + @staticmethod + def _lerobotdataset_to_transitions( + dataset: LeRobotDataset, + state_keys: Optional[Sequence[str]] = None, + ) -> list[Transition]: + """ + Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. + + Args: + dataset (LeRobotDataset): + The dataset to convert. Each item in the dataset is expected to have + at least the following keys: + { + "action": ... + "next.reward": ... + "next.done": ... + "episode_index": ... + } + plus whatever your 'state_keys' specify. + + state_keys (Optional[Sequence[str]]): + The dataset keys to include in 'state' and 'next_state'. Their names + will be kept as-is in the output transitions. E.g. + ["observation.state", "observation.environment_state"]. + If None, you must handle or define default keys. + + Returns: + transitions (List[Transition]): + A list of Transition dictionaries with the same length as `dataset`. + """ + + # If not provided, you can either raise an error or define a default: + if state_keys is None: + raise ValueError( + "You must provide a list of keys in `state_keys` that define your 'state'." + ) + + transitions: list[Transition] = [] + num_frames = len(dataset) + + for i in tqdm(range(num_frames)): + current_sample = dataset[i] + + # ----- 1) Current state ----- + current_state: dict[str, torch.Tensor] = {} + for key in state_keys: + val = current_sample[key] + current_state[key] = val.unsqueeze(0) # Add batch dimension + + # ----- 2) Action ----- + action = current_sample["action"].unsqueeze(0) # Add batch dimension + + # ----- 3) Reward and done ----- + reward = float(current_sample["next.reward"].item()) # ensure float + done = bool(current_sample["next.done"].item()) # ensure bool + + # ----- 4) Next state ----- + # If not done and the next sample is in the same episode, we pull the next sample's state. + # Otherwise (done=True or next sample crosses to a new episode), next_state = current_state. + next_state = current_state # default + if not done and (i < num_frames - 1): + next_sample = dataset[i + 1] + if next_sample["episode_index"] == current_sample["episode_index"]: + # Build next_state from the same keys + next_state_data: dict[str, torch.Tensor] = {} + for key in state_keys: + val = next_sample[key] + next_state_data[key] = val.unsqueeze(0) # Add batch dimension + next_state = next_state_data + + # ----- Construct the Transition ----- + transition = Transition( + state=current_state, + action=action, + reward=reward, + next_state=next_state, + done=done, + ) + transitions.append(transition) + + return transitions + + def sample(self, batch_size: int) -> BatchTransition: + """Sample a random batch of transitions and collate them into batched tensors.""" + list_of_transitions = random.sample(self.memory, batch_size) + + # -- Build batched states -- + batch_state = {} + for key in self.state_keys: + batch_state[key] = torch.cat( + [t["state"][key] for t in list_of_transitions], dim=0 + ).to(self.device) + if key.startswith("observation.image") and self.use_drq: + batch_state[key] = self.image_augmentation_function(batch_state[key]) + + # -- Build batched actions -- + batch_actions = torch.cat([t["action"] for t in list_of_transitions]).to( + self.device + ) + + # -- Build batched rewards -- + batch_rewards = torch.tensor( + [t["reward"] for t in list_of_transitions], dtype=torch.float32 + ).to(self.device) + + # -- Build batched next states -- + batch_next_state = {} + for key in self.state_keys: + batch_next_state[key] = torch.cat( + [t["next_state"][key] for t in list_of_transitions], dim=0 + ).to(self.device) + if key.startswith("observation.image") and self.use_drq: + batch_next_state[key] = self.image_augmentation_function( + batch_next_state[key] + ) + + # -- Build batched dones -- + batch_dones = torch.tensor( + [t["done"] for t in list_of_transitions], dtype=torch.float32 + ).to(self.device) + batch_dones = torch.tensor( + [t["done"] for t in list_of_transitions], dtype=torch.float32 + ).to(self.device) + + # Return a BatchTransition typed dict + return BatchTransition( + state=batch_state, + action=batch_actions, + reward=batch_rewards, + next_state=batch_next_state, + done=batch_dones, + ) + + +def concatenate_batch_transitions( + left_batch_transitions: BatchTransition, right_batch_transition: BatchTransition +) -> BatchTransition: + """NOTE: Be careful it change the left_batch_transitions in place""" + left_batch_transitions["state"] = { + key: torch.cat( + [ + left_batch_transitions["state"][key], + right_batch_transition["state"][key], + ], + dim=0, + ) + for key in left_batch_transitions["state"] + } + left_batch_transitions["action"] = torch.cat( + [left_batch_transitions["action"], right_batch_transition["action"]], dim=0 + ) + left_batch_transitions["reward"] = torch.cat( + [left_batch_transitions["reward"], right_batch_transition["reward"]], dim=0 + ) + left_batch_transitions["next_state"] = { + key: torch.cat( + [ + left_batch_transitions["next_state"][key], + right_batch_transition["next_state"][key], + ], + dim=0, + ) + for key in left_batch_transitions["next_state"] + } + left_batch_transitions["done"] = torch.cat( + [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 + ) + return left_batch_transitions + + +def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None): + if out_dir is None: + raise NotImplementedError() + if job_name is None: + raise NotImplementedError() + + init_logging() + logging.info(pformat(OmegaConf.to_container(cfg))) + + # Create an env dedicated to online episodes collection from policy rollout. + # online_env = make_env(cfg, n_envs=cfg.training.online_rollout_batch_size) + # NOTE: Off policy algorithm are efficient enought to use a single environment + logging.info("make_env online") + # online_env = make_env(cfg, n_envs=1) + # TODO: Remove the import of maniskill and unifiy with make env + online_env = make_maniskill_env(cfg, n_envs=1) + if cfg.training.eval_freq > 0: + logging.info("make_env eval") + # eval_env = make_env(cfg, n_envs=1) + # TODO: Remove the import of maniskill and unifiy with make env + eval_env = make_maniskill_env(cfg, n_envs=1) + + # TODO: Add a way to resume training + + # log metrics to terminal and wandb + logger = Logger(cfg, out_dir, wandb_job_name=job_name) + + set_global_seed(cfg.seed) + + # Check device is available + device = get_safe_torch_device(cfg.device, log=True) + + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info("make_policy") + # TODO: At some point we should just need make sac policy + policy: SACPolicy = make_policy( + hydra_cfg=cfg, + # dataset_stats=offline_dataset.meta.stats if not cfg.resume else None, + # Hack: But if we do online traning, we do not need dataset_stats + dataset_stats=None, + pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) + if cfg.resume + else None, + device=device, + ) + assert isinstance(policy, nn.Module) + + optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy) + + # TODO: Handle resume + + num_learnable_params = sum( + p.numel() for p in policy.parameters() if p.requires_grad + ) + num_total_params = sum(p.numel() for p in policy.parameters()) + + log_output_dir(out_dir) + logging.info(f"{cfg.env.task=}") + logging.info(f"{cfg.training.online_steps=}") + logging.info(f"{num_learnable_params=} ({format_big_number(num_learnable_params)})") + logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})") + + obs, info = online_env.reset() + + # HACK for maniskill + # obs = preprocess_observation(obs) + obs = preprocess_maniskill_observation(obs) + obs = {key: obs[key].to(device, non_blocking=True) for key in obs} + + replay_buffer = ReplayBuffer( + capacity=cfg.training.online_buffer_capacity, + device=device, + state_keys=cfg.policy.input_shapes.keys(), + ) + + batch_size = cfg.training.batch_size + + if cfg.dataset_repo_id is not None: + logging.info("make_dataset offline buffer") + offline_dataset = make_dataset(cfg) + logging.info("Convertion to a offline replay buffer") + offline_replay_buffer = ReplayBuffer.from_lerobot_dataset( + offline_dataset, device=device, state_keys=cfg.policy.input_shapes.keys() + ) + batch_size: int = batch_size // 2 # We will sample from both replay buffer + + # NOTE: For the moment we will solely handle the case of a single environment + sum_reward_episode = 0 + + for interaction_step in range(cfg.training.online_steps): + # NOTE: At some point we should use a wrapper to handle the observation + + if interaction_step >= cfg.training.online_step_before_learning: + action = policy.select_action(batch=obs) + next_obs, reward, done, truncated, info = online_env.step( + action.cpu().numpy() + ) + else: + action = online_env.action_space.sample() + next_obs, reward, done, truncated, info = online_env.step(action) + # HACK + action = torch.tensor(action, dtype=torch.float32).to( + device, non_blocking=True + ) + + # HACK: For maniskill + # next_obs = preprocess_observation(next_obs) + next_obs = preprocess_maniskill_observation(next_obs) + next_obs = {key: next_obs[key].to(device, non_blocking=True) for key in obs} + sum_reward_episode += float(reward[0]) + # Because we are using a single environment + # we can safely assume that the episode is done + if done[0] or truncated[0]: + logging.info( + f"Global step {interaction_step}: Episode reward: {sum_reward_episode}" + ) + logger.log_dict( + {"Sum episode reward": sum_reward_episode}, interaction_step + ) + sum_reward_episode = 0 + # HACK: This is for maniskill + logging.info( + f"global step {interaction_step}: episode success: {info['success'].float().item()} \n" + ) + logger.log_dict( + {"Episode success": info["success"].float().item()}, interaction_step + ) + + replay_buffer.add( + state=obs, + action=action, + reward=float(reward[0]), + next_state=next_obs, + done=done[0], + ) + obs = next_obs + + if interaction_step < cfg.training.online_step_before_learning: + continue + for _ in range(cfg.policy.utd_ratio - 1): + batch = replay_buffer.sample(batch_size) + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions(batch, batch_offline) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() + + batch = replay_buffer.sample(batch_size) + if cfg.dataset_repo_id is not None: + batch_offline = offline_replay_buffer.sample(batch_size) + batch = concatenate_batch_transitions( + left_batch_transitions=batch, right_batch_transition=batch_offline + ) + + actions = batch["action"] + rewards = batch["reward"] + observations = batch["state"] + next_observations = batch["next_state"] + done = batch["done"] + + loss_critic = policy.compute_loss_critic( + observations=observations, + actions=actions, + rewards=rewards, + next_observations=next_observations, + done=done, + ) + optimizers["critic"].zero_grad() + loss_critic.backward() + optimizers["critic"].step() + + training_infos = {} + training_infos["loss_critic"] = loss_critic.item() + + if interaction_step % cfg.training.policy_update_freq == 0: + # TD3 Trick + for _ in range(cfg.training.policy_update_freq): + loss_actor = policy.compute_loss_actor(observations=observations) + + optimizers["actor"].zero_grad() + loss_actor.backward() + optimizers["actor"].step() + + training_infos["loss_actor"] = loss_actor.item() + + loss_temperature = policy.compute_loss_temperature( + observations=observations + ) + optimizers["temperature"].zero_grad() + loss_temperature.backward() + optimizers["temperature"].step() + + training_infos["loss_temperature"] = loss_temperature.item() + + if interaction_step % cfg.training.log_freq == 0: + logger.log_dict(training_infos, interaction_step, mode="train") + + policy.update_target_networks() + + +@hydra.main(version_base="1.2", config_name="default", config_path="../configs") +def train_cli(cfg: dict): + train( + cfg, + out_dir=hydra.core.hydra_config.HydraConfig.get().run.dir, + job_name=hydra.core.hydra_config.HydraConfig.get().job.name, + ) + + +def train_notebook( + out_dir=None, job_name=None, config_name="default", config_path="../configs" +): + from hydra import compose, initialize + + hydra.core.global_hydra.GlobalHydra.instance().clear() + initialize(config_path=config_path) + cfg = compose(config_name=config_name) + train(cfg, out_dir=out_dir, job_name=job_name) + + +if __name__ == "__main__": + train_cli() diff --git a/lerobot/scripts/visualize_dataset.py b/lerobot/scripts/visualize_dataset.py index cdd5ce605..25bac4d3d 100644 --- a/lerobot/scripts/visualize_dataset.py +++ b/lerobot/scripts/visualize_dataset.py @@ -94,8 +94,12 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.ndim == 3 c, h, w = chw_float32_torch.shape - assert c < h and c < w, f"expect channel first images, but instead {chw_float32_torch.shape}" - hwc_uint8_numpy = (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + assert ( + c < h and c < w + ), f"expect channel first images, but instead {chw_float32_torch.shape}" + hwc_uint8_numpy = ( + (chw_float32_torch * 255).type(torch.uint8).permute(1, 2, 0).numpy() + ) return hwc_uint8_numpy diff --git a/lerobot/scripts/visualize_dataset_html.py b/lerobot/scripts/visualize_dataset_html.py index 2c81fbfc5..5f7e371ca 100644 --- a/lerobot/scripts/visualize_dataset_html.py +++ b/lerobot/scripts/visualize_dataset_html.py @@ -53,33 +53,86 @@ """ import argparse +import csv +import json import logging +import re import shutil +import tempfile +from io import StringIO from pathlib import Path -import tqdm -from flask import Flask, redirect, render_template, url_for +import numpy as np +import pandas as pd +import requests +from flask import Flask, redirect, render_template, request, url_for +from lerobot import available_datasets from lerobot.common.datasets.lerobot_dataset import LeRobotDataset +from lerobot.common.datasets.utils import IterableNamespace from lerobot.common.utils.utils import init_logging def run_server( - dataset: LeRobotDataset, - episodes: list[int], + dataset: LeRobotDataset | IterableNamespace | None, + episodes: list[int] | None, host: str, port: str, static_folder: Path, template_folder: Path, ): - app = Flask(__name__, static_folder=static_folder.resolve(), template_folder=template_folder.resolve()) + app = Flask( + __name__, + static_folder=static_folder.resolve(), + template_folder=template_folder.resolve(), + ) app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 0 # specifying not to cache @app.route("/") - def index(): - # home page redirects to the first episode page - [dataset_namespace, dataset_name] = dataset.repo_id.split("/") - first_episode_id = episodes[0] + def hommepage(dataset=dataset): + if dataset: + dataset_namespace, dataset_name = dataset.repo_id.split("/") + return redirect( + url_for( + "show_episode", + dataset_namespace=dataset_namespace, + dataset_name=dataset_name, + episode_id=0, + ) + ) + + dataset_param, episode_param = None, None + all_params = request.args + if "dataset" in all_params: + dataset_param = all_params["dataset"] + if "episode" in all_params: + episode_param = int(all_params["episode"]) + + if dataset_param: + dataset_namespace, dataset_name = dataset_param.split("/") + return redirect( + url_for( + "show_episode", + dataset_namespace=dataset_namespace, + dataset_name=dataset_name, + episode_id=episode_param if episode_param is not None else 0, + ) + ) + + featured_datasets = [ + "lerobot/aloha_static_cups_open", + "lerobot/columbia_cairlab_pusht_real", + "lerobot/taco_play", + ] + return render_template( + "visualize_dataset_homepage.html", + featured_datasets=featured_datasets, + lerobot_datasets=available_datasets, + ) + + @app.route("//") + def show_first_episode(dataset_namespace, dataset_name): + first_episode_id = 0 return redirect( url_for( "show_episode", @@ -89,31 +142,106 @@ def index(): ) ) - @app.route("///episode_") - def show_episode(dataset_namespace, dataset_name, episode_id): + @app.route( + "///episode_" + ) + def show_episode( + dataset_namespace, dataset_name, episode_id, dataset=dataset, episodes=episodes + ): + repo_id = f"{dataset_namespace}/{dataset_name}" + try: + if dataset is None: + dataset = get_dataset_info(repo_id) + except FileNotFoundError: + return ( + "Make sure to convert your LeRobotDataset to v2 & above. See how to convert your dataset at https://github.com/huggingface/lerobot/pull/461", + 400, + ) + dataset_version = ( + dataset.meta._version + if isinstance(dataset, LeRobotDataset) + else dataset.codebase_version + ) + match = re.search(r"v(\d+)\.", dataset_version) + if match: + major_version = int(match.group(1)) + if major_version < 2: + return "Make sure to convert your LeRobotDataset to v2 & above." + + episode_data_csv_str, columns = get_episode_data(dataset, episode_id) dataset_info = { - "repo_id": dataset.repo_id, - "num_samples": dataset.num_frames, - "num_episodes": dataset.num_episodes, + "repo_id": f"{dataset_namespace}/{dataset_name}", + "num_samples": dataset.num_frames + if isinstance(dataset, LeRobotDataset) + else dataset.total_frames, + "num_episodes": dataset.num_episodes + if isinstance(dataset, LeRobotDataset) + else dataset.total_episodes, "fps": dataset.fps, } - video_paths = [dataset.meta.get_video_file_path(episode_id, key) for key in dataset.meta.video_keys] - tasks = dataset.meta.episodes[episode_id]["tasks"] - videos_info = [ - {"url": url_for("static", filename=video_path), "filename": video_path.name} - for video_path in video_paths - ] + if isinstance(dataset, LeRobotDataset): + video_paths = [ + dataset.meta.get_video_file_path(episode_id, key) + for key in dataset.meta.video_keys + ] + videos_info = [ + { + "url": url_for("static", filename=video_path), + "filename": video_path.parent.name, + } + for video_path in video_paths + ] + tasks = dataset.meta.episodes[episode_id]["tasks"] + else: + video_keys = [ + key for key, ft in dataset.features.items() if ft["dtype"] == "video" + ] + videos_info = [ + { + "url": f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + + dataset.video_path.format( + episode_chunk=int(episode_id) // dataset.chunks_size, + video_key=video_key, + episode_index=episode_id, + ), + "filename": video_key, + } + for video_key in video_keys + ] + + response = requests.get( + f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/episodes.jsonl" + ) + response.raise_for_status() + # Split into lines and parse each line as JSON + tasks_jsonl = [ + json.loads(line) for line in response.text.splitlines() if line.strip() + ] + + filtered_tasks_jsonl = [ + row for row in tasks_jsonl if row["episode_index"] == episode_id + ] + tasks = filtered_tasks_jsonl[0]["tasks"] + videos_info[0]["language_instruction"] = tasks - ep_csv_url = url_for("static", filename=get_ep_csv_fname(episode_id)) + if episodes is None: + episodes = list( + range( + dataset.num_episodes + if isinstance(dataset, LeRobotDataset) + else dataset.total_episodes + ) + ) + return render_template( "visualize_dataset_template.html", episode_id=episode_id, episodes=episodes, dataset_info=dataset_info, videos_info=videos_info, - ep_csv_url=ep_csv_url, - has_policy=False, + episode_data_csv_str=episode_data_csv_str, + columns=columns, ) app.run(host=host, port=port) @@ -124,46 +252,78 @@ def get_ep_csv_fname(episode_id: int): return ep_csv_fname -def write_episode_data_csv(output_dir, file_name, episode_index, dataset): - """Write a csv file containg timeseries data of an episode (e.g. state and action). +def get_episode_data(dataset: LeRobotDataset | IterableNamespace, episode_index): + """Get a csv str containing timeseries data of an episode (e.g. state and action). This file will be loaded by Dygraph javascript to plot data in real time.""" - from_idx = dataset.episode_data_index["from"][episode_index] - to_idx = dataset.episode_data_index["to"][episode_index] + columns = [] - has_state = "observation.state" in dataset.features - has_action = "action" in dataset.features + selected_columns = [ + col for col, ft in dataset.features.items() if ft["dtype"] == "float32" + ] + selected_columns.remove("timestamp") # init header of csv with state and action names header = ["timestamp"] - if has_state: - dim_state = dataset.meta.shapes["observation.state"][0] - header += [f"state_{i}" for i in range(dim_state)] - if has_action: - dim_action = dataset.meta.shapes["action"][0] - header += [f"action_{i}" for i in range(dim_action)] - - columns = ["timestamp"] - if has_state: - columns += ["observation.state"] - if has_action: - columns += ["action"] - - rows = [] - data = dataset.hf_dataset.select_columns(columns) - for i in range(from_idx, to_idx): - row = [data[i]["timestamp"].item()] - if has_state: - row += data[i]["observation.state"].tolist() - if has_action: - row += data[i]["action"].tolist() - rows.append(row) - output_dir.mkdir(parents=True, exist_ok=True) - with open(output_dir / file_name, "w") as f: - f.write(",".join(header) + "\n") - for row in rows: - row_str = [str(col) for col in row] - f.write(",".join(row_str) + "\n") + for column_name in selected_columns: + dim_state = ( + dataset.meta.shapes[column_name][0] + if isinstance(dataset, LeRobotDataset) + else dataset.features[column_name].shape[0] + ) + header += [f"{column_name}_{i}" for i in range(dim_state)] + + if ( + "names" in dataset.features[column_name] + and dataset.features[column_name]["names"] + ): + column_names = dataset.features[column_name]["names"] + while not isinstance(column_names, list): + column_names = list(column_names.values())[0] + else: + column_names = [f"motor_{i}" for i in range(dim_state)] + columns.append({"key": column_name, "value": column_names}) + + selected_columns.insert(0, "timestamp") + + if isinstance(dataset, LeRobotDataset): + from_idx = dataset.episode_data_index["from"][episode_index] + to_idx = dataset.episode_data_index["to"][episode_index] + data = ( + dataset.hf_dataset.select(range(from_idx, to_idx)) + .select_columns(selected_columns) + .with_format("pandas") + ) + else: + repo_id = dataset.repo_id + + url = ( + f"https://huggingface.co/datasets/{repo_id}/resolve/main/" + + dataset.data_path.format( + episode_chunk=int(episode_index) // dataset.chunks_size, + episode_index=episode_index, + ) + ) + df = pd.read_parquet(url) + data = df[selected_columns] # Select specific columns + + rows = np.hstack( + ( + np.expand_dims(data["timestamp"], axis=1), + *[np.vstack(data[col]) for col in selected_columns[1:]], + ) + ).tolist() + + # Convert data to CSV string + csv_buffer = StringIO() + csv_writer = csv.writer(csv_buffer) + # Write header + csv_writer.writerow(header) + # Write data rows + csv_writer.writerows(rows) + csv_string = csv_buffer.getvalue() + + return csv_string, columns def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str]: @@ -175,9 +335,37 @@ def get_episode_video_paths(dataset: LeRobotDataset, ep_index: int) -> list[str] ] +def get_episode_language_instruction( + dataset: LeRobotDataset, ep_index: int +) -> list[str]: + # check if the dataset has language instructions + if "language_instruction" not in dataset.features: + return None + + # get first frame index + first_frame_idx = dataset.episode_data_index["from"][ep_index].item() + + language_instruction = dataset.hf_dataset[first_frame_idx]["language_instruction"] + # TODO (michel-aractingi) hack to get the sentence, some strings in openx are badly stored + # with the tf.tensor appearing in the string + return language_instruction.removeprefix("tf.Tensor(b'").removesuffix( + "', shape=(), dtype=string)" + ) + + +def get_dataset_info(repo_id: str) -> IterableNamespace: + response = requests.get( + f"https://huggingface.co/datasets/{repo_id}/resolve/main/meta/info.json" + ) + response.raise_for_status() # Raises an HTTPError for bad responses + dataset_info = response.json() + dataset_info["repo_id"] = repo_id + return IterableNamespace(dataset_info) + + def visualize_dataset_html( - dataset: LeRobotDataset, - episodes: list[int] = None, + dataset: LeRobotDataset | None, + episodes: list[int] | None = None, output_dir: Path | None = None, serve: bool = True, host: str = "127.0.0.1", @@ -186,43 +374,46 @@ def visualize_dataset_html( ) -> Path | None: init_logging() - if len(dataset.meta.image_keys) > 0: - raise NotImplementedError(f"Image keys ({dataset.meta.image_keys=}) are currently not supported.") + template_dir = Path(__file__).resolve().parent.parent / "templates" if output_dir is None: - output_dir = f"outputs/visualize_dataset_html/{dataset.repo_id}" + # Create a temporary directory that will be automatically cleaned up + output_dir = tempfile.mkdtemp(prefix="lerobot_visualize_dataset_") output_dir = Path(output_dir) if output_dir.exists(): if force_override: shutil.rmtree(output_dir) else: - logging.info(f"Output directory already exists. Loading from it: '{output_dir}'") + logging.info( + f"Output directory already exists. Loading from it: '{output_dir}'" + ) output_dir.mkdir(parents=True, exist_ok=True) - # Create a simlink from the dataset video folder containg mp4 files to the output directory - # so that the http server can get access to the mp4 files. static_dir = output_dir / "static" static_dir.mkdir(parents=True, exist_ok=True) - ln_videos_dir = static_dir / "videos" - if not ln_videos_dir.exists(): - ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) - - template_dir = Path(__file__).resolve().parent.parent / "templates" - if episodes is None: - episodes = list(range(dataset.num_episodes)) - - logging.info("Writing CSV files") - for episode_index in tqdm.tqdm(episodes): - # write states and actions in a csv (it can be slow for big datasets) - ep_csv_fname = get_ep_csv_fname(episode_index) - # TODO(rcadene): speedup script by loading directly from dataset, pyarrow, parquet, safetensors? - write_episode_data_csv(static_dir, ep_csv_fname, episode_index, dataset) + if dataset is None: + if serve: + run_server( + dataset=None, + episodes=None, + host=host, + port=port, + static_folder=static_dir, + template_folder=template_dir, + ) + else: + # Create a simlink from the dataset video folder containg mp4 files to the output directory + # so that the http server can get access to the mp4 files. + if isinstance(dataset, LeRobotDataset): + ln_videos_dir = static_dir / "videos" + if not ln_videos_dir.exists(): + ln_videos_dir.symlink_to((dataset.root / "videos").resolve()) - if serve: - run_server(dataset, episodes, host, port, static_dir, template_dir) + if serve: + run_server(dataset, episodes, host, port, static_dir, template_dir) def main(): @@ -231,7 +422,7 @@ def main(): parser.add_argument( "--repo-id", type=str, - required=True, + default=None, help="Name of hugging face repositery containing a LeRobotDataset dataset (e.g. `lerobot/pusht` for https://huggingface.co/datasets/lerobot/pusht).", ) parser.add_argument( @@ -246,6 +437,12 @@ def main(): default=None, help="Root directory for a dataset stored locally (e.g. `--root data`). By default, the dataset will be loaded from hugging face cache folder, or downloaded from the hub if available.", ) + parser.add_argument( + "--load-from-hf-hub", + type=int, + default=0, + help="Load videos and parquet files from HF Hub rather than local system.", + ) parser.add_argument( "--episodes", type=int, @@ -287,11 +484,19 @@ def main(): args = parser.parse_args() kwargs = vars(args) repo_id = kwargs.pop("repo_id") + load_from_hf_hub = kwargs.pop("load_from_hf_hub") root = kwargs.pop("root") local_files_only = kwargs.pop("local_files_only") - dataset = LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) - visualize_dataset_html(dataset, **kwargs) + dataset = None + if repo_id: + dataset = ( + LeRobotDataset(repo_id, root=root, local_files_only=local_files_only) + if not load_from_hf_hub + else get_dataset_info(repo_id) + ) + + visualize_dataset_html(dataset, **vars(args)) if __name__ == "__main__": diff --git a/lerobot/scripts/visualize_image_transforms.py b/lerobot/scripts/visualize_image_transforms.py index f9fb5c08a..a4ae4b5f3 100644 --- a/lerobot/scripts/visualize_image_transforms.py +++ b/lerobot/scripts/visualize_image_transforms.py @@ -162,8 +162,12 @@ def visualize_transforms(cfg, output_dir: Path, n_examples: int = 5): print("\nOriginal frame saved to:") print(f" {output_dir / 'original_frame.png'}.") - save_config_all_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples) - save_config_single_transforms(cfg.training.image_transforms, original_frame, output_dir, n_examples) + save_config_all_transforms( + cfg.training.image_transforms, original_frame, output_dir, n_examples + ) + save_config_single_transforms( + cfg.training.image_transforms, original_frame, output_dir, n_examples + ) @hydra.main(version_base="1.2", config_name="default", config_path="../configs") diff --git a/lerobot/templates/visualize_dataset_homepage.html b/lerobot/templates/visualize_dataset_homepage.html new file mode 100644 index 000000000..19613afb5 --- /dev/null +++ b/lerobot/templates/visualize_dataset_homepage.html @@ -0,0 +1,68 @@ + + + + + + Interactive Video Background Page + + + + +
+ +
+
+
+
+

LeRobot Dataset Visualizer

+ + create & train your own robots + +

+
+

Example Datasets:

+
    + {% for dataset in featured_datasets %} +
  • {{ dataset }}
  • + {% endfor %} +
+
+
+
+ + +
+ +
+ More example datasets +
    + {% for dataset in lerobot_datasets %} +
  • {{ dataset }}
  • + {% endfor %} +
+
+
+ + diff --git a/lerobot/templates/visualize_dataset_template.html b/lerobot/templates/visualize_dataset_template.html index 0fa1e713e..08de3e3dd 100644 --- a/lerobot/templates/visualize_dataset_template.html +++ b/lerobot/templates/visualize_dataset_template.html @@ -31,11 +31,16 @@ }">
-

{{ dataset_info.repo_id }}

+ + +

{{ dataset_info.repo_id }}

+
  • - Number of samples/frames: {{ dataset_info.num_frames }} + Number of samples/frames: {{ dataset_info.num_samples }}
  • Number of episodes: {{ dataset_info.num_episodes }} @@ -93,10 +98,35 @@

-
+
+
+ filter videos +
🔽
+
+ +
+
+ +
+
+
+ +
{% for video_info in videos_info %} -
-

{{ video_info.filename }}

+
+

{{ video_info.filename }}