Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch size >= 65536 in xformers.ops.memory_efficient_attention gives CUDA error. #845

Open
comfyanonymous opened this issue Sep 3, 2023 · 9 comments
Labels
bug Something isn't working

Comments

@comfyanonymous
Copy link

🐛 Bug

Xformers gives a CUDA error like this when the batch size is larger or equal to 65536.

RuntimeError: CUDA error: invalid configuration argument
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Command

To Reproduce

Steps to reproduce the behavior:

import xformers
import xformers.ops
import torch

q = torch.zeros(([65536, 16, 80])).cuda()
k = torch.zeros(([65536, 16, 80])).cuda()
v = torch.zeros(([65536, 16, 80])).cuda()
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)

Expected behavior

Raise a NotImplementedError or a ValueError if the input sizes are not supported.

Environment

I can reproduce this with the above code on my 3090 TI with xformers 0.0.21 and on the T4 GPU on free google colab with xformers-0.0.22.dev599

@danthe3rd danthe3rd added the bug Something isn't working label Sep 4, 2023
@danthe3rd
Copy link
Contributor

Hi,
Thanks for reporting this bug! We'll try to get this fixed asap.

@continue-revolution
Copy link

Can confirm this happens to me as well (AnimateDiff) for xformers >= 0.0.21

If I run with xformers==0.0.20, things work well

@Yard1
Copy link

Yard1 commented Oct 5, 2023

Also running into this issue.

@xzqjack
Copy link

xzqjack commented Oct 18, 2023

same error

@samiede
Copy link

samiede commented Oct 25, 2023

I do have the same issue as well!

@dianyo
Copy link

dianyo commented Dec 12, 2023

Hi @danthe3rd ,

I've traced back a bit to cuda code here. I found the problem is came from that the batch size used in the original attention layer will build corresponding SM threads on GPU. If the threads(batch) size is larger than one GPU can support (A100 can only support up to 32 x 2048 = 65536 threads), the error occurred.

Also took a quick look at pytorch source code and found that they always have a constraint constant (one calledMAX_BLOCK_SIZE) to deal with large amount of resource. Using the similar logic might solve this issue.

@danthe3rd
Copy link
Contributor

Hey,
So if you want to have a look, this is because we run many blocks in parallel across 3 dimensions (x,y,z), and there is a limit to 65k for dimensions y and z (source).
As you can see, we use dimension x for the number of queries, dimension y for the number of heads, and dimension z for the batch size.

__host__ dim3 getBlocksGrid() const {
return dim3(
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
num_heads,
num_batches);
}

A proper solution would be to "flatten" these dimensions into the x axis, and replace each occurence of blockIdx.[x,y,z] and gridDim.[x,y,z] in the code. Now you would also have to do it for Flash-Attention so this would be a bit more complicated...

@Ir1d
Copy link

Ir1d commented Apr 18, 2024

Hi, what is the workaround for this issue?

@guolinke
Copy link

guolinke commented Jun 1, 2024

A fast work-around is using several small sub-batches, each with batch size < 6.5k.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

9 participants