Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GemmEpilogueOp with series of CUTLASS kernel #61925

Merged
merged 63 commits into from
May 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
24470f3
split seq_len to improve mmha perf
YKTian-x2b Feb 21, 2024
e428611
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Feb 21, 2024
c3c8e1d
update postProcessKernel and some HyperParam
YKTian-x2b Mar 1, 2024
536831c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Mar 1, 2024
f63e995
Update mmha Kernel
YKTian-x2b Mar 4, 2024
f93b21c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Mar 4, 2024
e1b65fe
add cutlass fused fc ops
YKTian-x2b Mar 12, 2024
cf5264f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Mar 12, 2024
6f5cd32
mod of mmha was restored to the state of three months ago, and unites…
YKTian-x2b Mar 19, 2024
5f8e222
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Mar 19, 2024
f959b87
update fc ops
YKTian-x2b Mar 19, 2024
3f24f7f
refine some code
yuanlehome Mar 27, 2024
a8deb08
update
yuanlehome Mar 27, 2024
7ff49ba
update
yuanlehome Mar 28, 2024
e3c1981
fix drr rewrite
yuanlehome Mar 28, 2024
e071f57
update fc_fuse_pass with 2D_elementwiseAdd, will fix a bug with new pr
YKTian-x2b Mar 28, 2024
8fd8d3e
merge yuanlehome
YKTian-x2b Mar 29, 2024
519a02b
merge upstream/develop
YKTian-x2b Mar 29, 2024
d04df38
llm perf test completed
YKTian-x2b Apr 8, 2024
5ddee1f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Apr 8, 2024
6d2f530
a little bit of mod
YKTian-x2b Apr 8, 2024
f486821
rm non-related files
YKTian-x2b Apr 8, 2024
c48a0a6
try to recover mmha files that were deleted by mistake, my intention …
YKTian-x2b Apr 8, 2024
f8cc4eb
recover non-related mod
YKTian-x2b Apr 8, 2024
3f813a1
Revert "recover non-related mod"
YKTian-x2b Apr 8, 2024
a3fe38b
recover non-related mod
YKTian-x2b Apr 8, 2024
a499960
recover non-related mod
YKTian-x2b Apr 9, 2024
85ce927
unitest succ for two path(cublas and cutlass)
YKTian-x2b Apr 9, 2024
d59a483
clean dual-path
YKTian-x2b Apr 12, 2024
89c8b91
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Apr 12, 2024
586a6b1
Specification code and comments
YKTian-x2b Apr 15, 2024
9619d99
a mod default sm_version in fc_decl.h
YKTian-x2b Apr 15, 2024
f3c5c91
yuanlehome review mod
YKTian-x2b Apr 16, 2024
28050ee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Apr 16, 2024
f8b6805
yuanlehome review mod, rename gemm_epilogue files
YKTian-x2b Apr 16, 2024
333f445
yuanlehome review mod, rename gemm_epilogue again
YKTian-x2b Apr 16, 2024
18a03c3
push for merge new pr
YKTian-x2b Apr 16, 2024
8667893
unit test for fc_fuse_pass is available now
YKTian-x2b Apr 16, 2024
a20925f
rename fc_fuse_pass
YKTian-x2b Apr 17, 2024
7ead96d
rename all fc to gemm_epilogue
YKTian-x2b Apr 17, 2024
3bf1c93
update README
YKTian-x2b Apr 17, 2024
3850cab
update unittest and swizzle config in cutlass kernel
YKTian-x2b Apr 18, 2024
ee43d5b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Apr 18, 2024
cc2b2a1
var type error fixed
YKTian-x2b Apr 18, 2024
2549b3b
conflict fixed
YKTian-x2b Apr 18, 2024
d4c732b
undo comment out acts
YKTian-x2b Apr 18, 2024
d8a27eb
comment out acts
YKTian-x2b Apr 19, 2024
3abce89
code style fixed
YKTian-x2b Apr 19, 2024
952ddf4
code style fixed
YKTian-x2b Apr 19, 2024
191d969
Tuning unit tests
YKTian-x2b Apr 20, 2024
a3f1e86
merge conflict for pr merge
YKTian-x2b Apr 22, 2024
a0a1e3b
to avoid ci timeout, update unittest
YKTian-x2b Apr 22, 2024
f52173c
to avoid ci-converage timeout, update unittest
YKTian-x2b Apr 23, 2024
bef2ec4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b Apr 24, 2024
f50fbdc
conflict fixed(LLee233)
YKTian-x2b Apr 24, 2024
008e268
with cutlass download
YKTian-x2b Apr 24, 2024
b31b878
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b May 6, 2024
9bc0e06
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b May 8, 2024
2efeb6f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b May 8, 2024
5531403
use PADDLE_ENFORCE && unitest conflict fix
YKTian-x2b May 8, 2024
7cc423e
for unitest timeout
YKTian-x2b May 9, 2024
cc6a938
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b May 10, 2024
2b657ad
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,11 @@ bool AnalysisPredictor::PrepareExecutor() {
// gpu
if (!config_.custom_pass_only_) {
for (const auto &gpu_pass : kPirGpuPasses) {
pass_pm.AddPass(pir::PassRegistry::Instance().Get(gpu_pass));
auto pass = pir::PassRegistry::Instance().Get(gpu_pass);
if (pass->name() == "matmul_add_act_fuse_pass") {
pass->Set("use_cutlass", new bool(config_.use_cutlass_));
}
pass_pm.AddPass(std::move(pass));
}
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ const std::vector<std::string> kPirGpuPasses{
"conv2d_add_fuse_pass",
"embedding_eltwise_layernorm_fuse_pass",
"multihead_matmul_fuse_pass",
"fc_fuse_pass",
"matmul_add_act_fuse_pass",
"fc_elementwise_layernorm_fuse_pass",
"matmul_scale_fuse_pass",
"matmul_transpose_fuse_pass",
Expand All @@ -631,7 +631,7 @@ const std::vector<std::string> kPirMkldnnPasses{
"matmul_transpose_reshape_fuse_pass",
"matmul_elementwise_add_fuse_pass",
"matmul_activation_fuse_pass",
"fc_fuse_pass",
"matmul_add_act_fuse_pass",
"fc_onednn_enable_pass",
"fc_activation_fuse_pass",
"self_attention_fuse_pass",
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/ops_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@
'partial_allgather_',
'nop',
'nop_',
'gemm_epilogue',
'push_dense',
'limit_by_capacity',
'global_scatter',
Expand Down
138 changes: 0 additions & 138 deletions paddle/fluid/pir/transforms/gpu/fc_fuse_pass.cc

This file was deleted.

216 changes: 216 additions & 0 deletions paddle/fluid/pir/transforms/gpu/matmul_add_act_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/gpu/matmul_add_act_fuse_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/fluid/pir/utils/general_functions.h"

#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"

namespace {

std::set<std::string> act_ops = {
"gelu",
"relu",
};
std::unordered_map<std::string, std::string> activation_type = {
{"gelu", paddle::dialect::GeluOp::name()},
{"relu", paddle::dialect::ReluOp::name()},
};

class MatmulAddPattern : public paddle::drr::DrrPatternBase {
private:
std::string fused_op_name_;
bool reverse_add_;

public:
explicit MatmulAddPattern(const std::string &fused_op_name,
const bool reverse_add)
: fused_op_name_(fused_op_name), reverse_add_(reverse_add) {}

std::string name() const override { return "MatmulAddPattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x")},
{"transpose_y", pat.Attr("transpose_y")}});
const auto &add = pat.Op(paddle::dialect::AddOp::name());
matmul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("matmul_out")});
pat.Tensor("add_out") =
reverse_add_ ? add(pat.Tensor("y"), pat.Tensor("matmul_out"))
: add(pat.Tensor("matmul_out"), pat.Tensor("y"));

pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w"));
auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x"));
auto y_dims = pir::GetShapeFromValue(match_ctx.Tensor("y"));
if (w_dims.size() != 2 || x_dims.size() < 2) {
return false;
}
// Currently,FcOp and GemmEpilogueOp support only RRR format
if (x_dims.at(x_dims.size() - 1) != w_dims.at(0) ||
match_ctx.Attr<bool>("transpose_x") == true ||
match_ctx.Attr<bool>("transpose_y") == true) {
return false;
}

if (y_dims.size() == 1) {
return y_dims.at(0) == w_dims.at(1);
}

if (fused_op_name_ == paddle::dialect::FcOp::name()) {
if (y_dims.size() == 2) {
return y_dims.at(0) == 1 && y_dims.at(1) == w_dims.at(1);
}
} else {
if (y_dims.size() == x_dims.size()) {
if (y_dims.size() == 2) {
return ((y_dims.at(0) == 1) || (y_dims.at(0) == x_dims.at(0))) &&
y_dims.at(1) == w_dims.at(1);
}
for (size_t ii = 0; ii < x_dims.size() - 1; ii++) {
if (y_dims.at(ii) != x_dims.at(ii)) {
return false;
}
}
return y_dims.at(y_dims.size() - 1) == w_dims.at(1);
}
}
return false;
});

