Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DCU] fix bugs and surpport some fused ops #63217

Merged
merged 13 commits into from
Apr 12, 2024
Merged
2 changes: 1 addition & 1 deletion paddle/phi/core/visit_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ namespace phi {
"`"); \
} \
}()
#if defined(PADDLE_WITH_XPU) || defined(PADDLE_WITH_HIP)
#if defined(PADDLE_WITH_XPU)
#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,9 @@ if(WITH_ROCM)
"gpu/lu_kernel.cu"
"gpu/matrix_rank_kernel.cu"
"gpu/matrix_rank_tol_kernel.cu"
"gpu/multiclass_nms3_kernel.cu"
"gpu/put_along_axis_grad_kernel.cu"
"gpu/put_along_axis_kernel.cu"
"gpu/qr_kernel.cu"
"gpu/rms_norm_grad_kernel.cu"
"gpu/svd_kernel.cu"
"gpudnn/mha_cudnn_frontend.cu"
"fusion/gpu/block_multi_head_attention_kernel.cu"
Expand Down
24 changes: 21 additions & 3 deletions paddle/phi/kernels/funcs/layer_norm_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ __inline__ __device__ double rsqrt_(const double val) {
return ::rsqrt(val);
}

#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) || defined(PADDLE_WITH_HIP)
template <>
__inline__ __device__ half rsqrt_(const half val) {
return hrsqrt(val);
}
#endif

