Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Would you like to know if changing the training decoder's image_resolution from 512 to 256 will affect the network result? #1

Closed
remem123 opened this issue Jul 31, 2024 · 9 comments

Comments

@remem123
Copy link

remem123 commented Jul 31, 2024

Would you like to know if changing the training decoder's image_resolution from 512 to 256 will affect the network result? I changed image_resolution from 512 to 256 and encountered the following error while fine-tuning the steps:

[rank0]:`` Traceback (most recent call last):
[rank0]:   File "/root/autodl-fs/work/WaDiff-main/guided-diffusion/scripts/image_train.py", line 124, in <module>
[rank0]:     main()
[rank0]:   File "/root/autodl-fs/work/WaDiff-main/guided-diffusion/scripts/image_train.py", line 58, in main
[rank0]:     wm_decoder.load_state_dict(th.load(args.wm_decoder_path, map_location='cpu'), strict=False).eval()
[rank0]:   File "/root/miniconda3/envs/wadiff/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
[rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for StegaStampDecoder:
[rank0]:        size mismatch for dense.0.weight: copying a param with shape torch.Size([512, 2048]) from checkpoint, the shape in current model is torch.Size([512, 8192]).

I would appreciate it if you could answer me!Above is the use of translator translation, if there is inappropriate place hope you can understand.

@rmin2000
Copy link
Owner

Hi, thanks for your interest :).
You might review the --image_size in guided-diffusion/train.sh. Considering the error mentioned above, you have likely set the --image_size to 512 while attempting to load a checkpoint with a 256 resolution.

@remem123
Copy link
Author

remem123 commented Aug 1, 2024

