|
21 | 21 |
|
22 | 22 | namespace sherpa_onnx {
|
23 | 23 |
|
| 24 | + |
| 25 | +static void OrtStatusFailure(OrtStatus *status, const char *s) { |
| 26 | + const auto &api = Ort::GetApi(); |
| 27 | + const char *msg = api.GetErrorMessage(status); |
| 28 | + SHERPA_ONNX_LOGE( |
| 29 | + "Failed to enable TensorRT : %s." |
| 30 | + "Available providers: %s. Fallback to cuda", msg, s); |
| 31 | + api.ReleaseStatus(status); |
| 32 | +} |
| 33 | + |
24 | 34 | static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
25 | 35 | std::string provider_str) {
|
26 | 36 | Provider p = StringToProvider(std::move(provider_str));
|
@@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
53 | 63 | }
|
54 | 64 | break;
|
55 | 65 | }
|
| 66 | + case Provider::kTRT: { |
| 67 | + struct TrtPairs { |
| 68 | + const char* op_keys; |
| 69 | + const char* op_values; |
| 70 | + }; |
| 71 | + |
| 72 | + std::vector<TrtPairs> trt_options = { |
| 73 | + {"device_id", "0"}, |
| 74 | + {"trt_max_workspace_size", "2147483648"}, |
| 75 | + {"trt_max_partition_iterations", "10"}, |
| 76 | + {"trt_min_subgraph_size", "5"}, |
| 77 | + {"trt_fp16_enable", "0"}, |
| 78 | + {"trt_detailed_build_log", "0"}, |
| 79 | + {"trt_engine_cache_enable", "1"}, |
| 80 | + {"trt_engine_cache_path", "."}, |
| 81 | + {"trt_timing_cache_enable", "1"}, |
| 82 | + {"trt_timing_cache_path", "."} |
| 83 | + }; |
| 84 | + // ToDo : Trt configs |
| 85 | + // "trt_int8_enable" |
| 86 | + // "trt_int8_use_native_calibration_table" |
| 87 | + // "trt_dump_subgraphs" |
| 88 | + |
| 89 | + std::vector<const char*> option_keys, option_values; |
| 90 | + for (const TrtPairs& pair : trt_options) { |
| 91 | + option_keys.emplace_back(pair.op_keys); |
| 92 | + option_values.emplace_back(pair.op_values); |
| 93 | + } |
| 94 | + |
| 95 | + std::vector<std::string> available_providers = |
| 96 | + Ort::GetAvailableProviders(); |
| 97 | + if (std::find(available_providers.begin(), available_providers.end(), |
| 98 | + "TensorrtExecutionProvider") != available_providers.end()) { |
| 99 | + const auto& api = Ort::GetApi(); |
| 100 | + |
| 101 | + OrtTensorRTProviderOptionsV2* tensorrt_options; |
| 102 | + OrtStatus *statusC = api.CreateTensorRTProviderOptions( |
| 103 | + &tensorrt_options); |
| 104 | + OrtStatus *statusU = api.UpdateTensorRTProviderOptions( |
| 105 | + tensorrt_options, option_keys.data(), option_values.data(), |
| 106 | + option_keys.size()); |
| 107 | + sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options); |
| 108 | + |
| 109 | + if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); } |
| 110 | + if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); } |
| 111 | + |
| 112 | + api.ReleaseTensorRTProviderOptions(tensorrt_options); |
| 113 | + } |
| 114 | + // break; is omitted here intentionally so that |
| 115 | + // if TRT not available, CUDA will be used |
| 116 | + } |
56 | 117 | case Provider::kCUDA: {
|
57 | 118 | if (std::find(available_providers.begin(), available_providers.end(),
|
58 | 119 | "CUDAExecutionProvider") != available_providers.end()) {
|
@@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
|
116 | 177 | break;
|
117 | 178 | }
|
118 | 179 | }
|
119 |
| - |
120 | 180 | return sess_opts;
|
121 | 181 | }
|
122 | 182 |
|
|
0 commit comments