Skip to content

Commit

Permalink
【Hackathon 6th Fundable Projects 3 No.49】 [fluid_ops] c_scatter (#66848)
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc authored Aug 6, 2024
1 parent cb8220d commit 126d59a
Show file tree
Hide file tree
Showing 9 changed files with 231 additions and 341 deletions.
96 changes: 0 additions & 96 deletions paddle/fluid/operators/collective/c_scatter_op.cc

This file was deleted.

176 changes: 0 additions & 176 deletions paddle/fluid/operators/collective/c_scatter_op.cu.cc

This file was deleted.

58 changes: 0 additions & 58 deletions paddle/fluid/operators/collective/c_scatter_op.h

This file was deleted.

22 changes: 21 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,27 @@ void CropInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void CScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
void CScatterInferMeta(const MetaTensor& x,
int ring_id,
int root_id,
int nranks,
MetaTensor* out) {
PADDLE_ENFORCE_GE(nranks,
2,
common::errors::InvalidArgument(
"The number of ranks (%d) must be greater than 1 "
"to use collective op (c_scatter op).",
nranks));
PADDLE_ENFORCE_GE(
root_id,
0,
common::errors::InvalidArgument(
"The root_id (%d) for c_scatter_op must be non-negative.", root_id));
PADDLE_ENFORCE_GE(
ring_id,
0,
common::errors::InvalidArgument(
"The ring_id (%d) for c_scatter_op must be non-negative.", ring_id));
auto dim = x.dims();
dim[0] = dim[0] / nranks;
if (dim[0] < 0) dim[0] = -1;
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,8 @@ void CropInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaConfig config = MetaConfig());

void CScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out);
void CScatterInferMeta(
const MetaTensor& x, int ring_id, int root, int nranks, MetaTensor* out);

void CSplitInferMeta(const MetaTensor& x, int nranks, MetaTensor* out);

Expand Down
Loading

0 comments on commit 126d59a

Please sign in to comment.