Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HIL-SERL port] Add Reward classifier benchmark tracking to chose best visual encoder #688

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lerobot/configs/policy/hilserl_classifier.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ training:
save_checkpoint: true
image_keys: ["observation.images.top", "observation.images.wrist"]
label_key: "next.reward"
profile_inference_time: false
profile_inference_time_iters: 20

eval:
batch_size: 16
Expand Down
89 changes: 86 additions & 3 deletions lerobot/scripts/train_hilserl_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,19 @@
from pprint import pformat

import hydra
import numpy as np
import torch
import torch.nn as nn
import wandb
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from termcolor import colored
from torch import optim
from torch.autograd import profiler
from torch.cuda.amp import GradScaler
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from torch.utils.data import DataLoader, RandomSampler, WeightedRandomSampler, random_split
from tqdm import tqdm

import wandb
from lerobot.common.datasets.factory import resolve_delta_timestamps
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.logger import Logger
Expand Down Expand Up @@ -124,6 +126,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
batch_start_time = time.perf_counter()
samples = []
running_loss = 0
inference_times = []

with (
torch.no_grad(),
Expand All @@ -133,7 +136,18 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
images = [batch[img_key].to(device) for img_key in cfg.training.image_keys]
labels = batch[cfg.training.label_key].float().to(device)

outputs = model(images)
if cfg.training.profile_inference_time and logger._cfg.wandb.enable:
with (
profiler.profile(record_shapes=True) as prof,
profiler.record_function("model_inference"),
):
outputs = model(images)
inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)
else:
outputs = model(images)

loss = criterion(outputs.logits, labels)

# Track metrics
Expand Down Expand Up @@ -177,9 +191,76 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
else None,
}

if len(inference_times) > 0:
eval_info["inference_time_avg"] = np.mean(inference_times)
eval_info["inference_time_median"] = np.median(inference_times)
eval_info["inference_time_std"] = np.std(inference_times)
eval_info["inference_time_batch_size"] = val_loader.batch_size

print(
f"Inference mean time: {eval_info['inference_time_avg']:.2f} us, median: {eval_info['inference_time_median']:.2f} us, std: {eval_info['inference_time_std']:.2f} us, with {len(inference_times)} iterations on {device.type} device, batch size: {eval_info['inference_time_batch_size']}"
)

return accuracy, eval_info


def benchmark_inference_time(model, dataset, logger, cfg, device, step):
if not cfg.training.profile_inference_time:
return

iters = cfg.training.profile_inference_time_iters
inference_times = []

loader = DataLoader(
dataset,
batch_size=1,
num_workers=cfg.training.num_workers,
sampler=RandomSampler(dataset),
pin_memory=True,
)

model.eval()
with torch.no_grad():
for _ in tqdm(range(iters), desc="Benchmarking inference time"):
x = next(iter(loader))
x = [x[img_key].to(device) for img_key in cfg.training.image_keys]

# Warm up
for _ in range(10):
_ = model(x)

# sync the device
if device.type == "cuda":
torch.cuda.synchronize()
elif device.type == "mps":
torch.mps.synchronize()

with profiler.profile(record_shapes=True) as prof, profiler.record_function("model_inference"):
_ = model(x)

inference_times.append(
next(x for x in prof.key_averages() if x.key == "model_inference").cpu_time
)

inference_times = np.array(inference_times)
avg, median, std = inference_times.mean(), np.median(inference_times), inference_times.std()
print(
f"Inference time mean: {avg:.2f} us, median: {median:.2f} us, std: {std:.2f} us, with {iters} iterations on {device.type} device"
)
if logger._cfg.wandb.enable:
logger.log_dict(
{
"inference_time_benchmark_avg": avg,
"inference_time_benchmark_median": median,
"inference_time_benchmark_std": std,
},
step + 1,
mode="eval",
)

return avg, median, std


@hydra.main(version_base="1.2", config_path="../configs/policy", config_name="hilserl_classifier")
def train(cfg: DictConfig) -> None:
# Main training pipeline with support for resuming training
Expand Down Expand Up @@ -313,6 +394,8 @@ def train(cfg: DictConfig) -> None:

step += len(train_loader)

benchmark_inference_time(model, dataset, logger, cfg, device, step)

logging.info("Training completed")


Expand Down