From a72ffa98d592619e497fbf8ee923d2d2a512e4f7 Mon Sep 17 00:00:00 2001 From: co63oc Date: Thu, 29 Feb 2024 08:21:12 +0800 Subject: [PATCH] Fix --- paddle/fluid/pir/drr/README.md | 4 ++-- paddle/fluid/pir/drr/README_cn.md | 4 ++-- paddle/fluid/pir/drr/src/attr_type_uilts.h | 6 ++--- .../fluid/pir/drr/src/ir_operation_factory.cc | 24 +++++++++---------- paddle/fluid/pir/drr/src/pattern_graph.cc | 24 +++++++++---------- paddle/fluid/pir/drr/src/pattern_graph.h | 2 +- paddle/fluid/pir/drr/src/rewrite_pattern.cc | 8 +++---- .../pir/transforms/identity_op_clean_pass.cc | 22 ++++++++--------- 8 files changed, 47 insertions(+), 47 deletions(-) diff --git a/paddle/fluid/pir/drr/README.md b/paddle/fluid/pir/drr/README.md index 1c5de89780c6f..d9b435160c41d 100644 --- a/paddle/fluid/pir/drr/README.md +++ b/paddle/fluid/pir/drr/README.md @@ -9,9 +9,9 @@ DRR can reduce the development cost of PASS, allowing developers to focus on pro Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows: ~~~ c++ // 1. Inherit class from DrPatternBase -class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase { public: - std::string name() const override { return "RemoveRedundentCastPattern"; } + std::string name() const override { return "RemoveRedundantCastPattern"; } // 2. Overload operator() void operator()(paddle::drr::DrrPatternContext *ctx) const override { diff --git a/paddle/fluid/pir/drr/README_cn.md b/paddle/fluid/pir/drr/README_cn.md index e621e7112ac30..c01b21febeda3 100644 --- a/paddle/fluid/pir/drr/README_cn.md +++ b/paddle/fluid/pir/drr/README_cn.md @@ -9,9 +9,9 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P 以消除冗余 CastOp 的 PASS 为例,使用 DRR 的代码开发示例如下: ~~~ c++ // 1. 继承 DrrPatternBase 类 -class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase { public: - std::string name() const override { return "RemoveRedundentCastPattern"; } + std::string name() const override { return "RemoveRedundantCastPattern"; } // 2. 重载 operator() void operator()(paddle::drr::DrrPatternContext *ctx) const override { diff --git a/paddle/fluid/pir/drr/src/attr_type_uilts.h b/paddle/fluid/pir/drr/src/attr_type_uilts.h index 02f5a4defc155..a48ed382a7d19 100644 --- a/paddle/fluid/pir/drr/src/attr_type_uilts.h +++ b/paddle/fluid/pir/drr/src/attr_type_uilts.h @@ -48,7 +48,7 @@ PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray, paddle::dialect::IntArrayAttribute); template -struct IrAttrbuteCreator { +struct IrAttributeCreator { typename CppTypeToIrAttribute::type operator()(T obj) const { return CppTypeToIrAttribute::type::template get( pir::IrContext::Instance(), obj); @@ -56,7 +56,7 @@ struct IrAttrbuteCreator { }; template <> -struct IrAttrbuteCreator> { +struct IrAttributeCreator> { pir::ArrayAttribute operator()(std::vector obj) const { std::vector attr_vec; attr_vec.reserve(obj.size()); @@ -69,7 +69,7 @@ struct IrAttrbuteCreator> { }; template <> -struct IrAttrbuteCreator> { +struct IrAttributeCreator> { pir::ArrayAttribute operator()(std::vector obj) const { std::vector attr_vec; attr_vec.reserve(obj.size()); diff --git a/paddle/fluid/pir/drr/src/ir_operation_factory.cc b/paddle/fluid/pir/drr/src/ir_operation_factory.cc index f792ccbdaff92..cc92c223feacf 100644 --- a/paddle/fluid/pir/drr/src/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/src/ir_operation_factory.cc @@ -65,33 +65,33 @@ void OperationFactory::RegisterManualOpCreator() { pir::Attribute CreateIrAttribute(const std::any& obj) { if (obj.type() == typeid(bool)) { - return IrAttrbuteCreator()(std::any_cast(obj)); + return IrAttributeCreator()(std::any_cast(obj)); } else if (obj.type() == typeid(int32_t)) { - return IrAttrbuteCreator()(std::any_cast(obj)); + return IrAttributeCreator()(std::any_cast(obj)); } else if (obj.type() == typeid(int64_t)) { - return IrAttrbuteCreator()(std::any_cast(obj)); + return IrAttributeCreator()(std::any_cast(obj)); } else if (obj.type() == typeid(float)) { - return IrAttrbuteCreator()(std::any_cast(obj)); + return IrAttributeCreator()(std::any_cast(obj)); } else if (obj.type() == typeid(std::string)) { - return IrAttrbuteCreator()(std::any_cast(obj)); + return IrAttributeCreator()(std::any_cast(obj)); } else if (obj.type() == typeid(const char*)) { - return IrAttrbuteCreator()(std::any_cast(obj)); + return IrAttributeCreator()(std::any_cast(obj)); } else if (obj.type() == typeid(phi::DataType)) { - return IrAttrbuteCreator()( + return IrAttributeCreator()( std::any_cast(obj)); } else if (obj.type() == typeid(phi::Place)) { - return IrAttrbuteCreator()(std::any_cast(obj)); + return IrAttributeCreator()(std::any_cast(obj)); } else if (obj.type() == typeid(std::vector)) { - return IrAttrbuteCreator>()( + return IrAttributeCreator>()( std::any_cast>(obj)); } else if (obj.type() == typeid(std::vector)) { - return IrAttrbuteCreator>()( + return IrAttributeCreator>()( std::any_cast>(obj)); } else if (obj.type() == typeid(std::vector)) { - return IrAttrbuteCreator>()( + return IrAttributeCreator>()( std::any_cast>(obj)); } else if (obj.type() == typeid(phi::IntArray)) { - return IrAttrbuteCreator()( + return IrAttributeCreator()( std::any_cast(obj)); } else { PADDLE_THROW( diff --git a/paddle/fluid/pir/drr/src/pattern_graph.cc b/paddle/fluid/pir/drr/src/pattern_graph.cc index a8c72a064d0b8..be57150ed8ffd 100644 --- a/paddle/fluid/pir/drr/src/pattern_graph.cc +++ b/paddle/fluid/pir/drr/src/pattern_graph.cc @@ -147,8 +147,8 @@ void GraphTopo::WalkGraphNodesTopoOrder( const std::unordered_set &inputs_tensor = graph_->input_tensors(); const std::unordered_map> - &id2owned_tensor = graph_->id2owend_tensor(); - const std::vector> &owend_opcall = + &id2owned_tensor = graph_->id2owned_tensor(); + const std::vector> &owned_opcall = graph_->owned_op_call(); std::queue opcall_queue; @@ -156,7 +156,7 @@ void GraphTopo::WalkGraphNodesTopoOrder( opcall_dependent; // init opcall_dependent - for (const std::shared_ptr &opcall_sptr : owend_opcall) { + for (const std::shared_ptr &opcall_sptr : owned_opcall) { if (opcall_sptr.get()->inputs().empty()) { // opcall inputs is empty opcall_queue.push(opcall_sptr.get()); } else { @@ -174,11 +174,11 @@ void GraphTopo::WalkGraphNodesTopoOrder( "The input tensor [%s] must exists " "in pattern graph to be obtained.", tensor_name)); - for (const auto &tensor_comsumer : + for (const auto &tensor_consumer : id2owned_tensor.at(tensor_name).get()->consumers()) { - opcall_dependent[tensor_comsumer].erase(tensor_name); - if (opcall_dependent[tensor_comsumer].empty()) { - opcall_queue.push(tensor_comsumer); + opcall_dependent[tensor_consumer].erase(tensor_name); + if (opcall_dependent[tensor_consumer].empty()) { + opcall_queue.push(tensor_consumer); } } } @@ -190,10 +190,10 @@ void GraphTopo::WalkGraphNodesTopoOrder( // update opcall_dependent for (const auto &output_tensor : opcall->outputs()) { - for (const auto &tensor_comsumer : output_tensor->consumers()) { - opcall_dependent[tensor_comsumer].erase(output_tensor->name()); - if (opcall_dependent[tensor_comsumer].empty()) { - opcall_queue.push(tensor_comsumer); + for (const auto &tensor_consumer : output_tensor->consumers()) { + opcall_dependent[tensor_consumer].erase(output_tensor->name()); + if (opcall_dependent[tensor_consumer].empty()) { + opcall_queue.push(tensor_consumer); } } } @@ -202,7 +202,7 @@ void GraphTopo::WalkGraphNodesTopoOrder( std::ostream &operator<<(std::ostream &os, const PatternGraph &pattern_graph) { os << "\nAll Tensors:\n"; - for (const auto &kv : pattern_graph.id2owend_tensor()) { + for (const auto &kv : pattern_graph.id2owned_tensor()) { os << " " << kv.first; } os << "\n\n"; diff --git a/paddle/fluid/pir/drr/src/pattern_graph.h b/paddle/fluid/pir/drr/src/pattern_graph.h index e5cd74b2fa217..7243c99bfc853 100644 --- a/paddle/fluid/pir/drr/src/pattern_graph.h +++ b/paddle/fluid/pir/drr/src/pattern_graph.h @@ -57,7 +57,7 @@ class PatternGraph { } const std::unordered_map>& - id2owend_tensor() const { + id2owned_tensor() const { return id2owned_tensor_; } diff --git a/paddle/fluid/pir/drr/src/rewrite_pattern.cc b/paddle/fluid/pir/drr/src/rewrite_pattern.cc index 68a7b14f81a3e..13d61a304f097 100644 --- a/paddle/fluid/pir/drr/src/rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/src/rewrite_pattern.cc @@ -58,7 +58,7 @@ bool DrrRewritePattern::MatchAndRewrite( if (PatternGraphMatch(op, src_match_ctx.get())) { VLOG(4) << "DRR pattern (" << pattern_name_ << ") is matched in program."; PatternGraphRewrite(*src_match_ctx, rewriter); - VLOG(4) << "DRR pattern (" << pattern_name_ << ") is rewrited in program."; + VLOG(4) << "DRR pattern (" << pattern_name_ << ") is rewritten in program."; return true; } return false; @@ -414,13 +414,13 @@ MatchContextImpl DrrRewritePattern::CreateOperations( // add input tensors info for res_match_ctx for (const auto& in_tensor : result_pattern_graph.input_tensors()) { PADDLE_ENFORCE_NE( - result_pattern_graph.id2owend_tensor().count(in_tensor), + result_pattern_graph.id2owned_tensor().count(in_tensor), 0, phi::errors::NotFound("Not found the input tensor." "Drr input tensor [%s] must exist in the result " "pattern graph to be obtained.", in_tensor)); - if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) { + if (!result_pattern_graph.id2owned_tensor().at(in_tensor)->is_none()) { res_match_ctx.BindIrValue(in_tensor, src_match_ctx.GetIrValue(in_tensor)); } } @@ -508,7 +508,7 @@ void DrrRewritePattern::ReplaceOutputTensor( const MatchContextImpl& res_match_ctx, pir::PatternRewriter& rewriter) const { // NOLINT for (const auto& output_name : result_pattern_graph_->output_tensors()) { - if (source_pattern_graph_->id2owend_tensor().count(output_name)) { + if (source_pattern_graph_->id2owned_tensor().count(output_name)) { const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name); const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name); rewriter.ReplaceAllUsesWith(src_ir_tensor, res_ir_tensor); diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc index cf27800512b0b..f312ae2f27bd3 100644 --- a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -53,9 +53,9 @@ class RemoveUselessScalePattern : public paddle::drr::DrrPatternBase { } }; -class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantScalePattern : public paddle::drr::DrrPatternBase { public: - std::string name() const override { return "RemoveRedundentScalePattern"; } + std::string name() const override { return "RemoveRedundantScalePattern"; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { paddle::drr::SourcePattern pat = ctx->SourcePattern(); @@ -83,7 +83,7 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { paddle::drr::ResultPattern res = pat.ResultPattern(); - const auto &bais_attr = res.ComputeAttr( + const auto &bias_attr = res.ComputeAttr( [](const paddle::drr::MatchContext &match_ctx) -> float { float res_bias_1 = 0.f; float res_bias_2 = 0.f; @@ -115,7 +115,7 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase { {"place", pat.Attr("place_1")}}); const auto &scale_op_res = res.Op("pd_op.scale", - {{"bias", bais_attr}, {"bias_after_scale", res.BoolAttr(true)}}); + {{"bias", bias_attr}, {"bias_after_scale", res.BoolAttr(true)}}); scale_op_res({&res.Tensor("x"), &full_op_res()}, {&res.Tensor("scale_2_out")}); } @@ -154,9 +154,9 @@ class RemoveUselessConcatPattern : public paddle::drr::DrrPatternBase { } }; -class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase { public: - std::string name() const override { return "RemoveRedundentCastPattern"; } + std::string name() const override { return "RemoveRedundantCastPattern"; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { auto pat = ctx->SourcePattern(); @@ -245,10 +245,10 @@ class ReplaceDropoutWithScalePattern : public paddle::drr::DrrPatternBase { } }; -class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase { +class RemoveRedundantTransposePattern : public paddle::drr::DrrPatternBase { public: std::string name() const override { - return "RemoveRedundentTransposePattern"; + return "RemoveRedundantTransposePattern"; } void operator()(paddle::drr::DrrPatternContext *ctx) const override { @@ -286,13 +286,13 @@ class IdentityOpCleanPass : public pir::PatternRewritePass { pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { pir::RewritePatternSet ps(context); ps.Add(paddle::drr::Create(context)); - ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); - ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); ps.Add(paddle::drr::Create(context)); - ps.Add(paddle::drr::Create(context)); + ps.Add(paddle::drr::Create(context)); return ps; } };