Skip to content

Commit

Permalink
[PIR] add identity_op_clean_pass and matmul_scale_fuse_pass (#59840)
Browse files Browse the repository at this point in the history
* add identity_op_clean_pass

* update

* Rename test_identity_op_clean_pass.py to test_pir_identity_op_clean_pass.py

* Rename test_matmul_scale_fuse_pass.py to test_pir_matmul_scale_fuse_pass.py

* update

* update drr

* new_executor_sequential_run
  • Loading branch information
yuanlehome authored Dec 12, 2023
1 parent 2360fae commit 3d5fbe3
Show file tree
Hide file tree
Showing 12 changed files with 790 additions and 6 deletions.
4 changes: 4 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
pass_builder->AppendPass("inplace_op_var_pass");
LOG(INFO) << "This model run in GPU mixed precision mode with no ir "
"optimization.";
} else {
Expand Down Expand Up @@ -1918,6 +1919,9 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
if (std::getenv("FLAGS_initial_cpu_memory_in_mb") == nullptr) {
SetGflag("initial_cpu_memory_in_mb", "0");
}
if (std::getenv("FLAGS_new_executor_sequential_run") == nullptr) {
SetGflag("new_executor_sequential_run", "1");
}
});

if (config.thread_local_stream_enabled() &&
Expand Down
25 changes: 25 additions & 0 deletions paddle/fluid/pir/drr/attr_type_uilts.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<int32_t>, pir::ArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<int64_t>,
paddle::dialect::IntArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<float>, pir::ArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray,
paddle::dialect::IntArrayAttribute);

Expand All @@ -66,6 +67,18 @@ struct IrAttrbuteCreator<std::vector<int32_t>> {
}
};

template <>
struct IrAttrbuteCreator<std::vector<float>> {
pir::ArrayAttribute operator()(std::vector<float> obj) const {
std::vector<pir::Attribute> attr_vec;
attr_vec.reserve(obj.size());
for (float x : obj) {
attr_vec.push_back(FloatAttribute::get(pir::IrContext::Instance(), x));
}
return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec);
}
};

