Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #35: fix batched sampling, fix tests. #36

Merged
merged 4 commits into from
Aug 20, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix #35: fix batched sampling, fix tests.
janfb committed Aug 19, 2024
commit 1c8251a9689ee10a21bdf028c1097c9498d7a00c
38 changes: 5 additions & 33 deletions pyknos/mdn/mdn.py
Original file line number Diff line number Diff line change
@@ -10,7 +10,6 @@

import numpy as np
import torch
from nflows.utils import torchutils
from torch import Tensor, nn
from torch.nn import functional as F

@@ -277,24 +276,18 @@ def sample_mog(
Returns:
Tensor: Samples from the MoG.
"""
batch_size, n_mixtures, output_dim = means.shape

# We need (batch_size * num_samples) samples in total.
means, precision_factors = (
torchutils.repeat_rows(means, num_samples),
torchutils.repeat_rows(precision_factors, num_samples),
)
batch_size, _, output_dim = means.shape

# Normalize the logits for the coefficients.
coefficients = F.softmax(logits, dim=-1) # [batch_size, num_components]

# Choose num_samples mixture components per example in the batch.
choices = torch.multinomial(
coefficients, num_samples=num_samples, replacement=True
).view(-1) # [batch_size, num_samples]
) # [batch_size, num_samples]

# Create dummy index for indexing means and precision factors.
ix = torchutils.repeat_rows(torch.arange(batch_size), num_samples)
ix = torch.arange(batch_size).unsqueeze(1).expand(batch_size, num_samples)

# Select means and precision factors.
chosen_means = means[ix, choices, :]
@@ -305,7 +298,8 @@ def sample_mog(
zero_mean_samples = torch.linalg.solve_triangular(
chosen_precision_factors,
torch.randn(
batch_size * num_samples,
batch_size,
num_samples,
output_dim,
1,
device=chosen_precision_factors.device,
@@ -348,25 +342,3 @@ def _initialize(self) -> None:
self._upper_layer.bias.data = self._epsilon * torch.randn(
self._num_components * self._num_upper_params
)


# XXX This -> tests
def main():
# probs = torch.Tensor([[1, 0], [0, 1]])
# samples = torch.multinomial(probs, num_samples=5, replacement=True)
# print(samples)
# quit()
mdn = MultivariateGaussianMDN(
features=2,
context_features=3,
hidden_features=16,
hidden_net=nn.Linear(3, 16),
num_components=4,
)
inputs = torch.randn(1, 3)
samples = mdn.sample(9, inputs)
print(samples.shape)


if __name__ == "__main__":
main()
105 changes: 65 additions & 40 deletions tests/mdn_test.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
import pytest
import torch
import torch.nn as nn
from torch import Tensor, eye
from torch import Tensor

from pyknos.mdn.mdn import MultivariateGaussianMDN

@@ -16,44 +16,69 @@ def linear_gaussian(
return likelihood_shift + theta + torch.mm(chol_factor, torch.randn_like(theta).T).T


@pytest.mark.parametrize("dim", ([1, 5, 10]))
@pytest.mark.parametrize("device", ("cpu", "cuda:0"))
@pytest.mark.parametrize("hidden_features", (50, None))
def get_mdn(
features: int,
context_features: int,
num_components: int = 10,
hidden_features: Optional[int] = None,
) -> MultivariateGaussianMDN:
if hidden_features is None:
hidden_features = 50
return MultivariateGaussianMDN(
features=features,
context_features=context_features,
hidden_features=hidden_features,
hidden_net=nn.Sequential(
nn.Linear(context_features, hidden_features),
nn.ReLU(),
nn.Linear(hidden_features, hidden_features),
nn.ReLU(),
),
num_components=num_components,
custom_initialization=True,
)


@pytest.mark.parametrize("dim_input", ([1, 2]))
@pytest.mark.parametrize("dim_context", ([1, 2]))
@pytest.mark.parametrize(
"device",
(
"cpu",
pytest.param(
"cuda:0",
marks=pytest.mark.skipif(
not torch.cuda.is_available(), reason="CUDA not available"
),
),
),
)
@pytest.mark.parametrize("hidden_features", (10, None))
def test_mdn_for_diff_dimension_data(
dim: int, device: str, hidden_features: Optional[int], num_components: int = 10
dim_input: int,
dim_context: int,
device: str,
hidden_features: Optional[int],
num_components: int = 2,
) -> None:
if device == "cuda:0" and not torch.cuda.is_available():
pass
else:
theta = torch.rand(3, dim)
likelihood_shift = torch.zeros(theta.shape)
likelihood_cov = eye(dim)
context = linear_gaussian(theta, likelihood_shift, likelihood_cov)

x_numel = theta[0].numel()
y_numel = context[0].numel()

net_features = hidden_features if hidden_features is not None else 50
distribution = MultivariateGaussianMDN(
features=x_numel,
context_features=y_numel,
hidden_features=hidden_features,
hidden_net=nn.Sequential(
nn.Linear(y_numel, net_features),
nn.ReLU(),
nn.Linear(net_features, net_features),
nn.ReLU(),
),
num_components=num_components,
custom_initialization=True,
)
distribution = distribution.to(device)

logits, means, precisions, _, _ = distribution.get_mixture_components(
theta.to(device)
)

# Test evaluation and sampling.
distribution.log_prob(context.to(device), theta.to(device))
distribution.sample(100, theta.to(device))
distribution.sample_mog(10, logits, means, precisions)
num_samples = 5
num_context = 1
context = torch.randn(num_context, dim_context)

net_features = hidden_features if hidden_features is not None else 20
distribution = get_mdn(
features=dim_input,
context_features=dim_context,
num_components=num_components,
hidden_features=net_features,
)
distribution = distribution.to(device)

# Test evaluation and sampling.
samples = distribution.sample(num_samples, context=context.to(device))
assert samples.shape == (num_context, num_samples, dim_input)

log_probs = distribution.log_prob(
samples.squeeze(0), context.to(device).repeat(num_samples, 1)
)
assert log_probs.shape == (num_samples,)