Skip to content

Commit

Permalink
merge upstream/develop
Browse files Browse the repository at this point in the history
  • Loading branch information
YKTian-x2b committed Mar 29, 2024
2 parents 8fd8d3e + 70cc347 commit 519a02b
Show file tree
Hide file tree
Showing 163 changed files with 4,920 additions and 2,075 deletions.
12 changes: 8 additions & 4 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
<!-- TemplateReference: https://github.com/PaddlePaddle/Paddle/wiki/PULL-REQUEST-TEMPLATE--REFERENCE -->
<!-- Demo: https://github.com/PaddlePaddle/Paddle/pull/24810 -->
### PR types
<!-- One of [ New features | Bug fixes | Function optimization | Performance optimization | Breaking changes | Others ] -->

### PR changes
<!-- One of [ OPs | APIs | Docs | Others ] -->
### PR Category
<!-- One of [ User Experience | Execute Infrastructure | Operator Mechanism | CINN | Custom Device | Performance Optimization | Distributed Strategy | Parameter Server | Communication Library | Auto Parallel | Inference | Environment Adaptation | Others ] -->


### PR Types
<!-- One of [ New features | Bug fixes | Improvements | Performance | BC Breaking | Deprecations | Docs | Devs | Not User Facing | Security | Deprecations | Others ] -->


### Description
<!-- Describe what you’ve done -->
130 changes: 65 additions & 65 deletions paddle/cinn/common/broadcast_tree.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,71 +115,6 @@ void ForEachBroadcastDimExpr(const BroadcastLeaf& leaves,
}
}

std::optional<symbol::Broadcastable<symbol::DimExpr>> GetFirstCstrBroadcastable(
const BroadcastLeaf& leaves) {
std::optional<symbol::Broadcastable<symbol::DimExpr>> ret;
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
const auto& operands = broadcast.operands;
std::optional<symbol::DimExpr> lhs_symbol;
std::optional<symbol::DimExpr> rhs_symbol;
size_t i = 0;
for (; i < operands->size(); ++i) {
if (operands->at(i).template isa<std::string>()) {
lhs_symbol = operands->at(i);
break;
}
}
for (i++; i < operands->size(); ++i) {
if (operands->at(i).template isa<std::string>()) {
rhs_symbol = operands->at(i);
break;
}
}
if (lhs_symbol.has_value() && rhs_symbol.has_value()) {
CHECK(lhs_symbol != rhs_symbol)
<< lhs_symbol.value() << " != " << rhs_symbol.value();
ret = symbol::Broadcastable<symbol::DimExpr>{lhs_symbol.value(),
rhs_symbol.value()};
return true;
}
return false;
});
if (ret.has_value()) return ret.value();
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
const auto& operands = broadcast.operands;
std::optional<symbol::DimExpr> lhs_symbol;
std::optional<symbol::DimExpr> rhs;
for (const auto& operand : *operands) {
if (operand.template isa<std::string>()) {
lhs_symbol = operand;
break;
}
}
for (const auto& operand : *operands) {
if (operand != lhs_symbol) {
rhs = operand;
break;
}
}
if (lhs_symbol.has_value() && rhs.has_value()) {
ret = symbol::Broadcastable<symbol::DimExpr>{lhs_symbol.value(),
rhs.value()};
return true;
}
return false;
});
if (ret.has_value()) return ret.value();
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
const auto& operands = broadcast.operands;
CHECK_GE(operands->size(), 2);
CHECK(operands->at(0) != operands->at(1));
ret = symbol::Broadcastable<symbol::DimExpr>{operands->at(0),
operands->at(1)};
return true;
});
return ret;
}

using Pattern2Placement = std::unordered_map<symbol::DimExpr, symbol::DimExpr>;

