Skip to content

Commit d3c0cac

Browse files
committed
update
1 parent 5ee9884 commit d3c0cac

File tree

4 files changed

+64
-66
lines changed

4 files changed

+64
-66
lines changed

models/networks/inpaint_g.py

+12-15
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,13 @@ def __init__(self, opt, return_feat=False, return_pm=False):
291291
self.conv7_atrous = gen_conv(2*cnum, 4*cnum, 3, rate=2)
292292
self.conv8_atrous = gen_conv(2*cnum, 4*cnum, 3, rate=4)
293293
self.conv9_atrous = gen_conv(2*cnum, 4*cnum, 3, rate=8)
294-
self.conv10_atrous = gen_conv(2*cnum, 4*cnum, 3, rate=16) #8
294+
self.conv10_atrous = gen_conv(2*cnum, 4*cnum, 3, rate=16)
295295
self.conv11 = gen_conv(2*cnum, 4*cnum, 3, 1)
296-
self.conv12 = gen_conv(2*cnum, 4*cnum, 3, 1) #4
296+
self.conv12 = gen_conv(2*cnum, 4*cnum, 3, 1)
297297
self.conv13_upsample_conv = gen_deconv(2*cnum, 2*cnum)
298-
self.conv14 = gen_conv(cnum, 2*cnum, 3, 1) #2
298+
self.conv14 = gen_conv(cnum, 2*cnum, 3, 1)
299299
self.conv15_upsample_conv = gen_deconv(cnum, cnum)
300-
self.conv16 = gen_conv(cnum//2, cnum//2, 3, 1) #1
300+
self.conv16 = gen_conv(cnum//2, cnum//2, 3, 1)
301301
self.conv17 = gen_conv(cnum//4, 3, 3, 1, activation=None)
302302

303303
# stage2
@@ -316,7 +316,7 @@ def __init__(self, opt, return_feat=False, return_pm=False):
316316
self.pmconv3 = gen_conv(cnum//2, 2*cnum, 3, 1)
317317
self.pmconv4_downsample = gen_conv(cnum, 4*cnum, 3, 2)
318318
self.pmconv5 = gen_conv(2*cnum, 4*cnum, 3, 1)
319-
self.pmconv6 = gen_conv(2*cnum, 4*cnum, 3, 1,
319+
self.pmconv6 = gen_conv(2*cnum, 4*cnum, 3, 1,
320320
activation=nn.ReLU())
321321
self.pmconv9 = gen_conv(2*cnum, 4*cnum, 3, 1)
322322
self.pmconv10 = gen_conv(2*cnum, 4*cnum, 3, 1)
@@ -396,28 +396,25 @@ def forward(self, x, mask):
396396
x = self.pmconv4_downsample(x)
397397
x = self.pmconv5(x)
398398
x = self.pmconv6(x)
399-
pm = x
399+
pm_return = x
400+
400401
x = self.pmconv9(x)
401402
x = self.pmconv10(x)
403+
pm = x
402404
x = torch.cat([x_hallu, pm], 1)
403-
feat = x
404405

405406
x = self.allconv11(x)
406407
x = self.allconv12(x)
407408
x = self.allconv13_upsample_conv(x)
408-
feat_x2 = x
409409
x = self.allconv14(x)
410410
x = self.allconv15_upsample_conv(x)
411-
feat_x4 = x
412411
x = self.allconv16(x)
413412
x = self.allconv17(x)
414413
x_stage2 = torch.tanh(x)
415-
if self.return_feat:
416-
return x_stage1, x_stage2, [feat, feat_x2, feat_x4]
417-
elif self.return_pm:
418-
return x_stage1, x_stage2, pm
419-
else:
420-
return x_stage1, x_stage2
414+
if self.return_pm:
415+
return x_stage1, x_stage2, pm_return
416+
417+
return x_stage1, x_stage2
421418

422419
if __name__ == "__main__":
423420
pass

test.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import cv2
23
import torch
34
import data
45
from options.test_options import TestOptions

train.sh

+50-50
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,55 @@
1-
#BSIZE0=48 # stage coarse
2-
#BSIZE=96 # 96:64G
1+
##BSIZE0=48 # stage coarse
2+
##BSIZE=96 # 96:64G
33
BSIZE0=$((BSIZE/2))
44
NWK=16
5-
PREFIX="--dataset_mode_train trainimage \
6-
--gpu_ids 0,1 \
7-
--name debug \
8-
--dataset_mode_val valimage \
9-
--train_image_dir ./datasets/places/places2 \
10-
--train_image_list ./datasets/places/train_example.txt \
11-
--path_objectshape_list ./datasets/object_shapes.txt \
12-
--path_objectshape_base ./datasets/object_masks \
13-
--val_image_dir ./datasets/places2sample1k_val/places2samples1k_crop256 \
14-
--val_image_list ./datasets/places2sample1k_val/files.txt \
15-
--val_mask_dir ./datasets/places2sample1k_val/places2samples1k_256_mask_square128 \
16-
--no_vgg_loss \
17-
--no_ganFeat_loss \
18-
--load_size 640 \
19-
--crop_size 256 \
20-
--model inpaint \
21-
--netG baseconv \
22-
--netD deepfill \
23-
--preprocess_mode scale_shortside_and_crop \
24-
--validation_freq 10000 \
25-
--niter 50 "
26-
python train.py \
27-
${PREFIX} \
28-
--batchSize ${BSIZE0} \
29-
--nThreads ${NWK} \
30-
--no_fine_loss \
31-
--update_part coarse \
32-
--no_gan_loss \
33-
--freeze_D \
34-
--niter 1 \
35-
${EXTRA}
36-
python train.py \
37-
${PREFIX} \
38-
--batchSize ${BSIZE} \
39-
--nThreads ${NWK} \
40-
--update_part fine \
41-
--continue_train \
42-
--niter 2 \
43-
${EXTRA}
44-
python train.py \
45-
${PREFIX} \
46-
--batchSize ${BSIZE} \
47-
--nThreads ${NWK} \
48-
--update_part all \
49-
--continue_train \
50-
--niter 4 \
51-
${EXTRA}
52-
5+
#PREFIX="--dataset_mode_train trainimage \
6+
#--gpu_ids 0,1 \
7+
#--name debug \
8+
#--dataset_mode_val valimage \
9+
#--train_image_dir ./datasets/places/places2 \
10+
#--train_image_list ./datasets/places/train_example.txt \
11+
#--path_objectshape_list ./datasets/object_shapes.txt \
12+
#--path_objectshape_base ./datasets/object_masks \
13+
#--val_image_dir ./datasets/places2sample1k_val/places2samples1k_crop256 \
14+
#--val_image_list ./datasets/places2sample1k_val/files.txt \
15+
#--val_mask_dir ./datasets/places2sample1k_val/places2samples1k_256_mask_square128 \
16+
#--no_vgg_loss \
17+
#--no_ganFeat_loss \
18+
#--load_size 640 \
19+
#--crop_size 256 \
20+
#--model inpaint \
21+
#--netG baseconv \
22+
#--netD deepfill \
23+
#--preprocess_mode scale_shortside_and_crop \
24+
#--validation_freq 10000 \
25+
#--niter 50 "
26+
#python train.py \
27+
# ${PREFIX} \
28+
# --batchSize ${BSIZE0} \
29+
# --nThreads ${NWK} \
30+
# --no_fine_loss \
31+
# --update_part coarse \
32+
# --no_gan_loss \
33+
# --freeze_D \
34+
# --niter 1 \
35+
# ${EXTRA}
36+
#python train.py \
37+
# ${PREFIX} \
38+
# --batchSize ${BSIZE} \
39+
# --nThreads ${NWK} \
40+
# --update_part fine \
41+
# --continue_train \
42+
# --niter 2 \
43+
# ${EXTRA}
44+
#python train.py \
45+
# ${PREFIX} \
46+
# --batchSize ${BSIZE} \
47+
# --nThreads ${NWK} \
48+
# --update_part all \
49+
# --continue_train \
50+
# --niter 4 \
51+
# ${EXTRA}
52+
#
5353
PREFIX="--dataset_mode_train trainimage \
5454
--name debugarr0 \
5555
--gpu_ids 0,1 \

util/util.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def load_network(net, label, epoch, opt):
226226
if k.startswith("module."):
227227
k=k.replace("module.","")
228228
new_dict[k] = v
229-
net.load_state_dict(new_dict, strict=False)
229+
net.load_state_dict(new_dict)
230230
return net
231231

232232

0 commit comments

Comments
 (0)