Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] ControlNet execution fails with CUDNN_STATUS_NOT_SUPPORTED error due to CUDNN tensor contiguity issue in convolutions #169

Open
Kinyugo opened this issue Jan 28, 2025 · 5 comments

Comments

@Kinyugo
Copy link

Kinyugo commented Jan 28, 2025

Summary

The stable_fast library fails to run the ControlNet module in a Stable Diffusion Image-to-Image pipeline. The error occurs specifically during CUDNN convolution operations.

Steps to Reproduce

  1. Load a Stable Diffusion Image-to-Image pipeline with ControlNet
  2. Configure compilation settings with do_compile_controlnet=True
  3. Attempt to compile the pipeline
  4. Observe CUDNN error during ControlNet during execution with some input sizes

Error

The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):

graph(%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13):
    %x = sfast::cudnn_convolution_bias(%1, %2, %3, %4, %5, %6, %7, %8, %9)
         ~~~~~ <--- HERE
    return (%x)
RuntimeError: cuDNN error: CUDNN_STATUS_NOT_SUPPORTED. This error may appear if you passed in a non-contiguous input.

Investigation Findings

  1. Error only occurs when compiling the ControlNet module
  2. Other components (VAE, UNet) compile successfully
  3. Adding .contiguous() to ControlNet inputs did not resolve the issue
  4. The error appears to be related to tensor memory layout during CUDNN operations

Here is the code I used to selectively compile different parts of the pipeline:

def compile(m, config, do_compile_vae=True, do_compile_unet=True, do_compile_controlnet=True):
    # attribute `device` is not generally available
    device = m.device if hasattr(m, 'device') else torch.device(
        'cuda' if torch.cuda.is_available() else 'cpu')

    enable_cuda_graph = config.enable_cuda_graph and device.type == 'cuda'

    if do_compile_unet:
        m.unet = compile_unet(m.unet, config)
    if do_compile_controlnet and hasattr(m, 'controlnet'):
        m.controlnet = compile_unet(m.controlnet, config)
    if do_compile_vae:
        m.vae = compile_vae(m.vae, config)

    if config.enable_jit:
        lazy_trace_ = _build_lazy_trace(config)

        if getattr(m, 'text_encoder', None) is not None:
            m.text_encoder.forward = lazy_trace_(m.text_encoder.forward)
        # for SDXL
        if getattr(m, 'text_encoder_2', None) is not None:
            m.text_encoder_2.forward = lazy_trace_(m.text_encoder_2.forward)
        # for SVD
        if getattr(m, 'image_encoder', None) is not None:
            m.image_encoder.forward = lazy_trace_(m.image_encoder.forward)
        if config.trace_scheduler:
            m.scheduler.scale_model_input = lazy_trace_(
                m.scheduler.scale_model_input)
            m.scheduler.step = lazy_trace_(m.scheduler.step)

    if enable_cuda_graph:
        if getattr(m, 'text_encoder', None) is not None:
            m.text_encoder.forward = make_dynamic_graphed_callable(
                m.text_encoder.forward)
        if getattr(m, 'text_encoder_2', None) is not None:
            m.text_encoder_2.forward = make_dynamic_graphed_callable(
                m.text_encoder_2.forward)
        if getattr(m, 'image_encoder', None) is not None:
            m.image_encoder.forward = make_dynamic_graphed_callable(
                m.image_encoder.forward)

    if hasattr(m, 'image_processor'):
        from sfast.libs.diffusers.image_processor import patch_image_prcessor
        patch_image_prcessor(m.image_processor)

    return m

Questions

  1. Are there known issues with CUDNN operations and specific tensor shapes?
  2. Are there any recommended workarounds?

Environment Details

Collecting environment information...
PyTorch version: 2.2.2+cu118
Is debug build: False
CUDA used to build PyTorch: 11.8
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.35

