diff --git a/paddle/cinn/frontend/group_pattern.h b/paddle/cinn/frontend/group_pattern.h index 255eab33894d60..3e63f3626a2f1d 100644 --- a/paddle/cinn/frontend/group_pattern.h +++ b/paddle/cinn/frontend/group_pattern.h @@ -97,7 +97,7 @@ struct ShardableAxesUtil { return ret; } - static ShardableAxes GetFullyShardableAxes(size_t rank) { + static ShardableAxes MakeFullyShardableAxes(const size_t rank) { ShardableAxes ret; for (int i = 0; i < rank; ++i) { ret.emplace_back(ShardableAxis{ @@ -107,6 +107,27 @@ struct ShardableAxesUtil { } return ret; } + + static ShardableAxes MakeReduceOpInputShardableAxes( + const size_t input_rank, const std::vector& reduce_axes) { + if (reduce_axes.empty()) return ShardableAxes{}; + for (int64_t reduce_axis : reduce_axes) { + CHECK_GE(reduce_axis, 0); + CHECK_LT(reduce_axis, input_rank); + } + const auto IsReduceAxis = [&](int64_t i) { + return std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end(); + }; + ShardableAxes ret; + for (int64_t i = 0; i < input_rank; ++i) { + if (IsReduceAxis(i)) continue; + ret.emplace_back(ShardableAxis{ + .axis=i, + .axis_name=std::string("D") + std::to_string(ShardableAxis::UnqiueSeqNo()), + }); + } + return ret; + } }; struct SoleOutputShardableAxes { diff --git a/paddle/cinn/frontend/group_pattern_util.cc b/paddle/cinn/frontend/group_pattern_util.cc index 61638d01df64a4..836ddb850e6834 100644 --- a/paddle/cinn/frontend/group_pattern_util.cc +++ b/paddle/cinn/frontend/group_pattern_util.cc @@ -206,10 +206,88 @@ size_t GetRank(pir::Value value) { return value.type().dyn_cast().dims().size(); } +std::vector GetReduceAxes(const pir::Operation* reduce_op) { + const size_t input_rank = GetRank(reduce_op->operand_source(0)); + const auto& attr_val = reduce_op->attributes().at("dim"); + CHECK(attr_val.isa<::pir::ArrayAttribute>()); + const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>(); + std::vector reduce_axes; + for (int i = 0; i < axis_attr.size(); ++i) { + int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data(); + if (axis < 0) { + axis += input_rank; + } + CHECK_GE(axis, 0); + CHECK_LT(axis, input_rank); + reduce_axes.push_back(axis); + } + return reduce_axes; +} + +bool GetReduceOpKeepDims(const pir::Operation* reduce_op) { + const auto& attr_val = reduce_op->attributes().at("keep_dim"); + CHECK(attr_val.isa<::pir::BoolAttribute>()); + return attr_val.dyn_cast<::pir::BoolAttribute>(); +} + +ShardableAxes SequeezeShardableAxes(const ShardableAxes& sa) { + ShardableAxes ret_sa(sa); + for (int i = 0; i < ret_sa.size(); ++i) { + for (int j = i + 1; j < ret_sa.size(); ++j) { + CHECK_LT(ret_sa.at(i).axis, ret_sa.at(j).axis); + } + ret_sa.at(i).axis = i; + } + return ret_sa; +} + +ShardableAxesSignature MakeEmptyShardableAxesSignature(const pir::Operation* op) { + const int result_idx = GetOutputShardableAxesResultIdx(op); + pir::Value output = op->result(result_idx); + ShardableAxes output_sa = ShardableAxesUtil::MakeFullyShardableAxes(GetRank(output)); + using InputSignature = std::unordered_map; + InputSignature empty_input_sig; + for (int i = 0; i < op->num_operands(); ++i) { + empty_input_sig[OpAndOperandIndex{op, i}] = ShardableAxes{}; + } + return ShardableAxesSignature{ + .sole_output_sa = SoleOutputShardableAxes{ + .shardable_axes=output_sa, + }, + .input_shardable_axes = empty_input_sig, + }; +} + +ShardableAxesSignature MakeShardableAxesSignature4ReduceOp( + const pir::Operation* reduce_op) { + const size_t input_rank = GetRank(reduce_op->operand_source(0)); + const auto& reduce_axes = GetReduceAxes(reduce_op); + const ShardableAxes input_sa = + ShardableAxesUtil::MakeReduceOpInputShardableAxes(input_rank, reduce_axes); + using InputSignature = std::unordered_map; + const ShardableAxes output_sa = + (GetReduceOpKeepDims(reduce_op) ? input_sa : SequeezeShardableAxes(input_sa)); + return ShardableAxesSignature{ + .sole_output_sa = SoleOutputShardableAxes{ + .shardable_axes=output_sa, + }, + .input_shardable_axes = InputSignature{ + {OpAndOperandIndex{reduce_op, 0}, input_sa}, + }, + }; +} + +bool IsDisabledElementwiseOp(const pir::Operation* op) { + if (op->isa()) return true; + return false; +} + ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp( const pir::Operation* op) { - CHECK(!op->isa()) - << "reshape not supported. TODO(wuzhanfei)."; + if (IsDisabledElementwiseOp(op)) { + LOG(ERROR) << "[ShardableAxesSignature] no shardable axes signature found. op_name : " << op->name(); + return MakeEmptyShardableAxesSignature(op); + } const size_t rank = [&] { std::optional rank; for (int i = 0; i < op->num_operands(); ++i) { @@ -229,7 +307,7 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp( return rank.value(); }(); const ShardableAxes output_shardable_axes = - ShardableAxesUtil::GetFullyShardableAxes(rank); + ShardableAxesUtil::MakeFullyShardableAxes(rank); std::unordered_map input_shardable_axes; for (int i = 0; i < op->num_operands(); ++i) { input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes; @@ -244,21 +322,24 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp( ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp( const pir::Operation* op) { - LOG(FATAL) << "TODO(wuzhanfei)."; + LOG(ERROR) << "[ShardableAxesSignature] no shardable axes signature found. op_name : " << op->name(); + return MakeEmptyShardableAxesSignature(op); } ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) { const hlir::framework::OpPatternKind kind = GetOpPatternKind(op); - if (kind == hlir::framework::kElementWise) { + if (kind == hlir::framework::kReduction) { + return MakeShardableAxesSignature4ReduceOp(op); + } else if (kind == hlir::framework::kElementWise) { return MakeShardableAxesSignature4ElementWiseOp(op); } else if (kind == hlir::framework::kBroadcast) { return MakeShardableAxesSignature4BroadcastOp(op); } else { - LOG(FATAL) - << "only kReduction, kElementWise, kBroadcast supported. op_name:" + LOG(ERROR) + << "[ShardableAxesSignature] no shardable axes signature found. op_name:" << op->name(); } - LOG(FATAL) << "Dead code"; + return MakeEmptyShardableAxesSignature(op); } template @@ -498,7 +579,7 @@ std::unordered_map InferShardableAxesFromSink( CHECK_GT(op_topo.ops->count(sink), 0); const int result_idx = GetOutputShardableAxesResultIdx(sink); size_t rank = GetRank(sink->result(result_idx)); - const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank); + const auto& init_sa = ShardableAxesUtil::MakeFullyShardableAxes(rank); return ReversedInferShardableAxes(reversed_walker, sink, init_sa); } @@ -1718,22 +1799,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy { const pir::Operation* reduce_op, const ShardableAxes& shardable_axes) { const size_t input_rank = GetRank(reduce_op->operand_source(0)); - const auto& reduce_axes = [&]{ - const auto& attr_val = reduce_op->attributes().at("dim"); - CHECK(attr_val.isa<::pir::ArrayAttribute>()); - const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>(); - std::vector reduce_axes; - for (int i = 0; i < axis_attr.size(); ++i) { - int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data(); - if (axis < 0) { - axis += input_rank; - } - CHECK_GE(axis, 0); - CHECK_LT(axis, input_rank); - reduce_axes.push_back(axis); - } - return reduce_axes; - }(); + const auto& reduce_axes = GetReduceAxes(reduce_op); // no shardability if input reduced into one element. if (reduce_axes.empty()) return false; @@ -1747,11 +1813,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy { }; return std::find_if(shardable_axes.begin(), shardable_axes.end(), Condition) != shardable_axes.end(); }; - const bool keepdims = [&]{ - const auto& attr_val = reduce_op->attributes().at("keep_dim"); - CHECK(attr_val.isa<::pir::BoolAttribute>()); - return attr_val.dyn_cast<::pir::BoolAttribute>(); - }(); + const bool keepdims = GetReduceOpKeepDims(reduce_op); if (keepdims) { const size_t output_rank = input_rank; CHECK(!reduce_axes.empty());