20
20
from pprint import pformat
21
21
22
22
import hydra
23
+ import numpy as np
23
24
import torch
24
25
import torch .nn as nn
25
- import wandb
26
26
from deepdiff import DeepDiff
27
27
from omegaconf import DictConfig , OmegaConf
28
28
from termcolor import colored
29
29
from torch import optim
30
+ from torch .autograd import profiler
30
31
from torch .cuda .amp import GradScaler
31
- from torch .utils .data import DataLoader , WeightedRandomSampler , random_split
32
+ from torch .utils .data import DataLoader , RandomSampler , WeightedRandomSampler , random_split
32
33
from tqdm import tqdm
33
34
35
+ import wandb
34
36
from lerobot .common .datasets .factory import resolve_delta_timestamps
35
37
from lerobot .common .datasets .lerobot_dataset import LeRobotDataset
36
38
from lerobot .common .logger import Logger
@@ -124,6 +126,7 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
124
126
batch_start_time = time .perf_counter ()
125
127
samples = []
126
128
running_loss = 0
129
+ inference_times = []
127
130
128
131
with (
129
132
torch .no_grad (),
@@ -133,7 +136,18 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
133
136
images = [batch [img_key ].to (device ) for img_key in cfg .training .image_keys ]
134
137
labels = batch [cfg .training .label_key ].float ().to (device )
135
138
136
- outputs = model (images )
139
+ if cfg .training .profile_inference_time and logger ._cfg .wandb .enable :
140
+ with (
141
+ profiler .profile (record_shapes = True ) as prof ,
142
+ profiler .record_function ("model_inference" ),
143
+ ):
144
+ outputs = model (images )
145
+ inference_times .append (
146
+ next (x for x in prof .key_averages () if x .key == "model_inference" ).cpu_time
147
+ )
148
+ else :
149
+ outputs = model (images )
150
+
137
151
loss = criterion (outputs .logits , labels )
138
152
139
153
# Track metrics
@@ -177,9 +191,76 @@ def validate(model, val_loader, criterion, device, logger, cfg, num_samples_to_l
177
191
else None ,
178
192
}
179
193
194
+ if len (inference_times ) > 0 :
195
+ eval_info ["inference_time_avg" ] = np .mean (inference_times )
196
+ eval_info ["inference_time_median" ] = np .median (inference_times )
197
+ eval_info ["inference_time_std" ] = np .std (inference_times )
198
+ eval_info ["inference_time_batch_size" ] = val_loader .batch_size
199
+
200
+ print (
201
+ f"Inference mean time: { eval_info ['inference_time_avg' ]:.2f} us, median: { eval_info ['inference_time_median' ]:.2f} us, std: { eval_info ['inference_time_std' ]:.2f} us, with { len (inference_times )} iterations on { device .type } device, batch size: { eval_info ['inference_time_batch_size' ]} "
202
+ )
203
+
180
204
return accuracy , eval_info
181
205
182
206
207
+ def benchmark_inference_time (model , dataset , logger , cfg , device , step ):
208
+ if not cfg .training .profile_inference_time :
209
+ return
210
+
211
+ iters = cfg .training .profile_inference_time_iters
212
+ inference_times = []
213
+
214
+ loader = DataLoader (
215
+ dataset ,
216
+ batch_size = 1 ,
217
+ num_workers = cfg .training .num_workers ,
218
+ sampler = RandomSampler (dataset ),
219
+ pin_memory = True ,
220
+ )
221
+
222
+ model .eval ()
223
+ with torch .no_grad ():
224
+ for _ in tqdm (range (iters ), desc = "Benchmarking inference time" ):
225
+ x = next (iter (loader ))
226
+ x = [x [img_key ].to (device ) for img_key in cfg .training .image_keys ]
227
+
228
+ # Warm up
229
+ for _ in range (10 ):
230
+ _ = model (x )
231
+
232
+ # sync the device
233
+ if device .type == "cuda" :
234
+ torch .cuda .synchronize ()
235
+ elif device .type == "mps" :
236
+ torch .mps .synchronize ()
237
+
238
+ with profiler .profile (record_shapes = True ) as prof , profiler .record_function ("model_inference" ):
239
+ _ = model (x )
240
+
241
+ inference_times .append (
242
+ next (x for x in prof .key_averages () if x .key == "model_inference" ).cpu_time
243
+ )
244
+
245
+ inference_times = np .array (inference_times )
246
+ avg , median , std = inference_times .mean (), np .median (inference_times ), inference_times .std ()
247
+ print (
248
+ f"Inference time mean: { avg :.2f} us, median: { median :.2f} us, std: { std :.2f} us, with { iters } iterations on { device .type } device"
249
+ )
250
+ if logger ._cfg .wandb .enable :
251
+ logger .log_dict (
252
+ {
253
+ "inference_time_benchmark_avg" : avg ,
254
+ "inference_time_benchmark_median" : median ,
255
+ "inference_time_benchmark_std" : std ,
256
+ },
257
+ step + 1 ,
258
+ mode = "eval" ,
259
+ )
260
+
261
+ return avg , median , std
262
+
263
+
183
264
@hydra .main (version_base = "1.2" , config_path = "../configs/policy" , config_name = "hilserl_classifier" )
184
265
def train (cfg : DictConfig ) -> None :
185
266
# Main training pipeline with support for resuming training
@@ -313,6 +394,8 @@ def train(cfg: DictConfig) -> None:
313
394
314
395
step += len (train_loader )
315
396
397
+ benchmark_inference_time (model , dataset , logger , cfg , device , step )
398
+
316
399
logging .info ("Training completed" )
317
400
318
401
0 commit comments