-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathtrain_TransCAM.py
107 lines (86 loc) · 4.48 KB
/
train_TransCAM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import numpy as np
import torch
import os
from torch import optim
from torch.utils.data import DataLoader
from torchvision import transforms
import voc12.data
from tool import pyutils, imutils
import argparse
import importlib
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--max_epoches", default=20, type=int)
parser.add_argument("--network", default="network.conformer_CAM", type=str)
parser.add_argument("--lr", default=5e-5, type=float)
parser.add_argument("--num_workers", default=0, type=int)
parser.add_argument("--wt_dec", default=5e-4, type=float)
parser.add_argument("--train_list", default="voc12/train_aug.txt", type=str)
parser.add_argument("--arch", default='sm', type=str)
parser.add_argument("--val_list", default="voc12/val.txt", type=str)
parser.add_argument("--session_name", default="TransCAM", type=str)
parser.add_argument("--crop_size", default=512, type=int)
parser.add_argument("--weights", required=True, type=str)
parser.add_argument("--voc12_root", default='VOCdevkit/VOC2012', type=str)
parser.add_argument("--tblog_dir", default='./tblog', type=str)
parser.add_argument("--save_dir", default='./', type=str)
args = parser.parse_args()
pyutils.Logger(args.session_name + '.log')
print(vars(args))
model = getattr(importlib.import_module(args.network), 'Net_' + args.arch)()
print(model)
tblogger = SummaryWriter(args.tblog_dir)
train_dataset = voc12.data.VOC12ClsDataset(args.train_list, voc12_root=args.voc12_root,
transform=transforms.Compose([
imutils.RandomResizeLong(320, 640),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3,
hue=0.1),
np.asarray,
imutils.Normalize(),
imutils.RandomCrop(args.crop_size),
imutils.HWC_to_CHW,
torch.from_numpy
]))
train_data_loader = DataLoader(train_dataset, batch_size=args.batch_size,
shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wt_dec, eps=1e-8)
checkpoint = torch.load(args.weights, map_location='cpu')
if 'model' in checkpoint.keys():
checkpoint = checkpoint['model']
else:
checkpoint = checkpoint
model_dict = model.state_dict()
for k in ['trans_cls_head.weight', 'trans_cls_head.bias']:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint[k]
for k in ['conv_cls_head.weight', 'conv_cls_head.bias']:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint[k]
pretrained_dict = {k: v for k, v in checkpoint.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model = torch.nn.DataParallel(model).cuda()
model.train()
avg_meter = pyutils.AverageMeter('loss')
for ep in range(args.max_epoches):
for iter, pack in enumerate(train_data_loader):
_, img, label = pack
N, C, H, W = img.size()
bg_score = torch.ones((N, 1))
label = torch.cat((bg_score, label), dim=1)
label = label.cuda(non_blocking=True).unsqueeze(2).unsqueeze(3)
logits_conv, logits_trans, cams = model('transcam', img)
loss = F.multilabel_soft_margin_loss((logits_conv + logits_trans).unsqueeze(2).unsqueeze(3)[:, 1:, :, :], label[:, 1:, :, :])
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_meter.add({'loss': loss.item()})
else:
print('epoch: %5d' % ep,
'loss: %.4f' % avg_meter.get('loss'), flush=True)
avg_meter.pop()
torch.save(model.module.state_dict(), os.path.join(args.save_dir, args.session_name + '.pth'))