paddle::drr::ResultPattern res = pat.ResultPattern();
const auto &in_num_col_dims_attr =
res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int {
auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x"));
return static_cast<int>(x_dims.size()) - 1;
});
const auto &gemm_epilogue =
res.Op(fused_op_name_,
{{
{"in_num_col_dims", in_num_col_dims_attr},
{"activation_type", res.StrAttr("")},
{"padding_weights", res.BoolAttr(false)},
}});
gemm_epilogue({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")},
{&res.Tensor("add_out")});
}
};

// Act supports [relu, gelu]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我看gemm_epilogue支持的激活函数更多,为什么pass只支持两个呢?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时先这两个,后续是否扩展,我可能还得问问康康哥。

class MatmulAddActPattern : public paddle::drr::DrrPatternBase {
private:
std::string act_type_;
std::string fused_op_name_;

public:
explicit MatmulAddActPattern(const std::string &act_type,
const std::string &fused_op_name)
: act_type_(act_type), fused_op_name_(fused_op_name) {}

std::string name() const override { return "MatmulAddActPattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
const auto &gemm_epilogue =
pat.Op(fused_op_name_,
{{
{"in_num_col_dims", pat.Attr("in_num_col_dims")},
{"activation_type", pat.Attr("activation_type")},
{"padding_weights", pat.Attr("padding_weights")},
}});
std::unordered_map<std::string, paddle::drr::Attribute> act_attrs;
if (act_type_ == "gelu") {
act_attrs.emplace("approximate", pat.Attr("approximate"));
}
const auto &act = pat.Op(activation_type[act_type_], act_attrs);

gemm_epilogue({&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("y")},
{&pat.Tensor("gemm_epilogue_out")});
act({&pat.Tensor("gemm_epilogue_out")}, {&pat.Tensor("act_out")});

pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) {
const std::string &act_type =
match_ctx.Attr<std::string>("activation_type");
if (!act_type.empty()) return false;
if (act_type_ == "gelu") {
bool Attr_approx = match_ctx.Attr<bool>("approximate");
if (Attr_approx) return false;
}
return true;
});

