Skip to content

Commit bbf4636

Browse files
committed
updated
1 parent 804bc32 commit bbf4636

File tree

210 files changed

+10837
-14
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

210 files changed

+10837
-14
lines changed
Binary file not shown.
1.32 KB
Binary file not shown.
1.66 KB
Binary file not shown.
Binary file not shown.
3.39 KB
Binary file not shown.

GANs/3dGAN/src/main.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
2+
'''
3+
main.py
4+
5+
Welcome, this is the entrance to 3dgan
6+
'''
7+
import os
8+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
9+
import argparse
10+
from trainer import trainer
11+
import torch
12+
13+
from tester import tester
14+
import params
15+
16+
def str2bool(v):
17+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
18+
return True
19+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
20+
return False
21+
else:
22+
raise argparse.ArgumentTypeError('Boolean value expected.')
23+
24+
def main():
25+
26+
# add arguments
27+
parser = argparse.ArgumentParser()
28+
29+
# loggings parameters
30+
parser.add_argument('--logs', type=str, default=None, help='logs by tensorboardX')
31+
parser.add_argument('--local_test', type=str2bool, default=False, help='local test verbose')
32+
parser.add_argument('--model_name', type=str, default="dcgan", help='model name for saving')
33+
parser.add_argument('--test', type=str2bool, default=False, help='call tester.py')
34+
parser.add_argument('--use_visdom', type=str2bool, default=False, help='visualization by visdom')
35+
args = parser.parse_args()
36+
37+
# list params
38+
params.print_params()
39+
40+
# run program
41+
if args.test == False:
42+
trainer(args)
43+
else:
44+
tester(args)
45+
46+
47+
if __name__ == '__main__':
48+
main()
49+
50+

