Skip to content

Commit

Permalink
perfect unittest for pir (PaddlePaddle#61076)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanRisheng authored and eee4017 committed Jan 30, 2024
1 parent 5ba536c commit 070a394
Show file tree
Hide file tree
Showing 13 changed files with 96 additions and 81 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
'FusedConv2dAddActInferMeta',
'InterpolateInferMeta',
'DeformableConvInferMeta',
'MatrixNMSInferMeta',
}

_PREPARE_DATA_WITH_VECTOR_INT64_MTTABLE_ATTRIBUTE = {'FrobeniusNormOp'}
Expand Down
144 changes: 73 additions & 71 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2154,49 +2154,50 @@ def mm(input, mat2, name=None):
"""
if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _C_ops.matmul(input, mat2, False, False)
else:

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'mm'
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'mm'
)
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
x_shape = [1] + x_shape
if len(y_shape) == 1:
y_shape = y_shape + [1]

# check the inner 2 dimensions
if x_shape[-1] != y_shape[-2]:
if not ((x_shape[-1] == -1) or (y_shape[-2] == -1)):
raise ValueError(
"After performing an optional transpose, Input X's width should be "
"equal to Y's width for multiplication "
"prerequisites. But received X's shape: {}, Y's shape: {}\n".format(
x_shape, y_shape
)
)
x_shape = list(x.shape)
y_shape = list(y.shape)
if len(x_shape) == 1:
x_shape = [1] + x_shape
if len(y_shape) == 1:
y_shape = y_shape + [1]

# check the inner 2 dimensions
if x_shape[-1] != y_shape[-2]:
if not ((x_shape[-1] == -1) or (y_shape[-2] == -1)):
if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError(
"After performing an optional transpose, Input X's width should be "
"equal to Y's width for multiplication "
"prerequisites. But received X's shape: {}, Y's shape: {}\n".format(
x_shape, y_shape
)
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape)
)

if len(y_shape) > 2 and len(x_shape) > 2:
for i, dim_x in enumerate(x_shape[:-2]):
# don't check neg shape
if dim_x < 0 or y_shape[i] < 0:
continue
if dim_x != y_shape[i]:
raise ValueError(
"When the matrix is larger than 2 dimensions, the higher "
"dimensional values of the two matrices need to be equal. "
"But received x_shape[%d] != y_shape[%d]. X's shape: %s, "
"Y's shape: %s.\n" % (i, i, x_shape, y_shape)
)

__check_input(input, mat2)

__check_input(input, mat2)
if in_pir_mode():
return _C_ops.matmul(input, mat2, False, False)
else:
helper = LayerHelper('mm', **locals())
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
Expand Down Expand Up @@ -2514,33 +2515,33 @@ def inner(x, y, name=None):
nx = x.reshape((-1, xshape[-1]))
ny = y.reshape((-1, yshape[-1]))

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'inner'
)
x_shape = list(xshape)
y_shape = list(yshape)

# check the inner 2 dimensions
if x_shape[-1] != y_shape[-1]:
if not ((x_shape[-1] == -1) or (y_shape[-1] == -1)):
raise ValueError(
"After performing an optional transpose, Input X's last dim should be "
"equal to Y's last dim for multiplication "
"prerequisites. But received X's shape: {}, Y's shape: {}\n".format(
x_shape, y_shape
)
)

__check_input(nx, ny)

if in_dynamic_or_pir_mode():
return _C_ops.matmul(
nx, paddle.transpose(ny, [1, 0]), False, False
).reshape(dstshape)
else:

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val, name, ['float16', 'float32', 'float64'], 'inner'
)
x_shape = list(xshape)
y_shape = list(yshape)

# check the inner 2 dimensions
if x_shape[-1] != y_shape[-1]:
if not ((x_shape[-1] == -1) or (y_shape[-1] == -1)):
raise ValueError(
"After performing an optional transpose, Input X's last dim should be "
"equal to Y's last dim for multiplication "
"prerequisites. But received X's shape: {}, Y's shape: {}\n".format(
x_shape, y_shape
)
)

__check_input(nx, ny)
helper = LayerHelper('inner', **locals())
out = helper.create_variable_for_type_inference(dtype=nx.dtype)
helper.append_op(
Expand Down Expand Up @@ -2584,22 +2585,23 @@ def outer(x, y, name=None):
nx = x.reshape((-1, 1))
ny = y.reshape((1, -1))

if in_dynamic_or_pir_mode():
if in_dynamic_mode():
return _C_ops.matmul(nx, ny, False, False)
else:

def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val,
name,
['float16', 'float32', 'float64', 'int32', 'int64'],
'outer',
)

__check_input(nx, ny)
def __check_input(x, y):
var_names = {'x': x, 'y': y}
for name, val in var_names.items():
check_variable_and_dtype(
val,
name,
['float16', 'float32', 'float64', 'int32', 'int64'],
'outer',
)

__check_input(nx, ny)
if in_pir_mode():
return _C_ops.matmul(nx, ny, False, False)
else:
helper = LayerHelper('outer', **locals())
out = helper.create_variable_for_type_inference(dtype=nx.dtype)
helper.append_op(
Expand Down
5 changes: 3 additions & 2 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,7 @@ def test_dygraph_api(self):
np.testing.assert_allclose(out_ref, r.numpy(), rtol=1e-05)
paddle.enable_static()

@test_with_pir_api
def test_errors(self):
with static_guard():
with paddle.static.program_guard(paddle.static.Program()):
Expand Down Expand Up @@ -4910,7 +4911,7 @@ def test_check_grad(self):
create_test_act_fp16_class(
TestSilu, check_prim=True, enable_cinn=True, check_prim_pir=True
)
create_test_act_fp16_class(TestLogSigmoid)
create_test_act_fp16_class(TestLogSigmoid, check_pir=True)
create_test_act_fp16_class(
TestTanh, check_prim=True, check_prim_pir=True, enable_cinn=True
)
Expand Down Expand Up @@ -5100,7 +5101,7 @@ def test_check_grad(self):
TestSigmoid, check_prim=True, check_pir=True, check_prim_pir=True
)
create_test_act_bf16_class(TestSilu, check_prim=True, check_prim_pir=True)
create_test_act_bf16_class(TestLogSigmoid)
create_test_act_bf16_class(TestLogSigmoid, check_pir=True)
create_test_act_bf16_class(TestTanh, check_prim=True, check_prim_pir=True)
create_test_act_bf16_class(TestTanhshrink, check_pir=True)
create_test_act_bf16_class(TestHardShrink, check_pir=True)
Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_inner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_multiply_dynamic_case5(self):


class TestMultiplyError(unittest.TestCase):
@test_with_pir_api
def test_errors_static_case1(self):
# test static computation graph: dtype can not be int8
paddle.enable_static()
Expand All @@ -134,6 +135,7 @@ def test_errors_static_case1(self):
y = paddle.static.data(name='y', shape=[100], dtype=np.int8)
self.assertRaises(TypeError, paddle.inner, x, y)

@test_with_pir_api
def test_errors_static_case2(self):
# test static computation graph: inputs must be broadcastable
paddle.enable_static()
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_log_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def test_check_api(self):
self.check_api(axis)
self.check_api(-1, 'float64')

@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
x = paddle.static.data(name='X1', shape=[100], dtype='int32')
Expand Down
1 change: 0 additions & 1 deletion test/legacy_test/test_logcumsumexp_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ def test_gpu(self):

self.run_static(use_gpu=True)

# @test_with_pir_api
def test_name(self):
with base.program_guard(base.Program()):
x = paddle.static.data('x', [3, 4])
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_logsumexp.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,7 @@ def set_attrs_addition(self):


class TestLogsumexpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
self.assertRaises(TypeError, paddle.logsumexp, 1)
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_matmul_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def test_dygraph_without_out(self):


class API_TestMmError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with paddle_static_guard():

Expand Down
2 changes: 2 additions & 0 deletions test/legacy_test/test_matrix_nms_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from op_test import OpTest

import paddle
from paddle.pir_utils import test_with_pir_api


def python_matrix_nms(
Expand Down Expand Up @@ -310,6 +311,7 @@ def set_argument(self):


class TestMatrixNMSError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
M = 1200
N = 7
Expand Down
16 changes: 9 additions & 7 deletions test/legacy_test/test_matrix_power_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ def test_dygraph(self):


class TestMatrixPowerAPIError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
input_np = np.random.random([4, 4]).astype("float64")

Expand All @@ -317,13 +318,6 @@ def test_errors(self):
)
self.assertRaises(TypeError, paddle.linalg.matrix_power, input, 2)

# When out is set, the data type must be the same as input.
input = paddle.static.data(
name="input_1", shape=[4, 4], dtype="float32"
)
out = paddle.static.data(name="output", shape=[4, 4], dtype="float64")
self.assertRaises(TypeError, paddle.linalg.matrix_power, input, 2, out)

# The number of dimensions of input must be >= 2.
input = paddle.static.data(name="input_2", shape=[4], dtype="float32")
self.assertRaises(ValueError, paddle.linalg.matrix_power, input, 2)
Expand All @@ -348,6 +342,14 @@ def test_errors(self):
ValueError, paddle.linalg.matrix_power, input, -956301312
)

def test_old_ir_errors(self):
# When out is set, the data type must be the same as input.
input = paddle.static.data(
name="input_1", shape=[4, 4], dtype="float32"
)
out = paddle.static.data(name="output", shape=[4, 4], dtype="float64")
self.assertRaises(TypeError, paddle.linalg.matrix_power, input, 2, out)


class TestMatrixPowerSingularAPI(unittest.TestCase):
def setUp(self):
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_maxout_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def test_dygraph_api(self):
np.testing.assert_allclose(out3_ref, out3.numpy(), rtol=1e-05)
paddle.enable_static()

@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(paddle.static.Program()):
# The input type must be Variable.
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_multi_dot_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def get_inputs_and_outputs(self):

# python API test
class TestMultiDotOpError(unittest.TestCase):
@test_with_pir_api
def test_errors(self):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
Expand Down
1 change: 1 addition & 0 deletions test/legacy_test/test_outer.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def test_multiply_dynamic(self):


class TestMultiplyError(unittest.TestCase):
@test_with_pir_api
def test_errors_static(self):
# test static computation graph: dtype can not be int8
paddle.enable_static()
Expand Down

0 comments on commit 070a394

Please sign in to comment.