-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathoption_train_DistillationIQA.py
116 lines (97 loc) · 5.5 KB
/
option_train_DistillationIQA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
# import template
import argparse
import os
"""
Configuration file
"""
def check_args(args, rank=0):
if rank == 0:
with open(args.setting_file, 'w') as opt_file:
opt_file.write('------------ Options -------------\n')
print('------------ Options -------------')
for k in args.__dict__:
v = args.__dict__[k]
opt_file.write('%s: %s\n' % (str(k), str(v)))
print('%s: %s' % (str(k), str(v)))
opt_file.write('-------------- End ----------------\n')
print('------------ End -------------')
return args
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def set_args():
parser = argparse.ArgumentParser()
parser.add_argument('--gpu_ids', type=str, default='0')
parser.add_argument('--test_dataset', type=str, default='live', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k')
parser.add_argument('--train_dataset', type=str, default='kadid10k', help='Support datasets: pipal|livec|koniq-10k|bid|live|csiq|tid2013|kadid10k')
parser.add_argument('--train_patch_num', type=int, default=1, help='Number of sample patches from training image')
parser.add_argument('--test_patch_num', type=int, default=1, help='Number of sample patches from testing image')
parser.add_argument('--lr', dest='lr', type=float, default=2e-5, help='Learning rate')
parser.add_argument('--weight_decay', type=float, default=5e-4, help='Weight decay')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--epochs', type=int, default=100, help='Epochs for training')
parser.add_argument('--patch_size', type=int, default=224, help='Crop size for training & testing image patches')
parser.add_argument('--self_patch_num', type=int, default=10, help='number of training & testing image self patches')
parser.add_argument('--train_test_num', type=int, default=1, help='Train-test times')
parser.add_argument('--update_opt_epoch', type=int, default=30)
#Ref Img
parser.add_argument('--use_refHQ', type=str2bool, default=True)
parser.add_argument('--distillation_layer', type=int, default=18, help='last xth layers of HQ-MLP for distillation')
parser.add_argument('--net_print', type=int, default=2000)
parser.add_argument('--setting_file', type=str, default='setting.txt')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_DistillationIQA/')
parser.add_argument('--use_fitting_prcc_srcc', type=str2bool, default=True)
parser.add_argument('--print_netC', type=str2bool, default=False)
parser.add_argument('--teacherNet_model_path', type=str, default='./model_zoo/FR_teacher_cross_dataset.pth')
parser.add_argument('--studentNet_model_path', type=str, default=None, help='./model_zoo/NAR_student_cross_dataset.pth')
#distillation
parser.add_argument('--distillation_loss', type=str, default='l1', help='mse|l1|kldiv')
args = parser.parse_args()
#Dataset
args.setting_file = os.path.join(args.checkpoint_dir, args.setting_file)
if not os.path.exists('./dataset/'):
os.mkdir('./dataset/')
folder_path = {
'live': './dataset/LIVE/',
'csiq': './dataset/CSIQ/',
'tid2013': './dataset/TID2013/',
'koniq-10k': './dataset/koniq-10k/',
}
ref_dataset_path = './dataset/DIV2K_ref/'
args.ref_train_dataset_path = ref_dataset_path + 'train_HR/'
args.ref_test_dataset_path = ref_dataset_path + 'val_HR/'
#checkpoint files
args.model_checkpoint_dir = args.checkpoint_dir + 'models/'
args.result_checkpoint_dir = args.checkpoint_dir + 'results/'
args.log_checkpoint_dir = args.checkpoint_dir + 'log/'
if os.path.exists(args.checkpoint_dir) and os.path.isfile(args.checkpoint_dir):
raise IOError('Required dst path {} as a directory for checkpoint saving, got a file'.format(
args.checkpoint_dir))
elif not os.path.exists(args.checkpoint_dir):
os.makedirs(args.checkpoint_dir)
print('%s created successfully!'%args.checkpoint_dir)
if os.path.exists(args.model_checkpoint_dir) and os.path.isfile(args.model_checkpoint_dir):
raise IOError('Required dst path {} as a directory for checkpoint model saving, got a file'.format(
args.model_checkpoint_dir))
elif not os.path.exists(args.model_checkpoint_dir):
os.makedirs(args.model_checkpoint_dir)
print('%s created successfully!'%args.model_checkpoint_dir)
if os.path.exists(args.result_checkpoint_dir) and os.path.isfile(args.result_checkpoint_dir):
raise IOError('Required dst path {} as a directory for checkpoint results saving, got a file'.format(
args.result_checkpoint_dir))
elif not os.path.exists(args.result_checkpoint_dir):
os.makedirs(args.result_checkpoint_dir)
print('%s created successfully!'%args.result_checkpoint_dir)
if os.path.exists(args.log_checkpoint_dir) and os.path.isfile(args.log_checkpoint_dir):
raise IOError('Required dst path {} as a directory for checkpoint log saving, got a file'.format(
args.log_checkpoint_dir))
elif not os.path.exists(args.log_checkpoint_dir):
os.makedirs(args.log_checkpoint_dir)
print('%s created successfully!'%args.log_checkpoint_dir)
return args
if __name__ == "__main__":
args = set_args()