Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move fused_multi_transformer_op.cu.h phi #64012

Merged
merged 6 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,40 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
#endif
}

template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const int count,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(&tensor, tensor, opts, false, true);
task->Wait();
} else {
auto dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm =
paddle::platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}

} // namespace fusion
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h"
#include "paddle/fluid/operators/fused/fused_attention_utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h"
#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h"
#include "paddle/phi/kernels/fusion/gpu/attn_gemm_int8.h"
#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h"
#include "paddle/phi/kernels/fusion/gpu/fused_multi_transformer_helper.cu.h"

namespace paddle {
namespace operators {
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h"

#include "paddle/fluid/framework/op_registry.h"

#include "paddle/fluid/operators/fused/fused_attention_utils.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
Expand All @@ -24,8 +22,8 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h"
#include "paddle/phi/kernels/fusion/gpu/fused_multi_transformer_helper.cu.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace paddle {
namespace operators {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h"
#include "paddle/phi/kernels/funcs/cublaslt.h"
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/funcs/quant_dequant.h"
#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h"
#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h"
#include "paddle/phi/kernels/fusion/gpu/attn_gemm_int8.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
#include "paddle/phi/kernels/fusion/gpu/fused_multi_transformer_op.cu.h"

/*
Note(Zhengzekang):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ limitations under the License. */
#include <fstream>
#include <iomanip>

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h"
Expand All @@ -38,41 +30,6 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace phi {
namespace fusion {

template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const int count,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(&tensor, tensor, opts, false, true);
task->Wait();
} else {
auto dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm =
paddle::platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}

namespace { // NOLINT

using float16 = phi::dtype::float16;
Expand Down
21 changes: 11 additions & 10 deletions paddle/phi/kernels/fusion/gpu/mmha_util.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2503,13 +2503,13 @@ inline __device__ void zero(T& dst) { // NOLINT
dst = tmp.raw;
}

template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
template <int WARPS_PER_BLOCK, int WARP_SIZE_T = 32>
inline __device__ float block_sum(float* red_smem, float sum) {
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
int warp = threadIdx.x / WARP_SIZE_T;
int lane = threadIdx.x % WARP_SIZE_T;

#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
for (int mask = WARP_SIZE_T / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}

Expand Down Expand Up @@ -2830,7 +2830,7 @@ struct Qk_dot<float16, 4> {
}
};

constexpr int32_t WARP_SIZE = 32;
constexpr int32_t WARP_SIZE_TMP = 32;
constexpr int32_t HALF_WARP = 16;
constexpr float QUANT_MAX_BOUND = 127.0;
constexpr float QUANT_MIN_BOUND = -127.0;
Expand Down Expand Up @@ -2920,16 +2920,17 @@ template <typename T>
__inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) {
#pragma unroll
for (int mask = HALF_WARP; mask > 0; mask >>= 1) {
val = MaxFunc<T>()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE));
val =
MaxFunc<T>()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE_TMP));
}
return val;
}

template <typename T>
__inline__ __device__ T BlockReduceAbsMax(T val, unsigned mask) {
static __shared__ T smem[WARP_SIZE];
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t warp_id = threadIdx.x / WARP_SIZE;
static __shared__ T smem[WARP_SIZE_TMP];
int32_t lane_id = threadIdx.x % WARP_SIZE_TMP;
int32_t warp_id = threadIdx.x / WARP_SIZE_TMP;

val = WarpReduceAbsMax(val, mask);

Expand All @@ -2939,7 +2940,7 @@ __inline__ __device__ T BlockReduceAbsMax(T val, unsigned mask) {

__syncthreads();

T abs_max_val = (threadIdx.x < (blockDim.x / WARP_SIZE))
T abs_max_val = (threadIdx.x < (blockDim.x / WARP_SIZE_TMP))
? smem[threadIdx.x]
: static_cast<T>(0.0f);
abs_max_val = WarpReduceAbsMax(abs_max_val, mask);
Expand Down