Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN] Support injective group fusion #61197

Merged
merged 10 commits into from
Jan 30, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/group_with_group_merge_util.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/op_with_group_merge_util.h"
#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h"
#include "paddle/phi/core/flags.h"

#include "paddle/cinn/common/is_reachable_predicator.h"
Expand Down Expand Up @@ -63,6 +64,9 @@ class FuseHelper {
virtual bool HorizontalWithInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

virtual bool InjectiveFuseInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

virtual bool ElementwiseFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

Expand All @@ -78,6 +82,9 @@ class FuseHelper {
virtual bool ReduceFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

virtual bool ReduceFuseInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

virtual bool ReduceFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const = 0;

Expand Down Expand Up @@ -121,12 +128,18 @@ class GraphGroupFuseHelper final : public FuseHelper {
bool InjectiveHorizontalWithReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool InjectiveFuseInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool ReduceFuseElementwise(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool ReduceFuseBroadcast(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool ReduceFuseInjective(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

bool ReduceFuseReduce(const OpGroupPtr& src,
const OpGroupPtr& dst) const override;

Expand Down Expand Up @@ -357,6 +370,12 @@ bool GraphGroupFuseHelper<FusePassCtxT>::HorizontalWithInjective(
return horizontal_with_injective(src.GetGroup(), dst.GetGroup());
}

template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::InjectiveFuseInjective(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
return true;
}

template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ElementwiseFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
Expand Down Expand Up @@ -387,6 +406,21 @@ bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseBroadcast(
return reduce_fuse_broadcast(src.GetGroup(), dst.GetGroup());
}

template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseInjective(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
bool can_all_special_ops_fused = false;
dst.WalkOpNodes([&](const OpNode& op) {
can_all_special_ops_fused =
can_all_special_ops_fused &&
SpecialOpsFusionRule::GetInstance().ConsumerOpAllowsFusion(
op.node(), OpPatternKind::kReduction);
});

return can_all_special_ops_fused &&
horizontal_with_injective(src.GetGroup(), dst.GetGroup());
}

template <typename FusePassCtxT>
bool GraphGroupFuseHelper<FusePassCtxT>::ReduceFuseReduce(
const OpGroupPtr& src, const OpGroupPtr& dst) const {
Expand Down Expand Up @@ -790,7 +824,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass {
{{OpPatternKind::kInjective, OpPatternKind::kBroadcast},
&DefaultVerticalFusePass::IsSameSize},
{{OpPatternKind::kInjective, OpPatternKind::kInjective},
&DefaultVerticalFusePass::HorizontalWithInjective},
&DefaultVerticalFusePass::InjectiveFuseInjective},
{{OpPatternKind::kInjective, OpPatternKind::kReduction},
&DefaultVerticalFusePass::InjectiveHorizontalWithReduce},

Expand All @@ -799,7 +833,7 @@ class DefaultVerticalFusePass final : public VerticalFusePass {
{{OpPatternKind::kReduction, OpPatternKind::kBroadcast},
&DefaultVerticalFusePass::ReduceFuseBroadcast},
{{OpPatternKind::kReduction, OpPatternKind::kInjective},
&DefaultVerticalFusePass::HorizontalWithInjective},
&DefaultVerticalFusePass::ReduceFuseInjective},
{{OpPatternKind::kReduction, OpPatternKind::kReduction},
&DefaultVerticalFusePass::ReduceFuseReduce},
};
Expand All @@ -823,6 +857,12 @@ class DefaultVerticalFusePass final : public VerticalFusePass {
return ctx->fuse_helper().HorizontalWithInjective(src, dst);
}

static bool InjectiveFuseInjective(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().InjectiveFuseInjective(src, dst);
}

static bool ElementwiseFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
Expand Down Expand Up @@ -853,6 +893,12 @@ class DefaultVerticalFusePass final : public VerticalFusePass {
return ctx->fuse_helper().ReduceFuseBroadcast(src, dst);
}

static bool ReduceFuseInjective(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
return ctx->fuse_helper().ReduceFuseInjective(src, dst);
}

static bool ReduceFuseReduce(LightwareFusePassCtx* ctx,
const OpGroupPtr& src,
const OpGroupPtr& dst) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ class OpNode {
return paddle::get<T>(attr);
}

::pir::Operation* node() const { return node_; }

private:
friend struct std::hash<OpNode>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -609,13 +609,17 @@ class OpFusionPassHelper {
hlir::framework::pir::CompatibleInfo::OpKind(*consumer))) {
auto& consumer_group = fusion_groups_[consumer];
// second step: check producer can be fused into consumer group
VLOG(3) << "Call ConditionFunction, Producer Op Pattern : "
VLOG(3) << "Call ConditionFunction, Producer Op: [" << producer->name()
<< "] Pattern : "
<< hlir::framework::pir::CompatibleInfo::OpKind(*producer)
<< " , Consumer Group Pattern : "
<< consumer_group->op_pattern_kind;
<< " , Consumer Group [" << consumer->name()
<< "] Pattern : " << consumer_group->op_pattern_kind;

return relation.fusion_op_kind[consumer_group->op_pattern_kind](
bool result = relation.fusion_op_kind[consumer_group->op_pattern_kind](
producer, fusion_groups_[consumer], shape_analysis);
VLOG(3) << " CanFuse: " << result;

return result;
}

return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,11 @@ inline bool horizontal_or_can_inline(
return false;
}
}

// vertical relation: 1.can compute inline
if (producer->result(0).use_count() == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这是能融合slice/concat的tricky代码吗?

Copy link
Contributor Author

@zyfncg zyfncg Jan 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,这里的判断不算很tricky,如果融合规则明确的话,后面是可以直接使用的

return true;
}
// if (helper->GetNodeData(producer)->outlinks().size() == 1 &&
// helper->output_ops_set_.count(producer) == 0) {
// return true;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/dialect/operator/transforms/group_merge/special_ops_fusion_rule.h"

#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"

namespace cinn {
namespace dialect {
namespace ir {

bool GatherNdFusionFule(const ::pir::Operation* consumer,
OpPatternKind producer_group_pattern) {
if (producer_group_pattern == OpPatternKind::kReduction) {
return false;
}
return true;
}

bool SliceFusionFule(const ::pir::Operation* consumer,
OpPatternKind producer_group_pattern) {
if (producer_group_pattern == OpPatternKind::kReduction) {
return false;
}
return true;
}

void SpecialOpsFusionRule::Init() {
RegisterConsumerOpRule(paddle::dialect::GatherNdOp::name(),
&GatherNdFusionFule);
RegisterConsumerOpRule(cinn::dialect::SliceOp::name(), &SliceFusionFule);
}

} // namespace ir
} // namespace dialect
} // namespace cinn
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright (c) 2024 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/pir/core/operation.h"

namespace cinn {
namespace dialect {
namespace ir {

using OpPatternKind = hlir::framework::OpPatternKind;

class SpecialOpsFusionRule {
public:
typedef bool (*RuleFunc)(const ::pir::Operation*, OpPatternKind);

static const SpecialOpsFusionRule& GetInstance() {
thread_local static SpecialOpsFusionRule instance;
return instance;
}

bool ProducerOpAllowsFusion(const ::pir::Operation* producer,
OpPatternKind consumer_group_pattern) const {
auto iter = producer_op_rules_.find(producer->name());
if (iter != producer_op_rules_.end()) {
return iter->second(producer, consumer_group_pattern);
}
return true;
}

bool ConsumerOpAllowsFusion(const ::pir::Operation* consumer,
OpPatternKind producer_group_pattern) const {
auto iter = consumer_op_rules_.find(consumer->name());
if (iter != consumer_op_rules_.end()) {
return iter->second(consumer, producer_group_pattern);
}
return true;
}

private:
SpecialOpsFusionRule() { Init(); }

SpecialOpsFusionRule(const SpecialOpsFusionRule&) = delete;
SpecialOpsFusionRule(const SpecialOpsFusionRule&&) = delete;
SpecialOpsFusionRule& operator=(const SpecialOpsFusionRule&) = delete;

void Init();

void RegisterProducerOpRule(const std::string& producer_op_name,
RuleFunc rule) {
producer_op_rules_[producer_op_name] = rule;
}

void RegisterConsumerOpRule(const std::string& consumer_op_name,
RuleFunc rule) {
consumer_op_rules_[consumer_op_name] = rule;
}

std::map<std::string, RuleFunc> producer_op_rules_;
std::map<std::string, RuleFunc> consumer_op_rules_;
};

} // namespace ir
} // namespace dialect
} // namespace cinn
2 changes: 2 additions & 0 deletions paddle/cinn/hlir/op/contrib/gather_nd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ CINN_REGISTER_HELPER(gather_nd_ops) {
MakeOpFunction(cinn::hlir::op::InferShapeForGatherNd))
.set_attr("inferdtype",
MakeOpFunction(cinn::hlir::op::InferDtypeForGatherNd))
.set_attr<cinn::hlir::framework::OpPatternKind>(
"OpPattern", cinn::hlir::framework::OpPatternKind::kInjective)
.set_support_level(4);

return true;
Expand Down
19 changes: 16 additions & 3 deletions test/ir/pir/cinn/test_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@
# limitations under the License.
import unittest

from test_cinn_sub_graph import TestCinnSubGraphBase, apply_to_static

import paddle
from paddle import nn


def apply_to_static(net, use_cinn, input_spec=None):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(
net,
input_spec=input_spec,
build_strategy=build_strategy,
full_graph=True,
)


class RotaryPosEmb(nn.Layer):
def __init__(self):
super().__init__()
Expand All @@ -40,7 +49,11 @@ def rotate_half(self, x):
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x


class TestRotaryPosEmb(TestCinnSubGraphBase):
class TestRotaryPosEmb(unittest.TestCase):
def setUp(self):
paddle.seed(2022)
self.prepare_data()

def prepare_data(self):
self.q = paddle.randn([1, 2048, 8, 96], dtype="float32")
self.q.stop_gradient = False
Expand Down