forked from devsisters/neural-combinatorial-rl-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrainer.py
119 lines (94 loc) · 3.75 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import os
import numpy as np
from tqdm import trange
import tensorflow as tf
from tensorflow.contrib.framework.python.ops import arg_scope
from model import Model
from utils import show_all_variables
from data_loader import TSPDataLoader
class Trainer(object):
def __init__(self, config, rng):
self.config = config
self.rng = rng
self.task = config.task
self.model_dir = config.model_dir
self.gpu_memory_fraction = config.gpu_memory_fraction
self.log_step = config.log_step
self.max_step = config.max_step
self.num_log_samples = config.num_log_samples
self.checkpoint_secs = config.checkpoint_secs
self.summary_ops = {}
if config.task.lower().startswith('tsp'):
self.data_loader = TSPDataLoader(config, rng=self.rng)
else:
raise Exception("[!] Unknown task: {}".format(config.task))
self.models = {}
self.model = Model(
config,
inputs=self.data_loader.x,
labels=self.data_loader.y,
enc_seq_length=self.data_loader.seq_length,
dec_seq_length=self.data_loader.seq_length,
mask=self.data_loader.mask)
self.build_session()
show_all_variables()
def build_session(self):
self.saver = tf.train.Saver()
self.summary_writer = tf.summary.FileWriter(self.model_dir)
sv = tf.train.Supervisor(logdir=self.model_dir,
is_chief=True,
saver=self.saver,
summary_op=None,
summary_writer=self.summary_writer,
save_summaries_secs=300,
save_model_secs=self.checkpoint_secs,
global_step=self.model.global_step)
gpu_options = tf.GPUOptions(
per_process_gpu_memory_fraction=self.gpu_memory_fraction,
allow_growth=True) # seems to be not working
sess_config = tf.ConfigProto(allow_soft_placement=True,
gpu_options=gpu_options)
self.sess = sv.prepare_or_wait_for_session(config=sess_config)
def train(self):
tf.logging.info("Training starts...")
self.data_loader.run_input_queue(self.sess)
summary_writer = None
for k in trange(self.max_step, desc="train"):
fetch = {
'optim': self.model.optim,
}
result = self.model.train(self.sess, fetch, summary_writer)
if result['step'] % self.log_step == 0:
self._test(self.summary_writer)
summary_writer = self._get_summary_writer(result)
def test(self):
tf.logging.info("Testing starts...")
for idx in range(10):
self._test(None)
def _test(self, summary_writer):
fetch = {
'loss': self.model.total_inference_loss,
'pred': self.model.dec_inference,
'true': self.model.dec_targets,
}
result = self.model.test(self.sess, fetch, summary_writer)
tf.logging.info("")
tf.logging.info("test loss: {}".format(result['loss']))
for idx in range(self.num_log_samples):
pred, true = result['pred'][idx], result['true'][idx]
tf.logging.info("test pred: {}".format(pred))
tf.logging.info("test true: {} ({})".format(true, np.array_equal(pred, true)))
if summary_writer:
summary_writer.add_summary(result['summary'], result['step'])
def _inject_summary(self, tag, feed_dict, step):
summaries = self.sess.run(self.summary_ops[tag], feed_dict)
self.summary_writer.add_summary(summaries['summary'], step)
path = os.path.join(
self.config.sample_model_dir, "{}.png".format(step))
imwrite(path, img_tile(summaries['output'],
tile_shape=self.config.sample_image_grid)[:,:,0])
def _get_summary_writer(self, result):
if result['step'] % self.log_step == 0:
return self.summary_writer
else:
return None