Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XPU] add roformer relative embedding pass & kernel and spport in multi_encoder_xpu #62089

Merged
merged 1 commit into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(sine_pos_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(quant_dequant_xpu_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(roformer_relative_pos_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()

cc_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ namespace ir {
namespace patterns {

struct AdaptiveSeqlenPatternV1 : public PatternBase {
AdaptiveSeqlenPatternV1(PDPattern* pattern, const std::string& name_scope);
AdaptiveSeqlenPatternV1(PDPattern* pattern,
const std::string& name_scope,
const std::string& matmul_type);

// declare operator node's name
PATTERN_DECL_NODE(embedding_xpu);
Expand All @@ -44,7 +46,8 @@ struct AdaptiveSeqlenPatternV1 : public PatternBase {
};

AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern,
const std::string& name_scope)
const std::string& name_scope,
const std::string& matmul_type)
: PatternBase(pattern, name_scope, name_scope) {
auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr())
->assert_is_op("embedding_with_eltwise_add_xpu");
Expand All @@ -59,11 +62,11 @@ AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern,
->assert_is_op_input("multi_encoder_xpu", "x");

auto* mask = pattern->NewNode(mask_repr())
->assert_is_op_input("matmul", "X")
->assert_is_op_input("matmul", "Y");
auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul");
->assert_is_op_input(matmul_type, "X")
->assert_is_op_input(matmul_type, "Y");
auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op(matmul_type);
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul", "Out")
->assert_is_op_output(matmul_type, "Out")
->assert_is_op_input("scale", "X");
auto* scale = pattern->NewNode(scale_repr())->assert_is_op("scale");
auto* scale_out = pattern->NewNode(scale_out_repr())
Expand All @@ -88,9 +91,10 @@ AdaptiveSeqlenPatternV1::AdaptiveSeqlenPatternV1(PDPattern* pattern,
} // namespace patterns

int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1(
ir::Graph* graph) const {
ir::Graph* graph, const std::string& matmul_type) const {
GraphPatternDetector gpd;
patterns::AdaptiveSeqlenPatternV1 pattern(gpd.mutable_pattern(), name_scope_);
patterns::AdaptiveSeqlenPatternV1 pattern(
gpd.mutable_pattern(), name_scope_, matmul_type);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand Down Expand Up @@ -143,7 +147,9 @@ int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV1(
namespace patterns {

struct AdaptiveSeqlenPatternV2 : public PatternBase {
AdaptiveSeqlenPatternV2(PDPattern* pattern, const std::string& name_scope);
AdaptiveSeqlenPatternV2(PDPattern* pattern,
const std::string& name_scope,
const std::string& matmul_type);

// declare operator node's name
PATTERN_DECL_NODE(embedding_xpu);
Expand Down Expand Up @@ -172,7 +178,8 @@ struct AdaptiveSeqlenPatternV2 : public PatternBase {
};

AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern,
const std::string& name_scope)
const std::string& name_scope,
const std::string& matmul_type)
: PatternBase(pattern, name_scope, name_scope) {
auto* embedding_xpu = pattern->NewNode(embedding_xpu_repr())
->assert_is_op("embedding_with_eltwise_add_xpu");
Expand Down Expand Up @@ -201,11 +208,11 @@ AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern,
pattern->NewNode(unsqueeze_0_repr())->assert_is_op("unsqueeze2");
auto* unsqueeze_0_out = pattern->NewNode(unsqueeze_0_out_repr())
->assert_is_op_output("unsqueeze2", "Out")
->assert_is_op_input("matmul_v2", "X")
->assert_is_op_input("matmul_v2", "Y");
auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op("matmul_v2");
->assert_is_op_input(matmul_type, "X")
->assert_is_op_input(matmul_type, "Y");
auto* matmul = pattern->NewNode(matmul_repr())->assert_is_op(matmul_type);
auto* matmul_out = pattern->NewNode(matmul_out_repr())
->assert_is_op_output("matmul_v2", "Out")
->assert_is_op_output(matmul_type, "Out")
->assert_is_op_input("scale", "X");
auto* scale_0 = pattern->NewNode(scale_0_repr())->assert_is_op("scale");
auto* scale_0_out = pattern->NewNode(scale_0_out_repr())
Expand Down Expand Up @@ -244,9 +251,10 @@ AdaptiveSeqlenPatternV2::AdaptiveSeqlenPatternV2(PDPattern* pattern,
} // namespace patterns

int MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyAdaptiveSeqlenPassV2(
ir::Graph* graph) const {
ir::Graph* graph, const std::string& matmul_type) const {
GraphPatternDetector gpd;
patterns::AdaptiveSeqlenPatternV2 pattern(gpd.mutable_pattern(), name_scope_);
patterns::AdaptiveSeqlenPatternV2 pattern(
gpd.mutable_pattern(), name_scope_, matmul_type);

int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Expand Down Expand Up @@ -324,9 +332,13 @@ void MultiEncoderXPUAdaptiveSeqlenFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
std::vector<std::string> matmul_types{"matmul", "matmul_v2"};
int found_subgraph_count = 0;
for (auto& matmul_type : matmul_types) {
found_subgraph_count += ApplyAdaptiveSeqlenPassV1(graph, matmul_type);
found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph, matmul_type);
}

int found_subgraph_count = ApplyAdaptiveSeqlenPassV1(graph);
found_subgraph_count += ApplyAdaptiveSeqlenPassV2(graph);
AddStatis(found_subgraph_count);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase {
|
out_var*
*/
int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph) const;
int ApplyAdaptiveSeqlenPassV1(ir::Graph* graph,
const std::string& matmul_type) const;

/*
adaptive seqlen V2, before:
Expand Down Expand Up @@ -132,7 +133,8 @@ class MultiEncoderXPUAdaptiveSeqlenFusePass : public FusePassBase {
|
out_var*
*/
int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph) const;
int ApplyAdaptiveSeqlenPassV2(ir::Graph* graph,
const std::string& matmul_type) const;

private:
const std::string name_scope_{"multi_encoder_xpu_adaptive_seqlen_fuse_pass"};
Expand Down
Loading