GANs/3dGAN/src/model.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
import torch
2+
import params
3+
4+
'''
5+
6+
model.py
7+
8+
Define our GAN model
9+
10+
The cube_len is 32x32x32, and the maximum number of feature map is 256,
11+
so the results may be inconsistent with the paper
12+
13+
'''
14+
15+
class net_G(torch.nn.Module):
16+
def __init__(self, args):
17+
super(net_G, self).__init__()
18+
self.args = args
19+
self.cube_len = params.cube_len
20+
self.bias = params.bias
21+
self.z_dim = params.z_dim
22+
self.f_dim = 32
23+
24+
padd = (0, 0, 0)
25+
if self.cube_len == 32:
26+
padd = (1,1,1)
27+
28+
self.layer1 = self.conv_layer(self.z_dim, self.f_dim*8, kernel_size=4, stride=2, padding=padd, bias=self.bias)
29+
self.layer2 = self.conv_layer(self.f_dim*8, self.f_dim*4, kernel_size=4, stride=2, padding=(1, 1, 1), bias=self.bias)
30+
self.layer3 = self.conv_layer(self.f_dim*4, self.f_dim*2, kernel_size=4, stride=2, padding=(1, 1, 1), bias=self.bias)
31+
self.layer4 = self.conv_layer(self.f_dim*2, self.f_dim, kernel_size=4, stride=2, padding=(1, 1, 1), bias=self.bias)
32+
33+
self.layer5 = torch.nn.Sequential(
34+
torch.nn.ConvTranspose3d(self.f_dim, 1, kernel_size=4, stride=2, bias=self.bias, padding=(1, 1, 1)),
35+
torch.nn.Sigmoid()
36+
# torch.nn.Tanh()
37+
)
38+
39+
def conv_layer(self, input_dim, output_dim, kernel_size=4, stride=2, padding=(1,1,1), bias=False):
40+
layer = torch.nn.Sequential(
41+
torch.nn.ConvTranspose3d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding),
42+
torch.nn.BatchNorm3d(output_dim),
43+
torch.nn.ReLU(True)
44+
# torch.nn.LeakyReLU(self.leak_value, True)
45+
)
46+
return layer
47+
48+
def forward(self, x):
49+
out = x.view(-1, self.z_dim, 1, 1, 1)
50+
# print(out.size()) # torch.Size([32, 200, 1, 1, 1])
51+
out = self.layer1(out)
52+
# print(out.size()) # torch.Size([32, 256, 2, 2, 2])
53+
out = self.layer2(out)
54+
# print(out.size()) # torch.Size([32, 128, 4, 4, 4])
55+
out = self.layer3(out)
56+
# print(out.size()) # torch.Size([32, 64, 8, 8, 8])
57+
out = self.layer4(out)
58+
# print(out.size()) # torch.Size([32, 32, 16, 16, 16])
59+
out = self.layer5(out)
60+
# print(out.size()) # torch.Size([32, 1, 32, 32, 32])
61+
out = torch.squeeze(out)
62+
return out
63+
64+
65+
class net_D(torch.nn.Module):
66+
def __init__(self, args):
67+
super(net_D, self).__init__()
68+
self.args = args
69+
self.cube_len = params.cube_len
70+
self.leak_value = params.leak_value
71+
self.bias = params.bias
72+
73+
padd = (0,0,0)
74+
if self.cube_len == 32:
75+
padd = (1,1,1)
76+
77+
self.f_dim = 32
78+
79+
self.layer1 = self.conv_layer(1, self.f_dim, kernel_size=4, stride=2, padding=(1,1,1), bias=self.bias)
80+
self.layer2 = self.conv_layer(self.f_dim, self.f_dim*2, kernel_size=4, stride=2, padding=(1,1,1), bias=self.bias)
81+
self.layer3 = self.conv_layer(self.f_dim*2, self.f_dim*4, kernel_size=4, stride=2, padding=(1,1,1), bias=self.bias)
82+
self.layer4 = self.conv_layer(self.f_dim*4, self.f_dim*8, kernel_size=4, stride=2, padding=(1,1,1), bias=self.bias)
83+
84+
self.layer5 = torch.nn.Sequential(
85+
torch.nn.Conv3d(self.f_dim*8, 1, kernel_size=4, stride=2, bias=self.bias, padding=padd),
86+
torch.nn.Sigmoid()
87+
)
88+
89+
# self.layer5 = torch.nn.Sequential(
90+
# torch.nn.Linear(256*2*2*2, 1),
91+
# torch.nn.Sigmoid()
92+
# )
93+
94+
def conv_layer(self, input_dim, output_dim, kernel_size=4, stride=2, padding=(1,1,1), bias=False):
95+
layer = torch.nn.Sequential(
96+
torch.nn.Conv3d(input_dim, output_dim, kernel_size=kernel_size, stride=stride, bias=bias, padding=padding),
97+
torch.nn.BatchNorm3d(output_dim),
98+
torch.nn.LeakyReLU(self.leak_value, inplace=True)
99+
)
100+
return layer
101+
102+
def forward(self, x):
103+
# out = torch.unsqueeze(x, dim=1)
104+
out = x.view(-1, 1, self.cube_len, self.cube_len, self.cube_len)
105+
# print(out.size()) # torch.Size([32, 1, 32, 32, 32])
106+
out = self.layer1(out)
107+
# print(out.size()) # torch.Size([32, 32, 16, 16, 16])
108+
out = self.layer2(out)
109+
# print(out.size()) # torch.Size([32, 64, 8, 8, 8])
110+
out = self.layer3(out)
111+
# print(out.size()) # torch.Size([32, 128, 4, 4, 4])
112+
out = self.layer4(out)
113+
# print(out.size()) # torch.Size([32, 256, 2, 2, 2])
114+
# out = out.view(-1, 256*2*2*2)
115+
# print (out.size())
116+
out = self.layer5(out)
117+
# print(out.size()) # torch.Size([32, 1, 1, 1, 1])
118+
out = torch.squeeze(out)
119+
return out
120+

GANs/3dGAN/src/params.py

