Skip to content

Commit

Permalink
Remove the Q->ne[1] > 8 check
Browse files Browse the repository at this point in the history
  • Loading branch information
hjc4869 committed Feb 25, 2025
1 parent 29debe1 commit 5d4ab04
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
4 changes: 2 additions & 2 deletions ggml/src/ggml-cuda/fattn-wmma-f16.cu
Original file line number Diff line number Diff line change
Expand Up @@ -578,10 +578,10 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
return;
}

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
Expand All @@ -594,13 +594,13 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
default:
GGML_ABORT("fatal error");
break;
}
return;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))

if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/fattn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst

if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (fp16_mma_available(cc) && dst->src[0]->ne[1] > 8) {
if (fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
Expand Down

0 comments on commit 5d4ab04

Please sign in to comment.