Skip to content

Commit

Permalink
fused_ln:Added implementation for the HIP platform (#8472)
Browse files Browse the repository at this point in the history
  • Loading branch information
asr-sheep1 authored Jun 4, 2024
1 parent cb8a229 commit 4c9a3f5
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 66 deletions.
137 changes: 126 additions & 11 deletions model_zoo/gpt-3/external_ops/fused_ln/layer_norm_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,12 @@

#pragma once // NOLINT

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#else
#include <cuda.h> // NOLINT
#include <cuda_runtime.h> // NOLINT

#endif
#include "paddle/extension.h"

#define DEFAULT_THROW(NAME, TYPE) \
Expand Down Expand Up @@ -71,22 +74,34 @@
DEFAULT_THROW(NAME, TYPEIN); \
}

#ifdef PADDLE_WITH_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif

template <typename T>
__device__ __forceinline__ T WARP_SHFL_XOR(T value,
int laneMask,
int width = WARP_SIZE,
unsigned int mask = 0xffffffff) {
return __shfl_xor_sync(mask, value, laneMask, width);
#ifdef PADDLE_WITH_HIP
return __shfl_xor(value, laneMask, width);
#else
return __shfl_xor_sync(mask,value, laneMask, width);
#endif
}

template <typename T>
__device__ __forceinline__ T WARP_SHFL(T value,
int srcLane,
int width = WARP_SIZE,
unsigned int mask = 0xffffffff) {
#ifdef PADDLE_WITH_HIP
return __shfl(value, srcLane, width);
#else
return __shfl_sync(mask, value, srcLane, width);
#endif
}

