-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain.lua
407 lines (332 loc) · 13.6 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
require 'nn'
require 'nngraph'
require 'optim'
display = require 'display'
require './plot'
-- Parse command line arguments
cmd = torch.CmdLine()
cmd:text()
cmd:option('-hidden_size', 200, 'Hidden size of LSTM layer')
cmd:option('-glove_size', 100, 'Glove embedding size')
cmd:option('-dropout', 0.1, 'Dropout')
cmd:option('-learning_rate', 0.0002, 'Learning rate')
cmd:option('-learning_rate_decay', 1e-6, 'Learning rate decay')
cmd:option('-max_length', 20, 'Maximum output length')
cmd:option('-n_epochs', 10000, 'Number of epochs to train')
cmd:option('-series', 's1', 'Name of series')
opt = cmd:parse(arg)
require 'data'
require 'model'
-- Training
--------------------------------------------------------------------------------
-- Run a loop of optimization
n_epoch = 1
glove = torch.load('glove.t7')
function sample()
-- Inputs and targets
local sentence, encoder_inputs, command_decoder_inputs, _, argument_targets = unpack(makeSentence())
local sentence_tokens = tokenize(sentence)
print('--------\n[sample]', sentence)
-- Forward pass
-- -------------------------------------------------------------------------
-- Forward through encoder
local encoder_outputs = {[0] = torch.zeros(opt.hidden_size):double()}
for t = 1, #encoder_inputs do
encoder_outputs[t] = clones.encoder[t]:forward({encoder_inputs[t], encoder_outputs[t-1]})
end
-- Pad encoder outputs with 0s
local encoder_outputs_padded = {}
for t = 1, #encoder_outputs do
encoder_outputs_padded[t] = encoder_outputs[t]
-- encoder_outputs_padded[t]:add(encoder_outputs_reverse[t])
end
for t = 1, (opt.max_length - #encoder_outputs) do
table.insert(encoder_outputs_padded, torch.zeros(opt.hidden_size))
end
last_encoder_output = encoder_outputs_padded[#encoder_inputs]
-- Through command decoder
local command_decoder_inputs = {[1] = torch.LongTensor({command_EOS})}
local command_decoder_hidden_outputs = {[0] = last_encoder_output}
local command_decoder_outputs = {}
local sampled = ''
local sampled_full = ''
local sampled_tokens = {}
for t = 1, opt.max_length do
local command_decoder_output = clones.command_decoder[t]:forward(
{command_decoder_inputs[t], command_decoder_hidden_outputs[t-1], encoder_outputs_padded})
command_decoder_outputs[t] = command_decoder_output[1]
command_decoder_hidden_outputs[t] = command_decoder_output[2][1]
-- Choose most likely output
out_max, out_max_index = command_decoder_outputs[t]:max(1)
if out_max_index[1] == command_EOS then
break
end
local output_argument_name = command_index_to_word[out_max_index[1]]
table.insert(sampled_tokens, output_argument_name)
sampled = sampled .. ' ' .. output_argument_name
-- Next decoder input is current output
command_decoder_inputs[t + 1] = out_max_index
end
-- Get arguments from command output
local command_argument_ts = {}
local command_argument_indexes = {}
for t = 1, #command_decoder_outputs - 1 do
local _, command_index = command_decoder_outputs[t]:max(1)
command_index = command_index[1]
local command_word = command_index_to_word[command_index]
if command_word:match(token_re) then
command_argument_ts[command_word] = t
command_argument_indexes[command_word] = torch.LongTensor({command_index})
end
sampled_full = sampled_full .. ' ' .. command_word
end
-- For each argument output
for arg_name, arg_value in pairs(command_argument_indexes) do
arg_t = command_argument_ts[arg_name]
sampled_tokens = ''
arg_out_s = '{'
-- Forward through argument decoder
local argument_decoder_hidden_outputs = {[0] = last_encoder_output}
-- local argument_decoder_hidden_outputs = {[0] = torch.zeros(opt.hidden_size)}
local argument_decoder_attention_outputs = {}
local argument_decoder_outputs = {}
local argument_decoder_output_indexes = {}
for t = 1, #encoder_inputs do
argument_decoder_outputs[t], argument_decoder_gru_outputs = unpack(clones.argument_decoder[t]:forward({
command_decoder_hidden_outputs[arg_t],
encoder_inputs[t],
argument_decoder_hidden_outputs[t-1],
encoder_outputs_padded
}))
argument_decoder_hidden_outputs[t], argument_decoder_attention_outputs[t] = unpack(argument_decoder_gru_outputs)
-- Copy current token if > 0.5
if argument_decoder_outputs[t][1] > 0.5 then
arg_out_s = arg_out_s .. ' ' .. '1'
sampled_tokens = sampled_tokens .. sentence_tokens[t] .. ' '
else
arg_out_s = arg_out_s .. ' ' .. '0'
end
end
arg_out_s = arg_out_s .. ' }'
sampled_full = sampled_full:gsub(arg_name, '( ' .. arg_name .. ' = ' .. sampled_tokens .. ')')
print(arg_name, '~>', arg_out_s)
local target = argument_targets[arg_name]
if target ~= nil then
print(arg_name, '=>', asString(target[{{1, #encoder_inputs}}]))
else
print('! no target')
end
end
print('sampled', sampled_full)
end
function asString(t)
s = ''
for i = 1, t:size()[1] do
s = s .. ' ' .. t[i]
end
return '{' .. s .. ' }'
end
sample()
function feval(params_)
if params_ ~= params then
params:copy(params_)
end
grad_params:zero()
local loss = 0
-- Inputs and targets
local input_sentence, encoder_inputs, command_decoder_inputs, command_decoder_targets, argument_decoder_targets = unpack(makeSentence())
-- Forward through encoder
local encoder_outputs = {[0] = torch.zeros(opt.hidden_size)}
for t = 1, #encoder_inputs do
encoder_outputs[t] = clones.encoder[t]:forward({
encoder_inputs[t],
encoder_outputs[t-1]
})
end
-- Pad encoder outputs with 0s
last_encoder_output = encoder_outputs[#encoder_inputs]
local encoder_outputs_padded = {}
for t = 1, #encoder_outputs do
table.insert(encoder_outputs_padded, encoder_outputs[t])
end
for t = 1, (opt.max_length - #encoder_outputs) do
table.insert(encoder_outputs_padded, torch.zeros(opt.hidden_size))
end
-- Forward through command decoder
local command_decoder_hidden_outputs = {[0] = last_encoder_output}
local command_decoder_outputs = {}
local command_decoder_output_indexes = {}
local command_decoder_output_argument_indexes = {}
for t = 1, #command_decoder_inputs do
local command_decoder_output = clones.command_decoder[t]:forward({
command_decoder_inputs[t],
command_decoder_hidden_outputs[t-1],
encoder_outputs_padded
})
command_decoder_outputs[t] = command_decoder_output[1]
command_decoder_hidden_outputs[t] = command_decoder_output[2][1]
_, command_decoder_output_indexes[t] = command_decoder_outputs[t]:max(1)
loss = loss + clones.command_decoder_criterion[t]:forward(command_decoder_outputs[t], command_decoder_targets[t])
end
-- Arguments
---=========================================================================
-- Set up gradients
local d_encoder = torch.zeros(#encoder_inputs, opt.hidden_size)
local d_command_decoder_hidden = {[#command_decoder_inputs] = torch.zeros(opt.hidden_size)}
local d_command_decoder_in = {}
local d_command_decoder_outputs = torch.zeros(opt.max_length, n_command_words + 1)
-- First index command argument words to get related context
local command_argument_ts = {}
local command_argument_indexes = {}
for t = 1, #command_decoder_targets - 1 do
local command_index = command_decoder_targets[t][1]
local command_word = command_index_to_word[command_index]
if command_word:match(token_re) then
command_argument_ts[command_word] = t
command_argument_indexes[command_word] = torch.LongTensor({command_index})
else
-- We know they will have no gradient otherwise
d_command_decoder_outputs[t] = torch.zeros(n_command_words + 1)
end
end
-- For each command argument output ...
for arg_name, arg_value in pairs(argument_decoder_targets) do
local arg_t = command_argument_ts[arg_name]
local arg_index = command_argument_indexes[arg_name]
-- Forward through argument decoder
local argument_decoder_hidden_outputs = {[0] = last_encoder_output}
local argument_decoder_attention_outputs = {}
local argument_decoder_outputs = {}
local argument_decoder_output_indexes = {}
for t = 1, #encoder_inputs do
argument_decoder_outputs[t], argument_decoder_gru_outputs = unpack(clones.argument_decoder[t]:forward({
command_decoder_hidden_outputs[arg_t],
encoder_inputs[t],
argument_decoder_hidden_outputs[t-1],
encoder_outputs_padded
}))
argument_decoder_hidden_outputs[t], argument_decoder_attention_outputs[t] = unpack(argument_decoder_gru_outputs)
loss = loss + clones.argument_decoder_criterion[t]:forward(
argument_decoder_outputs[t],
arg_value[{{t}}]
)
end
-- Backward through argument decoder
local d_argument_decoder_out = {}
local d_argument_decoder_all = {}
local d_argument_decoder_hidden_command = {}
local d_argument_decoder_hidden_encoder = {}
local d_argument_decoder_hidden = {[#encoder_inputs] = torch.zeros(opt.hidden_size)}
local d_argument_decoder_in = torch.zeros(opt.max_length, opt.hidden_size)
for t = #encoder_inputs, 1, -1 do
-- decoder out < targets
d_argument_decoder_out[t] = clones.argument_decoder_criterion[t]:backward(
argument_decoder_outputs[t],
arg_value[{{t}}]
)
-- -- < decoder
d_argument_decoder_all[t] = clones.argument_decoder[t]:backward(
{
command_decoder_hidden_outputs[arg_t],
encoder_inputs[t],
argument_decoder_hidden_outputs[t-1],
encoder_outputs_padded
},
{
d_argument_decoder_out[t],
{
d_argument_decoder_hidden[t],
torch.zeros(opt.max_length)
}
}
)
d_argument_decoder_in[t], d_encoder_out_t, d_argument_decoder_hidden[t-1], d_encoder_all = unpack(d_argument_decoder_all[t])
-- Attention -> encoder gradients
for tt = 1, #encoder_inputs do
d_encoder[tt]:add(d_encoder_all[tt])
end
end
-- Last encoder output was initial hidden state
d_encoder[#encoder_inputs]:add(d_argument_decoder_hidden[0])
end
-- Backward through command decoder
for t = #command_decoder_inputs, 1, -1 do
-- decoder out < targets
d_command_decoder_outputs_t = clones.command_decoder_criterion[t]:backward(
command_decoder_outputs[t], command_decoder_targets[t])
if d_command_decoder_outputs[t] == nil then
d_command_decoder_outputs[t] = d_command_decoder_outputs_t
else
d_command_decoder_outputs[t]:add(d_command_decoder_outputs_t)
end
-- -- < decoder
d_command_decoder = clones.command_decoder[t]:backward(
{
command_decoder_inputs[t],
command_decoder_hidden_outputs[t-1],
encoder_outputs_padded
},
{
d_command_decoder_outputs[t],
{
d_command_decoder_hidden[t],
torch.zeros(opt.max_length)
}
}
)
d_command_decoder_in[t], d_command_decoder_hidden[t-1], d_encoder_all = unpack(d_command_decoder)
for tt = 1, #encoder_inputs do
d_encoder[tt]:add(d_encoder_all[tt])
end
end
-- Last encoder output was initial hidden state
d_encoder[#encoder_inputs]:add(d_command_decoder_hidden[0])
-- Backward through encoder
for t = #encoder_inputs, 1, -1 do
local _, d_encoder_t_1 = unpack(clones.encoder[t]:backward(
{encoder_inputs[t], encoder_outputs[t-1] or torch.zeros(opt.hidden_size)},
d_encoder[t]
))
if t > 1 then
d_encoder[t-1] = d_encoder_t_1
end
end
return loss, grad_params
end
losses = {}
loss_sofar = 0
learning_rates = {}
plot_every = 100
sample_every = 50
save_every = 5000
optim_state = {
learningRate = opt.learning_rate,
learningRateDecay = opt.learning_rate_decay
}
function save()
print('Saving...')
torch.save('models.t7', models)
torch.save('opt.t7', opt)
end
LOSS_PLOT_CUTOFF = 5 * plot_every
print(string.format("Training for %s epochs...", opt.n_epochs))
for n_epoch = 1, opt.n_epochs do
local _, loss = optim.adam(feval, params, optim_state)
loss_sofar = loss_sofar + loss[1]
-- Plot every plot_every
if n_epoch % plot_every == 0 then
if loss_sofar > 0 and loss_sofar < LOSS_PLOT_CUTOFF then
plot({x=n_epoch, y=loss_sofar/plot_every, series=opt.series})
end
loss_sofar = 0
end
-- Sample every sample_every
if n_epoch % sample_every == 0 then
sample()
print(n_epoch, loss[1])
end
-- Save every save_every
if n_epoch % save_every == 0 then
save()
end
end