-
Notifications
You must be signed in to change notification settings - Fork 5.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[OneDNN][PIR] Add depthwise_conv_onednn_pass (#63051)
* first commit of depthwise conv pass * style fix * add copy_onnx * add other create to test * check if onednn pass not register * add ifdef PADDLE_WITH_DNNL * add WITH_MKLDNN * fix style bug * add PADDLE_WITH_DNNL * add condition in onnx * SKIP WIN32 CI * name change mkl to onednn * change name * use python ut for depthwise conv * delete skipif * Rename test_depthwise_conv_onednn_pass.py to test_pir_depthwise_conv_onednn_pass.py
- Loading branch information
1 parent
e7a515b
commit af9b069
Showing
5 changed files
with
208 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
107 changes: 107 additions & 0 deletions
107
paddle/fluid/pir/transforms/onednn/depthwise_conv_onednn_pass.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/pir/transforms/onednn/depthwise_conv_onednn_pass.h" | ||
|
||
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h" | ||
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" | ||
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h" | ||
|
||
#include "paddle/pir/include/pass/pass.h" | ||
#include "paddle/pir/include/pass/pass_registry.h" | ||
|
||
namespace { | ||
|
||
class DepthwiseConvPattern : public paddle::drr::DrrPatternBase { | ||
private: | ||
std::string depthwise_conv_name_; | ||
|
||
public: | ||
explicit DepthwiseConvPattern(const std::string &conv_name) | ||
: depthwise_conv_name_(conv_name) {} | ||
|
||
std::string name() const override { return "DepthwiseConvPattern"; } | ||
|
||
uint32_t benefit() const override { return 2; } | ||
|
||
void operator()(paddle::drr::DrrPatternContext *ctx) const override { | ||
paddle::drr::SourcePattern pat = ctx->SourcePattern(); | ||
|
||
const auto &depthwise_conv = | ||
pat.Op(depthwise_conv_name_, | ||
{{"strides", pat.Attr("strides")}, | ||
{"paddings", pat.Attr("paddings")}, | ||
{"padding_algorithm", pat.Attr("padding_algorithm")}, | ||
{"dilations", pat.Attr("dilations")}, | ||
{"groups", pat.Attr("groups")}, | ||
{"data_format", pat.Attr("data_format")}}); | ||
|
||
depthwise_conv({&pat.Tensor("input"), &pat.Tensor("filter")}, | ||
{&pat.Tensor("conv_out")}); | ||
|
||
pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) { | ||
std::set<std::string> padding_algorithm = {"EXPLICIT", "SAME", "VALID"}; | ||
std::set<std::string> data_format = {"NCHW", "NHWC", "AnyLayout"}; | ||
if (padding_algorithm.count( | ||
match_ctx.Attr<std::string>("padding_algorithm")) == 0 || | ||
data_format.count(match_ctx.Attr<std::string>("data_format")) == 0 || | ||
match_ctx.Attr<int>("groups") < 1) { | ||
return false; | ||
} | ||
return true; | ||
}); | ||
|
||
paddle::drr::ResultPattern res = pat.ResultPattern(); | ||
|
||
const auto &conv2d = | ||
res.Op(paddle::dialect::Conv2dOp::name(), | ||
{{ | ||
{"strides", pat.Attr("strides")}, | ||
{"paddings", pat.Attr("paddings")}, | ||
{"padding_algorithm", pat.Attr("padding_algorithm")}, | ||
{"dilations", pat.Attr("dilations")}, | ||
{"groups", pat.Attr("groups")}, | ||
{"data_format", pat.Attr("data_format")}, | ||
}}); | ||
|
||
conv2d({&res.Tensor("input"), &res.Tensor("filter")}, | ||
{&res.Tensor("conv_out")}); | ||
} | ||
}; | ||
|
||
class DepthwiseConvMKLDNNPass : public pir::PatternRewritePass { | ||
public: | ||
DepthwiseConvMKLDNNPass() | ||
: pir::PatternRewritePass("depthwise_conv_mkldnn_pass", 2) {} | ||
|
||
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override { | ||
pir::RewritePatternSet ps(context); | ||
ps.Add(paddle::drr::Create<DepthwiseConvPattern>( | ||
context, paddle::dialect::DepthwiseConv2dOp::name())); | ||
return ps; | ||
} | ||
}; | ||
|
||
} // namespace | ||
|
||
namespace pir { | ||
|
||
std::unique_ptr<Pass> CreateDepthwiseConvMKLDNNPass() { | ||
// pd_op.depthwise_conv -> pd_op.conv2d | ||
return std::make_unique<DepthwiseConvMKLDNNPass>(); | ||
} | ||
|
||
} // namespace pir | ||
|
||
REGISTER_IR_PASS(depthwise_conv_onednn_pass, DepthwiseConvMKLDNNPass); |
26 changes: 26 additions & 0 deletions
26
paddle/fluid/pir/transforms/onednn/depthwise_conv_onednn_pass.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <memory> | ||
#include "paddle/pir/include/core/dll_decl.h" | ||
|
||
namespace pir { | ||
|
||
class Pass; | ||
|
||
IR_API std::unique_ptr<Pass> CreateDepthwiseConvMKLDNNPass(); | ||
|
||
} // namespace pir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
73 changes: 73 additions & 0 deletions
73
test/ir/pir/fused_pass/onednn/test_pir_depthwise_conv_onednn_pass.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import unittest | ||
|
||
import numpy as np | ||
from pass_test import PassTest | ||
|
||
import paddle | ||
|
||
paddle.enable_static() | ||
|
||
|
||
class TestConv2dAddFusePass(PassTest): | ||
def is_program_valid(self, program=None): | ||
return True | ||
|
||
def build_ir_program(self): | ||
with paddle.pir_utils.IrGuard(): | ||
main_prog = paddle.static.Program() | ||
start_prog = paddle.static.Program() | ||
with paddle.pir.core.program_guard(main_prog, start_prog): | ||
x = paddle.static.data( | ||
name='x', shape=[5, 2, 5, 5], dtype='float32' | ||
) | ||
|
||
conv2d = paddle.nn.Conv2D( | ||
in_channels=2, | ||
out_channels=2, | ||
kernel_size=[2, 2], | ||
groups=2, | ||
stride=[1, 1], | ||
padding=[1, 1, 1, 1], | ||
dilation=[1, 1], | ||
data_format='NCHW', | ||
bias_attr=False, | ||
) | ||
|
||
conv2d_out = conv2d(x) | ||
out = paddle.assign(conv2d_out) | ||
self.pass_list = ['depthwise_conv_onednn_pass'] | ||
|
||
self.feeds = { | ||
"x": np.random.random((5, 2, 5, 5)).astype("float32"), | ||
} | ||
self.fetch_list = [out] | ||
self.valid_op_map = { | ||
"pd_op.conv2d": 1, | ||
} | ||
return [main_prog, start_prog] | ||
|
||
def sample_program(self): | ||
yield self.build_ir_program(), False | ||
|
||
def setUp(self): | ||
self.places.append(paddle.CPUPlace()) | ||
|
||
def test_check_output(self): | ||
self.check_pass_correct() | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |