-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconfig.py
157 lines (115 loc) · 4.44 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse
import numpy as np
import os
import math
import cv2
import sys
import random
import tensorflow as tf
from tqdm import tqdm
from tensorflow.keras import Model
from src.metrics import mu_tonemap_tf
# results and model name
op_phase = 'train'
# image mini-batch size
img_mini_b = 8
#########################################################################
# READ & WRITE DATA PATHS #
#########################################################################
# path to save model
path_best_model = 'model_legacy/att_39.14.h5'
path_save_model = 'weights/att.h5'
path_save_model_finetune = 'weights/att_finetune.h5'
path_save_model_mu = 'weights/att_mu.h5'
path_save_model_all_data = 'weights/att_all.h5'
# paths to read data
path_read_train = 'train_jpg_2/'
path_read_val_test = 'valid_jpg_2/'
#########################################################################
# NUMBER OF IMAGES IN THE TRAINING, VALIDATION, AND TESTING SETS #
#########################################################################
if op_phase == 'train':
total_nb_train = len([path_read_train + f for f
in os.listdir(path_read_train)
if f.endswith(('_short.jpg', '_short.JPG', '_short.png', '_short.PNG', '_short.TIF'))])
total_nb_val = len([path_read_val_test + f for f
in os.listdir(path_read_val_test)
if f.endswith(('_short.jpg', '_short.JPG', '_short.png', '_short.PNG', '_short.TIF'))])
# number of training image batches
nb_train = int(math.ceil(total_nb_train/img_mini_b))
# number of validation image batches
nb_val = int(math.ceil(total_nb_val/img_mini_b))
elif op_phase == 'validation':
total_nb_test = len([path_read_val_valid + f for f
in os.listdir(path_read_val_valid)
if (f.endswith(('_short.jpg', '_short.JPG', '_short.png', '_short.PNG', '_short.TIF')))])
#########################################################################
# MODEL PARAMETERS & TRAINING SETTINGS #
#########################################################################
# input image size
img_w = 1900
img_h = 1060
# input patch size
patch_w = 320
patch_h = 320
# number of epochs
nb_epoch = 300
# number of input channels
nb_ch_all = 6
# number of output channels
nb_ch = 3 # change conv9 in the model and the folowing variable
# after how many epochs you change learning rate
scheduling_rate = 30
dropout_rate = 0.4
# generate learning rate array
lr_ = []
lr_.append(1e-4) # initial learning rate
for i in range(int(nb_epoch/scheduling_rate)):
lr_.append(lr_[i]*0.5)
train_set, val_set, test_set, comp_set = [], [], [], []
size_set, portrait_orientation_set = [], []
mse_list, psnr_list, ssim_list, mae_list = [], [], [], []
def lrfn(epoch):
if epoch < LR_RAMPUP_EPOCHS:
lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START
elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
lr = LR_MAX
else:
lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch -
LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS)
return lr
def step_decay_schedule(epoch):
'''
Wrapper function to create a LearningRateScheduler with step decay schedule.
'''
initial_lr = 1e-5
decay_factor = 0.5
step_size = 2
return initial_lr * (decay_factor ** np.floor(epoch/step_size))
def loss_function(y_true, y_pred):
squared_difference = tf.square(y_true - y_pred)
mse = tf.reduce_mean(squared_difference, axis=-1)
ssim = SSIMLoss(y_true, y_pred)
total_loss = 100*mse + 5*ssim
return total_loss
def MAE(y_true, y_pred):
squared_difference = tf.abs(y_true - y_pred)
mae = tf.reduce_mean(squared_difference, axis=-1)
return mae
def MAE_mu(y_true, y_pred):
y_true, y_pred = mu_tonemap_tf(y_true), mu_tonemap_tf(y_pred)
squared_difference = tf.abs(y_true - y_pred)
mae = tf.reduce_mean(squared_difference, axis=-1)
return mae
def loss_function_2(y_true, y_pred):
absolute_difference = tf.abs(y_true - y_pred)
mse = tf.reduce_mean(absolute_difference, axis=-1)
ssim = SSIMLossMS(y_true, y_pred)
total_loss = 5*mse + 1*ssim
return total_loss
def SSIMLoss(y_true, y_pred):
return 1 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 1.0))
def SSIMLossMS(y_true, y_pred):
return 1 - tf.reduce_mean(tf.image.ssim_multiscale(y_true, y_pred, 1.0))
lr__ = []
lr__.append(1e-5)