Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR][CUSTOM PASS] Part1: Reconstruct drr #60783

Merged
merged 6 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/dialect/shape/utils/shape_utils.h"
#include "paddle/pir/pass/pass.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
Expand Down
23 changes: 16 additions & 7 deletions paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/core/builtin_op.h"
#include "paddle/pir/pass/pass.h"
Expand All @@ -31,7 +30,7 @@ namespace cinn {
namespace dialect {
namespace ir {

class SumOpPattern : public paddle::drr::DrrPatternBase<SumOpPattern> {
class SumOpPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
Expand All @@ -55,9 +54,11 @@ class SumOpPattern : public paddle::drr::DrrPatternBase<SumOpPattern> {
{"keep_dim", pattern.Attr("keep_dim")}});
res.Tensor("ret") = cinn_reduce_sum(res.Tensor("arg0"));
}

std::string name() const override { return "SumOpPattern"; }
};

class MaxOpPattern : public paddle::drr::DrrPatternBase<MaxOpPattern> {
class MaxOpPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
Expand All @@ -80,9 +81,11 @@ class MaxOpPattern : public paddle::drr::DrrPatternBase<MaxOpPattern> {
{"keep_dim", pattern.Attr("keep_dim")}});
res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0"));
}

std::string name() const override { return "MaxOpPattern"; }
};

class MinOpPattern : public paddle::drr::DrrPatternBase<MinOpPattern> {
class MinOpPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
Expand All @@ -105,9 +108,11 @@ class MinOpPattern : public paddle::drr::DrrPatternBase<MinOpPattern> {
{"keep_dim", pattern.Attr("keep_dim")}});
res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0"));
}

std::string name() const override { return "MinOpPattern"; }
};

class ProdOpPattern : public paddle::drr::DrrPatternBase<ProdOpPattern> {
class ProdOpPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
Expand All @@ -130,6 +135,8 @@ class ProdOpPattern : public paddle::drr::DrrPatternBase<ProdOpPattern> {
{"keep_dim", pattern.Attr("keep_dim")}});
res.Tensor("ret") = cinn_reduce_max(res.Tensor("arg0"));
}

std::string name() const override { return "ProdOpPattern"; }
};

class ScaleOpPattern : public pir::OpRewritePattern<paddle::dialect::ScaleOp> {
Expand Down Expand Up @@ -586,7 +593,7 @@ class ExpandOpPattern
}
};

class UniformOpPattern : public paddle::drr::DrrPatternBase<UniformOpPattern> {
class UniformOpPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Source Pattern
Expand Down Expand Up @@ -632,6 +639,8 @@ class UniformOpPattern : public paddle::drr::DrrPatternBase<UniformOpPattern> {
{"diag_val", pattern.Attr("min_value")}});
res.Tensor("ret") = cinn_uniform();
}

std::string name() const override { return "ProdOpPattern"; }
};

PdOpToCinnOpPass::PdOpToCinnOpPass()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/dialect/shape/utils/dim_expr.h"
#include "paddle/pir/pass/pass.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/drr/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
file(GLOB DRR_SRCS "*.cc" "api/*.cc")
file(GLOB DRR_SRCS "*.cc" "include/*.cc")

