Skip to content

Commit c3ee94f

Browse files
committed
Third commit
1 parent 92766c6 commit c3ee94f

File tree

6 files changed

+86
-11
lines changed

6 files changed

+86
-11
lines changed

.idea/GenZoo.iml

+3-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/dictionaries/ayushtues.xml

+3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

VAE/data_loader.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import torch
22
import torch.utils.data
33
from torchvision import datasets, transforms
4+
import torch.nn as nn
5+
import torch.nn.functional as F
46

57

68
def load_mnist(batch_size):
@@ -10,6 +12,3 @@ def load_mnist(batch_size):
1012
batch_size=batch_size, shuffle=True)
1113

1214
return train_loader
13-
14-
15-

VAE/main.py

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
import torch.utils.data
3+
from torchvision import datasets, transforms
4+
from torch import nn, optim
5+
import model as VAE
6+
import torch.nn.functional as F
7+
import data_loader as load
8+
from tensorboardX import SummaryWriter
9+
import train.py as train_model
10+
import matplotlib
11+
import matplotlib.pyplot as plt
12+
13+
14+
trainloader = load.load_mnist(batch_size=60)
15+
x = trainloader.next()
16+
plt.imshow(x.numpy()[0], cmap='gray')
17+
18+
model = VAE.make_model()
19+
optimizer = optim.Adam(model.parameters(), lr=1e-3)
20+
device = ('cuda' if torch.cuda.is_available() else 'cpu')
21+
epoch = 10
22+
23+
train_model.train(trainloader, epoch, optimizer, model, device)
24+
25+
26+

VAE/model.py

+25-7
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44

5+
__name__ = "model.py"
6+
57

68
class VAE(nn.Module):
79
def __init__(self, z_dim=20, keep_prob=0.2):
@@ -14,11 +16,11 @@ def __init__(self, z_dim=20, keep_prob=0.2):
1416
self.fc3 = nn.Linear(z_dim, 256)
1517
self.fc4 = nn.Linear(256, 512)
1618
self.fc5 = nn.Linear(512, 784)
17-
self.decode = nn.Sequential(nn.ConvTranspose2d(16, 16, 3, padding=1),
18-
nn.BatchNorm2d(16),
19-
nn.ConvTranspose2d(16, 3, 8),
19+
self.decode = nn.Sequential(nn.ConvTranspose2d(16, 10, 3, padding=1),
20+
nn.BatchNorm2d(10),
21+
nn.ConvTranspose2d(10, 5, 8),
2022
nn.BatchNorm2d(3),
21-
nn.ConvTranspose2d(3, 1, 15))
23+
nn.ConvTranspose2d(5, 1, 15))
2224

2325
self.mean = nn.Linear(256, z_dim)
2426
self.logvar = nn.Linear(256, z_dim)
@@ -37,13 +39,29 @@ def encoder(self, x):
3739
return mean, logvar
3840

3941
def reparameterize(self, mean, logvar):
40-
std = logvar.mul(0.5).exp_()
41-
eps = torch.randn_like(std)
42-
return mean + eps * std
42+
if self.training:
43+
std = logvar.mul(0.5).exp_()
44+
eps = torch.randn_like(std)
45+
return mean + eps * std
46+
else:
47+
return mean
4348

4449
def decoder(self, z):
4550
z = self.fc3(z)
4651
z = self.fc4(z)
4752
z = self.fc5(z)
4853
z = z.view([-1, 16, 7, 7])
54+
z = self.decode(z)
55+
z = F.sigmoid(z)
56+
return z
57+
58+
def forward(self, x):
59+
mean, logvar = self.encoder(x)
60+
z = self.reparameterize(mean, logvar)
61+
x_output = self.decoder(z)
62+
return x_output, mean, logvar
63+
4964

65+
def make_model():
66+
model = VAE()
67+
return model

VAE/train.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import torch
2+
import torchvision
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from torch import nn, optim
6+
import model as VAE
7+
8+
batch_size = 20
9+
10+
11+
def loss_function(x_output, x, mean, logvar):
12+
bce = F.binary_cross_entropy(x_output, x.view(-1, 784))
13+
kld = -0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp())
14+
return (bce + kld) / (784 * batch_size)
15+
16+
17+
def train(trainloader, epoch, optimiser, model, device):
18+
for i in range(epoch):
19+
model.train()
20+
model.to(device)
21+
for images, _ in trainloader:
22+
images = images.to(device)
23+
optimiser.zero_grad()
24+
x_output, mean, logvar = model(images)
25+
loss = loss_function(x_output, images, mean, logvar)
26+
loss.backward()
27+
optimiser.step()

0 commit comments

Comments
 (0)