Skip to content

Commit

Permalink
softmax fwd: force vec size to 1 when dtype is float (PaddlePaddle#54183
Browse files Browse the repository at this point in the history
)

* softmax fwd: force vec size to 1 when dtype is float

* use 1024 as threshold to use cudnn
  • Loading branch information
shaojiewang authored May 30, 2023
1 parent 44bd592 commit f5a3b42
Showing 1 changed file with 36 additions and 22 deletions.
58 changes: 36 additions & 22 deletions paddle/phi/kernels/gpudnn/softmax_gpudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -798,6 +798,7 @@ void SwitchWarpSoftmaxForward(const IndexType blocks,
SOFTMAX_WARP_FORWARD_CASE(7, AccT);
SOFTMAX_WARP_FORWARD_CASE(8, AccT);
SOFTMAX_WARP_FORWARD_CASE(9, AccT);
SOFTMAX_WARP_FORWARD_CASE(10, AccT);
default:
break;
}
Expand Down Expand Up @@ -836,6 +837,7 @@ void SwitchWarpSoftmaxBackward(const int blocks,
SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
SOFTMAX_WARP_BACKWARD_CASE(10, AccT);
default:
break;
}
Expand Down Expand Up @@ -1262,7 +1264,7 @@ bool UseCudnnSoftmax(const GPUContext& ctx,
#endif
}
}
constexpr int max_dim = 512;
constexpr int max_dim = 1024;
if (!cudnn_available || !last_dim ||
(softmax_dim <= max_dim && sizeof(T) <= 4) ||
softmax_dim >= MATRIX_SOFTMAX_THREAHOLD) {
Expand Down Expand Up @@ -1311,27 +1313,7 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
using T4 = typename VecT4<T>::Type;
using T2 = typename VecT2<T>::Type;

if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
if (std::is_same<T, float>::value) {
SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
threads,
dev_ctx,
Expand All @@ -1341,6 +1323,38 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx,
dim,
dim,
dim_log2);
} else {
if (dim % 4 == 0) {
SwitchWarpSoftmaxForward<T, T4, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else if (dim % 2 == 0) {
SwitchWarpSoftmaxForward<T, T2, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
} else {
SwitchWarpSoftmaxForward<T, T, IndexType, LogMode>(blocks,
threads,
dev_ctx,
out_data,
x.data<T>(),
N,
dim,
dim,
dim_log2);
}
}
} else {
LaunchSoftmaxForwardCudnnKernel<T>(dev_ctx, x, axis, LogMode, out);
Expand Down

0 comments on commit f5a3b42

Please sign in to comment.