Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tensor_comsumer tensor_consumer,etc #62213

Merged
merged 3 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ DRR can reduce the development cost of PASS, allowing developers to focus on pro
Taking PASS to eliminate redundant CastOp as an example, the code example developed using DRR is as follows:
~~~ c++
// 1. Inherit class from DrPatternBase
class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase {
class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override { return "RemoveRedundentCastPattern"; }
std::string name() const override { return "RemoveRedundantCastPattern"; }

// 2. Overload operator()
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/drr/README_cn.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ DRR ( Declarative Rewrite Rule ) 是来处理这种 DAG-to-DAG 类型的一套 P
以消除冗余 CastOp 的 PASS 为例,使用 DRR 的代码开发示例如下:
~~~ c++
// 1. 继承 DrrPatternBase 类
class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase {
class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override { return "RemoveRedundentCastPattern"; }
std::string name() const override { return "RemoveRedundantCastPattern"; }

// 2. 重载 operator()
void operator()(paddle::drr::DrrPatternContext *ctx) const override {
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/drr/src/attr_type_uilts.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,15 @@ PD_SPECIALIZE_CppTypeToIrAttribute(phi::IntArray,
paddle::dialect::IntArrayAttribute);

template <typename T>
struct IrAttrbuteCreator {
struct IrAttributeCreator {
typename CppTypeToIrAttribute<T>::type operator()(T obj) const {
return CppTypeToIrAttribute<T>::type::template get(
pir::IrContext::Instance(), obj);
}
};

template <>
struct IrAttrbuteCreator<std::vector<int32_t>> {
struct IrAttributeCreator<std::vector<int32_t>> {
pir::ArrayAttribute operator()(std::vector<int32_t> obj) const {
std::vector<pir::Attribute> attr_vec;
attr_vec.reserve(obj.size());
Expand All @@ -69,7 +69,7 @@ struct IrAttrbuteCreator<std::vector<int32_t>> {
};

template <>
struct IrAttrbuteCreator<std::vector<float>> {
struct IrAttributeCreator<std::vector<float>> {
pir::ArrayAttribute operator()(std::vector<float> obj) const {
std::vector<pir::Attribute> attr_vec;
attr_vec.reserve(obj.size());
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/pir/drr/src/ir_operation_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,33 +65,33 @@ void OperationFactory::RegisterManualOpCreator() {

pir::Attribute CreateIrAttribute(const std::any& obj) {
if (obj.type() == typeid(bool)) {
return IrAttrbuteCreator<bool>()(std::any_cast<bool>(obj));
return IrAttributeCreator<bool>()(std::any_cast<bool>(obj));
} else if (obj.type() == typeid(int32_t)) {
return IrAttrbuteCreator<int32_t>()(std::any_cast<int32_t>(obj));
return IrAttributeCreator<int32_t>()(std::any_cast<int32_t>(obj));
} else if (obj.type() == typeid(int64_t)) {
return IrAttrbuteCreator<int64_t>()(std::any_cast<int64_t>(obj));
return IrAttributeCreator<int64_t>()(std::any_cast<int64_t>(obj));
} else if (obj.type() == typeid(float)) {
return IrAttrbuteCreator<float>()(std::any_cast<float>(obj));
return IrAttributeCreator<float>()(std::any_cast<float>(obj));
} else if (obj.type() == typeid(std::string)) {
return IrAttrbuteCreator<std::string>()(std::any_cast<std::string>(obj));
return IrAttributeCreator<std::string>()(std::any_cast<std::string>(obj));
} else if (obj.type() == typeid(const char*)) {
return IrAttrbuteCreator<std::string>()(std::any_cast<const char*>(obj));
return IrAttributeCreator<std::string>()(std::any_cast<const char*>(obj));
} else if (obj.type() == typeid(phi::DataType)) {
return IrAttrbuteCreator<phi::DataType>()(
return IrAttributeCreator<phi::DataType>()(
std::any_cast<phi::DataType>(obj));
} else if (obj.type() == typeid(phi::Place)) {
return IrAttrbuteCreator<phi::Place>()(std::any_cast<phi::Place>(obj));
return IrAttributeCreator<phi::Place>()(std::any_cast<phi::Place>(obj));
} else if (obj.type() == typeid(std::vector<int32_t>)) {
return IrAttrbuteCreator<std::vector<int32_t>>()(
return IrAttributeCreator<std::vector<int32_t>>()(
std::any_cast<std::vector<int32_t>>(obj));
} else if (obj.type() == typeid(std::vector<int64_t>)) {
return IrAttrbuteCreator<std::vector<int64_t>>()(
return IrAttributeCreator<std::vector<int64_t>>()(
std::any_cast<std::vector<int64_t>>(obj));
} else if (obj.type() == typeid(std::vector<float>)) {
return IrAttrbuteCreator<std::vector<float>>()(
return IrAttributeCreator<std::vector<float>>()(
std::any_cast<std::vector<float>>(obj));
} else if (obj.type() == typeid(phi::IntArray)) {
return IrAttrbuteCreator<phi::IntArray>()(
return IrAttributeCreator<phi::IntArray>()(
std::any_cast<phi::IntArray>(obj));
} else {
PADDLE_THROW(
Expand Down
24 changes: 12 additions & 12 deletions paddle/fluid/pir/drr/src/pattern_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,16 +147,16 @@ void GraphTopo::WalkGraphNodesTopoOrder(
const std::unordered_set<std::string> &inputs_tensor =
graph_->input_tensors();
const std::unordered_map<std::string, std::shared_ptr<Tensor>>
&id2owned_tensor = graph_->id2owend_tensor();
const std::vector<std::shared_ptr<OpCall>> &owend_opcall =
&id2owned_tensor = graph_->id2owned_tensor();
const std::vector<std::shared_ptr<OpCall>> &owned_opcall =
graph_->owned_op_call();

std::queue<const OpCall *> opcall_queue;
std::unordered_map<const OpCall *, std::unordered_set<std::string>>
opcall_dependent;

// init opcall_dependent
for (const std::shared_ptr<OpCall> &opcall_sptr : owend_opcall) {
for (const std::shared_ptr<OpCall> &opcall_sptr : owned_opcall) {
if (opcall_sptr.get()->inputs().empty()) { // opcall inputs is empty
opcall_queue.push(opcall_sptr.get());
} else {
Expand All @@ -174,11 +174,11 @@ void GraphTopo::WalkGraphNodesTopoOrder(
"The input tensor [%s] must exists "
"in pattern graph to be obtained.",
tensor_name));
for (const auto &tensor_comsumer :
for (const auto &tensor_consumer :
id2owned_tensor.at(tensor_name).get()->consumers()) {
opcall_dependent[tensor_comsumer].erase(tensor_name);
if (opcall_dependent[tensor_comsumer].empty()) {
opcall_queue.push(tensor_comsumer);
opcall_dependent[tensor_consumer].erase(tensor_name);
if (opcall_dependent[tensor_consumer].empty()) {
opcall_queue.push(tensor_consumer);
}
}
}
Expand All @@ -190,10 +190,10 @@ void GraphTopo::WalkGraphNodesTopoOrder(

// update opcall_dependent
for (const auto &output_tensor : opcall->outputs()) {
for (const auto &tensor_comsumer : output_tensor->consumers()) {
opcall_dependent[tensor_comsumer].erase(output_tensor->name());
if (opcall_dependent[tensor_comsumer].empty()) {
opcall_queue.push(tensor_comsumer);
for (const auto &tensor_consumer : output_tensor->consumers()) {
opcall_dependent[tensor_consumer].erase(output_tensor->name());
if (opcall_dependent[tensor_consumer].empty()) {
opcall_queue.push(tensor_consumer);
}
}
}
Expand All @@ -202,7 +202,7 @@ void GraphTopo::WalkGraphNodesTopoOrder(

std::ostream &operator<<(std::ostream &os, const PatternGraph &pattern_graph) {
os << "\nAll Tensors:\n";
for (const auto &kv : pattern_graph.id2owend_tensor()) {
for (const auto &kv : pattern_graph.id2owned_tensor()) {
os << " " << kv.first;
}
os << "\n\n";
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/drr/src/pattern_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class PatternGraph {
}

const std::unordered_map<std::string, std::shared_ptr<Tensor>>&
id2owend_tensor() const {
id2owned_tensor() const {
return id2owned_tensor_;
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pir/drr/src/rewrite_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ bool DrrRewritePattern::MatchAndRewrite(
if (PatternGraphMatch(op, src_match_ctx.get())) {
VLOG(4) << "DRR pattern (" << pattern_name_ << ") is matched in program.";
PatternGraphRewrite(*src_match_ctx, rewriter);
VLOG(4) << "DRR pattern (" << pattern_name_ << ") is rewrited in program.";
VLOG(4) << "DRR pattern (" << pattern_name_ << ") is rewritten in program.";
return true;
}
return false;
Expand Down Expand Up @@ -414,13 +414,13 @@ MatchContextImpl DrrRewritePattern::CreateOperations(
// add input tensors info for res_match_ctx
for (const auto& in_tensor : result_pattern_graph.input_tensors()) {
PADDLE_ENFORCE_NE(
result_pattern_graph.id2owend_tensor().count(in_tensor),
result_pattern_graph.id2owned_tensor().count(in_tensor),
0,
phi::errors::NotFound("Not found the input tensor."
"Drr input tensor [%s] must exist in the result "
"pattern graph to be obtained.",
in_tensor));
if (!result_pattern_graph.id2owend_tensor().at(in_tensor)->is_none()) {
if (!result_pattern_graph.id2owned_tensor().at(in_tensor)->is_none()) {
res_match_ctx.BindIrValue(in_tensor, src_match_ctx.GetIrValue(in_tensor));
}
}
Expand Down Expand Up @@ -508,7 +508,7 @@ void DrrRewritePattern::ReplaceOutputTensor(
const MatchContextImpl& res_match_ctx,
pir::PatternRewriter& rewriter) const { // NOLINT
for (const auto& output_name : result_pattern_graph_->output_tensors()) {
if (source_pattern_graph_->id2owend_tensor().count(output_name)) {
if (source_pattern_graph_->id2owned_tensor().count(output_name)) {
const auto& src_ir_tensor = src_match_ctx.GetIrValue(output_name);
const auto& res_ir_tensor = res_match_ctx.GetIrValue(output_name);
rewriter.ReplaceAllUsesWith(src_ir_tensor, res_ir_tensor);
Expand Down
22 changes: 11 additions & 11 deletions paddle/fluid/pir/transforms/identity_op_clean_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class RemoveUselessScalePattern : public paddle::drr::DrrPatternBase {
}
};

class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase {
class RemoveRedundantScalePattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override { return "RemoveRedundentScalePattern"; }
std::string name() const override { return "RemoveRedundantScalePattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();
Expand Down Expand Up @@ -83,7 +83,7 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase {

paddle::drr::ResultPattern res = pat.ResultPattern();

const auto &bais_attr = res.ComputeAttr(
const auto &bias_attr = res.ComputeAttr(
[](const paddle::drr::MatchContext &match_ctx) -> float {
float res_bias_1 = 0.f;
float res_bias_2 = 0.f;
Expand Down Expand Up @@ -115,7 +115,7 @@ class RemoveRedundentScalePattern : public paddle::drr::DrrPatternBase {
{"place", pat.Attr("place_1")}});
const auto &scale_op_res =
res.Op("pd_op.scale",
{{"bias", bais_attr}, {"bias_after_scale", res.BoolAttr(true)}});
{{"bias", bias_attr}, {"bias_after_scale", res.BoolAttr(true)}});
scale_op_res({&res.Tensor("x"), &full_op_res()},
{&res.Tensor("scale_2_out")});
}
Expand Down Expand Up @@ -154,9 +154,9 @@ class RemoveUselessConcatPattern : public paddle::drr::DrrPatternBase {
}
};

class RemoveRedundentCastPattern : public paddle::drr::DrrPatternBase {
class RemoveRedundantCastPattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override { return "RemoveRedundentCastPattern"; }
std::string name() const override { return "RemoveRedundantCastPattern"; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
auto pat = ctx->SourcePattern();
Expand Down Expand Up @@ -245,10 +245,10 @@ class ReplaceDropoutWithScalePattern : public paddle::drr::DrrPatternBase {
}
};

class RemoveRedundentTransposePattern : public paddle::drr::DrrPatternBase {
class RemoveRedundantTransposePattern : public paddle::drr::DrrPatternBase {
public:
std::string name() const override {
return "RemoveRedundentTransposePattern";
return "RemoveRedundantTransposePattern";
}

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
Expand Down Expand Up @@ -286,13 +286,13 @@ class IdentityOpCleanPass : public pir::PatternRewritePass {
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(paddle::drr::Create<RemoveUselessScalePattern>(context));
ps.Add(paddle::drr::Create<RemoveRedundentScalePattern>(context));
ps.Add(paddle::drr::Create<RemoveRedundantScalePattern>(context));
ps.Add(paddle::drr::Create<RemoveUselessCastPattern>(context));
ps.Add(paddle::drr::Create<RemoveUselessConcatPattern>(context));
ps.Add(paddle::drr::Create<RemoveRedundentCastPattern>(context));
ps.Add(paddle::drr::Create<RemoveRedundantCastPattern>(context));
ps.Add(paddle::drr::Create<DeleteDropoutOpPattern>(context));
ps.Add(paddle::drr::Create<ReplaceDropoutWithScalePattern>(context));
ps.Add(paddle::drr::Create<RemoveRedundentTransposePattern>(context));
ps.Add(paddle::drr::Create<RemoveRedundantTransposePattern>(context));
return ps;
}
};
Expand Down