Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU #64003

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 43 additions & 23 deletions paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,28 @@ void FlashAttnGradKernel(const Context& ctx,
const XPUType* out_data = reinterpret_cast<const XPUType*>(out.data<T>());
const float* softmax_lse_data = softmax_lse.data<float>();
const XPUType* dout_data = reinterpret_cast<const XPUType*>(dout.data<T>());

xpu::ctx_guard RAII_GUARD(ctx.x_context());
const float* bias_data = nullptr;
if (attn_mask.get_ptr() != nullptr) {
bias_data = attn_mask->data<float>();
if (attn_mask->dtype() == phi::DataType::FLOAT32) {
bias_data = attn_mask->data<float>();
} else if (attn_mask->dtype() == phi::DataType::FLOAT16 ||
attn_mask->dtype() == phi::DataType::BFLOAT16) {
float* bias_tmp = RAII_GUARD.alloc_l3_or_gm<float>(attn_mask->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(attn_mask->data<T>()),
bias_tmp,
attn_mask->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data = bias_tmp;
} else {
errors::Unimplemented(
"Unsupported dtype for attention_mask in xpu flash attention, only "
"float32, float16 and "
"bfloat16 are supported.");
}
}
// output
XPUType* dq_data = reinterpret_cast<XPUType*>(dq->data<T>());
Expand All @@ -92,6 +111,7 @@ void FlashAttnGradKernel(const Context& ctx,

// get seed offset
const int64_t* seed_offset_data = seed_offset.data<int64_t>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的seed_offset确定是int32_t吗?我看GPU好像是int64

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uploading image.png…

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mha_varlen_fwd提供的接口是int32的,所以这边就直接用int32了

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉是不是得确定下框架里传进来的数据类型,如果有int64的话得cast一把?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

框架传的是int64,我再修改一把

// template<typename T, typename TACCUM, typename TGEMM, typename TID = int>
// int mha_varlen_bwd(xdnn::Context* ctx, const T* dout, const T* q, const T*
// k, const T* v, const T* out, const TACCUM* softmax_lse, T* dq, T* dk, T*
Expand All @@ -106,28 +126,28 @@ void FlashAttnGradKernel(const Context& ctx,
// dv_maxptr = nullptr, const float* do_maxptr = nullptr);
int r = baidu::xpu::xfa::mha_varlen_bwd<XPUType, float, tfloat32, int>(
ctx.x_context(),
dout_data, // dout
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
dq_data, // dq
dk_data, // dk
dv_data, // dv
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<uint64_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
dout_data, // dout
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
dq_data, // dq
dk_data, // dk
dv_data, // dv
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<int32_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd");
#else
Expand Down
76 changes: 51 additions & 25 deletions paddle/phi/kernels/xpu/flash_attn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"

#ifdef PADDLE_WITH_XPU_XHPC
Expand Down Expand Up @@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx,
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed_offset_data[0] = fixed_seed_offset_data[0];
seed_offset_data[1] = fixed_seed_offset_data[1];
if ((fixed_seed_offset->place()).GetType() == phi::AllocationType::XPU) {
memory_utils::Copy(phi::CPUPlace(),
seed_offset_data,
fixed_seed_offset->place(),
fixed_seed_offset->data<int64_t>(),
sizeof(int64_t) * 2);
} else {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset->data<int64_t>();
seed_offset_data[0] = fixed_seed_offset_data[0];
seed_offset_data[1] = fixed_seed_offset_data[1];
}
} else {
std::pair<uint64_t, uint64_t> seed_offset_pair;
uint64_t inc = batch_size * num_heads * 32;
Expand All @@ -263,11 +271,29 @@ void FlashAttnKernel(const Context& ctx,
const XPUType* k_data = reinterpret_cast<const XPUType*>(k.data<T>());
const XPUType* v_data = reinterpret_cast<const XPUType*>(v.data<T>());
XPUType* out_data = reinterpret_cast<XPUType*>(out->data<T>());
float* softmax_lse_data = softmax_lse->data<float>();

xpu::ctx_guard RAII_GUARD(ctx.x_context());
float* softmax_lse_data = softmax_lse->data<float>();
const float* bias_data = nullptr;
if (attn_mask.get_ptr() != nullptr) {
bias_data = attn_mask->data<float>();
if (attn_mask->dtype() == phi::DataType::FLOAT32) {
bias_data = attn_mask->data<float>();
} else if (attn_mask->dtype() == phi::DataType::FLOAT16 ||
attn_mask->dtype() == phi::DataType::BFLOAT16) {
float* bias_tmp = RAII_GUARD.alloc_l3_or_gm<float>(attn_mask->numel());
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(attn_mask->data<T>()),
bias_tmp,
attn_mask->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
bias_data = bias_tmp;
} else {
errors::Unimplemented(
"Unsupported dtype for attention_mask in xpu flash attention, only "
"float32, float16 and "
"bfloat16 are supported.");
}
}
// template <typename T, typename TACCUM, typename TGEMM, typename TID> int
// mha_varlen_fwd(xdnn::Context* ctx, const T* q, const T* k, const T* v, T*
Expand All @@ -281,24 +307,24 @@ void FlashAttnKernel(const Context& ctx,
// nullptr);
int r = baidu::xpu::xfa::mha_varlen_fwd<XPUType, float, tfloat32, int>(
ctx.x_context(),
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<uint64_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<int32_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd");
#else
Expand Down