-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GemmEpilogueOp with series of CUTLASS kernel #61925
Merged
zhoutianzi666
merged 63 commits into
PaddlePaddle:develop
from
YKTian-x2b:my-cool-stuff
May 11, 2024
Merged
Changes from 61 commits
Commits
Show all changes
63 commits
Select commit
Hold shift + click to select a range
24470f3
split seq_len to improve mmha perf
YKTian-x2b e428611
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b c3c8e1d
update postProcessKernel and some HyperParam
YKTian-x2b 536831c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b f63e995
Update mmha Kernel
YKTian-x2b f93b21c
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b e1b65fe
add cutlass fused fc ops
YKTian-x2b cf5264f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b 6f5cd32
mod of mmha was restored to the state of three months ago, and unites…
YKTian-x2b 5f8e222
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b f959b87
update fc ops
YKTian-x2b 3f24f7f
refine some code
yuanlehome a8deb08
update
yuanlehome 7ff49ba
update
yuanlehome e3c1981
fix drr rewrite
yuanlehome e071f57
update fc_fuse_pass with 2D_elementwiseAdd, will fix a bug with new pr
YKTian-x2b 8fd8d3e
merge yuanlehome
YKTian-x2b 519a02b
merge upstream/develop
YKTian-x2b d04df38
llm perf test completed
YKTian-x2b 5ddee1f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b 6d2f530
a little bit of mod
YKTian-x2b f486821
rm non-related files
YKTian-x2b c48a0a6
try to recover mmha files that were deleted by mistake, my intention …
YKTian-x2b f8cc4eb
recover non-related mod
YKTian-x2b 3f813a1
Revert "recover non-related mod"
YKTian-x2b a3fe38b
recover non-related mod
YKTian-x2b a499960
recover non-related mod
YKTian-x2b 85ce927
unitest succ for two path(cublas and cutlass)
YKTian-x2b d59a483
clean dual-path
YKTian-x2b 89c8b91
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b 586a6b1
Specification code and comments
YKTian-x2b 9619d99
a mod default sm_version in fc_decl.h
YKTian-x2b f3c5c91
yuanlehome review mod
YKTian-x2b 28050ee
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b f8b6805
yuanlehome review mod, rename gemm_epilogue files
YKTian-x2b 333f445
yuanlehome review mod, rename gemm_epilogue again
YKTian-x2b 18a03c3
push for merge new pr
YKTian-x2b 8667893
unit test for fc_fuse_pass is available now
YKTian-x2b a20925f
rename fc_fuse_pass
YKTian-x2b 7ead96d
rename all fc to gemm_epilogue
YKTian-x2b 3bf1c93
update README
YKTian-x2b 3850cab
update unittest and swizzle config in cutlass kernel
YKTian-x2b ee43d5b
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b cc2b2a1
var type error fixed
YKTian-x2b 2549b3b
conflict fixed
YKTian-x2b d4c732b
undo comment out acts
YKTian-x2b d8a27eb
comment out acts
YKTian-x2b 3abce89
code style fixed
YKTian-x2b 952ddf4
code style fixed
YKTian-x2b 191d969
Tuning unit tests
YKTian-x2b a3f1e86
merge conflict for pr merge
YKTian-x2b a0a1e3b
to avoid ci timeout, update unittest
YKTian-x2b f52173c
to avoid ci-converage timeout, update unittest
YKTian-x2b bef2ec4
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b f50fbdc
conflict fixed(LLee233)
YKTian-x2b 008e268
with cutlass download
YKTian-x2b b31b878
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b 9bc0e06
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b 2efeb6f
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b 5531403
use PADDLE_ENFORCE && unitest conflict fix
YKTian-x2b 7cc423e
for unitest timeout
YKTian-x2b cc6a938
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b 2b657ad
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
YKTian-x2b File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
216 changes: 216 additions & 0 deletions
216
paddle/fluid/pir/transforms/gpu/matmul_add_act_fuse_pass.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,216 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/pir/transforms/gpu/matmul_add_act_fuse_pass.h" | ||
|
||
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" | ||
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" | ||
#include "paddle/fluid/pir/utils/general_functions.h" | ||
|
||
#include "paddle/pir/include/pass/pass.h" | ||
#include "paddle/pir/include/pass/pass_registry.h" | ||
|
||
namespace { | ||
|
||
std::set<std::string> act_ops = { | ||
"gelu", | ||
"relu", | ||
}; | ||
std::unordered_map<std::string, std::string> activation_type = { | ||
{"gelu", paddle::dialect::GeluOp::name()}, | ||
{"relu", paddle::dialect::ReluOp::name()}, | ||
}; | ||
|
||
class MatmulAddPattern : public paddle::drr::DrrPatternBase { | ||
private: | ||
std::string fused_op_name_; | ||
bool reverse_add_; | ||
|
||
public: | ||
explicit MatmulAddPattern(const std::string &fused_op_name, | ||
const bool reverse_add) | ||
: fused_op_name_(fused_op_name), reverse_add_(reverse_add) {} | ||
|
||
std::string name() const override { return "MatmulAddPattern"; } | ||
|
||
void operator()(paddle::drr::DrrPatternContext *ctx) const override { | ||
paddle::drr::SourcePattern pat = ctx->SourcePattern(); | ||
const auto &matmul = pat.Op(paddle::dialect::MatmulOp::name(), | ||
{{"transpose_x", pat.Attr("transpose_x")}, | ||
{"transpose_y", pat.Attr("transpose_y")}}); | ||
const auto &add = pat.Op(paddle::dialect::AddOp::name()); | ||
matmul({&pat.Tensor("x"), &pat.Tensor("w")}, {&pat.Tensor("matmul_out")}); | ||
pat.Tensor("add_out") = | ||
reverse_add_ ? add(pat.Tensor("y"), pat.Tensor("matmul_out")) | ||
: add(pat.Tensor("matmul_out"), pat.Tensor("y")); | ||
|
||
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { | ||
auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); | ||
auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); | ||
auto y_dims = pir::GetShapeFromValue(match_ctx.Tensor("y")); | ||
if (w_dims.size() != 2 || x_dims.size() < 2) { | ||
return false; | ||
} | ||
// Currently,FcOp and GemmEpilogueOp support only RRR format | ||
if (x_dims.at(x_dims.size() - 1) != w_dims.at(0) || | ||
match_ctx.Attr<bool>("transpose_x") == true || | ||
match_ctx.Attr<bool>("transpose_y") == true) { | ||
return false; | ||
} | ||
|
||
if (y_dims.size() == 1) { | ||
return y_dims.at(0) == w_dims.at(1); | ||
} | ||
|
||
if (fused_op_name_ == paddle::dialect::FcOp::name()) { | ||
if (y_dims.size() == 2) { | ||
return y_dims.at(0) == 1 && y_dims.at(1) == w_dims.at(1); | ||
} | ||
} else { | ||
if (y_dims.size() == x_dims.size()) { | ||
if (y_dims.size() == 2) { | ||
return ((y_dims.at(0) == 1) || (y_dims.at(0) == x_dims.at(0))) && | ||
y_dims.at(1) == w_dims.at(1); | ||
} | ||
for (size_t ii = 0; ii < x_dims.size() - 1; ii++) { | ||
if (y_dims.at(ii) != x_dims.at(ii)) { | ||
return false; | ||
} | ||
} | ||
return y_dims.at(y_dims.size() - 1) == w_dims.at(1); | ||
} | ||
} | ||
return false; | ||
}); | ||
|
||
paddle::drr::ResultPattern res = pat.ResultPattern(); | ||
const auto &in_num_col_dims_attr = | ||
res.ComputeAttr([](const paddle::drr::MatchContext &match_ctx) -> int { | ||
auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); | ||
return static_cast<int>(x_dims.size()) - 1; | ||
}); | ||
const auto &gemm_epilogue = | ||
res.Op(fused_op_name_, | ||
{{ | ||
{"in_num_col_dims", in_num_col_dims_attr}, | ||
{"activation_type", res.StrAttr("")}, | ||
{"padding_weights", res.BoolAttr(false)}, | ||
}}); | ||
gemm_epilogue({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")}, | ||
{&res.Tensor("add_out")}); | ||
} | ||
}; | ||
|
||
// Act supports [relu, gelu] | ||
class MatmulAddActPattern : public paddle::drr::DrrPatternBase { | ||
private: | ||
std::string act_type_; | ||
std::string fused_op_name_; | ||
|
||
public: | ||
explicit MatmulAddActPattern(const std::string &act_type, | ||
const std::string &fused_op_name) | ||
: act_type_(act_type), fused_op_name_(fused_op_name) {} | ||
|
||
std::string name() const override { return "MatmulAddActPattern"; } | ||
|
||
void operator()(paddle::drr::DrrPatternContext *ctx) const override { | ||
paddle::drr::SourcePattern pat = ctx->SourcePattern(); | ||
const auto &gemm_epilogue = | ||
pat.Op(fused_op_name_, | ||
{{ | ||
{"in_num_col_dims", pat.Attr("in_num_col_dims")}, | ||
{"activation_type", pat.Attr("activation_type")}, | ||
{"padding_weights", pat.Attr("padding_weights")}, | ||
}}); | ||
std::unordered_map<std::string, paddle::drr::Attribute> act_attrs; | ||
if (act_type_ == "gelu") { | ||
act_attrs.emplace("approximate", pat.Attr("approximate")); | ||
} | ||
const auto &act = pat.Op(activation_type[act_type_], act_attrs); | ||
|
||
gemm_epilogue({&pat.Tensor("x"), &pat.Tensor("w"), &pat.Tensor("y")}, | ||
{&pat.Tensor("gemm_epilogue_out")}); | ||
act({&pat.Tensor("gemm_epilogue_out")}, {&pat.Tensor("act_out")}); | ||
|
||
pat.AddConstraint([&](const paddle::drr::MatchContext &match_ctx) { | ||
const std::string &act_type = | ||
match_ctx.Attr<std::string>("activation_type"); | ||
if (!act_type.empty()) return false; | ||
if (act_type_ == "gelu") { | ||
bool Attr_approx = match_ctx.Attr<bool>("approximate"); | ||
if (Attr_approx) return false; | ||
} | ||
return true; | ||
}); | ||
|
||
paddle::drr::ResultPattern res = pat.ResultPattern(); | ||
std::unordered_map<std::string, paddle::drr::Attribute> fused_attrs{ | ||
{"in_num_col_dims", pat.Attr("in_num_col_dims")}, | ||
{"activation_type", res.StrAttr(act_type_)}, | ||
{"padding_weights", pat.Attr("padding_weights")}, | ||
}; | ||
const auto &gemm_epilogue_with_act = res.Op(fused_op_name_, fused_attrs); | ||
gemm_epilogue_with_act( | ||
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")}, | ||
{&res.Tensor("act_out")}); | ||
} | ||
}; | ||
|
||
class MatmulAddActFusePass : public pir::PatternRewritePass { | ||
public: | ||
MatmulAddActFusePass() | ||
: pir::PatternRewritePass("matmul_add_act_fuse_pass", 2) {} | ||
|
||
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { | ||
pir::RewritePatternSet ps(context); | ||
|
||
bool use_cutlass = false; | ||
if (Has(std::string("use_cutlass"))) { | ||
use_cutlass = Get<bool>(std::string("use_cutlass")); | ||
} | ||
if (use_cutlass) { | ||
/// MatmulAddPattern | ||
ps.Add(paddle::drr::Create<MatmulAddPattern>( | ||
context, paddle::dialect::GemmEpilogueOp::name(), true)); | ||
ps.Add(paddle::drr::Create<MatmulAddPattern>( | ||
context, paddle::dialect::GemmEpilogueOp::name(), false)); | ||
/// MatmulAddActPattern | ||
for (const auto &act_op : act_ops) { | ||
ps.Add(paddle::drr::Create<MatmulAddActPattern>( | ||
context, act_op, paddle::dialect::GemmEpilogueOp::name())); | ||
} | ||
} else { | ||
/// MatmulAddPattern | ||
ps.Add(paddle::drr::Create<MatmulAddPattern>( | ||
context, paddle::dialect::FcOp::name(), false)); | ||
/// MatmulAddActPattern | ||
ps.Add(paddle::drr::Create<MatmulAddActPattern>( | ||
context, "relu", paddle::dialect::FcOp::name())); | ||
} | ||
return ps; | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace pir { | ||
|
||
std::unique_ptr<Pass> CreateMatmulAddActFusePass() { | ||
return std::make_unique<MatmulAddActFusePass>(); | ||
} | ||
|
||
} // namespace pir | ||
|
||
REGISTER_IR_PASS(matmul_add_act_fuse_pass, MatmulAddActFusePass); |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我看gemm_epilogue支持的激活函数更多,为什么pass只支持两个呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时先这两个,后续是否扩展,我可能还得问问康康哥。