From 3d5fbe3be00dfcf4c61de363237783b65a94de5a Mon Sep 17 00:00:00 2001 From: Yuanle Liu Date: Tue, 12 Dec 2023 22:20:02 +0800 Subject: [PATCH] [PIR] add identity_op_clean_pass and matmul_scale_fuse_pass (#59840) * add identity_op_clean_pass * update * Rename test_identity_op_clean_pass.py to test_pir_identity_op_clean_pass.py * Rename test_matmul_scale_fuse_pass.py to test_pir_matmul_scale_fuse_pass.py * update * update drr * new_executor_sequential_run --- .../fluid/inference/api/analysis_predictor.cc | 4 + paddle/fluid/pir/drr/attr_type_uilts.h | 25 ++ paddle/fluid/pir/drr/drr_rewrite_pattern.cc | 16 +- paddle/fluid/pir/drr/ir_operation_factory.cc | 14 + .../fusion/matmul_scale_fuse_pass.cc | 102 +++++++ .../fusion/matmul_scale_fuse_pass.h | 26 ++ .../pir/transforms/identity_op_clean_pass.cc | 233 ++++++++++++++++ .../pir/transforms/identity_op_clean_pass.h | 26 ++ paddle/fluid/pybind/pir.cc | 2 + test/ir/pir/fused_pass/pass_test.py | 1 + .../test_pir_matmul_scale_fuse_pass.py | 93 +++++++ .../ir/pir/test_pir_identity_op_clean_pass.py | 254 ++++++++++++++++++ 12 files changed, 790 insertions(+), 6 deletions(-) create mode 100644 paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc create mode 100644 paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h create mode 100644 paddle/fluid/pir/transforms/identity_op_clean_pass.cc create mode 100644 paddle/fluid/pir/transforms/identity_op_clean_pass.h create mode 100644 test/ir/pir/fused_pass/test_pir_matmul_scale_fuse_pass.py create mode 100644 test/ir/pir/test_pir_identity_op_clean_pass.py diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 476c78638c47f..5cddb1db14989 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1734,6 +1734,7 @@ void AnalysisPredictor::PrepareArgument() { argument_->SetEnableIrOptim(true); pass_builder->ClearPasses(); pass_builder->AppendPass("auto_mixed_precision_pass"); + pass_builder->AppendPass("inplace_op_var_pass"); LOG(INFO) << "This model run in GPU mixed precision mode with no ir " "optimization."; } else { @@ -1918,6 +1919,9 @@ CreatePaddlePredictor( if (std::getenv("FLAGS_initial_cpu_memory_in_mb") == nullptr) { SetGflag("initial_cpu_memory_in_mb", "0"); } + if (std::getenv("FLAGS_new_executor_sequential_run") == nullptr) { + SetGflag("new_executor_sequential_run", "1"); + } }); if (config.thread_local_stream_enabled() && diff --git a/paddle/fluid/pir/drr/attr_type_uilts.h b/paddle/fluid/pir/drr/attr_type_uilts.h index 28b26ba26a2a1..4043aa3c64383 100644 --- a/paddle/fluid/pir/drr/attr_type_uilts.h +++ b/paddle/fluid/pir/drr/attr_type_uilts.h @@ -43,6 +43,7 @@ PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, pir::ArrayAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, paddle::dialect::IntArrayAttribute); +PD_SPECIALIZE_CppTypeToIrAttribute(std::vector, pir::ArrayAttribute); PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray, paddle::dialect::IntArrayAttribute); @@ -66,6 +67,18 @@ struct IrAttrbuteCreator> { } }; +template <> +struct IrAttrbuteCreator> { + pir::ArrayAttribute operator()(std::vector obj) const { + std::vector attr_vec; + attr_vec.reserve(obj.size()); + for (float x : obj) { + attr_vec.push_back(FloatAttribute::get(pir::IrContext::Instance(), x)); + } + return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec); + } +}; + template struct IrAttrTypeCast { static T To(const pir::Attribute& attr) { @@ -114,5 +127,17 @@ struct IrAttrTypeCast> { } }; +template <> +struct IrAttrTypeCast> { + static std::vector To(const pir::Attribute& attr) { + std::vector result; + auto array_attr = attr.dyn_cast(); + for (size_t i = 0; i < array_attr.size(); i++) { + result.push_back(array_attr.at(i).dyn_cast().data()); + } + return result; + } +}; + } // namespace drr } // namespace pir diff --git a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc index 91be95e788805..9be191d2d6c43 100644 --- a/paddle/fluid/pir/drr/drr_rewrite_pattern.cc +++ b/paddle/fluid/pir/drr/drr_rewrite_pattern.cc @@ -43,7 +43,7 @@ bool DrrRewritePattern::PatternGraphMatch( std::vector drr_output_sequence; std::vector ir_output_sequence; std::unordered_map output_op_map; - for (auto pair : bind_map) { + for (const auto& pair : bind_map) { drr_output_sequence.push_back(pair.first); } // using dfs to obtain the arrangement of all candidate ir ops @@ -396,10 +396,11 @@ MatchContextImpl DrrRewritePattern::CreateOperations( Value ir_val = res_match_ctx.GetIrValue(input->name()).get(); if (ir_val) { Operation* ir_input_op = ir_val.dyn_cast().owner(); - if (max_input_op_index < op_2_temp_program_index[ir_input_op]) { - max_input_op_index = op_2_temp_program_index[ir_input_op]; + if (max_input_op_index < op_2_temp_program_index.at(ir_input_op)) { + max_input_op_index = op_2_temp_program_index.at(ir_input_op); max_index_op = ir_input_op; - } else if (max_input_op_index == op_2_temp_program_index[ir_input_op]) { + } else if (max_input_op_index == + op_2_temp_program_index.at(ir_input_op)) { const auto& ops_vec = temp_program[max_input_op_index]; for (auto it = ops_vec.begin(); it != ops_vec.end(); it++) { if (*it == max_index_op) { @@ -430,6 +431,9 @@ MatchContextImpl DrrRewritePattern::CreateOperations( Operation* new_op = CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx); op_2_temp_program_index[new_op] = max_input_op_index + 1; + if (max_input_op_index + 1 >= temp_program.size()) { + temp_program.push_back({}); + } temp_program[max_input_op_index + 1].push_back(new_op); }); @@ -471,13 +475,13 @@ void DrrRewritePattern::DeleteSourcePatternOp( const ResultPatternGraph& result_pattern_graph, const MatchContextImpl& src_match_ctx, pir::PatternRewriter& rewriter) const { // NOLINT - std::queue delete_ops_que; std::unordered_set delete_ops_set; GraphTopo graph_topo_visit(&source_pattern_graph); graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) { Operation* op = src_match_ctx.Operation(&op_call).get(); - if (op->use_empty()) { + VLOG(5) << "DRR delete op: " << op->name() << " pointer: " << op; + if (delete_ops_set.count(op) == 0 && op->use_empty()) { delete_ops_que.push(op); delete_ops_set.insert(op); } diff --git a/paddle/fluid/pir/drr/ir_operation_factory.cc b/paddle/fluid/pir/drr/ir_operation_factory.cc index f1a565090c92e..6644026fabde0 100644 --- a/paddle/fluid/pir/drr/ir_operation_factory.cc +++ b/paddle/fluid/pir/drr/ir_operation_factory.cc @@ -57,6 +57,17 @@ void OperationFactory::RegisterManualOpCreator() { pir::PatternRewriter& rewriter) { return rewriter.Build(inputs); }); + RegisterOperationCreator( + "pd_op.scale", + [](const std::vector& inputs, + const pir::AttributeMap& attrs, + pir::PatternRewriter& rewriter) { + return rewriter.Build( + inputs[0].dyn_cast(), + inputs[1].dyn_cast(), + attrs.at("bias").dyn_cast().data(), + attrs.at("bias_after_scale").dyn_cast().data()); + }); } static pir::Attribute CreateIrAttribute(const std::any& obj) { @@ -83,6 +94,9 @@ static pir::Attribute CreateIrAttribute(const std::any& obj) { } else if (obj.type() == typeid(std::vector)) { return IrAttrbuteCreator>()( std::any_cast>(obj)); + } else if (obj.type() == typeid(std::vector)) { + return IrAttrbuteCreator>()( + std::any_cast>(obj)); } else if (obj.type() == typeid(phi::IntArray)) { return IrAttrbuteCreator()( std::any_cast(obj)); diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc new file mode 100644 index 0000000000000..627c1cd516cc8 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" + +#include "paddle/common/ddim.h" + +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class MatmulScaleFusePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &matmul_op = pat.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + + matmul_op({&pat.Tensor("x"), &pat.Tensor("y")}, + {&pat.Tensor("matmul_out")}); + const auto &full_op = pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape")}, + {"value", pat.Attr("value")}, + {"dtype", pat.Attr("dtype")}, + {"place", pat.Attr("place")}}); + const auto &scale_op = + pat.Op(paddle::dialect::ScaleOp::name(), + {{"bias", pat.Attr("bias")}, + {"bias_after_scale", pat.Attr("bias_after_scale")}}); + scale_op({&pat.Tensor("matmul_out"), &full_op()}, + {&pat.Tensor("scale_out")}); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return std::abs(match_ctx.Attr("bias")) <= 1e-6; + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &full_op_res = res.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape")}, + {"value", pat.Attr("value")}, + {"dtype", pat.Attr("dtype")}, + {"place", pat.Attr("place")}}); + const auto &scale_op_res = + res.Op(paddle::dialect::ScaleOp::name(), + {{"bias", + res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + return 0.0; + })}, + {"bias_after_scale", pat.Attr("bias_after_scale")}}); + const auto &matmul_op_res = + res.Op(paddle::dialect::MatmulOp::name(), + {{"transpose_x", pat.Attr("transpose_x")}, + {"transpose_y", pat.Attr("transpose_y")}}); + scale_op_res({&res.Tensor("y"), &full_op_res()}, + {&res.Tensor("scale_res_out")}); + matmul_op_res({&res.Tensor("x"), &res.Tensor("scale_res_out")}, + {&res.Tensor("scale_out")}); + } +}; + +class MatmulScaleFusePass : public pir::PatternRewritePass { + public: + MatmulScaleFusePass() + : pir::PatternRewritePass("matmul_scale_fuse_pass", 2) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(MatmulScaleFusePattern().Build(context)); + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateMatmulScaleFusePass() { + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(matmul_scale_fuse_pass, MatmulScaleFusePass); diff --git a/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h new file mode 100644 index 0000000000000..d2d6a1b923745 --- /dev/null +++ b/paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateMatmulScaleFusePass(); + +} // namespace pir diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.cc b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc new file mode 100644 index 0000000000000..6f6c684e373c7 --- /dev/null +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.cc @@ -0,0 +1,233 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/pir/transforms/identity_op_clean_pass.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" +#include "paddle/fluid/pir/drr/api/drr_pattern_base.h" +#include "paddle/fluid/pir/transforms/fusion/conv2d_add_fuse_pass.h" +#include "paddle/fluid/pir/transforms/transform_general_functions.h" + +#include "paddle/common/ddim.h" + +#include "paddle/pir/core/builtin_op.h" +#include "paddle/pir/pass/pass.h" +#include "paddle/pir/pass/pass_registry.h" +#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h" + +namespace { + +class RemoveUselessScalePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &full_op = pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape")}, + {"value", pat.Attr("value")}, + {"dtype", pat.Attr("dtype")}, + {"place", pat.Attr("place")}}); + const auto &scale_op = + pat.Op(paddle::dialect::ScaleOp::name(), + {{"bias", pat.Attr("bias")}, + {"bias_after_scale", pat.Attr("bias_after_scale")}}); + scale_op({&pat.Tensor("x"), &full_op()}, {&pat.Tensor("scale_out")}); + + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + return (match_ctx.Attr("value") == 1.0 && + match_ctx.Attr("bias") == 0.0); + }); + + pir::drr::ResultPattern res = pat.ResultPattern(); + res.Tensor("scale_out").Assign(res.Tensor("x")); + } +}; + +class RemoveRedundentScalePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &full_op_1 = pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape_1")}, + {"value", pat.Attr("value_1")}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + const auto &scale_op_1 = + pat.Op(paddle::dialect::ScaleOp::name(), + {{"bias", pat.Attr("bias_1")}, + {"bias_after_scale", pat.Attr("bias_after_scale_1")}}); + const auto &full_op_2 = pat.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape_2")}, + {"value", pat.Attr("value_2")}, + {"dtype", pat.Attr("dtype_2")}, + {"place", pat.Attr("place_2")}}); + const auto &scale_op_2 = + pat.Op(paddle::dialect::ScaleOp::name(), + {{"bias", pat.Attr("bias_2")}, + {"bias_after_scale", pat.Attr("bias_after_scale_2")}}); + scale_op_1({&pat.Tensor("x"), &full_op_1()}, {&pat.Tensor("scale_1_out")}); + scale_op_2({&pat.Tensor("scale_1_out"), &full_op_2()}, + {&pat.Tensor("scale_2_out")}); + + pir::drr::ResultPattern res = pat.ResultPattern(); + + const auto &bais_res = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + float res_bias_1 = 0.f; + float res_bias_2 = 0.f; + if (match_ctx.Attr("bias_after_scale_1")) { + res_bias_1 = match_ctx.Attr("bias_1"); + } else { + res_bias_1 = match_ctx.Attr("value_1") * + match_ctx.Attr("bias_1"); + } + if (match_ctx.Attr("bias_after_scale_2")) { + res_bias_2 = res_bias_1 * match_ctx.Attr("value_2") + + match_ctx.Attr("bias_2"); + } else { + res_bias_2 = (res_bias_1 + match_ctx.Attr("bias_2")) * + match_ctx.Attr("value_2"); + } + return res_bias_2; + }); + const auto &res_scale_input = + res.Attr([](const pir::drr::MatchContext &match_ctx) -> float { + return match_ctx.Attr("value_1") * + match_ctx.Attr("value_2"); + }); + + const auto &full_op_res = res.Op(paddle::dialect::FullOp::name(), + {{"shape", pat.Attr("shape_1")}, + {"value", res_scale_input}, + {"dtype", pat.Attr("dtype_1")}, + {"place", pat.Attr("place_1")}}); + const auto &scale_op_res = + res.Op("pd_op.scale", + {{"bias", bais_res}, + {"bias_after_scale", + res.Attr([](const pir::drr::MatchContext &match_ctx) -> bool { + return true; + })}}); + scale_op_res({&res.Tensor("x"), &full_op_res()}, + {&res.Tensor("scale_2_out")}); + } +}; + +class RemoveUselessCastPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + pat.Tensor("ret") = pat.Op("pd_op.cast")(pat.Tensor("arg0")); + pat.RequireEqual(pat.Tensor("ret").dtype(), pat.Tensor("arg0").dtype()); + auto res = pat.ResultPattern(); + res.Tensor("ret").Assign(res.Tensor("arg0")); + } +}; + +class RemoveUselessConcatPattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + const auto &combine = pat.Op(pir::CombineOp::name()); + combine({&pat.Tensor("x")}, {&pat.Tensor("combine_out")}); + pat.Tensor("out") = pat.Op(paddle::dialect::ConcatOp::name())( + pat.Tensor("combine_out"), pat.Tensor("axis")); + pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) { + auto x_type = dynamic_cast( + match_ctx.Tensor("combine_out")) + .get() + .type(); + return x_type.isa() && + x_type.dyn_cast().size() == 1; + }); + auto res = pat.ResultPattern(); + res.Tensor("out").Assign(res.Tensor("x")); + } +}; + +class RemoveRedundentCastPattern + : public pir::drr::DrrPatternBase { + void operator()(pir::drr::DrrPatternContext *ctx) const override { + auto pat = ctx->SourcePattern(); + pat.Tensor("tmp") = pat.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype1")}})(pat.Tensor("arg0")); + pat.Tensor("ret") = pat.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(pat.Tensor("tmp")); + auto res = pat.ResultPattern(); + res.Tensor("ret") = res.Op( + "pd_op.cast", {{"dtype", pat.Attr("dtype2")}})(res.Tensor("arg0")); + } +}; + +class RemoveRedundentTransposePattern + : public pir::drr::DrrPatternBase { + public: + void operator()(pir::drr::DrrPatternContext *ctx) const override { + pir::drr::SourcePattern pat = ctx->SourcePattern(); + const auto &transpose1 = + pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_1")}}); + const auto &transpose2 = + pat.Op("pd_op.transpose", {{"perm", pat.Attr("perm_2")}}); + + pat.Tensor("ret") = transpose2(transpose1(pat.Tensor("arg_transpose"))); + + pir::drr::ResultPattern res = pat.ResultPattern(); + const auto &new_perm_attr = res.Attr( + [](const pir::drr::MatchContext &match_ctx) -> std::vector { + const auto &perm1 = match_ctx.Attr>("perm_1"); + const auto &perm2 = match_ctx.Attr>("perm_2"); + std::vector new_perm; + for (int v : perm2) { + new_perm.emplace_back(perm1[v]); + } + return new_perm; + }); + const auto &tranpose_continuous = + res.Op("pd_op.transpose", {{"perm", new_perm_attr}}); + + res.Tensor("ret") = tranpose_continuous(res.Tensor("arg_transpose")); + } +}; + +class IdentityOpCleanPass : public pir::PatternRewritePass { + public: + IdentityOpCleanPass() + : pir::PatternRewritePass("identity_op_clean_pass", 2) {} + + pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { + pir::RewritePatternSet ps(context); + ps.Add(RemoveUselessScalePattern().Build(context)); + ps.Add(RemoveRedundentScalePattern().Build(context)); + ps.Add(RemoveUselessCastPattern().Build(context)); + ps.Add(RemoveUselessConcatPattern().Build(context)); + ps.Add(RemoveRedundentCastPattern().Build(context)); + ps.Add(RemoveRedundentTransposePattern().Build(context)); + return ps; + } +}; + +} // namespace + +namespace pir { + +std::unique_ptr CreateIdentityOpCleanPass() { + return std::make_unique(); +} +} // namespace pir + +REGISTER_IR_PASS(identity_op_clean_pass, IdentityOpCleanPass); diff --git a/paddle/fluid/pir/transforms/identity_op_clean_pass.h b/paddle/fluid/pir/transforms/identity_op_clean_pass.h new file mode 100644 index 0000000000000..6f2d6dae46f70 --- /dev/null +++ b/paddle/fluid/pir/transforms/identity_op_clean_pass.h @@ -0,0 +1,26 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/pir/core/dll_decl.h" + +namespace pir { + +class Pass; + +IR_API std::unique_ptr CreateIdentityOpCleanPass(); + +} // namespace pir diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index feb37821e58ef..a243ada3d2cd0 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -103,6 +103,8 @@ USE_PIR_PASS(fused_weight_only_linear_pass); USE_PIR_PASS(fused_linear_param_grad_add_pass); USE_PIR_PASS(inplace_pass); USE_PIR_PASS(replace_fetch_with_shadow_output_pass); +USE_PIR_PASS(identity_op_clean_pass); +USE_PIR_PASS(matmul_scale_fuse_pass); USE_PIR_PASS(conv2d_bn_fuse_pass); USE_PIR_PASS(conv2d_add_fuse_pass); USE_PIR_PASS(conv2d_add_act_fuse_pass); diff --git a/test/ir/pir/fused_pass/pass_test.py b/test/ir/pir/fused_pass/pass_test.py index ae7bab43618c4..236f4834dac15 100644 --- a/test/ir/pir/fused_pass/pass_test.py +++ b/test/ir/pir/fused_pass/pass_test.py @@ -79,6 +79,7 @@ def check_pass_correct(self, atol=1e-5): executor = paddle.static.Executor(paddle.base.CPUPlace()) elif self.place_runtime == "gpu": executor = paddle.static.Executor(paddle.base.CUDAPlace(0)) + for program, need_translate_to_pir in self.sample_program(): if need_translate_to_pir: program = pir.translate_to_pir(program.desc) diff --git a/test/ir/pir/fused_pass/test_pir_matmul_scale_fuse_pass.py b/test/ir/pir/fused_pass/test_pir_matmul_scale_fuse_pass.py new file mode 100644 index 0000000000000..3ae1306de7bcd --- /dev/null +++ b/test/ir/pir/fused_pass/test_pir_matmul_scale_fuse_pass.py @@ -0,0 +1,93 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from pass_test import PassTest + +import paddle + +paddle.enable_static() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestMatmulScaleFusePattern(PassTest): + r""" + x_var f_var + \ / + matmul + | + scale + """ + + def is_program_valid(self, program=None): + return True + + def sample_program(self): + for x_shape in [[3, 2]]: + for w_shape in [[2, 3]]: + for scale_bias in [1e-7]: + for scale_value in [2.0]: + for bias_after_scale in [True]: + pir_program = None + with paddle.pir_utils.IrGuard(): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): + x = paddle.static.data( + name='x', shape=x_shape, dtype='float32' + ) + w = paddle.static.data( + name='w', shape=w_shape, dtype='float32' + ) + out = paddle.scale( + paddle.matmul(x, w), + scale=scale_value, + bias=scale_bias, + bias_after_scale=bias_after_scale, + ) + + self.pass_list = ['matmul_scale_fuse_pass'] + self.feeds = { + "x": np.random.random(x_shape).astype( + "float32" + ), + "w": np.random.random(w_shape).astype( + "float32" + ), + } + self.fetch_list = [out] + self.valid_op_map = { + "pd_op.scale": 1, + "pd_op.matmul": 1, + } + yield pir_program, False + + def setUp(self): + self.place_runtime = "gpu" + + def test_check_output(self): + self.check_pass_correct() + + +class TestMatmulScaleFusePatternWtihCpu(TestMatmulScaleFusePattern): + def setUp(self): + self.place_runtime = "cpu" + + +if __name__ == "__main__": + unittest.main() diff --git a/test/ir/pir/test_pir_identity_op_clean_pass.py b/test/ir/pir/test_pir_identity_op_clean_pass.py new file mode 100644 index 0000000000000..c69fec9143486 --- /dev/null +++ b/test/ir/pir/test_pir_identity_op_clean_pass.py @@ -0,0 +1,254 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from fused_pass.pass_test import PassTest + +import paddle + +paddle.enable_static() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestRemoveUselessScalePattern(PassTest): + def is_program_valid(self, program=None): + return True + + def build_ir_progam(self): + pir_program = None + with paddle.pir_utils.IrGuard(): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): + x = paddle.static.data( + name='x', shape=[3, 1, 28, 28], dtype='float32' + ) + out = paddle.scale(x, scale=1.0, bias=0.0) + self.pass_list = ['identity_op_clean_pass'] + self.feeds = {"x": np.random.random((3, 1, 28, 28)).astype("float32")} + self.fetch_list = [out] + self.valid_op_map = {"pd_op.scale": 0} + return pir_program + + def sample_program(self): + pir_program = self.build_ir_progam() + yield pir_program, False + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestRemoveRedundentScalePattern(PassTest): + def is_program_valid(self, program=None): + return True + + def sample_program(self): + for bias_after_scale_1 in [True, False]: + for bias_after_scale_2 in [True, False]: + pir_program = None + with paddle.pir_utils.IrGuard(): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): + x = paddle.static.data( + name='x', shape=[3, 1, 28, 28], dtype='float32' + ) + scale_out1 = paddle.scale( + x, + scale=2.0, + bias=1.0, + bias_after_scale=bias_after_scale_1, + ) + out = paddle.scale( + scale_out1, + scale=2.0, + bias=2.0, + bias_after_scale=bias_after_scale_2, + ) + self.pass_list = ['identity_op_clean_pass'] + self.feeds = { + "x": np.random.random((3, 1, 28, 28)).astype("float32") + } + self.fetch_list = [out] + self.valid_op_map = {"pd_op.scale": 1} + yield pir_program, False + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestRemoveUselessCastPattern(PassTest): + def is_program_valid(self, program=None): + return True + + def sample_program(self): + for tmp_type in ['float32', 'float16']: + pir_program = None + with paddle.pir_utils.IrGuard(): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): + x = paddle.static.data( + name='x', shape=[3, 1, 28, 28], dtype=tmp_type + ) + out = paddle.cast(x, tmp_type) + self.pass_list = ['identity_op_clean_pass'] + self.feeds = { + "x": np.random.random((3, 1, 28, 28)).astype(tmp_type) + } + self.fetch_list = [out] + self.valid_op_map = {"pd_op.cast": 0} + yield pir_program, False + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestRemoveUselessConcatPattern(PassTest): + def is_program_valid(self, program=None): + return True + + def sample_program(self): + pir_program = None + with paddle.pir_utils.IrGuard(): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): + x_input = paddle.static.data( + name='x_input', shape=[3, 1, 28, 28], dtype="float32" + ) + out = paddle.concat(x=[x_input]) + self.pass_list = ['identity_op_clean_pass'] + self.feeds = { + "x_input": np.random.random((3, 1, 28, 28)).astype("float32") + } + self.fetch_list = [out] + self.valid_op_map = {"pd_op.concat": 0} + yield pir_program, False + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestRemoveRedundentCastPattern(PassTest): + def is_program_valid(self, program=None): + return True + + def sample_program(self): + for type_1 in ["float16"]: + for type_2 in ["int32"]: + pir_program = None + with paddle.pir_utils.IrGuard(): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): + x = paddle.static.data( + name='x', shape=[3, 1, 28, 28], dtype="float32" + ) + out = paddle.cast(paddle.cast(x, type_1), type_2) + self.pass_list = ['identity_op_clean_pass'] + self.feeds = { + "x": np.random.random((3, 1, 28, 28)).astype("float32") + } + self.fetch_list = [out] + self.valid_op_map = {"pd_op.cast": 1} + yield pir_program, False + + def test_check_output(self): + self.check_pass_correct() + + +@unittest.skipIf( + not paddle.base.core.is_compiled_with_cuda(), + "core is not complied with CUDA", +) +class TestRemoveRedundentTransposePattern(PassTest): + def is_program_valid(self, program=None): + return True + + def sample_program(self): + for perm1_shape in [[1, 2, 0]]: + for perm2_shape in [[0, 2, 1]]: + pir_program = None + with paddle.pir_utils.IrGuard(): + pir_program = paddle.static.Program() + with paddle.pir.core.program_guard(pir_program): + x = paddle.static.data( + name='x', shape=[2, 3, 4], dtype="float32" + ) + out = paddle.transpose( + paddle.transpose(x, perm1_shape), perm2_shape + ) + self.pass_list = ['identity_op_clean_pass'] + self.feeds = { + "x": np.random.random((2, 3, 4)).astype("float32") + } + self.fetch_list = [out] + self.valid_op_map = {"pd_op.transpose": 1} + yield pir_program, False + + def test_check_output(self): + self.check_pass_correct() + + +class TestRemoveRedundentTransposePatternWithCpu( + TestRemoveRedundentTransposePattern +): + def setUp(self): + self.place_runtime = "cpu" + + +class TestRemoveRedundentCastPatternWithCpu(TestRemoveRedundentCastPattern): + def setUp(self): + self.place_runtime = "cpu" + + +class TestRemoveUselessCastPatternWithCpu(TestRemoveUselessCastPattern): + def setUp(self): + self.place_runtime = "cpu" + + +class TestRemoveUselessConcatPatternWithCpu(TestRemoveUselessConcatPattern): + def setUp(self): + self.place_runtime = "cpu" + + +class TestRemoveRedundentScalePatternWithCpu(TestRemoveRedundentScalePattern): + def setUp(self): + self.place_runtime = "cpu" + + +class TestRemoveUselessScalePatternWithCpu(TestRemoveUselessScalePattern): + def setUp(self): + self.place_runtime = "cpu" + + +if __name__ == "__main__": + unittest.main()