Skip to content

Commit

Permalink
Merge pull request #90 from masa-su/fix_init_arg_error
Browse files Browse the repository at this point in the history
Explicit init arg of exponential distributions
  • Loading branch information
masa-su authored Nov 19, 2019
2 parents 5dc7c4c + 20d581e commit 74fcb2c
Showing 1 changed file with 35 additions and 8 deletions.
43 changes: 35 additions & 8 deletions pixyz/distributions/exponential_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,14 @@
from .distributions import DistributionBase


def _valid_param_dict(raw_dict):
return {var_name: value for var_name, value in raw_dict.items() if value is not None}


class Normal(DistributionBase):
"""Normal distribution parameterized by :attr:`loc` and :attr:`scale`. """
def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size(), loc=None, scale=None):
super().__init__(cond_var, var, name, features_shape, **_valid_param_dict({'loc': loc, 'scale': scale}))

@property
def params_keys(self):
Expand All @@ -32,6 +38,8 @@ def distribution_name(self):

class Bernoulli(DistributionBase):
"""Bernoulli distribution parameterized by :attr:`probs`."""
def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size(), probs=None):
super().__init__(cond_var, var, name, features_shape, **_valid_param_dict({'probs': probs}))

@property
def params_keys(self):
Expand All @@ -49,11 +57,11 @@ def distribution_name(self):
class RelaxedBernoulli(Bernoulli):
"""Relaxed (re-parameterizable) Bernoulli distribution parameterized by :attr:`probs`."""

def __init__(self, temperature=torch.tensor(0.1), cond_var=[], var=["x"], name="p", features_shape=torch.Size(),
**kwargs):
def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), temperature=torch.tensor(0.1),
probs=None):
self._temperature = temperature

super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **kwargs)
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, probs=probs)

@property
def temperature(self):
Expand Down Expand Up @@ -99,6 +107,8 @@ class FactorizedBernoulli(Bernoulli):
[Vedantam+ 2017] Generative Models of Visually Grounded Imagination
"""
def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size(), probs=None):
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, probs=probs)

@property
def distribution_name(self):
Expand All @@ -114,6 +124,9 @@ def get_log_prob(self, x_dict):

class Categorical(DistributionBase):
"""Categorical distribution parameterized by :attr:`probs`."""
def __init__(self, cond_var=[], var=['x'], name='p', features_shape=torch.Size(), probs=None):
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'probs': probs}))

@property
def params_keys(self):
Expand All @@ -131,11 +144,11 @@ def distribution_name(self):
class RelaxedCategorical(Categorical):
"""Relaxed (re-parameterizable) categorical distribution parameterized by :attr:`probs`."""

def __init__(self, temperature=torch.tensor(0.1), cond_var=[], var=["x"], name="p", features_shape=torch.Size(),
**kwargs):
def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), temperature=torch.tensor(0.1),
probs=None):
self._temperature = temperature

super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **kwargs)
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, probs=probs)

@property
def temperature(self):
Expand Down Expand Up @@ -183,10 +196,11 @@ def sample_variance(self, x_dict={}):
class Multinomial(DistributionBase):
"""Multinomial distribution parameterized by :attr:`total_count` and :attr:`probs`."""

def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), total_count=1, **kwargs):
def __init__(self, total_count=1, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), probs=None):
self._total_count = total_count

super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape, **kwargs)
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'probs': probs}))

@property
def total_count(self):
Expand All @@ -207,6 +221,9 @@ def distribution_name(self):

class Dirichlet(DistributionBase):
"""Dirichlet distribution parameterized by :attr:`concentration`."""
def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), concentration=None):
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'concentration': concentration}))

@property
def params_keys(self):
Expand All @@ -223,6 +240,10 @@ def distribution_name(self):

class Beta(DistributionBase):
"""Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`."""
def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(),
concentration1=None, concentration0=None):
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'concentration1': concentration1, 'concentration0': concentration0}))

@property
def params_keys(self):
Expand All @@ -241,6 +262,9 @@ class Laplace(DistributionBase):
"""
Laplace distribution parameterized by :attr:`loc` and :attr:`scale`.
"""
def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), loc=None, scale=None):
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'loc': loc, 'scale': scale}))

@property
def params_keys(self):
Expand All @@ -259,6 +283,9 @@ class Gamma(DistributionBase):
"""
Gamma distribution parameterized by :attr:`concentration` and :attr:`rate`.
"""
def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), concentration=None, rate=None):
super().__init__(cond_var=cond_var, var=var, name=name, features_shape=features_shape,
**_valid_param_dict({'concentration': concentration, 'rate': rate}))

@property
def params_keys(self):
Expand Down

0 comments on commit 74fcb2c

Please sign in to comment.