#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename T,
typename U,
typename ScaleT = U,
Expand Down Expand Up @@ -254,7 +254,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(

#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
#ifdef PADDLE_WITH_HIP
mu_local += __shfl_xor(mu_local, it);
#else
mu_local += __shfl_xor_sync(uint32_t(-1), mu_local, it);
#endif
}
if (WARPS_N > 1) {
if (lane == 0) {
Expand Down Expand Up @@ -290,7 +294,11 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fast_ln_fwd_kernel(

#pragma unroll
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
#ifdef PADDLE_WITH_HIP
var_local += __shfl_xor(var_local, it);
#else
var_local += __shfl_xor_sync(uint32_t(-1), var_local, it);
#endif
}

if (WARPS_N > 1) {
Expand Down Expand Up @@ -546,7 +554,7 @@ __inline__ __device__ void cuLoadAddStridedInputs(const int64_t i1_block,
}
}

#ifdef PADDLE_WITH_CUDA
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <bool IsFusedDropoutResidualLn,
bool NeedDDropoutSrcPtr,
typename T,
Expand Down Expand Up @@ -678,16 +686,26 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void fused_ln_bwd_fast_kernel(
#pragma unroll
// row reduction among 32 threads.
for (int it = 1; it < THREADS_PER_WARP; it *= 2) {
#ifdef PADDLE_WITH_HIP
sum_loss1 += __shfl_xor(sum_loss1, it);
sum_loss2 += __shfl_xor(sum_loss2, it);
#else
sum_loss1 += __shfl_xor_sync(uint32_t(-1), sum_loss1, it);
sum_loss2 += __shfl_xor_sync(uint32_t(-1), sum_loss2, it);
#endif
}
sum_loss1 *= rn;
sum_loss2 *= rn;
} else {
#pragma unroll
for (int it = 16; it > 0; it /= 2) {
#ifdef PADDLE_WITH_HIP
sum_loss1 += __shfl_down(sum_loss1, it);
sum_loss2 += __shfl_down(sum_loss2, it);
#else
sum_loss1 += __shfl_down_sync(uint32_t(-1), sum_loss1, it);
sum_loss2 += __shfl_down_sync(uint32_t(-1), sum_loss2, it);
#endif
}

if (lane == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#else
#include <cuda_fp16.h>
#include <cub/cub.cuh>
#endif
Expand All @@ -21,9 +26,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
#endif

namespace phi {
namespace fusion {
Expand Down Expand Up @@ -51,7 +54,6 @@ void FusedBiasDropoutResidualLnGradKernel(
DenseTensor* bias_grad,
DenseTensor* ln_scale_grad,
DenseTensor* ln_bias_grad) {
#ifndef PADDLE_WITH_HIP
using U = LayerNormParamType<T>;
auto* d_y_data = y_grad.data<T>();
auto* ln_scale_data =
Expand Down Expand Up @@ -114,19 +116,24 @@ void FusedBiasDropoutResidualLnGradKernel(
d_x_data,
d_bias_data,
d_residual_data);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FusedBiasDropoutResidualLnGradKernel not surpport for rocm"));
#endif
}

} // namespace fusion
} // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad,
GPU,
ALL_LAYOUT,
phi::fusion::FusedBiasDropoutResidualLnGradKernel,
float,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm_grad,
GPU,
ALL_LAYOUT,
phi::fusion::FusedBiasDropoutResidualLnGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
#ifndef PADDLE_WITH_HIP
#include "paddle/phi/kernels/fusion/gpu/fused_dropout_helper.h"
#endif

namespace phi {
namespace fusion {
Expand All @@ -42,7 +40,6 @@ void FusedBiasDropoutResidualLnKernel(
DenseTensor* dropout_mask_out,
DenseTensor* ln_mean,
DenseTensor* ln_variance) {
#ifndef PADDLE_WITH_HIP
using U = phi::funcs::LayerNormParamType<T>;
auto* x_data = x.data<T>();
auto* bias_data = (bias.get_ptr() == nullptr) ? nullptr : bias->data<T>();
Expand Down Expand Up @@ -95,14 +92,20 @@ void FusedBiasDropoutResidualLnKernel(
y_data,
ln_mean_data,
ln_var_data);
#else
PADDLE_THROW(phi::errors::Unimplemented(
"FusedBiasDropoutResidualLnKernel not support for rocm"));
#endif
}
} // namespace fusion
} // namespace phi

#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
GPU,
ALL_LAYOUT,
phi::fusion::FusedBiasDropoutResidualLnKernel,
float,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
#else
PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
GPU,
ALL_LAYOUT,
Expand All @@ -112,3 +115,4 @@ PD_REGISTER_KERNEL(fused_bias_dropout_residual_layer_norm,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UINT8);
}
#endif
8 changes: 6 additions & 2 deletions paddle/phi/kernels/fusion/gpu/fused_dropout_act_bias.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ struct GeluFunctor {
template <typename T>
struct FastGeluFunctor {
inline __device__ T operator()(const T x) const {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE(0, "FastGelu not surpport for rocm");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI这里提示错误,建议修改为

PADDLE_THROW(phi::errors::Unimplemented("ROCM does not support FastGelu"));

#else
return phi::GeluFwd<T, true>(x);
#endif
}
};

Expand Down Expand Up @@ -92,8 +96,8 @@ __global__ void FusedDropoutActBias(
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;

curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
GPURAND(StatePhilox4_32_10_t) state;
GPURAND(_init)(seed, idx, increment, &state);

const T factor =
phi::fusion::GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
Expand Down
39 changes: 29 additions & 10 deletions paddle/phi/kernels/fusion/gpu/fused_dropout_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,25 @@ limitations under the License. */
#include <curand_kernel.h>
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_fp16.h>
#include <hip/hip_runtime.h>
#include <hiprand.h>
#include <hiprand_kernel.h>
#endif

#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"

#ifdef PADDLE_WITH_HIP
#define GPU(str) hip##str
#define GPURAND(str) hiprand##str
#else
#define GPU(str) cuda##str
#define GPURAND(str) curand##str
#endif

namespace phi {
namespace fusion {

Expand Down Expand Up @@ -63,34 +78,38 @@ inline phi::backends::gpu::GpuLaunchConfig Get1DBlocksAnd2DGrids(
}

template <int VecSize>
__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
__forceinline__ __device__ void RandVec(GPURAND(StatePhilox4_32_10_t) * state,
float *data);

template <>
__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state,
__forceinline__ __device__ void RandVec<1>(GPURAND(StatePhilox4_32_10_t) *
state,
float *data) {
data[0] = curand_uniform(state);
data[0] = GPURAND(_uniform)(state);
}

template <>
__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state,
__forceinline__ __device__ void RandVec<2>(GPURAND(StatePhilox4_32_10_t) *
state,
float *data) {
data[0] = curand_uniform(state);
data[1] = curand_uniform(state);
data[0] = GPURAND(_uniform)(state);
data[1] = GPURAND(_uniform)(state);
}

template <>
__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state,
__forceinline__ __device__ void RandVec<4>(GPURAND(StatePhilox4_32_10_t) *
state,
float *data) {
float4 rand4 = curand_uniform4(state);
float4 rand4 = GPURAND(_uniform4)(state);
data[0] = rand4.x;
data[1] = rand4.y;
data[2] = rand4.w;
data[3] = rand4.z;
}

template <>
__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state,
__forceinline__ __device__ void RandVec<8>(GPURAND(StatePhilox4_32_10_t) *
state,
float *data) {
RandVec<4>(state, data);
RandVec<4>(state, data + 4);
Expand All @@ -99,7 +118,7 @@ __forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state,
template <typename T>
inline void SetZero(const phi::GPUContext &ctx, T *ptr, const size_t size) {
PADDLE_ENFORCE_GPU_SUCCESS(
cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream()));
GPU(MemsetAsync)(ptr, 0, size * sizeof(T), ctx.stream()));
}

/**
Expand Down
Loading