-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[refactor] remove conv_cache from CogVideoX VAE #9524
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I hope there's nothing too backward-breaking here since now the return types of the blocks have changed but I guess that is unavoidable no? The blocks aren't fully public in the sense that you cannot import them with from diffusers import CogVideoXResnetBlock3D
. So, I guess okay?
Or can we prevent it by introducing a class attribute like "use_conv_caching" and do caching and return the cache only when this is True?
Or am I just being paranoid?
self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[List[torch.Tensor]] = None | ||
) -> torch.Tensor: | ||
if conv_cache is None: | ||
conv_cache = [None] * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is with the 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry this is just for the time being. It is effectively the number of children modules that are directly a CausalConv3D or have a CausalConv3D. In the first iteration of the frame-wise decoding, the values passed to the conv_cache
layer across all blocks needs to be None. Doing this just makes it a bit simple to handle. For second to n'th iteration, the conv_cache will be populated with previous temporal slice. Since the SpatialNorm3D layer has 2 CausalConv3d layers, we create a list of 2 elements here. This leads to a little cleaner implementation where we can just pop stuff instead of maintaining conditional statements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah cool. Perhaps this could be added as a comment.
There shouldn't, hopefully, be any backward breaking change introduced here. It simply moves the conv_cache saved per CausalConv3D block outside, by returning the required cached values up the stack for frame batch iteration Unfortunately, it leads to a little ugly code. We need to maintain a tape that is used collect all intermediate conv_cache values and propagate it back for future iteration. This is effectively what the list of tensors are doing. At every intermediate layer that uses a CausalConv3D, we create a nested list that contains all the conv_cache outputs from children modules. This is just a draft to show the required changes for making what we discussed work. It is very inefficient due to using
I don't think we can do this. We always have to return the conv_cache or store it like we were doing before. As we discussed internally, we could allow oneshot decoding instead of frame-batchwise decoding (which is what we do currently). The former would not require a conv_cache but will come at a significant memory cost. This is something we can make configurable but it's for another PR. For this PR, we need to ensure same behaviour as before, and therefore must have the conv_cache. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohh thanks!
so this is actually a lot more complex than I had thought! lol I made some suggestions but it is just a rough proposal.
Feel free to test it out to see if it would work, or brainstorm for better solutions
a few feedbacks:
- I think if cache is None, we can just pass that default value all the way to CogVideoXCausalConv3d, we don't need to make sure it has the same length
- I think putting all the cache into a list would work because it should be add and used in the same order - but with a dict it's more explicit and easier to follow
Can you share some words for those not on the internal Slack on what this patch will change for users of CogVideoX through Diffusers? |
This patch will hopefully enable the compilation of the VAE, resulting to (hopefully) better speedups. Plus, this patch turns the caching code into something that respects the Diffusers philosophy. |
This update also makes We need the following patch, first: diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
index 2d6e20ac8..a882a0aa9 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -41,7 +41,7 @@ class CogVideoXSafeConv3d(nn.Conv3d):
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
- memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
+ memory_count = (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
# Set to 2GB, suitable for CuDNN
if memory_count > 2: This is needed because Compilation code (thanks to @a-r-r-o-w):import os
# os.environ["TORCH_LOGS"] = "+dynamo,graph_breaks,recompiles"
import numpy as np
import torch
from diffusers.utils import load_video
from diffusers import AutoencoderKLCogVideoX
with torch.no_grad():
vae = AutoencoderKLCogVideoX.from_pretrained(
"THUDM/CogVideoX-2b", subfolder="vae", torch_dtype=torch.float16
).to("cuda")
vae.encode = torch.compile(vae.encode, mode="max-autotune", fullgraph=True)
vae.decode = torch.compile(vae.decode, mode="max-autotune", fullgraph=True)
# vae = torch.compile(vae, mode="max-autotune", fullgraph=True)
for i in range(5):
video = load_video(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
)
video = np.array(video).transpose(3, 0, 1, 2) / 127.5 - 1
video = torch.from_numpy(video).unsqueeze(0).to("cuda", dtype=torch.float16)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
encoded = vae.encode(video, return_dict=False)[0].sample()
decoded = (vae.decode(encoded, return_dict=False)[0].float() + 1) * 255
end.record()
torch.cuda.synchronize()
time_elapsed = start.elapsed_time(end) / 1000
decoded = torch.clamp(decoded, 0, 255)
video = decoded.permute(0, 2, 3, 4, 1)[0].cpu().numpy().astype(np.uint8)
print("Inference time:", i, " - ", round(time_elapsed, 3)) Vanilla compile on Inference time: 0 - 9.357
Inference time: 1 - 7.066
Inference time: 2 - 7.084
Inference time: 3 - 7.057
Inference time: 4 - 7.053 Separate Inference time: 0 - 1286.903
Inference time: 1 - 13.297
Inference time: 2 - 4.504
Inference time: 3 - 4.514
Inference time: 4 - 4.522 |
Co-Authored-By: YiYi Xu <yixu310@gmail.com>
Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! let's also run slow tests on cogvideox to make sure it is backward compatible on the vae
I didn't find any usage on the individual blocks on github so I think we don't have to worry about breaking for these blocks
feel free to merge after that!
My local testing script works fine and produces the same video as before these changes. I just realized that our CogVideoXIntegrationTests fail on the CI due to OOM, pt -> np conversion, and shape mismatch. I'll fix this in a follow-up PR. |
* remove conv cache from the layer and pass as arg instead * make style * yiyi's cleaner implementation Co-Authored-By: YiYi Xu <yixu310@gmail.com> * sayak's compiled implementation Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
* remove conv cache from the layer and pass as arg instead * make style * yiyi's cleaner implementation Co-Authored-By: YiYi Xu <yixu310@gmail.com> * sayak's compiled implementation Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com> --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: YiYi Xu <yixu310@gmail.com>
What does this PR do?
As discussed internall in https://huggingface.slack.com/archives/C068ZAHJZCZ/p1725402937927749?thread_ts=1725360687.955529&cid=C068ZAHJZCZ
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @yiyixuxu