Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] remove conv_cache from CogVideoX VAE #9524

Merged
merged 8 commits into from
Sep 28, 2024

Conversation

a-r-r-o-w
Copy link
Member

What does this PR do?

As discussed internall in https://huggingface.slack.com/archives/C068ZAHJZCZ/p1725402937927749?thread_ts=1725360687.955529&cid=C068ZAHJZCZ

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul @yiyixuxu

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hope there's nothing too backward-breaking here since now the return types of the blocks have changed but I guess that is unavoidable no? The blocks aren't fully public in the sense that you cannot import them with from diffusers import CogVideoXResnetBlock3D. So, I guess okay?

Or can we prevent it by introducing a class attribute like "use_conv_caching" and do caching and return the cache only when this is True?

Or am I just being paranoid?

self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[List[torch.Tensor]] = None
) -> torch.Tensor:
if conv_cache is None:
conv_cache = [None] * 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is with the 2?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry this is just for the time being. It is effectively the number of children modules that are directly a CausalConv3D or have a CausalConv3D. In the first iteration of the frame-wise decoding, the values passed to the conv_cache layer across all blocks needs to be None. Doing this just makes it a bit simple to handle. For second to n'th iteration, the conv_cache will be populated with previous temporal slice. Since the SpatialNorm3D layer has 2 CausalConv3d layers, we create a list of 2 elements here. This leads to a little cleaner implementation where we can just pop stuff instead of maintaining conditional statements.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah cool. Perhaps this could be added as a comment.

@a-r-r-o-w
Copy link
Member Author

There shouldn't, hopefully, be any backward breaking change introduced here. It simply moves the conv_cache saved per CausalConv3D block outside, by returning the required cached values up the stack for frame batch iteration i and then back down the stack for iteration i + 1 repeatedly. Basically, we now don't have a mutating state variable inside the block as discussed internally.

Unfortunately, it leads to a little ugly code. We need to maintain a tape that is used collect all intermediate conv_cache values and propagate it back for future iteration. This is effectively what the list of tensors are doing. At every intermediate layer that uses a CausalConv3D, we create a nested list that contains all the conv_cache outputs from children modules.

This is just a draft to show the required changes for making what we discussed work. It is very inefficient due to using .pop(0). In case we decide that we're okay with this design, I will convert the implementation to pop from the end, or use a deque, which is much more efficient, so please ignore that bit for now.

Or can we prevent it by introducing a class attribute like "use_conv_caching" and do caching and return the cache only when this is True?

I don't think we can do this. We always have to return the conv_cache or store it like we were doing before. As we discussed internally, we could allow oneshot decoding instead of frame-batchwise decoding (which is what we do currently). The former would not require a conv_cache but will come at a significant memory cost. This is something we can make configurable but it's for another PR. For this PR, we need to ensure same behaviour as before, and therefore must have the conv_cache.

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohh thanks!
so this is actually a lot more complex than I had thought! lol I made some suggestions but it is just a rough proposal.
Feel free to test it out to see if it would work, or brainstorm for better solutions

a few feedbacks:

  1. I think if cache is None, we can just pass that default value all the way to CogVideoXCausalConv3d, we don't need to make sure it has the same length
  2. I think putting all the cache into a list would work because it should be add and used in the same order - but with a dict it's more explicit and easier to follow

@tin2tin
Copy link

tin2tin commented Sep 26, 2024

Can you share some words for those not on the internal Slack on what this patch will change for users of CogVideoX through Diffusers?

@sayakpaul
Copy link
Member

sayakpaul commented Sep 26, 2024

Can you share some words for those not on the internal Slack on what this patch will change for users of CogVideoX through Diffusers?

This patch will hopefully enable the compilation of the VAE, resulting to (hopefully) better speedups. Plus, this patch turns the caching code into something that respects the Diffusers philosophy.

@sayakpaul
Copy link
Member

This update also makes torch.compile() possible on the VAE.

