Skip to content

Commit

Permalink
update code
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate committed Mar 31, 2024
1 parent 0a5944a commit 1aa3783
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 12 deletions.
6 changes: 3 additions & 3 deletions paddle/fluid/operators/expand_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"(Tensor, default Tensor<float>). A tensor with rank in [1, 8]."
"X is the input to be expanded.");
AddInput("ExpandTimes",
"(Tensor<int>), optional). If provided, expand according to "
Expand All @@ -113,7 +113,7 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
.AsDuplicable()
.AsDispensable();
AddOutput("Out",
"(Tensor, default Tensor<float>). A tensor with rank in [1, 6]."
"(Tensor, default Tensor<float>). A tensor with rank in [1, 8]."
"The rank of Output(Out) have the same with Input(X). "
"After expanding, size of each dimension of Output(Out) is equal "
"to size of the corresponding dimension of Input(X) multiplying "
Expand All @@ -124,7 +124,7 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Expand operator tiles the input by given times number. You should set times
number for each dimension by providing attribute 'expand_times'. The rank of X
should be in [1, 6]. Please note that size of 'expand_times' must be the same
should be in [1, 8]. Please note that size of 'expand_times' must be the same
with X's rank. Following is a using case:
Input(X) is a 3-D tensor with shape [2, 3, 1]:
[
Expand Down
15 changes: 14 additions & 1 deletion paddle/fluid/operators/expand_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ class ExpandKernel : public framework::OpKernel<T> {
case 6:
Expand<6>(context);
break;
case 7:
Expand<7>(context);
break;
case 8:
Expand<8>(context);
break;
}
}

Expand Down Expand Up @@ -249,10 +255,17 @@ class ExpandGradKernel : public framework::OpKernel<T> {
case 6:
ExpandBackward<6>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 7:
ExpandBackward<7>(context, reshape_dims_vec, reduce_dims_vec);
break;
case 8:
ExpandBackward<8>(context, reshape_dims_vec, reduce_dims_vec);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support tensor with rank being between 1 and 6. But "
"Only support tensor with rank being between 1 and %d. But "
"received tensor's rank = %d.",
MAX_RANK_SUPPORTED,
dims));
}
}
Expand Down
18 changes: 10 additions & 8 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1219,7 +1219,7 @@ void EinsumRawInferMeta(const std::vector<const MetaTensor*>& inputs,
void ExpandInferMeta(const MetaTensor& x,
const IntArray& shape,
MetaTensor* out) {
#define EXPAND_MAX_RANK_SUPPORTED 8
#define MAX_RANK_SUPPORTED 8
auto x_dims = x.dims();
auto expand_shape = shape.GetData();

Expand All @@ -1238,11 +1238,11 @@ void ExpandInferMeta(const MetaTensor& x,
static_cast<size_t>(x_dims.size())));
PADDLE_ENFORCE_LE(
expand_shape.size(),
EXPAND_MAX_RANK_SUPPORTED,
MAX_RANK_SUPPORTED,
phi::errors::InvalidArgument("The number of elements (%d) of 'shape' for "
"must not be greater than %d.",
expand_shape.size(),
EXPAND_MAX_RANK_SUPPORTED));
MAX_RANK_SUPPORTED));
PADDLE_ENFORCE_GE(
expand_shape.size(),
0,
Expand Down Expand Up @@ -1283,6 +1283,7 @@ void ExpandInferMeta(const MetaTensor& x,
if (out_rank > 0 && out_shape[0] == x_dims[0]) {
out->share_lod(x);
}
#undef MAX_RANK_SUPPORTED
}

void FillAnyLikeInferMeta(const MetaTensor& x,
Expand Down Expand Up @@ -4722,7 +4723,7 @@ void TileInferMeta(const MetaTensor& x,
const IntArray& repeat_times,
MetaTensor* out,
MetaConfig config) {
#define TILE_MAX_RANK_SUPPORTED 6
#define MAX_RANK_SUPPORTED 6

auto repeat_times_data = repeat_times.GetData();
auto x_dims = x.dims();
Expand All @@ -4732,19 +4733,19 @@ void TileInferMeta(const MetaTensor& x,

PADDLE_ENFORCE_LE(
x_dims.size(),
TILE_MAX_RANK_SUPPORTED,
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The rank of the input 'x' for tile op "
"must not be greater than %d, but the value received is %d.",
TILE_MAX_RANK_SUPPORTED,
MAX_RANK_SUPPORTED,
x_dims.size()));
PADDLE_ENFORCE_LE(
repeat_times_data.size(),
TILE_MAX_RANK_SUPPORTED,
MAX_RANK_SUPPORTED,
errors::InvalidArgument(
"The size of the shape of input 'repeat_times' for tile op "
"must not be greater than %d, but the value received is %d.",
TILE_MAX_RANK_SUPPORTED,
MAX_RANK_SUPPORTED,
repeat_times_data.size()));
PADDLE_ENFORCE_GE(
repeat_times_data.size(),
Expand Down Expand Up @@ -4785,6 +4786,7 @@ void TileInferMeta(const MetaTensor& x,
out->share_lod(x);
}
out->set_dtype(x.dtype());
#undef MAX_RANK_SUPPORTED
}

void TopKInferMeta(const MetaTensor& x,
Expand Down

0 comments on commit 1aa3783

Please sign in to comment.