Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support GQA and MQA #60550

Merged
merged 1 commit into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion cmake/external/flashattn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ set(FLASHATTN_PREFIX_DIR ${THIRD_PARTY_PATH}/flashattn)
set(FLASHATTN_SOURCE_SUBDIR csrc)
set(FLASHATTN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/flashattn)
set(FLASHATTN_REPOSITORY ${GIT_URL}/PaddlePaddle/flash-attention.git)
set(FLASHATTN_TAG 0598fa245bbfb8c4462002600864518c0e37e714)
set(FLASHATTN_TAG fd6890c7ef6e53380b9eddc0a12b5acc641eb57d)

set(FLASHATTN_INCLUDE_DIR
"${FLASHATTN_INSTALL_DIR}/include"
Expand Down Expand Up @@ -67,6 +67,20 @@ else()
set(FLASHATTN_CXX_FLAGS_DEBUG ${CMAKE_CXX_FLAGS_DEBUG})
endif()

set(FA_NVCC_ARCH_BIN "")
foreach(arch ${NVCC_ARCH_BIN})
string(STRIP ${arch} arch)
if(arch STREQUAL "")
continue()
endif()

if(FA_NVCC_ARCH_BIN STREQUAL "")
set(FA_NVCC_ARCH_BIN "${arch}")
else()
set(FA_NVCC_ARCH_BIN "${FA_NVCC_ARCH_BIN}-${arch}")
endif()
endforeach()

ExternalProject_Add(
extern_flashattn
${EXTERNAL_PROJECT_LOG_ARGS} ${SHALLOW_CLONE}
Expand Down Expand Up @@ -94,6 +108,7 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE=ON
-DCMAKE_BUILD_TYPE=${THIRD_PARTY_BUILD_TYPE}
-DCMAKE_JOB_POOL_COMPILE:STRING=compile
-DNVCC_ARCH_BIN=${FA_NVCC_ARCH_BIN}
-DCMAKE_JOB_POOLS:STRING=compile=4
${EXTERNAL_OPTIONAL_ARGS}
CMAKE_CACHE_ARGS
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/infermeta/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1206,15 +1206,15 @@ void FusedRopeGradInferMeta(const MetaTensor& dout_q,
"[batch_size, seq_len, num_heads, head_dim],"
"but got %u.",
input_dims.size()));
if (dout_q) {
if (dout_q && dq) {
dq->set_dims(dout_q.dims());
dq->set_dtype(dout_q.dtype());
}
if (dout_k) {
if (dout_k && dk) {
dk->set_dims(dout_k.dims());
dk->set_dtype(dout_k.dtype());
}
if (dout_v) {
if (dout_v && dv) {
dv->set_dims(dout_v.dims());
dv->set_dtype(dout_v.dtype());
}
Expand Down
29 changes: 25 additions & 4 deletions paddle/phi/kernels/gpu/flash_attn_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/kernels/empty_kernel.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/kernels/reshape_kernel.h"
#ifdef PADDLE_WITH_FLASHATTN
#include "paddle/phi/backends/dynload/flashattn.h"
Expand Down Expand Up @@ -196,8 +197,23 @@ void FlashAttnGradKernel(const Context& ctx,
seed_offset.data<int64_t>());

ctx.template Alloc<T>(dq);
ctx.template Alloc<T>(dk);
ctx.template Alloc<T>(dv);

bool is_mha = (num_heads == num_heads_k);

void* dk_data = nullptr;
void* dv_data = nullptr;
phi::DenseTensor dk_expanded, dv_expanded;
if (is_mha) {
dk_data = ctx.template Alloc<T>(dk);
dv_data = ctx.template Alloc<T>(dv);
} else {
std::initializer_list<int64_t> dk_dv_shape = {
batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size};
dk_expanded.Resize(dk_dv_shape);
dv_expanded.Resize(dk_dv_shape);
dk_data = ctx.template Alloc<T>(&dk_expanded);
dv_data = ctx.template Alloc<T>(&dv_expanded);
}

cudaStream_t stream = ctx.stream();

Expand All @@ -216,8 +232,8 @@ void FlashAttnGradKernel(const Context& ctx,
softmax_lse.data(),
params.rng_state.data(),
dq->data(),
dk->data(),
dv->data(),
dk_data,
dv_data,
params.dq_accum.data(),
params.batch_size,
params.max_seqlen_q,
Expand All @@ -240,6 +256,11 @@ void FlashAttnGradKernel(const Context& ctx,
params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr,
params.mask_dims.data());
CheckFlashAttnStatus(succ);

if (!is_mha) {
phi::SumKernel<T, Context>(ctx, dk_expanded, {3}, dk->type(), false, dk);
phi::SumKernel<T, Context>(ctx, dv_expanded, {3}, dv->type(), false, dv);
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FlashAttention is unsupported, please set use_flash_attn to false."));
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/flash_attn_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ struct FlashAttnFwdParamsV2 {
max_seqlen_q(_max_seqlen_q),
max_seqlen_k(_max_seqlen_k),
num_heads(_num_heads),
num_heads_k(_num_heads),
num_heads_k(_num_heads_k),
head_size(_head_size),
scale(_scale),
dropout(_dropout),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,8 @@ def send_forward_backward_recv_forward_backward(
if _timers is not None:
_timers("send_forward_backward_recv_forward_backward").start()

self._send_meta(output_tensor)
if output_tensor is not None:
self._send_meta(output_tensor)
if recv_prev:
self._recv_meta()

Expand Down