Python version: 3.10.16 | packaged by conda-forge | (main, Dec  5 2024, 14:16:10) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-5.15.0-86-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 535.104.12
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                       x86_64
CPU op-mode(s):                     32-bit, 64-bit
Address sizes:                      46 bits physical, 57 bits virtual
Byte Order:                         Little Endian
CPU(s):                             12
On-line CPU(s) list:                0-11
Vendor ID:                          GenuineIntel
Model name:                         Intel(R) Xeon(R) Gold 6338 CPU @ 2.00GHz
CPU family:                         6
Model:                              106
Thread(s) per core:                 2
Core(s) per socket:                 6
Socket(s):                          1
Stepping:                           6
BogoMIPS:                           3999.99
Flags:                              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves wbnoinvd arat avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid fsrm md_clear arch_capabilities
Virtualization:                     VT-x
Hypervisor vendor:                  KVM
Virtualization type:                full
L1d cache:                          288 KiB (6 instances)
L1i cache:                          192 KiB (6 instances)
L2 cache:                           7.5 MiB (6 instances)
L3 cache:                           48 MiB (1 instance)
NUMA node(s):                       1
NUMA node0 CPU(s):                  0-11
Vulnerability Gather data sampling: Unknown: Dependent on hypervisor status
Vulnerability Itlb multihit:        Not affected
Vulnerability L1tf:                 Not affected
Vulnerability Mds:                  Not affected
Vulnerability Meltdown:             Not affected
Vulnerability Mmio stale data:      Mitigation; Clear CPU buffers; SMT Host state unknown
Vulnerability Retbleed:             Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:           Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:                Not affected
Vulnerability Tsx async abort:      Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==1.26.3
[pip3] nvidia-cublas-cu11==11.11.3.6
[pip3] nvidia-cuda-cupti-cu11==11.8.87
[pip3] nvidia-cuda-nvrtc-cu11==11.8.89
[pip3] nvidia-cuda-runtime-cu11==11.8.89
[pip3] nvidia-cudnn-cu11==8.7.0.84
[pip3] nvidia-cufft-cu11==10.9.0.58
[pip3] nvidia-curand-cu11==10.3.0.86
[pip3] nvidia-cusolver-cu11==11.4.1.48
[pip3] nvidia-cusparse-cu11==11.7.5.86
[pip3] nvidia-nccl-cu11==2.19.3
[pip3] nvidia-nvtx-cu11==11.8.86
[pip3] stable-fast==1.0.5+torch222cu118
[pip3] torch==2.2.2+cu118
[pip3] torchaudio==2.2.2+cu118
[pip3] torchvision==0.17.2+cu118
[pip3] triton==2.2.0
[conda] numpy                     1.26.3                   pypi_0    pypi
[conda] nvidia-cublas-cu11        11.11.3.6                pypi_0    pypi
[conda] nvidia-cuda-cupti-cu11    11.8.87                  pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu11    11.8.89                  pypi_0    pypi
[conda] nvidia-cuda-runtime-cu11  11.8.89                  pypi_0    pypi
[conda] nvidia-cudnn-cu11         8.7.0.84                 pypi_0    pypi
[conda] nvidia-cufft-cu11         10.9.0.58                pypi_0    pypi
[conda] nvidia-curand-cu11        10.3.0.86                pypi_0    pypi
[conda] nvidia-cusolver-cu11      11.4.1.48                pypi_0    pypi
[conda] nvidia-cusparse-cu11      11.7.5.86                pypi_0    pypi
[conda] nvidia-nccl-cu11          2.19.3                   pypi_0    pypi
[conda] nvidia-nvtx-cu11          11.8.86                  pypi_0    pypi
[conda] stable-fast               1.0.5+torch222cu118          pypi_0    pypi
[conda] torch                     2.2.2+cu118              pypi_0    pypi
[conda] torchaudio                2.2.2+cu118              pypi_0    pypi
[conda] torchvision               0.17.2+cu118             pypi_0    pypi
[conda] triton                    2.2.0                    pypi_0    pypi
[conda] diffusers                 0.32.2                   pypi_0    pypi
[conda] transformers              4.48.1                   pypi_0    pypi
@Kinyugo
Copy link
Author

Kinyugo commented Jan 29, 2025

Additional Information

I have attempted to ensure input and output contiguity but the error still occurs. Strangely it only occurs with some input sizes and not others. e.g

error shape

latent_shape torch.Size([64, 4, 128, 128])
prompt_embeds_shape torch.Size([128, 77, 768])
negative_prompt_embeds_shape torch.Size([1, 77, 768])
control_image_shape torch.Size([128, 3, 1024, 1024])

working shape

latent_shape torch.Size([16, 4, 128, 128])
prompt_embeds_shape torch.Size([32, 77, 768])
negative_prompt_embeds_shape torch.Size([1, 77, 768])
control_image_shape torch.Size([32, 3, 1024, 1024])

Here is my code to enforce contiguity:

