diff --git a/paddle/fluid/operators/expand_as_v2_op.h b/paddle/fluid/operators/expand_as_v2_op.h index 2c62dc570ff214..abc89ba75c6717 100644 --- a/paddle/fluid/operators/expand_as_v2_op.h +++ b/paddle/fluid/operators/expand_as_v2_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/expand_op.cc b/paddle/fluid/operators/expand_op.cc index 71295296218f02..bd558ee944359d 100644 --- a/paddle/fluid/operators/expand_op.cc +++ b/paddle/fluid/operators/expand_op.cc @@ -44,10 +44,11 @@ class ExpandOp : public framework::OperatorWithKernel { static_cast(x_dims.size()))); PADDLE_ENFORCE_LE( x_dims.size(), - 6, + MAX_RANK_SUPPORTED, platform::errors::InvalidArgument( "The number of dimensions of the input for Op(expand) " - "must not be greater than 6, but the value received is %d.", + "must not be greater than %d, but the value received is %d.", + MAX_RANK_SUPPORTED, x_dims.size())); std::vector out_shape(x_dims.size()); @@ -98,7 +99,7 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "(Tensor, default Tensor). A tensor with rank in [1, 8]." "X is the input to be expanded."); AddInput("ExpandTimes", "(Tensor), optional). If provided, expand according to " @@ -112,7 +113,7 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { .AsDuplicable() .AsDispensable(); AddOutput("Out", - "(Tensor, default Tensor). A tensor with rank in [1, 6]." + "(Tensor, default Tensor). A tensor with rank in [1, 8]." "The rank of Output(Out) have the same with Input(X). " "After expanding, size of each dimension of Output(Out) is equal " "to size of the corresponding dimension of Input(X) multiplying " @@ -123,7 +124,7 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( Expand operator tiles the input by given times number. You should set times number for each dimension by providing attribute 'expand_times'. The rank of X -should be in [1, 6]. Please note that size of 'expand_times' must be the same +should be in [1, 8]. Please note that size of 'expand_times' must be the same with X's rank. Following is a using case: Input(X) is a 3-D tensor with shape [2, 3, 1]: [ diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index ee100b3b484184..3d9fbe883b31b4 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -21,7 +21,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace paddle { namespace operators { @@ -128,6 +128,12 @@ class ExpandKernel : public framework::OpKernel { case 6: Expand<6>(context); break; + case 7: + Expand<7>(context); + break; + case 8: + Expand<8>(context); + break; } } @@ -249,10 +255,17 @@ class ExpandGradKernel : public framework::OpKernel { case 6: ExpandBackward<6>(context, reshape_dims_vec, reduce_dims_vec); break; + case 7: + ExpandBackward<7>(context, reshape_dims_vec, reduce_dims_vec); + break; + case 8: + ExpandBackward<8>(context, reshape_dims_vec, reduce_dims_vec); + break; default: PADDLE_THROW(platform::errors::InvalidArgument( - "Only support tensor with rank being between 1 and 6. But " + "Only support tensor with rank being between 1 and %d. But " "received tensor's rank = %d.", + MAX_RANK_SUPPORTED, dims)); } } diff --git a/paddle/fluid/operators/expand_v2_op.h b/paddle/fluid/operators/expand_v2_op.h index 0a70faddb7d589..b61cf2dc485e5c 100644 --- a/paddle/fluid/operators/expand_v2_op.h +++ b/paddle/fluid/operators/expand_v2_op.h @@ -22,7 +22,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace paddle { namespace operators { diff --git a/paddle/fluid/prim/api/manual_prim/utils/utils.h b/paddle/fluid/prim/api/manual_prim/utils/utils.h index 9062d979b40db0..cbbe8466711140 100644 --- a/paddle/fluid/prim/api/manual_prim/utils/utils.h +++ b/paddle/fluid/prim/api/manual_prim/utils/utils.h @@ -88,7 +88,7 @@ static phi::DDim get_reduce_dims(const phi::DDim& x_dims, * y_dims = [2, 1, 6, 1] <-- shaped are right-aligned for comparison * <-- broadcast --> * z_dims = [10, 2, 4, 6, 5] - * ==> reduce_dims_from_z_to_x = [0, 1, 3] + * ==> reduce_dims_from_z_to_x = [1, 3] * ==> reduce_dims_from_z_to_y = [0, 2, 4] */ auto out_dims = paddle::operators::details::BroadcastTwoDims(x_dims, y_dims); diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 97edce9ad79530..63d1d1c9b32d0e 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1532,7 +1532,7 @@ void ExpandAsInferMeta(const MetaTensor& x, const MetaTensor& y, const std::vector& target_shape, MetaTensor* out) { -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 auto x_dims = x.dims(); PADDLE_ENFORCE_GE( target_shape.size(), diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 74d04da5de8f2b..a152bc152ae6bb 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -1219,7 +1219,7 @@ void EinsumRawInferMeta(const std::vector& inputs, void ExpandInferMeta(const MetaTensor& x, const IntArray& shape, MetaTensor* out) { -#define MAX_RANK_SUPPORTED 6 +#define EXPAND_MAX_RANK_SUPPORTED 8 auto x_dims = x.dims(); auto expand_shape = shape.GetData(); @@ -1238,11 +1238,11 @@ void ExpandInferMeta(const MetaTensor& x, static_cast(x_dims.size()))); PADDLE_ENFORCE_LE( expand_shape.size(), - MAX_RANK_SUPPORTED, + EXPAND_MAX_RANK_SUPPORTED, phi::errors::InvalidArgument("The number of elements (%d) of 'shape' for " "must not be greater than %d.", expand_shape.size(), - MAX_RANK_SUPPORTED)); + EXPAND_MAX_RANK_SUPPORTED)); PADDLE_ENFORCE_GE( expand_shape.size(), 0, @@ -1283,6 +1283,7 @@ void ExpandInferMeta(const MetaTensor& x, if (out_rank > 0 && out_shape[0] == x_dims[0]) { out->share_lod(x); } +#undef EXPAND_MAX_RANK_SUPPORTED } void FillAnyLikeInferMeta(const MetaTensor& x, @@ -4722,7 +4723,7 @@ void TileInferMeta(const MetaTensor& x, const IntArray& repeat_times, MetaTensor* out, MetaConfig config) { -#define MAX_RANK_SUPPORTED 6 +#define TILE_MAX_RANK_SUPPORTED 6 auto repeat_times_data = repeat_times.GetData(); auto x_dims = x.dims(); @@ -4732,19 +4733,19 @@ void TileInferMeta(const MetaTensor& x, PADDLE_ENFORCE_LE( x_dims.size(), - MAX_RANK_SUPPORTED, + TILE_MAX_RANK_SUPPORTED, errors::InvalidArgument( "The rank of the input 'x' for tile op " "must not be greater than %d, but the value received is %d.", - MAX_RANK_SUPPORTED, + TILE_MAX_RANK_SUPPORTED, x_dims.size())); PADDLE_ENFORCE_LE( repeat_times_data.size(), - MAX_RANK_SUPPORTED, + TILE_MAX_RANK_SUPPORTED, errors::InvalidArgument( "The size of the shape of input 'repeat_times' for tile op " "must not be greater than %d, but the value received is %d.", - MAX_RANK_SUPPORTED, + TILE_MAX_RANK_SUPPORTED, repeat_times_data.size())); PADDLE_ENFORCE_GE( repeat_times_data.size(), @@ -4785,6 +4786,7 @@ void TileInferMeta(const MetaTensor& x, out->share_lod(x); } out->set_dtype(x.dtype()); +#undef TILE_MAX_RANK_SUPPORTED } void TopKInferMeta(const MetaTensor& x, diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cc b/paddle/phi/kernels/funcs/eigen/broadcast.cc index 04e13a6799931d..0bf9d37d60e4a4 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cc +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cc @@ -73,7 +73,9 @@ struct EigenBroadcastGrad { template struct FUNCTOR; \ template struct FUNCTOR; \ template struct FUNCTOR; \ - template struct FUNCTOR + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::bfloat16); diff --git a/paddle/phi/kernels/funcs/eigen/broadcast.cu b/paddle/phi/kernels/funcs/eigen/broadcast.cu index 0c5a3408872c47..fe16588c9bce6e 100644 --- a/paddle/phi/kernels/funcs/eigen/broadcast.cu +++ b/paddle/phi/kernels/funcs/eigen/broadcast.cu @@ -72,7 +72,9 @@ struct EigenBroadcastGrad { template struct FUNCTOR; \ template struct FUNCTOR; \ template struct FUNCTOR; \ - template struct FUNCTOR + template struct FUNCTOR; \ + template struct FUNCTOR; \ + template struct FUNCTOR INSTANTIATION(EigenBroadcast, bool); INSTANTIATION(EigenBroadcast, dtype::float16); INSTANTIATION(EigenBroadcast, dtype::bfloat16); diff --git a/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h b/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h index 54ef6e0c1f9cb7..2b1d0d60bee50a 100644 --- a/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_as_grad_kernel_impl.h @@ -116,10 +116,19 @@ void ExpandAsGradKernel(const Context& context, ExpandAsBackward( context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); break; + case 7: + ExpandAsBackward( + context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 8: + ExpandAsBackward( + context, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; default: PADDLE_THROW(errors::InvalidArgument( - "Only support tensor with rank being between 1 and 6. But " + "Only support tensor with rank being between 1 and %d. But " "received tensor's rank = %d.", + MAX_RANK_SUPPORTED, dims)); } } diff --git a/paddle/phi/kernels/impl/expand_as_kernel_impl.h b/paddle/phi/kernels/impl/expand_as_kernel_impl.h index cee562b42778e1..927cd73b3eb4ed 100755 --- a/paddle/phi/kernels/impl/expand_as_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_as_kernel_impl.h @@ -20,7 +20,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace phi { @@ -158,6 +158,12 @@ void ExpandAsKernel(const Context& ctx, case 6: ExpandAs(ctx, x, real_target_shape, out); break; + case 7: + ExpandAs(ctx, x, real_target_shape, out); + break; + case 8: + ExpandAs(ctx, x, real_target_shape, out); + break; } } diff --git a/paddle/phi/kernels/impl/expand_grad_kernel_impl.h b/paddle/phi/kernels/impl/expand_grad_kernel_impl.h index 4dd9dc4d50337a..f24fff253558a4 100644 --- a/paddle/phi/kernels/impl/expand_grad_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_grad_kernel_impl.h @@ -128,10 +128,19 @@ void ExpandGradKernel(const Context& ctx, ExpandBackward( ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); break; + case 7: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; + case 8: + ExpandBackward( + ctx, out_grad, reshape_dims_vec, reduce_dims_vec, in_grad); + break; default: PADDLE_THROW(phi::errors::InvalidArgument( - "Only support tensor with rank being between 1 and 6. But " + "Only support tensor with rank being between 1 and %d. But " "received tensor's rank = %d.", + MAX_RANK_SUPPORTED, dims)); } } diff --git a/paddle/phi/kernels/impl/expand_kernel_impl.h b/paddle/phi/kernels/impl/expand_kernel_impl.h index 181dd2558fa385..7d675e036a55e5 100644 --- a/paddle/phi/kernels/impl/expand_kernel_impl.h +++ b/paddle/phi/kernels/impl/expand_kernel_impl.h @@ -19,7 +19,7 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace phi { using Tensor = DenseTensor; @@ -169,6 +169,12 @@ void ExpandKernel(const Context& ctx, case 6: Expand(ctx, x, shape, out); break; + case 7: + Expand(ctx, x, shape, out); + break; + case 8: + Expand(ctx, x, shape, out); + break; } } diff --git a/paddle/phi/kernels/xpu/expand_as_kernel.cc b/paddle/phi/kernels/xpu/expand_as_kernel.cc index 0701294217f412..45d0515a0b822c 100644 --- a/paddle/phi/kernels/xpu/expand_as_kernel.cc +++ b/paddle/phi/kernels/xpu/expand_as_kernel.cc @@ -17,7 +17,7 @@ #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" -#define MAX_RANK_SUPPORTED 6 +#define MAX_RANK_SUPPORTED 8 namespace phi { diff --git a/test/legacy_test/op_test.py b/test/legacy_test/op_test.py index c18a142a1ec9dc..4e2543571b002a 100644 --- a/test/legacy_test/op_test.py +++ b/test/legacy_test/op_test.py @@ -527,7 +527,7 @@ def is_complex_test(): not in check_shape_white_list.NEED_TO_FIX_OP_LIST ): raise AssertionError( - "Input's shape should be large than or equal to 100 for " + "Number of element(s) of input should be large than or equal to 100 for " + cls.op_type + " Op." ) diff --git a/test/legacy_test/test_expand_v2_op.py b/test/legacy_test/test_expand_v2_op.py index ff96f28ba5caa8..8cbbfb2a2e39a4 100644 --- a/test/legacy_test/test_expand_v2_op.py +++ b/test/legacy_test/test_expand_v2_op.py @@ -110,6 +110,110 @@ def init_data(self): self.expand_times = (1, 1, 1, 1) +class TestExpandV2OpRank5(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [5, 2, 1, 4, 5] + self.shape = [5, 2, 3, 4, 5] + self.expand_times = [1, 1, 3, 1, 1] + + +class TestExpandV2OpRank5_Corner(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [5, 2, 3, 4, 5] + self.shape = [5, 2, 3, 4, 5] + self.expand_times = [1, 1, 1, 1, 1] + + +class TestExpandV2OpRank5_ZeroDim(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [] + self.shape = [5, 2, 3, 4, 5] + self.expand_times = [5, 2, 3, 4, 5] + + def if_enable_cinn(self): + self.enable_cinn = False + + +class TestExpandV2OpRank6(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [1, 2, 1, 4, 5, 6] + self.shape = [1, 2, 3, 4, 5, 6] + self.expand_times = [1, 1, 3, 1, 1, 1] + + +class TestExpandV2OpRank6_Corner(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [1, 2, 3, 4, 5, 6] + self.shape = [1, 2, 3, 4, 5, 6] + self.expand_times = [1, 1, 1, 1, 1, 1] + + +class TestExpandV2OpRank6_ZeroDim(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [] + self.shape = [1, 2, 3, 4, 5, 6] + self.expand_times = [1, 2, 3, 4, 5, 6] + + def if_enable_cinn(self): + self.enable_cinn = False + + +class TestExpandV2OpRank7(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [5, 2, 1, 4, 5, 6, 7] + self.shape = [5, 2, 3, 4, 5, 6, 7] + self.expand_times = [1, 1, 3, 1, 1, 1, 1] + + +class TestExpandV2OpRank7_Corner(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [1, 2, 3, 4, 5, 2, 2] + self.shape = [1, 2, 3, 4, 5, 2, 2] + self.expand_times = [1, 1, 1, 1, 1, 1, 1] + + +class TestExpandV2OpRank7_ZeroDim(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [] + self.shape = [1, 2, 3, 4, 5, 6, 7] + self.expand_times = [1, 2, 3, 4, 5, 6, 7] + + def if_enable_cinn(self): + self.enable_cinn = False + + +class TestExpandV2OpRank8(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [1, 2, 1, 4, 5, 6, 7, 8] + self.shape = [1, 2, 3, 4, 5, 6, 7, 8] + self.expand_times = [1, 1, 3, 1, 1, 1, 1, 1] + + +class TestExpandV2OpRank8_Corner(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [1, 2, 3, 4, 5, 2, 2, 2] + self.shape = [1, 2, 3, 4, 5, 2, 2, 2] + self.expand_times = [1, 1, 1, 1, 1, 1, 1, 1] + + def test_check_grad(self): + self.check_grad( + ['X'], + 'Out', + check_prim=True, + check_pir=True, + check_prim_pir=True, + numeric_grad_delta=1e-5, + max_relative_error=2e-7, # need slightly larger than 1e-7. + ) + + +class TestExpandV2OpRank8_ZeroDim(TestExpandV2OpRank1): + def init_data(self): + self.ori_shape = [] + self.shape = [1, 2, 3, 4, 5, 6, 7, 8] + self.expand_times = [1, 2, 3, 4, 5, 6, 7, 8] + + # Situation 2: shape is a list(with tensor) class TestExpandV2OpRank1_tensor_attr(OpTest): def setUp(self): @@ -300,22 +404,23 @@ def test_check_grad(self): class TestExpandV2Error(unittest.TestCase): @test_with_pir_api def test_errors(self): - paddle.enable_static() - with paddle.static.program_guard( - paddle.static.Program(), paddle.static.Program() - ): - shape = [2, 2] - if not in_pir_mode(): - x1 = base.create_lod_tensor( - np.array([[-1]]), [[1]], base.CPUPlace() - ) - self.assertRaises(TypeError, paddle.tensor.expand, x1, shape) - x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="bool") - x2.stop_gradient = False - self.assertRaises(ValueError, paddle.tensor.expand, x2, shape) - x2.stop_gradient = True - self.assertRaises(TypeError, paddle.tensor.expand, x2, 1) - paddle.disable_static() + with static_guard(): + with paddle.static.program_guard( + paddle.static.Program(), paddle.static.Program() + ): + shape = [2, 2] + if not in_pir_mode(): + x1 = base.create_lod_tensor( + np.array([[-1]]), [[1]], base.CPUPlace() + ) + self.assertRaises( + TypeError, paddle.tensor.expand, x1, shape + ) + x2 = paddle.static.data(name='x2', shape=[-1, 4], dtype="bool") + x2.stop_gradient = False + self.assertRaises(ValueError, paddle.tensor.expand, x2, shape) + x2.stop_gradient = True + self.assertRaises(TypeError, paddle.tensor.expand, x2, 1) # Test python API @@ -552,7 +657,7 @@ class TestExpandPirValueListShape(unittest.TestCase): def test_value_list_shape1(self): with static_guard(): with paddle.static.program_guard(paddle.static.Program()): - x = paddle.static.data('x', [1, 3]) + x = paddle.static.data('x', [1, 1]) shape = [2, paddle.full([], 4)] out = paddle.expand(x, shape) np.testing.assert_array_equal(tuple(out.shape), (2, -1)) diff --git a/test/white_list/op_threshold_white_list.py b/test/white_list/op_threshold_white_list.py index 9b9d590fd0a210..351efe8da96b02 100644 --- a/test/white_list/op_threshold_white_list.py +++ b/test/white_list/op_threshold_white_list.py @@ -54,6 +54,7 @@ 'solve', 'qr', 'layer_norm', + # 'expand_v2', ] NEED_FIX_FP64_CHECK_OUTPUT_THRESHOLD_OP_LIST = [