Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#69 from tc20042008/xk-cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
Xk cinn trivalop fuse
  • Loading branch information
tc20042008 authored Mar 14, 2024
2 parents c371dca + 27a647c commit 0937581
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 31 deletions.
23 changes: 22 additions & 1 deletion paddle/cinn/frontend/group_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ struct ShardableAxesUtil {
return ret;
}

static ShardableAxes GetFullyShardableAxes(size_t rank) {
static ShardableAxes MakeFullyShardableAxes(const size_t rank) {
ShardableAxes ret;
for (int i = 0; i < rank; ++i) {
ret.emplace_back(ShardableAxis{
Expand All @@ -107,6 +107,27 @@ struct ShardableAxesUtil {
}
return ret;
}

static ShardableAxes MakeReduceOpInputShardableAxes(
const size_t input_rank, const std::vector<int64_t>& reduce_axes) {
if (reduce_axes.empty()) return ShardableAxes{};
for (int64_t reduce_axis : reduce_axes) {
CHECK_GE(reduce_axis, 0);
CHECK_LT(reduce_axis, input_rank);
}
const auto IsReduceAxis = [&](int64_t i) {
return std::find(reduce_axes.begin(), reduce_axes.end(), i) != reduce_axes.end();
};
ShardableAxes ret;
for (int64_t i = 0; i < input_rank; ++i) {
if (IsReduceAxis(i)) continue;
ret.emplace_back(ShardableAxis{
.axis=i,
.axis_name=std::string("D") + std::to_string(ShardableAxis::UnqiueSeqNo()),
});
}
return ret;
}
};

struct SoleOutputShardableAxes {
Expand Down
122 changes: 92 additions & 30 deletions paddle/cinn/frontend/group_pattern_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,88 @@ size_t GetRank(pir::Value value) {
return value.type().dyn_cast<pir::DenseTensorType>().dims().size();
}

std::vector<int64_t> GetReduceAxes(const pir::Operation* reduce_op) {
const size_t input_rank = GetRank(reduce_op->operand_source(0));
const auto& attr_val = reduce_op->attributes().at("dim");
CHECK(attr_val.isa<::pir::ArrayAttribute>());
const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>();
std::vector<int64_t> reduce_axes;
for (int i = 0; i < axis_attr.size(); ++i) {
int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data();
if (axis < 0) {
axis += input_rank;
}
CHECK_GE(axis, 0);
CHECK_LT(axis, input_rank);
reduce_axes.push_back(axis);
}
return reduce_axes;
}

bool GetReduceOpKeepDims(const pir::Operation* reduce_op) {
const auto& attr_val = reduce_op->attributes().at("keep_dim");
CHECK(attr_val.isa<::pir::BoolAttribute>());
return attr_val.dyn_cast<::pir::BoolAttribute>();
}

ShardableAxes SequeezeShardableAxes(const ShardableAxes& sa) {
ShardableAxes ret_sa(sa);
for (int i = 0; i < ret_sa.size(); ++i) {
for (int j = i + 1; j < ret_sa.size(); ++j) {
CHECK_LT(ret_sa.at(i).axis, ret_sa.at(j).axis);
}
ret_sa.at(i).axis = i;
}
return ret_sa;
}

ShardableAxesSignature MakeEmptyShardableAxesSignature(const pir::Operation* op) {
const int result_idx = GetOutputShardableAxesResultIdx(op);
pir::Value output = op->result(result_idx);
ShardableAxes output_sa = ShardableAxesUtil::MakeFullyShardableAxes(GetRank(output));
using InputSignature = std::unordered_map<OpAndOperandIndex, ShardableAxes>;
InputSignature empty_input_sig;
for (int i = 0; i < op->num_operands(); ++i) {
empty_input_sig[OpAndOperandIndex{op, i}] = ShardableAxes{};
}
return ShardableAxesSignature{
.sole_output_sa = SoleOutputShardableAxes{
.shardable_axes=output_sa,
},
.input_shardable_axes = empty_input_sig,
};
}

ShardableAxesSignature MakeShardableAxesSignature4ReduceOp(
const pir::Operation* reduce_op) {
const size_t input_rank = GetRank(reduce_op->operand_source(0));
const auto& reduce_axes = GetReduceAxes(reduce_op);
const ShardableAxes input_sa =
ShardableAxesUtil::MakeReduceOpInputShardableAxes(input_rank, reduce_axes);
using InputSignature = std::unordered_map<OpAndOperandIndex, ShardableAxes>;
const ShardableAxes output_sa =
(GetReduceOpKeepDims(reduce_op) ? input_sa : SequeezeShardableAxes(input_sa));
return ShardableAxesSignature{
.sole_output_sa = SoleOutputShardableAxes{
.shardable_axes=output_sa,
},
.input_shardable_axes = InputSignature{
{OpAndOperandIndex{reduce_op, 0}, input_sa},
},
};
}

bool IsDisabledElementwiseOp(const pir::Operation* op) {
if (op->isa<cinn::dialect::ReshapeOp>()) return true;
return false;
}

ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(
const pir::Operation* op) {
CHECK(!op->isa<cinn::dialect::ReshapeOp>())
<< "reshape not supported. TODO(wuzhanfei).";
if (IsDisabledElementwiseOp(op)) {
LOG(ERROR) << "[ShardableAxesSignature] no shardable axes signature found. op_name : " << op->name();
return MakeEmptyShardableAxesSignature(op);
}
const size_t rank = [&] {
std::optional<size_t> rank;
for (int i = 0; i < op->num_operands(); ++i) {
Expand All @@ -229,7 +307,7 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(
return rank.value();
}();
const ShardableAxes output_shardable_axes =
ShardableAxesUtil::GetFullyShardableAxes(rank);
ShardableAxesUtil::MakeFullyShardableAxes(rank);
std::unordered_map<OpAndOperandIndex, ShardableAxes> input_shardable_axes;
for (int i = 0; i < op->num_operands(); ++i) {
input_shardable_axes[OpAndOperandIndex{op, i}] = output_shardable_axes;
Expand All @@ -244,21 +322,24 @@ ShardableAxesSignature MakeShardableAxesSignature4ElementWiseOp(

ShardableAxesSignature MakeShardableAxesSignature4BroadcastOp(
const pir::Operation* op) {
LOG(FATAL) << "TODO(wuzhanfei).";
LOG(ERROR) << "[ShardableAxesSignature] no shardable axes signature found. op_name : " << op->name();
return MakeEmptyShardableAxesSignature(op);
}

ShardableAxesSignature MakeShardableAxesSignature4Op(const pir::Operation* op) {
const hlir::framework::OpPatternKind kind = GetOpPatternKind(op);
if (kind == hlir::framework::kElementWise) {
if (kind == hlir::framework::kReduction) {
return MakeShardableAxesSignature4ReduceOp(op);
} else if (kind == hlir::framework::kElementWise) {
return MakeShardableAxesSignature4ElementWiseOp(op);
} else if (kind == hlir::framework::kBroadcast) {
return MakeShardableAxesSignature4BroadcastOp(op);
} else {
LOG(FATAL)
<< "only kReduction, kElementWise, kBroadcast supported. op_name:"
LOG(ERROR)
<< "[ShardableAxesSignature] no shardable axes signature found. op_name:"
<< op->name();
}
LOG(FATAL) << "Dead code";
return MakeEmptyShardableAxesSignature(op);
}

template<typename InputIt>
Expand Down Expand Up @@ -498,7 +579,7 @@ std::unordered_map<pir::Value, ShardableAxes> InferShardableAxesFromSink(
CHECK_GT(op_topo.ops->count(sink), 0);
const int result_idx = GetOutputShardableAxesResultIdx(sink);
size_t rank = GetRank(sink->result(result_idx));
const auto& init_sa = ShardableAxesUtil::GetFullyShardableAxes(rank);
const auto& init_sa = ShardableAxesUtil::MakeFullyShardableAxes(rank);
return ReversedInferShardableAxes(reversed_walker, sink, init_sa);
}

Expand Down Expand Up @@ -1718,22 +1799,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
const pir::Operation* reduce_op,
const ShardableAxes& shardable_axes) {
const size_t input_rank = GetRank(reduce_op->operand_source(0));
const auto& reduce_axes = [&]{
const auto& attr_val = reduce_op->attributes().at("dim");
CHECK(attr_val.isa<::pir::ArrayAttribute>());
const auto& axis_attr = attr_val.dyn_cast<::pir::ArrayAttribute>();
std::vector<int64_t> reduce_axes;
for (int i = 0; i < axis_attr.size(); ++i) {
int64_t axis = axis_attr.at(i).dyn_cast<::pir::Int64Attribute>().data();
if (axis < 0) {
axis += input_rank;
}
CHECK_GE(axis, 0);
CHECK_LT(axis, input_rank);
reduce_axes.push_back(axis);
}
return reduce_axes;
}();
const auto& reduce_axes = GetReduceAxes(reduce_op);

// no shardability if input reduced into one element.
if (reduce_axes.empty()) return false;
Expand All @@ -1747,11 +1813,7 @@ class LoopAlignableClusteringPolicy final : public ClusteringPolicy {
};
return std::find_if(shardable_axes.begin(), shardable_axes.end(), Condition) != shardable_axes.end();
};
const bool keepdims = [&]{
const auto& attr_val = reduce_op->attributes().at("keep_dim");
CHECK(attr_val.isa<::pir::BoolAttribute>());
return attr_val.dyn_cast<::pir::BoolAttribute>();
}();
const bool keepdims = GetReduceOpKeepDims(reduce_op);
if (keepdims) {
const size_t output_rank = input_rank;
CHECK(!reduce_axes.empty());
Expand Down

0 comments on commit 0937581

Please sign in to comment.