Skip to content

Commit

Permalink
Merge pull request #135 from masa-su/develop/v0.2.1
Browse files Browse the repository at this point in the history
Develop/v0.2.1
  • Loading branch information
masa-su authored Oct 13, 2020
2 parents 44ea5ca + ca48483 commit d7f126c
Show file tree
Hide file tree
Showing 30 changed files with 442 additions and 61 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ target/

# Jupyter Notebook
.ipynb_checkpoints
data/

# pyenv
.python-version
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

[![pypi](https://img.shields.io/pypi/v/pixyz.svg)](https://pypi.python.org/pypi/pixyz)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
[![Python Version](https://img.shields.io/pypi/pyversions/Django.svg)](https://github.com/masa-su/pixyz)
[![Python Version](https://img.shields.io/badge/python-3.6%20%7C%203.7%20%7C%203.8%20-blue)](https://github.com/masa-su/pixyz)
[![Pytorch Version](https://img.shields.io/badge/pytorch-1.0-yellow.svg)](https://github.com/masa-su/pixyz)
[![Read the Docs](https://readthedocs.org/projects/pixyz/badge/?version=latest)](http://docs.pixyz.io)
[![TravisCI](https://travis-ci.org/masa-su/pixyz.svg?branch=master)](https://github.com/masa-su/pixyz)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
]

autodoc_member_order = 'bysource'
autodoc_default_flags = ['show-inheritance']
autodoc_default_options = {'show-inheritance': True}

napoleon_numpy_docstring = True
napoleon_include_init_with_doc = True
9 changes: 8 additions & 1 deletion examples/cvae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'cvae'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"plot_number = 1\n",
"\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/gan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1788,7 +1788,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'gan'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"z_sample = torch.randn(64, z_dim).to(device)\n",
"_x, _y = iter(test_loader).next()\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/glow.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -813,7 +813,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'glow'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"z_sample = torch.randn(64, 3, 32, 32).to(device)\n",
"_x, _ = iter(test_loader).next()\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/hierarchical_variational_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'hierarchical_variational_inference'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"z_sample = 0.5 * torch.randn(64, z_dim).to(device)\n",
"_x, _ = iter(test_loader).next()\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/jmvae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'jmvae'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"plot_number = 1\n",
"\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/jmvae_poe.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'jmvae_poe'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"plot_number = 1\n",
"\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/m2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -816,7 +816,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'm2'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"for epoch in range(1, epochs + 1):\n",
" train_loss = train(epoch)\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/maximum_likelihood.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'maximum_likelihood'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"for epoch in range(1, epochs + 1):\n",
" train_loss = train(epoch)\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/mmd_vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'mmd_vae'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"z_sample = 0.5 * torch.randn(64, z_dim).to(device)\n",
"_x, _ = iter(test_loader).next()\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/mvae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3395,7 +3395,14 @@
],
"source": [
"# for visualising in TensorBoard\n",
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'mvae'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"plot_number = 1\n",
"\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/real_nvp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'real_nvp'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"z_sample = torch.randn(64, z_dim).to(device)\n",
"_x, _ = iter(test_loader).next()\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/real_nvp_cifar.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'real_nvp_cifar'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"z_sample = torch.randn(64, 3, 32, 32).to(device)\n",
"_x, _ = iter(test_loader).next()\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/real_nvp_cond.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'real_nvp_cond'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"plot_number = 5\n",
"\n",
Expand Down
2 changes: 2 additions & 0 deletions examples/real_nvp_toy.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install sklearn\n",
"\n",
"from __future__ import print_function\n",
"import torch\n",
"import torch.utils.data\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,14 @@
],
"source": [
"# for visualising in TensorBoard\n",
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'vae'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"# fix latent variable z for watching generative model improvement \n",
"z_sample = 0.5 * torch.randn(64, z_dim).to(device)\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/vae_with_vae_class.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,14 @@
],
"source": [
"# for visualising in TensorBoard\n",
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'vae_with_vae_class'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"# fix latent variable z for watching generative model improvement \n",
"z_sample = 0.5 * torch.randn(64, z_dim).to(device)\n",
Expand Down
9 changes: 8 additions & 1 deletion examples/vi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,14 @@
}
],
"source": [
"writer = SummaryWriter()\n",
"import pixyz\n",
"import datetime\n",
"\n",
"dt_now = datetime.datetime.now()\n",
"exp_time = dt_now.strftime('%Y%m%d_%H:%M:%S')\n",
"v = pixyz.__version__\n",
"nb_name = 'vi'\n",
"writer = SummaryWriter(\"runs/\" + v + \".\" + nb_name + exp_time)\n",
"\n",
"z_sample = 0.5 * torch.randn(64, z_dim).to(device)\n",
"_x, _ = iter(test_loader).next()\n",
Expand Down
2 changes: 1 addition & 1 deletion pixyz/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
name = "pixyz"
__version__ = "0.2.0"
__version__ = "0.2.1"
22 changes: 15 additions & 7 deletions pixyz/distributions/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch import nn
from copy import deepcopy

from ..utils import get_dict_values, replace_dict_keys, replace_dict_keys_split, delete_dict_values,\
from ..utils import get_dict_values, replace_dict_keys, delete_dict_values,\
tolist, sum_samples, convert_latex_name, lru_cache_for_sample_dict
from ..losses import LogProb, Prob

Expand Down Expand Up @@ -653,7 +653,9 @@ def _set_buffers(self, **params_dict):
for key in params_dict.keys():
if type(params_dict[key]) is str:
if params_dict[key] in self._cond_var:
self.replace_params_dict[params_dict[key]] = key
if params_dict[key] not in self.replace_params_dict:
self.replace_params_dict[params_dict[key]] = []
self.replace_params_dict[params_dict[key]].append(key)
else:
raise ValueError("parameter setting {}:{} is not valid because cond_var does not contains {}."
.format(key, params_dict[key], params_dict[key]))
Expand Down Expand Up @@ -772,10 +774,16 @@ def get_log_prob(self, x_dict, sum_features=True, feature_dims=None):

@lru_cache_for_sample_dict()
def get_params(self, params_dict={}, **kwargs):
params_dict, vars_dict = replace_dict_keys_split(params_dict, self.replace_params_dict)
output_dict = self.forward(**vars_dict)
replaced_params_dict = {}
for key, value in params_dict.items():
if key in self.replace_params_dict:
for replaced_key in self.replace_params_dict[key]:
replaced_params_dict[replaced_key] = value

output_dict.update(params_dict)
vars_dict = {key: value for key, value in params_dict.items() if key not in self.replace_params_dict}
output_dict = self(**vars_dict)

output_dict.update(replaced_params_dict)

# append constant parameters to output_dict
constant_params_dict = get_dict_values(dict(self.named_buffers()), self.params_keys,
Expand Down Expand Up @@ -1067,7 +1075,7 @@ def __init__(self, p, replace_dict):
self._input_var = _input_var

def forward(self, *args, **kwargs):
return self.p.forward(*args, **kwargs)
return self.p(*args, **kwargs)

def get_params(self, params_dict={}):
params_dict = replace_dict_keys(params_dict, self._replace_inv_cond_var_dict)
Expand Down Expand Up @@ -1204,7 +1212,7 @@ def __init__(self, p, marginalize_list):
self._marginalize_list = marginalize_list

def forward(self, *args, **kwargs):
return self.p.forward(*args, **kwargs)
return self.p(*args, **kwargs)

def get_params(self, params_dict={}):
return self.p.get_params(params_dict)
Expand Down
Loading

0 comments on commit d7f126c

Please sign in to comment.