14
14
15
15
#include " paddle/phi/kernels/flash_attn_kernel.h"
16
16
#include " paddle/phi/backends/xpu/enforce_xpu.h"
17
- #include " paddle/phi/core/enforce .h"
17
+ #include " paddle/phi/common/memory_utils .h"
18
18
#include " paddle/phi/core/kernel_registry.h"
19
19
20
20
#ifdef PADDLE_WITH_XPU_XHPC
@@ -239,10 +239,18 @@ void FlashAttnKernel(const Context& ctx,
239
239
seed_offset->Resize ({2 });
240
240
int64_t * seed_offset_data = ctx.template HostAlloc <int64_t >(seed_offset);
241
241
if (fixed_seed_offset.get_ptr ()) {
242
- const int64_t * fixed_seed_offset_data =
243
- fixed_seed_offset.get_ptr ()->data <int64_t >();
244
- seed_offset_data[0 ] = fixed_seed_offset_data[0 ];
245
- seed_offset_data[1 ] = fixed_seed_offset_data[1 ];
242
+ if ((fixed_seed_offset->place ()).GetType () == phi::AllocationType::XPU) {
243
+ memory_utils::Copy (phi::CPUPlace (),
244
+ seed_offset_data,
245
+ fixed_seed_offset->place (),
246
+ fixed_seed_offset->data <int64_t >(),
247
+ sizeof (int64_t ) * 2 );
248
+ } else {
249
+ const int64_t * fixed_seed_offset_data =
250
+ fixed_seed_offset->data <int64_t >();
251
+ seed_offset_data[0 ] = fixed_seed_offset_data[0 ];
252
+ seed_offset_data[1 ] = fixed_seed_offset_data[1 ];
253
+ }
246
254
} else {
247
255
std::pair<uint64_t , uint64_t > seed_offset_pair;
248
256
uint64_t inc = batch_size * num_heads * 32 ;
@@ -263,11 +271,29 @@ void FlashAttnKernel(const Context& ctx,
263
271
const XPUType* k_data = reinterpret_cast <const XPUType*>(k.data <T>());
264
272
const XPUType* v_data = reinterpret_cast <const XPUType*>(v.data <T>());
265
273
XPUType* out_data = reinterpret_cast <XPUType*>(out->data <T>());
266
- float * softmax_lse_data = softmax_lse->data <float >();
267
274
275
+ xpu::ctx_guard RAII_GUARD (ctx.x_context ());
276
+ float * softmax_lse_data = softmax_lse->data <float >();
268
277
const float * bias_data = nullptr ;
269
278
if (attn_mask.get_ptr () != nullptr ) {
270
- bias_data = attn_mask->data <float >();
279
+ if (attn_mask->dtype () == phi::DataType::FLOAT32) {
280
+ bias_data = attn_mask->data <float >();
281
+ } else if (attn_mask->dtype () == phi::DataType::FLOAT16 ||
282
+ attn_mask->dtype () == phi::DataType::BFLOAT16) {
283
+ float * bias_tmp = RAII_GUARD.alloc_l3_or_gm <float >(attn_mask->numel ());
284
+ int r = xpu::cast<XPUType, float >(
285
+ ctx.x_context (),
286
+ reinterpret_cast <const XPUType*>(attn_mask->data <T>()),
287
+ bias_tmp,
288
+ attn_mask->numel ());
289
+ PADDLE_ENFORCE_XDNN_SUCCESS (r, " cast" );
290
+ bias_data = bias_tmp;
291
+ } else {
292
+ errors::Unimplemented (
293
+ " Unsupported dtype for attention_mask in xpu flash attention, only "
294
+ " float32, float16 and "
295
+ " bfloat16 are supported." );
296
+ }
271
297
}
272
298
// template <typename T, typename TACCUM, typename TGEMM, typename TID> int
273
299
// mha_varlen_fwd(xdnn::Context* ctx, const T* q, const T* k, const T* v, T*
@@ -281,24 +307,24 @@ void FlashAttnKernel(const Context& ctx,
281
307
// nullptr);
282
308
int r = baidu::xpu::xfa::mha_varlen_fwd<XPUType, float , tfloat32, int >(
283
309
ctx.x_context (),
284
- q_data, // q
285
- k_data, // k
286
- v_data, // v
287
- out_data, // out
288
- softmax_lse_data, // softmax_lse
289
- qlod, // lod_seqlens_q
290
- kvlod, // lod_seqlens_k
291
- seqlen_q, // max_seqlen_q
292
- seqlen_k, // max_seqlen_k
293
- num_heads, // head_num
294
- num_heads_k, // head_num_k
295
- head_size, // head_dim
296
- 1 .0f / std::sqrt (head_size), // softmax_scale
297
- dropout, // p_dropout
298
- static_cast <uint64_t >(seed_offset_data[0 ]), // seed
299
- causal, // is_causal
300
- nullptr , // attn_mask
301
- bias_data // bias
310
+ q_data, // q
311
+ k_data, // k
312
+ v_data, // v
313
+ out_data, // out
314
+ softmax_lse_data, // softmax_lse
315
+ qlod, // lod_seqlens_q
316
+ kvlod, // lod_seqlens_k
317
+ seqlen_q, // max_seqlen_q
318
+ seqlen_k, // max_seqlen_k
319
+ num_heads, // head_num
320
+ num_heads_k, // head_num_k
321
+ head_size, // head_dim
322
+ 1 .0f / std::sqrt (head_size), // softmax_scale
323
+ dropout, // p_dropout
324
+ static_cast <int32_t >(seed_offset_data[0 ]), // seed
325
+ causal, // is_causal
326
+ nullptr , // attn_mask
327
+ bias_data // bias
302
328
);
303
329
PADDLE_ENFORCE_XDNN_SUCCESS (r, " mha_varlen_fwd" );
304
330
#else
0 commit comments