-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
144 lines (126 loc) · 3.79 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Run the trainig pipelie.
Example:
$ python main.py --batch_size 100
"""
import argparse
from loguru import logger
import threading
import time
import tensorflow as tf
from model import Mach
from utils import TBLogger
from utils import next_run_prefix
from utils import load_data_and_constants
def run(hparams, components):
try:
for c in components:
c.start()
logger.info('Begin wait on main...')
running = True
while running:
for c in components:
if c.running == False:
running = False
time.sleep(5)
except:
logger.debug('tear down.')
for c in components:
c.stop()
def main(hparams):
# Get a unique id for this training run.
run_prefix = next_run_prefix()
# Build components.
components = []
for i in range(hparams.n_components):
# Load a unique dataset for each component.
mnist_i, hparams = load_data_and_constants(hparams)
# Tensorboard logger tool.
logdir_i = hparams.log_dir + "/" + run_prefix + "/" + 'c' + str(i)
tblogger_i = TBLogger(logdir_i)
# Component.
mach_i = Mach(i, mnist_i, hparams, tblogger_i)
components.append(mach_i)
# Connect components
for i in range(hparams.n_components):
if i != 0:
components[i].set_child(components[i-1])
# Run experiment.
run(hparams, components)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--batch_size',
default=50,
type=int,
help='The number of examples per batch. Default batch_size=128')
parser.add_argument(
'--learning_rate',
default=1e-5,
type=float,
help='Component learning rate. Default learning_rate=1e-4')
parser.add_argument(
'--n_embedding',
default=128,
type=int,
help='Size of embedding between components. Default n_embedding=128')
parser.add_argument(
'--n_components',
default=2,
type=int,
help='The number of training iterations. Default n_components=2')
parser.add_argument(
'--n_iterations',
default=10000,
type=int,
help='The number of training iterations. Default n_iterations=10000')
parser.add_argument(
'--n_hidden1',
default=512,
type=int,
help='Size of layer 1. Default n_hidden1=512')
parser.add_argument(
'--n_hidden2',
default=512,
type=int,
help='Size of layer 1. Default n_hidden2=512')
parser.add_argument(
'--n_shidden1',
default=512,
type=int,
help='Size of synthetic model hidden layer 1. Default n_shidden1=512')
parser.add_argument(
'--n_shidden2',
default=512,
type=int,
help='Size of synthetic model hidden layer 2. Default n_shidden2=512')
parser.add_argument(
'--max_depth',
default=1,
type=int,
help='Depth at which the synthetic inputs are used. Default max_depth=2')
parser.add_argument(
'--n_print',
default=100,
type=int,
help=
'The number of iterations between print statements. Default n_print=100'
)
parser.add_argument(
'--log_dir',
default='logs',
type=str,
help='location of tensorboard logs. Default log_dir=logs'
)
parser.add_argument(
'--n_train_steps',
default=10000000,
type=int,
help='Training steps. Default n_train_steps=1000000'
)
parser.add_argument(
'--dataset',
default='mnist',
type=str,
help='Dataset on which to run the network. Default is mnist.')
hparams = parser.parse_args()
main(hparams)