1
+ import os
2
+
3
+ import torch
4
+ import torchvision
5
+ import torch .nn .functional as F
6
+ import torch .nn as nn
7
+ import numpy as np
8
+ import DBCNN .WPFolder
9
+ from PIL import Image
10
+
11
+ torch .manual_seed (0 )
12
+ torch .cuda .manual_seed_all (0 )
13
+
14
+
15
+ def pil_loader (path ):
16
+ # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
17
+ with open (path , 'rb' ) as f :
18
+ img = Image .open (f )
19
+ return img .convert ('RGB' )
20
+
21
+
22
+ def accimage_loader (path ):
23
+ import accimage
24
+ try :
25
+ return accimage .Image (path )
26
+ except IOError :
27
+ # Potentially a decoding problem, fall back to PIL.Image
28
+ return pil_loader (path )
29
+
30
+
31
+ def default_loader (path ):
32
+ from torchvision import get_image_backend
33
+ if get_image_backend () == 'accimage' :
34
+ return accimage_loader (path )
35
+ else :
36
+ return pil_loader (path )
37
+
38
+ IMG_EXTENSIONS = ['.jpg' , '.jpeg' , '.png' , '.ppm' , '.bmp' , '.pgm' , '.tif' ]
39
+
40
+ def weight_init (net ):
41
+ for m in net .modules ():
42
+ if isinstance (m , nn .Conv2d ):
43
+ nn .init .kaiming_normal_ (m .weight .data ,nonlinearity = 'relu' )
44
+ m .bias .data .zero_ ()
45
+ elif isinstance (m , nn .Linear ):
46
+ nn .init .kaiming_normal_ (m .weight .data ,nonlinearity = 'relu' )
47
+ m .bias .data .zero_ ()
48
+ elif isinstance (m , nn .BatchNorm2d ):
49
+ m .weight .data .fill_ (1 )
50
+ m .bias .data .zero_ ()
51
+
52
+
53
+
54
+ class SCNN (nn .Module ):
55
+
56
+ def __init__ (self ):
57
+ """Declare all needed layers."""
58
+ super (SCNN , self ).__init__ ()
59
+
60
+ # Linear classifier.
61
+
62
+ self .num_class = 39
63
+ # self.features = nn.Sequential(nn.Conv2d(3,48,3,1,1),nn.ReLU(inplace=True),
64
+ # nn.Conv2d(48,48,3,2,1),nn.ReLU(inplace=True),
65
+ # nn.Conv2d(48,64,3,1,1),nn.ReLU(inplace=True),
66
+ # nn.Conv2d(64,64,3,2,1),nn.ReLU(inplace=True),
67
+ # nn.Conv2d(64,64,3,1,1),nn.ReLU(inplace=True),
68
+ # nn.Conv2d(64,64,3,2,1),nn.ReLU(inplace=True),
69
+ # nn.Conv2d(64,128,3,1,1),nn.ReLU(inplace=True),
70
+ # nn.Conv2d(128,128,3,1,1),nn.ReLU(inplace=True),
71
+ # nn.Conv2d(128,128,3,2,1),nn.ReLU(inplace=True))
72
+ self .features = nn .Sequential (nn .Conv2d (3 ,48 ,3 ,1 ,1 ),nn .BatchNorm2d (48 ),nn .ReLU (inplace = True ),
73
+ nn .Conv2d (48 ,48 ,3 ,2 ,1 ),nn .BatchNorm2d (48 ),nn .ReLU (inplace = True ),
74
+ nn .Conv2d (48 ,64 ,3 ,1 ,1 ),nn .BatchNorm2d (64 ),nn .ReLU (inplace = True ),
75
+ nn .Conv2d (64 ,64 ,3 ,2 ,1 ),nn .BatchNorm2d (64 ),nn .ReLU (inplace = True ),
76
+ nn .Conv2d (64 ,64 ,3 ,1 ,1 ),nn .BatchNorm2d (64 ),nn .ReLU (inplace = True ),
77
+ nn .Conv2d (64 ,64 ,3 ,2 ,1 ),nn .BatchNorm2d (64 ),nn .ReLU (inplace = True ),
78
+ nn .Conv2d (64 ,128 ,3 ,1 ,1 ),nn .BatchNorm2d (128 ),nn .ReLU (inplace = True ),
79
+ nn .Conv2d (128 ,128 ,3 ,1 ,1 ),nn .BatchNorm2d (128 ),nn .ReLU (inplace = True ),
80
+ nn .Conv2d (128 ,128 ,3 ,2 ,1 ),nn .BatchNorm2d (128 ),nn .ReLU (inplace = True ))
81
+ weight_init (self .features )
82
+ self .pooling = nn .AvgPool2d (14 ,1 )
83
+ self .projection = nn .Sequential (nn .Conv2d (128 ,256 ,1 ,1 ,0 ), nn .BatchNorm2d (256 ), nn .ReLU (inplace = True ),
84
+ nn .Conv2d (256 ,256 ,1 ,1 ,0 ), nn .BatchNorm2d (256 ), nn .ReLU (inplace = True ))
85
+ weight_init (self .projection )
86
+ self .classifier = nn .Linear (256 ,self .num_class )
87
+ weight_init (self .classifier )
88
+
89
+ def forward (self , X ):
90
+ # return X
91
+ N = X .size ()[0 ]
92
+ assert X .size () == (N , 3 , 224 , 224 )
93
+ X = self .features (X )
94
+ assert X .size () == (N , 128 , 14 , 14 )
95
+ X = self .pooling (X )
96
+ assert X .size () == (N , 128 , 1 , 1 )
97
+ X = self .projection (X )
98
+ X = X .view (X .size (0 ), - 1 )
99
+ X = self .classifier (X )
100
+ assert X .size () == (N , self .num_class )
101
+ return X
102
+
103
+ class SCNNManager (object ):
104
+ """Manager class to train S-CNN.
105
+ """
106
+ def __init__ (self , options , path ):
107
+ """Prepare the network, criterion, solver, and data.
108
+ Args:
109
+ options, dict: Hyperparameters.
110
+ """
111
+ print ('Prepare the network and data.' )
112
+ self ._options = options
113
+ self ._path = path
114
+ self ._epoch = 0
115
+ # Network.
116
+ network = SCNN ()
117
+ weight_init (network )
118
+ #self._net = network.cuda()
119
+ self ._net = torch .nn .DataParallel (network ).cuda ()
120
+
121
+ logspaced_LR = np .logspace (- 1 ,- 4 , self ._options ['epochs' ])
122
+ # Load the model from disk.
123
+ checkpoints_list = os .listdir (self ._path ['model' ])
124
+ if len (checkpoints_list ) != 0 :
125
+ self ._net .load_state_dict (torch .load (os .path .join (self ._path ['model' ],'%s%s%s' % ('net_params' , str (len (checkpoints_list )- 1 ), '.pkl' ))))
126
+ self ._epoch = len (checkpoints_list )
127
+ self ._options ['base_lr' ] = logspaced_LR [len (checkpoints_list )]
128
+ #self._net.load_state_dict(torch.load(self._path['model']))
129
+ print (self ._net )
130
+ # Criterion.
131
+ self ._criterion = torch .nn .CrossEntropyLoss ().cuda ()
132
+ # Solver.
133
+ self ._solver = torch .optim .SGD (
134
+ self ._net .parameters (), lr = self ._options ['base_lr' ],
135
+ momentum = 0.9 , weight_decay = self ._options ['weight_decay' ])
136
+ # self._solver = torch.optim.Adam(
137
+ # self._net.parameters(), lr=self._options['base_lr'],
138
+ # weight_decay=self._options['weight_decay'])
139
+ lambda1 = lambda epoch : logspaced_LR [epoch ]
140
+ self ._scheduler = torch .optim .lr_scheduler .LambdaLR (self ._solver ,lr_lambda = lambda1 )
141
+
142
+ train_transforms = torchvision .transforms .Compose ([
143
+ torchvision .transforms .Resize (size = 256 ), # Let smaller edge match
144
+ torchvision .transforms .RandomHorizontalFlip (),
145
+ torchvision .transforms .RandomCrop (size = 224 ),
146
+ torchvision .transforms .ToTensor (),
147
+ torchvision .transforms .Normalize (mean = (0.485 , 0.456 , 0.406 ),
148
+ std = (0.229 , 0.224 , 0.225 ))
149
+ ])
150
+ test_transforms = torchvision .transforms .Compose ([
151
+ torchvision .transforms .Resize (size = 256 ),
152
+ torchvision .transforms .CenterCrop (size = 224 ),
153
+ torchvision .transforms .ToTensor (),
154
+ torchvision .transforms .Normalize (mean = (0.485 , 0.456 , 0.406 ),
155
+ std = (0.229 , 0.224 , 0.225 ))
156
+ ])
157
+ train_data = WPFolder .WPFolder (
158
+ root = self ._path ['waterloo_pascal' ], loader = default_loader , extensions = IMG_EXTENSIONS ,
159
+ transform = train_transforms ,train = True , ratio = 0.8 )
160
+ test_data = WPFolder .WPFolder (
161
+ root = self ._path ['waterloo_pascal' ], loader = default_loader , extensions = IMG_EXTENSIONS ,
162
+ transform = test_transforms , train = False , ratio = 0.8 )
163
+ self ._train_loader = torch .utils .data .DataLoader (
164
+ train_data , batch_size = self ._options ['batch_size' ],
165
+ shuffle = True , num_workers = 0 , pin_memory = True )
166
+ self ._test_loader = torch .utils .data .DataLoader (
167
+ test_data , batch_size = self ._options ['batch_size' ],
168
+ shuffle = False , num_workers = 0 , pin_memory = True )
169
+
170
+ def train (self ):
171
+ """Train the network."""
172
+ print ('Training.' )
173
+ best_acc = 0.0
174
+ best_epoch = None
175
+ print ('Epoch\t Train loss\t Train acc\t Test acc' )
176
+ for t in range (self ._epoch ,self ._options ['epochs' ]):
177
+ epoch_loss = []
178
+ num_correct = 0.0
179
+ num_total = 0.0
180
+ batchindex = 0
181
+ for X , y in self ._train_loader :
182
+ X = torch .tensor (X .cuda ())
183
+ y = torch .tensor (y .cuda (non_blocking = True )) #async=True
184
+ #y = torch.tensor(y.to(device))
185
+
186
+ # Clear the existing gradients.
187
+ self ._solver .zero_grad ()
188
+ # Forward pass.
189
+ score = self ._net (X )
190
+ loss = self ._criterion (score , y .detach ())
191
+ epoch_loss .append (loss .item ())
192
+
193
+ # Prediction.
194
+ _ , prediction = torch .max (F .softmax (score .data ), 1 )
195
+ num_total += y .size (0 )
196
+ num_correct += torch .sum (prediction == y )
197
+ # Backward pass.
198
+ loss .backward ()
199
+ self ._solver .step ()
200
+ batchindex = batchindex + 1
201
+ print ('%d epoch done' % (t + 1 ))
202
+ train_acc = 100 * num_correct .float () / num_total
203
+ if (t < 2 ) | (t > 20 ):
204
+ with torch .no_grad ():
205
+ test_acc = self ._accuracy (self ._test_loader )
206
+ if test_acc > best_acc :
207
+ best_acc = test_acc
208
+ best_epoch = t + 1
209
+ print ('*' , end = '' )
210
+ print ('%d\t %4.3f\t \t %4.2f%%\t \t %4.2f%%' %
211
+ (t + 1 , sum (epoch_loss ) / len (epoch_loss ), train_acc , test_acc ))
212
+ pwd = os .getcwd ()
213
+ modelpath = os .path .join (pwd ,'models' ,('net_params' + str (t ) + '.pkl' ))
214
+ torch .save (self ._net .state_dict (), modelpath )
215
+ self ._scheduler .step (t )
216
+ print ('Best at epoch %d, test accuaray %f' % (best_epoch , best_acc ))
217
+
218
+ def _accuracy (self , data_loader ):
219
+ """Compute the train/test accuracy.
220
+ Args:
221
+ data_loader: Train/Test DataLoader.
222
+ Returns:
223
+ Train/Test accuracy in percentage.
224
+ """
225
+ self ._net .eval ()
226
+ num_correct = 0.0
227
+ num_total = 0.0
228
+ batchindex = 0
229
+ for X , y in data_loader :
230
+ # Data.
231
+ batchindex = batchindex + 1
232
+ X = torch .tensor (X .cuda ())
233
+ y = torch .tensor (y .cuda (non_blocking = True )) #async=True
234
+
235
+ # Prediction.
236
+ score = self ._net (X )
237
+ _ , prediction = torch .max (score .data , 1 )
238
+ num_total += y .size (0 )
239
+ num_correct += torch .sum (prediction == y .data )
240
+ self ._net .train () # Set the model to training phase
241
+ return 100 * num_correct .float () / num_total
242
+
243
+
244
+
245
+ def main ():
246
+ """The main function."""
247
+ import argparse
248
+ parser = argparse .ArgumentParser (
249
+ description = 'Train DB-CNN for BIQA.' )
250
+ parser .add_argument ('--base_lr' , dest = 'base_lr' , type = float , default = 1e-1 ,
251
+ help = 'Base learning rate for training.' )
252
+ parser .add_argument ('--batch_size' , dest = 'batch_size' , type = int ,
253
+ default = 128 , help = 'Batch size.' )
254
+ parser .add_argument ('--epochs' , dest = 'epochs' , type = int ,
255
+ default = 30 , help = 'Epochs for training.' )
256
+ parser .add_argument ('--weight_decay' , dest = 'weight_decay' , type = float ,
257
+ default = 5e-4 , help = 'Weight decay.' )
258
+
259
+ args = parser .parse_args ()
260
+ if args .base_lr <= 0 :
261
+ raise AttributeError ('--base_lr parameter must >0.' )
262
+ if args .batch_size <= 0 :
263
+ raise AttributeError ('--batch_size parameter must >0.' )
264
+ if args .epochs < 0 :
265
+ raise AttributeError ('--epochs parameter must >=0.' )
266
+ if args .weight_decay <= 0 :
267
+ raise AttributeError ('--weight_decay parameter must >0.' )
268
+
269
+
270
+ options = {
271
+ 'base_lr' : args .base_lr ,
272
+ 'batch_size' : args .batch_size ,
273
+ 'epochs' : args .epochs ,
274
+ 'weight_decay' : args .weight_decay ,
275
+ }
276
+
277
+
278
+ path = {
279
+ 'waterloo_pascal' : 'Z:\Waterloo\exploration_database_and_code\image' ,
280
+ 'model' : 'D:\zwx_Project\dbcnn_pytorch\models'
281
+ }
282
+
283
+ manager = SCNNManager (options , path )
284
+ # manager.getStat()
285
+ manager .train ()
286
+
287
+
288
+ if __name__ == '__main__' :
289
+ main ()
0 commit comments