Thanks for your quick answer! I was careless! I solved that problem! But I have a new problem (. Too many 'size mismatch for output_blocks' errors are displayed. I'm trying to solve them (.

@rmin2000
Copy link
Owner

rmin2000 commented Aug 1, 2024

No worries, could you provide more details about it, like your console output? It appears that there is a mismatch between the checkpoint and the Unet architecture.

@remem123
Copy link
Author

remem123 commented Aug 1, 2024

The error log is shown below and I think you are absolutely right (. Finally, I tried using strict=False in the dist_util.load_state_dict. This didn't work either. QAQ

[rank0]: Traceback (most recent call last):
[rank0]:   File "/root/autodl-fs/work/WaDiff-main/guided-diffusion/scripts/image_train.py", line 124, in <module>
[rank0]:     main()
[rank0]:   File "/root/autodl-fs/work/WaDiff-main/guided-diffusion/scripts/image_train.py", line 73, in main
[rank0]:     TrainLoop(
[rank0]:   File "/root/autodl-fs/work/WaDiff-main/guided-diffusion/guided_diffusion/train_util.py", line 79, in __init__
[rank0]:     self._load_and_sync_parameters(load_wm_model=True)
[rank0]:   File "/root/autodl-fs/work/WaDiff-main/guided-diffusion/guided_diffusion/train_util.py", line 147, in _load_and_sync_parameters
[rank0]:     self.model.load_state_dict(model_dict)
[rank0]:   File "/root/miniconda3/envs/wadiff/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2215, in load_state_dict
[rank0]:     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
[rank0]: RuntimeError: Error(s) in loading state_dict for UNetModel:
[rank0]:        Missing key(s) in state_dict: "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias", "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.1.norm.weight", "input_blocks.8.1.norm.bias", "input_blocks.8.1.qkv.weight", "input_blocks.8.1.qkv.bias", "input_blocks.8.1.proj_out.weight", "input_blocks.8.1.proj_out.bias", "input_blocks.10.0.skip_connection.weight", "input_blocks.10.0.skip_connection.bias". 
[rank0]:        Unexpected key(s) in state_dict: "input_blocks.15.0.in_layers.0.weight", "input_blocks.15.0.in_layers.0.bias", "input_blocks.15.0.in_layers.2.weight", "input_blocks.15.0.in_layers.2.bias", "input_blocks.15.0.emb_layers.1.weight", "input_blocks.15.0.emb_layers.1.bias", "input_blocks.15.0.out_layers.0.weight", "input_blocks.15.0.out_layers.0.bias", "input_blocks.15.0.out_layers.3.weight", "input_blocks.15.0.out_layers.3.bias", "input_blocks.16.0.in_layers.0.weight", "input_blocks.16.0.in_layers.0.bias", "input_blocks.16.0.in_layers.2.weight", "input_blocks.16.0.in_layers.2.bias", "input_blocks.16.0.emb_layers.1.weight", "input_blocks.16.0.emb_layers.1.bias", "input_blocks.16.0.out_layers.0.weight", "input_blocks.16.0.out_layers.0.bias", "input_blocks.16.0.out_layers.3.weight", "input_blocks.16.0.out_layers.3.bias", "input_blocks.16.1.norm.weight", "input_blocks.16.1.norm.bias", "input_blocks.16.1.qkv.weight", "input_blocks.16.1.qkv.bias", "input_blocks.16.1.proj_out.weight", "input_blocks.16.1.proj_out.bias", "input_blocks.17.0.in_layers.0.weight", "input_blocks.17.0.in_layers.0.bias", "input_blocks.17.0.in_layers.2.weight", "input_blocks.17.0.in_layers.2.bias", "input_blocks.17.0.emb_layers.1.weight", "input_blocks.17.0.emb_layers.1.bias", "input_blocks.17.0.out_layers.0.weight", "input_blocks.17.0.out_layers.0.bias", "input_blocks.17.0.out_layers.3.weight", "input_blocks.17.0.out_layers.3.bias", "input_blocks.17.1.norm.weight", "input_blocks.17.1.norm.bias", "input_blocks.17.1.qkv.weight", "input_blocks.17.1.qkv.bias", "input_blocks.17.1.proj_out.weight", "input_blocks.17.1.proj_out.bias", "output_blocks.15.0.in_layers.0.weight", "output_blocks.15.0.in_layers.0.bias", "output_blocks.15.0.in_layers.2.weight", "output_blocks.15.0.in_layers.2.bias", "output_blocks.15.0.emb_layers.1.weight", "output_blocks.15.0.emb_layers.1.bias", "output_blocks.15.0.out_layers.0.weight", "output_blocks.15.0.out_layers.0.bias", "output_blocks.15.0.out_layers.3.weight", "output_blocks.15.0.out_layers.3.bias", "output_blocks.15.0.skip_connection.weight", "output_blocks.15.0.skip_connection.bias", "output_blocks.16.0.in_layers.0.weight", "output_blocks.16.0.in_layers.0.bias", "output_blocks.16.0.in_layers.2.weight", "output_blocks.16.0.in_layers.2.bias", "output_blocks.16.0.emb_layers.1.weight", "output_blocks.16.0.emb_layers.1.bias", "output_blocks.16.0.out_layers.0.weight", "output_blocks.16.0.out_layers.0.bias", "output_blocks.16.0.out_layers.3.weight", "output_blocks.16.0.out_layers.3.bias", "output_blocks.16.0.skip_connection.weight", "output_blocks.16.0.skip_connection.bias", "output_blocks.17.0.in_layers.0.weight", "output_blocks.17.0.in_layers.0.bias", "output_blocks.17.0.in_layers.2.weight", "output_blocks.17.0.in_layers.2.bias", "output_blocks.17.0.emb_layers.1.weight", "output_blocks.17.0.emb_layers.1.bias", "output_blocks.17.0.out_layers.0.weight", "output_blocks.17.0.out_layers.0.bias", "output_blocks.17.0.out_layers.3.weight", "output_blocks.17.0.out_layers.3.bias", "output_blocks.17.0.skip_connection.weight", "output_blocks.17.0.skip_connection.bias", "output_blocks.14.1.in_layers.0.weight", "output_blocks.14.1.in_layers.0.bias", "output_blocks.14.1.in_layers.2.weight", "output_blocks.14.1.in_layers.2.bias", "output_blocks.14.1.emb_layers.1.weight", "output_blocks.14.1.emb_layers.1.bias", "output_blocks.14.1.out_layers.0.weight", "output_blocks.14.1.out_layers.0.bias", "output_blocks.14.1.out_layers.3.weight", "output_blocks.14.1.out_layers.3.bias". 
[rank0]:        size mismatch for input_blocks.10.0.in_layers.2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([768, 512, 3, 3]).
[rank0]:        size mismatch for input_blocks.10.0.in_layers.2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
[rank0]:        size mismatch for input_blocks.10.0.emb_layers.1.weight: copying a param with shape torch.Size([1024, 1024]) from checkpoint, the shape in current model is torch.Size([1536, 1024]).
[rank0]:        size mismatch for input_blocks.10.0.emb_layers.1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1536]).
[rank0]:        size mismatch for input_blocks.10.0.out_layers.0.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
[rank0]:        size mismatch for input_blocks.10.0.out_layers.0.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([768]).
···

@rmin2000
Copy link
Owner

rmin2000 commented Aug 1, 2024

Well, I think setting strict=False will not help when handling mismatched parameters with the same key. Which pre-trained model are you using, I downloaded the checkpoint from link.

@remem123
Copy link
Author

remem123 commented Aug 1, 2024

I'm using a 256x256 ImageNet diffusion model. I think the models we use should be the same, I downloaded the model from the checkpoint link you gave in the README file. Due to capacity, the data set I used for my decoder training was ImageNet-Sketch. This data set consists of 50000 images, 50 images for each of the 1000 ImageNet classes. Here is my script command for./guided diffusion/train.sh:

MODEL_FLAGS="--wm_length 48 --attention_resolutions 32,16,8 --class_cond False --image_size 256 --num_channels 256 --learn_sigma True --num_head_channels 64 --num_res_blocks 2 --resblock_updown True"
DIFFUSION_FLAGS="--diffusion_steps 1000 --noise_schedule linear"
TRAIN_FLAGS="--lr 1e-4 --batch_size 4"
NUM_GPUS=1
mpiexec -n $NUM_GPUS python scripts/image_train.py --alpha 0.4 --threshold 400 --wm_decoder_path... /StegaStamp/logs/imagenet_stegastamp_48_01082024_08:49:03/checkpoints/step_7500_decoder.pth --data_dir .. /data/imagenet-sketch/sketch --resume_checkpoint models/256x256_diffusion_uncond.pt $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS

Thank you very much for your patient answer!!!

@rmin2000
Copy link
Owner

rmin2000 commented Aug 2, 2024

No worries, I will run experiments on my server and will clarify the results to you as soon as possible.

@rmin2000
Copy link
Owner

rmin2000 commented Aug 5, 2024

Hi, thank you for your patience. I ran the script on my server and did not encounter the error. Perhaps you could try loading the checkpoint with the default architecture (by setting the --wm_length to 0) to check if the checkpoint and the model architecture are matched.

@remem123
Copy link
Author

remem123 commented Aug 7, 2024

Thank you very much for your patient answer, I will try it right away!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants