Skip to content

Commit 69347ff

Browse files
Support TensorRT provider (#921)
Signed-off-by: manickavela1998@gmail.com <manickavela1998@gmail.com> Signed-off-by: manickavela1998@gmail.com <manickavela.arumugam@uniphore.com>
1 parent 7e0931c commit 69347ff

File tree

3 files changed

+64
-1
lines changed

3 files changed

+64
-1
lines changed

sherpa-onnx/csrc/provider.cc

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ Provider StringToProvider(std::string s) {
2424
return Provider::kXnnpack;
2525
} else if (s == "nnapi") {
2626
return Provider::kNNAPI;
27+
} else if (s == "trt") {
28+
return Provider::kTRT;
2729
} else {
2830
SHERPA_ONNX_LOGE("Unsupported string: %s. Fallback to cpu", s.c_str());
2931
return Provider::kCPU;

sherpa-onnx/csrc/provider.h

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ enum class Provider {
1818
kCoreML = 2, // CoreMLExecutionProvider
1919
kXnnpack = 3, // XnnpackExecutionProvider
2020
kNNAPI = 4, // NnapiExecutionProvider
21+
kTRT = 5, // TensorRTExecutionProvider
2122
};
2223

2324
/**

sherpa-onnx/csrc/session.cc

+61-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@
2121

2222
namespace sherpa_onnx {
2323

24+
25+
static void OrtStatusFailure(OrtStatus *status, const char *s) {
26+
const auto &api = Ort::GetApi();
27+
const char *msg = api.GetErrorMessage(status);
28+
SHERPA_ONNX_LOGE(
29+
"Failed to enable TensorRT : %s."
30+
"Available providers: %s. Fallback to cuda", msg, s);
31+
api.ReleaseStatus(status);
32+
}
33+
2434
static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
2535
std::string provider_str) {
2636
Provider p = StringToProvider(std::move(provider_str));
@@ -53,6 +63,57 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
5363
}
5464
break;
5565
}
66+
case Provider::kTRT: {
67+
struct TrtPairs {
68+
const char* op_keys;
69+
const char* op_values;
70+
};
71+
72+
std::vector<TrtPairs> trt_options = {
73+
{"device_id", "0"},
74+
{"trt_max_workspace_size", "2147483648"},
75+
{"trt_max_partition_iterations", "10"},
76+
{"trt_min_subgraph_size", "5"},
77+
{"trt_fp16_enable", "0"},
78+
{"trt_detailed_build_log", "0"},
79+
{"trt_engine_cache_enable", "1"},
80+
{"trt_engine_cache_path", "."},
81+
{"trt_timing_cache_enable", "1"},
82+
{"trt_timing_cache_path", "."}
83+
};
84+
// ToDo : Trt configs
85+
// "trt_int8_enable"
86+
// "trt_int8_use_native_calibration_table"
87+
// "trt_dump_subgraphs"
88+
89+
std::vector<const char*> option_keys, option_values;
90+
for (const TrtPairs& pair : trt_options) {
91+
option_keys.emplace_back(pair.op_keys);
92+
option_values.emplace_back(pair.op_values);
93+
}
94+
95+
std::vector<std::string> available_providers =
96+
Ort::GetAvailableProviders();
97+
if (std::find(available_providers.begin(), available_providers.end(),
98+
"TensorrtExecutionProvider") != available_providers.end()) {
99+
const auto& api = Ort::GetApi();
100+
101+
OrtTensorRTProviderOptionsV2* tensorrt_options;
102+
OrtStatus *statusC = api.CreateTensorRTProviderOptions(
103+
&tensorrt_options);
104+
OrtStatus *statusU = api.UpdateTensorRTProviderOptions(
105+
tensorrt_options, option_keys.data(), option_values.data(),
106+
option_keys.size());
107+
sess_opts.AppendExecutionProvider_TensorRT_V2(*tensorrt_options);
108+
109+
if (statusC) { OrtStatusFailure(statusC, os.str().c_str()); }
110+
if (statusU) { OrtStatusFailure(statusU, os.str().c_str()); }
111+
112+
api.ReleaseTensorRTProviderOptions(tensorrt_options);
113+
}
114+
// break; is omitted here intentionally so that
115+
// if TRT not available, CUDA will be used
116+
}
56117
case Provider::kCUDA: {
57118
if (std::find(available_providers.begin(), available_providers.end(),
58119
"CUDAExecutionProvider") != available_providers.end()) {
@@ -116,7 +177,6 @@ static Ort::SessionOptions GetSessionOptionsImpl(int32_t num_threads,
116177
break;
117178
}
118179
}
119-
120180
return sess_opts;
121181
}
122182

0 commit comments

Comments
 (0)