Skip to content

Commit

Permalink
Cinn trivalop fuse (PaddlePaddle#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Mar 12, 2024
1 parent dbc7a90 commit 55f975c
Showing 1 changed file with 33 additions and 10 deletions.
43 changes: 33 additions & 10 deletions paddle/cinn/hlir/framework/pir/trivial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,28 @@ struct MappingTargetExprToDestExprMutator : public ir::IRMutator<> {
void operator()(Expr* expr) { IRMutator::Visit(expr, expr); }

private:
void Visit(const ir::Expr* current, Expr* op) override {
if (current == &source_) {
void Visit(const ir::Load* load, Expr* op) override {
if (load == source_.ptr()) {
VLOG(4) << "substitude find!";
*op = dest_;
} else {
IRMutator::Visit(current, op);
IRMutator::Visit(load, op);
}
}
void Visit(const ir::Store* store, Expr* op) override {
if (store == source_.ptr()) {
VLOG(4) << "substitude find!";
*op = dest_;
} else {
IRMutator::Visit(store, op);
}
}
void Visit(const ir::Reduce* reduce, Expr* op) override {
if (reduce == source_.ptr()) {
VLOG(4) << "substitude find!";
*op = dest_;
} else {
IRMutator::Visit(reduce, op);
}
}

Expand Down Expand Up @@ -496,8 +512,8 @@ TrivialOp TTFusion(TrivialOp upstream, TrivialOp downstream) {
VLOG(4) << "TTFusion begin.";

const auto& replaced_tensor = upstream.GetOutputTensor();
VLOG(4) << "connected tensor is:" << replaced_tensor;
VLOG(4) << "store value is :" << downstream.GetStoreValue();
VLOG(4) << "upstream is " << upstream.GetFuncBody();
VLOG(4) << "downstream is " << downstream.GetFuncBody();

TrivialOp fused(ir::ir_utils::IRCopy(downstream.GetFuncBody()));
SequenceMutator(
Expand All @@ -517,8 +533,8 @@ ReduceOp TRFusion(TrivialOp upstream, ReduceOp downstream) {
VLOG(4) << "TRFusion begin.";

const auto& replaced_tensor = upstream.GetOutputTensor();
VLOG(4) << "connected tensor is:" << replaced_tensor;
VLOG(4) << "store value is :" << downstream.GetStoreValue();
VLOG(4) << "upstream is " << upstream.GetFuncBody();
VLOG(4) << "downstream is " << downstream.GetFuncBody();

ReduceOp fused(ir::ir_utils::IRCopy(downstream.GetFuncBody()));
SequenceMutator(
Expand Down Expand Up @@ -580,10 +596,10 @@ FusibleOp TrivialFusion(FusionNode* upstream, FusionNode* downstream) {
CHECK(upstream->IsTrivial());
if (downstream->IsTrivial()) {
return TTFusion(std::get<TrivialOp>(upstream->fusible_op),
std::get<TrivialOp>(upstream->fusible_op));
std::get<TrivialOp>(downstream->fusible_op));
} else {
return TRFusion(std::get<TrivialOp>(upstream->fusible_op),
std::get<ReduceOp>(upstream->fusible_op));
std::get<ReduceOp>(downstream->fusible_op));
}
}

Expand All @@ -608,7 +624,11 @@ std::vector<ReduceOp> ReduceTransformRecursive(ReduceOp reduce_op,

std::vector<FusibleOp> ReduceTransform(FusionNode* downstream) {
if (downstream->IsTrivial()) {
PADDLE_THROW("TODO: implement the R + T fusion.");
if (downstream->upstream.empty()) {
return {downstream->fusible_op};
} else {
PADDLE_THROW("TODO: implement the R + T fusion.");
}
} else {
auto reduces = ReduceTransformRecursive(
std::get<ReduceOp>(downstream->fusible_op), downstream);
Expand Down Expand Up @@ -697,8 +717,11 @@ struct FusionGraph {
}

std::vector<ir::Expr> DoFusion() {
VLOG(4) << "Start Trivial Fusion";
DoTrivialFusion();
VLOG(4) << "Start Transform T2R";
TransformSinkTrivialOpToReduce();
VLOG(4) << "Start RR Fusion";
ReduceLoopTranform();
return GetExprResults();
}
Expand Down

0 comments on commit 55f975c

Please sign in to comment.