Skip to content

Commit 8cf862d

Browse files
committed
[XPU] fix segmentfault caused by setting fix_seed_offset on XPU
1 parent a13f7dc commit 8cf862d

File tree

2 files changed

+54
-46
lines changed

2 files changed

+54
-46
lines changed

paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc

+23-22
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ void FlashAttnGradKernel(const Context& ctx,
9292

9393
// get seed offset
9494
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
95+
9596
// template<typename T, typename TACCUM, typename TGEMM, typename TID = int>
9697
// int mha_varlen_bwd(xdnn::Context* ctx, const T* dout, const T* q, const T*
9798
// k, const T* v, const T* out, const TACCUM* softmax_lse, T* dq, T* dk, T*
@@ -106,28 +107,28 @@ void FlashAttnGradKernel(const Context& ctx,
106107
// dv_maxptr = nullptr, const float* do_maxptr = nullptr);
107108
int r = baidu::xpu::xfa::mha_varlen_bwd<XPUType, float, tfloat32, int>(
108109
ctx.x_context(),
109-
dout_data, // dout
110-
q_data, // q
111-
k_data, // k
112-
v_data, // v
113-
out_data, // out
114-
softmax_lse_data, // softmax_lse
115-
dq_data, // dq
116-
dk_data, // dk
117-
dv_data, // dv
118-
qlod, // lod_seqlens_q
119-
kvlod, // lod_seqlens_k
120-
seqlen_q, // max_seqlen_q
121-
seqlen_k, // max_seqlen_k
122-
num_heads, // head_num
123-
num_heads_k, // head_num_k
124-
head_size, // head_dim
125-
1.0f / std::sqrt(head_size), // softmax_scale
126-
dropout, // p_dropout
127-
static_cast<uint64_t>(seed_offset_data[0]), // seed
128-
causal, // is_causal
129-
nullptr, // attn_mask
130-
bias_data // bias
110+
dout_data, // dout
111+
q_data, // q
112+
k_data, // k
113+
v_data, // v
114+
out_data, // out
115+
softmax_lse_data, // softmax_lse
116+
dq_data, // dq
117+
dk_data, // dk
118+
dv_data, // dv
119+
qlod, // lod_seqlens_q
120+
kvlod, // lod_seqlens_k
121+
seqlen_q, // max_seqlen_q
122+
seqlen_k, // max_seqlen_k
123+
num_heads, // head_num
124+
num_heads_k, // head_num_k
125+
head_size, // head_dim
126+
1.0f / std::sqrt(head_size), // softmax_scale
127+
dropout, // p_dropout
128+
static_cast<int32_t>(seed_offset_data[0]), // seed
129+
causal, // is_causal
130+
nullptr, // attn_mask
131+
bias_data // bias
131132
);
132133
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd");
133134
#else

paddle/phi/kernels/xpu/flash_attn_kernel.cc

+31-24
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include "paddle/phi/kernels/flash_attn_kernel.h"
1616
#include "paddle/phi/backends/xpu/enforce_xpu.h"
17-
#include "paddle/phi/core/enforce.h"
17+
#include "paddle/phi/common/memory_utils.h"
1818
#include "paddle/phi/core/kernel_registry.h"
1919

2020
#ifdef PADDLE_WITH_XPU_XHPC
@@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx,
239239
seed_offset->Resize({2});
240240
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
241241
if (fixed_seed_offset.get_ptr()) {
242-
const int64_t* fixed_seed_offset_data =
243-
fixed_seed_offset.get_ptr()->data<int64_t>();
244-
seed_offset_data[0] = fixed_seed_offset_data[0];
245-
seed_offset_data[1] = fixed_seed_offset_data[1];
242+
if ((fixed_seed_offset->place()).GetType() == phi::AllocationType::XPU) {
243+
memory_utils::Copy(phi::CPUPlace(),
244+
seed_offset_data,
245+
fixed_seed_offset->place(),
246+
fixed_seed_offset->data<int64_t>(),
247+
sizeof(int64_t) * 2);
248+
} else {
249+
const int64_t* fixed_seed_offset_data =
250+
fixed_seed_offset->data<int64_t>();
251+
seed_offset_data[0] = fixed_seed_offset_data[0];
252+
seed_offset_data[1] = fixed_seed_offset_data[1];
253+
}
246254
} else {
247255
std::pair<uint64_t, uint64_t> seed_offset_pair;
248256
uint64_t inc = batch_size * num_heads * 32;
@@ -264,7 +272,6 @@ void FlashAttnKernel(const Context& ctx,
264272
const XPUType* v_data = reinterpret_cast<const XPUType*>(v.data<T>());
265273
XPUType* out_data = reinterpret_cast<XPUType*>(out->data<T>());
266274
float* softmax_lse_data = softmax_lse->data<float>();
267-
268275
const float* bias_data = nullptr;
269276
if (attn_mask.get_ptr() != nullptr) {
270277
bias_data = attn_mask->data<float>();
@@ -281,24 +288,24 @@ void FlashAttnKernel(const Context& ctx,
281288
// nullptr);
282289
int r = baidu::xpu::xfa::mha_varlen_fwd<XPUType, float, tfloat32, int>(
283290
ctx.x_context(),
284-
q_data, // q
285-
k_data, // k
286-
v_data, // v
287-
out_data, // out
288-
softmax_lse_data, // softmax_lse
289-
qlod, // lod_seqlens_q
290-
kvlod, // lod_seqlens_k
291-
seqlen_q, // max_seqlen_q
292-
seqlen_k, // max_seqlen_k
293-
num_heads, // head_num
294-
num_heads_k, // head_num_k
295-
head_size, // head_dim
296-
1.0f / std::sqrt(head_size), // softmax_scale
297-
dropout, // p_dropout
298-
static_cast<uint64_t>(seed_offset_data[0]), // seed
299-
causal, // is_causal
300-
nullptr, // attn_mask
301-
bias_data // bias
291+
q_data, // q
292+
k_data, // k
293+
v_data, // v
294+
out_data, // out
295+
softmax_lse_data, // softmax_lse
296+
qlod, // lod_seqlens_q
297+
kvlod, // lod_seqlens_k
298+
seqlen_q, // max_seqlen_q
299+
seqlen_k, // max_seqlen_k
300+
num_heads, // head_num
301+
num_heads_k, // head_num_k
302+
head_size, // head_dim
303+
1.0f / std::sqrt(head_size), // softmax_scale
304+
dropout, // p_dropout
305+
static_cast<int32_t>(seed_offset_data[0]), // seed
306+
causal, // is_causal
307+
nullptr, // attn_mask
308+
bias_data // bias
302309
);
303310
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd");
304311
#else

0 commit comments

Comments
 (0)