Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dynamic Shape] Substitute Complicate DimExpr in FusionOp #62766

Merged
merged 4 commits into from
Mar 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/include/dialect/shape/utils/dim_expr_util.h"
#include "paddle/pir/include/dialect/shape/utils/shape_or_data_expr.h"
#include "paddle/pir/include/pass/pass_registry.h"
#include "paddle/pir/include/pattern_rewrite/frozen_rewrite_pattern_set.h"
Expand Down Expand Up @@ -590,26 +591,178 @@ pir::Operation* ProcessDyShapeGroup(
}
}

namespace {

bool IsComplicatedDimExpr(const symbol::DimExpr& dim_expr) {
auto lambdas = symbol::Overloaded{
[](std::int64_t dim_expr) { return false; },
[](const std::string& dim_expr) { return false; },
[](const symbol::Negative<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Reciprocal<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Add<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Mul<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Max<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Min<symbol::DimExpr>& dim_expr) { return true; },
[](const symbol::Broadcast<symbol::DimExpr>& dim_expr) { return true; }};
return std::visit(lambdas, dim_expr.variant());
}

template <typename DoEachT>
void VisitEachInputValue(const GroupPtr& group, const DoEachT& DoEach) {
for (pir::Value value : GetBlockOutsideInput(group->ops)) {
DoEach(value);
}
}

template <typename DoEachT>
void VisitEachDimExprFromTensorShapeOrData(
const symbol::TensorShapeOrDataDimExprs& shape_or_data,
const DoEachT& DoEach) {
for (const auto& dim_expr : shape_or_data.shape()) {
DoEach(dim_expr);
}
if (!shape_or_data.data().has_value()) {
return;
}
for (const auto& dim_expr : shape_or_data.data().value()) {
DoEach(dim_expr);
}
}

template <typename DoEachT>
void VisitEachDimExpr(const symbol::ShapeOrDataDimExprs& shape_or_data,
const DoEachT& DoEach) {
auto lambdas = symbol::Overloaded{
[&](const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data) {
VisitEachDimExprFromTensorShapeOrData(tensor_shape_or_data, DoEach);
},
[&](const symbol::TensorListShapeOrDataDimExprs& tensor_list) {
symbol::TensorListShapeOrDataDimExprs simplified_tensor_list;
for (const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data :
tensor_list) {
VisitEachDimExprFromTensorShapeOrData(tensor_shape_or_data, DoEach);
}
}};
return std::visit(lambdas, shape_or_data.variant());
}

std::unordered_map<symbol::DimExpr, symbol::DimExpr>
CollectSubstituteDimExprMap(
const GroupPtr& group,
pir::ShapeConstraintIRAnalysis& shape_analysis) { // NOLINT
std::unordered_map<symbol::DimExpr, symbol::DimExpr> dim_expr_map;

VisitEachInputValue(group, [&](::pir::Value value) {
if (!shape_analysis.HasShapeOrDataForValue(value)) {
return;
}
auto& shape_or_data = shape_analysis.GetShapeOrDataForValue(value);
VisitEachDimExpr(shape_or_data, [&](const symbol::DimExpr& dim_expr) {
if (IsComplicatedDimExpr(dim_expr) &&
dim_expr_map.find(dim_expr) == dim_expr_map.end()) {
dim_expr_map[dim_expr] =
symbol::DimExpr(shape_analysis.GetNextSymName());
}
});
});

return dim_expr_map;
}

bool IsShapeOrDataNeedSubstitute(
const symbol::ShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>& dim_expr_map) {
bool ret = false;
VisitEachDimExpr(shape_or_data, [&](const symbol::DimExpr& dim_expr) {
if (dim_expr_map.find(dim_expr) != dim_expr_map.end()) {
ret = true;
}
});
return ret;
}

symbol::TensorShapeOrDataDimExprs SubstituteTensorShapeOrData(
const symbol::TensorShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>& dim_expr_map) {
const auto& SimplifyDimExpr =
[&](const std::vector<symbol::DimExpr>& original_dim_expr)
-> std::vector<symbol::DimExpr> {
std::vector<symbol::DimExpr> simplified_dim_expr{};
for (const symbol::DimExpr& dim_expr : original_dim_expr) {
simplified_dim_expr.push_back(symbol::SimplifyDimExpr(
symbol::SubstituteDimExpr(dim_expr, dim_expr_map)));
}
return simplified_dim_expr;
};

std::vector<symbol::DimExpr> simplified_shape =
SimplifyDimExpr(shape_or_data.shape());
if (!shape_or_data.data().has_value()) {
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape);
}
std::vector<symbol::DimExpr> simplified_data =
SimplifyDimExpr(shape_or_data.data().value());
return symbol::ShapeOrData<symbol::DimExpr>(simplified_shape,
simplified_data);
}

symbol::ShapeOrDataDimExprs SubstituteShapeOrData(
const symbol::ShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>& dim_expr_map) {
auto lambdas = symbol::Overloaded{
[&](const symbol::TensorShapeOrDataDimExprs& tensor_shape_or_data) {
return symbol::ShapeOrDataDimExprs(
SubstituteTensorShapeOrData(tensor_shape_or_data, dim_expr_map));
},
[&](const symbol::TensorListShapeOrDataDimExprs& tensor_list) {
symbol::TensorListShapeOrDataDimExprs simplified_tensor_list;
for (symbol::TensorShapeOrDataDimExprs tensor_shape_or_data :
tensor_list) {
simplified_tensor_list.push_back(
SubstituteTensorShapeOrData(tensor_shape_or_data, dim_expr_map));
}
return symbol::ShapeOrDataDimExprs(simplified_tensor_list);
}};
return std::visit(lambdas, shape_or_data.variant());
}

symbol::ShapeOrDataDimExprs TrySubstitute(
const symbol::ShapeOrDataDimExprs& shape_or_data,
const std::unordered_map<symbol::DimExpr, symbol::DimExpr>& dim_expr_map) {
if (!IsShapeOrDataNeedSubstitute(shape_or_data, dim_expr_map)) {
return shape_or_data;
}
return SubstituteShapeOrData(shape_or_data, dim_expr_map);
}

} // namespace

std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs>
CreateGroupShapeOrDataExprs(
const GroupPtr& group,
pir::ShapeConstraintIRAnalysis& shape_analysis) { // NOLINT
std::unordered_map<symbol::DimExpr, symbol::DimExpr> dim_expr_map =
CollectSubstituteDimExprMap(group, shape_analysis);
std::unordered_map<::pir::Value, symbol::ShapeOrDataDimExprs> value2shape;
for (auto* op : group->ops) {
for (size_t i = 0; i < op->num_operands(); ++i) {
auto operand = op->operand_source(i);
if (operand && value2shape.find(operand) == value2shape.end() &&
shape_analysis.HasShapeOrDataForValue(operand)) {
value2shape.insert(
{operand, shape_analysis.GetShapeOrDataForValue(operand)});
{operand,
TrySubstitute(shape_analysis.GetShapeOrDataForValue(operand),
dim_expr_map)});
}
}
for (size_t i = 0; i < op->num_results(); ++i) {
auto result = op->result(i);
if (result && value2shape.find(result) == value2shape.end() &&
shape_analysis.HasShapeOrDataForValue(result)) {
value2shape.insert(
{result, shape_analysis.GetShapeOrDataForValue(result)});
{result,
TrySubstitute(shape_analysis.GetShapeOrDataForValue(result),
dim_expr_map)});
}
}
}
Expand Down