template <typename T>
struct IrAttrTypeCast {
static T To(const pir::Attribute& attr) {
Expand Down Expand Up @@ -114,5 +127,17 @@ struct IrAttrTypeCast<std::vector<int64_t>> {
}
};

template <>
struct IrAttrTypeCast<std::vector<float>> {
static std::vector<float> To(const pir::Attribute& attr) {
std::vector<float> result;
auto array_attr = attr.dyn_cast<pir::ArrayAttribute>();
for (size_t i = 0; i < array_attr.size(); i++) {
result.push_back(array_attr.at(i).dyn_cast<pir::FloatAttribute>().data());
}
return result;
}
};

} // namespace drr
} // namespace pir
16 changes: 10 additions & 6 deletions paddle/fluid/pir/drr/drr_rewrite_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ bool DrrRewritePattern::PatternGraphMatch(
std::vector<const OpCall*> drr_output_sequence;
std::vector<Operation*> ir_output_sequence;
std::unordered_map<const OpCall*, Operation*> output_op_map;
for (auto pair : bind_map) {
for (const auto& pair : bind_map) {
drr_output_sequence.push_back(pair.first);
}
// using dfs to obtain the arrangement of all candidate ir ops
Expand Down Expand Up @@ -396,10 +396,11 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
Value ir_val = res_match_ctx.GetIrValue(input->name()).get();
if (ir_val) {
Operation* ir_input_op = ir_val.dyn_cast<pir::OpResult>().owner();
if (max_input_op_index < op_2_temp_program_index[ir_input_op]) {
max_input_op_index = op_2_temp_program_index[ir_input_op];
if (max_input_op_index < op_2_temp_program_index.at(ir_input_op)) {
max_input_op_index = op_2_temp_program_index.at(ir_input_op);
max_index_op = ir_input_op;
} else if (max_input_op_index == op_2_temp_program_index[ir_input_op]) {
} else if (max_input_op_index ==
op_2_temp_program_index.at(ir_input_op)) {
const auto& ops_vec = temp_program[max_input_op_index];
for (auto it = ops_vec.begin(); it != ops_vec.end(); it++) {
if (*it == max_index_op) {
Expand Down Expand Up @@ -430,6 +431,9 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
Operation* new_op =
CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx);
op_2_temp_program_index[new_op] = max_input_op_index + 1;
if (max_input_op_index + 1 >= temp_program.size()) {
temp_program.push_back({});
}
temp_program[max_input_op_index + 1].push_back(new_op);
});

Expand Down Expand Up @@ -471,13 +475,13 @@ void DrrRewritePattern::DeleteSourcePatternOp(
const ResultPatternGraph& result_pattern_graph,
const MatchContextImpl& src_match_ctx,
pir::PatternRewriter& rewriter) const { // NOLINT

std::queue<Operation*> delete_ops_que;
std::unordered_set<Operation*> delete_ops_set;
GraphTopo graph_topo_visit(&source_pattern_graph);
graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) {
Operation* op = src_match_ctx.Operation(&op_call).get();
if (op->use_empty()) {
VLOG(5) << "DRR delete op: " << op->name() << " pointer: " << op;
if (delete_ops_set.count(op) == 0 && op->use_empty()) {
delete_ops_que.push(op);
delete_ops_set.insert(op);
}
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/pir/drr/ir_operation_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ void OperationFactory::RegisterManualOpCreator() {
pir::PatternRewriter& rewriter) {
return rewriter.Build<pir::CombineOp>(inputs);
});
RegisterOperationCreator(
"pd_op.scale",
[](const std::vector<Value>& inputs,
const pir::AttributeMap& attrs,
pir::PatternRewriter& rewriter) {
return rewriter.Build<paddle::dialect::ScaleOp>(
inputs[0].dyn_cast<pir::OpResult>(),
inputs[1].dyn_cast<pir::OpResult>(),
attrs.at("bias").dyn_cast<pir::FloatAttribute>().data(),
attrs.at("bias_after_scale").dyn_cast<pir::BoolAttribute>().data());
});
}

static pir::Attribute CreateIrAttribute(const std::any& obj) {
Expand All @@ -83,6 +94,9 @@ static pir::Attribute CreateIrAttribute(const std::any& obj) {
} else if (obj.type() == typeid(std::vector<int64_t>)) {
return IrAttrbuteCreator<std::vector<int64_t>>()(
std::any_cast<std::vector<int64_t>>(obj));
} else if (obj.type() == typeid(std::vector<float>)) {
return IrAttrbuteCreator<std::vector<float>>()(
std::any_cast<std::vector<float>>(obj));
} else if (obj.type() == typeid(phi::IntArray)) {
return IrAttrbuteCreator<phi::IntArray>()(
std::any_cast<phi::IntArray>(obj));
Expand Down
102 changes: 102 additions & 0 deletions paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/common/ddim.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

class MatmulScaleFusePattern
: public pir::drr::DrrPatternBase<MatmulScaleFusePattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul_op = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x")},
{"transpose_y", pat.Attr("transpose_y")}});

matmul_op({&pat.Tensor("x"), &pat.Tensor("y")},
{&pat.Tensor("matmul_out")});
const auto &full_op = pat.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape")},
{"value", pat.Attr("value")},
{"dtype", pat.Attr("dtype")},
{"place", pat.Attr("place")}});
const auto &scale_op =
pat.Op(paddle::dialect::ScaleOp::name(),
{{"bias", pat.Attr("bias")},
{"bias_after_scale", pat.Attr("bias_after_scale")}});
scale_op({&pat.Tensor("matmul_out"), &full_op()},
{&pat.Tensor("scale_out")});

pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) {
return std::abs(match_ctx.Attr<float>("bias")) <= 1e-6;
});

pir::drr::ResultPattern res = pat.ResultPattern();
const auto &full_op_res = res.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape")},
{"value", pat.Attr("value")},
{"dtype", pat.Attr("dtype")},
{"place", pat.Attr("place")}});
const auto &scale_op_res =
res.Op(paddle::dialect::ScaleOp::name(),
{{"bias",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 0.0;
})},
{"bias_after_scale", pat.Attr("bias_after_scale")}});
const auto &matmul_op_res =
res.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x")},
{"transpose_y", pat.Attr("transpose_y")}});
scale_op_res({&res.Tensor("y"), &full_op_res()},
{&res.Tensor("scale_res_out")});
matmul_op_res({&res.Tensor("x"), &res.Tensor("scale_res_out")},
{&res.Tensor("scale_out")});
}
};

class MatmulScaleFusePass : public pir::PatternRewritePass {
public:
MatmulScaleFusePass()
: pir::PatternRewritePass("matmul_scale_fuse_pass", 2) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(MatmulScaleFusePattern().Build(context));
return ps;
}
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateMatmulScaleFusePass() {
return std::make_unique<MatmulScaleFusePass>();
}
} // namespace pir

REGISTER_IR_PASS(matmul_scale_fuse_pass, MatmulScaleFusePass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateMatmulScaleFusePass();

} // namespace pir
Loading

0 comments on commit 3d5fbe3

Please sign in to comment.