Skip to content

Commit 20d581e

Browse files
committed
arrange the order of args of RelaxedDist
1 parent 80d0a8e commit 20d581e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pixyz/distributions/exponential_distributions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def distribution_name(self):
5757
class RelaxedBernoulli(Bernoulli):
5858
"""Relaxed (re-parameterizable) Bernoulli distribution parameterized by :attr:`probs`."""
5959

60-
def __init__(self, temperature=torch.tensor(0.1), cond_var=[], var=["x"], name="p", features_shape=torch.Size(),
60+
def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), temperature=torch.tensor(0.1),
6161
probs=None):
6262
self._temperature = temperature
6363

@@ -144,7 +144,7 @@ def distribution_name(self):
144144
class RelaxedCategorical(Categorical):
145145
"""Relaxed (re-parameterizable) categorical distribution parameterized by :attr:`probs`."""
146146

147-
def __init__(self, temperature=torch.tensor(0.1), cond_var=[], var=["x"], name="p", features_shape=torch.Size(),
147+
def __init__(self, cond_var=[], var=["x"], name="p", features_shape=torch.Size(), temperature=torch.tensor(0.1),
148148
probs=None):
149149
self._temperature = temperature
150150

0 commit comments

Comments
 (0)