-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
342 lines (301 loc) · 15.3 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
"""
Main training script for the various architectures considered in the paper.
The BondiNet and AABondiNet architectures are trained on the custom block dataset, while the other architectures are trained on the 8x8 grid dataset.
Please note that these architectures can vary the value for the stride used in the first layer.
Authors:
Edoardo Daniele Cannas - edoardodanielecannas@polimi.it
Sara Mandelli - sara.mandelli@polimi.it
Paolo Bestagini - paolo.bestagini@polimi.it
Stefano Tubaro - stefano.tubaro@polimi.it
"""
# TODO: fix the data in the new folders for sharing
# --- Libraries import
from tqdm import tqdm
import os
import argparse
import shutil
import warnings
import numpy as np
import torch
torch.manual_seed(21)
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')
from isplutils.data import CustomBlockJPEGBalancedDataset, JPEG8x8BalancedDataset, balanced_collate_fn
from albumentations.pytorch import ToTensorV2
import albumentations as A
from architectures.fornet import create_model
from architectures.utils import save_model, batch_forward
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import SubsetRandomSampler, DataLoader
from isplutils.utils import make_train_tag
import sys
# TODO: debug the new dataloader
# --- Main script
def main(args):
# --- Parse arguments
gpu = args.gpu
model_name = args.model
batch_size = args.batch_size
lr = args.lr
min_lr = args.min_lr
es_patience = args.es_patience
sched_patience = args.sched_patience
init_period = args.init_period
epochs = args.epochs
workers = args.workers
p_train_val = args.perc_train_val
p_train_test = args.perc_train_test
weights_folder = args.models_dir
logs_folder = args.log_dir
data_root = args.data_root
debug = args.debug
suffix = args.suffix
initial_model = args.init
train_from_scratch = args.scratch
in_channels = 1 if args.grayscale else 3
fl_stride = args.first_layer_stride
aa_pool_only = args.aa_pool_only
jpeg_bs = args.jpeg_bs if 'BondiNet' in model_name else 8 # Block size = 8 for the SOTA models, whatever for the BondiNet
random_crop = args.random_crop
if (fl_stride is not None) and ('BondiNet' in model_name):
# For the BondiNet models, the patch size must be a multiple of 64*fl_stride
# PLEASE NOTE that the Dataset will then adjust the patch size to be aligned to the JPEG grid!
patch_size = 64*fl_stride
else:
# For the SOTA models, we are considering a 224x224 patch size
patch_size = 224
# --- GPU configuration
device = 'cuda:{}'.format(gpu) if torch.cuda.is_available() else 'cpu'
# --- Instantiate network
params = {'in_channels': in_channels, 'num_classes': 1, 'first_layer_stride': fl_stride,
'pool_only': aa_pool_only} # Create a dictionary of parameters
net = create_model(model_name, params, device) # call the factory function to create the model
# --- Instantiate Dataset and DataLoader
if 'BondiNet' in model_name:
# Using the custom block dataset for the BondiNet experiments
transforms = [A.RandomCrop(patch_size, patch_size), ToTensorV2()] if random_crop else [ToTensorV2()]
transforms = A.Compose(transforms)
dataset = CustomBlockJPEGBalancedDataset(data_root=data_root, patch_size=patch_size, transforms=transforms,
grayscale=args.grayscale, jpeg_bs=jpeg_bs)
else:
# Using the 8x8 grid dataset for the other experiments
net_normalizer = net.get_normalizer()
transforms = [A.Normalize(mean=net_normalizer.mean, std=net_normalizer.std), ToTensorV2()]
transforms = A.Compose(transforms)
dataset = JPEG8x8BalancedDataset(data_root=data_root, patch_size=patch_size, transforms=transforms,
grayscale=args.grayscale, disaligned_grid_patch=random_crop, )
# Split in training and validation
dataset_idxs = list(range(len(dataset)))
np.random.seed(args.split_seed) # setting the seed for training-val split
np.random.shuffle(dataset_idxs)
test_split_index = int(np.floor((1 - p_train_test) * len(dataset)))
train_val_idxs, test_idxs = dataset_idxs[test_split_index:], dataset_idxs[:test_split_index]
val_split_index = int(np.floor((1 - p_train_val) * len(train_val_idxs)))
train_idx, val_idx = train_val_idxs[val_split_index:], train_val_idxs[:val_split_index]
# --- Create Samplers
train_sampler = SubsetRandomSampler(train_idx)
val_sampler = SubsetRandomSampler(val_idx)
# --- Create DataLoaders
train_dl = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=workers, shuffle=False, drop_last=True,
sampler=train_sampler, collate_fn=balanced_collate_fn,)
val_dl = DataLoader(dataset=dataset, batch_size=batch_size, num_workers=workers, shuffle=False, drop_last=True,
sampler=val_sampler, collate_fn=balanced_collate_fn,)
# --- DEBUG DATASET --- #
if debug:
for batch_data in tqdm(train_dl, desc='Training loader', leave=False, total=len(train_dl)):
img, label = batch_data
# --- Optimization
optimizer = torch.optim.Adam(net.get_trainable_parameters(), lr=lr)
criterion = torch.nn.BCEWithLogitsLoss()
if sched_patience is not None:
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='min',
patience=sched_patience, verbose=True)
else:
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer=optimizer,
T_0=init_period, eta_min=min_lr,
verbose=True)
# --- Checkpoint paths
train_tag = make_train_tag(net_class=model_name, lr=lr, batch_size=batch_size, p_train_val=p_train_val,
p_train_test=p_train_test, split_seed=args.split_seed, suffix=suffix, debug=debug,
in_channels=in_channels, init_period=init_period,
jpeg_bs=jpeg_bs, random_crop=random_crop,
fl_stride=fl_stride if 'BondiNet' in model_name else None,
aa_pool_only=aa_pool_only if 'AA' in model_name else None)
os.makedirs(os.path.join(weights_folder, train_tag), exist_ok=True)
bestval_path = os.path.join(weights_folder, train_tag, 'bestval.pth')
last_path = os.path.join(weights_folder, train_tag, 'last.pth')
# --- Load model from checkpoint
min_val_loss = 100
epoch = 0
net_state = None
opt_state = None
if initial_model is not None:
# If given load initial model
print('Loading model form: {}'.format(initial_model))
state = torch.load(initial_model, map_location='cpu')
net_state = state['net']
elif not train_from_scratch and os.path.exists(last_path):
print('Loading model form: {}'.format(last_path))
state = torch.load(last_path, map_location='cpu')
net_state = state['net']
opt_state = state['opt']
epoch = state['epoch']
if not train_from_scratch and os.path.exists(bestval_path):
state = torch.load(bestval_path, map_location='cpu')
min_val_loss = state['val_loss']
if net_state is not None:
incomp_keys = net.load_state_dict(net_state, strict=False)
print(incomp_keys)
if opt_state is not None:
for param_group in opt_state['param_groups']:
param_group['lr'] = lr
optimizer.load_state_dict(opt_state)
# --- Initialize Tensorboard
logdir = os.path.join(logs_folder, train_tag)
if epoch == 0:
# If training from scratch or initialization remove history if exists
shutil.rmtree(logdir, ignore_errors=True)
# --- Tensorboard instance
tb = SummaryWriter(log_dir=logdir)
if epoch == 0:
patch_size = patch_size if random_crop else jpeg_bs*round(patch_size/jpeg_bs)
dummy = torch.randn((1, in_channels if 'BondiNet' in model_name else 3, patch_size, patch_size), device=device)
dummy = dummy.to(device)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Dry-run first
net(dummy)
# Add the graph after
#tb.add_graph(net, [dummy, ], verbose=False)
# --- Training-validation loop
train_tot_it = 0
val_tot_it = 0
es_counter = 0
cur_lr = lr
epochs = 1 if debug else epochs
train_len = len(train_dl)
for e in range(epochs):
# Training
net.train()
optimizer.zero_grad()
train_loss = train_acc = train_num = 0
for batch_idx, batch_data in enumerate(tqdm(train_dl, desc='Training epoch {}'.format(e), leave=False, total=len(train_dl))):
# Fetch data
batch_img, batch_label = batch_data
# Forward pass
batch_loss, batch_acc = batch_forward(net, device, criterion, batch_img, batch_label)
# Backpropagation
batch_loss.backward()
optimizer.step()
optimizer.zero_grad()
# Statistics
batch_num = len(batch_label)
train_num += batch_num
train_tot_it += batch_num
train_loss += batch_loss.item() * batch_num
train_acc += batch_acc.item() * batch_num
# Iteration logging
tb.add_scalar('train/it-loss', batch_loss.item(), train_tot_it)
tb.add_scalar('train/it-acc', batch_acc.item(), train_tot_it)
# Validation
net.eval()
val_loss = val_acc = val_num = 0
for batch_data in tqdm(val_dl, desc='Validating epoch {}'.format(e), leave=False, total=len(val_dl)):
# Fetch data
batch_img, batch_label = batch_data
with torch.no_grad():
# Forward pass
batch_loss, batch_acc = batch_forward(net, device, criterion, batch_img, batch_label)
# Statistics
batch_num = len(batch_label)
val_num += batch_num
val_tot_it += batch_num
val_loss += batch_loss.item() * batch_num
val_acc += batch_acc.item() * batch_num
# Iteration logging
tb.add_scalar('validation/it-loss', batch_loss.item(), val_tot_it)
tb.add_scalar('validation/it-acc', batch_acc.item(), val_tot_it)
print('\nEpoch {}:\nTraining loss:{:.4f}, acc:{:.4f}\nValidation loss:{:.4f}, acc:{:.4f}'
.format(e, train_loss / train_num, train_acc / train_num, val_loss / val_num, val_acc / val_num))
# Logging
train_loss /= train_num
train_acc /= train_num
val_loss /= val_num
val_acc /= val_num
tb.add_scalar('lr', optimizer.param_groups[0]['lr'], e)
tb.add_scalar('train/epoch-loss', train_loss, e)
tb.add_scalar('train/epoch-accuracy', train_acc, e)
tb.add_scalar('validation/epoch-loss', val_loss, e)
tb.add_scalar('validation/epoch-accuracy', val_acc, e)
tb.flush()
# Scheduler step
if isinstance(lr_scheduler, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts):
# If it's CosineAnnealingWarmRestarts, pass the epoch fraction
lr_scheduler.step(e + batch_idx / train_len)
else:
# Otherwise, pass the validation loss
lr_scheduler.step(val_loss)
# Epoch checkpoint
save_model(net, optimizer, train_loss, val_loss, batch_size, epoch, last_path)
if val_loss < min_val_loss:
min_val_loss = val_loss
save_model(net, optimizer, train_loss, val_loss, batch_size, epoch, bestval_path)
es_counter = 0
else:
es_counter += 1
if optimizer.param_groups[0]['lr'] <= min_lr:
print('Reached minimum learning rate. Stopping.')
break
elif es_counter == es_patience:
print('Early stopping patience reached. Stopping.')
break
# Needed to flush out last events for the logger
tb.close()
print('Training completed! Bye!')
if __name__ == '__main__':
# --- Parse arguments
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--model', help='Model name', type=str, default='FOCALCNN',
choices=['BondiNet', 'AABondiNet', 'DenseNet121', 'AADenseNet121', 'ResNet50', 'AAResNet50',])
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--min_lr', type=float, default=1e-8)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--es_patience', type=int, default=10,
help='Patience for stopping the training if no improvement'
'on the validation loss is seen')
# Add mutually exclusive group for the two schedulers
scheduler_args = parser.add_mutually_exclusive_group(required=False)
scheduler_args.add_argument('--init_period', type=int, default=10, help='Period for the CosineAnnealingWarmRestart')
scheduler_args.add_argument('--sched_patience', type=int, default=None, help='Patience for the ReduceLROnPlateau scheduler')
parser.add_argument('--workers', type=int, default=os.cpu_count() // 2)
parser.add_argument('--perc_train_test', type=float, help='Fraction of trainval/test set', default=0.75)
parser.add_argument('--perc_train_val', type=float, help='Fraction of train/val set', default=0.75)
parser.add_argument('--split_seed', type=int, help='Random seed for training/validation split', default=42)
parser.add_argument('--data_root', type=str, required=True,
help='Path to the folder containing the datasets')
parser.add_argument('--jpeg_bs', type=int, help='Block size for the JPEG compression', default=8)
parser.add_argument('--log_dir', type=str, help='Directory for saving the training logs',
default='./logs')
parser.add_argument('--models_dir', type=str, help='Directory for saving the models weights',
default='./models')
parser.add_argument('--init', type=str, help='Weight initialization file')
parser.add_argument('--scratch', action='store_true',
help='Train from scratch the model, or use the last checkpoint available')
parser.add_argument('--debug', action='store_true', help='Activate debug')
parser.add_argument('--suffix', type=str, help='Suffix to default tag')
parser.add_argument('--grayscale', action='store_true', help='Whether to work on grayscale images or not')
parser.add_argument('--first_layer_stride', type=int,
help='Stride of the first layer for the BondiNet', default=2)
parser.add_argument('--aa_pool_only', action='store_true', help='Whether to use the BlurPool only in '
'the maxPool layers of the AABondiNets')
parser.add_argument('--random_crop', action='store_true', help='Whether to use random crop or not')
args = parser.parse_args()
# --- CALL MAIN FUNCTION --- #
try:
main(args)
except Exception as e:
print(e)
# --- Exit the script
sys.exit(0)