-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathmodel.py
81 lines (69 loc) · 3.04 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import tensorflow as tf
from typing import Callable
slim = tf.contrib.slim
__leaky_relu_alpha__ = 0.2
def __leaky_relu__(x, alpha=__leaky_relu_alpha__, name='Leaky_ReLU'):
return tf.maximum(x, alpha*x, name=name)
class Model(object):
def __init__(self,
input_tensor: tf.Variable,
variable_scope_name: str,
n_hidden_neurons: int,
n_hidden_layers: int,
n_out_dim: int,
activation_fn: Callable,
reuse: bool):
self.input = input_tensor
self.variable_scope_name = variable_scope_name
self.n_hidden_neurons = n_hidden_neurons
self.n_hidden_layers = n_hidden_layers
self.n_out_dim = n_out_dim
self.activation_fn = activation_fn
self.reuse = reuse
self.output_tensor = None
self.var_list = None
self.define_model()
def define_model(self):
with tf.variable_scope(self.variable_scope_name, reuse=self.reuse) as vs:
x = self.input
with slim.arg_scope([slim.fully_connected],
num_outputs=self.n_hidden_neurons,
activation_fn=self.activation_fn):
for i in range(self.n_hidden_layers):
x = slim.fully_connected(inputs=x)
self.output_tensor = slim.fully_connected(inputs=x,
num_outputs=self.n_out_dim,
activation_fn=None)
self.var_list = tf.contrib.framework.get_variables(vs)
class Generator(Model):
def __init__(self,
input_tensor: tf.Variable,
variable_scope_name: str='Generator',
n_hidden_neurons: int=512,
n_hidden_layers: int=3,
n_out_dim: int=2,
activation_fn: Callable=__leaky_relu__,
reuse: bool=False):
super(Generator, self).__init__(input_tensor,
variable_scope_name,
n_hidden_neurons,
n_hidden_layers,
n_out_dim,
activation_fn,
reuse)
class Critic(Model):
def __init__(self,
input_tensor: tf.Variable,
variable_scope_name: str='Critic',
n_hidden_neurons: int=512,
n_hidden_layers: int=3,
n_out_dim: int=1,
activation_fn: Callable=__leaky_relu__,
reuse: bool=False):
super(Critic, self).__init__(input_tensor,
variable_scope_name,
n_hidden_neurons,
n_hidden_layers,
n_out_dim,
activation_fn,
reuse)