Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code for initialize CausalConv3d from pretrained Conv2D. #168

Open
Sutongtong233 opened this issue Mar 29, 2024 · 10 comments
Open

Code for initialize CausalConv3d from pretrained Conv2D. #168

Sutongtong233 opened this issue Mar 29, 2024 · 10 comments

Comments

@Sutongtong233
Copy link

Hi, I find that you introduce in CausalVideoVAE.md that you use special initialization(tail initialization) for CausalConv3d training. I am interested in this trick, and I would be sincerely grateful if you could share the specific initialization code.

@vivym
Copy link

vivym commented Mar 29, 2024

w = vae_2d_ckpt["state_dict"][key_2d]            # conv2d weight
new_w = torch.zeros(shape_3d, dtype=w.dtype)
new_w[:, :, -1, :, :] = w

https://github.com/vivym/OmniGen/blob/main/scripts/inflate_conv_for_video_vae.py

@Birdylx
Copy link

Birdylx commented Mar 30, 2024

@vivym thanks, but I have another question about temporal upsample at this line https://github.com/vivym/OmniGen/blob/4f0bf7d7f7dcb6b1b79b50c90153f7477151e139/src/omni_gen/models/video_vae/upsamplers.py#L87, it isn't 2x upsample, it will always be odd frames.

@vivym
Copy link

vivym commented Mar 30, 2024

@Birdylx It is indeed an odd number of frames. You can refer to the paper https://arxiv.org/abs/2310.05737

@Birdylx
Copy link

Birdylx commented Mar 30, 2024

@vivym thanks for your quick rely!, I will read the paper for more details.

@Birdylx
Copy link

Birdylx commented Mar 30, 2024

@vivym Do you train the full model? or freeze the model, just train the temporal block?

@Sutongtong233
Copy link
Author

Thanks:) I will have a try.

@Sutongtong233
Copy link
Author

It works. Thanks a lot!

@Sutongtong233
Copy link
Author

I see, "Despite the VAE in Diffusion training being frozen" mentioned in your latest doc. Is that means that you've found freezing 2d-VAE weight ("tail" of casual3d Conv) performs better?

@Sutongtong233 Sutongtong233 reopened this Apr 8, 2024
@Sutongtong233
Copy link
Author

@vivym Do you train the full model? or freeze the model, just train the temporal block?

I've tried train the full model, the motion blurring is alleviated, while the single frame reconstruction degrade.

@Catpp01
Copy link

Catpp01 commented Oct 7, 2024

w = vae_2d_ckpt["state_dict"][key_2d]  # conv2d weight
new_w = torch.zeros(shape_3d, dtype=w.dtype) # shape_3d = (batch_size, 3, t, height, width)
new_w[:, :, -1, :, :] = w #    --tail initialization
# center   : new_w[:, :, T/2, :, :]
# average  : new_w[:, :, :, :, :]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants