Skip to content

Commit

Permalink
Add cuDNN 9.0 (#62498)
Browse files Browse the repository at this point in the history
* fix cuDNN 9 problem

* remove glog
  • Loading branch information
jeng1220 authored Mar 11, 2024
1 parent f512028 commit c8e8be2
Show file tree
Hide file tree
Showing 11 changed files with 492 additions and 41 deletions.
82 changes: 77 additions & 5 deletions paddle/fluid/operators/cudnn_rnn_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ struct CudnnRNNCache {
~CudnnRNNCache() { release(); }

cudnnRNNDescriptor_t rnn_desc_;
#if CUDNN_VERSION >= 90000
cudnnRNNDataDescriptor_t x_desc_;
cudnnRNNDataDescriptor_t y_desc_;
#else
cudnnTensorDescriptor_t *x_desc_;
cudnnTensorDescriptor_t *y_desc_;
#endif

cudnnTensorDescriptor_t hx_desc_;
cudnnTensorDescriptor_t cx_desc_;
Expand Down Expand Up @@ -93,7 +98,37 @@ struct CudnnRNNCache {
const auto numDirections = is_bidirec_ ? 2 : 1;
auto cudnn_size =
cudnn_type == CUDNN_DATA_FLOAT ? sizeof(float) : sizeof(double);
#if CUDNN_VERSION >= 90000
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateRNNDataDescriptor(&x_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateRNNDataDescriptor(&y_desc_));

std::vector<int> seq_length_array(batch_size_);
for (int i = 0; i < batch_size_; ++i) {
seq_length_array[i] = seq_length_;
}

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDataDescriptor(
x_desc_,
cudnn_type,
CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED,
seq_length_,
batch_size_,
input_size_,
reinterpret_cast<const int *>(seq_length_array.data()),
nullptr));

PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDataDescriptor(
y_desc_,
cudnn_type,
CUDNN_RNN_DATA_LAYOUT_BATCH_MAJOR_UNPACKED,
seq_length_,
batch_size_,
hidden_size_ * numDirections,
reinterpret_cast<const int *>(seq_length_array.data()),
nullptr));
#else
x_desc_ = new cudnnTensorDescriptor_t[seq_length_];
y_desc_ = new cudnnTensorDescriptor_t[seq_length_];
std::vector<int> dims = {batch_size_, input_size_, 1};
Expand All @@ -114,6 +149,7 @@ struct CudnnRNNCache {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetTensorNdDescriptor(
y_desc_[i], cudnn_type, 3, dims_y.data(), strides_y.data()));
}
#endif

std::vector<int> dims_hx = {
num_layers_ * numDirections, batch_size_, hidden_size_};
Expand Down Expand Up @@ -185,7 +221,24 @@ struct CudnnRNNCache {

PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateRNNDescriptor(&rnn_desc_));

#if CUDNN_VERSION >= 90000
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v8(
rnn_desc_,
CUDNN_RNN_ALGO_STANDARD,
CUDNN_LSTM,
CUDNN_RNN_DOUBLE_BIAS,
is_bidirec_ ? CUDNN_BIDIRECTIONAL : CUDNN_UNIDIRECTIONAL,
CUDNN_LINEAR_INPUT,
cudnn_type,
cudnn_type,
CUDNN_DEFAULT_MATH,
input_size_,
hidden_size_,
hidden_size_,
num_layers_,
dropout_desc_,
CUDNN_RNN_PADDED_IO_ENABLED));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetRNNDescriptor_v6(
handle,
rnn_desc_,
Expand All @@ -197,15 +250,19 @@ struct CudnnRNNCache {
CUDNN_LSTM,
CUDNN_RNN_ALGO_STANDARD,
cudnn_type));

#endif
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateFilterDescriptor(&w_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnCreateFilterDescriptor(&dw_desc_));

#if CUDNN_VERSION >= 90000
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNWeightSpaceSize(
handle, rnn_desc_, &weights_size_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNParamsSize(
handle, rnn_desc_, x_desc_[0], &weights_size_, cudnn_type));

#endif
PADDLE_ENFORCE_EQ(
weights_size_,
cudnn_size * weight_numel,
Expand All @@ -220,18 +277,32 @@ struct CudnnRNNCache {
w_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSetFilterNdDescriptor(
dw_desc_, cudnn_type, CUDNN_TENSOR_NCHW, 3, dim_w));

#if CUDNN_VERSION >= 90000
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetRNNTempSpaceSizes(handle,
rnn_desc_,
CUDNN_FWD_MODE_TRAINING,
x_desc_,
&workspace_size_,
reserve_size_));
#else
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnGetRNNWorkspaceSize(
handle, rnn_desc_, seq_length_, x_desc_, &workspace_size_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnGetRNNTrainingReserveSize(
handle, rnn_desc_, seq_length_, x_desc_, reserve_size_));

#endif
workspace_data_.Resize({static_cast<int64_t>(workspace_size_)});
workspace_data_.mutable_data<uint8_t>(place);
}

void release() {
#if CUDNN_VERSION >= 90000
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDestroyRNNDataDescriptor(x_desc_));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDestroyRNNDataDescriptor(y_desc_));
#else
for (size_t i = 0; i < seq_length_; ++i) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(x_desc_[i]));
Expand All @@ -241,6 +312,7 @@ struct CudnnRNNCache {

delete[] x_desc_;
delete[] y_desc_;
#endif

PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cudnnDestroyTensorDescriptor(hx_desc_));
Expand Down
12 changes: 12 additions & 0 deletions paddle/fluid/platform/dynload/cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_R8(DEFINE_WRAP);
#endif

#ifdef CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9
CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(DEFINE_WRAP);
#endif

#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9
CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(DEFINE_WRAP);
#endif

#ifdef CUDNN_DNN_ROUTINE_EACH_R9
CUDNN_DNN_ROUTINE_EACH_R9(DEFINE_WRAP);
#endif

bool HasCUDNN() { return phi::dynload::HasCUDNN(); }

} // namespace dynload
Expand Down
50 changes: 35 additions & 15 deletions paddle/fluid/platform/dynload/cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,6 @@ extern bool HasCUDNN();
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
__macro(cudnnGetRNNTrainingReserveSize); \
__macro(cudnnRNNForwardTraining); \
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetTensorNdDescriptorEx); \
Expand All @@ -111,8 +104,7 @@ extern bool HasCUDNN();
__macro(cudnnCreateActivationDescriptor); \
__macro(cudnnSetActivationDescriptor); \
__macro(cudnnGetActivationDescriptor); \
__macro(cudnnDestroyActivationDescriptor); \
__macro(cudnnSetRNNDescriptor_v6);
__macro(cudnnDestroyActivationDescriptor);
CUDNN_DNN_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)

#if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000
Expand Down Expand Up @@ -147,12 +139,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \
__macro(cudnnCreateRNNDataDescriptor); \
__macro(cudnnDestroyRNNDataDescriptor); \
__macro(cudnnSetRNNDataDescriptor); \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx); \
__macro(cudnnRNNForwardInferenceEx);
__macro(cudnnSetRNNDataDescriptor);
CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif

Expand Down Expand Up @@ -182,6 +169,39 @@ CUDNN_DNN_ROUTINE_EACH_AFTER_R7(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
CUDNN_DNN_ROUTINE_EACH_R8(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif

#if CUDNN_VERSION < 90000
#define CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(__macro) \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
__macro(cudnnGetRNNTrainingReserveSize); \
__macro(cudnnSetRNNDescriptor_v6); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnRNNForwardTraining); \
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights);
CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif

#if CUDNN_VERSION < 90000 && CUDNN_VERSION >= 7201
#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(__macro) \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNForwardInferenceEx); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx);
CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(
PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif

#if CUDNN_VERSION >= 90000
#define CUDNN_DNN_ROUTINE_EACH_R9(__macro) \
__macro(cudnnGetRNNWeightSpaceSize); \
__macro(cudnnGetRNNTempSpaceSizes); \
__macro(cudnnRNNForward); \
__macro(cudnnRNNBackwardData_v8); \
__macro(cudnnRNNBackwardWeights_v8);
CUDNN_DNN_ROUTINE_EACH_R9(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/backends/dynload/cudnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ CUDNN_DNN_ROUTINE_EACH_R8(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_FRONTEND(DEFINE_WRAP);
#endif

#ifdef CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9
CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(DEFINE_WRAP);
#endif

#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9
CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(DEFINE_WRAP);
#endif

#ifdef CUDNN_DNN_ROUTINE_EACH_R9
CUDNN_DNN_ROUTINE_EACH_R9(DEFINE_WRAP);
#endif

bool HasCUDNN() {
std::call_once(cudnn_dso_flag,
[]() { cudnn_dso_handle = GetCUDNNDsoHandle(); });
Expand Down
50 changes: 35 additions & 15 deletions paddle/phi/backends/dynload/cudnn.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnSetDropoutDescriptor); \
__macro(cudnnRestoreDropoutDescriptor); \
__macro(cudnnCreateRNNDescriptor); \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
__macro(cudnnGetRNNTrainingReserveSize); \
__macro(cudnnRNNForwardTraining); \
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnDestroyDropoutDescriptor); \
__macro(cudnnDestroyRNNDescriptor); \
__macro(cudnnSetTensorNdDescriptorEx); \
Expand All @@ -124,8 +117,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
__macro(cudnnCreateActivationDescriptor); \
__macro(cudnnSetActivationDescriptor); \
__macro(cudnnGetActivationDescriptor); \
__macro(cudnnDestroyActivationDescriptor); \
__macro(cudnnSetRNNDescriptor_v6);
__macro(cudnnDestroyActivationDescriptor);
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)

#if CUDNN_VERSION >= 7000 && CUDNN_VERSION < 8000
Expand Down Expand Up @@ -159,12 +151,7 @@ CUDNN_DNN_ROUTINE_EACH_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(__macro) \
__macro(cudnnCreateRNNDataDescriptor); \
__macro(cudnnDestroyRNNDataDescriptor); \
__macro(cudnnSetRNNDataDescriptor); \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx); \
__macro(cudnnRNNForwardInferenceEx);
__macro(cudnnSetRNNDataDescriptor);
CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif

Expand Down Expand Up @@ -207,6 +194,39 @@ CUDNN_DNN_ROUTINE_EACH_R8(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
CUDNN_DNN_ROUTINE_EACH_FRONTEND(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif

#if CUDNN_VERSION < 90000
#define CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(__macro) \
__macro(cudnnGetRNNParamsSize); \
__macro(cudnnGetRNNWorkspaceSize); \
__macro(cudnnGetRNNTrainingReserveSize); \
__macro(cudnnSetRNNDescriptor_v6); \
__macro(cudnnRNNForwardInference); \
__macro(cudnnRNNForwardTraining); \
__macro(cudnnRNNBackwardData); \
__macro(cudnnRNNBackwardWeights);
CUDNN_DNN_ROUTINE_EACH_REMOVED_IN_E9(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif

#if CUDNN_VERSION < 90000 && CUDNN_VERSION >= 7201
#define CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(__macro) \
__macro(cudnnSetRNNPaddingMode); \
__macro(cudnnRNNForwardInferenceEx); \
__macro(cudnnRNNForwardTrainingEx); \
__macro(cudnnRNNBackwardDataEx); \
__macro(cudnnRNNBackwardWeightsEx);
CUDNN_DNN_ROUTINE_EACH_AFTER_TWO_R7_REMOVED_IN_E9(
DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif

#if CUDNN_VERSION >= 90000
#define CUDNN_DNN_ROUTINE_EACH_R9(__macro) \
__macro(cudnnGetRNNWeightSpaceSize); \
__macro(cudnnGetRNNTempSpaceSizes); \
__macro(cudnnRNNForward); \
__macro(cudnnRNNBackwardData_v8); \
__macro(cudnnRNNBackwardWeights_v8);
CUDNN_DNN_ROUTINE_EACH_R9(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
} // namespace dynload
} // namespace phi

Expand Down
Loading

0 comments on commit c8e8be2

Please sign in to comment.