Skip to content

Commit b637386

Browse files
authored
[HIL-SERL port] Add Reward classifier benchmark tracking to chose best visual encoder (#688)
1 parent 1252524 commit b637386

File tree

2 files changed

+88
-3
lines changed

2 files changed

+88
-3
lines changed

lerobot/configs/policy/hilserl_classifier.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ training:
2727
# image_keys: ["observation.images.top", "observation.images.wrist"]
2828
image_keys: ["observation.images.laptop", "observation.images.phone"]
2929
label_key: "next.reward"
30+
profile_inference_time: false
31+
profile_inference_time_iters: 20
3032

3133
eval:
3234
batch_size: 16

lerobot/scripts/train_hilserl_classifier.py

+86-3
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,19 @@
2020
from pprint import pformat
2121

2222
import hydra
23+
import numpy as np
2324
import torch
2425
import torch.nn as nn
25-
import wandb
2626
from deepdiff import DeepDiff
2727
from omegaconf import DictConfig, OmegaConf
2828
from termcolor import colored
2929
from torch import optim
30+
from torch.autograd import profiler
3031
from torch.cuda.amp import GradScaler
31-
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
32+
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
3233
from tqdm import tqdm
3334

35+
import wandb
3436
from lerobot.common.datasets.factory import resolve_delta_timestamps
3537
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
3638
from lerobot.common.logger import Logger
@@ -124,6 +126,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
124126
batch_start_time = time.perf_counter()
125127
samples = []
126128
running_loss = 0
129+
inference_times = []
127130

128131
with (
129132
torch.no_grad(),
@@ -133,7 +136,18 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
133136
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
134137
labels = batch[cfg.training.label_key].float().to(device)
135138

136-
outputs = model(images)
139+
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
140+
with (
141+
profiler.profile(record_shapes=True) as prof,
142+
profiler.record_function("model_inference"),
143+
):
144+
outputs = model(images)
145+
inference_times.append(
146+
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
147+
)
148+
else:
149+
outputs = model(images)
150+
137151
loss = criterion(outputs.logits, labels)
138152

139153
# Track metrics
@@ -177,9 +191,76 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
177191
else None,
178192
}
179193

194+
if len(inference_times) > 0:
195+
eval_info["inference_time_avg"] = np.mean(inference_times)
196+
eval_info["inference_time_median"] = np.median(inference_times)
197+
eval_info["inference_time_std"] = np.std(inference_times)
198+
eval_info["inference_time_batch_size"] = val_loader.batch_size
199+
200+
print(
201+
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
202+
)
203+
180204
return accuracy, eval_info
181205

182206

207+
def benchmark_inference_time(model, dataset, logger, cfg, device, step):
208+
if not cfg.training.profile_inference_time:
209+
return
210+
211+
iters = cfg.training.profile_inference_time_iters
212+
inference_times = []
213+
214+
loader = DataLoader(
215+
dataset,
216+
batch_size=1,
217+
num_workers=cfg.training.num_workers,
218+
sampler=RandomSampler(dataset),
219+
pin_memory=True,
220+
)
221+
222+
model.eval()
223+
with torch.no_grad():
224+
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
225+
x = next(iter(loader))
226+
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]
227+
228+
# Warm up
229+
for _ in range(10):
230+
_ = model(x)
231+
232+
# sync the device
233+
if device.type == "cuda":
234+
torch.cuda.synchronize()
235+
elif device.type == "mps":
236+
torch.mps.synchronize()
237+
238+
with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
239+
_ = model(x)
240+
241+
inference_times.append(
242+
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
243+
)
244+
245+
inference_times = np.array(inference_times)
246+
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
247+
print(
248+
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
249+
)
250+
if logger._cfg.wandb.enable:
251+
logger.log_dict(
252+
{
253+
"inference_time_benchmark_avg": avg,
254+
"inference_time_benchmark_median": median,
255+
"inference_time_benchmark_std": std,
256+
},
257+
step + 1,
258+
mode="eval",
259+
)
260+
261+
return avg, median, std
262+
263+
183264
@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
184265
def train(cfg: DictConfig) -> None:
185266
# Main training pipeline with support for resuming training
@@ -313,6 +394,8 @@ def train(cfg: DictConfig) -> None:
313394

314395
step += len(train_loader)
315396

397+
benchmark_inference_time(model, dataset, logger, cfg, device, step)
398+
316399
logging.info("Training completed")
317400

318401

0 commit comments

Comments
 (0)