We need the following patch, first:

diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
index 2d6e20ac8..a882a0aa9 100644
--- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
+++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py
@@ -41,7 +41,7 @@ class CogVideoXSafeConv3d(nn.Conv3d):
     """
 
     def forward(self, input: torch.Tensor) -> torch.Tensor:
-        memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
+        memory_count = (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
 
         # Set to 2GB, suitable for CuDNN
         if memory_count > 2:

This is needed because item() isn’t compatible with torch.compile() and other approaches like torch.prod(input.shape) renders the conditional flow afterward to be dynamic, which is incompatible with torch.compile().

Compilation code (thanks to @a-r-r-o-w):
import os
# os.environ["TORCH_LOGS"] = "+dynamo,graph_breaks,recompiles"

import numpy as np
import torch
from diffusers.utils import load_video
from diffusers import AutoencoderKLCogVideoX

with torch.no_grad():
    vae = AutoencoderKLCogVideoX.from_pretrained(
		    "THUDM/CogVideoX-2b", subfolder="vae", torch_dtype=torch.float16
		).to("cuda")

    vae.encode = torch.compile(vae.encode, mode="max-autotune", fullgraph=True)
    vae.decode = torch.compile(vae.decode, mode="max-autotune", fullgraph=True)
    # vae = torch.compile(vae, mode="max-autotune", fullgraph=True)
    
    for i in range(5):
        video = load_video(
		        "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/hiker.mp4"
		    )
        video = np.array(video).transpose(3, 0, 1, 2) / 127.5 - 1
        video = torch.from_numpy(video).unsqueeze(0).to("cuda", dtype=torch.float16)
        
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        
        start.record()
        encoded = vae.encode(video, return_dict=False)[0].sample()
        decoded = (vae.decode(encoded, return_dict=False)[0].float() + 1) * 255
        end.record()
        torch.cuda.synchronize()
        time_elapsed = start.elapsed_time(end) / 1000

        decoded = torch.clamp(decoded, 0, 255)
        video = decoded.permute(0, 2, 3, 4, 1)[0].cpu().numpy().astype(np.uint8)

        print("Inference time:", i, " - ", round(time_elapsed, 3))

Vanilla compile on vae:

Inference time: 0  -  9.357
Inference time: 1  -  7.066
Inference time: 2  -  7.084
Inference time: 3  -  7.057
Inference time: 4  -  7.053

Separate torch.compile() on encode() and decode():

Inference time: 0  -  1286.903
Inference time: 1  -  13.297
Inference time: 2  -  4.504
Inference time: 3  -  4.514
Inference time: 4  -  4.522

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review September 26, 2024 13:26
Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! let's also run slow tests on cogvideox to make sure it is backward compatible on the vae

I didn't find any usage on the individual blocks on github so I think we don't have to worry about breaking for these blocks

feel free to merge after that!

@a-r-r-o-w
Copy link
Member Author

My local testing script works fine and produces the same video as before these changes. I just realized that our CogVideoXIntegrationTests fail on the CI due to OOM, pt -> np conversion, and shape mismatch. I'll fix this in a follow-up PR.

@a-r-r-o-w a-r-r-o-w merged commit bd4df28 into main Sep 28, 2024
18 checks passed
@a-r-r-o-w a-r-r-o-w deleted the cogvideox/vae-remove-conv-cache branch September 28, 2024 11:39
leisuzz pushed a commit to leisuzz/diffusers that referenced this pull request Oct 11, 2024
* remove conv cache from the layer and pass as arg instead

* make style

* yiyi's cleaner implementation

Co-Authored-By: YiYi Xu <yixu310@gmail.com>

* sayak's compiled implementation

Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* remove conv cache from the layer and pass as arg instead

* make style

* yiyi's cleaner implementation

Co-Authored-By: YiYi Xu <yixu310@gmail.com>

* sayak's compiled implementation

Co-Authored-By: Sayak Paul <spsayakpaul@gmail.com>

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants