14
14
15
15
#include " paddle/phi/kernels/flash_attn_kernel.h"
16
16
#include " paddle/phi/backends/xpu/enforce_xpu.h"
17
- #include " paddle/phi/core/enforce .h"
17
+ #include " paddle/phi/common/memory_utils .h"
18
18
#include " paddle/phi/core/kernel_registry.h"
19
19
20
20
#ifdef PADDLE_WITH_XPU_XHPC
@@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx,
239
239
seed_offset->Resize ({2 });
240
240
int64_t * seed_offset_data = ctx.template HostAlloc <int64_t >(seed_offset);
241
241
if (fixed_seed_offset.get_ptr ()) {
242
- const int64_t * fixed_seed_offset_data =
243
- fixed_seed_offset.get_ptr ()->data <int64_t >();
244
- seed_offset_data[0 ] = fixed_seed_offset_data[0 ];
245
- seed_offset_data[1 ] = fixed_seed_offset_data[1 ];
242
+ if ((fixed_seed_offset->place ()).GetType () == phi::AllocationType::XPU) {
243
+ memory_utils::Copy (phi::CPUPlace (),
244
+ seed_offset_data,
245
+ fixed_seed_offset->place (),
246
+ fixed_seed_offset->data <int64_t >(),
247
+ sizeof (int64_t ) * 2 );
248
+ } else {
249
+ const int64_t * fixed_seed_offset_data =
250
+ fixed_seed_offset->data <int64_t >();
251
+ seed_offset_data[0 ] = fixed_seed_offset_data[0 ];
252
+ seed_offset_data[1 ] = fixed_seed_offset_data[1 ];
253
+ }
246
254
} else {
247
255
std::pair<uint64_t , uint64_t > seed_offset_pair;
248
256
uint64_t inc = batch_size * num_heads * 32 ;
@@ -264,7 +272,6 @@ void FlashAttnKernel(const Context& ctx,
264
272
const XPUType* v_data = reinterpret_cast <const XPUType*>(v.data <T>());
265
273
XPUType* out_data = reinterpret_cast <XPUType*>(out->data <T>());
266
274
float * softmax_lse_data = softmax_lse->data <float >();
267
-
268
275
const float * bias_data = nullptr ;
269
276
if (attn_mask.get_ptr () != nullptr ) {
270
277
bias_data = attn_mask->data <float >();
@@ -281,24 +288,24 @@ void FlashAttnKernel(const Context& ctx,
281
288
// nullptr);
282
289
int r = baidu::xpu::xfa::mha_varlen_fwd<XPUType, float , tfloat32, int >(
283
290
ctx.x_context (),
284
- q_data, // q
285
- k_data, // k
286
- v_data, // v
287
- out_data, // out
288
- softmax_lse_data, // softmax_lse
289
- qlod, // lod_seqlens_q
290
- kvlod, // lod_seqlens_k
291
- seqlen_q, // max_seqlen_q
292
- seqlen_k, // max_seqlen_k
293
- num_heads, // head_num
294
- num_heads_k, // head_num_k
295
- head_size, // head_dim
296
- 1 .0f / std::sqrt (head_size), // softmax_scale
297
- dropout, // p_dropout
298
- static_cast <uint64_t >(seed_offset_data[0 ]), // seed
299
- causal, // is_causal
300
- nullptr , // attn_mask
301
- bias_data // bias
291
+ q_data, // q
292
+ k_data, // k
293
+ v_data, // v
294
+ out_data, // out
295
+ softmax_lse_data, // softmax_lse
296
+ qlod, // lod_seqlens_q
297
+ kvlod, // lod_seqlens_k
298
+ seqlen_q, // max_seqlen_q
299
+ seqlen_k, // max_seqlen_k
300
+ num_heads, // head_num
301
+ num_heads_k, // head_num_k
302
+ head_size, // head_dim
303
+ 1 .0f / std::sqrt (head_size), // softmax_scale
304
+ dropout, // p_dropout
305
+ static_cast <int32_t >(seed_offset_data[0 ]), // seed
306
+ causal, // is_causal
307
+ nullptr , // attn_mask
308
+ bias_data // bias
302
309
);
303
310
PADDLE_ENFORCE_XDNN_SUCCESS (r, " mha_varlen_fwd" );
304
311
#else
0 commit comments