Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #35: fix batched sampling, fix tests. #36

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 13 additions & 41 deletions pyknos/mdn/mdn.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@

import numpy as np
import torch
from nflows.utils import torchutils
from torch import Tensor, nn
from torch.nn import functional as F

@@ -34,8 +33,8 @@ def __init__(
hidden_net: nn.Module,
num_components: int,
hidden_features: Optional[int],
custom_initialization=False,
embedding_net=None,
custom_initialization: bool = False,
embedding_net: Optional[nn.Module] = None,
):
"""Mixture of multivariate Gaussians with full diagonal.

@@ -46,7 +45,8 @@ def __init__(
hidden_net: A Module which outputs final hidden representation before
paramterization layers (i.e logits, means, and log precisions).
num_components: Number of mixture components.
custom_initialization: XXX
custom_initialization: If True, initialize mixture coefficients to be
approximately uniform and covariances to be approximately the identity.
"""

# Infer hidden_features from hidden_net if not provided.
@@ -97,12 +97,11 @@ def __init__(
hidden_features, num_components * self._num_upper_params
)

# XXX docstring text
# embedding_net: NOT IMPLEMENTED
# A `nn.Module` which has trainable parameters to encode the
# context (conditioning). It is trained jointly with the MDN.
if embedding_net is not None:
raise NotImplementedError
raise NotImplementedError(
"embedding net is not implemented yet. We recommend using MDN in ",
"conjunction with nflows as done in the sbi package.",
)

# Constant for numerical stability.
self._epsilon = 1e-4
@@ -277,24 +276,18 @@ def sample_mog(
Returns:
Tensor: Samples from the MoG.
"""
batch_size, n_mixtures, output_dim = means.shape

# We need (batch_size * num_samples) samples in total.
means, precision_factors = (
torchutils.repeat_rows(means, num_samples),
torchutils.repeat_rows(precision_factors, num_samples),
)
batch_size, _, output_dim = means.shape

# Normalize the logits for the coefficients.
coefficients = F.softmax(logits, dim=-1) # [batch_size, num_components]

# Choose num_samples mixture components per example in the batch.
choices = torch.multinomial(
coefficients, num_samples=num_samples, replacement=True
).view(-1) # [batch_size, num_samples]
) # [batch_size, num_samples]

# Create dummy index for indexing means and precision factors.
ix = torchutils.repeat_rows(torch.arange(batch_size), num_samples)
ix = torch.arange(batch_size).unsqueeze(1).expand(batch_size, num_samples)

# Select means and precision factors.
chosen_means = means[ix, choices, :]
@@ -305,7 +298,8 @@ def sample_mog(
zero_mean_samples = torch.linalg.solve_triangular(
chosen_precision_factors,
torch.randn(
batch_size * num_samples,
batch_size,
num_samples,
output_dim,
1,
device=chosen_precision_factors.device,
@@ -348,25 +342,3 @@ def _initialize(self) -> None:
self._upper_layer.bias.data = self._epsilon * torch.randn(
self._num_components * self._num_upper_params
)


# XXX This -> tests
def main():
# probs = torch.Tensor([[1, 0], [0, 1]])
# samples = torch.multinomial(probs, num_samples=5, replacement=True)
# print(samples)
# quit()
mdn = MultivariateGaussianMDN(
features=2,
context_features=3,
hidden_features=16,
hidden_net=nn.Linear(3, 16),
num_components=4,
)
inputs = torch.randn(1, 3)
samples = mdn.sample(9, inputs)
print(samples.shape)


if __name__ == "__main__":
main()
105 changes: 65 additions & 40 deletions tests/mdn_test.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import pytest
import torch
import torch.nn as nn
from torch import Tensor, eye
from torch import Tensor

from pyknos.mdn.mdn import MultivariateGaussianMDN

@@ -16,44 +16,69 @@ def linear_gaussian(
return likelihood_shift + theta + torch.mm(chol_factor, torch.randn_like(theta).T).T


@pytest.mark.parametrize("dim", ([1, 5, 10]))
@pytest.mark.parametrize("device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("hidden_features", (50, None))
def get_mdn(
features: int,
context_features: int,
num_components: int = 10,
hidden_features: Optional[int] = None,
) -> MultivariateGaussianMDN:
if hidden_features is None:
hidden_features = 50
return MultivariateGaussianMDN(
features=features,
context_features=context_features,
hidden_features=hidden_features,
hidden_net=nn.Sequential(
nn.Linear(context_features, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, hidden_features),
nn.ReLU(),
),
num_components=num_components,
custom_initialization=True,
)


@pytest.mark.parametrize("dim_input", ([1, 2]))
@pytest.mark.parametrize("dim_context", ([1, 2]))
@pytest.mark.parametrize(
"device",
(
"cpu",
pytest.param(
"cuda:0",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
),
)
@pytest.mark.parametrize("hidden_features", (10, None))
def test_mdn_for_diff_dimension_data(
dim: int, device: str, hidden_features: Optional[int], num_components: int = 10
dim_input: int,
dim_context: int,
device: str,
hidden_features: Optional[int],
num_components: int = 2,
) -> None:
if device == "cuda:0" and not torch.cuda.is_available():
pass
else:
theta = torch.rand(3, dim)
likelihood_shift = torch.zeros(theta.shape)
likelihood_cov = eye(dim)
context = linear_gaussian(theta, likelihood_shift, likelihood_cov)

x_numel = theta[0].numel()
y_numel = context[0].numel()

net_features = hidden_features if hidden_features is not None else 50
distribution = MultivariateGaussianMDN(
features=x_numel,
context_features=y_numel,
hidden_features=hidden_features,
hidden_net=nn.Sequential(
nn.Linear(y_numel, net_features),
nn.ReLU(),
nn.Linear(net_features, net_features),
nn.ReLU(),
),
num_components=num_components,
custom_initialization=True,
)
distribution = distribution.to(device)

logits, means, precisions, _, _ = distribution.get_mixture_components(
theta.to(device)
)

# Test evaluation and sampling.
distribution.log_prob(context.to(device), theta.to(device))
distribution.sample(100, theta.to(device))
distribution.sample_mog(10, logits, means, precisions)
num_samples = 5
num_context = 1
context = torch.randn(num_context, dim_context)

net_features = hidden_features if hidden_features is not None else 20
distribution = get_mdn(
features=dim_input,
context_features=dim_context,
num_components=num_components,
hidden_features=net_features,
)
distribution = distribution.to(device)

# Test evaluation and sampling.
samples = distribution.sample(num_samples, context=context.to(device))
assert samples.shape == (num_context, num_samples, dim_input)

log_probs = distribution.log_prob(
samples.squeeze(0), context.to(device).repeat(num_samples, 1)
)
assert log_probs.shape == (num_samples,)