-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
70 lines (63 loc) · 2.22 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import trm
import utils
import torch
from torch import nn, optim
def train(model, enc_x, dec_x, target,
epoches=1000, lr=5e-5,
device='cpu', show_loss=True,
save=True, path=None):
model = model.to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
losses = [ ]
for i in range(epoches):
optimizer.zero_grad()
output = model(enc_x, dec_x)
loss = criterion(output, target.contiguous().view(-1))
if (i + 1) % 10 == 0:
print(f"Epoch {i + 1}/{epoches} Loss: {loss.item()}")
losses.append(loss.item())
loss.backward()
optimizer.step()
if show_loss:
utils.draw_loss(losses)
if save:
torch.save(model, path)
if __name__ == '__main__':
config = utils.load_config()
vocab = utils.load_vocab()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'
e_x = [
'I very like you',
'I very very very like you',
]
d_x = [
'<sta> I love you',
'<sta> I very love you',
]
t_x = [
'I love you <end>',
'I very love you <end>',
]
e_x = utils.sen2vec(e_x, vocab, config[ 'max_len' ]).to(device)
d_x = utils.sen2vec(d_x, vocab, config[ 'max_len' ]).to(device)
t_x = utils.sen2vec(t_x, vocab, config[ 'max_len' ]).to(device)
model = trm.Transformer(config[ 'n' ],
config[ 'n_vocab' ],
config[ 'd_model' ],
config[ 'd_k' ],
config[ 'd_v' ],
config[ 'n_head' ],
config[ 'd_ff' ],
config[ 'pad_token' ],
config[ 'max_len' ],
config[ 'dropout' ],
device)
train(model, e_x, d_x, t_x, device=device, path=config[ 'save_path' ])
# model = torch.load(config['save_path'])
pred = model.greedy_decoder(e_x,
vocab[ '<sta>' ],
vocab[ '<end>' ],
vocab[ '<pad>' ])
print(utils.vec2sen(pred, vocab))