From a5db941040eaf0086669922731f24f4c1b0945ce Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Fri, 12 Jan 2024 08:21:38 +0000 Subject: [PATCH 1/5] reconstruct drr --- .../add_broadcast_to_elementwise_pass.cc | 1 - .../transforms/fully_insert_broadcast_pass.cc | 1 - ...e_shape_ops_into_generate_shape_op_pass.cc | 1 - .../operator/transforms/pd_to_cinn_pass.cc | 21 ++++++--- ...plit_generate_shape_into_shape_ops_pass.cc | 1 - paddle/fluid/pir/drr/README.md | 22 ++++++--- paddle/fluid/pir/drr/README_cn.md | 16 ++++--- paddle/fluid/pir/drr/api/drr_pattern_base.h | 36 +++++++++------ .../fluid/pir/drr/api/drr_pattern_context.h | 12 ++--- .../pir/drr/{ => api}/drr_rewrite_pattern.h | 37 +++++---------- paddle/fluid/pir/drr/{ => api}/ir_operation.h | 4 +- paddle/fluid/pir/drr/{ => api}/ir_value.h | 4 +- paddle/fluid/pir/drr/api/match_context.h | 1 - paddle/fluid/pir/drr/api/tensor_interface.cc | 2 +- .../pir/drr/{api => }/drr_pattern_context.cc | 12 +---- paddle/fluid/pir/drr/drr_rewrite_pattern.cc | 46 ++++++++++++++++++- .../fluid/pir/drr/{api => }/match_context.cc | 2 +- paddle/fluid/pir/drr/match_context_impl.h | 4 +- .../transforms/fusion/attention_fuse_pass.cc | 7 ++- .../transforms/fusion/conv2d_add_fuse_pass.cc | 5 +- .../fc_elementwise_layernorm_fuse_pass.cc | 14 ++++-- .../pir/transforms/fusion/fc_fuse_pass.cc | 9 ++-- .../fused_dot_product_attention_pass.cc | 30 ++++++++---- .../fusion/fused_dropout_add_pass.cc | 14 ++++-- .../fusion/fused_gemm_epilogue_pass.cc | 34 +++++++++----- .../fused_linear_param_grad_add_pass.cc | 40 ++++++++++------ .../fusion/fused_weight_only_linear_pass.cc | 7 ++- .../fusion/matmul_scale_fuse_pass.cc | 5 +- .../pir/transforms/identity_op_clean_pass.cc | 44 ++++++++++++------ paddle/pir/pattern_rewrite/pattern_match.h | 2 +- .../drr_same_type_binding_test.cc | 13 ++++-- test/cpp/pir/pattern_rewrite/drr_test.cc | 35 ++++++++++---- 32 files changed, 312 insertions(+), 170 deletions(-) rename paddle/fluid/pir/drr/{ => api}/drr_rewrite_pattern.h (75%) rename paddle/fluid/pir/drr/{ => api}/ir_operation.h (95%) rename paddle/fluid/pir/drr/{ => api}/ir_value.h (98%) rename paddle/fluid/pir/drr/{api => }/drr_pattern_context.cc (93%) rename paddle/fluid/pir/drr/{api => }/match_context.cc (97%) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc index a887b035852a3..185800c623ffc 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc @@ -20,7 +20,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pattern_rewrite/pattern_applicator.h" diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc index e5347281e009a..79eee23eedc61 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fully_insert_broadcast_pass.cc @@ -21,7 +21,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pattern_rewrite/pattern_applicator.h" diff --git a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc index 9a9057e993be7..824749968fe13 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/fuse_shape_ops_into_generate_shape_op_pass.cc @@ -25,7 +25,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/dialect/shape/utils/shape_utils.h" #include "paddle/pir/pass/pass.h" diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 8c2becde5d990..45c7c3900b166 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -20,7 +20,6 @@ #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/api/drr_pattern_base.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/pass.h" @@ -31,7 +30,7 @@ namespace cinn { namespace dialect { namespace ir { -class SumOpPattern : public paddle::drr::DrrPatternBase { +class SumOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -55,9 +54,11 @@ class SumOpPattern : public paddle::drr::DrrPatternBase { {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0")); } + + std::string pattern_name() const override { return "SumOpPattern"; } }; -class MaxOpPattern : public paddle::drr::DrrPatternBase { +class MaxOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -80,9 +81,11 @@ class MaxOpPattern : public paddle::drr::DrrPatternBase { {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } + + std::string pattern_name() const override { return "MaxOpPattern"; } }; -class MinOpPattern : public paddle::drr::DrrPatternBase { +class MinOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -105,9 +108,11 @@ class MinOpPattern : public paddle::drr::DrrPatternBase { {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } + + std::string pattern_name() const override { return "MinOpPattern"; } }; -class ProdOpPattern : public paddle::drr::DrrPatternBase { +class ProdOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -130,6 +135,8 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase { {"keep_dim", pattern.Attr("keep_dim")}}); res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } + + std::string pattern_name() const override { return "ProdOpPattern"; } }; class ScaleOpPattern : public pir::OpRewritePattern { @@ -586,7 +593,7 @@ class ExpandOpPattern } }; -class UniformOpPattern : public paddle::drr::DrrPatternBase { +class UniformOpPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -632,6 +639,8 @@ class UniformOpPattern : public paddle::drr::DrrPatternBase { {"diag_val", pattern.Attr("min_value")}}); res.Tensor("ret") = cinn_uniform(); } + + std::string pattern_name() const override { return "ProdOpPattern"; } }; PdOpToCinnOpPass::PdOpToCinnOpPass() diff --git a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc index cec66f7c70e2e..749e042bbf47b 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/split_generate_shape_into_shape_ops_pass.cc @@ -24,7 +24,6 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/match_context.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/dialect/shape/utils/dim_expr.h" #include "paddle/pir/pass/pass.h" diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 6fbac0756ae86..3a8e69584b68a 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -8,9 +8,8 @@ DRR can reduce the development cost of PASS, allowing developers to focus on pro Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows: ~~~ c++ -// 1. Inherit specialized template class from DrPatternBase -class RemoveRedundentCastPattern - : public paddle::drr::DrrPatternBase { +// 1. Inherit class from DrPatternBase +class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { // 2. Overload operator() void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute @@ -32,6 +31,10 @@ class RemoveRedundentCastPattern res.Op(paddle::dialect::CastOp::name(), {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } + + std::string pattern_name() const override { + return "RemoveRedundentCastPattern"; + } }; ~~~ @@ -165,7 +168,7 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const ## 3 Example Example 1: Matmul + Add -> FusedGemmEpilogue ~~~ c++ -class FusedLinearPattern : public paddle::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Define SourcePattern @@ -193,13 +196,16 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase Full ~~~ c++ -class FoldExpandToConstantPattern - : public paddle::drr::DrrPatternBase { +class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Define SourcePattern @@ -225,6 +231,10 @@ class FoldExpandToConstantPattern {"dtype", pat.Attr("dtype_1")}, {"place", pat.Attr("place_1")}}); res.Tensor("ret") = full2(); + } + + std::string pattern_name() const override { + return "FoldExpandToConstantPattern"; } }; ~~~ diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md index 1291bec2954c4..57cf8e23050a1 100644 --- a/paddle/fluid/pir/drr/README_cn.md +++ b/paddle/fluid/pir/drr/README_cn.md @@ -8,9 +8,8 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P 以消除冗余 CastOp 的 PASS 为例,使用 DRR 的代码开发示例如下: ~~~ c++ -// 1. 继承 DrrPatternBase 的特化模板类 -class RemoveRedundentCastPattern - : public paddle::drr::DrrPatternBase { +// 1. 继承 DrrPatternBase 类 +class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { // 2. 重载 operator() void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern @@ -32,6 +31,8 @@ class RemoveRedundentCastPattern res.Op(paddle::dialect::CastOp::name(), {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } + + std::string pattern_name() const override { return "RemoveRedundentCastPattern"; } }; ~~~ @@ -168,7 +169,7 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const ## 3 使用示例 Example 1: Matmul + Add -> FusedGemmEpilogue ~~~ c++ -class FusedLinearPattern : public paddle::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 定义 Source Pattern @@ -196,13 +197,14 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase Full ~~~ c++ -class FoldExpandToConstantPattern - : public paddle::drr::DrrPatternBase { +class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // 定义 Source Pattern @@ -229,5 +231,7 @@ class FoldExpandToConstantPattern {"place", pat.Attr("place_1")}}); res.Tensor("ret") = full2(); } + + std::string pattern_name() const override { return "FoldExpandToConstantPattern"; } }; ~~~ diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/api/drr_pattern_base.h index 18252d536869f..f5060361e7289 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_base.h +++ b/paddle/fluid/pir/drr/api/drr_pattern_base.h @@ -14,28 +14,38 @@ #pragma once +#include +#include + #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" -#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" +#include "paddle/fluid/pir/drr/api/drr_rewrite_pattern.h" +#include "paddle/fluid/pir/drr/api/match_context.h" + +namespace pir { +class IrContext; +} namespace paddle { namespace drr { -template +class DrrRewritePattern; +class DrrPatternContext; + class DrrPatternBase { public: virtual ~DrrPatternBase() = default; - // Define the Drr Pattern. - virtual void operator()(paddle::drr::DrrPatternContext* ctx) const = 0; - - std::unique_ptr Build( - pir::IrContext* ir_context, pir::PatternBenefit benefit = 1) const { - DrrPatternContext drr_context; - this->operator()(&drr_context); - std::string pattern_name = pir::get_type_name(); - return std::make_unique( - pattern_name, drr_context, ir_context, benefit); - } + // Define the drr pattern. + virtual void operator()(drr::DrrPatternContext* ctx) const = 0; + + // Give the drr pattern name. + virtual std::string pattern_name() const = 0; + + // Give the drr pattern benefit. + virtual uint32_t pattern_benefit() const { return 1; } + + // Build the Drr Pattern. + std::unique_ptr Build(pir::IrContext* ir_context) const; }; } // namespace drr diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.h b/paddle/fluid/pir/drr/api/drr_pattern_context.h index feb0e988aa882..bb864b85acc70 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/api/drr_pattern_context.h @@ -85,17 +85,17 @@ class TensorDataType { std::string tensor_name_; }; +using ConstraintFunction = std::function; class Constraint { public: - explicit Constraint( - const std::function& constrain_fn) + explicit Constraint(const ConstraintFunction& constrain_fn) : IsContextMatchConstraint_(constrain_fn) {} bool operator()(const MatchContext& match_context) const { return IsContextMatchConstraint_(match_context); } private: - std::function IsContextMatchConstraint_; + ConstraintFunction IsContextMatchConstraint_; }; class DrrPatternContext { @@ -132,8 +132,7 @@ class DrrPatternContext { // void RequireEqual(const Attribute& first, const Attribute& second); void RequireEqual(const TensorShape& first, const TensorShape& second); void RequireEqual(const TensorDataType& first, const TensorDataType& second); - void RequireNativeCall( - const std::function& custom_fn); + void RequireNativeCall(const ConstraintFunction& custom_fn); std::shared_ptr source_pattern_graph_; std::vector constraints_; @@ -322,8 +321,7 @@ class SourcePattern { ctx_->RequireEqual(first, second); } - void RequireNativeCall( - const std::function& custom_fn) { + void RequireNativeCall(const ConstraintFunction& custom_fn) { ctx_->RequireNativeCall(custom_fn); } diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/api/drr_rewrite_pattern.h similarity index 75% rename from paddle/fluid/pir/drr/drr_rewrite_pattern.h rename to paddle/fluid/pir/drr/api/drr_rewrite_pattern.h index 6163c6d9d0193..7a166f59013cc 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.h +++ b/paddle/fluid/pir/drr/api/drr_rewrite_pattern.h @@ -21,41 +21,28 @@ #include #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" -#include "paddle/fluid/pir/drr/api/match_context.h" -#include "paddle/fluid/pir/drr/ir_operation.h" -#include "paddle/fluid/pir/drr/ir_operation_factory.h" -#include "paddle/fluid/pir/drr/match_context_impl.h" -#include "paddle/fluid/pir/drr/pattern_graph.h" -#include "paddle/phi/core/enforce.h" -#include "paddle/pir/core/operation.h" -#include "paddle/pir/core/type_name.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" +namespace pir { +class IrContext; +} + namespace paddle { namespace drr { +class OpCall; +class Constraint; +class DrrPatternContext; +class MatchContextImpl; +class SourcePatternGraph; +class ResultPatternGraph; + class DrrRewritePattern : public pir::RewritePattern { public: explicit DrrRewritePattern(const std::string& pattern_name, const DrrPatternContext& drr_context, pir::IrContext* context, - pir::PatternBenefit benefit = 1) - : pir::RewritePattern( - drr_context.source_pattern_graph()->AnchorNode()->name(), - benefit, - context, - {}), - pattern_name_(pattern_name), - source_pattern_graph_(drr_context.source_pattern_graph()), - constraints_(drr_context.constraints()), - result_pattern_graph_(drr_context.result_pattern_graph()) { - PADDLE_ENFORCE_NE( - source_pattern_graph_->owned_op_call().empty(), - true, - phi::errors::InvalidArgument("Source pattern graph is empty." - "Suggested fix: Please check the DRR " - "source pattern definition code.")); - } + pir::PatternBenefit benefit); bool MatchAndRewrite( pir::Operation* op, diff --git a/paddle/fluid/pir/drr/ir_operation.h b/paddle/fluid/pir/drr/api/ir_operation.h similarity index 95% rename from paddle/fluid/pir/drr/ir_operation.h rename to paddle/fluid/pir/drr/api/ir_operation.h index a88bb3bfff97c..b13b1b6c8395a 100644 --- a/paddle/fluid/pir/drr/ir_operation.h +++ b/paddle/fluid/pir/drr/api/ir_operation.h @@ -14,7 +14,9 @@ #pragma once -#include "paddle/pir/core/operation.h" +namespace pir { +class Operation; +} namespace paddle { namespace drr { diff --git a/paddle/fluid/pir/drr/ir_value.h b/paddle/fluid/pir/drr/api/ir_value.h similarity index 98% rename from paddle/fluid/pir/drr/ir_value.h rename to paddle/fluid/pir/drr/api/ir_value.h index ae99fd8c1964e..e21b610fc8739 100644 --- a/paddle/fluid/pir/drr/ir_value.h +++ b/paddle/fluid/pir/drr/api/ir_value.h @@ -35,7 +35,7 @@ class IrShape { int64_t at(int idx) const { return dims_.at(idx); } private: - const phi::DDim dims_; + const common::DDim dims_; }; class IrDtype { @@ -109,7 +109,5 @@ class IrValue : public TensorInterface { const IrDtype dtype_; }; -class IrAttr; - } // namespace drr } // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.h b/paddle/fluid/pir/drr/api/match_context.h index 762c86cf8a8e6..7821e9c73e365 100644 --- a/paddle/fluid/pir/drr/api/match_context.h +++ b/paddle/fluid/pir/drr/api/match_context.h @@ -18,7 +18,6 @@ #include #include "paddle/fluid/pir/drr/api/tensor_interface.h" -#include "paddle/fluid/pir/drr/ir_operation.h" namespace paddle { namespace drr { diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc index 335f95214887a..35be9de499750 100644 --- a/paddle/fluid/pir/drr/api/tensor_interface.cc +++ b/paddle/fluid/pir/drr/api/tensor_interface.cc @@ -13,7 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/drr/api/tensor_interface.h" -#include "paddle/fluid/pir/drr/ir_value.h" +#include "paddle/fluid/pir/drr/api/ir_value.h" namespace paddle { namespace drr { diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.cc b/paddle/fluid/pir/drr/drr_pattern_context.cc similarity index 93% rename from paddle/fluid/pir/drr/api/drr_pattern_context.cc rename to paddle/fluid/pir/drr/drr_pattern_context.cc index 7f98f0b34cbeb..65d72a9f58175 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.cc +++ b/paddle/fluid/pir/drr/drr_pattern_context.cc @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include #include "paddle/fluid/pir/drr/pattern_graph.h" #include "paddle/phi/core/enforce.h" @@ -60,14 +61,6 @@ std::vector DrrPatternContext::constraints() const { return constraints_; } -// void DrrPatternContext::RequireEqual(const Attribute& first, const Attribute& -// second) { -// auto constrain_fn = [&](const MatchContext& match_context) { -// return match_context.Attr(first.id()) == match_context.Attr(second.id()); -// }; -// constraints_.emplace_back(constrain_fn); -// } - void DrrPatternContext::RequireEqual(const TensorShape& first, const TensorShape& second) { // Note: we capture the datas by value for constrain_fn @@ -90,8 +83,7 @@ void DrrPatternContext::RequireEqual(const TensorDataType& first, constraints_.emplace_back(constrain_fn); } -void DrrPatternContext::RequireNativeCall( - const std::function& custom_fn) { +void DrrPatternContext::RequireNativeCall(const ConstraintFunction& custom_fn) { constraints_.emplace_back(custom_fn); } diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc index d408c1aab1349..9c32354932510 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc @@ -12,11 +12,46 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/drr/drr_rewrite_pattern.h" +// #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +// #include "paddle/fluid/pir/drr/api/ir_operation.h" +// #include "paddle/fluid/pir/drr/api/match_context.h" + +// #include "paddle/phi/core/enforce.h" +// #include "paddle/pir/core/operation.h" +// #include "paddle/pir/core/type_name.h" +// #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" + +#include "paddle/fluid/pir/drr/api/drr_rewrite_pattern.h" + +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/ir_operation_factory.h" +#include "paddle/fluid/pir/drr/match_context_impl.h" +#include "paddle/fluid/pir/drr/pattern_graph.h" namespace paddle { namespace drr { +DrrRewritePattern::DrrRewritePattern(const std::string& pattern_name, + const DrrPatternContext& drr_context, + pir::IrContext* context, + pir::PatternBenefit benefit) + : pir::RewritePattern( + drr_context.source_pattern_graph()->AnchorNode()->name(), + benefit, + context, + {}), + pattern_name_(pattern_name), + source_pattern_graph_(drr_context.source_pattern_graph()), + constraints_(drr_context.constraints()), + result_pattern_graph_(drr_context.result_pattern_graph()) { + PADDLE_ENFORCE_NE( + source_pattern_graph_->owned_op_call().empty(), + true, + phi::errors::InvalidArgument("Source pattern graph is empty." + "Suggested fix: Please check the DRR " + "source pattern definition code.")); +} + bool DrrRewritePattern::MatchAndRewrite( pir::Operation* op, pir::PatternRewriter& rewriter) const { // NOLINT @@ -25,6 +60,7 @@ bool DrrRewritePattern::MatchAndRewrite( if (PatternGraphMatch(op, src_match_ctx.get())) { VLOG(4) << "DRR pattern (" << pattern_name_ << ") is matched in program."; PatternGraphRewrite(*src_match_ctx, rewriter); + VLOG(4) << "DRR pattern (" << pattern_name_ << ") is rewrited in program."; return true; } return false; @@ -516,5 +552,13 @@ void DrrRewritePattern::DeleteSourcePatternOp( } } +std::unique_ptr DrrPatternBase::Build( + pir::IrContext* ir_context) const { + DrrPatternContext drr_context; + this->operator()(&drr_context); + return std::make_unique( + pattern_name(), drr_context, ir_context, pattern_benefit()); +} + } // namespace drr } // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.cc b/paddle/fluid/pir/drr/match_context.cc similarity index 97% rename from paddle/fluid/pir/drr/api/match_context.cc rename to paddle/fluid/pir/drr/match_context.cc index e5f15adf72e75..c171482ed2b0a 100644 --- a/paddle/fluid/pir/drr/api/match_context.cc +++ b/paddle/fluid/pir/drr/match_context.cc @@ -16,7 +16,7 @@ #include -#include "paddle/fluid/pir/drr/ir_operation.h" +#include "paddle/fluid/pir/drr/api/ir_operation.h" #include "paddle/fluid/pir/drr/match_context_impl.h" namespace paddle { diff --git a/paddle/fluid/pir/drr/match_context_impl.h b/paddle/fluid/pir/drr/match_context_impl.h index b1234d8129936..21cca6ead2c3a 100644 --- a/paddle/fluid/pir/drr/match_context_impl.h +++ b/paddle/fluid/pir/drr/match_context_impl.h @@ -19,10 +19,10 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/api/ir_operation.h" +#include "paddle/fluid/pir/drr/api/ir_value.h" #include "paddle/fluid/pir/drr/api/tensor_interface.h" #include "paddle/fluid/pir/drr/attr_type_uilts.h" -#include "paddle/fluid/pir/drr/ir_operation.h" -#include "paddle/fluid/pir/drr/ir_value.h" #include "paddle/pir/core/builtin_attribute.h" namespace paddle { diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index ab19247de4b26..acebc1d854d99 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -21,8 +21,7 @@ namespace { -class MultiHeadMatmulFusePattern - : public paddle::drr::DrrPatternBase { +class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // @@ -216,6 +215,10 @@ class MultiHeadMatmulFusePattern &res.Tensor("add_4_in_2")}, {&res.Tensor("reshape_4_out")}); } + + std::string pattern_name() const override { + return "MultiHeadMatmulFusePattern"; + } }; class AttentionFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc index e86dc04037fa0..39b5d7cf087df 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc @@ -28,8 +28,7 @@ namespace { -class Conv2dAddFusePattern - : public paddle::drr::DrrPatternBase { +class Conv2dAddFusePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -83,6 +82,8 @@ class Conv2dAddFusePattern &res.NoneTensor()}, {&res.Tensor("add_out")}); } + + std::string pattern_name() const override { return "Conv2dAddFusePattern"; } }; class Conv2dAddFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc index 7e5c4bbe8ea18..c0ca75fcb5a8f 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc @@ -21,8 +21,7 @@ namespace { -class FcElementwiseLayerNormFusePattern - : public paddle::drr::DrrPatternBase { +class FcElementwiseLayerNormFusePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -89,10 +88,13 @@ class FcElementwiseLayerNormFusePattern &res.Tensor("layernorm_mean"), &res.Tensor("layernorm_variance")}); } + + std::string pattern_name() const override { + return "FcElementwiseLayerNormFusePattern"; + } }; -class FcElementwiseLayerNormFuse2Pattern - : public paddle::drr::DrrPatternBase { +class FcElementwiseLayerNormFuse2Pattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -150,6 +152,10 @@ class FcElementwiseLayerNormFuse2Pattern &res.Tensor("layernorm_mean"), &res.Tensor("layernorm_variance")}); } + + std::string pattern_name() const override { + return "FcElementwiseLayerNormFuse2Pattern"; + } }; class FcElementwiseLayerNormFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc index b49ab9ff4ac77..749a54383d95a 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc @@ -21,7 +21,7 @@ namespace { -class MatmulAddPattern : public paddle::drr::DrrPatternBase { +class MatmulAddPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -79,10 +79,11 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase { fc({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")}, {&res.Tensor("add_out")}); } + + std::string pattern_name() const override { return "MatmulAddPattern"; } }; -class FcWithReluPattern - : public paddle::drr::DrrPatternBase { +class FcWithReluPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -117,6 +118,8 @@ class FcWithReluPattern fc_with_relu({&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("y")}, {&res.Tensor("relu_out")}); } + + std::string pattern_name() const override { return "FcWithReluPattern"; } }; class FcFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc index 9b2e7f2f3f2e7..c935b66d2fe3d 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc @@ -21,8 +21,7 @@ namespace { -class FusedDotProductAttentionPattern - : public paddle::drr::DrrPatternBase { +class FusedDotProductAttentionPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -137,10 +136,13 @@ class FusedDotProductAttentionPattern &res.Tensor("softmax_aux"), &res.Tensor("rng_state")}); } + + std::string pattern_name() const override { + return "FusedDotProductAttentionPattern"; + } }; -class FusedDotProductAttentionGradPattern - : public paddle::drr::DrrPatternBase { +class FusedDotProductAttentionGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -314,11 +316,14 @@ class FusedDotProductAttentionGradPattern &res.Tensor("out_grad")}, {&res.Tensor("q_grad"), &res.Tensor("k_grad"), &res.Tensor("v_grad")}); } + + std::string pattern_name() const override { + return "FusedDotProductAttentionGradPattern"; + } }; class FusedDotProductAttentionWithDropoutPattern - : public paddle::drr::DrrPatternBase< - FusedDotProductAttentionWithDropoutPattern> { + : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -441,11 +446,14 @@ class FusedDotProductAttentionWithDropoutPattern &res.Tensor("softmax_aux"), &res.Tensor("rng_state")}); } + + std::string pattern_name() const override { + return "FusedDotProductAttentionWithDropoutPattern"; + } }; class FusedDotProductAttentionGradWithDropoutPattern - : public paddle::drr::DrrPatternBase< - FusedDotProductAttentionGradWithDropoutPattern> { + : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -629,12 +637,16 @@ class FusedDotProductAttentionGradWithDropoutPattern &res.Tensor("out_grad")}, {&res.Tensor("q_grad"), &res.Tensor("k_grad"), &res.Tensor("v_grad")}); } + + std::string pattern_name() const override { + return "FusedDotProductAttentionGradWithDropoutPattern"; + } }; class FusedDotProductAttentionPass : public pir::PatternRewritePass { public: FusedDotProductAttentionPass() - : pir::PatternRewritePass("fused_dot_product_attention_pass", 1) {} + : pir::PatternRewritePass("fused_dot_product_attention_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc index df8b39cfc8676..f728b888d2461 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -21,8 +21,7 @@ namespace { -class FusedDropoutAddPattern - : public paddle::drr::DrrPatternBase { +class FusedDropoutAddPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -50,10 +49,11 @@ class FusedDropoutAddPattern {&res.Tensor("x"), &res.Tensor("y"), &res.Tensor("seed_tensor")}, {&res.Tensor("add_out"), &res.Tensor("mask")}); } + + std::string pattern_name() const override { return "FusedDropoutAddPattern"; } }; -class FusedDropoutGradAddGradPattern - : public paddle::drr::DrrPatternBase { +class FusedDropoutGradAddGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -103,12 +103,16 @@ class FusedDropoutGradAddGradPattern fused_dropout_add_grad({&res.Tensor("mask"), &res.Tensor("add_out_grad")}, {&res.Tensor("x_grad"), &res.Tensor("y_grad")}); } + + std::string pattern_name() const override { + return "FusedDropoutGradAddGradPattern"; + } }; class FusedDropoutAddPass : public pir::PatternRewritePass { public: FusedDropoutAddPass() - : pir::PatternRewritePass("fused_dropout_add_pass", 1) {} + : pir::PatternRewritePass("fused_dropout_add_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 02a6b4744cdcb..36404dae16381 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -21,8 +21,7 @@ namespace { -class FusedLinearPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -54,10 +53,11 @@ class FusedLinearPattern {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out")}); } + + std::string pattern_name() const override { return "FusedLinearPattern"; } }; -class FusedLinearGradPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearGradPattern : public paddle::drr::DrrPatternBase < { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -109,10 +109,11 @@ class FusedLinearGradPattern &res.Tensor("w_grad"), &res.Tensor("bias_grad")}); } + + std::string pattern_name() const override { return "FusedLinearGradPattern"; } }; -class FusedLinearGeluPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearGeluPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -148,9 +149,10 @@ class FusedLinearGeluPattern {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out"), &res.Tensor("reserve_space")}); } + + std::string pattern_name() const override { return "FusedLinearGeluPattern"; } }; -class FusedLinearReluPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearReluPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -186,10 +188,11 @@ class FusedLinearReluPattern {&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")}, {&res.Tensor("out"), &res.Tensor("reserve_space")}); } + + std::string pattern_name() const override { return "FusedLinearReluPattern"; } }; -class FusedLinearGeluGradPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearGeluGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -254,10 +257,13 @@ class FusedLinearGeluGradPattern &res.Tensor("w1_grad"), &res.Tensor("bias1_grad")}); } + + std::string pattern_name() const override { + return "FusedLinearGeluGradPattern"; + } }; -class FusedLinearReluGradPattern - : public paddle::drr::DrrPatternBase { +class FusedLinearReluGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -322,6 +328,10 @@ class FusedLinearReluGradPattern &res.Tensor("w1_grad"), &res.Tensor("bias1_grad")}); } + + std::string pattern_name() const override { + return "FusedLinearReluGradPattern"; + } }; class FusedGemmEpiloguePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc index 8c93ff9822675..737fdf5b3f396 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -21,8 +21,7 @@ namespace { // add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add -class FusedMatmulAddGradAddPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddGradAddPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -89,11 +88,14 @@ class FusedMatmulAddGradAddPattern &res.NoneTensor()}, {&res.Tensor("add_out"), &res.Tensor("dbias")}); } + + std::string pattern_name() const override { + return "FusedMatmulAddGradAddPattern"; + } }; // matmul_grad + add_ -> matmul + fused_liner_param_gard_add -class FusedMatmulGradAddPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulGradAddPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -149,11 +151,14 @@ class FusedMatmulGradAddPattern &res.NoneTensor()}, {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } + + std::string pattern_name() const override { + return "FusedMatmulGradAddPattern"; + } }; // matmul + 0 = add_(0,1) -> fused_liner_param_gard_add -class FusedMatmulAddaPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddaPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -198,11 +203,12 @@ class FusedMatmulAddaPattern &res.NoneTensor()}, {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } + + std::string pattern_name() const override { return "FusedMatmulAddaPattern"; } }; // matmul + 1 = add_(1,0) -> fused_liner_param_gard_add -class FusedMatmulAddbPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddbPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -247,11 +253,12 @@ class FusedMatmulAddbPattern &res.NoneTensor()}, {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } + + std::string pattern_name() const override { return "FusedMatmulAddbPattern"; } }; // add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add -class FusedMatmulAddGradAddaPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddGradAddaPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -304,11 +311,14 @@ class FusedMatmulAddGradAddaPattern &res.NoneTensor()}, {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); } + + std::string pattern_name() const override { + return "FusedMatmulAddGradAddaPattern"; + } }; // add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add -class FusedMatmulAddGradAddbPattern - : public paddle::drr::DrrPatternBase { +class FusedMatmulAddGradAddbPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -361,12 +371,16 @@ class FusedMatmulAddGradAddbPattern &res.NoneTensor()}, {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); } + + std::string pattern_name() const override { + return "FusedMatmulAddGradAddbPattern"; + } }; class FusedLinearParamGradAddPass : public pir::PatternRewritePass { public: FusedLinearParamGradAddPass() - : pir::PatternRewritePass("fused_linear_param_grad_add_pass", 1) {} + : pir::PatternRewritePass("fused_linear_param_grad_add_pass", 2) {} pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index 82864f3d80e88..d324e1f2ece23 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -35,8 +35,7 @@ int getSMVersion() { return sm_version; } -class FusedWeightOnlyLinearPattern - : public paddle::drr::DrrPatternBase { +class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // @@ -126,6 +125,10 @@ class FusedWeightOnlyLinearPattern &res.Tensor("weight_scale_tensor")}, {&res.Tensor("add_out")}); } + + std::string pattern_name() const override { + return "FusedWeightOnlyLinearPattern"; + } }; class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc index 0bced0b8ec823..4dbb6cab4b5bc 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -27,8 +27,7 @@ namespace { -class MatmulScaleFusePattern - : public paddle::drr::DrrPatternBase { +class MatmulScaleFusePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -76,6 +75,8 @@ class MatmulScaleFusePattern matmul_op_res({&res.Tensor("x"), &res.Tensor("scale_res_out")}, {&res.Tensor("scale_out")}); } + + std::string pattern_name() const override { return "MatmulScaleFusePattern"; } }; class MatmulScaleFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index ac49d494d1c73..2e2e1845d2b9b 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/api/drr_pattern_base.h" -#include "paddle/fluid/pir/drr/ir_value.h" +#include "paddle/fluid/pir/drr/api/ir_value.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" @@ -31,8 +31,7 @@ namespace { -class RemoveUselessScalePattern - : public paddle::drr::DrrPatternBase { +class RemoveUselessScalePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -55,10 +54,13 @@ class RemoveUselessScalePattern paddle::drr::ResultPattern res = pat.ResultPattern(); res.Tensor("scale_out").Assign(res.Tensor("x")); } + + std::string pattern_name() const override { + return "RemoveUselessScalePattern"; + } }; -class RemoveRedundentScalePattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -126,10 +128,13 @@ class RemoveRedundentScalePattern scale_op_res({&res.Tensor("x"), &full_op_res()}, {&res.Tensor("scale_2_out")}); } + + std::string pattern_name() const override { + return "RemoveRedundentScalePattern"; + } }; -class RemoveUselessCastPattern - : public paddle::drr::DrrPatternBase { +class RemoveUselessCastPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); @@ -138,10 +143,13 @@ class RemoveUselessCastPattern auto res = pat.ResultPattern(); res.Tensor("ret").Assign(res.Tensor("arg0")); } + + std::string pattern_name() const override { + return "RemoveUselessCastPattern"; + } }; -class RemoveUselessConcatPattern - : public paddle::drr::DrrPatternBase { +class RemoveUselessConcatPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); @@ -158,10 +166,13 @@ class RemoveUselessConcatPattern auto res = pat.ResultPattern(); res.Tensor("out").Assign(res.Tensor("x")); } + + std::string pattern_name() const override { + return "RemoveUselessConcatPattern"; + } }; -class RemoveRedundentCastPattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = pat.Op( @@ -172,10 +183,13 @@ class RemoveRedundentCastPattern res.Tensor("ret") = res.Op( "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } + + std::string pattern_name() const override { + return "RemoveRedundentCastPattern"; + } }; -class RemoveRedundentTransposePattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -202,6 +216,10 @@ class RemoveRedundentTransposePattern res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); } + + std::string pattern_name() const override { + return "RemoveRedundentTransposePattern"; + } }; class IdentityOpCleanPass : public pir::PatternRewritePass { diff --git a/paddle/pir/pattern_rewrite/pattern_match.h b/paddle/pir/pattern_rewrite/pattern_match.h index a0c34d8f58f07..475779f99cb28 100644 --- a/paddle/pir/pattern_rewrite/pattern_match.h +++ b/paddle/pir/pattern_rewrite/pattern_match.h @@ -37,7 +37,7 @@ namespace pir { // This class reprensents the benefit of a pattern. The most common -// unit to use is the `numver of operations` in the pattern. +// unit to use is the `number of operations` in the pattern. class IR_API PatternBenefit { public: PatternBenefit() = default; diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc index 1a938e7f600b7..6cc0a16611904 100644 --- a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -49,11 +49,10 @@ output0 output1 output2 output3 output4 output5 output6 */ -class SameTypeBindingTestPattern - // This class is for test cases of the same type of OP. - // (without considering the computational logic between OPs, - // only focusing on the process of matching and replacing) - : public paddle::drr::DrrPatternBase { +// This class is for test cases of the same type of OP. +// (without considering the computational logic between OPs, +// only focusing on the process of matching and replacing) +class SameTypeBindingTestPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern src = ctx->SourcePattern(); @@ -179,6 +178,10 @@ class SameTypeBindingTestPattern res.Tensor("output5") = full_5(); res.Tensor("output6") = full_6(); } + + std::string pattern_name() const override { + return "SameTypeBindingTestPattern"; + } }; void BuildProgram(pir::Builder &builder) { // NOLINT diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc index 54b5ff2025e49..f2f032a0181df 100644 --- a/test/cpp/pir/pattern_rewrite/drr_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -23,8 +23,7 @@ #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass_manager.h" -class RemoveRedundentReshapePattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentReshapePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source patterns @@ -42,10 +41,13 @@ class RemoveRedundentReshapePattern res.Op("pd_op.reshape")({&res.Tensor("arg0"), &res.Tensor("shape1")}, {&res.Tensor("ret"), &res.Tensor("xshape_1")}); } + + std::string pattern_name() const override { + return "RemoveRedundentReshapePattern"; + } }; -class FoldExpandToConstantPattern - : public paddle::drr::DrrPatternBase { +class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { // Source Pattern @@ -79,10 +81,13 @@ class FoldExpandToConstantPattern {"place", pat.Attr("place_1")}}); res.Tensor("ret") = full2(); } + + std::string pattern_name() const override { + return "FoldExpandToConstantPattern"; + } }; -class RemoveRedundentTransposePattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -109,10 +114,13 @@ class RemoveRedundentTransposePattern res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); } + + std::string pattern_name() const override { + return "RemoveRedundentTransposePattern"; + } }; -class RemoveRedundentCastPattern - : public paddle::drr::DrrPatternBase { +class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); pat.Tensor("tmp") = pat.Op( @@ -123,10 +131,13 @@ class RemoveRedundentCastPattern res.Tensor("ret") = res.Op( "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } + + std::string pattern_name() const override { + return "RemoveRedundentCastPattern"; + } }; -class RemoveUselessCastPattern - : public paddle::drr::DrrPatternBase { +class RemoveUselessCastPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); @@ -135,6 +146,10 @@ class RemoveUselessCastPattern auto res = pat.ResultPattern(); res.Tensor("ret").Assign(res.Tensor("arg0")); } + + std::string pattern_name() const override { + return "RemoveUselessCastPattern"; + } }; void BuildProgram(pir::Builder &builder) { // NOLINT From 078fa99ed8746d27c92e0a086bee2741a91c36f1 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Fri, 12 Jan 2024 11:44:22 +0000 Subject: [PATCH 2/5] reconstruct drr v2 --- .../operator/transforms/pd_to_cinn_pass.cc | 2 +- paddle/fluid/pir/drr/CMakeLists.txt | 2 +- paddle/fluid/pir/drr/api/ir_operation.h | 35 ------ paddle/fluid/pir/drr/api/ir_value.h | 113 ------------------ paddle/fluid/pir/drr/api/tensor_interface.cc | 36 ------ paddle/fluid/pir/drr/api/tensor_interface.h | 63 ---------- .../drr_match_context.h} | 6 +- .../drr/{api => include}/drr_pattern_base.h | 6 +- .../{api => include}/drr_pattern_context.h | 4 +- .../{api => include}/drr_rewrite_pattern.h | 2 +- paddle/fluid/pir/drr/ir_operation_factory.cc | 8 +- paddle/fluid/pir/drr/ir_operation_factory.h | 2 +- paddle/fluid/pir/drr/match_context.cc | 7 +- paddle/fluid/pir/drr/match_context_impl.h | 40 +++---- ..._pattern_context.cc => pattern_context.cc} | 13 +- paddle/fluid/pir/drr/pattern_graph.cc | 2 +- ..._rewrite_pattern.cc => rewrite_pattern.cc} | 48 +++----- .../transforms/fusion/attention_fuse_pass.cc | 6 +- .../transforms/fusion/conv2d_add_fuse_pass.cc | 2 +- .../fc_elementwise_layernorm_fuse_pass.cc | 19 +-- .../pir/transforms/fusion/fc_fuse_pass.cc | 27 ++--- .../fused_dot_product_attention_pass.cc | 2 +- .../fusion/fused_dropout_add_pass.cc | 2 +- .../fusion/fused_gemm_epilogue_pass.cc | 23 ++-- .../fused_linear_param_grad_add_pass.cc | 81 ++++++++----- .../fusion/fused_weight_only_linear_pass.cc | 15 +-- .../fusion/matmul_scale_fuse_pass.cc | 2 +- .../pir/transforms/identity_op_clean_pass.cc | 10 +- test/cpp/pir/cinn/dialect_convert_test.cc | 2 +- .../drr_same_type_binding_test.cc | 2 +- test/cpp/pir/pattern_rewrite/drr_test.cc | 2 +- 31 files changed, 171 insertions(+), 413 deletions(-) delete mode 100644 paddle/fluid/pir/drr/api/ir_operation.h delete mode 100644 paddle/fluid/pir/drr/api/ir_value.h delete mode 100644 paddle/fluid/pir/drr/api/tensor_interface.cc delete mode 100644 paddle/fluid/pir/drr/api/tensor_interface.h rename paddle/fluid/pir/drr/{api/match_context.h => include/drr_match_context.h} (89%) rename paddle/fluid/pir/drr/{api => include}/drr_pattern_base.h (87%) rename paddle/fluid/pir/drr/{api => include}/drr_pattern_context.h (99%) rename paddle/fluid/pir/drr/{api => include}/drr_rewrite_pattern.h (98%) rename paddle/fluid/pir/drr/{drr_pattern_context.cc => pattern_context.cc} (92%) rename paddle/fluid/pir/drr/{drr_rewrite_pattern.cc => rewrite_pattern.cc} (93%) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 45c7c3900b166..3c503be702410 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -19,7 +19,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h" #include "paddle/cinn/hlir/framework/pir/utils.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/pass.h" diff --git a/paddle/fluid/pir/drr/CMakeLists.txt b/paddle/fluid/pir/drr/CMakeLists.txt index fa43d828d05bc..d35693c674c61 100644 --- a/paddle/fluid/pir/drr/CMakeLists.txt +++ b/paddle/fluid/pir/drr/CMakeLists.txt @@ -1,4 +1,4 @@ -file(GLOB DRR_SRCS "*.cc" "api/*.cc") +file(GLOB DRR_SRCS "*.cc" "include/*.cc") set(op_creator_gen_file ${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py diff --git a/paddle/fluid/pir/drr/api/ir_operation.h b/paddle/fluid/pir/drr/api/ir_operation.h deleted file mode 100644 index b13b1b6c8395a..0000000000000 --- a/paddle/fluid/pir/drr/api/ir_operation.h +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -namespace pir { -class Operation; -} - -namespace paddle { -namespace drr { - -class IrOperation { - public: - explicit IrOperation(pir::Operation* op) : op_(op) {} - - pir::Operation* get() const { return op_; } - - private: - pir::Operation* op_; -}; - -} // namespace drr -} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/ir_value.h b/paddle/fluid/pir/drr/api/ir_value.h deleted file mode 100644 index e21b610fc8739..0000000000000 --- a/paddle/fluid/pir/drr/api/ir_value.h +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" -#include "paddle/fluid/pir/drr/api/tensor_interface.h" -#include "paddle/pir/core/type.h" -#include "paddle/pir/core/value.h" - -namespace paddle { -namespace drr { - -class IrShape { - public: - explicit IrShape(const phi::DDim& dims) : dims_(dims) {} - - bool operator==(const IrShape& other) const { return dims_ == other.dims_; } - - int size() const { return dims_.size(); } - - int64_t at(int idx) const { return dims_.at(idx); } - - private: - const common::DDim dims_; -}; - -class IrDtype { - public: - explicit IrDtype(pir::Type dtype) : dtype_(dtype) {} - - bool operator==(IrDtype other) const { return dtype_ == other.dtype_; } - - template - bool isa() const { - return dtype_.isa(); - } - - template - T dyn_cast() const { - return dtype_.dyn_cast(); - } - - private: - const pir::Type dtype_; -}; - -class IrValue : public TensorInterface { - public: - explicit IrValue(const pir::Value& value) - : value_(value), - shape_((value && value.type() && - value.type().dyn_cast()) - ? value.type() - .dyn_cast() - .dims() - : phi::DDim{}), - dtype_((value && value.type() && - value.type().dyn_cast()) - ? value.type() - .dyn_cast() - .dtype() - : pir::Type{}) {} - - ShapeInterface Shape() const override { return ShapeInterface(&shape_); } - DtypeInterface Dtype() const override { return DtypeInterface(&dtype_); } - - explicit operator bool() const { return value_.operator bool(); } - - template - bool isa() const { - return value_.isa(); - } - - template - T dyn_cast() const { - return value_.dyn_cast(); - } - - template - bool type_isa() const { - return value_.type().isa(); - } - - template - T type_dyn_cast() const { - return value_.type().dyn_cast(); - } - - // Don't use it in drr pass! - const pir::Value& get() const { return value_; } - - private: - const pir::Value value_; - const IrShape shape_; - const IrDtype dtype_; -}; - -} // namespace drr -} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/tensor_interface.cc b/paddle/fluid/pir/drr/api/tensor_interface.cc deleted file mode 100644 index 35be9de499750..0000000000000 --- a/paddle/fluid/pir/drr/api/tensor_interface.cc +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "paddle/fluid/pir/drr/api/tensor_interface.h" -#include "paddle/fluid/pir/drr/api/ir_value.h" - -namespace paddle { -namespace drr { - -bool ShapeInterface::operator==(const ShapeInterface& other) const { - return *shape_ == *other.shape_; -} - -int ShapeInterface::size() const { return shape_->size(); } - -int64_t ShapeInterface::at(int idx) const { return shape_->at(idx); } - -bool DtypeInterface::operator==(const DtypeInterface& other) const { - return *dtype_ == *other.dtype_; -} - -IrDtype DtypeInterface::get() const { return *(this->dtype_); } - -} // namespace drr -} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/tensor_interface.h b/paddle/fluid/pir/drr/api/tensor_interface.h deleted file mode 100644 index 24774f00d5a29..0000000000000 --- a/paddle/fluid/pir/drr/api/tensor_interface.h +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include - -namespace paddle { -namespace drr { - -class IrValue; -class IrShape; -class IrDtype; - -class ShapeInterface final { - public: - bool operator==(const ShapeInterface& other) const; - - int size() const; - - int64_t at(int idx) const; - - private: - explicit ShapeInterface(const IrShape* shape) : shape_(shape) {} - - friend class IrValue; - - const IrShape* shape_; -}; - -class DtypeInterface final { - public: - bool operator==(const DtypeInterface& other) const; - - IrDtype get() const; - - private: - explicit DtypeInterface(const IrDtype* dtype) : dtype_(dtype) {} - - friend class IrValue; - - const IrDtype* dtype_; -}; - -class TensorInterface { - public: - virtual ShapeInterface Shape() const = 0; - virtual DtypeInterface Dtype() const = 0; -}; - -} // namespace drr -} // namespace paddle diff --git a/paddle/fluid/pir/drr/api/match_context.h b/paddle/fluid/pir/drr/include/drr_match_context.h similarity index 89% rename from paddle/fluid/pir/drr/api/match_context.h rename to paddle/fluid/pir/drr/include/drr_match_context.h index 7821e9c73e365..4339595b710d4 100644 --- a/paddle/fluid/pir/drr/api/match_context.h +++ b/paddle/fluid/pir/drr/include/drr_match_context.h @@ -17,7 +17,9 @@ #include #include -#include "paddle/fluid/pir/drr/api/tensor_interface.h" +namespace pir { +class Value; +} namespace paddle { namespace drr { @@ -29,7 +31,7 @@ class MatchContext final { public: MatchContext(std::shared_ptr impl); - const TensorInterface& Tensor(const std::string& tensor_name) const; + const pir::Value& Tensor(const std::string& tensor_name) const; template T Attr(const std::string& attr_name) const; diff --git a/paddle/fluid/pir/drr/api/drr_pattern_base.h b/paddle/fluid/pir/drr/include/drr_pattern_base.h similarity index 87% rename from paddle/fluid/pir/drr/api/drr_pattern_base.h rename to paddle/fluid/pir/drr/include/drr_pattern_base.h index f5060361e7289..f4dfc2ad12747 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_base.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_base.h @@ -17,9 +17,9 @@ #include #include -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" -#include "paddle/fluid/pir/drr/api/drr_rewrite_pattern.h" -#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/fluid/pir/drr/include/drr_match_context.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/include/drr_rewrite_pattern.h" namespace pir { class IrContext; diff --git a/paddle/fluid/pir/drr/api/drr_pattern_context.h b/paddle/fluid/pir/drr/include/drr_pattern_context.h similarity index 99% rename from paddle/fluid/pir/drr/api/drr_pattern_context.h rename to paddle/fluid/pir/drr/include/drr_pattern_context.h index bb864b85acc70..ec0508b05f592 100644 --- a/paddle/fluid/pir/drr/api/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_context.h @@ -22,7 +22,7 @@ #include #include -#include "paddle/fluid/pir/drr/api/match_context.h" +#include "paddle/fluid/pir/drr/include/drr_match_context.h" namespace paddle { namespace drr { @@ -190,8 +190,6 @@ class Tensor { public: static const char NONE_TENSOR_NAME[]; - const std::string& DebugName() const; - TensorShape shape() const { return TensorShape(name()); } TensorDataType dtype() const { return TensorDataType(name()); } diff --git a/paddle/fluid/pir/drr/api/drr_rewrite_pattern.h b/paddle/fluid/pir/drr/include/drr_rewrite_pattern.h similarity index 98% rename from paddle/fluid/pir/drr/api/drr_rewrite_pattern.h rename to paddle/fluid/pir/drr/include/drr_rewrite_pattern.h index 7a166f59013cc..11d07b7fca269 100644 --- a/paddle/fluid/pir/drr/api/drr_rewrite_pattern.h +++ b/paddle/fluid/pir/drr/include/drr_rewrite_pattern.h @@ -20,7 +20,7 @@ #include #include -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" namespace pir { diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc index bbc31e9df7c25..c552550b98c2a 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -135,7 +135,7 @@ pir::Value GetIrValueByDrrTensor(const Tensor& tensor, if (tensor.is_none()) { return pir::Value{}; } - return res_match_ctx.GetIrValue(tensor.name()).get(); + return res_match_ctx.GetIrValue(tensor.name()); } std::vector GetIrValuesByDrrTensors( @@ -153,11 +153,7 @@ void BindIrOutputs(const OpCall& op_call, pir::Operation* op, MatchContextImpl* match_ctx) { for (size_t i = 0; i < op_call.outputs().size(); ++i) { - std::shared_ptr ir_value = nullptr; - if (op->result(i)) { - ir_value = std::make_shared(op->result(i)); - } - match_ctx->BindIrValue(op_call.outputs()[i]->name(), ir_value); + match_ctx->BindIrValue(op_call.outputs()[i]->name(), op->result(i)); } } diff --git a/paddle/fluid/pir/drr/ir_operation_factory.h b/paddle/fluid/pir/drr/ir_operation_factory.h index 40682904df62a..ac59a0310b63f 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.h +++ b/paddle/fluid/pir/drr/ir_operation_factory.h @@ -16,7 +16,7 @@ #include -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/fluid/pir/drr/match_context_impl.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" diff --git a/paddle/fluid/pir/drr/match_context.cc b/paddle/fluid/pir/drr/match_context.cc index c171482ed2b0a..3da7b24e5df4a 100644 --- a/paddle/fluid/pir/drr/match_context.cc +++ b/paddle/fluid/pir/drr/match_context.cc @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/drr/api/match_context.h" - #include -#include "paddle/fluid/pir/drr/api/ir_operation.h" +#include "paddle/fluid/pir/drr/include/drr_match_context.h" #include "paddle/fluid/pir/drr/match_context_impl.h" namespace paddle { @@ -25,8 +23,7 @@ namespace drr { MatchContext::MatchContext(std::shared_ptr impl) : impl_(impl) {} -const TensorInterface& MatchContext::Tensor( - const std::string& tensor_name) const { +const pir::Value& MatchContext::Tensor(const std::string& tensor_name) const { return impl_->Tensor(tensor_name); } diff --git a/paddle/fluid/pir/drr/match_context_impl.h b/paddle/fluid/pir/drr/match_context_impl.h index 21cca6ead2c3a..26c043384069b 100644 --- a/paddle/fluid/pir/drr/match_context_impl.h +++ b/paddle/fluid/pir/drr/match_context_impl.h @@ -18,12 +18,12 @@ #include #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" -#include "paddle/fluid/pir/drr/api/ir_operation.h" -#include "paddle/fluid/pir/drr/api/ir_value.h" -#include "paddle/fluid/pir/drr/api/tensor_interface.h" #include "paddle/fluid/pir/drr/attr_type_uilts.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/pir/core/builtin_attribute.h" +#include "paddle/pir/core/operation.h" +#include "paddle/pir/core/operation_utils.h" +#include "paddle/pir/core/value.h" namespace paddle { namespace drr { @@ -33,7 +33,7 @@ class MatchContextImpl final { MatchContextImpl() = default; ~MatchContextImpl() = default; - const TensorInterface& Tensor(const std::string& tensor_name) const { + const pir::Value& Tensor(const std::string& tensor_name) const { PADDLE_ENFORCE_NE( tensor_map_.count(tensor_name), 0, @@ -41,10 +41,10 @@ class MatchContextImpl final { "Not found tensor." "The Drr tensor [%s] must exist in pattern graph to be obtained.", tensor_name)); - return *tensor_map_.at(tensor_name); + return tensor_map_.at(tensor_name); } - const IrOperation& Operation(const OpCall* op_call) const { + pir::Operation* IrOperation(const OpCall* op_call) const { PADDLE_ENFORCE_NE( operation_map_.count(op_call), 0, @@ -52,7 +52,7 @@ class MatchContextImpl final { "The Drr operation [%s] must exist in the " "pattern graph to be obtained.", op_call->name())); - return *operation_map_.at(op_call); + return operation_map_.at(op_call); } template @@ -60,7 +60,7 @@ class MatchContextImpl final { return IrAttrTypeCast::To(GetIrAttr(attr_name)); } - const IrValue& GetIrValue(const std::string& tensor_name) const { + pir::Value GetIrValue(const std::string& tensor_name) const { auto iter = tensor_map_.find(tensor_name); PADDLE_ENFORCE_NE( iter, @@ -69,7 +69,7 @@ class MatchContextImpl final { "The Drr tensor [%s] is not found in the map, " "unable to obtain the corresponding IrValue.", tensor_name)); - return *iter->second; + return iter->second; } pir::Attribute GetIrAttr(const std::string& attr_name) const { @@ -84,8 +84,8 @@ class MatchContextImpl final { return iter->second; } - const std::unordered_map>& - operation_map() const { + const std::unordered_map& operation_map() + const { return operation_map_; } @@ -93,18 +93,15 @@ class MatchContextImpl final { return attr_map_; } - const std::unordered_map>& tensor_map() - const { + const std::unordered_map& tensor_map() const { return tensor_map_; } - void BindIrValue(const std::string& value_name, - const std::shared_ptr& value) { + void BindIrValue(const std::string& value_name, const pir::Value& value) { tensor_map_.emplace(value_name, value); } - void BindIrOperation(const OpCall* op_call, - const std::shared_ptr& op) { + void BindIrOperation(const OpCall* op_call, pir::Operation* op) { operation_map_.emplace(op_call, op); const auto& attrs = op_call->attributes(); for (const auto& kv : attrs) { @@ -112,7 +109,7 @@ class MatchContextImpl final { [&](auto&& arg) { if constexpr (std::is_same_v, NormalAttribute>) { - BindIrAttr(arg.name(), op->get()->attribute(kv.first)); + BindIrAttr(arg.name(), op->attribute(kv.first)); } }, kv.second); @@ -124,9 +121,8 @@ class MatchContextImpl final { attr_map_.emplace(attr_name, attr); } - std::unordered_map> tensor_map_; - std::unordered_map> - operation_map_; + std::unordered_map tensor_map_; + std::unordered_map operation_map_; std::unordered_map attr_map_; }; diff --git a/paddle/fluid/pir/drr/drr_pattern_context.cc b/paddle/fluid/pir/drr/pattern_context.cc similarity index 92% rename from paddle/fluid/pir/drr/drr_pattern_context.cc rename to paddle/fluid/pir/drr/pattern_context.cc index 65d72a9f58175..a3823ab0e1810 100644 --- a/paddle/fluid/pir/drr/drr_pattern_context.cc +++ b/paddle/fluid/pir/drr/pattern_context.cc @@ -12,10 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" #include +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/phi/core/enforce.h" namespace paddle { @@ -66,8 +67,8 @@ void DrrPatternContext::RequireEqual(const TensorShape& first, // Note: we capture the datas by value for constrain_fn // because the datas are destructed before running constrain_fn. auto constrain_fn = [=](const MatchContext& match_context) { - return match_context.Tensor(first.tensor_name()).Shape() == - match_context.Tensor(second.tensor_name()).Shape(); + return pir::GetShapeFromValue(match_context.Tensor(first.tensor_name())) == + pir::GetShapeFromValue(match_context.Tensor(second.tensor_name())); }; constraints_.emplace_back(constrain_fn); } @@ -77,8 +78,10 @@ void DrrPatternContext::RequireEqual(const TensorDataType& first, // Note: we capture the datas by value for constrain_fn // because the datas are destructed before running constrain_fn. auto constrain_fn = [=](const MatchContext& match_context) { - return match_context.Tensor(first.tensor_name()).Dtype() == - match_context.Tensor(second.tensor_name()).Dtype(); + return pir::GetDataTypeFromValue( + match_context.Tensor(first.tensor_name())) == + pir::GetDataTypeFromValue( + match_context.Tensor(second.tensor_name())); }; constraints_.emplace_back(constrain_fn); } diff --git a/paddle/fluid/pir/drr/pattern_graph.cc b/paddle/fluid/pir/drr/pattern_graph.cc index 58c79c65acf2f..5409133b7480b 100644 --- a/paddle/fluid/pir/drr/pattern_graph.cc +++ b/paddle/fluid/pir/drr/pattern_graph.cc @@ -16,7 +16,7 @@ #include -#include "paddle/fluid/pir/drr/api/drr_pattern_context.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_context.h" #include "paddle/phi/core/enforce.h" namespace paddle { diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc b/paddle/fluid/pir/drr/rewrite_pattern.cc similarity index 93% rename from paddle/fluid/pir/drr/drr_rewrite_pattern.cc rename to paddle/fluid/pir/drr/rewrite_pattern.cc index 9c32354932510..52fdada10dc52 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/rewrite_pattern.cc @@ -12,22 +12,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -// #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" -// #include "paddle/fluid/pir/drr/api/ir_operation.h" -// #include "paddle/fluid/pir/drr/api/match_context.h" - -// #include "paddle/phi/core/enforce.h" -// #include "paddle/pir/core/operation.h" -// #include "paddle/pir/core/type_name.h" -// #include "paddle/fluid/pir/drr/api/drr_pattern_context.h" - -#include "paddle/fluid/pir/drr/api/drr_rewrite_pattern.h" - -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_rewrite_pattern.h" #include "paddle/fluid/pir/drr/ir_operation_factory.h" #include "paddle/fluid/pir/drr/match_context_impl.h" #include "paddle/fluid/pir/drr/pattern_graph.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/pir/core/operation.h" + namespace paddle { namespace drr { @@ -296,13 +289,11 @@ bool DrrRewritePattern::MatchFromOutputToInput( << ir_node->num_results() << ")."; break; } - source_pattern_match_ctx->BindIrOperation( - drr_node, std::make_shared(ir_node)); + source_pattern_match_ctx->BindIrOperation(drr_node, ir_node); // binding input_tensor of current_op for (size_t i = 0; i < drr_input_tensors.size(); ++i) { - source_pattern_match_ctx->BindIrValue( - drr_input_tensors[i]->name(), - std::make_shared(ir_node->operand(i).source())); + source_pattern_match_ctx->BindIrValue(drr_input_tensors[i]->name(), + ir_node->operand(i).source()); if (ir_node->operand_source(i).isa()) { matched = false; VLOG(8) << drr_node->name() @@ -348,9 +339,8 @@ bool DrrRewritePattern::MatchFromOutputToInput( // binding output tensor of current_op auto drr_op_output_tensor = drr_node->outputs(); for (size_t j = 0; j < drr_op_output_tensor.size(); j++) { - source_pattern_match_ctx->BindIrValue( - drr_op_output_tensor[j]->name(), - std::make_shared(ir_node->result(j))); + source_pattern_match_ctx->BindIrValue(drr_op_output_tensor[j]->name(), + ir_node->result(j)); } ++step; } @@ -415,9 +405,7 @@ MatchContextImpl DrrRewritePattern::CreateOperations( "pattern graph to be obtained.", in_tensor)); if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { - res_match_ctx.BindIrValue( - in_tensor, - std::make_shared(src_match_ctx.GetIrValue(in_tensor))); + res_match_ctx.BindIrValue(in_tensor, src_match_ctx.GetIrValue(in_tensor)); } } @@ -467,9 +455,8 @@ MatchContextImpl DrrRewritePattern::CreateOperations( } if (max_input_op_index == 0UL) { VLOG(6) << "Not found producer op for (" << op_call.name() << ")"; - pir::Operation* source_patter_first_op = - src_match_ctx.Operation(source_pattern_graph.owned_op_call()[0].get()) - .get(); + pir::Operation* source_patter_first_op = src_match_ctx.IrOperation( + source_pattern_graph.owned_op_call()[0].get()); max_input_op_index = op_2_temp_program_index[source_patter_first_op]; rewriter.set_insertion_point(source_patter_first_op); } else { @@ -495,9 +482,8 @@ void DrrRewritePattern::RebindIrTensorForAssignTensor( for (const auto& kv : tensor_assign_map) { const auto& src_tensor_name = kv.first; const auto& dst_tensor_name = kv.second; - res_match_ctx->BindIrValue( - src_tensor_name, - std::make_shared(res_match_ctx->GetIrValue(dst_tensor_name))); + res_match_ctx->BindIrValue(src_tensor_name, + res_match_ctx->GetIrValue(dst_tensor_name)); } } @@ -509,7 +495,7 @@ void DrrRewritePattern::ReplaceOutputTensor( if (source_pattern_graph_->id2owend_tensor().count(output_name)) { const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); - rewriter.ReplaceAllUsesWith(src_ir_tensor.get(), res_ir_tensor.get()); + rewriter.ReplaceAllUsesWith(src_ir_tensor, res_ir_tensor); } else { LOG(WARNING) << "The output tensor (" << output_name << ") in the result_pattern_graph is not the tensor" @@ -527,7 +513,7 @@ void DrrRewritePattern::DeleteSourcePatternOp( std::unordered_set delete_ops_set; GraphTopo graph_topo_visit(&source_pattern_graph); graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { - pir::Operation* op = src_match_ctx.Operation(&op_call).get(); + pir::Operation* op = src_match_ctx.IrOperation(&op_call); VLOG(5) << "DRR delete op: " << op->name() << " pointer: " << op; if (delete_ops_set.count(op) == 0 && op->use_empty()) { delete_ops_que.push(op); diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index acebc1d854d99..c33da0e7286e3 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -14,7 +14,8 @@ #include "paddle/fluid/pir/transforms/fusion/attention_fuse_pass.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" @@ -158,7 +159,8 @@ class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { res.Tensor("concat_1_out") = concat_1(res.Tensor("combine_1_out")); const auto &reshape_5_shape = res.Attr( [](const paddle::drr::MatchContext &match_ctx) -> std::vector { - auto matmul_1_in_2 = match_ctx.Tensor("matmul_1_in_2").Shape(); + auto matmul_1_in_2 = + pir::GetShapeFromValue(match_ctx.Tensor("matmul_1_in_2")); return {-1, 3, matmul_1_in_2.at(1)}; }); const auto &reshape_5 = diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc index 39b5d7cf087df..9e71255c30e6a 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc @@ -23,7 +23,7 @@ #include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/common/ddim.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/pass/pass.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc index c0ca75fcb5a8f..79f573ec18d99 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc @@ -14,7 +14,8 @@ #include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" @@ -48,12 +49,14 @@ class FcElementwiseLayerNormFusePattern : public paddle::drr::DrrPatternBase { // Constrains the activation is none pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { int64_t layer_norm_x = 1; + auto fc_out_dims = pir::GetShapeFromValue(match_ctx.Tensor("fc_out")); + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); for (int i = match_ctx.Attr("begin_norm_axis"); - i < match_ctx.Tensor("fc_out").Shape().size(); + i < fc_out_dims.size(); i++) { - layer_norm_x *= match_ctx.Tensor("fc_out").Shape().at(i); + layer_norm_x *= fc_out_dims.at(i); } - if (layer_norm_x == match_ctx.Tensor("w").Shape().at(1)) { + if (layer_norm_x == w_dims.at(1)) { return true; } return false; @@ -121,12 +124,14 @@ class FcElementwiseLayerNormFuse2Pattern : public paddle::drr::DrrPatternBase { // Constrains the activation is none pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { int64_t layer_norm_x = 1; + auto fc_out_dims = pir::GetShapeFromValue(match_ctx.Tensor("fc_out")); + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); for (int i = match_ctx.Attr("begin_norm_axis"); - i < match_ctx.Tensor("fc_out").Shape().size(); + i < fc_out_dims.size(); i++) { - layer_norm_x *= match_ctx.Tensor("fc_out").Shape().at(i); + layer_norm_x *= fc_out_dims.at(i); } - if (layer_norm_x == match_ctx.Tensor("w").Shape().at(1)) { + if (layer_norm_x == w_dims.at(1)) { return true; } return false; diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc index 749a54383d95a..a48488f46728f 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc @@ -14,7 +14,8 @@ #include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" @@ -33,25 +34,22 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase { pat.Tensor("add_out") = add(pat.Tensor("matmul_out"), pat.Tensor("y")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - if (match_ctx.Tensor("w").Shape().size() != 2 || - match_ctx.Tensor("x").Shape().size() < 2) { + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto y_dims = pir::GetShapeFromValue(match_ctx.Tensor("y")); + if (w_dims.size() != 2 || x_dims.size() < 2) { return false; } - if (match_ctx.Tensor("x").Shape().at( - match_ctx.Tensor("x").Shape().size() - 1) != - match_ctx.Tensor("w").Shape().at(0) || + if (x_dims.at(x_dims.size() - 1) != w_dims.at(0) || match_ctx.Attr("transpose_x") == true || match_ctx.Attr("transpose_y") == true) { return false; } - if (match_ctx.Tensor("y").Shape().size() == 1) { - return match_ctx.Tensor("y").Shape().at(0) == - match_ctx.Tensor("w").Shape().at(1); + if (y_dims.size() == 1) { + return y_dims.at(0) == w_dims.at(1); } - if (match_ctx.Tensor("y").Shape().size() == 2) { - return match_ctx.Tensor("y").Shape().at(0) == 1 && - match_ctx.Tensor("y").Shape().at(1) == - match_ctx.Tensor("w").Shape().at(1); + if (y_dims.size() == 2) { + return y_dims.at(0) == 1 && y_dims.at(1) == w_dims.at(1); } return false; }); @@ -60,7 +58,8 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase { const auto &in_num_col_dims_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> std::any { - return match_ctx.Tensor("x").Shape().size() - 1; + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + return x_dims.size() - 1; }); const auto &false_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc index c935b66d2fe3d..9911263dedbd1 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc index f728b888d2461..3227a57cd2318 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -14,7 +14,7 @@ #include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 36404dae16381..c1747b6312488 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -13,8 +13,10 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" @@ -34,9 +36,11 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase { pat.Tensor("out") = add(pat.Tensor("tmp"), pat.Tensor("bias")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("w").Shape().size() == 2 && - match_ctx.Tensor("x").Shape().size() >= 2 && - match_ctx.Tensor("bias").Shape().size() == 1); + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto bias_dims = pir::GetShapeFromValue(match_ctx.Tensor("bias")); + return (w_dims.size() == 2 && x_dims.size() >= 2 && + bias_dims.size() == 1); }); paddle::drr::ResultPattern res = pat.ResultPattern(); @@ -57,7 +61,7 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase { std::string pattern_name() const override { return "FusedLinearPattern"; } }; -class FusedLinearGradPattern : public paddle::drr::DrrPatternBase < { +class FusedLinearGradPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -78,9 +82,11 @@ class FusedLinearGradPattern : public paddle::drr::DrrPatternBase < { {&pat.Tensor("x_grad"), &pat.Tensor("w_grad")}); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("w").Shape().size() == 2 && - match_ctx.Tensor("x").Shape().size() >= 2 && - match_ctx.Tensor("bias").Shape().size() == 1); + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto bias_dims = pir::GetShapeFromValue(match_ctx.Tensor("bias")); + return (w_dims.size() == 2 && x_dims.size() >= 2 && + bias_dims.size() == 1); }); paddle::drr::ResultPattern res = pat.ResultPattern(); @@ -152,6 +158,7 @@ class FusedLinearGeluPattern : public paddle::drr::DrrPatternBase { std::string pattern_name() const override { return "FusedLinearGeluPattern"; } }; + class FusedLinearReluPattern : public paddle::drr::DrrPatternBase { public: void operator()(paddle::drr::DrrPatternContext *ctx) const override { diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc index 737fdf5b3f396..d3df0f49d25fd 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -14,7 +14,8 @@ #include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" #include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" @@ -50,18 +51,22 @@ class FusedMatmulAddGradAddPattern : public paddle::drr::DrrPatternBase { pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { const auto &x_trans = match_ctx.Attr("trans_x"); const auto &y_trans = match_ctx.Attr("trans_y"); - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape() && - match_ctx.Tensor("out").Shape() == - match_ctx.Tensor("fwd_add_out_grad").Shape() && - x_trans == false && y_trans == false); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + auto out_dims = pir::GetShapeFromValue(match_ctx.Tensor("out")); + auto fwd_add_out_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("fwd_add_out_grad")); + return (weight_grad_dims == dweight_dims && + out_dims == fwd_add_out_grad_dims && x_trans == false && + y_trans == false); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = @@ -113,17 +118,19 @@ class FusedMatmulGradAddPattern : public paddle::drr::DrrPatternBase { pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { const auto &x_trans = match_ctx.Attr("trans_x"); const auto &y_trans = match_ctx.Attr("trans_y"); - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape() && - x_trans == false && y_trans == false); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + return (weight_grad_dims == dweight_dims && x_trans == false && + y_trans == false); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = @@ -173,15 +180,17 @@ class FusedMatmulAddaPattern : public paddle::drr::DrrPatternBase { add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape()); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + return (weight_grad_dims == dweight_dims); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = @@ -223,15 +232,17 @@ class FusedMatmulAddbPattern : public paddle::drr::DrrPatternBase { add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape()); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + return (weight_grad_dims == dweight_dims); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = @@ -285,17 +296,19 @@ class FusedMatmulAddGradAddaPattern : public paddle::drr::DrrPatternBase { add_(pat.Tensor("dweight"), pat.Tensor("weight_grad")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape() && - match_ctx.Tensor("out").Shape() == - match_ctx.Tensor("dadd_out").Shape()); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + auto out_dims = pir::GetShapeFromValue(match_ctx.Tensor("out")); + auto dadd_out_dims = pir::GetShapeFromValue(match_ctx.Tensor("dadd_out")); + return (weight_grad_dims == dweight_dims && out_dims == dadd_out_dims); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { @@ -345,17 +358,19 @@ class FusedMatmulAddGradAddbPattern : public paddle::drr::DrrPatternBase { add_(pat.Tensor("weight_grad"), pat.Tensor("dweight")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - return (match_ctx.Tensor("weight_grad").Shape() == - match_ctx.Tensor("dweight").Shape() && - match_ctx.Tensor("out").Shape() == - match_ctx.Tensor("dadd_out").Shape()); + auto weight_grad_dims = + pir::GetShapeFromValue(match_ctx.Tensor("weight_grad")); + auto dweight_dims = pir::GetShapeFromValue(match_ctx.Tensor("dweight")); + auto out_dims = pir::GetShapeFromValue(match_ctx.Tensor("out")); + auto dadd_out_dims = pir::GetShapeFromValue(match_ctx.Tensor("dadd_out")); + return (weight_grad_dims == dweight_dims && out_dims == dadd_out_dims); }); paddle::drr::ResultPattern res = pat.ResultPattern(); const auto &muti_precision_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { - return !(match_ctx.Tensor("dweight").Dtype() == - match_ctx.Tensor("weight_grad").Dtype()); + return !(pir::GetDataTypeFromValue(match_ctx.Tensor("dweight")) == + pir::GetDataTypeFromValue(match_ctx.Tensor("weight_grad"))); }); const auto &true_attr = res.Attr([](const paddle::drr::MatchContext &match_ctx) -> bool { diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index d324e1f2ece23..e3732c24db3ce 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -14,7 +14,8 @@ #include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/place.h" #include "paddle/pir/pass/pass.h" @@ -62,21 +63,21 @@ class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { bool matmul_trans_y = match_ctx.Attr("matmul_transpose_y"); if (matmul_trans_x || matmul_trans_y) return false; - if (!(match_ctx.Tensor("w").Shape().size() == 2 && - match_ctx.Tensor("x").Shape().size() >= 2 && - match_ctx.Tensor("bias").Shape().size() == 1)) { + auto w_dims = pir::GetShapeFromValue(match_ctx.Tensor("w")); + auto x_dims = pir::GetShapeFromValue(match_ctx.Tensor("x")); + auto bias_dims = pir::GetShapeFromValue(match_ctx.Tensor("bias")); + if (!(w_dims.size() == 2 && x_dims.size() >= 2 && + bias_dims.size() == 1)) { return false; } - auto w_dims = match_ctx.Tensor("w").Shape(); if (w_dims.at(0) % 64 != 0 || w_dims.at(1) % 16 != 0) return false; - auto w_dtype = match_ctx.Tensor("w").Dtype().get(); + auto w_dtype = pir::GetDataTypeFromValue(match_ctx.Tensor("w")); if (!w_dtype.isa() && !w_dtype.isa()) return false; - auto x_dims = match_ctx.Tensor("x").Shape(); if (x_dims.at(x_dims.size() - 1) != w_dims.at(1)) return false; return true; diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc index 4dbb6cab4b5bc..059e00175ee21 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -16,7 +16,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/common/ddim.h" diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index 2e2e1845d2b9b..4bd0d19c96cc6 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -17,8 +17,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" -#include "paddle/fluid/pir/drr/api/ir_value.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" @@ -158,10 +157,9 @@ class RemoveUselessConcatPattern : public paddle::drr::DrrPatternBase { pat.Tensor("out") = pat.Op(paddle::dialect::ConcatOp::name())( pat.Tensor("combine_out"), pat.Tensor("axis")); pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { - auto combine_out = dynamic_cast( - match_ctx.Tensor("combine_out")); - return combine_out.type_isa() && - combine_out.type_dyn_cast().size() == 1; + auto combine_out = match_ctx.Tensor("combine_out"); + return combine_out.type().isa() && + combine_out.type().dyn_cast().size() == 1; }); auto res = pat.ResultPattern(); res.Tensor("out").Assign(res.Tensor("x")); diff --git a/test/cpp/pir/cinn/dialect_convert_test.cc b/test/cpp/pir/cinn/dialect_convert_test.cc index 398c089268830..f67e55cade1f3 100644 --- a/test/cpp/pir/cinn/dialect_convert_test.cc +++ b/test/cpp/pir/cinn/dialect_convert_test.cc @@ -21,7 +21,7 @@ #include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_manager.h" diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc index 6cc0a16611904..fcffa97d4084d 100644 --- a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass.h" diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc index f2f032a0181df..bfb59e39fe1b8 100644 --- a/test/cpp/pir/pattern_rewrite/drr_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -18,7 +18,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" -#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/dead_code_elimination_pass.h" #include "paddle/pir/core/builtin_dialect.h" #include "paddle/pir/pass/pass_manager.h" From fe3aae1f85d85ccaab24760abd34ef2d5af5a2b2 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 15 Jan 2024 04:51:58 +0000 Subject: [PATCH 3/5] fix include for all pass --- .../pir/drr/include/drr_pattern_context.h | 1 + .../transforms/fusion/attention_fuse_pass.cc | 1 - .../fusion/conv2d_add_act_fuse_pass.cc | 10 ++++++---- .../transforms/fusion/conv2d_add_fuse_pass.cc | 12 +++-------- .../transforms/fusion/conv2d_bn_fuse_pass.cc | 11 +++++----- .../fc_elementwise_layernorm_fuse_pass.cc | 3 ++- .../pir/transforms/fusion/fc_fuse_pass.cc | 3 ++- .../fused_dot_product_attention_pass.cc | 3 ++- .../fusion/fused_dropout_add_pass.cc | 3 ++- .../fusion/fused_gemm_epilogue_pass.cc | 2 +- .../fused_linear_param_grad_add_pass.cc | 4 +++- .../fusion/fused_weight_only_linear_pass.cc | 3 ++- .../fusion/matmul_scale_fuse_pass.cc | 6 +----- .../pir/transforms/identity_op_clean_pass.cc | 9 +-------- .../replace_fetch_with_shadow_output_pass.cc | 20 ++++--------------- paddle/pir/core/visitors.h | 4 ++-- paddle/pir/pass/pass.cc | 2 ++ paddle/pir/pass/pass.h | 5 ++--- .../pir/pattern_rewrite/pattern_applicator.cc | 2 +- .../pir/pattern_rewrite/pattern_applicator.h | 5 ++++- .../pattern_rewrite/pattern_rewrite_driver.cc | 16 +++++++++++++++ .../pattern_rewrite/pattern_rewrite_driver.h | 19 ++++-------------- 22 files changed, 66 insertions(+), 78 deletions(-) diff --git a/paddle/fluid/pir/drr/include/drr_pattern_context.h b/paddle/fluid/pir/drr/include/drr_pattern_context.h index ec0508b05f592..0539708300ac7 100644 --- a/paddle/fluid/pir/drr/include/drr_pattern_context.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_context.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "paddle/fluid/pir/drr/include/drr_match_context.h" diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index c33da0e7286e3..73355875c40be 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -18,7 +18,6 @@ #include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc index a3c71cc90e60d..8ef58da2b4bad 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.cc @@ -11,17 +11,19 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_act_fuse_pass.h" -#include "paddle/common/ddim.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" + +#include "paddle/common/ddim.h" + namespace { class Conv2dAddActFusePattern diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc index 9e71255c30e6a..34d84dfc17b73 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc @@ -12,19 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" - -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" - -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" -#include "paddle/common/ddim.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + #include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc index 42129852bc8bc..eefed9493d58e 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.cc @@ -11,17 +11,16 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" -#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/fusion/conv2d_bn_fuse_pass.h" -#include "paddle/common/ddim.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" + namespace { class Conv2dBnFusePattern diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc index 79f573ec18d99..ab74552849121 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc @@ -13,12 +13,13 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc index a48488f46728f..b3fbd578da6cf 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc @@ -13,12 +13,13 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fc_fuse_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc index 9911263dedbd1..b7b4fc45327f5 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc @@ -13,11 +13,12 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc index 3227a57cd2318..ee26351124add 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -13,11 +13,12 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index c1747b6312488..4284dd126e248 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -17,9 +17,9 @@ #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc index d3df0f49d25fd..63c76404b6136 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -13,12 +13,14 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + namespace { // add_grad + matmul_grad + add_ -> matmul + fused_liner_param_gard_add diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index e3732c24db3ce..2cabca6a07b6c 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -13,14 +13,15 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/place.h" + #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc index 059e00175ee21..0eeff6c6598e3 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -13,17 +13,13 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" #include "paddle/fluid/pir/transforms/transform_general_functions.h" -#include "paddle/common/ddim.h" - #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index 4bd0d19c96cc6..cde25aaefe4ad 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -13,20 +13,13 @@ // limitations under the License. #include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" -#include -#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" -#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" + #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/drr/include/drr_pattern_base.h" -#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" -#include "paddle/fluid/pir/transforms/transform_general_functions.h" - -#include "paddle/common/ddim.h" #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { diff --git a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc index 5e499436ec7f6..8029cfc9ddbf5 100644 --- a/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc +++ b/paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.cc @@ -18,9 +18,6 @@ #include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/pass.h" #include "paddle/pir/pass/pass_registry.h" -#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" -#include "paddle/pir/pattern_rewrite/pattern_match.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace { @@ -39,24 +36,15 @@ class ReplaceFetchWithShadowOutputPattern } }; -class ReplaceFetchWithShadowOutputPass : public pir::Pass { +class ReplaceFetchWithShadowOutputPass : public pir::PatternRewritePass { public: ReplaceFetchWithShadowOutputPass() - : pir::Pass("replace_fetch_with_shadow_output_pass", 0) {} + : pir::PatternRewritePass("replace_fetch_with_shadow_output_pass", 0) {} - bool Initialize(pir::IrContext* context) override { + pir::RewritePatternSet InitializePatterns(pir::IrContext* context) override { pir::RewritePatternSet ps(context); ps.Add(context); - patterns_ = pir::FrozenRewritePatternSet(std::move(ps)); - return true; - } - - void Run(pir::Operation* op) override { - pir::GreedyRewriteConfig cfg; - cfg.use_top_down_traversal = true; - cfg.max_iterations = 10; - auto [_, num_rewrites] = pir::ApplyPatternsGreedily(op, patterns_, cfg); - AddStatistics(num_rewrites); + return ps; } bool CanApplyOn(pir::Operation* op) const override { diff --git a/paddle/pir/core/visitors.h b/paddle/pir/core/visitors.h index 3fdcb71bff9b9..7d9e9eacf4394 100644 --- a/paddle/pir/core/visitors.h +++ b/paddle/pir/core/visitors.h @@ -41,8 +41,8 @@ void Walk(Operation *op, template void Walk(Operation *op, FuncTy &&callback) { - return detail::Walk(op, callback, Order); + return Walk(op, callback, Order); } - } // namespace detail + } // namespace pir diff --git a/paddle/pir/pass/pass.cc b/paddle/pir/pass/pass.cc index 2f9cb896215dd..c04669317ef16 100644 --- a/paddle/pir/pass/pass.cc +++ b/paddle/pir/pass/pass.cc @@ -23,6 +23,8 @@ #include "paddle/pir/pass/pass_adaptor.h" #include "paddle/pir/pass/pass_instrumentation.h" #include "paddle/pir/pass/pass_manager.h" +#include "paddle/pir/pattern_rewrite/pattern_match.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" namespace pir { diff --git a/paddle/pir/pass/pass.h b/paddle/pir/pass/pass.h index 6c2c565322bf8..f85c7519cbe19 100644 --- a/paddle/pir/pass/pass.h +++ b/paddle/pir/pass/pass.h @@ -21,9 +21,8 @@ #include #include "paddle/common/enforce.h" -#include "paddle/pir/core/builtin_op.h" #include "paddle/pir/pass/analysis_manager.h" -#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" namespace pir { @@ -197,7 +196,7 @@ class IR_API Pass { std::unordered_map> attr_dels_; }; -class PatternRewritePass : public Pass { +class IR_API PatternRewritePass : public Pass { public: PatternRewritePass(const std::string& name, uint8_t opt_level, diff --git a/paddle/pir/pattern_rewrite/pattern_applicator.cc b/paddle/pir/pattern_rewrite/pattern_applicator.cc index 6e45768542061..f67e41255a33e 100644 --- a/paddle/pir/pattern_rewrite/pattern_applicator.cc +++ b/paddle/pir/pattern_rewrite/pattern_applicator.cc @@ -14,8 +14,8 @@ #include +#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/pir/pattern_rewrite/pattern_applicator.h" - #include "paddle/pir/pattern_rewrite/pattern_match.h" namespace pir { diff --git a/paddle/pir/pattern_rewrite/pattern_applicator.h b/paddle/pir/pattern_rewrite/pattern_applicator.h index a0fdf58fd57e0..37c0a42cbf974 100644 --- a/paddle/pir/pattern_rewrite/pattern_applicator.h +++ b/paddle/pir/pattern_rewrite/pattern_applicator.h @@ -21,11 +21,14 @@ #include "paddle/pir/core/op_info.h" #include "paddle/pir/core/operation.h" -#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" #include "paddle/pir/pattern_rewrite/pattern_match.h" namespace pir { +class FrozenRewritePatternSet; +class RewritePattern; +class Pattern; + class PatternApplicator { public: using CostModel = std::function; diff --git a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc index 3788c63273ffa..c138785038d5a 100644 --- a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc +++ b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.cc @@ -227,4 +227,20 @@ std::pair ApplyPatternsGreedily( return std::make_pair(converged, num_rewrites); } +IR_API std::pair ApplyPatternsGreedily( + Operation* op, + const FrozenRewritePatternSet& patterns, + GreedyRewriteConfig config) { + bool sum_converged = true; + int64_t sum_num_rewrites = 0; + for (uint32_t i = 0; i < op->num_regions(); ++i) { + Region& region = op->region(i); + auto [converged, num_rewrites] = + ApplyPatternsGreedily(region, patterns, config); + sum_converged &= converged; + sum_num_rewrites += num_rewrites; + } + return std::make_pair(sum_converged, sum_num_rewrites); +} + } // namespace pir diff --git a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.h b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.h index 8186e3cadb195..8ed55843e2adb 100644 --- a/paddle/pir/pattern_rewrite/pattern_rewrite_driver.h +++ b/paddle/pir/pattern_rewrite/pattern_rewrite_driver.h @@ -16,11 +16,11 @@ #include "paddle/pir/core/dll_decl.h" #include "paddle/pir/core/region.h" -#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h" -#include "paddle/pir/pattern_rewrite/pattern_match.h" namespace pir { +class FrozenRewritePatternSet; + /// This enum will control which ops will be added to the worklist during the /// match rewrite process enum class IR_API GreedyRewriteStrictness { @@ -73,20 +73,9 @@ ApplyPatternsGreedily(Region& region, // NOLINT GreedyRewriteConfig config = GreedyRewriteConfig()); /// Perform a match and rewrite process for all regions of a given op. -inline IR_API std::pair ApplyPatternsGreedily( +IR_API std::pair ApplyPatternsGreedily( Operation* op, const FrozenRewritePatternSet& patterns, - GreedyRewriteConfig config = GreedyRewriteConfig()) { - bool sum_converged = true; - int64_t sum_num_rewrites = 0; - for (uint32_t i = 0; i < op->num_regions(); ++i) { - Region& region = op->region(i); - auto [converged, num_rewrites] = - ApplyPatternsGreedily(region, patterns, config); - sum_converged &= converged; - sum_num_rewrites += num_rewrites; - } - return std::make_pair(sum_converged, sum_num_rewrites); -} + GreedyRewriteConfig config = GreedyRewriteConfig()); } // namespace pir From e4692cac83a010a3f1a2a8c2a45100e8fca8ca10 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 15 Jan 2024 06:33:21 +0000 Subject: [PATCH 4/5] pattern_name to name, pattern_benefit to benefit --- paddle/fluid/pir/drr/README.md | 14 +++++------- paddle/fluid/pir/drr/README_cn.md | 6 ++--- .../fluid/pir/drr/include/drr_pattern_base.h | 4 ++-- paddle/fluid/pir/drr/rewrite_pattern.cc | 2 +- .../transforms/fusion/attention_fuse_pass.cc | 4 +--- .../transforms/fusion/conv2d_add_fuse_pass.cc | 2 +- .../fc_elementwise_layernorm_fuse_pass.cc | 4 ++-- .../pir/transforms/fusion/fc_fuse_pass.cc | 4 ++-- .../fused_dot_product_attention_pass.cc | 8 +++---- .../fusion/fused_dropout_add_pass.cc | 6 ++--- .../fusion/fused_gemm_epilogue_pass.cc | 16 +++++--------- .../fused_linear_param_grad_add_pass.cc | 20 +++++------------ .../fusion/fused_weight_only_linear_pass.cc | 4 +--- .../fusion/matmul_scale_fuse_pass.cc | 2 +- .../pir/transforms/identity_op_clean_pass.cc | 22 +++++-------------- .../drr_same_type_binding_test.cc | 4 +--- test/cpp/pir/pattern_rewrite/drr_test.cc | 18 +++++---------- 17 files changed, 49 insertions(+), 91 deletions(-) diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 3a8e69584b68a..6d320e61fb857 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -32,9 +32,9 @@ class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } - std::string pattern_name() const override { - return "RemoveRedundentCastPattern"; - } + std::string pattern_name() const override { + return "RemoveRedundentCastPattern"; + } }; ~~~ @@ -197,9 +197,7 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("out")}); } - std::string pattern_name() const override { - return "FusedLinearPattern"; - } + std::string name() const override { return "FusedLinearPattern"; } }; ~~~ @@ -233,8 +231,6 @@ class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = full2(); } - std::string pattern_name() const override { - return "FoldExpandToConstantPattern"; - } + std::string name() const override { return "FoldExpandToConstantPattern"; } }; ~~~ diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md index 57cf8e23050a1..4051a5e547f31 100644 --- a/paddle/fluid/pir/drr/README_cn.md +++ b/paddle/fluid/pir/drr/README_cn.md @@ -32,7 +32,7 @@ class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } - std::string pattern_name() const override { return "RemoveRedundentCastPattern"; } + std::string name() const override { return "RemoveRedundentCastPattern"; } }; ~~~ @@ -198,7 +198,7 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("out")}); } - std::string pattern_name() const override { return "FusedLinearPattern"; } + std::string name() const override { return "FusedLinearPattern"; } }; ~~~ @@ -232,6 +232,6 @@ class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = full2(); } - std::string pattern_name() const override { return "FoldExpandToConstantPattern"; } + std::string name() const override { return "FoldExpandToConstantPattern"; } }; ~~~ diff --git a/paddle/fluid/pir/drr/include/drr_pattern_base.h b/paddle/fluid/pir/drr/include/drr_pattern_base.h index f4dfc2ad12747..e079fed999a13 100644 --- a/paddle/fluid/pir/drr/include/drr_pattern_base.h +++ b/paddle/fluid/pir/drr/include/drr_pattern_base.h @@ -39,10 +39,10 @@ class DrrPatternBase { virtual void operator()(drr::DrrPatternContext* ctx) const = 0; // Give the drr pattern name. - virtual std::string pattern_name() const = 0; + virtual std::string name() const = 0; // Give the drr pattern benefit. - virtual uint32_t pattern_benefit() const { return 1; } + virtual uint32_t benefit() const { return 1; } // Build the Drr Pattern. std::unique_ptr Build(pir::IrContext* ir_context) const; diff --git a/paddle/fluid/pir/drr/rewrite_pattern.cc b/paddle/fluid/pir/drr/rewrite_pattern.cc index 52fdada10dc52..5d3726246b36b 100644 --- a/paddle/fluid/pir/drr/rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/rewrite_pattern.cc @@ -543,7 +543,7 @@ std::unique_ptr DrrPatternBase::Build( DrrPatternContext drr_context; this->operator()(&drr_context); return std::make_unique( - pattern_name(), drr_context, ir_context, pattern_benefit()); + name(), drr_context, ir_context, benefit()); } } // namespace drr diff --git a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc index 73355875c40be..616ff6f607c58 100644 --- a/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/attention_fuse_pass.cc @@ -217,9 +217,7 @@ class MultiHeadMatmulFusePattern : public paddle::drr::DrrPatternBase { {&res.Tensor("reshape_4_out")}); } - std::string pattern_name() const override { - return "MultiHeadMatmulFusePattern"; - } + std::string name() const override { return "MultiHeadMatmulFusePattern"; } }; class AttentionFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc index 34d84dfc17b73..fbfa9c6891a55 100644 --- a/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.cc @@ -77,7 +77,7 @@ class Conv2dAddFusePattern : public paddle::drr::DrrPatternBase { {&res.Tensor("add_out")}); } - std::string pattern_name() const override { return "Conv2dAddFusePattern"; } + std::string name() const override { return "Conv2dAddFusePattern"; } }; class Conv2dAddFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc index ab74552849121..e57e9b1bef727 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_elementwise_layernorm_fuse_pass.cc @@ -93,7 +93,7 @@ class FcElementwiseLayerNormFusePattern : public paddle::drr::DrrPatternBase { &res.Tensor("layernorm_variance")}); } - std::string pattern_name() const override { + std::string name() const override { return "FcElementwiseLayerNormFusePattern"; } }; @@ -159,7 +159,7 @@ class FcElementwiseLayerNormFuse2Pattern : public paddle::drr::DrrPatternBase { &res.Tensor("layernorm_variance")}); } - std::string pattern_name() const override { + std::string name() const override { return "FcElementwiseLayerNormFuse2Pattern"; } }; diff --git a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc index b3fbd578da6cf..18200f2e6b4e2 100644 --- a/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fc_fuse_pass.cc @@ -80,7 +80,7 @@ class MatmulAddPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("add_out")}); } - std::string pattern_name() const override { return "MatmulAddPattern"; } + std::string name() const override { return "MatmulAddPattern"; } }; class FcWithReluPattern : public paddle::drr::DrrPatternBase { @@ -119,7 +119,7 @@ class FcWithReluPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("relu_out")}); } - std::string pattern_name() const override { return "FcWithReluPattern"; } + std::string name() const override { return "FcWithReluPattern"; } }; class FcFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc index b7b4fc45327f5..0b5737ecf69d6 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dot_product_attention_pass.cc @@ -138,7 +138,7 @@ class FusedDotProductAttentionPattern : public paddle::drr::DrrPatternBase { &res.Tensor("rng_state")}); } - std::string pattern_name() const override { + std::string name() const override { return "FusedDotProductAttentionPattern"; } }; @@ -318,7 +318,7 @@ class FusedDotProductAttentionGradPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("q_grad"), &res.Tensor("k_grad"), &res.Tensor("v_grad")}); } - std::string pattern_name() const override { + std::string name() const override { return "FusedDotProductAttentionGradPattern"; } }; @@ -448,7 +448,7 @@ class FusedDotProductAttentionWithDropoutPattern &res.Tensor("rng_state")}); } - std::string pattern_name() const override { + std::string name() const override { return "FusedDotProductAttentionWithDropoutPattern"; } }; @@ -639,7 +639,7 @@ class FusedDotProductAttentionGradWithDropoutPattern {&res.Tensor("q_grad"), &res.Tensor("k_grad"), &res.Tensor("v_grad")}); } - std::string pattern_name() const override { + std::string name() const override { return "FusedDotProductAttentionGradWithDropoutPattern"; } }; diff --git a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc index ee26351124add..0041c70488ffa 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_dropout_add_pass.cc @@ -51,7 +51,7 @@ class FusedDropoutAddPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("add_out"), &res.Tensor("mask")}); } - std::string pattern_name() const override { return "FusedDropoutAddPattern"; } + std::string name() const override { return "FusedDropoutAddPattern"; } }; class FusedDropoutGradAddGradPattern : public paddle::drr::DrrPatternBase { @@ -105,9 +105,7 @@ class FusedDropoutGradAddGradPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("x_grad"), &res.Tensor("y_grad")}); } - std::string pattern_name() const override { - return "FusedDropoutGradAddGradPattern"; - } + std::string name() const override { return "FusedDropoutGradAddGradPattern"; } }; class FusedDropoutAddPass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc index 4284dd126e248..6a39c015893e3 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_gemm_epilogue_pass.cc @@ -58,7 +58,7 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("out")}); } - std::string pattern_name() const override { return "FusedLinearPattern"; } + std::string name() const override { return "FusedLinearPattern"; } }; class FusedLinearGradPattern : public paddle::drr::DrrPatternBase { @@ -116,7 +116,7 @@ class FusedLinearGradPattern : public paddle::drr::DrrPatternBase { &res.Tensor("bias_grad")}); } - std::string pattern_name() const override { return "FusedLinearGradPattern"; } + std::string name() const override { return "FusedLinearGradPattern"; } }; class FusedLinearGeluPattern : public paddle::drr::DrrPatternBase { @@ -156,7 +156,7 @@ class FusedLinearGeluPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("out"), &res.Tensor("reserve_space")}); } - std::string pattern_name() const override { return "FusedLinearGeluPattern"; } + std::string name() const override { return "FusedLinearGeluPattern"; } }; class FusedLinearReluPattern : public paddle::drr::DrrPatternBase { @@ -196,7 +196,7 @@ class FusedLinearReluPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("out"), &res.Tensor("reserve_space")}); } - std::string pattern_name() const override { return "FusedLinearReluPattern"; } + std::string name() const override { return "FusedLinearReluPattern"; } }; class FusedLinearGeluGradPattern : public paddle::drr::DrrPatternBase { @@ -265,9 +265,7 @@ class FusedLinearGeluGradPattern : public paddle::drr::DrrPatternBase { &res.Tensor("bias1_grad")}); } - std::string pattern_name() const override { - return "FusedLinearGeluGradPattern"; - } + std::string name() const override { return "FusedLinearGeluGradPattern"; } }; class FusedLinearReluGradPattern : public paddle::drr::DrrPatternBase { @@ -336,9 +334,7 @@ class FusedLinearReluGradPattern : public paddle::drr::DrrPatternBase { &res.Tensor("bias1_grad")}); } - std::string pattern_name() const override { - return "FusedLinearReluGradPattern"; - } + std::string name() const override { return "FusedLinearReluGradPattern"; } }; class FusedGemmEpiloguePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc index 63c76404b6136..1453426cc8df6 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_linear_param_grad_add_pass.cc @@ -96,9 +96,7 @@ class FusedMatmulAddGradAddPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("add_out"), &res.Tensor("dbias")}); } - std::string pattern_name() const override { - return "FusedMatmulAddGradAddPattern"; - } + std::string name() const override { return "FusedMatmulAddGradAddPattern"; } }; // matmul_grad + add_ -> matmul + fused_liner_param_gard_add @@ -161,9 +159,7 @@ class FusedMatmulGradAddPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } - std::string pattern_name() const override { - return "FusedMatmulGradAddPattern"; - } + std::string name() const override { return "FusedMatmulGradAddPattern"; } }; // matmul + 0 = add_(0,1) -> fused_liner_param_gard_add @@ -215,7 +211,7 @@ class FusedMatmulAddaPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } - std::string pattern_name() const override { return "FusedMatmulAddaPattern"; } + std::string name() const override { return "FusedMatmulAddaPattern"; } }; // matmul + 1 = add_(1,0) -> fused_liner_param_gard_add @@ -267,7 +263,7 @@ class FusedMatmulAddbPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("add_out"), &res.Tensor("dbias_out")}); } - std::string pattern_name() const override { return "FusedMatmulAddbPattern"; } + std::string name() const override { return "FusedMatmulAddbPattern"; } }; // add_grad + matmul + 0 = add_(0,1) -> fused_liner_param_gard_add @@ -327,9 +323,7 @@ class FusedMatmulAddGradAddaPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); } - std::string pattern_name() const override { - return "FusedMatmulAddGradAddaPattern"; - } + std::string name() const override { return "FusedMatmulAddGradAddaPattern"; } }; // add_grad + matmul + 1 = add_(1,0) -> fused_liner_param_gard_add @@ -389,9 +383,7 @@ class FusedMatmulAddGradAddbPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("dweight_out"), &res.Tensor("dbias")}); } - std::string pattern_name() const override { - return "FusedMatmulAddGradAddbPattern"; - } + std::string name() const override { return "FusedMatmulAddGradAddbPattern"; } }; class FusedLinearParamGradAddPass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc index 2cabca6a07b6c..df61b1eb25ba2 100644 --- a/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/fused_weight_only_linear_pass.cc @@ -128,9 +128,7 @@ class FusedWeightOnlyLinearPattern : public paddle::drr::DrrPatternBase { {&res.Tensor("add_out")}); } - std::string pattern_name() const override { - return "FusedWeightOnlyLinearPattern"; - } + std::string name() const override { return "FusedWeightOnlyLinearPattern"; } }; class FusedWeightOnlyLinearPass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc index 0eeff6c6598e3..cabd7a7274cb7 100644 --- a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -72,7 +72,7 @@ class MatmulScaleFusePattern : public paddle::drr::DrrPatternBase { {&res.Tensor("scale_out")}); } - std::string pattern_name() const override { return "MatmulScaleFusePattern"; } + std::string name() const override { return "MatmulScaleFusePattern"; } }; class MatmulScaleFusePass : public pir::PatternRewritePass { diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index cde25aaefe4ad..53210443eda4e 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -47,9 +47,7 @@ class RemoveUselessScalePattern : public paddle::drr::DrrPatternBase { res.Tensor("scale_out").Assign(res.Tensor("x")); } - std::string pattern_name() const override { - return "RemoveUselessScalePattern"; - } + std::string name() const override { return "RemoveUselessScalePattern"; } }; class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { @@ -121,9 +119,7 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { {&res.Tensor("scale_2_out")}); } - std::string pattern_name() const override { - return "RemoveRedundentScalePattern"; - } + std::string name() const override { return "RemoveRedundentScalePattern"; } }; class RemoveUselessCastPattern : public paddle::drr::DrrPatternBase { @@ -136,9 +132,7 @@ class RemoveUselessCastPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret").Assign(res.Tensor("arg0")); } - std::string pattern_name() const override { - return "RemoveUselessCastPattern"; - } + std::string name() const override { return "RemoveUselessCastPattern"; } }; class RemoveUselessConcatPattern : public paddle::drr::DrrPatternBase { @@ -158,9 +152,7 @@ class RemoveUselessConcatPattern : public paddle::drr::DrrPatternBase { res.Tensor("out").Assign(res.Tensor("x")); } - std::string pattern_name() const override { - return "RemoveUselessConcatPattern"; - } + std::string name() const override { return "RemoveUselessConcatPattern"; } }; class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { @@ -175,9 +167,7 @@ class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } - std::string pattern_name() const override { - return "RemoveRedundentCastPattern"; - } + std::string name() const override { return "RemoveRedundentCastPattern"; } }; class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { @@ -208,7 +198,7 @@ class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); } - std::string pattern_name() const override { + std::string name() const override { return "RemoveRedundentTransposePattern"; } }; diff --git a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc index fcffa97d4084d..342311cf76b77 100644 --- a/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_same_type_binding_test.cc @@ -179,9 +179,7 @@ class SameTypeBindingTestPattern : public paddle::drr::DrrPatternBase { res.Tensor("output6") = full_6(); } - std::string pattern_name() const override { - return "SameTypeBindingTestPattern"; - } + std::string name() const override { return "SameTypeBindingTestPattern"; } }; void BuildProgram(pir::Builder &builder) { // NOLINT diff --git a/test/cpp/pir/pattern_rewrite/drr_test.cc b/test/cpp/pir/pattern_rewrite/drr_test.cc index bfb59e39fe1b8..6efe87d8ca70c 100644 --- a/test/cpp/pir/pattern_rewrite/drr_test.cc +++ b/test/cpp/pir/pattern_rewrite/drr_test.cc @@ -42,9 +42,7 @@ class RemoveRedundentReshapePattern : public paddle::drr::DrrPatternBase { {&res.Tensor("ret"), &res.Tensor("xshape_1")}); } - std::string pattern_name() const override { - return "RemoveRedundentReshapePattern"; - } + std::string name() const override { return "RemoveRedundentReshapePattern"; } }; class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { @@ -82,9 +80,7 @@ class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = full2(); } - std::string pattern_name() const override { - return "FoldExpandToConstantPattern"; - } + std::string name() const override { return "FoldExpandToConstantPattern"; } }; class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { @@ -115,7 +111,7 @@ class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); } - std::string pattern_name() const override { + std::string name() const override { return "RemoveRedundentTransposePattern"; } }; @@ -132,9 +128,7 @@ class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } - std::string pattern_name() const override { - return "RemoveRedundentCastPattern"; - } + std::string name() const override { return "RemoveRedundentCastPattern"; } }; class RemoveUselessCastPattern : public paddle::drr::DrrPatternBase { @@ -147,9 +141,7 @@ class RemoveUselessCastPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret").Assign(res.Tensor("arg0")); } - std::string pattern_name() const override { - return "RemoveUselessCastPattern"; - } + std::string name() const override { return "RemoveUselessCastPattern"; } }; void BuildProgram(pir::Builder &builder) { // NOLINT From 7c12b9f6e76024e7541d948a6b1c1ee5e32a8507 Mon Sep 17 00:00:00 2001 From: yuanlehome Date: Mon, 15 Jan 2024 07:03:26 +0000 Subject: [PATCH 5/5] fix --- .../dialect/operator/transforms/pd_to_cinn_pass.cc | 10 +++++----- paddle/fluid/pir/drr/README.md | 4 +--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc index 3c503be702410..9a7db9b7a0a1f 100644 --- a/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc +++ b/paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc @@ -55,7 +55,7 @@ class SumOpPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0")); } - std::string pattern_name() const override { return "SumOpPattern"; } + std::string name() const override { return "SumOpPattern"; } }; class MaxOpPattern : public paddle::drr::DrrPatternBase { @@ -82,7 +82,7 @@ class MaxOpPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } - std::string pattern_name() const override { return "MaxOpPattern"; } + std::string name() const override { return "MaxOpPattern"; } }; class MinOpPattern : public paddle::drr::DrrPatternBase { @@ -109,7 +109,7 @@ class MinOpPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } - std::string pattern_name() const override { return "MinOpPattern"; } + std::string name() const override { return "MinOpPattern"; } }; class ProdOpPattern : public paddle::drr::DrrPatternBase { @@ -136,7 +136,7 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0")); } - std::string pattern_name() const override { return "ProdOpPattern"; } + std::string name() const override { return "ProdOpPattern"; } }; class ScaleOpPattern : public pir::OpRewritePattern { @@ -640,7 +640,7 @@ class UniformOpPattern : public paddle::drr::DrrPatternBase { res.Tensor("ret") = cinn_uniform(); } - std::string pattern_name() const override { return "ProdOpPattern"; } + std::string name() const override { return "ProdOpPattern"; } }; PdOpToCinnOpPass::PdOpToCinnOpPass() diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 6d320e61fb857..9b9790538d48a 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -32,9 +32,7 @@ class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); } - std::string pattern_name() const override { - return "RemoveRedundentCastPattern"; - } + std::string name() const override { return "RemoveRedundentCastPattern"; } }; ~~~