-
Notifications
You must be signed in to change notification settings - Fork 652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch size >= 65536 in xformers.ops.memory_efficient_attention gives CUDA error. #845
Comments
Hi, |
Can confirm this happens to me as well (AnimateDiff) for xformers >= 0.0.21 If I run with xformers==0.0.20, things work well |
Also running into this issue. |
same error |
I do have the same issue as well! |
Hi @danthe3rd , I've traced back a bit to cuda code here. I found the problem is came from that the batch size used in the original attention layer will build corresponding SM threads on GPU. If the threads(batch) size is larger than one GPU can support (A100 can only support up to 32 x 2048 = 65536 threads), the error occurred. Also took a quick look at pytorch source code and found that they always have a constraint constant (one called |
Hey, xformers/xformers/csrc/attention/cuda/fmha/kernel_forward.h Lines 358 to 363 in 1254a16
A proper solution would be to "flatten" these dimensions into the |
Hi, what is the workaround for this issue? |
A fast work-around is using several small sub-batches, each with batch size < 6.5k. |
🐛 Bug
Xformers gives a CUDA error like this when the batch size is larger or equal to 65536.
Command
To Reproduce
Steps to reproduce the behavior:
Expected behavior
Raise a NotImplementedError or a ValueError if the input sizes are not supported.
Environment
I can reproduce this with the above code on my 3090 TI with xformers 0.0.21 and on the T4 GPU on free google colab with xformers-0.0.22.dev599
The text was updated successfully, but these errors were encountered: