Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] add identity_op_clean_pass and matmul_scale_fuse_pass #59840

Merged
merged 9 commits into from
Dec 12, 2023
4 changes: 4 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1734,6 +1734,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetEnableIrOptim(true);
pass_builder->ClearPasses();
pass_builder->AppendPass("auto_mixed_precision_pass");
pass_builder->AppendPass("inplace_op_var_pass");
LOG(INFO) << "This model run in GPU mixed precision mode with no ir "
"optimization.";
} else {
Expand Down Expand Up @@ -1918,6 +1919,9 @@ CreatePaddlePredictor<AnalysisConfig, PaddleEngineKind::kAnalysis>(
if (std::getenv("FLAGS_initial_cpu_memory_in_mb") == nullptr) {
SetGflag("initial_cpu_memory_in_mb", "0");
}
if (std::getenv("FLAGS_new_executor_sequential_run") == nullptr) {
SetGflag("new_executor_sequential_run", "1");
}
});

if (config.thread_local_stream_enabled() &&
Expand Down
25 changes: 25 additions & 0 deletions paddle/fluid/pir/drr/attr_type_uilts.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ PD_SPECIALIZE_CppTypeToIrAttribute(phi::Place, paddle::dialect::PlaceAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<int32_t>, pir::ArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<int64_t>,
paddle::dialect::IntArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(std::vector<float>, pir::ArrayAttribute);
PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray,
paddle::dialect::IntArrayAttribute);

Expand All @@ -66,6 +67,18 @@ struct IrAttrbuteCreator<std::vector<int32_t>> {
}
};

template <>
struct IrAttrbuteCreator<std::vector<float>> {
pir::ArrayAttribute operator()(std::vector<float> obj) const {
std::vector<pir::Attribute> attr_vec;
attr_vec.reserve(obj.size());
for (float x : obj) {
attr_vec.push_back(FloatAttribute::get(pir::IrContext::Instance(), x));
}
return pir::ArrayAttribute::get(pir::IrContext::Instance(), attr_vec);
}
};

