Skip to content

Commit 5346bcf

Browse files
committed
Adjust deconvolution filter size for different scaling factors
1 parent f8af70e commit 5346bcf

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ def __init__(self, sess, config):
4848
# Different model layer counts and filter sizes for FSRCNN vs FSRCNN-s (fast), (d, s, m) in paper
4949
model_params = [[56, 12, 4], [32, 8, 1]]
5050
self.model_params = model_params[self.fast]
51+
52+
self.deconv_radius = [3, 5, 7][self.scale - 2]
5153

5254
self.checkpoint_dir = config.checkpoint_dir
5355
self.output_dir = config.output_dir
@@ -65,11 +67,12 @@ def build_model(self):
6567
d, s, m = self.model_params
6668

6769
expand_weight, deconv_weight = 'w{}'.format(m + 3), 'w{}'.format(m + 4)
70+
deconv_size = self.deconv_radius * 2 + 1
6871
self.weights = {
6972
'w1': tf.Variable(tf.random_normal([5, 5, 1, d], stddev=0.0378, dtype=tf.float32), name='w1'),
7073
'w2': tf.Variable(tf.random_normal([1, 1, d, s], stddev=0.3536, dtype=tf.float32), name='w2'),
7174
expand_weight: tf.Variable(tf.random_normal([1, 1, s, d], stddev=0.189, dtype=tf.float32), name=expand_weight),
72-
deconv_weight: tf.Variable(tf.random_normal([9, 9, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
75+
deconv_weight: tf.Variable(tf.random_normal([deconv_size, deconv_size, 1, d], stddev=0.0001, dtype=tf.float32), name=deconv_weight)
7376
}
7477

7578
expand_bias, deconv_bias = 'b{}'.format(m + 3), 'b{}'.format(m + 4)

0 commit comments

Comments
 (0)