From def1f5bc30716571a7a0b6f0965d6bb354c7a72c Mon Sep 17 00:00:00 2001 From: NeroLoh <745827440@qq.com> Date: Mon, 26 Feb 2024 19:08:42 +0800 Subject: [PATCH] [XPU] add roformer relative embedding pass & kernel and spport in multi_encoder_xpu --- paddle/fluid/framework/ir/CMakeLists.txt | 2 + ...i_encoder_xpu_adaptive_seqlen_fuse_pass.cc | 48 +-- ...ti_encoder_xpu_adaptive_seqlen_fuse_pass.h | 6 +- .../ir/xpu/multi_encoder_xpu_fuse_pass.cc | 300 +++++++++++++++-- .../ir/xpu/multi_encoder_xpu_fuse_pass.h | 4 +- .../ir/xpu/roformer_relative_pos_fuse_pass.cc | 301 ++++++++++++++++++ .../inference/api/paddle_pass_builder.cc | 1 + paddle/phi/api/yaml/fused_ops.yaml | 11 +- paddle/phi/backends/xpu/xpu2_op_list.cc | 2 + paddle/phi/infermeta/fusion.cc | 54 ++++ paddle/phi/infermeta/fusion.h | 7 + .../fusion/xpu/multi_encoder_xpu_kernel.cc | 35 +- .../xpu/roformer_relative_embedding_kernel.cc | 78 +++++ .../test_xpu_roformer_relative_pos_pass.py | 167 ++++++++++ 14 files changed, 969 insertions(+), 47 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/roformer_relative_pos_fuse_pass.cc create mode 100644 paddle/phi/kernels/fusion/xpu/roformer_relative_embedding_kernel.cc create mode 100644 test/ir/inference/test_xpu_roformer_relative_pos_pass.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 765fa1779b0e51..cb8093298d9bb3 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -322,6 +322,8 @@ if(WITH_XPU) ${XPU_PASS_DEPS}) pass_library(sine_pos_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(quant_dequant_xpu_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(roformer_relative_pos_fuse_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) endif() cc_library( diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc index e20320e29a9593..fa75f29ae9187f 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.cc @@ -25,7 +25,9 @@ namespace ir { namespace patterns { struct AdaptiveSeqlenPatternV1 : public PatternBase { - AdaptiveSeqlenPatternV1(PDPattern* pattern, const std::string& name_scope); + AdaptiveSeqlenPatternV1(PDPattern* pattern, + const std::string& name_scope, + const std::string& matmul_type); // declare operator node's name PATTERN_DECL_NODE(embedding_xpu); @@ -44,7 +46,8 @@ struct AdaptiveSeqlenPatternV1 : public PatternBase { }; AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern, - const std::string& name_scope) + const std::string& name_scope, + const std::string& matmul_type) : PatternBase(pattern, name_scope, name_scope) { auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr()) ->assert_is_op("embedding_with_eltwise_add_xpu"); @@ -59,11 +62,11 @@ AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern, ->assert_is_op_input("multi_encoder_xpu", "x"); auto* mask = pattern->NewNode(mask_repr()) - ->assert_is_op_input("matmul", "X") - ->assert_is_op_input("matmul", "Y"); - auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul"); + ->assert_is_op_input(matmul_type, "X") + ->assert_is_op_input(matmul_type, "Y"); + auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op(matmul_type); auto* matmul_out = pattern->NewNode(matmul_out_repr()) - ->assert_is_op_output("matmul", "Out") + ->assert_is_op_output(matmul_type, "Out") ->assert_is_op_input("scale", "X"); auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale"); auto* scale_out = pattern->NewNode(scale_out_repr()) @@ -88,9 +91,10 @@ AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern, } // namespace patterns int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1( - ir::Graph* graph) const { + ir::Graph* graph, const std::string& matmul_type) const { GraphPatternDetector gpd; - patterns::AdaptiveSeqlenPatternV1 pattern(gpd.mutable_pattern(), name_scope_); + patterns::AdaptiveSeqlenPatternV1 pattern( + gpd.mutable_pattern(), name_scope_, matmul_type); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -143,7 +147,9 @@ int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1( namespace patterns { struct AdaptiveSeqlenPatternV2 : public PatternBase { - AdaptiveSeqlenPatternV2(PDPattern* pattern, const std::string& name_scope); + AdaptiveSeqlenPatternV2(PDPattern* pattern, + const std::string& name_scope, + const std::string& matmul_type); // declare operator node's name PATTERN_DECL_NODE(embedding_xpu); @@ -172,7 +178,8 @@ struct AdaptiveSeqlenPatternV2 : public PatternBase { }; AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern, - const std::string& name_scope) + const std::string& name_scope, + const std::string& matmul_type) : PatternBase(pattern, name_scope, name_scope) { auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr()) ->assert_is_op("embedding_with_eltwise_add_xpu"); @@ -201,11 +208,11 @@ AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern, pattern->NewNode(unsqueeze_0_repr())->assert_is_op("unsqueeze2"); auto* unsqueeze_0_out = pattern->NewNode(unsqueeze_0_out_repr()) ->assert_is_op_output("unsqueeze2", "Out") - ->assert_is_op_input("matmul_v2", "X") - ->assert_is_op_input("matmul_v2", "Y"); - auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul_v2"); + ->assert_is_op_input(matmul_type, "X") + ->assert_is_op_input(matmul_type, "Y"); + auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op(matmul_type); auto* matmul_out = pattern->NewNode(matmul_out_repr()) - ->assert_is_op_output("matmul_v2", "Out") + ->assert_is_op_output(matmul_type, "Out") ->assert_is_op_input("scale", "X"); auto* scale_0 = pattern->NewNode(scale_0_repr())->assert_is_op("scale"); auto* scale_0_out = pattern->NewNode(scale_0_out_repr()) @@ -244,9 +251,10 @@ AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern, } // namespace patterns int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV2( - ir::Graph* graph) const { + ir::Graph* graph, const std::string& matmul_type) const { GraphPatternDetector gpd; - patterns::AdaptiveSeqlenPatternV2 pattern(gpd.mutable_pattern(), name_scope_); + patterns::AdaptiveSeqlenPatternV2 pattern( + gpd.mutable_pattern(), name_scope_, matmul_type); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -324,9 +332,13 @@ void MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); + std::vector matmul_types{"matmul", "matmul_v2"}; + int found_subgraph_count = 0; + for (auto& matmul_type : matmul_types) { + found_subgraph_count += ApplyAdaptiveSeqlenPassV1(graph, matmul_type); + found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph, matmul_type); + } - int found_subgraph_count = ApplyAdaptiveSeqlenPassV1(graph); - found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph); AddStatis(found_subgraph_count); } diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h index 22910c21205300..ea3b52bf35a24a 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass.h @@ -76,7 +76,8 @@ class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase { | out_var* */ - int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph) const; + int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph, + const std::string& matmul_type) const; /* adaptive seqlen V2, before: @@ -132,7 +133,8 @@ class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase { | out_var* */ - int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph) const; + int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph, + const std::string& matmul_type) const; private: const std::string name_scope_{"multi_encoder_xpu_adaptive_seqlen_fuse_pass"}; diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc index 8e126df64ad417..e7a5acac2bae24 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc @@ -38,7 +38,8 @@ struct SingleEncoderXPUPattern : public PatternBase { bool norm_before, bool with_q_scale, bool with_mask, - bool is_smooth_quant); + bool is_smooth_quant, + const std::string& relative_type); // declare operator node's name // If norm_before, use ln_0 & ln_1. @@ -141,6 +142,16 @@ struct SingleEncoderXPUPattern : public PatternBase { PATTERN_DECL_NODE(smooth_scale_1_out); PATTERN_DECL_NODE(smooth_scale_2_out); + // roformer_relative_embedding_xpu + PATTERN_DECL_NODE(q_relative_emb); + PATTERN_DECL_NODE(q_cos_embedding); + PATTERN_DECL_NODE(q_sin_embedding); + PATTERN_DECL_NODE(q_relative_emb_out); + PATTERN_DECL_NODE(k_relative_emb); + PATTERN_DECL_NODE(k_cos_embedding); + PATTERN_DECL_NODE(k_sin_embedding); + PATTERN_DECL_NODE(k_relative_emb_out); + private: std::string act_type_; std::string matmul_type_0_; @@ -150,6 +161,7 @@ struct SingleEncoderXPUPattern : public PatternBase { bool with_q_scale_{false}; bool with_mask_{true}; bool is_smooth_quant_{false}; + std::string relative_type_ = ""; }; SingleEncoderXPUPattern::SingleEncoderXPUPattern( @@ -162,7 +174,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( bool norm_before, bool with_q_scale, bool with_mask, - bool is_smooth_quant) + bool is_smooth_quant, + const std::string& relative_type) : PatternBase(pattern, name_scope, name_scope), act_type_(act_type), matmul_type_0_(matmul_type_0), @@ -171,7 +184,8 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( norm_before_(norm_before), with_q_scale_(with_q_scale), with_mask_(with_mask), - is_smooth_quant_(is_smooth_quant) { + is_smooth_quant_(is_smooth_quant), + relative_type_(relative_type) { // layer_norm 0 PDNode* ln_0_x = pattern->NewNode(ln_0_x_repr()); PDNode* ln_0_bias = nullptr; @@ -244,14 +258,38 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( ->assert_var_not_persistable(); PDNode* q_scale = nullptr; PDNode* q_scale_out = nullptr; + std::string target_op_type = matmul_type_1_; if (with_q_scale_) { q_scale = pattern->NewNode(q_scale_repr())->assert_is_op("scale"); q_scale_out = pattern->NewNode(q_scale_out_repr()) ->assert_is_op_output("scale", "Out") ->assert_is_op_input(matmul_type_1_, "X") ->assert_var_not_persistable(); + target_op_type = "scale"; } else { - q_transpose_out->assert_is_op_input(matmul_type_1_, "X"); + if (relative_type_.empty()) { + q_transpose_out->assert_is_op_input(target_op_type, "X"); + } else { + q_transpose_out->assert_is_op_input(relative_type_, "x"); + } + } + PDNode* q_relative_emb = nullptr; + PDNode* q_cos_embedding = nullptr; + PDNode* q_sin_embedding = nullptr; + PDNode* q_relative_emb_out = nullptr; + if (relative_type_ == "roformer_relative_embedding_xpu") { + VLOG(3) << "build q_relative_emb"; + q_relative_emb = + pattern->NewNode(q_relative_emb_repr())->assert_is_op(relative_type_); + q_sin_embedding = pattern->NewNode(q_sin_embedding_repr()) + ->assert_is_op_input(relative_type_, "sin_emb") + ->AsInput(); + q_cos_embedding = pattern->NewNode(q_cos_embedding_repr()) + ->assert_is_op_input(relative_type_, "cos_emb") + ->AsInput(); + q_relative_emb_out = pattern->NewNode(q_relative_emb_out_repr()) + ->assert_is_op_output(relative_type_, "out") + ->assert_is_op_input(target_op_type, "X"); } // k: matmul + add + reshape + transpose @@ -279,9 +317,23 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( pattern->NewNode(k_transpose_repr())->assert_is_op("transpose2"); auto* k_transpose_out = pattern->NewNode(k_transpose_out_repr()) ->assert_is_op_output("transpose2", "Out") - ->assert_is_op_input(matmul_type_1_, "Y") ->assert_var_not_persistable(); + PDNode* k_relative_emb = nullptr; + PDNode* k_sin_embedding = q_sin_embedding; + PDNode* k_cos_embedding = q_cos_embedding; + PDNode* k_relative_emb_out = nullptr; + if (relative_type_.empty()) { + k_transpose_out->assert_is_op_input(matmul_type_1_, "Y"); + } else if (relative_type_ == "roformer_relative_embedding_xpu") { + VLOG(3) << "build k_relative_emb"; + k_transpose_out->assert_is_op_input(relative_type_, "x"); + k_relative_emb = + pattern->NewNode(k_relative_emb_repr())->assert_is_op(relative_type_); + k_relative_emb_out = pattern->NewNode(k_relative_emb_out_repr()) + ->assert_is_op_output(relative_type_, "out") + ->assert_is_op_input(matmul_type_1_, "Y"); + } // qk: matmul + add + softmax auto* qk_matmul = pattern->NewNode(qk_matmul_repr())->assert_is_op(matmul_type_1_); @@ -482,18 +534,31 @@ SingleEncoderXPUPattern::SingleEncoderXPUPattern( q_add->LinksFrom({q_matmul_out, q_add_bias}).LinksTo({q_add_out}); q_reshape->LinksFrom({q_add_out}).LinksTo({q_reshape_out}); q_transpose->LinksFrom({q_reshape_out}).LinksTo({q_transpose_out}); - PDNode* qk_matmul_x = q_transpose_out; + PDNode* last_node = q_transpose_out; + if (relative_type_ == "roformer_relative_embedding_xpu") { + VLOG(3) << "build q_relative_emb link"; + q_relative_emb->LinksFrom({last_node, q_sin_embedding, q_cos_embedding}) + .LinksTo({q_relative_emb_out}); + last_node = q_relative_emb_out; + } if (with_q_scale_) { - q_scale->LinksFrom({q_transpose_out}).LinksTo({q_scale_out}); - qk_matmul_x = q_scale_out; + q_scale->LinksFrom({last_node}).LinksTo({q_scale_out}); + last_node = q_scale_out; } + PDNode* qk_matmul_x = last_node; k_matmul->LinksFrom({q_matmul_x, k_matmul_w}).LinksTo({k_matmul_out}); k_add->LinksFrom({k_matmul_out, k_add_bias}).LinksTo({k_add_out}); k_reshape->LinksFrom({k_add_out}).LinksTo({k_reshape_out}); k_transpose->LinksFrom({k_reshape_out}).LinksTo({k_transpose_out}); - - qk_matmul->LinksFrom({qk_matmul_x, k_transpose_out}).LinksTo({qk_matmul_out}); + last_node = k_transpose_out; + if (relative_type_ == "roformer_relative_embedding_xpu") { + VLOG(3) << "build k_relative_emb link"; + k_relative_emb->LinksFrom({last_node, k_sin_embedding, k_cos_embedding}) + .LinksTo({k_relative_emb_out}); + last_node = k_relative_emb_out; + } + qk_matmul->LinksFrom({qk_matmul_x, last_node}).LinksTo({qk_matmul_out}); PDNode* qk_softmax_x = qk_matmul_out; if (with_mask_) { qk_add->LinksFrom({qk_matmul_out, qk_add_mask}).LinksTo({qk_add_out}); @@ -571,7 +636,8 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { pattern_param.norm_before, pattern_param.with_q_scale, pattern_param.with_mask, - pattern_param.is_smooth_quant); + pattern_param.is_smooth_quant, + pattern_param.relative_type); while (ApplyMultiEncoderXPUFuse(graph)) { multi_encoder_fused_counts++; } @@ -950,7 +1016,8 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( bool norm_before, bool with_q_scale, bool with_mask, - bool is_smooth_quant) const { + bool is_smooth_quant, + const std::string& relative_type) const { bool local_quant = false; if (std::getenv("XPU_LOCAL_QUANT")) { local_quant = atoi(std::getenv("XPU_LOCAL_QUANT")); @@ -965,7 +1032,8 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( norm_before, with_q_scale, with_mask, - is_smooth_quant); + is_smooth_quant, + relative_type); int found_subgraph_count = 0; auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, @@ -1068,6 +1136,16 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( GET_IR_NODE(smooth_scale_1_out); GET_IR_NODE(smooth_scale_2_out); + // roformer_relative_embedding_xpu + GET_IR_NODE(q_relative_emb); + GET_IR_NODE(q_cos_embedding); + GET_IR_NODE(q_sin_embedding); + GET_IR_NODE(q_relative_emb_out); + GET_IR_NODE(k_relative_emb); + GET_IR_NODE(k_cos_embedding); + GET_IR_NODE(k_sin_embedding); + GET_IR_NODE(k_relative_emb_out); + auto* block = q_matmul->Op()->Block(); auto* scope = param_scope(); auto weight_dtype = @@ -1275,6 +1353,24 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( op_desc.SetAttr("relative_type", static_cast(0)); op_desc.SetAttr("use_precision", use_precision); op_desc.SetAttr("is_per_channel", is_per_channel); + if (relative_type == "roformer_relative_embedding_xpu") { + // q/k share the rotary embedding + op_desc.SetInput("roformer_embedding", + {q_cos_embedding->Name(), q_sin_embedding->Name()}); + op_desc.SetAttr("relative_type", 1); + auto q_cos_emb_shape = q_cos_embedding->Var()->GetShape(); + CHECK_GE(static_cast(q_cos_emb_shape.size()), 2) + << q_cos_emb_shape.size(); + auto size_per_head = q_reshape_out->Var()->GetShape()[3]; + CHECK_EQ(size_per_head, q_cos_emb_shape[q_cos_emb_shape.size() - 1]); + int max_pos_len = q_cos_emb_shape[q_cos_emb_shape.size() - 2]; + VLOG(3) << "relative embedding max sequence len: " << max_pos_len; + op_desc.SetAttr("max_pos_len", max_pos_len); + } else { + op_desc.SetInput("roformer_embedding", {}); + op_desc.SetAttr("max_pos_len", 0); + } + // if quant,skip softmax,and use qk_matmul out_threshold as softmax_max auto softmax_max_name = qk_matmul->Op()->Output("Out")[0]; if (var_quant_scales.find(softmax_max_name) != var_quant_scales.end()) { @@ -1320,6 +1416,10 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( IR_NODE_LINK_TO(smooth_scale_1_weight, single_encoder_xpu); IR_NODE_LINK_TO(smooth_scale_2_weight, single_encoder_xpu); } + if (relative_type == "roformer_relative_embedding_xpu") { + IR_NODE_LINK_TO(q_cos_embedding, single_encoder_xpu); + IR_NODE_LINK_TO(q_sin_embedding, single_encoder_xpu); + } // Delete nodes std::unordered_set delete_nodes{ln_1, @@ -1405,6 +1505,12 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( delete_nodes.insert(smooth_scale_1_out); delete_nodes.insert(smooth_scale_2_out); } + if (relative_type == "roformer_relative_embedding_xpu") { + delete_nodes.insert(q_relative_emb); + delete_nodes.insert(q_relative_emb_out); + delete_nodes.insert(k_relative_emb); + delete_nodes.insert(k_relative_emb_out); + } GraphSafeRemoveNodes(graph, delete_nodes); found_subgraph_count++; }; @@ -1453,7 +1559,8 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const { "fc_bias", "ln_scale", "ln_bias", - "smooth_scale_weight"}; + "smooth_scale_weight", + "roformer_embedding"}; std::map> arg_names_map; std::string mask_name = single_encoders[0]->Op()->Inputs().count("mask") > 0 ? single_encoders[0]->Op()->Inputs().at("mask")[0] @@ -1556,6 +1663,11 @@ bool MultiEncoderXPUFusePass::ApplyMultiEncoderXPUFuse(ir::Graph* graph) const { quant_types.end(), per_quant_types.begin(), per_quant_types.end()); } op_desc.SetAttr("quant_types", quant_types); + if (single_encoders[0]->Op()->HasAttr("max_pos_len")) { + op_desc.SetAttr("max_pos_len", + PADDLE_GET_CONST( + int, single_encoders[0]->Op()->GetAttr("max_pos_len"))); + } op_desc.SetOutput("out", {out_name}); op_desc.SetOutput("x_fp16", {x_fp16_name}); op_desc.SetOutput("out_fp16", {out_fp16_name}); @@ -1642,15 +1754,157 @@ std::vector MultiEncoderXPUFusePass::GeneratePatternParams() const { return std::vector{ // Params are arranged in alphabetic order - {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true, false}, - {"gelu", "matmul_v2", "matmul_v2", "matmul_v2", false, true, true, false}, - {"gelu", "mul", "matmul", "matmul", false, true, true, false}, - {"relu", "mul", "matmul", "matmul", false, true, true, false}, - - {"gelu", "matmul_v2", "matmul", "matmul_v2", false, false, true, true}, - {"gelu", "matmul_v2", "matmul_v2", "matmul_v2", false, true, true, true}, - {"gelu", "mul", "matmul", "matmul", false, true, true, true}, - {"relu", "mul", "matmul", "matmul", false, true, true, true}, + {"gelu", + "matmul_v2", + "matmul", + "matmul_v2", + false, + false, + true, + false, + ""}, + {"gelu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + false, + ""}, + {"gelu", "mul", "matmul", "matmul", false, true, true, false, ""}, + {"relu", "mul", "matmul", "matmul", false, true, true, false, ""}, + {"relu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + false, + ""}, + + {"gelu", + "matmul_v2", + "matmul", + "matmul_v2", + false, + false, + true, + true, + ""}, + {"gelu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + true, + ""}, + {"gelu", "mul", "matmul", "matmul", false, true, true, true, ""}, + {"relu", "mul", "matmul", "matmul", false, true, true, true, ""}, + {"relu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + true, + ""}, + + {"gelu", + "matmul_v2", + "matmul", + "matmul_v2", + false, + false, + true, + false, + "roformer_relative_embedding_xpu"}, + {"gelu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + false, + "roformer_relative_embedding_xpu"}, + {"gelu", + "mul", + "matmul", + "matmul", + false, + true, + true, + false, + "roformer_relative_embedding_xpu"}, + {"relu", + "mul", + "matmul", + "matmul", + false, + true, + true, + false, + "roformer_relative_embedding_xpu"}, + {"relu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + false, + "roformer_relative_embedding_xpu"}, + + {"gelu", + "matmul_v2", + "matmul", + "matmul_v2", + false, + false, + true, + true, + "roformer_relative_embedding_xpu"}, + {"gelu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + true, + "roformer_relative_embedding_xpu"}, + {"gelu", + "mul", + "matmul", + "matmul", + false, + true, + true, + true, + "roformer_relative_embedding_xpu"}, + {"relu", + "mul", + "matmul", + "matmul", + false, + true, + true, + true, + "roformer_relative_embedding_xpu"}, + {"relu", + "matmul_v2", + "matmul_v2", + "matmul_v2", + false, + true, + true, + true, + "roformer_relative_embedding_xpu"}, }; } diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h index 6c45838073af68..238f7d8d419c56 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.h @@ -129,6 +129,7 @@ struct PatternParam { bool with_q_scale; bool with_mask; bool is_smooth_quant; + std::string relative_type; }; class MultiEncoderXPUFusePass : public FusePassBase { @@ -144,7 +145,8 @@ class MultiEncoderXPUFusePass : public FusePassBase { bool norm_before, bool with_q_scale, bool with_mask, - bool is_smooth_quant) const; + bool is_smooth_qunat, + const std::string& relative_type) const; bool ApplyMultiEncoderXPUFuse(ir::Graph* graph) const; diff --git a/paddle/fluid/framework/ir/xpu/roformer_relative_pos_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/roformer_relative_pos_fuse_pass.cc new file mode 100644 index 00000000000000..2c50c77cad8d7f --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/roformer_relative_pos_fuse_pass.cc @@ -0,0 +1,301 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "glog/logging.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/quantize_helper.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { +/* +fuse block in vis model to reformer_relative_pos_xpu op +------------------------------------------------------ */ +/* support xpu roformer relative pos */ +/* x --------------- */ +/* | \ | */ +/* | \ | */ +/* split shape | */ +/* / | \ | */ +/* / | \ | */ +/* | scale slice | */ +/* \ | / \ | */ +/* \ | / \ | */ +/* concat slice slice | */ +/* | / \ | */ +/* | / \ | */ +/* elementwise_mul elementwise_mul */ +/* | / */ +/* | / */ +/* elementwise_add */ +/* | */ +/* | */ +/* out */ +/*-------------------------------------------*/ +/* After the pass apply: */ +/* x */ +/* cos_emb | sin_emb */ +/* \ | / */ +/* xpu_roformer_relative */ +/* | */ +/* | */ +/* out */ +/*-------------------------------------------*/ + +struct RoformerRelativePosXPUPattern : public PatternBase { + RoformerRelativePosXPUPattern(PDPattern* pattern, + const std::string& name_scope); + // declare operator node's name + PATTERN_DECL_NODE(split); + PATTERN_DECL_NODE(scale); + PATTERN_DECL_NODE(concat); + PATTERN_DECL_NODE(mul1); + + PATTERN_DECL_NODE(shape); + PATTERN_DECL_NODE(slice1); + PATTERN_DECL_NODE(slice_sin); + PATTERN_DECL_NODE(slice_cos); + + PATTERN_DECL_NODE(mul2); + PATTERN_DECL_NODE(add); + // declare variable node's name + PATTERN_DECL_NODE(x); + PATTERN_DECL_NODE(sin_emb); + PATTERN_DECL_NODE(cos_emb); + PATTERN_DECL_NODE(split_out1); + PATTERN_DECL_NODE(split_out2); + PATTERN_DECL_NODE(scale_out); + PATTERN_DECL_NODE(concat_out); + PATTERN_DECL_NODE(mul1_out); + PATTERN_DECL_NODE(shape_out); + PATTERN_DECL_NODE(slice1_out); + PATTERN_DECL_NODE(slice_sin_out); + PATTERN_DECL_NODE(slice_cos_out); + PATTERN_DECL_NODE(mul2_out); + PATTERN_DECL_NODE(add_out); +}; + +RoformerRelativePosXPUPattern::RoformerRelativePosXPUPattern( + PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* x = pattern->NewNode(x_repr()) + ->assert_is_op_input("split", "X") + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_input("shape", "Input") + ->AsInput(); + + auto* split = pattern->NewNode(split_repr()) + ->assert_is_op("split") + ->assert_op_attr("axis", 3) + ->assert_op_attr("num", 2); // do we really need it + + auto* split_out1 = pattern->NewNode(split_out1_repr()) + ->assert_is_op_input("scale", "X") + ->assert_is_op_nth_output("split", "Out", 1); + auto* split_out2 = pattern->NewNode(split_out2_repr()) + ->assert_is_op_nth_input("concat", "X", 1) + ->assert_is_op_nth_output("split", "Out", 0); + split->LinksFrom({x}).LinksTo({split_out1, split_out2}); + + auto* scale = pattern->NewNode(scale_repr()) + ->assert_is_op("scale") + ->assert_more([&](Node* node) { + auto* op_desc = node->Op(); + auto scale = op_desc->GetAttrIfExists("scale"); + return (std::fabs(scale + 1.0) < 1e-5); + }); + auto* scale_out = pattern->NewNode(scale_out_repr()) + ->assert_is_op_input("concat", "X") + ->assert_is_op_output("scale", "Out"); + scale->LinksFrom({split_out1}).LinksTo({scale_out}); + auto* concat = pattern->NewNode(concat_repr())->assert_is_op("concat"); + auto* concat_out = pattern->NewNode(concat_out_repr()) + ->assert_is_op_input("elementwise_mul", "X") + ->assert_is_op_output("concat", "Out"); + concat->LinksFrom({scale_out, split_out2}).LinksTo({concat_out}); + auto* shape = pattern->NewNode(shape_repr())->assert_is_op("shape"); + auto* shape_out = pattern->NewNode(shape_out_repr()) + ->assert_is_op_input("slice", "Input") + ->assert_is_op_output("shape", "Out"); + shape->LinksFrom({x}).LinksTo({shape_out}); + auto* slice1 = pattern->NewNode(slice1_repr())->assert_is_op("slice"); + auto* slice1_out = pattern->NewNode(slice1_out_repr()) + ->assert_is_op_input("slice", "EndsTensorList") + ->assert_is_op_output("slice", "Out"); + slice1->LinksFrom({shape_out}).LinksTo({slice1_out}); + auto* sin_emb = pattern->NewNode(sin_emb_repr()) + ->assert_is_op_input("slice", "Input") + ->AsInput(); + auto* cos_emb = pattern->NewNode(cos_emb_repr()) + ->assert_is_op_input("slice", "Input") + ->AsInput(); + auto* slice_sin = pattern->NewNode(slice_sin_repr())->assert_is_op("slice"); + auto* slice_sin_out = pattern->NewNode(slice_sin_out_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_op_output("slice", "Out"); + slice_sin->LinksFrom({sin_emb, slice1_out}).LinksTo({slice_sin_out}); + auto* mul1 = pattern->NewNode(mul1_repr())->assert_is_op("elementwise_mul"); + auto* mul1_out = pattern->NewNode(mul1_out_repr()) + ->assert_is_op_input("elementwise_add", "Y") + ->assert_is_op_output("elementwise_mul", "Out"); + mul1->LinksFrom({concat_out, slice_sin_out}).LinksTo({mul1_out}); + auto* add = pattern->NewNode(add_repr())->assert_is_op("elementwise_add"); + auto* add_out = pattern->NewNode(add_out_repr()) + ->assert_is_op_output("elementwise_add", "Out") + ->AsOutput(); + auto* slice_cos = pattern->NewNode(slice_cos_repr())->assert_is_op("slice"); + auto* slice_cos_out = pattern->NewNode(slice_cos_out_repr()) + ->assert_is_op_input("elementwise_mul", "Y") + ->assert_is_op_output("slice", "Out"); + slice_cos->LinksFrom({cos_emb, slice1_out}).LinksTo({slice_cos_out}); + auto* mul2 = pattern->NewNode(mul2_repr())->assert_is_op("elementwise_mul"); + auto* mul2_out = pattern->NewNode(mul2_out_repr()) + ->assert_is_op_input("elementwise_add", "X") + ->assert_is_op_output("elementwise_mul", "Out"); + mul2->LinksFrom({x, slice_cos_out}).LinksTo({mul2_out}); + add->LinksFrom({mul2_out, mul1_out}).LinksTo({add_out}); +} + +} // namespace patterns + +class RoformerRelativePosFusePass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + const std::string name_scope_{"roformer_relative_pos_fuse_pass"}; +}; + +void RoformerRelativePosFusePass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + GraphPatternDetector gpd; + patterns::RoformerRelativePosXPUPattern pattern(gpd.mutable_pattern(), + name_scope_); + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle RoformerRelativePosFusePass fuse"; + /* declare operator node's name */ + // declare variable node's name + GET_IR_NODE(split); + GET_IR_NODE(scale); + GET_IR_NODE(concat); + GET_IR_NODE(mul1); + GET_IR_NODE(shape); + GET_IR_NODE(slice1); + GET_IR_NODE(slice_sin); + GET_IR_NODE(slice_cos); + GET_IR_NODE(mul2); + GET_IR_NODE(add); + // declare variable node's name + GET_IR_NODE(x); + GET_IR_NODE(sin_emb); + GET_IR_NODE(cos_emb); + GET_IR_NODE(split_out1); + GET_IR_NODE(split_out2); + GET_IR_NODE(scale_out); + GET_IR_NODE(concat_out); + GET_IR_NODE(mul1_out); + GET_IR_NODE(shape_out); + GET_IR_NODE(slice1_out); + GET_IR_NODE(slice_sin_out); + GET_IR_NODE(slice_cos_out); + GET_IR_NODE(mul2_out); + GET_IR_NODE(add_out); + auto* block = add->Op()->Block(); + auto* scope = param_scope(); + PADDLE_ENFORCE_NOT_NULL( + scope, platform::errors::InvalidArgument("Scope cannot be nullptr.")); + // Generate roformer_relative_embedding_xpu fused op + framework::OpDesc fused_op_desc(block); + fused_op_desc.SetType("roformer_relative_embedding_xpu"); + // set attrs for fused op + fused_op_desc.SetInput("x", {x->Name()}); + fused_op_desc.SetInput("sin_emb", {sin_emb->Name()}); + fused_op_desc.SetInput("cos_emb", {cos_emb->Name()}); + + fused_op_desc.SetOutput("out", {add_out->Name()}); + fused_op_desc.SetAttr("max_pos_len", + static_cast(cos_emb->Var()->GetShape()[2])); + + // relink fused op + auto* fused_op = graph->CreateOpNode(&fused_op_desc); + IR_NODE_LINK_TO(x, fused_op); + IR_NODE_LINK_TO(sin_emb, fused_op); + IR_NODE_LINK_TO(cos_emb, fused_op); + IR_NODE_LINK_TO(fused_op, add_out); + // delete useless node + std::unordered_set delete_nodes = {split, + scale, + concat, + mul1, + shape, + slice1, + slice_sin, + slice_cos, + mul2, + add, + split_out1, + split_out2, + scale_out, + concat_out, + shape_out, + slice1_out, + slice_sin_out, + slice_cos_out, + mul2_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + + AddStatis(found_subgraph_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(roformer_relative_pos_fuse_pass, + paddle::framework::ir::RoformerRelativePosFusePass); + +REGISTER_PASS_CAPABILITY(roformer_relative_pos_fuse_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "roformer_relative_embedding_xpu", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 0684064df81e8a..508381dc3a3104 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -528,6 +528,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "delete_dropout_op_pass", "delete_concat_op_pass", "gather_squeeze_pass", + "roformer_relative_pos_fuse_pass", "delete_repeated_ops_pass", "identity_op_clean_pass", "fused_continuous_same_ops_pass", diff --git a/paddle/phi/api/yaml/fused_ops.yaml b/paddle/phi/api/yaml/fused_ops.yaml index 2ca0a32be59f52..c7b0b14606b98a 100644 --- a/paddle/phi/api/yaml/fused_ops.yaml +++ b/paddle/phi/api/yaml/fused_ops.yaml @@ -399,7 +399,7 @@ backward : max_pool2d_v2_grad - op : multi_encoder_xpu - args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] smooth_scale_weight, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel, float[] softmax_max_value, str[] quant_types) + args : (Tensor x, Tensor[] fc_input_max, Tensor[] fc_weight, Tensor[] fc_weight_max, Tensor[] fc_bias, Tensor[] ln_scale, Tensor[] ln_bias, Tensor[] smooth_scale_weight, Tensor[] roformer_embedding, Tensor mask, Tensor seq_lod, Tensor max_seq_len, int layer_num, bool norm_before, int hidden_dim, int head_num, int size_per_head, int ffn_hidden_dim_scale, int act_type, int relative_type, int slice_idx, bool is_per_channel, int max_pos_len, float[] softmax_max_value, str[] quant_types) output : Tensor(out), Tensor(x_fp16), Tensor(out_fp16) infer_meta : func : MultiEncoderXPUInferMeta @@ -437,6 +437,15 @@ func : quantize_xpu data_type : x +- op : roformer_relative_embedding_xpu + args : (Tensor x, Tensor sin_emb, Tensor cos_emb, int max_pos_len) + output : Tensor(out) + infer_meta : + func : RoformerRelativePosXPUInferMeta + kernel : + func : roformer_relative_embedding_xpu + data_type : x + - op : self_dp_attention args : (Tensor x, float alpha = 1.0f, int head_number = 1) output : Tensor(out) diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 55aae9f24c1a61..14d761a1f1479e 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -1196,6 +1196,8 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT32})}, {"sine_pos_xpu", XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, + {"roformer_relative_embedding_xpu", + XPUKernelSet({phi::DataType::FLOAT16, phi::DataType::FLOAT32})}, }; return s_xpu2_kernels; diff --git a/paddle/phi/infermeta/fusion.cc b/paddle/phi/infermeta/fusion.cc index 6e85754335ce9e..af280b44d6501d 100644 --- a/paddle/phi/infermeta/fusion.cc +++ b/paddle/phi/infermeta/fusion.cc @@ -1447,6 +1447,7 @@ void MultiEncoderXPUInferMeta( const std::vector& ln_scale, const std::vector& ln_bias, const std::vector& smooth_scale_weight, + const std::vector& roformer_embedding, const MetaTensor& mask, const MetaTensor& seq_lod, const MetaTensor& max_seq_len, @@ -1460,6 +1461,7 @@ void MultiEncoderXPUInferMeta( int relative_type, int slice_idx, bool is_per_channel, + int max_pos_len, const std::vector& softmax_max_value, const std::vector& quant_types, MetaTensor* out, @@ -3829,4 +3831,56 @@ void MultiGruInferMeta( hidden->set_dims(out_dims); hidden->share_lod(x); } + +void RoformerRelativePosXPUInferMeta(const MetaTensor& x, + const MetaTensor& sin_emb, + const MetaTensor& cos_emb, + int max_pos_len, + MetaTensor* out) { + auto x_dims = x.dims(); + auto x_dims_size = x_dims.size(); + auto sin_emb_dims = sin_emb.dims(); + auto sin_emb_dims_size = sin_emb_dims.size(); + auto cos_emb_dims = cos_emb.dims(); + auto cos_emb_dims_size = cos_emb_dims.size(); + PADDLE_ENFORCE_EQ( + x_dims_size, + 4, + phi::errors::InvalidArgument( + "x_dims_size should be 4, but received x_dims_size is %d", + x_dims_size)); + PADDLE_ENFORCE_EQ( + sin_emb_dims_size, + 4, + phi::errors::InvalidArgument( + "sin_emb_dims_size should be 4, but received sin_emb_dims_size is %d", + sin_emb_dims_size)); + PADDLE_ENFORCE_EQ( + cos_emb_dims_size, + 4, + phi::errors::InvalidArgument( + "cos_emb_dims_size should be 4, but received cos_emb_dims_size is %d", + cos_emb_dims_size)); + for (int i = 0; i < sin_emb_dims_size; i++) { + PADDLE_ENFORCE_EQ( + sin_emb_dims[i], + cos_emb_dims[i], + phi::errors::InvalidArgument( + "sin_emb_dims[i] should be equal to cos_emb_dims[i], index i is " + "%d, sin_emb_dims[i] is %d, cos_emb_dims[i] is %d", + i, + sin_emb_dims[i], + cos_emb_dims[i])); + } + PADDLE_ENFORCE_EQ( + x_dims[3], + cos_emb_dims[3], + phi::errors::InvalidArgument("x_dims[3] should be equal to cos_dims[3], " + "but sin_dims[3] is %d, cos_dims[3] is %d", + x_dims[3], + cos_emb_dims[3])); + out->set_dims(x_dims); + out->set_dtype(x.dtype()); +} + } // namespace phi diff --git a/paddle/phi/infermeta/fusion.h b/paddle/phi/infermeta/fusion.h index 767f22fd245f4d..87999ab2b45641 100644 --- a/paddle/phi/infermeta/fusion.h +++ b/paddle/phi/infermeta/fusion.h @@ -151,6 +151,7 @@ void MultiEncoderXPUInferMeta( const std::vector& ln_scale, const std::vector& ln_bias, const std::vector& smooth_scale_weight, + const std::vector& roformer_embedding, const MetaTensor& mask, const MetaTensor& seq_lod, const MetaTensor& max_seq_len, @@ -164,6 +165,7 @@ void MultiEncoderXPUInferMeta( int relative_type, int slice_idx, bool is_per_channel, + int max_pos_len, const std::vector& softmax_max_value, const std::vector& quant_types, MetaTensor* out, @@ -838,6 +840,11 @@ void QKVAttentionXPUInferMeta(const MetaTensor& q, void SinePosXPUInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); +void RoformerRelativePosXPUInferMeta(const MetaTensor& x, + const MetaTensor& sin_emb, + const MetaTensor& cos_emb, + int max_pos_len, + MetaTensor* out); void MultiGruInferMeta( const MetaTensor& x, diff --git a/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc index 1f76fc3ef02d84..0b311eb0e65f71 100644 --- a/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/multi_encoder_xpu_kernel.cc @@ -47,6 +47,7 @@ void MultiEncoderXPUKernel( const std::vector& ln_scale, const std::vector& ln_bias, const std::vector& smooth_scale_weight, + const std::vector& roformer_embedding, const paddle::optional& mask, const paddle::optional& seq_lod, const paddle::optional& max_seq_len, @@ -60,6 +61,7 @@ void MultiEncoderXPUKernel( int relative_type, int slice_idx, bool is_per_channel, + int max_pos_len, const std::vector& softmax_max_value, const std::vector& quant_types, DenseTensor* out, @@ -150,7 +152,6 @@ void MultiEncoderXPUKernel( } } - std::vector test_data(6, 0); for (size_t i = 0; i < fc_input_max.size(); i++) { fc_input_max_data.push_back(fc_input_max[i]->data()); } @@ -199,6 +200,16 @@ void MultiEncoderXPUKernel( qkv_attn_param.quant_type_.assign(set_quant_types.begin(), set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; + if (!roformer_embedding.empty()) { + std::vector roformer_embedding_data; + for (size_t i = 0; i < roformer_embedding.size(); i++) { + roformer_embedding_data.push_back(roformer_embedding[i]->data()); + } + qkv_attn_param.relative_type = relative_type; + qkv_attn_param.max_pos_len = max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_data.begin(), + roformer_embedding_data.end()); + } if (!enable_int8) { if (local_quant) { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) @@ -242,6 +253,16 @@ void MultiEncoderXPUKernel( qkv_attn_param.quant_type_.assign(set_quant_types.begin(), set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; + if (!roformer_embedding.empty()) { + std::vector roformer_embedding_data; + for (size_t i = 0; i < roformer_embedding.size(); i++) { + roformer_embedding_data.push_back(roformer_embedding[i]->data()); + } + qkv_attn_param.relative_type = relative_type; + qkv_attn_param.max_pos_len = max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_data.begin(), + roformer_embedding_data.end()); + } if (!enable_int8) { if (local_quant) { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) @@ -288,6 +309,16 @@ void MultiEncoderXPUKernel( qkv_attn_param.quant_type_.assign(set_quant_types.begin(), set_quant_types.end()); qkv_attn_param.scale_of_hidden_units = ffn_hidden_dim_scale; + if (!roformer_embedding.empty()) { + std::vector roformer_embedding_data; + for (size_t i = 0; i < roformer_embedding.size(); i++) { + roformer_embedding_data.push_back(roformer_embedding[i]->data()); + } + qkv_attn_param.relative_type = relative_type; + qkv_attn_param.max_pos_len = max_pos_len; + qkv_attn_param.relative_pos.assign(roformer_embedding_data.begin(), + roformer_embedding_data.end()); + } if (!enable_int8) { if (local_quant) { TRANSFORMER_ENCODER_KERNEL_IMPL(XPUTypeFP16, XPUTypeFP16, float) @@ -319,6 +350,6 @@ PD_REGISTER_KERNEL(multi_encoder_xpu, phi::fusion::MultiEncoderXPUKernel, float, phi::dtype::float16) { - kernel->InputAt(9).SetBackend(phi::Backend::CPU); kernel->InputAt(10).SetBackend(phi::Backend::CPU); + kernel->InputAt(11).SetBackend(phi::Backend::CPU); } diff --git a/paddle/phi/kernels/fusion/xpu/roformer_relative_embedding_kernel.cc b/paddle/phi/kernels/fusion/xpu/roformer_relative_embedding_kernel.cc new file mode 100644 index 00000000000000..ae42b0eabc614e --- /dev/null +++ b/paddle/phi/kernels/fusion/xpu/roformer_relative_embedding_kernel.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { +namespace fusion { + +template +void RoformerRelativePosXPUKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& sin_emb, + const DenseTensor& cos_emb, + int max_pos_len, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + + auto* x_data = reinterpret_cast(x.data()); + auto* sin_emb_data = sin_emb.data(); + auto* cos_emb_data = cos_emb.data(); + auto* out_data = reinterpret_cast(ctx.template Alloc(out)); + xpu::ctx_guard RAII_GUARD(ctx.x_context()); + auto x_dims = x.dims(); + int batch = x_dims[0]; + int head_num = x_dims[1]; + int seqlen = x_dims[2]; + int head_dim = x_dims[3]; + if (seqlen > max_pos_len) { + PADDLE_THROW(phi::errors::InvalidArgument( + "The input sequence length should be less than or equal to the " + "maximum position length. But received seqlen: %d, max_pos_len: %d", + seqlen, + max_pos_len)); + } + std::vector lod; + lod.resize(batch + 1); + for (int i = 0; i < batch + 1; i++) { + lod[i] = i * seqlen; + } + int r = + xpu::rope(ctx.x_context(), + x_data, + out_data, + cos_emb_data, + sin_emb_data, + batch, + head_num, + head_dim, + head_num * head_dim, + lod, + max_pos_len, + false, // no vsl + true); // transpose to [n, seql, head_num, head_dim] + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "roformer_relative_embedding_xpu"); +} + +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(roformer_relative_embedding_xpu, + XPU, + ALL_LAYOUT, + phi::fusion::RoformerRelativePosXPUKernel, + float, + phi::dtype::float16) {} diff --git a/test/ir/inference/test_xpu_roformer_relative_pos_pass.py b/test/ir/inference/test_xpu_roformer_relative_pos_pass.py new file mode 100644 index 00000000000000..93c448463af9ca --- /dev/null +++ b/test/ir/inference/test_xpu_roformer_relative_pos_pass.py @@ -0,0 +1,167 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial + +import hypothesis.strategies as st +import numpy as np +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestRoformerRelativePosXPUPass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + # config.switch_ir_optim(True) + # config.switch_ir_debug(True) + yield config, ["roformer_relative_embedding_xpu"], (1e-3, 1e-3) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=10), min_size=4, max_size=4 + ) + ) + x_shape[1] = draw(st.integers(min_value=12, max_value=12)) + x_shape[2] = draw(st.integers(min_value=512, max_value=512)) + x_shape[3] = draw(st.integers(min_value=32, max_value=32)) + sin_emb_shape = draw( + st.lists( + st.integers(min_value=1, max_value=1), + min_size=4, + max_size=4, + ) + ) + sin_emb_shape[1] = draw(st.integers(min_value=1, max_value=1)) + sin_emb_shape[2] = draw(st.integers(min_value=512, max_value=512)) + sin_emb_shape[3] = draw(st.integers(min_value=32, max_value=32)) + cos_emb_shape = sin_emb_shape + + def generate_data(shape): + return np.random.random(shape).astype(np.float32) + + # Here we will compose a program + # Still has some risks that the program is invalid or cause bug while running + # Use function `is_program_valid` to filter the invalid programs before running + # Use function `add_skip_pass_case` to ignore the programs even if they cause bug while runing + split_op = OpConfig( + "split", + inputs={"X": ["x"]}, + outputs={"Out": ["split_out1", "split_out2"]}, + axis=3, + num=2, + ) + scale_op = OpConfig( + "scale", + inputs={"X": ["split_out2"]}, + outputs={"Out": ["scale_out"]}, + scale=-1, + ) + concat_op = OpConfig( + "concat", + inputs={"X": ["scale_out", "split_out1"]}, + outputs={"Out": ["concat_out"]}, + axis=-1, + ) + shape_op = OpConfig( + "shape", + inputs={"Input": ["x"]}, + outputs={"Out": ["shape_out"]}, + ) + slice1_op = OpConfig( + "slice", + inputs={"Input": ["shape_out"]}, + outputs={"Out": ["slice1_out"]}, + axes=[0], + starts=[-2], + ends=[-1], + infer_flags=[1], + decrease_axis=[0], + ) + slice_sin_op = OpConfig( + "slice", + inputs={"Input": ["sin_emb"], "EndsTensorList": ["slice1_out"]}, + outputs={"Out": ["slice_sin_out"]}, + axes=[2], + starts=[0], + ends=[-1], + infer_flags=[-1], + decrease_axis=[], + ) + slice_cos_op = OpConfig( + "slice", + inputs={"Input": ["cos_emb"], "EndsTensorList": ["slice1_out"]}, + outputs={"Out": ["slice_cos_out"]}, + axes=[2], + starts=[0], + ends=[-1], + infer_flags=[-1], + decrease_axis=[], + ) + mul1_op = OpConfig( + "elementwise_mul", + inputs={"X": ["concat_out"], "Y": ["slice_sin_out"]}, + outputs={"Out": ["mul1_out"]}, + ) + mul2_op = OpConfig( + "elementwise_mul", + inputs={"X": ["x"], "Y": ["slice_cos_out"]}, + outputs={"Out": ["mul2_out"]}, + ) + add_op = OpConfig( + "elementwise_add", + inputs={"X": ["mul2_out"], "Y": ["mul1_out"]}, + outputs={"Out": ["add_out"]}, + ) + + ops = [ + split_op, + scale_op, + concat_op, + shape_op, + slice1_op, + slice_sin_op, + slice_cos_op, + mul1_op, + mul2_op, + add_op, + ] + + program_config = ProgramConfig( + ops=ops, + inputs={ + "x": TensorConfig(data_gen=partial(generate_data, x_shape)), + "sin_emb": TensorConfig( + data_gen=partial(generate_data, sin_emb_shape) + ), + "cos_emb": TensorConfig( + data_gen=partial(generate_data, cos_emb_shape) + ), + }, + weights={}, + outputs=ops[-1].outputs["Out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["roformer_relative_pos_fuse_pass"], + ) + + +if __name__ == "__main__": + unittest.main()