Pattern2Placement ConstructCstrLhsEqRhsReplacement(
Expand Down Expand Up @@ -291,6 +226,71 @@ BroadcastBranch<BroadcastTree> ConstructBroadcastBranch(

} // namespace

std::optional<symbol::Broadcastable<symbol::DimExpr>> GetFirstCstrBroadcastable(
const BroadcastLeaf& leaves) {
std::optional<symbol::Broadcastable<symbol::DimExpr>> ret;
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
const auto& operands = broadcast.operands;
std::optional<symbol::DimExpr> lhs_symbol;
std::optional<symbol::DimExpr> rhs_symbol;
size_t i = 0;
for (; i < operands->size(); ++i) {
if (operands->at(i).template isa<std::string>()) {
lhs_symbol = operands->at(i);
break;
}
}
for (i++; i < operands->size(); ++i) {
if (operands->at(i).template isa<std::string>()) {
rhs_symbol = operands->at(i);
break;
}
}
if (lhs_symbol.has_value() && rhs_symbol.has_value()) {
CHECK(lhs_symbol != rhs_symbol)
<< lhs_symbol.value() << " != " << rhs_symbol.value();
ret = symbol::Broadcastable<symbol::DimExpr>{lhs_symbol.value(),
rhs_symbol.value()};
return true;
}
return false;
});
if (ret.has_value()) return ret.value();
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
const auto& operands = broadcast.operands;
std::optional<symbol::DimExpr> lhs_symbol;
std::optional<symbol::DimExpr> rhs;
for (const auto& operand : *operands) {
if (operand.template isa<std::string>()) {
lhs_symbol = operand;
break;
}
}
for (const auto& operand : *operands) {
if (operand != lhs_symbol) {
rhs = operand;
break;
}
}
if (lhs_symbol.has_value() && rhs.has_value()) {
ret = symbol::Broadcastable<symbol::DimExpr>{lhs_symbol.value(),
rhs.value()};
return true;
}
return false;
});
if (ret.has_value()) return ret.value();
ForEachBroadcastDimExpr(leaves, [&](const auto& broadcast) -> bool {
const auto& operands = broadcast.operands;
CHECK_GE(operands->size(), 2);
CHECK(operands->at(0) != operands->at(1));
ret = symbol::Broadcastable<symbol::DimExpr>{operands->at(0),
operands->at(1)};
return true;
});
return ret;
}

BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves) {
std::optional<symbol::Broadcastable<symbol::DimExpr>>
broadcastable_condition = GetFirstCstrBroadcastable(leaves);
Expand Down
3 changes: 3 additions & 0 deletions paddle/cinn/common/broadcast_tree.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@ BroadcastTree ConstructBroadcastTree(const BroadcastLeaf& leaves);

std::string ToTxtString(const BroadcastTree&);

std::optional<symbol::Broadcastable<symbol::DimExpr>> GetFirstCstrBroadcastable(
const BroadcastLeaf& leaves);

} // namespace cinn::common
2 changes: 1 addition & 1 deletion paddle/cinn/frontend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ add_subdirectory(paddle)
add_subdirectory(decomposer)
add_subdirectory(op_mappers)
add_subdirectory(pass)
add_subdirectory(group_cluster)
# add_subdirectory(group_cluster)

cinn_cc_test(test_op_mapper_registry SRCS op_mapper_registry_test.cc DEPS
cinncore)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ set(cinn_transforms_deps
cinn_op_dialect
op_dialect_vjp
cinn_runtime_dialect
group_cluster
# group_cluster
pir_compiler)

cinn_cc_library(cinn_transforms SRCS ${cinn_transforms_srcs} DEPS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ class AddBroadcastToElementwisePass : public pir::PatternRewritePass {
context);

// bitwise ops
ps.Add<AddBroadcastToElementwisePattern<paddle::dialect::BitwiseAndOp>>(
context);
ps.Add<AddBroadcastToElementwisePattern<paddle::dialect::BitwiseOrOp>>(
context);
ps.Add<AddBroadcastToElementwisePattern<paddle::dialect::BitwiseXorOp>>(
Expand Down
55 changes: 43 additions & 12 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/pir/include/dialect/shape/ir/shape_dialect.h"
#include "paddle/pir/include/pass/pass_manager.h"

#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/add_store_in_fusion_op_pass.h"
Expand All @@ -36,9 +37,10 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/divide_group_op_to_fusion_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/move_generate_shape_ops_to_prologue_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/simplify_dim_expr_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/single_op_fallback_to_phi.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/substitute_dim_expr_based_on_constraints_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/insert_broadcast_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lower_cinn_fusion_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/lowering_pass/lower_cinn_fusion_op_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/pd_to_cinn_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/remove_unchanged_reshape_pass.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/replace_dynamic_expand_pass.h"
Expand Down Expand Up @@ -73,23 +75,27 @@ bool HasDynamicShape(const pir::Program& program) {
}
} // namespace

void ApplyPdToCinnPass(
::pir::Program* program,
const std::function<std::shared_ptr<::pir::PassManager>()>&
CreatePassManager) {
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass());
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());
pass_manager->Run(program);
}

