Skip to content

Commit fe55b1c

Browse files
Fixes current density estimator bug (#1155)
* fixes current bug! * Added tests
1 parent 9a8c7c0 commit fe55b1c

File tree

2 files changed

+71
-6
lines changed

2 files changed

+71
-6
lines changed

sbi/neural_nets/density_estimators/nflows_flow.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,8 @@ def sample(self, sample_shape: Shape, condition: Tensor) -> Tensor:
135135
num_samples = torch.Size(sample_shape).numel()
136136

137137
samples = self.net.sample(num_samples, context=condition)
138-
139-
return samples.reshape((
140-
*sample_shape,
141-
condition_batch_dim,
142-
-1,
143-
))
138+
samples = samples.transpose(0, 1)
139+
return samples.reshape((*sample_shape, condition_batch_dim, *self.input_shape))
144140

145141
def sample_and_log_prob(
146142
self, sample_shape: torch.Size, condition: Tensor, **kwargs

tests/density_estimator_test.py

+69
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,75 @@ def test_correctness_of_density_estimator_log_prob(
283283
assert torch.allclose(log_probs[0, :], log_probs[1, :])
284284

285285

286+
@pytest.mark.parametrize(
287+
"density_estimator_build_fn",
288+
(
289+
build_mdn,
290+
build_maf,
291+
build_maf_rqs,
292+
build_nsf,
293+
build_zuko_bpf,
294+
build_zuko_gf,
295+
build_zuko_maf,
296+
build_zuko_naf,
297+
build_zuko_ncsf,
298+
build_zuko_nice,
299+
build_zuko_nsf,
300+
build_zuko_sospf,
301+
build_zuko_unaf,
302+
build_categoricalmassestimator,
303+
build_mnle,
304+
),
305+
)
306+
@pytest.mark.parametrize("input_event_shape", ((1,), (4,)))
307+
@pytest.mark.parametrize("condition_event_shape", ((1,), (7,)))
308+
def test_correctness_of_batched_vs_seperate_sample_and_log_prob(
309+
density_estimator_build_fn, input_event_shape, condition_event_shape
310+
):
311+
input_sample_dim = 2
312+
batch_dim = 2
313+
density_estimator, inputs, condition = _build_density_estimator_and_tensors(
314+
density_estimator_build_fn,
315+
input_event_shape,
316+
condition_event_shape,
317+
batch_dim,
318+
input_sample_dim,
319+
)
320+
# Batched vs separate sampling
321+
samples = density_estimator.sample((1000,), condition=condition)
322+
samples_separate1 = density_estimator.sample(
323+
(1000,), condition=condition[0][None, ...]
324+
)
325+
samples_separate2 = density_estimator.sample(
326+
(1000,), condition=condition[1][None, ...]
327+
)
328+
329+
# Check if means are approx. same
330+
samples_m = torch.mean(samples, dim=0, dtype=torch.float32)
331+
samples_separate1_m = torch.mean(samples_separate1, dim=0, dtype=torch.float32)
332+
samples_separate2_m = torch.mean(samples_separate2, dim=0, dtype=torch.float32)
333+
samples_sep_m = torch.cat([samples_separate1_m, samples_separate2_m], dim=0)
334+
335+
assert torch.allclose(
336+
samples_m, samples_sep_m, atol=0.5, rtol=0.5
337+
), "Batched sampling is not consistent with separate sampling."
338+
339+
# Batched vs separate log_prob
340+
log_probs = density_estimator.log_prob(inputs, condition=condition)
341+
342+
log_probs_separate1 = density_estimator.log_prob(
343+
inputs[:, :1], condition=condition[0][None, ...]
344+
)
345+
log_probs_separate2 = density_estimator.log_prob(
346+
inputs[:, 1:], condition=condition[1][None, ...]
347+
)
348+
log_probs_sep = torch.hstack([log_probs_separate1, log_probs_separate2])
349+
350+
assert torch.allclose(
351+
log_probs, log_probs_sep, atol=1e-2, rtol=1e-2
352+
), "Batched log_prob is not consistent with separate log_prob."
353+
354+
286355
def _build_density_estimator_and_tensors(
287356
density_estimator_build_fn: str,
288357
input_event_shape: Tuple[int],

0 commit comments

Comments
 (0)