paddle::drr::ResultPattern res = pat.ResultPattern();
std::unordered_map<std::string, paddle::drr::Attribute> fused_attrs{
{"in_num_col_dims", pat.Attr("in_num_col_dims")},
{"activation_type", res.StrAttr(act_type_)},
{"padding_weights", pat.Attr("padding_weights")},
};
const auto &gemm_epilogue_with_act = res.Op(fused_op_name_, fused_attrs);
gemm_epilogue_with_act(
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")},
{&res.Tensor("act_out")});
}
};

class MatmulAddActFusePass : public pir::PatternRewritePass {
public:
MatmulAddActFusePass()
: pir::PatternRewritePass("matmul_add_act_fuse_pass", 2) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);

bool use_cutlass = false;
if (Has(std::string("use_cutlass"))) {
use_cutlass = Get<bool>(std::string("use_cutlass"));
}
if (use_cutlass) {
/// MatmulAddPattern
ps.Add(paddle::drr::Create<MatmulAddPattern>(
context, paddle::dialect::GemmEpilogueOp::name(), true));
ps.Add(paddle::drr::Create<MatmulAddPattern>(
context, paddle::dialect::GemmEpilogueOp::name(), false));
/// MatmulAddActPattern
for (const auto &act_op : act_ops) {
ps.Add(paddle::drr::Create<MatmulAddActPattern>(
context, act_op, paddle::dialect::GemmEpilogueOp::name()));
}
} else {
/// MatmulAddPattern
ps.Add(paddle::drr::Create<MatmulAddPattern>(
context, paddle::dialect::FcOp::name(), false));
/// MatmulAddActPattern
ps.Add(paddle::drr::Create<MatmulAddActPattern>(
context, "relu", paddle::dialect::FcOp::name()));
}
return ps;
}
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateMatmulAddActFusePass() {
return std::make_unique<MatmulAddActFusePass>();
}

} // namespace pir

REGISTER_IR_PASS(matmul_add_act_fuse_pass, MatmulAddActFusePass);
Loading