template <typename U>
Expand Down Expand Up @@ -181,8 +196,17 @@ __device__ void cuWelfordMuSigma2(const T* __restrict__ vals,
}
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
#ifdef PADDLE_WITH_HIP
for (int l = 0; l <= 5; ++l)
#else
for (int l = 0; l <= 4; ++l)
#endif
{
#ifdef PADDLE_WITH_HIP
int srcLaneB = (threadIdx.x + (1 << l)) & 63;
#else
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
#endif
U sigma2B = WARP_SHFL(sigma2, srcLaneB);
if (!rms_only) {
U muB = WARP_SHFL(mu, srcLaneB);
Expand Down Expand Up @@ -306,8 +330,17 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals,
}
}
// intra-warp reductions
for (int l = 0; l <= 4; ++l) {
#ifdef PADDLE_WITH_HIP
for (int l = 0; l <= 5; ++l)
#else
for (int l = 0; l <= 4; ++l)
#endif
{
#ifdef PADDLE_WITH_HIP
int srcLaneB = (threadIdx.x + (1 << l)) & 63;
#else
int srcLaneB = (threadIdx.x + (1 << l)) & 31;
#endif
float sigma2B = WARP_SHFL(sigma2, srcLaneB);
if (!rms_only) {
float muB = WARP_SHFL(mu, srcLaneB);
Expand Down Expand Up @@ -369,15 +402,15 @@ __device__ void cuWelfordMuSigma2(const phi::dtype::float16* __restrict__ vals,
}
}

template <typename U>
template <typename U> __device__
U rsqrt(U v) {
return U(1) / sqrt(v);
}
template <>
template <> __device__
float rsqrt(float v) {
return rsqrtf(v);
}
template <>
template <> __device__
double rsqrt(double v) {
return rsqrt(v);
}
Expand Down Expand Up @@ -914,6 +947,22 @@ __global__ void cuComputeGradInput(const V* __restrict__ dout,
}
}

#ifdef PADDLE_WITH_HIP
static hipDeviceProp_t GetDevicePropImpl() {
int device = -1;
PD_CHECK(hipGetDevice(&device) == hipSuccess);
hipDeviceProp_t prop;
PD_CHECK(hipGetDeviceProperties(&prop, device) == hipSuccess);
return prop;
}

static hipDeviceProp_t* GetDeviceProp() {
static auto prop = GetDevicePropImpl();
return &prop;
}

#else

static cudaDeviceProp GetDevicePropImpl() {
int device = -1;
PD_CHECK(cudaGetDevice(&device) == cudaSuccess);
Expand All @@ -926,8 +975,10 @@ static cudaDeviceProp* GetDeviceProp() {
static auto prop = GetDevicePropImpl();
return &prop;
}
#endif

template <typename T, typename U, typename V>
#ifdef PADDLE_WITH_HIP
void HostApplyLayerNorm(V* output,
U* mean,
U* invvar,
Expand All @@ -937,8 +988,25 @@ void HostApplyLayerNorm(V* output,
double epsilon,
const V* gamma,
const V* beta,
cudaStream_t stream) {
hipStream_t stream)
#else
void HostApplyLayerNorm(V* output,
U* mean,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma,
const V* beta,
cudaStream_t stream)
#endif
{
#ifdef PADDLE_WITH_HIP
const dim3 threads(64, 4, 1);
#else
const dim3 threads(32, 4, 1);
#endif
const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1];
const dim3 blocks(1, std::min((uint64_t)n1, maxGridY), 1);
int nshared =
Expand All @@ -948,16 +1016,32 @@ void HostApplyLayerNorm(V* output,
}

template <typename T, typename U, typename V = T>
#ifdef PADDLE_WITH_HIP
void HostApplyRMSNorm(V* output,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma,
hipStream_t stream)
#else
void HostApplyRMSNorm(V* output,
U* invvar,
const T* input,
int n1,
int n2,
double epsilon,
const V* gamma,
cudaStream_t stream) {
cudaStream_t stream)
#endif
{
// auto stream = at::cuda::getCurrentCUDAStream().stream();
#ifdef PADDLE_WITH_HIP
const dim3 threads(64, 4, 1);
#else
const dim3 threads(32, 4, 1);
#endif
// const uint64_t maxGridY =
// at::cuda::getCurrentDeviceProperties()->maxGridSize[1];
const uint64_t maxGridY = GetDeviceProp()->maxGridSize[1];
Expand Down Expand Up @@ -1015,6 +1099,7 @@ static void cuda_rms_norm(const paddle::Tensor& x,
}

template <typename T, typename U, typename V>
#ifdef PADDLE_WITH_HIP
void HostLayerNormGradient(const V* dout,
const U* mean,
const U* invvar,
Expand All @@ -1027,7 +1112,23 @@ void HostLayerNormGradient(const V* dout,
T* grad_input,
V* grad_gamma,
V* grad_beta,
cudaStream_t stream) {
hipStream_t stream)
#else
void HostLayerNormGradient(const V* dout,
const U* mean,
const U* invvar,
const paddle::Tensor& input,
int n1,
int n2,
const V* gamma,
const V* beta,
double epsilon,
T* grad_input,
V* grad_gamma,
V* grad_beta,
cudaStream_t stream)
#endif
{
if (gamma != NULL && beta != NULL) {
// compute grad_gamma(j) and grad_beta(j)
const int part_size = 16;
Expand Down Expand Up @@ -1085,6 +1186,18 @@ void HostLayerNormGradient(const V* dout,
}

template <typename T, typename U, typename V>
#ifdef PADDLE_WITH_HIP
void HostRMSNormGradient(const V* dout,
const U* invvar,
const paddle::Tensor& input,
int n1,
int n2,
const V* gamma,
double epsilon,
T* grad_input,
V* grad_gamma,
hipStream_t stream)
#else
void HostRMSNormGradient(const V* dout,
const U* invvar,
const paddle::Tensor& input,
Expand All @@ -1094,7 +1207,9 @@ void HostRMSNormGradient(const V* dout,
double epsilon,
T* grad_input,
V* grad_gamma,
cudaStream_t stream) {
cudaStream_t stream)
#endif
{
if (gamma != NULL) {
const int part_size = 16;
const dim3 threads2(32, 4, 1);
Expand Down
Loading

0 comments on commit 4c9a3f5

Please sign in to comment.