Skip to content

Commit

Permalink
update, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
sandyhouse committed Sep 2, 2021
1 parent bca1df6 commit 0b5902f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
29 changes: 15 additions & 14 deletions paddle/fluid/operators/collective/recv_v2_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,24 @@ class RecvOpV2 : public framework::OperatorWithKernel {
ring_id, 0,
platform::errors::InvalidArgument(
"The ring_id (%d) for recv_v2 op must be non-negative.", ring_id));
auto out_shape = ctx->Attrs().Get<std::vector<int>>("out_shape");
PADDLE_ENFORCE_GE(out_shape.size(), 1,
platform::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i], 1,
platform::errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1.",
i, out_shape[i]));
}

if (ctx->GetOutputsVarType("Out").front() ==
framework::proto::VarType::LOD_TENSOR) {
auto out_shape = ctx->Attrs().Get<std::vector<int>>("out_shape");
PADDLE_ENFORCE_GE(
out_shape.size(), 1,
platform::errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i], 1,
platform::errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1.",
i, out_shape[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
}
}
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/operators/collective/recv_v2_op.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
peer, 0,
platform::errors::InvalidArgument(
"The peer (%d) for recv_v2 op must be non-negative.", peer));
auto out_shape = ctx.Attr<std::vector<int>>("out_shape");

gpuStream_t stream = nullptr;
auto place = ctx.GetPlace();
Expand All @@ -66,8 +65,8 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
auto out_array = out_var->GetMutable<framework::LoDTensorArray>();
for (size_t idx = 0; idx < out_array->size(); ++idx) {
VLOG(3) << "LodTensorArray: idx(" << idx << ")";
auto out_dims = framework::make_ddim(out_shape);
auto out = &out_array->at(idx);
auto out_dims = out->dims();
out->mutable_data<T>(out_dims, place, 0);
auto numel = out->numel();
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv(
Expand All @@ -78,6 +77,7 @@ class RecvOpV2CUDAKernel : public framework::OpKernel<T> {
return;
}

auto out_shape = ctx.Attr<std::vector<int>>("out_shape");
auto out = ctx.Output<framework::LoDTensor>("Out");
auto out_dims = out->dims();
auto numel = out->numel();
Expand Down

1 comment on commit 0b5902f

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.