Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix SAC and port HIL SERL #644

Open
wants to merge 70 commits into
base: user/michel-aractingi/2024-11-27-port-hil-serl
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
0a4e9e2
[vizualizer] for LeRobodDataset V2 (#576)
mishig25 Dec 20, 2024
4a43c83
Fix broken `create_lerobot_dataset_card` (#590)
helper2424 Dec 23, 2024
b1cfb6a
Update README.md (#612)
CharlesCNorton Jan 3, 2025
31c34a4
Fix Quality workflow (#622)
aliberts Jan 8, 2025
d649815
fix(docs): typos in benchmark readme.md (#614)
CharlesCNorton Jan 9, 2025
a1b5d0f
fix(visualise): use correct language description for each episode id …
villekuosmanen Jan 9, 2025
c2f7af3
typo fix: batch_convert_dataset_v1_to_v2.py (#615)
CharlesCNorton Jan 9, 2025
100f54e
[viz] Fixes & updates to html visualizer (#617)
mishig25 Jan 9, 2025
df7310e
fixes to SO-100 readme (#600)
philfung Jan 10, 2025
068efce
Fix for the issue https://github.com/huggingface/lerobot/issues/638 (…
PradeepKadubandi Jan 15, 2025
472a7f5
[WIP] correct sac implementation
AdilZouitine Jan 13, 2025
c86dace
remove breakpoint
AdilZouitine Jan 13, 2025
a0a50de
SAC works
AdilZouitine Jan 14, 2025
be96501
Add rlpd tricks
AdilZouitine Jan 15, 2025
956c547
[WIP] correct sac implementation
AdilZouitine Jan 13, 2025
86df8a4
remove breakpoint
AdilZouitine Jan 13, 2025
c1d4bf4
SAC works
AdilZouitine Jan 14, 2025
8105efb
Add rlpd tricks
AdilZouitine Jan 15, 2025
7d2970f
Change SAC policy implementation with configuration and modeling classes
AdilZouitine Jan 17, 2025
1fb03d4
Add type annotations and restructure SACConfig class fields
AdilZouitine Jan 21, 2025
d75b44f
Stable version of rlpd + drq
AdilZouitine Jan 22, 2025
322a78a
Added server directory in `lerobot/scripts` that contains scripts and…
michel-aractingi Jan 28, 2025
36576c9
FREEDOM, added back the optimization loop code in `learner_server.py`
michel-aractingi Jan 28, 2025
42618f4
- Added additional logging information in wandb around the timings of…
michel-aractingi Jan 29, 2025
9aabe21
Added missing config files `env/maniskill_example.yaml` and `policy/s…
michel-aractingi Jan 29, 2025
e856ffc
Removed unnecessary time.sleep in the streaming server on the learner…
michel-aractingi Jan 29, 2025
367dfe5
Added support for checkpointing the policy. We can save and load the …
michel-aractingi Jan 30, 2025
7c89bd1
Cleaned `learner_server.py`. Added several block function to improve …
michel-aractingi Jan 31, 2025
f1c8bfe
[Port HIL-SERL] Add HF vision encoder option in SAC (#651)
ChorntonYoel Jan 31, 2025
506821c
- Refactor observation encoder in `modeling_sac.py`
michel-aractingi Jan 31, 2025
2211209
- Added base gym env class for the real robot environment.
michel-aractingi Feb 3, 2025
efb1982
Added crop_dataset_roi.py that allows you to load a lerobotdataset ->…
michel-aractingi Feb 3, 2025
e0527b4
Added additional wrappers for the environment: Action repeat, keyboar…
michel-aractingi Feb 4, 2025
7d5a953
fixed bug in crop_dataset_roi.py
michel-aractingi Feb 5, 2025
1252524
- Added `lerobot/scripts/server/gym_manipulator.py` that contains all…
michel-aractingi Feb 6, 2025
b637386
[HIL-SERL port] Add Reward classifier benchmark tracking to chose bes…
helper2424 Feb 6, 2025
d51374c
Several fixes to move the actor_server and learner_server code from t…
michel-aractingi Feb 10, 2025
b5f8943
Added sac_real config file in the policym configs dir.
michel-aractingi Feb 10, 2025
a7db395
- Added JointMaskingActionSpace wrapper in `gym_manipulator` in order…
michel-aractingi Feb 11, 2025
a1d16fb
[Port HIL-SERL] Add resnet-10 as default encoder for HIL-SERL (#696)
helper2424 Feb 11, 2025
6868c88
[PORT-Hilserl] classifier fixes (#695)
ChorntonYoel Feb 11, 2025
b9217b0
Added possiblity to record and replay delta actions during teleoperat…
michel-aractingi Feb 12, 2025
dc086dc
Added logging for interventions to monitor the rate of interventions …
michel-aractingi Feb 13, 2025
459f22e
fix log_alpha in modeling_sac: change to nn.parameter
michel-aractingi Feb 13, 2025
c462a47
Hardcoded some normalization parameters. TODO refactor
michel-aractingi Feb 13, 2025
0c32008
Changed bounds for a new so100 robot
michel-aractingi Feb 13, 2025
d9a7037
Changed the init_final value to center the starting mean and std of t…
michel-aractingi Feb 13, 2025
b07d95f
removed uncomment in actor server
michel-aractingi Feb 13, 2025
95de8e2
nit
michel-aractingi Feb 13, 2025
c9e50bb
Optimized the replay buffer from the memory side to store data on cpu…
michel-aractingi Feb 13, 2025
36711d7
Modified crop_dataset_roi interface to automatically write the croppe…
michel-aractingi Feb 14, 2025
7ae368e
Fixed bug in the action scale of the intervention actions and offline…
michel-aractingi Feb 14, 2025
2f3370e
Add maniskill support.
AdilZouitine Feb 14, 2025
446f434
Improve wandb logging and custom step tracking in logger
AdilZouitine Feb 17, 2025
befa1fe
Re-enable parameter push thread in learner server
AdilZouitine Feb 17, 2025
ff47c0b
- Fixed big issue in the loading of the policy parameters sent by the…
michel-aractingi Feb 19, 2025
ff82367
Refactor SAC policy with performance optimizations and multi-camera s…
AdilZouitine Feb 20, 2025
3ffe0cf
[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependen…
helper2424 Feb 21, 2025
5467191
Added caching function in the learner_server and modeling sac in orde…
michel-aractingi Feb 21, 2025
42a0381
Update ManiSkill configuration and replay buffer to support truncatio…
AdilZouitine Feb 24, 2025
ef8d943
Refactor ReplayBuffer with tensor-based storage and improved sampling…
AdilZouitine Feb 25, 2025
5b4a7aa
Add storage device parameter to replay buffer initialization
AdilZouitine Feb 25, 2025
1df9ee4
Add memory optimization option to ReplayBuffer
AdilZouitine Feb 25, 2025
d8a1758
Add storage device configuration for SAC policy and replay buffer
AdilZouitine Mar 4, 2025
584cad8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 4, 2025
700f00c
[HIL-SERL] Migrate threading to multiprocessing (#759)
helper2424 Mar 5, 2025
d711e20
[Port HIL-SERL] Balanced sampler function speed up and refactor to al…
s1lent4gnt Mar 12, 2025
25b88f3
Remove torch.no_grad decorator and optimize next action prediction in…
AdilZouitine Mar 10, 2025
5081c14
Add custom save and load methods for SAC policy
AdilZouitine Mar 12, 2025
41219fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/quality.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:
uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry
run: pipx install "poetry<2.0.0"

- name: Poetry check
run: poetry check
Expand All @@ -64,7 +64,7 @@ jobs:
uses: actions/checkout@v3

- name: Install poetry
run: pipx install poetry
run: pipx install "poetry<2.0.0"

- name: Install poetry-relax
run: poetry self add poetry-relax
Expand Down
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ repos:
rev: v3.19.0
hooks:
- id: pyupgrade
exclude: '^(.*_pb2_grpc\.py|.*_pb2\.py$)'
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.2
hooks:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@

### Acknowledgment

- Thanks to Tony Zaho, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Tony Zhao, Zipeng Fu and colleagues for open sourcing ACT policy, ALOHA environments and datasets. Ours are adapted from [ALOHA](https://tonyzhaozh.github.io/aloha) and [Mobile ALOHA](https://mobile-aloha.github.io).
- Thanks to Cheng Chi, Zhenjia Xu and colleagues for open sourcing Diffusion policy, Pusht environment and datasets, as well as UMI datasets. Ours are adapted from [Diffusion Policy](https://diffusion-policy.cs.columbia.edu) and [UMI Gripper](https://umi-gripper.github.io).
- Thanks to Nicklas Hansen, Yunhai Feng and colleagues for open sourcing TDMPC policy, Simxarm environments and datasets. Ours are adapted from [TDMPC](https://github.com/nicklashansen/tdmpc) and [FOWM](https://www.yunhaifeng.com/FOWM).
- Thanks to Antonio Loquercio and Ashish Kumar for their early support.
Expand Down
10 changes: 5 additions & 5 deletions benchmarks/video/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ How to decode videos?

## Variables
**Image content & size**
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an appartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
We don't expect the same optimal settings for a dataset of images from a simulation, or from real-world in an apartment, or in a factory, or outdoor, or with lots of moving objects in the scene, etc. Similarly, loading times might not vary linearly with the image size (resolution).
For these reasons, we run this benchmark on four representative datasets:
- `lerobot/pusht_image`: (96 x 96 pixels) simulation with simple geometric shapes, fixed camera.
- `aliberts/aloha_mobile_shrimp_image`: (480 x 640 pixels) real-world indoor, moving camera.
Expand Down Expand Up @@ -63,7 +63,7 @@ This of course is affected by the `-g` parameter during encoding, which specifie

Note that this differs significantly from a typical use case like watching a movie, in which every frame is loaded sequentially from the beginning to the end and it's acceptable to have big values for `-g`.

Additionally, because some policies might request single timestamps that are a few frames appart, we also have the following scenario:
Additionally, because some policies might request single timestamps that are a few frames apart, we also have the following scenario:
- `2_frames_4_space`: 2 frames with 4 consecutive frames of spacing in between (e.g `[t, t + 5 / fps]`),

However, due to how video decoding is implemented with `pyav`, we don't have access to an accurate seek so in practice this scenario is essentially the same as `6_frames` since all 6 frames between `t` and `t + 5 / fps` will be decoded.
Expand All @@ -85,8 +85,8 @@ However, due to how video decoding is implemented with `pyav`, we don't have acc
**Average Structural Similarity Index Measure (higher is better)**
`avg_ssim` evaluates the perceived quality of images by comparing luminance, contrast, and structure. SSIM values range from -1 to 1, where 1 indicates perfect similarity.

One aspect that can't be measured here with those metrics is the compatibility of the encoding accross platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not be pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
One aspect that can't be measured here with those metrics is the compatibility of the encoding across platforms, in particular on web browser, for visualization purposes.
h264, h265 and AV1 are all commonly used codecs and should not pose an issue. However, the chroma subsampling (`pix_fmt`) format might affect compatibility:
- `yuv420p` is more widely supported across various platforms, including web browsers.
- `yuv444p` offers higher color fidelity but might not be supported as broadly.

Expand Down Expand Up @@ -116,7 +116,7 @@ Additional encoding parameters exist that are not included in this benchmark. In
- `-preset` which allows for selecting encoding presets. This represents a collection of options that will provide a certain encoding speed to compression ratio. By leaving this parameter unspecified, it is considered to be `medium` for libx264 and libx265 and `8` for libsvtav1.
- `-tune` which allows to optimize the encoding for certains aspects (e.g. film quality, fast decoding, etc.).

See the documentation mentioned above for more detailled info on these settings and for a more comprehensive list of other parameters.
See the documentation mentioned above for more detailed info on these settings and for a more comprehensive list of other parameters.

Similarly on the decoding side, other decoders exist but are not implemented in our current benchmark. To name a few:
- `torchaudio`
Expand Down
86 changes: 67 additions & 19 deletions benchmarks/video/run_video_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
import pandas as pd
import PIL
import torch
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
from skimage.metrics import (
mean_squared_error,
peak_signal_noise_ratio,
structural_similarity,
)
from tqdm import tqdm

from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
Expand Down Expand Up @@ -81,7 +85,9 @@ def get_directory_size(directory: Path) -> int:
return total_size


def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> torch.Tensor:
def load_original_frames(
imgs_dir: Path, timestamps: list[float], fps: int
) -> torch.Tensor:
frames = []
for ts in timestamps:
idx = int(ts * fps)
Expand All @@ -94,7 +100,11 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t


def save_decoded_frames(
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
imgs_dir: Path,
save_dir: Path,
frames: torch.Tensor,
timestamps: list[float],
fps: int,
) -> None:
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
return
Expand All @@ -104,7 +114,10 @@ def save_decoded_frames(
idx = int(ts * fps)
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
shutil.copyfile(
imgs_dir / f"frame_{idx:06d}.png",
save_dir / f"frame_{idx:06d}_original.png",
)


def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
Expand All @@ -116,11 +129,17 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
hf_dataset = dataset.hf_dataset.with_format(None)

# We only save images from the first camera
img_keys = [key for key in hf_dataset.features if key.startswith("observation.image")]
img_keys = [
key for key in hf_dataset.features if key.startswith("observation.image")
]
imgs_dataset = hf_dataset.select_columns(img_keys[0])

for i, item in enumerate(
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
tqdm(
imgs_dataset,
desc=f"saving {dataset.repo_id} first episode images",
leave=False,
)
):
img = item[img_keys[0]]
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
Expand All @@ -129,7 +148,9 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
break


def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> list[float]:
def sample_timestamps(
timestamps_mode: str, ep_num_images: int, fps: int
) -> list[float]:
# Start at 5 to allow for 2_frames_4_space and 6_frames
idx = random.randint(5, ep_num_images - 1)
match timestamps_mode:
Expand All @@ -154,7 +175,9 @@ def decode_video_frames(
backend: str,
) -> torch.Tensor:
if backend in ["pyav", "video_reader"]:
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
return decode_video_frames_torchvision(
video_path, timestamps, tolerance_s, backend
)
else:
raise NotImplementedError(backend)

Expand All @@ -181,7 +204,9 @@ def process_sample(sample: int):
}

with time_benchmark:
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
frames = decode_video_frames(
video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend
)
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames

with time_benchmark:
Expand All @@ -190,12 +215,18 @@ def process_sample(sample: int):

frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
for i in range(num_frames):
result["mse_values"].append(mean_squared_error(original_frames_np[i], frames_np[i]))
result["mse_values"].append(
mean_squared_error(original_frames_np[i], frames_np[i])
)
result["psnr_values"].append(
peak_signal_noise_ratio(original_frames_np[i], frames_np[i], data_range=1.0)
peak_signal_noise_ratio(
original_frames_np[i], frames_np[i], data_range=1.0
)
)
result["ssim_values"].append(
structural_similarity(original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0)
structural_similarity(
original_frames_np[i], frames_np[i], data_range=1.0, channel_axis=0
)
)

if save_frames and sample == 0:
Expand All @@ -215,7 +246,9 @@ def process_sample(sample: int):
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
for future in tqdm(
as_completed(futures), total=num_samples, desc="samples", leave=False
):
result = future.result()
load_times_video_ms.append(result["load_time_video_ms"])
load_times_images_ms.append(result["load_time_images_ms"])
Expand Down Expand Up @@ -275,9 +308,13 @@ def benchmark_encoding_decoding(
random.seed(seed)
benchmark_table = []
for timestamps_mode in tqdm(
decoding_cfg["timestamps_modes"], desc="decodings (timestamps_modes)", leave=False
decoding_cfg["timestamps_modes"],
desc="decodings (timestamps_modes)",
leave=False,
):
for backend in tqdm(decoding_cfg["backends"], desc="decodings (backends)", leave=False):
for backend in tqdm(
decoding_cfg["backends"], desc="decodings (backends)", leave=False
):
benchmark_row = benchmark_decoding(
imgs_dir,
video_path,
Expand Down Expand Up @@ -355,14 +392,23 @@ def main(
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
# We only use the first episode
save_first_episode(imgs_dir, dataset)
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
for key, values in tqdm(
encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False
):
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
encoding_cfg = BASE_ENCODING.copy()
encoding_cfg["vcodec"] = video_codec
encoding_cfg["pix_fmt"] = pixel_format
encoding_cfg[key] = value
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
args_path = Path(
"_".join(str(value) for value in encoding_cfg.values())
)
video_path = (
output_dir
/ "videos"
/ args_path
/ f"{repo_id.replace('/', '_')}.mp4"
)
benchmark_table += benchmark_encoding_decoding(
dataset,
video_path,
Expand All @@ -388,7 +434,9 @@ def main(
# Concatenate all results
df_list = [pd.read_csv(csv_path) for csv_path in file_paths]
concatenated_df = pd.concat(df_list, ignore_index=True)
concatenated_path = output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
concatenated_path = (
output_dir / f"{now:%Y-%m-%d}_{now:%H-%M-%S}_all_{num_samples}-samples.csv"
)
concatenated_df.to_csv(concatenated_path, header=True, index=False)


Expand Down
18 changes: 18 additions & 0 deletions checkport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import socket


def check_port(host, port):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
try:
s.connect((host, port))
print(f"Connection successful to {host}:{port}!")
except Exception as e:
print(f"Connection failed to {host}:{port}: {e}")
finally:
s.close()


if __name__ == "__main__":
host = "127.0.0.1" # or "localhost"
port = 51350
check_port(host, port)
11 changes: 11 additions & 0 deletions docker/lerobot-gpu-mani-skill/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
FROM huggingface/lerobot-gpu:latest

RUN apt-get update && apt-get install -y --no-install-recommends \
libvulkan1 vulkan-tools \
&& apt-get clean && rm -rf /var/lib/apt/lists/*

RUN pip install --upgrade --no-cache-dir pip
RUN pip install --no-cache-dir ".[mani-skill]"

# Set EGL as the rendering backend for MuJoCo
ENV MUJOCO_GL="egl"
Loading