def enforce_contiguous_hook(layer_name: str):
    """
    Creates hooks that enforce contiguous tensors for both inputs and outputs
    of Conv2d layers.

    Args:
        layer_name (str): Name of the layer being hooked

    Returns:
        pre_hook_fn: Forward pre-hook function for inputs
        hook_fn: Forward hook function for outputs
    """

    def pre_hook_fn(module, inputs):
        inputs = list(inputs)
        for idx, input_tensor in enumerate(inputs):
            if (
                isinstance(input_tensor, torch.Tensor)
                and not input_tensor.is_contiguous()
            ):
                print(f"🔄 Making input contiguous in {layer_name}")
                inputs[idx] = input_tensor.contiguous()

        return tuple(inputs)

    def hook_fn(module, inputs, outputs):
        if isinstance(outputs, torch.Tensor):
            if not outputs.is_contiguous():
                print(f"🔄 Making output contiguous in {layer_name}")
                outputs = outputs.contiguous()
        else:
            # Handle case where outputs is a tuple
            outputs = list(outputs)
            for idx, output_tensor in enumerate(outputs):
                if (
                    isinstance(output_tensor, torch.Tensor)
                    and not output_tensor.is_contiguous()
                ):
                    print(f"🔄 Making output {idx} contiguous in {layer_name}")
                    outputs[idx] = output_tensor.contiguous()
            outputs = tuple(outputs)

        return outputs

    return pre_hook_fn, hook_fn


def register_contiguous_enforcement(model):
    """
    Register hooks that enforce contiguous tensors for Conv2d layers.

    Args:
        model (nn.Module): PyTorch model to enforce contiguity
    """
    for name, module in model.named_modules():
        pre_hook_fn, hook_fn = enforce_contiguous_hook(name)
        module.register_forward_pre_hook(pre_hook_fn)
        module.register_forward_hook(hook_fn)

@chengzeyi
Copy link
Owner

@Kinyugo Can you try setting enable_cnn_optimization = False when compiling controlnet?

@Kinyugo
Copy link
Author

Kinyugo commented Jan 30, 2025

@chengzeyi thanks for the fix compilation works when enable_cnn_optimization=False.

Do you have an idea of the amount of loss in performance that it causes? Since controlnet usually has large image inputs.

@chengzeyi
Copy link
Owner

@chengzeyi thanks for the fix compilation works when enable_cnn_optimization=False.

Do you have an idea of the amount of loss in performance that it causes? Since controlnet usually has large image inputs.

Should be not very large, I guess. Besides, currently my major focus is on newer models so if you are interested you could see my other projects like ParaAttention or Comfy-WaveSpeed.

@Kinyugo
Copy link
Author

Kinyugo commented Jan 30, 2025

Thank you. I will take a look at the projects.


Here is the updated code. Let me know if you would like to add the option to skip cnn optimizations for the controlnet. I could add an option to the config and open a PR with the changes.

import copy

import torch
from sfast.compilers.diffusion_pipeline_compiler import (
    _build_lazy_trace,
    compile_unet,
    compile_vae,
    make_dynamic_graphed_callable,
)


def sfast_compile(m, config):
    # attribute `device` is not generally available
    device = (
        m.device
        if hasattr(m, "device")
        else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )

    enable_cuda_graph = config.enable_cuda_graph and device.type == "cuda"

    m.unet = compile_unet(m.unet, config)
    if hasattr(m, "controlnet"):
        controlnet_config = copy.deepcopy(config)
        controlnet_config.enable_cnn_optimization = False
        m.controlnet = compile_unet(m.controlnet, controlnet_config)
    m.vae = compile_vae(m.vae, config)

    if config.enable_jit:
        lazy_trace_ = _build_lazy_trace(config)

        if getattr(m, "text_encoder", None) is not None:
            m.text_encoder.forward = lazy_trace_(m.text_encoder.forward)
        # for SDXL
        if getattr(m, "text_encoder_2", None) is not None:
            m.text_encoder_2.forward = lazy_trace_(m.text_encoder_2.forward)
        # for SVD
        if getattr(m, "image_encoder", None) is not None:
            m.image_encoder.forward = lazy_trace_(m.image_encoder.forward)
        if config.trace_scheduler:
            m.scheduler.scale_model_input = lazy_trace_(m.scheduler.scale_model_input)
            m.scheduler.step = lazy_trace_(m.scheduler.step)

    if enable_cuda_graph:
        if getattr(m, "text_encoder", None) is not None:
            m.text_encoder.forward = make_dynamic_graphed_callable(
                m.text_encoder.forward
            )
        if getattr(m, "text_encoder_2", None) is not None:
            m.text_encoder_2.forward = make_dynamic_graphed_callable(
                m.text_encoder_2.forward
            )
        if getattr(m, "image_encoder", None) is not None:
            m.image_encoder.forward = make_dynamic_graphed_callable(
                m.image_encoder.forward
            )

    if hasattr(m, "image_processor"):
        from sfast.libs.diffusers.image_processor import patch_image_prcessor

        patch_image_prcessor(m.image_processor)

    return m

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants