@@ -57,58 +57,6 @@ def conv(x, channels, kernel=4, stride=2, pad=0, pad_type='zero', use_bias=True,
57
57
58
58
return x
59
59
60
-
61
- def partial_conv (x , channels , kernel = 3 , stride = 2 , use_bias = True , padding = 'SAME' , sn = False , scope = 'conv_0' ):
62
- with tf .variable_scope (scope ):
63
- if padding .lower () == 'SAME' .lower ():
64
- with tf .variable_scope ('mask' ):
65
- _ , h , w , _ = x .get_shape ().as_list ()
66
-
67
- slide_window = kernel * kernel
68
- mask = tf .ones (shape = [1 , h , w , 1 ])
69
-
70
- update_mask = tf .layers .conv2d (mask , filters = 1 ,
71
- kernel_size = kernel , kernel_initializer = tf .constant_initializer (1.0 ),
72
- strides = stride , padding = padding , use_bias = False , trainable = False )
73
-
74
- mask_ratio = slide_window / (update_mask + 1e-8 )
75
- update_mask = tf .clip_by_value (update_mask , 0.0 , 1.0 )
76
- mask_ratio = mask_ratio * update_mask
77
-
78
- with tf .variable_scope ('x' ):
79
- if sn :
80
- w = tf .get_variable ("kernel" , shape = [kernel , kernel , x .get_shape ()[- 1 ], channels ],
81
- initializer = weight_init , regularizer = weight_regularizer )
82
- x = tf .nn .conv2d (input = x , filter = spectral_norm (w ), strides = [1 , stride , stride , 1 ], padding = padding )
83
- else :
84
- x = tf .layers .conv2d (x , filters = channels ,
85
- kernel_size = kernel , kernel_initializer = weight_init ,
86
- kernel_regularizer = weight_regularizer ,
87
- strides = stride , padding = padding , use_bias = False )
88
- x = x * mask_ratio
89
-
90
- if use_bias :
91
- bias = tf .get_variable ("bias" , [channels ], initializer = tf .constant_initializer (0.0 ))
92
-
93
- x = tf .nn .bias_add (x , bias )
94
- x = x * update_mask
95
- else :
96
- if sn :
97
- w = tf .get_variable ("kernel" , shape = [kernel , kernel , x .get_shape ()[- 1 ], channels ],
98
- initializer = weight_init , regularizer = weight_regularizer )
99
- x = tf .nn .conv2d (input = x , filter = spectral_norm (w ), strides = [1 , stride , stride , 1 ], padding = padding )
100
- if use_bias :
101
- bias = tf .get_variable ("bias" , [channels ], initializer = tf .constant_initializer (0.0 ))
102
-
103
- x = tf .nn .bias_add (x , bias )
104
- else :
105
- x = tf .layers .conv2d (x , filters = channels ,
106
- kernel_size = kernel , kernel_initializer = weight_init ,
107
- kernel_regularizer = weight_regularizer ,
108
- strides = stride , padding = padding , use_bias = use_bias )
109
-
110
- return x
111
-
112
60
def fully_connected (x , units , use_bias = True , sn = False , scope = 'linear' ):
113
61
with tf .variable_scope (scope ):
114
62
x = flatten (x )
@@ -259,19 +207,6 @@ def no_norm_resblock(x_init, channels, use_bias=True, sn=False, scope='resblock'
259
207
260
208
return x + x_init
261
209
262
- def group_resblock (x_init , channels , groups , use_bias = True , sn = False , scope = 'resblock' ):
263
- with tf .variable_scope (scope ):
264
- with tf .variable_scope ('res1' ):
265
- x = conv (x_init , channels , kernel = 3 , stride = 1 , pad = 1 , pad_type = 'reflect' , use_bias = use_bias , sn = sn )
266
- x = group_norm (x , groups )
267
- x = relu (x )
268
-
269
- with tf .variable_scope ('res2' ):
270
- x = conv (x , channels , kernel = 3 , stride = 1 , pad = 1 , pad_type = 'reflect' , use_bias = use_bias , sn = sn )
271
- x = group_norm (x , groups )
272
-
273
- return x + x_init
274
-
275
210
##################################################################################
276
211
# Sampling
277
212
##################################################################################
@@ -431,16 +366,3 @@ def regularization_loss(scope_name) :
431
366
loss .append (item )
432
367
433
368
return tf .reduce_sum (loss )
434
-
435
- def z_sample (mean , logvar ):
436
- eps = tf .random_normal (tf .shape (mean ), mean = 0.0 , stddev = 1.0 , dtype = tf .float32 )
437
-
438
- return mean + tf .exp (logvar * 0.5 ) * eps
439
-
440
-
441
- def kl_loss (mean , logvar ):
442
- # shape : [batch_size, channel]
443
- loss = 0.5 * tf .reduce_sum (tf .square (mean ) + tf .exp (logvar ) - 1 - logvar , axis = - 1 )
444
- loss = tf .reduce_mean (loss )
445
-
446
- return loss
0 commit comments