set(op_creator_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_creator_drr_gen.py
Expand Down
16 changes: 10 additions & 6 deletions paddle/fluid/pir/drr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ DRR can reduce the development cost of PASS, allowing developers to focus on pro

Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows:
~~~ c++
// 1. Inherit specialized template class from DrPatternBase
class RemoveRedundentCastPattern
: public paddle::drr::DrrPatternBase<RemoveRedundentCastPattern> {
// 1. Inherit class from DrPatternBase
class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase {
// 2. Overload operator()
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 3. Define a SourcePattern containing two consecutive CastOps using Op, Tensor, and Attribute
Expand All @@ -32,6 +31,8 @@ class RemoveRedundentCastPattern
res.Op(paddle::dialect::CastOp::name(),
{{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0"));
}

std::string name() const override { return "RemoveRedundentCastPattern"; }
};
~~~

Expand Down Expand Up @@ -165,7 +166,7 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
## 3 Example
Example 1: Matmul + Add -> FusedGemmEpilogue
~~~ c++
class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern> {
class FusedLinearPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Define SourcePattern
Expand Down Expand Up @@ -193,13 +194,14 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
{&res.Tensor("out")});
}

std::string name() const override { return "FusedLinearPattern"; }
};
~~~

Example 2: Full + Expand -> Full
~~~ c++
class FoldExpandToConstantPattern
: public paddle::drr::DrrPatternBase<FoldExpandToConstantPattern> {
class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// Define SourcePattern
Expand All @@ -226,5 +228,7 @@ class FoldExpandToConstantPattern
{"place", pat.Attr("place_1")}});
res.Tensor("ret") = full2();
}

std::string name() const override { return "FoldExpandToConstantPattern"; }
};
~~~
16 changes: 10 additions & 6 deletions paddle/fluid/pir/drr/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P

以消除冗余 CastOp 的 PASS 为例,使用 DRR 的代码开发示例如下:
~~~ c++
// 1. 继承 DrrPatternBase 的特化模板类
class RemoveRedundentCastPattern
: public paddle::drr::DrrPatternBase<RemoveRedundentCastPattern> {
// 1. 继承 DrrPatternBase 类
class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase {
// 2. 重载 operator()
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 3. 使用 Op、Tensor 和 Attribute 定义一个包含两个连续 CastOp 的 SourcePattern
Expand All @@ -32,6 +31,8 @@ class RemoveRedundentCastPattern
res.Op(paddle::dialect::CastOp::name(),
{{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0"));
}

std::string name() const override { return "RemoveRedundentCastPattern"; }
};
~~~

Expand Down Expand Up @@ -168,7 +169,7 @@ Attribute Attr(const AttrComputeFunc& attr_compute_func) const</pre></td>
## 3 使用示例
Example 1: Matmul + Add -> FusedGemmEpilogue
~~~ c++
class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern> {
class FusedLinearPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 定义 Source Pattern
Expand Down Expand Up @@ -196,13 +197,14 @@ class FusedLinearPattern : public paddle::drr::DrrPatternBase<FusedLinearPattern
{&res.Tensor("x"), &res.Tensor("w"), &res.Tensor("bias")},
{&res.Tensor("out")});
}

std::string name() const override { return "FusedLinearPattern"; }
};
~~~

Example 2: Full + Expand -> Full
~~~ c++
class FoldExpandToConstantPattern
: public paddle::drr::DrrPatternBase<FoldExpandToConstantPattern> {
class FoldExpandToConstantPattern : public paddle::drr::DrrPatternBase {
public:
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
// 定义 Source Pattern
Expand All @@ -229,5 +231,7 @@ class FoldExpandToConstantPattern
{"place", pat.Attr("place_1")}});
res.Tensor("ret") = full2();
}

std::string name() const override { return "FoldExpandToConstantPattern"; }
};
~~~
36 changes: 0 additions & 36 deletions paddle/fluid/pir/drr/api/tensor_interface.cc

This file was deleted.

63 changes: 0 additions & 63 deletions paddle/fluid/pir/drr/api/tensor_interface.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
#include <memory>
#include <string>

#include "paddle/fluid/pir/drr/api/tensor_interface.h"
#include "paddle/fluid/pir/drr/ir_operation.h"
namespace pir {
class Value;
}

namespace paddle {
namespace drr {
Expand All @@ -30,7 +31,7 @@ class MatchContext final {
public:
MatchContext(std::shared_ptr<const MatchContextImpl> impl);

const TensorInterface& Tensor(const std::string& tensor_name) const;
const pir::Value& Tensor(const std::string& tensor_name) const;

template <typename T>
T Attr(const std::string& attr_name) const;
Expand Down
Loading