Skip to content

Commit

Permalink
[OneDNN][PIR] Add depthwise_conv_onednn_pass (#63051)
Browse files Browse the repository at this point in the history
* first commit of depthwise conv pass

* style fix

* add copy_onnx

* add other create to test

* check if onednn pass not register

* add ifdef PADDLE_WITH_DNNL

* add WITH_MKLDNN

* fix style bug

* add PADDLE_WITH_DNNL

* add condition in onnx

* SKIP WIN32 CI

* name change mkl to onednn

* change name

* use python ut for depthwise conv

* delete skipif

* Rename test_depthwise_conv_onednn_pass.py to test_pir_depthwise_conv_onednn_pass.py
  • Loading branch information
zhanglirong1999 authored Apr 9, 2024
1 parent e7a515b commit af9b069
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 0 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ const std::vector<std::string> kPirXpuPasses{// Functional pass
"add_layernorm_xpu_fuse_pass"};

const std::vector<std::string> kPirMkldnnPasses{
"depthwise_conv_onednn_pass",
"squeeze_transpose_onednn_fuse_pass",
"conv2d_bias_fuse_pass",
"conv2d_transpose_bias_fuse_pass",
Expand Down
107 changes: 107 additions & 0 deletions paddle/fluid/pir/transforms/onednn/depthwise_conv_onednn_pass.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/pir/transforms/onednn/depthwise_conv_onednn_pass.h"

#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/include/drr_pattern_base.h"

#include "paddle/pir/include/pass/pass.h"
#include "paddle/pir/include/pass/pass_registry.h"

namespace {

class DepthwiseConvPattern : public paddle::drr::DrrPatternBase {
private:
std::string depthwise_conv_name_;

public:
explicit DepthwiseConvPattern(const std::string &conv_name)
: depthwise_conv_name_(conv_name) {}

std::string name() const override { return "DepthwiseConvPattern"; }

uint32_t benefit() const override { return 2; }

void operator()(paddle::drr::DrrPatternContext *ctx) const override {
paddle::drr::SourcePattern pat = ctx->SourcePattern();

const auto &depthwise_conv =
pat.Op(depthwise_conv_name_,
{{"strides", pat.Attr("strides")},
{"paddings", pat.Attr("paddings")},
{"padding_algorithm", pat.Attr("padding_algorithm")},
{"dilations", pat.Attr("dilations")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")}});

depthwise_conv({&pat.Tensor("input"), &pat.Tensor("filter")},
{&pat.Tensor("conv_out")});

pat.RequireNativeCall([&](const paddle::drr::MatchContext &match_ctx) {
std::set<std::string> padding_algorithm = {"EXPLICIT", "SAME", "VALID"};
std::set<std::string> data_format = {"NCHW", "NHWC", "AnyLayout"};
if (padding_algorithm.count(
match_ctx.Attr<std::string>("padding_algorithm")) == 0 ||
data_format.count(match_ctx.Attr<std::string>("data_format")) == 0 ||
match_ctx.Attr<int>("groups") < 1) {
return false;
}
return true;
});

paddle::drr::ResultPattern res = pat.ResultPattern();

const auto &conv2d =
res.Op(paddle::dialect::Conv2dOp::name(),
{{
{"strides", pat.Attr("strides")},
{"paddings", pat.Attr("paddings")},
{"padding_algorithm", pat.Attr("padding_algorithm")},
{"dilations", pat.Attr("dilations")},
{"groups", pat.Attr("groups")},
{"data_format", pat.Attr("data_format")},
}});

conv2d({&res.Tensor("input"), &res.Tensor("filter")},
{&res.Tensor("conv_out")});
}
};

class DepthwiseConvMKLDNNPass : public pir::PatternRewritePass {
public:
DepthwiseConvMKLDNNPass()
: pir::PatternRewritePass("depthwise_conv_mkldnn_pass", 2) {}

pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override {
pir::RewritePatternSet ps(context);
ps.Add(paddle::drr::Create<DepthwiseConvPattern>(
context, paddle::dialect::DepthwiseConv2dOp::name()));
return ps;
}
};

} // namespace

namespace pir {

std::unique_ptr<Pass> CreateDepthwiseConvMKLDNNPass() {
// pd_op.depthwise_conv -> pd_op.conv2d
return std::make_unique<DepthwiseConvMKLDNNPass>();
}

} // namespace pir

REGISTER_IR_PASS(depthwise_conv_onednn_pass, DepthwiseConvMKLDNNPass);
26 changes: 26 additions & 0 deletions paddle/fluid/pir/transforms/onednn/depthwise_conv_onednn_pass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <memory>
#include "paddle/pir/include/core/dll_decl.h"

namespace pir {

class Pass;

IR_API std::unique_ptr<Pass> CreateDepthwiseConvMKLDNNPass();

} // namespace pir
1 change: 1 addition & 0 deletions paddle/fluid/pir/transforms/passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ USE_PIR_PASS(fused_dot_product_attention_pass);
USE_PIR_PASS(fused_flash_attn_pass);

#ifdef PADDLE_WITH_DNNL
USE_PIR_PASS(depthwise_conv_onednn_pass);
USE_PIR_PASS(squeeze_transpose_onednn_fuse_pass);
USE_PIR_PASS(batch_norm_act_fuse_pass);
USE_PIR_PASS(conv2d_bias_fuse_pass);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import numpy as np
from pass_test import PassTest

import paddle

paddle.enable_static()


class TestConv2dAddFusePass(PassTest):
def is_program_valid(self, program=None):
return True

def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[5, 2, 5, 5], dtype='float32'
)

conv2d = paddle.nn.Conv2D(
in_channels=2,
out_channels=2,
kernel_size=[2, 2],
groups=2,
stride=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
data_format='NCHW',
bias_attr=False,
)

conv2d_out = conv2d(x)
out = paddle.assign(conv2d_out)
self.pass_list = ['depthwise_conv_onednn_pass']

self.feeds = {
"x": np.random.random((5, 2, 5, 5)).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.conv2d": 1,
}
return [main_prog, start_prog]

def sample_program(self):
yield self.build_ir_program(), False

def setUp(self):
self.places.append(paddle.CPUPlace())

def test_check_output(self):
self.check_pass_correct()


if __name__ == "__main__":
unittest.main()

0 comments on commit af9b069

Please sign in to comment.