Skip to content

Commit

Permalink
Move fused_multi_transformer_op.cu.h phi (#64012)
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc authored May 28, 2024
1 parent f8763de commit f2fa26f
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 59 deletions.
35 changes: 35 additions & 0 deletions paddle/fluid/operators/fused/fused_attention_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,40 @@ static void AllReduce(phi::DenseTensor &tensor, // NOLINT
#endif
}

template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const int count,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(&tensor, tensor, opts, false, true);
task->Wait();
} else {
auto dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm =
paddle::platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}

} // namespace fusion
} // namespace phi
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h"
#include "paddle/fluid/operators/fused/fused_attention_utils.h"
#include "paddle/fluid/platform/device/gpu/gpu_resource_pool.h"
#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h"
#include "paddle/phi/kernels/fusion/gpu/attn_gemm_int8.h"
#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h"
#include "paddle/phi/kernels/fusion/gpu/fused_multi_transformer_helper.cu.h"

namespace paddle {
namespace operators {
Expand Down
6 changes: 2 additions & 4 deletions paddle/fluid/operators/fused/fused_multi_transformer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fused/fused_multi_transformer_helper.cu.h"

#include "paddle/fluid/framework/op_registry.h"

#include "paddle/fluid/operators/fused/fused_attention_utils.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
Expand All @@ -24,8 +22,8 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fmha_ref.h"
#include "paddle/phi/kernels/fusion/gpu/fused_multi_transformer_helper.cu.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace paddle {
namespace operators {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h"
#include "paddle/phi/kernels/funcs/cublaslt.h"
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/funcs/quant_dequant.h"
#include "paddle/phi/kernels/fusion/gpu/attention_layer.norm.h"
#include "paddle/phi/kernels/fusion/gpu/attn_gemm.h"
#include "paddle/phi/kernels/fusion/gpu/attn_gemm_int8.h"
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
#include "paddle/phi/kernels/fusion/gpu/fused_multi_transformer_op.cu.h"

/*
Note(Zhengzekang):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,6 @@ limitations under the License. */
#include <fstream>
#include <iomanip>

#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/distributed/collective/process_group_nccl.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device/gpu/nccl_helper.h"
#include "paddle/phi/core/distributed/nccl_comm_context.h"
COMMON_DECLARE_bool(dynamic_static_unified_comm);
#endif

#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/kernels/funcs/load_store_util.h"
#include "paddle/phi/kernels/fusion/gpu/fused_bias_act_utils.h"
Expand All @@ -38,41 +30,6 @@ COMMON_DECLARE_bool(dynamic_static_unified_comm);
namespace phi {
namespace fusion {

template <typename T>
static void AllReduce(phi::DenseTensor &tensor, // NOLINT
const int ring_id,
const int count,
const phi::GPUContext &ctx) {
if (ring_id == -1) return;
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto map = paddle::distributed::ProcessGroupMapFromGid::getInstance();

if (map->has(ring_id)) {
paddle::distributed::ProcessGroup *pg = map->get(ring_id);
paddle::distributed::AllreduceOptions opts;
opts.reduce_op = distributed::ReduceOp::SUM;
auto task = pg->AllReduce(&tensor, tensor, opts, false, true);
task->Wait();
} else {
auto dtype = paddle::platform::ToNCCLDataType(
paddle::framework::TransToProtoVarType(tensor.dtype()));
int64_t numel = tensor.numel();
const void *sendbuff = tensor.data<T>();
auto place = ctx.GetPlace();
void *recvbuff = tensor.mutable_data<T>(place);
auto comm =
paddle::platform::NCCLCommContext::Instance().Get(ring_id, place);
auto stream = ctx.stream();
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::ncclAllReduce(
sendbuff, recvbuff, count, dtype, ncclSum, comm->comm(), stream));
}
#else
PADDLE_THROW(phi::errors::Unimplemented(
"PaddlePaddle should compile with NCCL or RCCL when used tensor model "
"parallel op."));
#endif
}

namespace { // NOLINT

using float16 = phi::dtype::float16;
Expand Down
21 changes: 11 additions & 10 deletions paddle/phi/kernels/fusion/gpu/mmha_util.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -2503,13 +2503,13 @@ inline __device__ void zero(T& dst) { // NOLINT
dst = tmp.raw;
}

template <int WARPS_PER_BLOCK, int WARP_SIZE = 32>
template <int WARPS_PER_BLOCK, int WARP_SIZE_T = 32>
inline __device__ float block_sum(float* red_smem, float sum) {
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
int warp = threadIdx.x / WARP_SIZE_T;
int lane = threadIdx.x % WARP_SIZE_T;

#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
for (int mask = WARP_SIZE_T / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}

Expand Down Expand Up @@ -2830,7 +2830,7 @@ struct Qk_dot<float16, 4> {
}
};

constexpr int32_t WARP_SIZE = 32;
constexpr int32_t WARP_SIZE_TMP = 32;
constexpr int32_t HALF_WARP = 16;
constexpr float QUANT_MAX_BOUND = 127.0;
constexpr float QUANT_MIN_BOUND = -127.0;
Expand Down Expand Up @@ -2920,16 +2920,17 @@ template <typename T>
__inline__ __device__ T WarpReduceAbsMax(T val, unsigned lane_mask) {
#pragma unroll
for (int mask = HALF_WARP; mask > 0; mask >>= 1) {
val = MaxFunc<T>()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE));
val =
MaxFunc<T>()(val, __shfl_xor_sync(lane_mask, val, mask, WARP_SIZE_TMP));
}
return val;
}

template <typename T>
__inline__ __device__ T BlockReduceAbsMax(T val, unsigned mask) {
static __shared__ T smem[WARP_SIZE];
int32_t lane_id = threadIdx.x % WARP_SIZE;
int32_t warp_id = threadIdx.x / WARP_SIZE;
static __shared__ T smem[WARP_SIZE_TMP];
int32_t lane_id = threadIdx.x % WARP_SIZE_TMP;
int32_t warp_id = threadIdx.x / WARP_SIZE_TMP;

val = WarpReduceAbsMax(val, mask);

Expand All @@ -2939,7 +2940,7 @@ __inline__ __device__ T BlockReduceAbsMax(T val, unsigned mask) {

__syncthreads();

T abs_max_val = (threadIdx.x < (blockDim.x / WARP_SIZE))
T abs_max_val = (threadIdx.x < (blockDim.x / WARP_SIZE_TMP))
? smem[threadIdx.x]
: static_cast<T>(0.0f);
abs_max_val = WarpReduceAbsMax(abs_max_val, mask);
Expand Down

0 comments on commit f2fa26f

Please sign in to comment.