Skip to content

Commit

Permalink
[CINN] Support injective group fusion (#61197)
Browse files Browse the repository at this point in the history
* support injective group fusion

* [CINN+PIR]Fix IsSupportCINN Logic

* fix comment

* merge gather_nd

* refactor some trick code

* revert some code

* fix bug

---------

Co-authored-by: Aurelius84 <zhangliujie@baidu.com>
  • Loading branch information
zyfncg and Aurelius84 authored Jan 30, 2024
1 parent 6054d64 commit 79f3b1c
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h"
#include "paddle/phi/core/flags.h"

#include "paddle/cinn/common/is_reachable_predicator.h"
Expand Down Expand Up @@ -63,6 +64,9 @@ class FuseHelper {
virtual bool HorizontalWithInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

virtual bool InjectiveFuseInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

virtual bool ElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

Expand All @@ -78,6 +82,9 @@ class FuseHelper {
virtual bool ReduceFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

virtual bool ReduceFuseInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

virtual bool ReduceFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

Expand Down Expand Up @@ -121,12 +128,18 @@ class GraphGroupFuseHelper final : public FuseHelper {
bool InjectiveHorizontalWithReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool InjectiveFuseInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool ReduceFuseElementwise(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool ReduceFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool ReduceFuseInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool ReduceFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

Expand Down Expand Up @@ -357,6 +370,12 @@ bool GraphGroupFuseHelper<FusePassCtxT>::HorizontalWithInjective(
return horizontal_with_injective(src.GetGroup(), dst.GetGroup());
}

template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::InjectiveFuseInjective(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return true;
}

template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ElementwiseFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
Expand Down Expand Up @@ -387,6 +406,21 @@ bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseBroadcast(
return reduce_fuse_broadcast(src.GetGroup(), dst.GetGroup());
}

template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseInjective(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
bool can_all_special_ops_fused = false;
dst.WalkOpNodes([&](const OpNode& op) {
can_all_special_ops_fused =
can_all_special_ops_fused &&
SpecialOpsFusionRule::GetInstance().ConsumerOpAllowsFusion(
op.node(), OpPatternKind::kReduction);
});

return can_all_special_ops_fused &&
horizontal_with_injective(src.GetGroup(), dst.GetGroup());
}

template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
Expand Down Expand Up @@ -790,7 +824,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass {
{{OpPatternKind::kInjective, OpPatternKind::kBroadcast},
&DefaultVerticalFusePass::IsSameSize},
{{OpPatternKind::kInjective, OpPatternKind::kInjective},
&DefaultVerticalFusePass::HorizontalWithInjective},
&DefaultVerticalFusePass::InjectiveFuseInjective},
{{OpPatternKind::kInjective, OpPatternKind::kReduction},
&DefaultVerticalFusePass::InjectiveHorizontalWithReduce},

Expand All @@ -799,7 +833,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass {
{{OpPatternKind::kReduction, OpPatternKind::kBroadcast},
&DefaultVerticalFusePass::ReduceFuseBroadcast},
{{OpPatternKind::kReduction, OpPatternKind::kInjective},
&DefaultVerticalFusePass::HorizontalWithInjective},
&DefaultVerticalFusePass::ReduceFuseInjective},
{{OpPatternKind::kReduction, OpPatternKind::kReduction},
&DefaultVerticalFusePass::ReduceFuseReduce},
};
Expand All @@ -823,6 +857,12 @@ class DefaultVerticalFusePass final : public VerticalFusePass {
return ctx->fuse_helper().HorizontalWithInjective(src, dst);
}

static bool InjectiveFuseInjective(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().InjectiveFuseInjective(src, dst);
}

static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
Expand Down Expand Up @@ -853,6 +893,12 @@ class DefaultVerticalFusePass final : public VerticalFusePass {
return ctx->fuse_helper().ReduceFuseBroadcast(src, dst);
}

static bool ReduceFuseInjective(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseInjective(src, dst);
}

static bool ReduceFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class OpNode {
return paddle::get<T>(attr);
}

::pir::Operation* node() const { return node_; }

private:
friend struct std::hash<OpNode>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,13 +609,17 @@ class OpFusionPassHelper {
hlir::framework::pir::CompatibleInfo::OpKind(*consumer))) {
auto& consumer_group = fusion_groups_[consumer];
// second step: check producer can be fused into consumer group
VLOG(3) << "Call ConditionFunction, Producer Op Pattern : "
VLOG(3) << "Call ConditionFunction, Producer Op: [" << producer->name()
<< "] Pattern : "
<< hlir::framework::pir::CompatibleInfo::OpKind(*producer)
<< " , Consumer Group Pattern : "
<< consumer_group->op_pattern_kind;
<< " , Consumer Group [" << consumer->name()
<< "] Pattern : " << consumer_group->op_pattern_kind;

return relation.fusion_op_kind[consumer_group->op_pattern_kind](
bool result = relation.fusion_op_kind[consumer_group->op_pattern_kind](
producer, fusion_groups_[consumer], shape_analysis);
VLOG(3) << " CanFuse: " << result;

return result;
}

return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,11 @@ inline bool horizontal_or_can_inline(
return false;
}
}

// vertical relation: 1.can compute inline
if (producer->result(0).use_count() == 1) {
return true;
}
// if (helper->GetNodeData(producer)->outlinks().size() == 1 &&
// helper->output_ops_set_.count(producer) == 0) {
// return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"

namespace cinn {
namespace dialect {
namespace ir {

bool GatherNdFusionFule(const ::pir::Operation* consumer,
OpPatternKind producer_group_pattern) {
if (producer_group_pattern == OpPatternKind::kReduction) {
return false;
}
return true;
}

bool SliceFusionFule(const ::pir::Operation* consumer,
OpPatternKind producer_group_pattern) {
if (producer_group_pattern == OpPatternKind::kReduction) {
return false;
}
return true;
}

void SpecialOpsFusionRule::Init() {
RegisterConsumerOpRule(paddle::dialect::GatherNdOp::name(),
&GatherNdFusionFule);
RegisterConsumerOpRule(cinn::dialect::SliceOp::name(), &SliceFusionFule);
}

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2024 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/pir/core/operation.h"

namespace cinn {
namespace dialect {
namespace ir {

using OpPatternKind = hlir::framework::OpPatternKind;

class SpecialOpsFusionRule {
public:
typedef bool (*RuleFunc)(const ::pir::Operation*, OpPatternKind);

static const SpecialOpsFusionRule& GetInstance() {
thread_local static SpecialOpsFusionRule instance;
return instance;
}

bool ProducerOpAllowsFusion(const ::pir::Operation* producer,
OpPatternKind consumer_group_pattern) const {
auto iter = producer_op_rules_.find(producer->name());
if (iter != producer_op_rules_.end()) {
return iter->second(producer, consumer_group_pattern);
}
return true;
}

bool ConsumerOpAllowsFusion(const ::pir::Operation* consumer,
OpPatternKind producer_group_pattern) const {
auto iter = consumer_op_rules_.find(consumer->name());
if (iter != consumer_op_rules_.end()) {
return iter->second(consumer, producer_group_pattern);
}
return true;
}

private:
SpecialOpsFusionRule() { Init(); }

SpecialOpsFusionRule(const SpecialOpsFusionRule&) = delete;
SpecialOpsFusionRule(const SpecialOpsFusionRule&&) = delete;
SpecialOpsFusionRule& operator=(const SpecialOpsFusionRule&) = delete;

void Init();

void RegisterProducerOpRule(const std::string& producer_op_name,
RuleFunc rule) {
producer_op_rules_[producer_op_name] = rule;
}

void RegisterConsumerOpRule(const std::string& consumer_op_name,
RuleFunc rule) {
consumer_op_rules_[consumer_op_name] = rule;
}

std::map<std::string, RuleFunc> producer_op_rules_;
std::map<std::string, RuleFunc> consumer_op_rules_;
};

} // namespace ir
} // namespace dialect
} // namespace cinn
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/op/contrib/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ CINN_REGISTER_HELPER(gather_nd_ops) {
MakeOpFunction(cinn::hlir::op::InferShapeForGatherNd))
.set_attr("inferdtype",
MakeOpFunction(cinn::hlir::op::InferDtypeForGatherNd))
.set_attr<cinn::hlir::framework::OpPatternKind>(
"OpPattern", cinn::hlir::framework::OpPatternKind::kInjective)
.set_support_level(4);

return true;
Expand Down
19 changes: 16 additions & 3 deletions test/ir/pir/cinn/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@
# limitations under the License.
import unittest

from test_cinn_sub_graph import TestCinnSubGraphBase, apply_to_static

import paddle
from paddle import nn


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net,
input_spec=input_spec,
build_strategy=build_strategy,
full_graph=True,
)


class RotaryPosEmb(nn.Layer):
def __init__(self):
super().__init__()
Expand All @@ -40,7 +49,11 @@ def rotate_half(self, x):
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x


class TestRotaryPosEmb(TestCinnSubGraphBase):
class TestRotaryPosEmb(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.q = paddle.randn([1, 2048, 8, 96], dtype="float32")
self.q.stop_gradient = False
Expand Down

0 comments on commit 79f3b1c

Please sign in to comment.