Skip to content

Commit 5c5e46d

Browse files
runzhechco63oc
authored andcommitted
[XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (PaddlePaddle#64003)
* [XPU] fix segmentfault caused by setting fix_seed_offset on XPU * cast attention_mask to float32 when necessary
1 parent 2ed82cd commit 5c5e46d

File tree

2 files changed

+94
-48
lines changed

2 files changed

+94
-48
lines changed

paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc

+43-23
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,28 @@ void FlashAttnGradKernel(const Context& ctx,
6969
const XPUType* out_data = reinterpret_cast<const XPUType*>(out.data<T>());
7070
const float* softmax_lse_data = softmax_lse.data<float>();
7171
const XPUType* dout_data = reinterpret_cast<const XPUType*>(dout.data<T>());
72+
73+
xpu::ctx_guard RAII_GUARD(ctx.x_context());
7274
const float* bias_data = nullptr;
7375
if (attn_mask.get_ptr() != nullptr) {
74-
bias_data = attn_mask->data<float>();
76+
if (attn_mask->dtype() == phi::DataType::FLOAT32) {
77+
bias_data = attn_mask->data<float>();
78+
} else if (attn_mask->dtype() == phi::DataType::FLOAT16 ||
79+
attn_mask->dtype() == phi::DataType::BFLOAT16) {
80+
float* bias_tmp = RAII_GUARD.alloc_l3_or_gm<float>(attn_mask->numel());
81+
int r = xpu::cast<XPUType, float>(
82+
ctx.x_context(),
83+
reinterpret_cast<const XPUType*>(attn_mask->data<T>()),
84+
bias_tmp,
85+
attn_mask->numel());
86+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
87+
bias_data = bias_tmp;
88+
} else {
89+
errors::Unimplemented(
90+
"Unsupported dtype for attention_mask in xpu flash attention, only "
91+
"float32, float16 and "
92+
"bfloat16 are supported.");
93+
}
7594
}
7695
// output
7796
XPUType* dq_data = reinterpret_cast<XPUType*>(dq->data<T>());
@@ -92,6 +111,7 @@ void FlashAttnGradKernel(const Context& ctx,
92111

93112
// get seed offset
94113
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
114+
95115
// template<typename T, typename TACCUM, typename TGEMM, typename TID = int>
96116
// int mha_varlen_bwd(xdnn::Context* ctx, const T* dout, const T* q, const T*
97117
// k, const T* v, const T* out, const TACCUM* softmax_lse, T* dq, T* dk, T*
@@ -106,28 +126,28 @@ void FlashAttnGradKernel(const Context& ctx,
106126
// dv_maxptr = nullptr, const float* do_maxptr = nullptr);
107127
int r = baidu::xpu::xfa::mha_varlen_bwd<XPUType, float, tfloat32, int>(
108128
ctx.x_context(),
109-
dout_data, // dout
110-
q_data, // q
111-
k_data, // k
112-
v_data, // v
113-
out_data, // out
114-
softmax_lse_data, // softmax_lse
115-
dq_data, // dq
116-
dk_data, // dk
117-
dv_data, // dv
118-
qlod, // lod_seqlens_q
119-
kvlod, // lod_seqlens_k
120-
seqlen_q, // max_seqlen_q
121-
seqlen_k, // max_seqlen_k
122-
num_heads, // head_num
123-
num_heads_k, // head_num_k
124-
head_size, // head_dim
125-
1.0f / std::sqrt(head_size), // softmax_scale
126-
dropout, // p_dropout
127-
static_cast<uint64_t>(seed_offset_data[0]), // seed
128-
causal, // is_causal
129-
nullptr, // attn_mask
130-
bias_data // bias
129+
dout_data, // dout
130+
q_data, // q
131+
k_data, // k
132+
v_data, // v
133+
out_data, // out
134+
softmax_lse_data, // softmax_lse
135+
dq_data, // dq
136+
dk_data, // dk
137+
dv_data, // dv
138+
qlod, // lod_seqlens_q
139+
kvlod, // lod_seqlens_k
140+
seqlen_q, // max_seqlen_q
141+
seqlen_k, // max_seqlen_k
142+
num_heads, // head_num
143+
num_heads_k, // head_num_k
144+
head_size, // head_dim
145+
1.0f / std::sqrt(head_size), // softmax_scale
146+
dropout, // p_dropout
147+
static_cast<int32_t>(seed_offset_data[0]), // seed
148+
causal, // is_causal
149+
nullptr, // attn_mask
150+
bias_data // bias
131151
);
132152
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd");
133153
#else

paddle/phi/kernels/xpu/flash_attn_kernel.cc

+51-25
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include "paddle/phi/kernels/flash_attn_kernel.h"
1616
#include "paddle/phi/backends/xpu/enforce_xpu.h"
17-
#include "paddle/phi/core/enforce.h"
17+
#include "paddle/phi/common/memory_utils.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919

2020
#ifdef PADDLE_WITH_XPU_XHPC
@@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx,
239239
seed_offset->Resize({2});
240240
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
241241
if (fixed_seed_offset.get_ptr()) {
242-
const int64_t* fixed_seed_offset_data =
243-
fixed_seed_offset.get_ptr()->data<int64_t>();
244-
seed_offset_data[0] = fixed_seed_offset_data[0];
245-
seed_offset_data[1] = fixed_seed_offset_data[1];
242+
if ((fixed_seed_offset->place()).GetType() == phi::AllocationType::XPU) {
243+
memory_utils::Copy(phi::CPUPlace(),
244+
seed_offset_data,
245+
fixed_seed_offset->place(),
246+
fixed_seed_offset->data<int64_t>(),
247+
sizeof(int64_t) * 2);
248+
} else {
249+
const int64_t* fixed_seed_offset_data =
250+
fixed_seed_offset->data<int64_t>();
251+
seed_offset_data[0] = fixed_seed_offset_data[0];
252+
seed_offset_data[1] = fixed_seed_offset_data[1];
253+
}
246254
} else {
247255
std::pair<uint64_t, uint64_t> seed_offset_pair;
248256
uint64_t inc = batch_size * num_heads * 32;
@@ -263,11 +271,29 @@ void FlashAttnKernel(const Context& ctx,
263271
const XPUType* k_data = reinterpret_cast<const XPUType*>(k.data<T>());
264272
const XPUType* v_data = reinterpret_cast<const XPUType*>(v.data<T>());
265273
XPUType* out_data = reinterpret_cast<XPUType*>(out->data<T>());
266-
float* softmax_lse_data = softmax_lse->data<float>();
267274

275+
xpu::ctx_guard RAII_GUARD(ctx.x_context());
276+
float* softmax_lse_data = softmax_lse->data<float>();
268277
const float* bias_data = nullptr;
269278
if (attn_mask.get_ptr() != nullptr) {
270-
bias_data = attn_mask->data<float>();
279+
if (attn_mask->dtype() == phi::DataType::FLOAT32) {
280+
bias_data = attn_mask->data<float>();
281+
} else if (attn_mask->dtype() == phi::DataType::FLOAT16 ||
282+
attn_mask->dtype() == phi::DataType::BFLOAT16) {
283+
float* bias_tmp = RAII_GUARD.alloc_l3_or_gm<float>(attn_mask->numel());
284+
int r = xpu::cast<XPUType, float>(
285+
ctx.x_context(),
286+
reinterpret_cast<const XPUType*>(attn_mask->data<T>()),
287+
bias_tmp,
288+
attn_mask->numel());
289+
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
290+
bias_data = bias_tmp;
291+
} else {
292+
errors::Unimplemented(
293+
"Unsupported dtype for attention_mask in xpu flash attention, only "
294+
"float32, float16 and "
295+
"bfloat16 are supported.");
296+
}
271297
}
272298
// template <typename T, typename TACCUM, typename TGEMM, typename TID> int
273299
// mha_varlen_fwd(xdnn::Context* ctx, const T* q, const T* k, const T* v, T*
@@ -281,24 +307,24 @@ void FlashAttnKernel(const Context& ctx,
281307
// nullptr);
282308
int r = baidu::xpu::xfa::mha_varlen_fwd<XPUType, float, tfloat32, int>(
283309
ctx.x_context(),
284-
q_data, // q
285-
k_data, // k
286-
v_data, // v
287-
out_data, // out
288-
softmax_lse_data, // softmax_lse
289-
qlod, // lod_seqlens_q
290-
kvlod, // lod_seqlens_k
291-
seqlen_q, // max_seqlen_q
292-
seqlen_k, // max_seqlen_k
293-
num_heads, // head_num
294-
num_heads_k, // head_num_k
295-
head_size, // head_dim
296-
1.0f / std::sqrt(head_size), // softmax_scale
297-
dropout, // p_dropout
298-
static_cast<uint64_t>(seed_offset_data[0]), // seed
299-
causal, // is_causal
300-
nullptr, // attn_mask
301-
bias_data // bias
310+
q_data, // q
311+
k_data, // k
312+
v_data, // v
313+
out_data, // out
314+
softmax_lse_data, // softmax_lse
315+
qlod, // lod_seqlens_q
316+
kvlod, // lod_seqlens_k
317+
seqlen_q, // max_seqlen_q
318+
seqlen_k, // max_seqlen_k
319+
num_heads, // head_num
320+
num_heads_k, // head_num_k
321+
head_size, // head_dim
322+
1.0f / std::sqrt(head_size), // softmax_scale
323+
dropout, // p_dropout
324+
static_cast<int32_t>(seed_offset_data[0]), // seed
325+
causal, // is_causal
326+
nullptr, // attn_mask
327+
bias_data // bias
302328
);
303329
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd");
304330
#else

0 commit comments

Comments
 (0)