Skip to content

Commit

Permalink
[XPU] add roformer relative embedding pass & kernel and spport in mul…
Browse files Browse the repository at this point in the history
…ti_encoder_xpu
  • Loading branch information
NeroLoh committed Feb 26, 2024
1 parent 044dfe1 commit d6b0308
Show file tree
Hide file tree
Showing 14 changed files with 975 additions and 47 deletions.
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(sine_pos_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(quant_dequant_xpu_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(roformer_relative_pos_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()

cc_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ namespace ir {
namespace patterns {

struct AdaptiveSeqlenPatternV1 : public PatternBase {
AdaptiveSeqlenPatternV1(PDPattern* pattern, const std::string& name_scope);
AdaptiveSeqlenPatternV1(PDPattern* pattern,
const std::string& name_scope,
const std::string& matmul_type);

// declare operator node's name
PATTERN_DECL_NODE(embedding_xpu);
Expand All @@ -44,7 +46,8 @@ struct AdaptiveSeqlenPatternV1 : public PatternBase {
};

AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern,
const std::string& name_scope)
const std::string& name_scope,
const std::string& matmul_type)
: PatternBase(pattern, name_scope, name_scope) {
auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr())
->assert_is_op("embedding_with_eltwise_add_xpu");
Expand All @@ -59,11 +62,11 @@ AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern,
->assert_is_op_input("multi_encoder_xpu", "x");

auto* mask = pattern->NewNode(mask_repr())
->assert_is_op_input("matmul", "X")
->assert_is_op_input("matmul", "Y");
auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul");
->assert_is_op_input(matmul_type, "X")
->assert_is_op_input(matmul_type, "Y");
auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op(matmul_type);
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul", "Out")
->assert_is_op_output(matmul_type, "Out")
->assert_is_op_input("scale", "X");
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out = pattern->NewNode(scale_out_repr())
Expand All @@ -88,9 +91,10 @@ AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern,
} // namespace patterns

int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1(
ir::Graph* graph) const {
ir::Graph* graph, const std::string& matmul_type) const {
GraphPatternDetector gpd;
patterns::AdaptiveSeqlenPatternV1 pattern(gpd.mutable_pattern(), name_scope_);
patterns::AdaptiveSeqlenPatternV1 pattern(
gpd.mutable_pattern(), name_scope_, matmul_type);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand Down Expand Up @@ -143,7 +147,9 @@ int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1(
namespace patterns {

struct AdaptiveSeqlenPatternV2 : public PatternBase {
AdaptiveSeqlenPatternV2(PDPattern* pattern, const std::string& name_scope);
AdaptiveSeqlenPatternV2(PDPattern* pattern,
const std::string& name_scope,
const std::string& matmul_type);

// declare operator node's name
PATTERN_DECL_NODE(embedding_xpu);
Expand Down Expand Up @@ -172,7 +178,8 @@ struct AdaptiveSeqlenPatternV2 : public PatternBase {
};

AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern,
const std::string& name_scope)
const std::string& name_scope,
const std::string& matmul_type)
: PatternBase(pattern, name_scope, name_scope) {
auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr())
->assert_is_op("embedding_with_eltwise_add_xpu");
Expand Down Expand Up @@ -201,11 +208,11 @@ AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern,
pattern->NewNode(unsqueeze_0_repr())->assert_is_op("unsqueeze2");
auto* unsqueeze_0_out = pattern->NewNode(unsqueeze_0_out_repr())
->assert_is_op_output("unsqueeze2", "Out")
->assert_is_op_input("matmul_v2", "X")
->assert_is_op_input("matmul_v2", "Y");
auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul_v2");
->assert_is_op_input(matmul_type, "X")
->assert_is_op_input(matmul_type, "Y");
auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op(matmul_type);
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul_v2", "Out")
->assert_is_op_output(matmul_type, "Out")
->assert_is_op_input("scale", "X");
auto* scale_0 = pattern->NewNode(scale_0_repr())->assert_is_op("scale");
auto* scale_0_out = pattern->NewNode(scale_0_out_repr())
Expand Down Expand Up @@ -244,9 +251,10 @@ AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern,
} // namespace patterns

int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV2(
ir::Graph* graph) const {
ir::Graph* graph, const std::string& matmul_type) const {
GraphPatternDetector gpd;
patterns::AdaptiveSeqlenPatternV2 pattern(gpd.mutable_pattern(), name_scope_);
patterns::AdaptiveSeqlenPatternV2 pattern(
gpd.mutable_pattern(), name_scope_, matmul_type);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand Down Expand Up @@ -324,9 +332,13 @@ void MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
std::vector<std::string> matmul_types{"matmul", "matmul_v2"};
int found_subgraph_count = 0;
for (auto& matmul_type : matmul_types) {
found_subgraph_count += ApplyAdaptiveSeqlenPassV1(graph, matmul_type);
found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph, matmul_type);
}

int found_subgraph_count = ApplyAdaptiveSeqlenPassV1(graph);
found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph);
AddStatis(found_subgraph_count);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase {
|
out_var*
*/
int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph) const;
int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph,
const std::string& matmul_type) const;

/*
adaptive seqlen V2, before:
Expand Down Expand Up @@ -132,7 +133,8 @@ class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase {
|
out_var*
*/
int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph) const;
int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph,
const std::string& matmul_type) const;

private:
const std::string name_scope_{"multi_encoder_xpu_adaptive_seqlen_fuse_pass"};
Expand Down
Loading

0 comments on commit d6b0308

Please sign in to comment.