Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed May 28, 2024
1 parent 3bf4e57 commit 4439621
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 35 deletions.
35 changes: 35 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,40 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
#endif
}

template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const int count,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(&tensor, tensor, opts, false, true);
task->Wait();
} else {
auto dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm =
paddle::platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}

} // namespace fusion
} // namespace phi
36 changes: 1 addition & 35 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ limitations under the License. */
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

#include "paddle/fluid/operators/fused/fused_attention_utils.h"
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h"
Expand All @@ -38,41 +39,6 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace phi {
namespace fusion {

template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const int count,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(&tensor, tensor, opts, false, true);
task->Wait();
} else {
auto dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm =
paddle::platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}

namespace { // NOLINT

using float16 = phi::dtype::float16;
Expand Down

0 comments on commit 4439621

Please sign in to comment.