-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcl_pnn.py
117 lines (108 loc) · 5.66 KB
/
cl_pnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""
Continuous learning with MLP-PNN.
Reference: Progressive Neural Networks (Google DeepMind)
@author huangyz0918
@date 16/09/2021
"""
import time
import torch
import neptune
import argparse
import numpy as np
from model import MLP_PNN, parameter_number
from model import Trainer, Evaluator, get_dataloader_keyword
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Input optional guidance for training")
parser.add_argument("--epoch", default=10, type=int, help="The number of training epoch")
parser.add_argument("--lc", default=False, action='store_true',
help="Test on the task with/without lateral connections")
parser.add_argument("--lr", default=0.001, type=float, help="Learning rate")
parser.add_argument("--batch", default=128, type=int, help="Training batch size")
parser.add_argument("--step", default=30, type=int, help="Training step size")
parser.add_argument("--gpu", default=4, type=int, help="Number of GPU device")
parser.add_argument("--log", default=False, action='store_true',
help="record the experiment into web neptune.ai")
parser.add_argument("--ek", default=False, action='store_true', help="evaluate the CL by keywords")
parser.add_argument("--dpath", default="./dataset", type=str, help="The path of dataset")
parser.add_argument("--tqdm", default=False, action='store_true', help="enable terminal tqdm output.")
parser.add_argument("--freq", default=30, type=int, help="Model saving frequency (in step)")
parser.add_argument("--save", default="stft", type=str, help="The save name")
parameters = parser.parse_args()
class_list_0 = ["yes", "no", "nine",
"three", "bed", "up",
"down", "wow", "happy",
"four", "stop", "go",
"dog", "cat", "five"]
class_list_1 = ["tree", "one", "eight"]
class_list_2 = ["left", "right", "bird"]
class_list_3 = ["seven", "six", "two"]
class_list_4 = ["marvin", "on", "sheila"]
class_list_5 = ["off", "house", "zero"]
# initialize and setup Neptune
if parameters.log:
neptune.init('huangyz0918/kws')
neptune.create_experiment(name='kws_model', tags=['pytorch', 'KWS', 'GSC', 'PNN'], params=vars(parameters))
class_list = []
learning_tasks = [class_list_0, class_list_1, class_list_2, class_list_3, class_list_4, class_list_5]
for task in learning_tasks:
class_list += task
# initializing the PNN model.
model = MLP_PNN(256, 129, 129 * 125)
# start continuous learning.
model.add_column(len(learning_tasks[0])) # add the first column for the PNN.
trainer = Trainer(parameters, model)
la_list = []
acc_list = []
bwt_list = []
learned_class_list = []
start_time = time.time()
for task_id, task_class in enumerate(learning_tasks):
print(">>> Learned Class: ", learned_class_list, " To Learn: ", task_class)
learned_class_list += task_class
class_encoding = {category: index for index, category in enumerate(task_class)}
train_loader, test_loader = get_dataloader_keyword(parameters.dpath, task_class, class_encoding,
parameters.batch)
# smaller column sizes from 2nd task inwards to limit expansion.
if task_id > 0:
trainer.model.add_column(len(task_class), hsize=32)
optimizer = torch.optim.SGD(model.parameters(), lr=parameters.lr, momentum=0.9)
# fine-tune the whole model.
if parameters.log:
trainer.model_train(task_id, optimizer, train_loader, test_loader, is_pnn=True, tag=f'task{task_id}')
else:
trainer.model_train(task_id, optimizer, train_loader, test_loader, is_pnn=True)
# start evaluating the CL on previous tasks.
total_learned_acc = 0
if parameters.ek:
evaluate_list = class_list
else:
evaluate_list = learning_tasks
for val_id in range(task_id + 1):
print(f">>> Testing on task {val_id}, Keywords: {evaluate_list[val_id]}")
test_encoding = {category: index for index, category in enumerate(evaluate_list[val_id])}
_, val_loader = get_dataloader_keyword(parameters.dpath, evaluate_list[val_id], test_encoding,
parameters.batch)
if parameters.log:
evaluator = Evaluator(trainer.model, tag=f't{task_id}v{val_id}')
else:
evaluator = Evaluator(trainer.model)
if parameters.lc:
log_data = evaluator.pnn_evaluate(val_id, val_loader, with_lateral_con=True)
else:
log_data = evaluator.pnn_evaluate(val_id, val_loader)
if parameters.log:
neptune.log_metric(f'TASK-{task_id}-acc', log_data["test_accuracy"])
if val_id <= task_id:
total_learned_acc += log_data["test_accuracy"]
if val_id == task_id:
la_list.append(log_data["test_accuracy"])
acc_list.append(total_learned_acc / (task_id + 1))
print(f'ACC on TASK {task_id}: {total_learned_acc / (task_id + 1)}')
if task_id > 0:
bwt_list.append(np.mean([acc_list[i + 1] - acc_list[i] for i in range(len(acc_list) - 1)]))
duration = time.time() - start_time
print(f'Total time {duration}, Avg: {duration / len(learning_tasks)}s')
print(f'ACC: {np.mean(acc_list)}, std: {np.std(acc_list)}')
print(f'LA: {np.mean(la_list)}, std: {np.std(la_list)}')
print(f'BWT: {np.mean(bwt_list)}, std: {np.std(bwt_list)}')
print(f'Parameter: {parameter_number(trainer.model) / 1024} K')