template <typename T>
struct IrAttrTypeCast {
static T To(const pir::Attribute& attr) {
Expand Down Expand Up @@ -114,5 +127,17 @@ struct IrAttrTypeCast<std::vector<int64_t>> {
}
};

template <>
struct IrAttrTypeCast<std::vector<float>> {
static std::vector<float> To(const pir::Attribute& attr) {
std::vector<float> result;
auto array_attr = attr.dyn_cast<pir::ArrayAttribute>();
for (size_t i = 0; i < array_attr.size(); i++) {
result.push_back(array_attr.at(i).dyn_cast<pir::FloatAttribute>().data());
}
return result;
}
};

} // namespace drr
} // namespace pir
16 changes: 10 additions & 6 deletions paddle/fluid/pir/drr/drr_rewrite_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ bool DrrRewritePattern::PatternGraphMatch(
std::vector<const OpCall*> drr_output_sequence;
std::vector<Operation*> ir_output_sequence;
std::unordered_map<const OpCall*, Operation*> output_op_map;
for (auto pair : bind_map) {
for (const auto& pair : bind_map) {
drr_output_sequence.push_back(pair.first);
}
// using dfs to obtain the arrangement of all candidate ir ops
Expand Down Expand Up @@ -396,10 +396,11 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
Value ir_val = res_match_ctx.GetIrValue(input->name()).get();
if (ir_val) {
Operation* ir_input_op = ir_val.dyn_cast<pir::OpResult>().owner();
if (max_input_op_index < op_2_temp_program_index[ir_input_op]) {
max_input_op_index = op_2_temp_program_index[ir_input_op];
if (max_input_op_index < op_2_temp_program_index.at(ir_input_op)) {
max_input_op_index = op_2_temp_program_index.at(ir_input_op);
max_index_op = ir_input_op;
} else if (max_input_op_index == op_2_temp_program_index[ir_input_op]) {
} else if (max_input_op_index ==
op_2_temp_program_index.at(ir_input_op)) {
const auto& ops_vec = temp_program[max_input_op_index];
for (auto it = ops_vec.begin(); it != ops_vec.end(); it++) {
if (*it == max_index_op) {
Expand Down Expand Up @@ -430,6 +431,9 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
Operation* new_op =
CreateOperation(op_call, src_match_ctx, rewriter, &res_match_ctx);
op_2_temp_program_index[new_op] = max_input_op_index + 1;
if (max_input_op_index + 1 >= temp_program.size()) {
temp_program.push_back({});
}
temp_program[max_input_op_index + 1].push_back(new_op);
});

Expand Down Expand Up @@ -471,13 +475,13 @@ void DrrRewritePattern::DeleteSourcePatternOp(
const ResultPatternGraph& result_pattern_graph,
const MatchContextImpl& src_match_ctx,
pir::PatternRewriter& rewriter) const { // NOLINT

std::queue<Operation*> delete_ops_que;
std::unordered_set<Operation*> delete_ops_set;
GraphTopo graph_topo_visit(&source_pattern_graph);
graph_topo_visit.WalkGraphNodesTopoOrder([&](const OpCall& op_call) {
Operation* op = src_match_ctx.Operation(&op_call).get();
if (op->use_empty()) {
VLOG(5) << "DRR delete op: " << op->name() << " pointer: " << op;
if (delete_ops_set.count(op) == 0 && op->use_empty()) {
delete_ops_que.push(op);
delete_ops_set.insert(op);
}
Expand Down
14 changes: 14 additions & 0 deletions paddle/fluid/pir/drr/ir_operation_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,17 @@ void OperationFactory::RegisterManualOpCreator() {
pir::PatternRewriter& rewriter) {
return rewriter.Build<pir::CombineOp>(inputs);
});
RegisterOperationCreator(
"pd_op.scale",
[](const std::vector<Value>& inputs,
const pir::AttributeMap& attrs,
pir::PatternRewriter& rewriter) {
return rewriter.Build<paddle::dialect::ScaleOp>(
inputs[0].dyn_cast<pir::OpResult>(),
inputs[1].dyn_cast<pir::OpResult>(),
attrs.at("bias").dyn_cast<pir::FloatAttribute>().data(),
attrs.at("bias_after_scale").dyn_cast<pir::BoolAttribute>().data());
});
}

static pir::Attribute CreateIrAttribute(const std::any& obj) {
Expand All @@ -83,6 +94,9 @@ static pir::Attribute CreateIrAttribute(const std::any& obj) {
} else if (obj.type() == typeid(std::vector<int64_t>)) {
return IrAttrbuteCreator<std::vector<int64_t>>()(
std::any_cast<std::vector<int64_t>>(obj));
} else if (obj.type() == typeid(std::vector<float>)) {
return IrAttrbuteCreator<std::vector<float>>()(
std::any_cast<std::vector<float>>(obj));
} else if (obj.type() == typeid(phi::IntArray)) {
return IrAttrbuteCreator<phi::IntArray>()(
std::any_cast<phi::IntArray>(obj));
Expand Down
102 changes: 102 additions & 0 deletions paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/drr_pattern_base.h"
#include "paddle/fluid/pir/transforms/transform_general_functions.h"

#include "paddle/common/ddim.h"

#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pass/pass_registry.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"

namespace {

class MatmulScaleFusePattern
: public pir::drr::DrrPatternBase<MatmulScaleFusePattern> {
public:
void operator()(pir::drr::DrrPatternContext *ctx) const override {
pir::drr::SourcePattern pat = ctx->SourcePattern();
const auto &matmul_op = pat.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x")},
{"transpose_y", pat.Attr("transpose_y")}});

matmul_op({&pat.Tensor("x"), &pat.Tensor("y")},
{&pat.Tensor("matmul_out")});
const auto &full_op = pat.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape")},
{"value", pat.Attr("value")},
{"dtype", pat.Attr("dtype")},
{"place", pat.Attr("place")}});
const auto &scale_op =
pat.Op(paddle::dialect::ScaleOp::name(),
{{"bias", pat.Attr("bias")},
{"bias_after_scale", pat.Attr("bias_after_scale")}});
scale_op({&pat.Tensor("matmul_out"), &full_op()},
{&pat.Tensor("scale_out")});

pat.RequireNativeCall([&](const pir::drr::MatchContext &match_ctx) {
return std::abs(match_ctx.Attr<float>("bias")) <= 1e-6;
});

pir::drr::ResultPattern res = pat.ResultPattern();
const auto &full_op_res = res.Op(paddle::dialect::FullOp::name(),
{{"shape", pat.Attr("shape")},
{"value", pat.Attr("value")},
{"dtype", pat.Attr("dtype")},
{"place", pat.Attr("place")}});
const auto &scale_op_res =
res.Op(paddle::dialect::ScaleOp::name(),
{{"bias",
res.Attr([](const pir::drr::MatchContext &match_ctx) -> float {
return 0.0;
})},
{"bias_after_scale", pat.Attr("bias_after_scale")}});
const auto &matmul_op_res =
res.Op(paddle::dialect::MatmulOp::name(),
{{"transpose_x", pat.Attr("transpose_x")},
{"transpose_y", pat.Attr("transpose_y")}});
scale_op_res({&res.Tensor("y"), &full_op_res()},
{&res.Tensor("scale_res_out")});
matmul_op_res({&res.Tensor("x"), &res.Tensor("scale_res_out")},
{&res.Tensor("scale_out")});
}
};

class MatmulScaleFusePass : public pir::PatternRewritePass {
public:
MatmulScaleFusePass()
: pir::PatternRewritePass("matmul_scale_fuse_pass", 2) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(MatmulScaleFusePattern().Build(context));
return ps;
}
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateMatmulScaleFusePass() {
return std::make_unique<MatmulScaleFusePass>();
}
} // namespace pir

REGISTER_IR_PASS(matmul_scale_fuse_pass, MatmulScaleFusePass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/fusion/matmul_scale_fuse_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateMatmulScaleFusePass();

} // namespace pir
Loading