From 265b6ab742374361f090d380048a21dddc21fc3c Mon Sep 17 00:00:00 2001 From: co63oc Date: Sat, 3 Aug 2024 08:11:25 +0800 Subject: [PATCH] Fix --- .../operators/collective/c_scatter_op.cc | 96 ---------- .../operators/collective/c_scatter_op.cu.cc | 176 ------------------ .../fluid/operators/collective/c_scatter_op.h | 58 ------ paddle/phi/infermeta/unary.cc | 22 ++- paddle/phi/infermeta/unary.h | 3 +- paddle/phi/kernels/cpu/c_scatter_kernel.cc | 63 +++++++ paddle/phi/kernels/gpu/c_scatter_kernel.cu | 136 ++++++++++++++ .../phi/ops/yaml/inconsistent/static_ops.yaml | 9 - paddle/phi/ops/yaml/ops.yaml | 9 + 9 files changed, 231 insertions(+), 341 deletions(-) delete mode 100644 paddle/fluid/operators/collective/c_scatter_op.cc delete mode 100644 paddle/fluid/operators/collective/c_scatter_op.cu.cc delete mode 100644 paddle/fluid/operators/collective/c_scatter_op.h create mode 100644 paddle/phi/kernels/cpu/c_scatter_kernel.cc create mode 100644 paddle/phi/kernels/gpu/c_scatter_kernel.cu diff --git a/paddle/fluid/operators/collective/c_scatter_op.cc b/paddle/fluid/operators/collective/c_scatter_op.cc deleted file mode 100644 index 1d9d40997f933..0000000000000 --- a/paddle/fluid/operators/collective/c_scatter_op.cc +++ /dev/null @@ -1,96 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/collective/c_scatter_op.h" - -namespace paddle::operators { - -class CScatterOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "CScatter"); - OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CScatter"); - int root_id = ctx->Attrs().Get("root"); - int ring_id = ctx->Attrs().Get("ring_id"); - int nranks = ctx->Attrs().Get("nranks"); - PADDLE_ENFORCE_GE(nranks, - 2, - common::errors::InvalidArgument( - "The number of ranks (%d) must be greater than 1 " - "to use collective op (c_scatter op).", - nranks)); - PADDLE_ENFORCE_GE( - root_id, - 0, - common::errors::InvalidArgument( - "The root_id (%d) for c_scatter_op must be non-negative.", - root_id)); - PADDLE_ENFORCE_GE( - ring_id, - 0, - common::errors::InvalidArgument( - "The ring_id (%d) for c_scatter_op must be non-negative.", - root_id)); - phi::DDim dim = ctx->GetInputDim("X"); - dim[0] = dim[0] / nranks; - if (dim[0] < 0) dim[0] = -1; - ctx->SetOutputDim("Out", dim); - } - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.GetPlace()); - } -}; - -class CScatterOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "(Tensor) tensor to be broadcasted."); - AddOutput("Out", "(Tensor) the result of broadcast."); - AddAttr("ring_id", "(int default 0) nccl communication ring id.") - .SetDefault(0); - AddAttr("root", "(int default 0) root id for broadcasting.") - .SetDefault(0); - AddAttr("nranks", "(int default 0) number of ranks.").SetDefault(0); - AddAttr( - "use_calc_stream", - "(bool default false) eject CUDA operations to calculation stream.") - .SetDefault(false); - AddComment(R"DOC( -CScatter Operator -Scatter the source to all participators. -)DOC"); - } -}; - -} // namespace paddle::operators - -namespace ops = paddle::operators; - -REGISTER_OP_WITHOUT_GRADIENT(c_scatter, ops::CScatterOp, ops::CScatterOpMaker); - -PD_REGISTER_STRUCT_KERNEL(c_scatter, - CPU, - ALL_LAYOUT, - ops::CScatterOpCPUKernel, - float, - double, - int, - int64_t, - phi::dtype::float16) {} diff --git a/paddle/fluid/operators/collective/c_scatter_op.cu.cc b/paddle/fluid/operators/collective/c_scatter_op.cu.cc deleted file mode 100644 index c74a520d0dc71..0000000000000 --- a/paddle/fluid/operators/collective/c_scatter_op.cu.cc +++ /dev/null @@ -1,176 +0,0 @@ -/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/collective/c_scatter_op.h" -#include "paddle/phi/core/distributed/comm_context_manager.h" - -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) -#include "paddle/common/flags.h" -#include "paddle/fluid/platform/device/gpu/nccl_helper.h" -#include "paddle/phi/core/distributed/nccl_comm_context.h" -#include "paddle/phi/core/platform/collective_helper.h" -COMMON_DECLARE_bool(dynamic_static_unified_comm); -#endif - -namespace paddle { -namespace operators { - -template -class CScatterOpCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) - auto x = ctx.Input("X"); - auto out = ctx.Output("Out"); - int numel = x->numel(); - ncclDataType_t dtype = phi::ToNCCLDataType(x->dtype()); - - int nranks = ctx.Attr("nranks"); - int root_id = ctx.Attr("root"); - int ring_id = ctx.Attr("ring_id"); - auto place = ctx.GetPlace(); - gpuStream_t stream = nullptr; - platform::NCCLComm* comm = nullptr; - phi::distributed::NCCLCommContext* comm_ctx = nullptr; - PADDLE_ENFORCE_GE( - root_id, - 0, - common::errors::InvalidArgument( - "The root_id (%d) for c_scatter_op must be non-negative.", - root_id)); - PADDLE_ENFORCE_GE( - ring_id, - 0, - common::errors::InvalidArgument( - "The ring_id (%d) for c_scatter_op must be non-negative.", - ring_id)); - - const auto& comm_context_manager = - phi::distributed::CommContextManager::GetInstance(); - if (FLAGS_dynamic_static_unified_comm) { - PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), - true, - common::errors::InvalidArgument( - "You choose to use new communication library by " - "setting environment " - "variable FLAGS_dynamic_static_unified_comm True. " - "But ring_id(%d) is " - "not found in comm_context_manager.", - std::to_string(ring_id))); - comm_ctx = static_cast( - comm_context_manager.Get(std::to_string(ring_id))); - PADDLE_ENFORCE_NE(comm_ctx, - nullptr, - common::errors::Unavailable( - "NCCLCommContext is nullptr, collective op should " - "has ring_id attr.")); - PADDLE_ENFORCE_EQ(nranks, - comm_ctx->GetSize(), - common::errors::InvalidArgument( - "The number of ranks (%d) you set of must " - "be equal to comm_ctx->GetSize() (%d).", - nranks, - comm_ctx->GetSize())); - - stream = comm_ctx->GetStream(); - VLOG(3) << "new comm_context_manager has ring_id " << ring_id; - } else { // old comm_context - comm = platform::NCCLCommContext::Instance().Get(ring_id, place); - PADDLE_ENFORCE_EQ(nranks, - comm->nranks(), - common::errors::InvalidArgument( - "The number of ranks (%d) you set of must " - "be equal to comm->nranks (%d).", - nranks, - comm->nranks())); - - stream = comm->stream(); - VLOG(3) << "old NCCLCommContext has ring_id " << ring_id; - } - if (ctx.Attr("use_calc_stream")) { - // should ExecutionContext for calc stream. - stream = ctx.cuda_device_context().stream(); - } - - phi::DDim x_dims = x->dims(); - phi::DDim out_dims(x_dims); - phi::DenseTensor temp; - auto out_ptr = temp.mutable_data(out_dims, place); - - if (FLAGS_dynamic_static_unified_comm) { - if (root_id == comm_ctx->GetRank()) { - comm_ctx->Broadcast( - const_cast(x), *x, root_id, stream); - framework::TensorCopy(*static_cast(x), - place, - *phi::DeviceContextPool::Instance().Get(place), - static_cast(&temp)); - } else { - comm_ctx->Broadcast(&temp, temp, root_id, stream); - } - } else { - if (root_id == comm->rank()) { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBcast( - reinterpret_cast(const_cast(x->data())), - numel, - dtype, - root_id, - comm->comm(), - stream)); - - framework::TensorCopy(*static_cast(x), - place, - *phi::DeviceContextPool::Instance().Get(place), - static_cast(&temp)); - } else { - PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::ncclBcast( - out_ptr, numel, dtype, root_id, comm->comm(), stream)); - } - } - - out_dims[0] = out_dims[0] / nranks; - auto start_index = FLAGS_dynamic_static_unified_comm - ? out_dims[0] * comm_ctx->GetRank() - : out_dims[0] * comm->rank(); - auto end_index = start_index + out_dims[0]; - temp = temp.Slice(start_index, end_index); - temp.Resize(out_dims); - out->mutable_data(out_dims, place); - framework::TensorCopySync(*static_cast(&temp), - place, - static_cast(out)); - out->Resize(out_dims); -#else - PADDLE_ENFORCE_EQ( - true, - false, - common::errors::Unavailable("PaddlePaddle should compile with GPU.")); -#endif - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -PD_REGISTER_STRUCT_KERNEL(c_scatter, - GPU, - ALL_LAYOUT, - ops::CScatterOpCUDAKernel, - float, - double, - int, - int64_t, - phi::dtype::float16) {} diff --git a/paddle/fluid/operators/collective/c_scatter_op.h b/paddle/fluid/operators/collective/c_scatter_op.h deleted file mode 100644 index 3174aac97e76d..0000000000000 --- a/paddle/fluid/operators/collective/c_scatter_op.h +++ /dev/null @@ -1,58 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include -#include -#include - -#include "paddle/fluid/framework/data_type.h" -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" - -#if defined(PADDLE_WITH_GLOO) -#include "paddle/phi/core/distributed/gloo_comm_context.h" -#endif - -namespace paddle { -namespace operators { - -template -class CScatterOpCPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { -#if defined(PADDLE_WITH_GLOO) - auto& dev_ctx = ctx.device_context(); - auto in = ctx.Input("X"); - auto out = ctx.Output("Out"); - auto root_id = ctx.Attr("root"); - - auto comm_ctx = static_cast( - dev_ctx.GetCommContext()); - PADDLE_ENFORCE_NE(comm_ctx, - nullptr, - ::common::errors::Unavailable( - "NCCLCommContext is nullptr, collective op should " - "has ring_id attr.")); - comm_ctx->Scatter(out, *in, root_id); -#else - PADDLE_THROW(common::errors::Unavailable( - "PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON")); -#endif - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index 8c0c953cfdffe..ade75a14a3105 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -795,7 +795,27 @@ void CropInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } -void CScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { +void CScatterInferMeta(const MetaTensor& x, + int ring_id, + int root_id, + int nranks, + MetaTensor* out) { + PADDLE_ENFORCE_GE(nranks, + 2, + common::errors::InvalidArgument( + "The number of ranks (%d) must be greater than 1 " + "to use collective op (c_scatter op).", + nranks)); + PADDLE_ENFORCE_GE( + root_id, + 0, + common::errors::InvalidArgument( + "The root_id (%d) for c_scatter_op must be non-negative.", root_id)); + PADDLE_ENFORCE_GE( + ring_id, + 0, + common::errors::InvalidArgument( + "The ring_id (%d) for c_scatter_op must be non-negative.", ring_id)); auto dim = x.dims(); dim[0] = dim[0] / nranks; if (dim[0] < 0) dim[0] = -1; diff --git a/paddle/phi/infermeta/unary.h b/paddle/phi/infermeta/unary.h index 2ac8b5363b670..4cf1d558414ae 100644 --- a/paddle/phi/infermeta/unary.h +++ b/paddle/phi/infermeta/unary.h @@ -158,7 +158,8 @@ void CropInferMeta(const MetaTensor& x, MetaTensor* out, MetaConfig config = MetaConfig()); -void CScatterInferMeta(const MetaTensor& x, int nranks, MetaTensor* out); +void CScatterInferMeta( + const MetaTensor& x, int ring_id, int root, int nranks, MetaTensor* out); void CSplitInferMeta(const MetaTensor& x, int nranks, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/c_scatter_kernel.cc b/paddle/phi/kernels/cpu/c_scatter_kernel.cc new file mode 100644 index 0000000000000..c8558b8db36e7 --- /dev/null +++ b/paddle/phi/kernels/cpu/c_scatter_kernel.cc @@ -0,0 +1,63 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +#if defined(PADDLE_WITH_GLOO) +#include "paddle/phi/core/distributed/gloo_comm_context.h" +#endif + +namespace phi { + +template +void CScatterOpCPUKernel(const Context &dev_ctx, + const DenseTensor &x, + int ring_id, + int root, + int nranks, + bool use_calc_stream, + DenseTensor *out) { +#if defined(PADDLE_WITH_GLOO) + auto in = &x; + auto root_id = root; + + auto comm_ctx = static_cast( + dev_ctx.GetCommContext()); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + ::common::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + comm_ctx->Scatter(out, *in, root_id); +#else + PADDLE_THROW(common::errors::Unavailable( + "PaddlePaddle should compile with GLOO by setting WITH_GLOO=ON")); +#endif +} +} // namespace phi + +PD_REGISTER_KERNEL(c_scatter, + CPU, + ALL_LAYOUT, + phi::CScatterOpCPUKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/c_scatter_kernel.cu b/paddle/phi/kernels/gpu/c_scatter_kernel.cu new file mode 100644 index 0000000000000..4ea62f468e58e --- /dev/null +++ b/paddle/phi/kernels/gpu/c_scatter_kernel.cu @@ -0,0 +1,136 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "glog/logging.h" +#include "paddle/phi/core/distributed/comm_context_manager.h" + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/phi/core/distributed/nccl_comm_context.h" +#endif +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CScatterOpCUDAKernel(const Context& dev_ctx, + const DenseTensor& input, + int ring_id, + int root, + int nranks, + bool use_calc_stream, + DenseTensor* out) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + auto x = &input; + int numel = x->numel(); + ncclDataType_t dtype = phi::ToNCCLDataType(x->dtype()); + + int root_id = root; + auto place = dev_ctx.GetPlace(); + gpuStream_t stream = nullptr; + phi::distributed::NCCLCommContext* comm_ctx = nullptr; + PADDLE_ENFORCE_GE( + root_id, + 0, + common::errors::InvalidArgument( + "The root_id (%d) for c_scatter_op must be non-negative.", root_id)); + PADDLE_ENFORCE_GE( + ring_id, + 0, + common::errors::InvalidArgument( + "The ring_id (%d) for c_scatter_op must be non-negative.", ring_id)); + + const auto& comm_context_manager = + phi::distributed::CommContextManager::GetInstance(); + + PADDLE_ENFORCE_EQ(comm_context_manager.Has(std::to_string(ring_id)), + true, + common::errors::InvalidArgument( + "You choose to use new communication library by " + "setting environment " + "variable FLAGS_dynamic_static_unified_comm True. " + "But ring_id(%d) is " + "not found in comm_context_manager.", + std::to_string(ring_id))); + comm_ctx = static_cast( + comm_context_manager.Get(std::to_string(ring_id))); + PADDLE_ENFORCE_NE(comm_ctx, + nullptr, + common::errors::Unavailable( + "NCCLCommContext is nullptr, collective op should " + "has ring_id attr.")); + PADDLE_ENFORCE_EQ(nranks, + comm_ctx->GetSize(), + common::errors::InvalidArgument( + "The number of ranks (%d) you set of must " + "be equal to comm_ctx->GetSize() (%d).", + nranks, + comm_ctx->GetSize())); + + stream = comm_ctx->GetStream(); + VLOG(3) << "new comm_context_manager has ring_id " << ring_id; + + if (use_calc_stream) { + // should ExecutionContext for calc stream. + stream = dev_ctx.stream(); + } + + phi::DDim x_dims = x->dims(); + phi::DDim out_dims(x_dims); + phi::DenseTensor temp; + temp.Resize(out_dims); + auto out_ptr = dev_ctx.template Alloc(&temp); + + if (root_id == comm_ctx->GetRank()) { + comm_ctx->Broadcast(const_cast(x), *x, root_id, stream); + phi::Copy(dev_ctx, + *static_cast(x), + place, + false, + static_cast(&temp)); + } else { + comm_ctx->Broadcast(&temp, temp, root_id, stream); + } + + out_dims[0] = out_dims[0] / nranks; + auto start_index = out_dims[0] * comm_ctx->GetRank(); + auto end_index = start_index + out_dims[0]; + temp = temp.Slice(start_index, end_index); + temp.Resize(out_dims); + out->Resize(out_dims); + dev_ctx.template Alloc(out); + phi::Copy(dev_ctx, + *static_cast(&temp), + place, + true, + static_cast(out)); + out->Resize(out_dims); +#else + PADDLE_ENFORCE_EQ( + true, + false, + common::errors::Unavailable("PaddlePaddle should compile with GPU.")); +#endif +} +} // namespace phi + +PD_REGISTER_KERNEL(c_scatter, + GPU, + ALL_LAYOUT, + phi::CScatterOpCUDAKernel, + float, + double, + int, + int64_t, + phi::dtype::float16) {} diff --git a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml index 5afaa90709d2e..b3a1958483f5c 100644 --- a/paddle/phi/ops/yaml/inconsistent/static_ops.yaml +++ b/paddle/phi/ops/yaml/inconsistent/static_ops.yaml @@ -174,15 +174,6 @@ func : reduce_scatter param: [x, nranks] -- op : c_scatter - args : (Tensor x, int ring_id = 0, int root = 0, int nranks = 0, bool use_calc_stream = false) - output : Tensor(out) - infer_meta : - func : CScatterInferMeta - param : [x, nranks] - kernel : - func : c_scatter - - op : c_split args : (Tensor x, int rank = 0, int nranks = 1, int ring_id = 0, bool use_calc_stream = false, bool use_model_parallel = true) output : Tensor(out) diff --git a/paddle/phi/ops/yaml/ops.yaml b/paddle/phi/ops/yaml/ops.yaml index b782651301cd8..999dd62494705 100755 --- a/paddle/phi/ops/yaml/ops.yaml +++ b/paddle/phi/ops/yaml/ops.yaml @@ -758,6 +758,15 @@ func : c_reduce_sum inplace : (x -> out) +- op : c_scatter + args : (Tensor x, int ring_id = 0, int root = 0, int nranks = 0, bool use_calc_stream = false) + output : Tensor(out) + infer_meta : + func : CScatterInferMeta + param : [x, ring_id, root, nranks] + kernel : + func : c_scatter + - op : c_sync_calc_stream args : (Tensor x) output : Tensor(out)