Skip to content

Commit

Permalink
use python ut for depthwise conv
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanglirong1999 committed Apr 2, 2024
1 parent 7c06e68 commit f5ebb8f
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 74 deletions.
5 changes: 0 additions & 5 deletions test/cpp/pir/pattern_rewrite/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@ paddle_test(drr_fuse_linear_test SRCS drr_fuse_linear_test.cc)
paddle_test(drr_fuse_linear_param_grad_add_test SRCS
drr_fuse_linear_param_grad_add_test.cc)

if(WITH_MKLDNN AND NOT WIN32)
paddle_test(depthwise_conv_onednn_pass_test SRCS
depthwise_conv_onednn_pass_test.cc)
endif()

if(WITH_GPU)
paddle_test(drr_attention_fuse_test SRCS drr_attention_fuse_test.cc)
endif()
Expand Down
69 changes: 0 additions & 69 deletions test/cpp/pir/pattern_rewrite/depthwise_conv_onednn_pass_test.cc

This file was deleted.

77 changes: 77 additions & 0 deletions test/ir/pir/fused_pass/onednn/test_depthwise_conv_onednn_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest

import numpy as np
from pass_test import PassTest

import paddle

paddle.enable_static()


@unittest.skipIf(
not paddle.base.core.is_compiled_with_mkldnn(),
"Test case only for OneDNN pass.",
)
class TestConv2dAddFusePass(PassTest):
def is_program_valid(self, program=None):
return True

def build_ir_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[5, 2, 5, 5], dtype='float32'
)

conv2d = paddle.nn.Conv2D(
in_channels=2,
out_channels=2,
kernel_size=[2, 2],
groups=2,
stride=[1, 1],
padding=[1, 1, 1, 1],
dilation=[1, 1],
data_format='NCHW',
bias_attr=False,
)

conv2d_out = conv2d(x)
out = paddle.assign(conv2d_out)
self.pass_list = ['depthwise_conv_onednn_pass']

self.feeds = {
"x": np.random.random((5, 2, 5, 5)).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.conv2d": 1,
}
return [main_prog, start_prog]

def sample_program(self):
yield self.build_ir_program(), False

def setUp(self):
self.places.append(paddle.CPUPlace())

def test_check_output(self):
self.check_pass_correct()


if __name__ == "__main__":
unittest.main()

0 comments on commit f5ebb8f

Please sign in to comment.