void ApplyCinnPreprocessPass(
::pir::Program* program,
const std::function<std::shared_ptr<::pir::PassManager>()>&
CreatePassManager) {
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
bool has_dynamic_shape = HasDynamicShape(*program);

pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass());
if (!has_dynamic_shape && FLAGS_check_infer_symbolic) {
pass_manager->AddPass(pir::CreateShapeOptimizationPass());
pass_manager->AddPass(cinn::dialect::ir::CreateCheckInferSymbolicPass());
}
pass_manager->AddPass(cinn::dialect::ir::CreatePdOpToCinnOpPass());

pass_manager->AddPass(
cinn::dialect::ir::CreateAddBroadcastToElementwisePass());
pass_manager->AddPass(pir::CreateDeadCodeEliminationPass());

if (has_dynamic_shape) {
pass_manager->AddPass(cinn::dialect::ir::CreateConvert0DTo1DPass());
Expand All @@ -115,18 +121,19 @@ void ApplyBuildGroupOpPass(
pass_manager->AddPass(cinn::dialect::ir::CreateRemoveUnchangedReshapePass());

pass_manager->AddPass(pir::CreateBuildCinnPass());
if (has_dynamic_shape) {
pass_manager->AddPass(cinn::dialect::ir::CreateInsertBroadcastPass());
}

pass_manager->Run(program);
}

void ApplyGroupOpPass(::pir::Program* program,
const std::function<std::shared_ptr<pir::PassManager>()>&
CreatePassManager) {
std::shared_ptr<pir::PassManager> pass_manager = CreatePassManager();
pass_manager->AddPass(
cinn::dialect::ir::CreateAddBroadcastToElementwisePass());
if (HasDynamicShape(*program)) {
pass_manager->AddPass(::pir::CreateShapeOptimizationPass());
pass_manager->AddPass(cinn::dialect::ir::CreateInsertBroadcastPass());
pass_manager->AddPass(
cinn::dialect::ir::CreateSubstituteDimExprBasedOnConstraintsPass());
pass_manager->AddPass(cinn::dialect::ir::CreateSimplifyDimExprPass());
Expand Down Expand Up @@ -175,25 +182,49 @@ void ApplyCinnLowerPass(
pass_manager->AddPass(std::move(pass.value()));
}

pass_manager->AddPass(cinn::dialect::ir::CreateSingleOpFallbackToPhiPass());
if (has_dynamic_shape && !force_static_shape) {
pass_manager->AddPass(
cinn::dialect::ir::CreateLowerCinnDyShapeFusionOpPass());
} else {
pass_manager->AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass());
}

pass_manager->AddPass(cinn::dialect::ir::CreateLowerCinnFusionOpPass());
pass_manager->AddPass(
cinn::dialect::ir::CreateSplitGenerateShapeIntoShapeOpsPass());

pass_manager->Run(program);
}

template <typename OP_TYPE>
int64_t GetOpCount(const ::pir::Operation* op) {
int64_t count = 0;
for (auto& region : *op) {
for (auto& block : region) {
for (auto& sub_op : block) {
if (sub_op.isa<OP_TYPE>()) {
count++;
continue;
}
if (sub_op.num_regions() > 0) {
count += GetOpCount<OP_TYPE>(&sub_op);
}
}
}
}
return count;
}

void ApplyCinnPass(::pir::Program* program,
const std::function<std::shared_ptr<pir::PassManager>()>&
CreatePassManager) {
ApplyPdToCinnPass(program, CreatePassManager);
ApplyCinnPreprocessPass(program, CreatePassManager);
ApplyBuildGroupOpPass(program, CreatePassManager);
ApplyGroupOpPass(program, CreatePassManager);
ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager);
LOG(INFO) << "FusionOp count before lowering : *****[ "
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
<< " ]*****";
ApplyCinnLowerPass(program, CreatePassManager);
}

Expand Down
Loading

0 comments on commit 519a02b

Please sign in to comment.