+59
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
2+
'''
3+
params.py
4+
5+
Managers of all hyper-parameters
6+
7+
'''
8+
9+
import torch
10+
11+
epochs = 500
12+
batch_size = 32
13+
soft_label = False
14+
adv_weight = 0
15+
d_thresh = 0.8
16+
z_dim = 200
17+
z_dis = "norm"
18+
model_save_step = 1
19+
g_lr = 0.0025
20+
d_lr = 0.00001
21+
beta = (0.5, 0.999)
22+
cube_len = 32
23+
leak_value = 0.2
24+
bias = False
25+
26+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
27+
data_dir = '../volumetric_data/'
28+
model_dir = 'airplane/' # change it to train on other data models
29+
output_dir = '../outputs'
30+
images_dir = '../images'
31+
32+
def print_params():
33+
l = 16
34+
print (l*'*' + 'hyper-parameters' + l*'*')
35+
36+
print ('epochs =', epochs)
37+
print ('batch_size =', batch_size)
38+
print ('soft_labels =', soft_label)
39+
print ('adv_weight =', adv_weight)
40+
print ('d_thresh =', d_thresh)
41+
print ('z_dim =', z_dim)
42+
print ('z_dis =', z_dis)
43+
print ('model_images_save_step =', model_save_step)
44+
print ('data =', model_dir)
45+
print ('device =', device)
46+
print ('g_lr =', g_lr)
47+
print ('d_lr =', d_lr)
48+
print ('cube_len =', cube_len)
49+
print ('leak_value =', leak_value)
50+
print ('bias =', bias)
51+
52+
print (l*'*' + 'hyper-parameters' + l*'*')
53+
54+
55+
56+
57+
58+
59+

GANs/3dGAN/src/tester.py

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
2+
'''
3+
tester.py
4+
5+
Test the trained 3dgan models
6+
'''
7+
8+
import torch
9+
from torch import optim
10+
from torch import nn
11+
from collections import OrderedDict
12+
from utils import *
13+
import os
14+
from model import net_G, net_D
15+
# from lr_sh import MultiStepLR
16+
17+
# added
18+
import datetime
19+
from tensorboardX import SummaryWriter
20+
import matplotlib.pyplot as plt
21+
import numpy as np
22+
import params
23+
import visdom
24+
25+
# def test_gen(args):
26+
# test_z = []
27+
# test_num = 1000
28+
# for i in range(test_num):
29+
# z = generateZ(args, 1)
30+
# z = z.numpy()
31+
# test_z.append(z)
32+
33+
# test_z = np.array(test_z)
34+
# print (test_z.shape)
35+
# np.save("test_z", test_z)
36+
37+
def tester(args):
38+
print ('Evaluation Mode...')
39+
40+
# image_saved_path = '../images'
41+
image_saved_path = params.images_dir
42+
if not os.path.exists(image_saved_path):
43+
os.makedirs(image_saved_path)
44+
45+
if args.use_visdom == True:
46+
vis = visdom.Visdom()
47+
48+
save_file_path = params.output_dir + '/' + args.model_name
49+
pretrained_file_path_G = save_file_path+'/'+'G.pth'
50+
pretrained_file_path_D = save_file_path+'/'+'D.pth'
51+
52+
print (pretrained_file_path_G)
53+
54+
D = net_D(args)
55+
G = net_G(args)
56+
57+
if not torch.cuda.is_available():
58+
G.load_state_dict(torch.load(pretrained_file_path_G, map_location={'cuda:0': 'cpu'}))
59+
D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))
60+
else:
61+
G.load_state_dict(torch.load(pretrained_file_path_G))
62+
D.load_state_dict(torch.load(pretrained_file_path_D, map_location={'cuda:0': 'cpu'}))
63+
64+
print ('visualizing model')
65+
66+
# test generator
67+
# test_gen(args)
68+
G.to(params.device)
69+
D.to(params.device)
70+
G.eval()
71+
D.eval()
72+
73+
# test_z = np.load("test_z.npy")
74+
# print (test_z.shape)
75+
# N = test_z.shape[0]
76+
77+
N = 8
78+
79+
for i in range(N):
80+
# z = test_z[i,:]
81+
# z = torch.FloatTensor(z)
82+
83+
z = generateZ(args, 1)
84+
85+
# print (z.size())
86+
fake = G(z)
87+
samples = fake.unsqueeze(dim=0).detach().cpu().numpy()
88+
# print (samples.shape)
89+
# print (fake)
90+
y_prob = D(fake)
91+
y_real = torch.ones_like(y_prob)
92+
# criterion = nn.BCELoss()
93+
# print (y_prob.item(), criterion(y_prob, y_real).item())
94+
95+
### visualization
96+
if args.use_visdom == False:
97+
SavePloat_Voxels(samples, image_saved_path, 'tester_norm_'+str(i))
98+
else:
99+
plotVoxelVisdom(samples[0,:], vis, "tester_"+str(i))
100+
101+
102+

0 commit comments

Comments
 (0)