@@ -283,6 +283,75 @@ def test_correctness_of_density_estimator_log_prob(
283
283
assert torch .allclose (log_probs [0 , :], log_probs [1 , :])
284
284
285
285
286
+ @pytest .mark .parametrize (
287
+ "density_estimator_build_fn" ,
288
+ (
289
+ build_mdn ,
290
+ build_maf ,
291
+ build_maf_rqs ,
292
+ build_nsf ,
293
+ build_zuko_bpf ,
294
+ build_zuko_gf ,
295
+ build_zuko_maf ,
296
+ build_zuko_naf ,
297
+ build_zuko_ncsf ,
298
+ build_zuko_nice ,
299
+ build_zuko_nsf ,
300
+ build_zuko_sospf ,
301
+ build_zuko_unaf ,
302
+ build_categoricalmassestimator ,
303
+ build_mnle ,
304
+ ),
305
+ )
306
+ @pytest .mark .parametrize ("input_event_shape" , ((1 ,), (4 ,)))
307
+ @pytest .mark .parametrize ("condition_event_shape" , ((1 ,), (7 ,)))
308
+ def test_correctness_of_batched_vs_seperate_sample_and_log_prob (
309
+ density_estimator_build_fn , input_event_shape , condition_event_shape
310
+ ):
311
+ input_sample_dim = 2
312
+ batch_dim = 2
313
+ density_estimator , inputs , condition = _build_density_estimator_and_tensors (
314
+ density_estimator_build_fn ,
315
+ input_event_shape ,
316
+ condition_event_shape ,
317
+ batch_dim ,
318
+ input_sample_dim ,
319
+ )
320
+ # Batched vs separate sampling
321
+ samples = density_estimator .sample ((1000 ,), condition = condition )
322
+ samples_separate1 = density_estimator .sample (
323
+ (1000 ,), condition = condition [0 ][None , ...]
324
+ )
325
+ samples_separate2 = density_estimator .sample (
326
+ (1000 ,), condition = condition [1 ][None , ...]
327
+ )
328
+
329
+ # Check if means are approx. same
330
+ samples_m = torch .mean (samples , dim = 0 , dtype = torch .float32 )
331
+ samples_separate1_m = torch .mean (samples_separate1 , dim = 0 , dtype = torch .float32 )
332
+ samples_separate2_m = torch .mean (samples_separate2 , dim = 0 , dtype = torch .float32 )
333
+ samples_sep_m = torch .cat ([samples_separate1_m , samples_separate2_m ], dim = 0 )
334
+
335
+ assert torch .allclose (
336
+ samples_m , samples_sep_m , atol = 0.5 , rtol = 0.5
337
+ ), "Batched sampling is not consistent with separate sampling."
338
+
339
+ # Batched vs separate log_prob
340
+ log_probs = density_estimator .log_prob (inputs , condition = condition )
341
+
342
+ log_probs_separate1 = density_estimator .log_prob (
343
+ inputs [:, :1 ], condition = condition [0 ][None , ...]
344
+ )
345
+ log_probs_separate2 = density_estimator .log_prob (
346
+ inputs [:, 1 :], condition = condition [1 ][None , ...]
347
+ )
348
+ log_probs_sep = torch .hstack ([log_probs_separate1 , log_probs_separate2 ])
349
+
350
+ assert torch .allclose (
351
+ log_probs , log_probs_sep , atol = 1e-2 , rtol = 1e-2
352
+ ), "Batched log_prob is not consistent with separate log_prob."
353
+
354
+
286
355
def _build_density_estimator_and_tensors (
287
356
density_estimator_build_fn : str ,
288
357
input_event_shape : Tuple [int ],
0 commit comments