Skip to content

Commit

Permalink
fix GQA bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Jan 12, 2024
1 parent 114c152 commit badd90a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
const int64_t num_heads = dims[1];
const int64_t head_size_og = dout.dims()[2];
const int64_t head_size = dims[2];
const int64_t total_k = k.dims[0];
const int64_t total_k = k.dims()[0];
const int64_t num_heads_k = k.dims()[1];

bool is_mha = (num_heads == num_heads_k);
Expand All @@ -80,7 +80,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
total_k, num_heads_k, num_heads / num_heads_k, head_size};

DenseTensor dk_tmp;
if (dk) {
if (dk && is_mha) {
ctx.template Alloc<T>(dk);
dk_ptr = dk->data();
} else {
Expand All @@ -89,7 +89,7 @@ void FlashAttnUnpaddedGradKernel(const Context& ctx,
}

DenseTensor dv_tmp;
if (dv) {
if (dv && is_mha) {
ctx.template Alloc<T>(dv);
dv_ptr = dv->data();
} else {
Expand Down Expand Up @@ -219,7 +219,7 @@ void FlashAttnGradKernel(const Context& ctx,
DenseTensor dk_tmp;
std::initializer_list<int64_t> dk_dv_shape = {
batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size};
if (dk) {
if (dk && is_mha) {
ctx.template Alloc<T>(dk);
dk_ptr = dk->data();
} else {
Expand All @@ -228,7 +228,7 @@ void FlashAttnGradKernel(const Context& ctx,
}

DenseTensor dv_tmp;
if (dv) {
if (dv && is_mha) {
ctx.template Alloc<T>(dv);
dv_ptr = dv->data();
} else {
Expand Down

0 comments on commit badd90a

Please sign in to comment.