Skip to content

Commit f0b27bb

Browse files
authored
Add files via upload
1 parent aaf6c50 commit f0b27bb

18 files changed

+3477
-1
lines changed

DBCNN/DBCNN_train_attack.py

+402
Large diffs are not rendered by default.

DBCNN/SCNN.py

+289
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
import os
2+
3+
import torch
4+
import torchvision
5+
import torch.nn.functional as F
6+
import torch.nn as nn
7+
import numpy as np
8+
import DBCNN.WPFolder
9+
from PIL import Image
10+
11+
torch.manual_seed(0)
12+
torch.cuda.manual_seed_all(0)
13+
14+
15+
def pil_loader(path):
16+
# open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
17+
with open(path, 'rb') as f:
18+
img = Image.open(f)
19+
return img.convert('RGB')
20+
21+
22+
def accimage_loader(path):
23+
import accimage
24+
try:
25+
return accimage.Image(path)
26+
except IOError:
27+
# Potentially a decoding problem, fall back to PIL.Image
28+
return pil_loader(path)
29+
30+
31+
def default_loader(path):
32+
from torchvision import get_image_backend
33+
if get_image_backend() == 'accimage':
34+
return accimage_loader(path)
35+
else:
36+
return pil_loader(path)
37+
38+
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif']
39+
40+
def weight_init(net):
41+
for m in net.modules():
42+
if isinstance(m, nn.Conv2d):
43+
nn.init.kaiming_normal_(m.weight.data,nonlinearity='relu')
44+
m.bias.data.zero_()
45+
elif isinstance(m, nn.Linear):
46+
nn.init.kaiming_normal_(m.weight.data,nonlinearity='relu')
47+
m.bias.data.zero_()
48+
elif isinstance(m, nn.BatchNorm2d):
49+
m.weight.data.fill_(1)
50+
m.bias.data.zero_()
51+
52+
53+
54+
class SCNN(nn.Module):
55+
56+
def __init__(self):
57+
"""Declare all needed layers."""
58+
super(SCNN, self).__init__()
59+
60+
# Linear classifier.
61+
62+
self.num_class = 39
63+
# self.features = nn.Sequential(nn.Conv2d(3,48,3,1,1),nn.ReLU(inplace=True),
64+
# nn.Conv2d(48,48,3,2,1),nn.ReLU(inplace=True),
65+
# nn.Conv2d(48,64,3,1,1),nn.ReLU(inplace=True),
66+
# nn.Conv2d(64,64,3,2,1),nn.ReLU(inplace=True),
67+
# nn.Conv2d(64,64,3,1,1),nn.ReLU(inplace=True),
68+
# nn.Conv2d(64,64,3,2,1),nn.ReLU(inplace=True),
69+
# nn.Conv2d(64,128,3,1,1),nn.ReLU(inplace=True),
70+
# nn.Conv2d(128,128,3,1,1),nn.ReLU(inplace=True),
71+
# nn.Conv2d(128,128,3,2,1),nn.ReLU(inplace=True))
72+
self.features = nn.Sequential(nn.Conv2d(3,48,3,1,1),nn.BatchNorm2d(48),nn.ReLU(inplace=True),
73+
nn.Conv2d(48,48,3,2,1),nn.BatchNorm2d(48),nn.ReLU(inplace=True),
74+
nn.Conv2d(48,64,3,1,1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),
75+
nn.Conv2d(64,64,3,2,1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),
76+
nn.Conv2d(64,64,3,1,1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),
77+
nn.Conv2d(64,64,3,2,1),nn.BatchNorm2d(64),nn.ReLU(inplace=True),
78+
nn.Conv2d(64,128,3,1,1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),
79+
nn.Conv2d(128,128,3,1,1),nn.BatchNorm2d(128),nn.ReLU(inplace=True),
80+
nn.Conv2d(128,128,3,2,1),nn.BatchNorm2d(128),nn.ReLU(inplace=True))
81+
weight_init(self.features)
82+
self.pooling = nn.AvgPool2d(14,1)
83+
self.projection = nn.Sequential(nn.Conv2d(128,256,1,1,0), nn.BatchNorm2d(256), nn.ReLU(inplace=True),
84+
nn.Conv2d(256,256,1,1,0), nn.BatchNorm2d(256), nn.ReLU(inplace=True))
85+
weight_init(self.projection)
86+
self.classifier = nn.Linear(256,self.num_class)
87+
weight_init(self.classifier)
88+
89+
def forward(self, X):
90+
# return X
91+
N = X.size()[0]
92+
assert X.size() == (N, 3, 224, 224)
93+
X = self.features(X)
94+
assert X.size() == (N, 128, 14, 14)
95+
X = self.pooling(X)
96+
assert X.size() == (N, 128, 1, 1)
97+
X = self.projection(X)
98+
X = X.view(X.size(0), -1)
99+
X = self.classifier(X)
100+
assert X.size() == (N, self.num_class)
101+
return X
102+
103+
class SCNNManager(object):
104+
"""Manager class to train S-CNN.
105+
"""
106+
def __init__(self, options, path):
107+
"""Prepare the network, criterion, solver, and data.
108+
Args:
109+
options, dict: Hyperparameters.
110+
"""
111+
print('Prepare the network and data.')
112+
self._options = options
113+
self._path = path
114+
self._epoch = 0
115+
# Network.
116+
network = SCNN()
117+
weight_init(network)
118+
#self._net = network.cuda()
119+
self._net = torch.nn.DataParallel(network).cuda()
120+
121+
logspaced_LR = np.logspace(-1,-4, self._options['epochs'])
122+
# Load the model from disk.
123+
checkpoints_list = os.listdir(self._path['model'])
124+
if len(checkpoints_list) != 0:
125+
self._net.load_state_dict(torch.load(os.path.join(self._path['model'],'%s%s%s' % ('net_params', str(len(checkpoints_list)-1), '.pkl'))))
126+
self._epoch = len(checkpoints_list)
127+
self._options['base_lr'] = logspaced_LR[len(checkpoints_list)]
128+
#self._net.load_state_dict(torch.load(self._path['model']))
129+
print(self._net)
130+
# Criterion.
131+
self._criterion = torch.nn.CrossEntropyLoss().cuda()
132+
# Solver.
133+
self._solver = torch.optim.SGD(
134+
self._net.parameters(), lr=self._options['base_lr'],
135+
momentum=0.9, weight_decay=self._options['weight_decay'])
136+
# self._solver = torch.optim.Adam(
137+
# self._net.parameters(), lr=self._options['base_lr'],
138+
# weight_decay=self._options['weight_decay'])
139+
lambda1 = lambda epoch: logspaced_LR[epoch]
140+
self._scheduler = torch.optim.lr_scheduler.LambdaLR(self._solver,lr_lambda=lambda1)
141+
142+
train_transforms = torchvision.transforms.Compose([
143+
torchvision.transforms.Resize(size=256), # Let smaller edge match
144+
torchvision.transforms.RandomHorizontalFlip(),
145+
torchvision.transforms.RandomCrop(size=224),
146+
torchvision.transforms.ToTensor(),
147+
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
148+
std=(0.229, 0.224, 0.225))
149+
])
150+
test_transforms = torchvision.transforms.Compose([
151+
torchvision.transforms.Resize(size=256),
152+
torchvision.transforms.CenterCrop(size=224),
153+
torchvision.transforms.ToTensor(),
154+
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406),
155+
std=(0.229, 0.224, 0.225))
156+
])
157+
train_data = WPFolder.WPFolder(
158+
root=self._path['waterloo_pascal'], loader = default_loader, extensions = IMG_EXTENSIONS,
159+
transform=train_transforms,train = True, ratio = 0.8)
160+
test_data = WPFolder.WPFolder(
161+
root=self._path['waterloo_pascal'], loader = default_loader, extensions = IMG_EXTENSIONS,
162+
transform=test_transforms, train = False, ratio = 0.8)
163+
self._train_loader = torch.utils.data.DataLoader(
164+
train_data, batch_size=self._options['batch_size'],
165+
shuffle=True, num_workers=0, pin_memory=True)
166+
self._test_loader = torch.utils.data.DataLoader(
167+
test_data, batch_size=self._options['batch_size'],
168+
shuffle=False, num_workers=0, pin_memory=True)
169+
170+
def train(self):
171+
"""Train the network."""
172+
print('Training.')
173+
best_acc = 0.0
174+
best_epoch = None
175+
print('Epoch\tTrain loss\tTrain acc\tTest acc')
176+
for t in range(self._epoch,self._options['epochs']):
177+
epoch_loss = []
178+
num_correct = 0.0
179+
num_total = 0.0
180+
batchindex = 0
181+
for X, y in self._train_loader:
182+
X = torch.tensor(X.cuda())
183+
y = torch.tensor(y.cuda(non_blocking=True)) #async=True
184+
#y = torch.tensor(y.to(device))
185+
186+
# Clear the existing gradients.
187+
self._solver.zero_grad()
188+
# Forward pass.
189+
score = self._net(X)
190+
loss = self._criterion(score, y.detach())
191+
epoch_loss.append(loss.item())
192+
193+
# Prediction.
194+
_, prediction = torch.max(F.softmax(score.data), 1)
195+
num_total += y.size(0)
196+
num_correct += torch.sum(prediction == y)
197+
# Backward pass.
198+
loss.backward()
199+
self._solver.step()
200+
batchindex = batchindex + 1
201+
print('%d epoch done' % (t+1))
202+
train_acc = 100 * num_correct.float() / num_total
203+
if (t < 2) | (t > 20):
204+
with torch.no_grad():
205+
test_acc = self._accuracy(self._test_loader)
206+
if test_acc > best_acc:
207+
best_acc = test_acc
208+
best_epoch = t + 1
209+
print('*', end='')
210+
print('%d\t%4.3f\t\t%4.2f%%\t\t%4.2f%%' %
211+
(t+1, sum(epoch_loss) / len(epoch_loss), train_acc, test_acc))
212+
pwd = os.getcwd()
213+
modelpath = os.path.join(pwd,'models',('net_params' + str(t) + '.pkl'))
214+
torch.save(self._net.state_dict(), modelpath)
215+
self._scheduler.step(t)
216+
print('Best at epoch %d, test accuaray %f' % (best_epoch, best_acc))
217+
218+
def _accuracy(self, data_loader):
219+
"""Compute the train/test accuracy.
220+
Args:
221+
data_loader: Train/Test DataLoader.
222+
Returns:
223+
Train/Test accuracy in percentage.
224+
"""
225+
self._net.eval()
226+
num_correct = 0.0
227+
num_total = 0.0
228+
batchindex = 0
229+
for X, y in data_loader:
230+
# Data.
231+
batchindex = batchindex + 1
232+
X = torch.tensor(X.cuda())
233+
y = torch.tensor(y.cuda(non_blocking=True)) #async=True
234+
235+
# Prediction.
236+
score = self._net(X)
237+
_, prediction = torch.max(score.data, 1)
238+
num_total += y.size(0)
239+
num_correct += torch.sum(prediction == y.data)
240+
self._net.train() # Set the model to training phase
241+
return 100 * num_correct.float() / num_total
242+
243+
244+
245+
def main():
246+
"""The main function."""
247+
import argparse
248+
parser = argparse.ArgumentParser(
249+
description='Train DB-CNN for BIQA.')
250+
parser.add_argument('--base_lr', dest='base_lr', type=float, default=1e-1,
251+
help='Base learning rate for training.')
252+
parser.add_argument('--batch_size', dest='batch_size', type=int,
253+
default=128, help='Batch size.')
254+
parser.add_argument('--epochs', dest='epochs', type=int,
255+
default=30, help='Epochs for training.')
256+
parser.add_argument('--weight_decay', dest='weight_decay', type=float,
257+
default=5e-4, help='Weight decay.')
258+
259+
args = parser.parse_args()
260+
if args.base_lr <= 0:
261+
raise AttributeError('--base_lr parameter must >0.')
262+
if args.batch_size <= 0:
263+
raise AttributeError('--batch_size parameter must >0.')
264+
if args.epochs < 0:
265+
raise AttributeError('--epochs parameter must >=0.')
266+
if args.weight_decay <= 0:
267+
raise AttributeError('--weight_decay parameter must >0.')
268+
269+
270+
options = {
271+
'base_lr': args.base_lr,
272+
'batch_size': args.batch_size,
273+
'epochs': args.epochs,
274+
'weight_decay': args.weight_decay,
275+
}
276+
277+
278+
path = {
279+
'waterloo_pascal': 'Z:\Waterloo\exploration_database_and_code\image',
280+
'model': 'D:\zwx_Project\dbcnn_pytorch\models'
281+
}
282+
283+
manager = SCNNManager(options, path)
284+
# manager.getStat()
285+
manager.train()
286+
287+
288+
if __name__ == '__main__':
289+
main()

0 commit comments

Comments
 (0)