@@ -14,6 +14,8 @@ extern "C" {
14
14
pub ( crate ) fn OrtSessionOptionsAppendExecutionProvider_CoreML ( options : * mut sys:: OrtSessionOptions , flags : u32 ) -> sys:: OrtStatusPtr ;
15
15
#[ cfg( feature = "directml" ) ]
16
16
pub ( crate ) fn OrtSessionOptionsAppendExecutionProvider_DML ( options : * mut sys:: OrtSessionOptions , device_id : std:: os:: raw:: c_int ) -> sys:: OrtStatusPtr ;
17
+ #[ cfg( feature = "rocm" ) ]
18
+ pub ( crate ) fn OrtSessionOptionsAppendExecutionProvider_ROCm ( options : * mut sys:: OrtSessionOptions , device_id : std:: os:: raw:: c_int ) -> sys:: OrtStatusPtr ;
17
19
}
18
20
19
21
/// Execution provider container. See [the ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/) for more
@@ -73,7 +75,8 @@ impl ExecutionProvider {
73
75
dnnl = "DnnlExecutionProvider" ,
74
76
onednn = "DnnlExecutionProvider" ,
75
77
coreml = "CoreMLExecutionProvider" ,
76
- directml = "DmlExecutionProvider"
78
+ directml = "DmlExecutionProvider" ,
79
+ rocm = "ROCmExecutionProvider"
77
80
}
78
81
79
82
/// Returns `true` if this execution provider is available, `false` otherwise.
@@ -212,6 +215,14 @@ pub(crate) fn apply_execution_providers(options: *mut sys::OrtSessionOptions, ex
212
215
return ; // EP found
213
216
}
214
217
}
218
+ #[ cfg( feature = "rocm" ) ]
219
+ "ROCmExecutionProvider" => {
220
+ let device_id = init_args. get ( "device_id" ) . map_or ( 0 , |s| s. parse :: < i32 > ( ) . unwrap_or ( 0 ) ) ;
221
+ let status = unsafe { OrtSessionOptionsAppendExecutionProvider_ROCm ( options, device_id) } ;
222
+ if status_to_result_and_log ( "DirectML" , status) . is_ok ( ) {
223
+ return ; // EP found
224
+ }
225
+ }
215
226
_ => { }
216
227
} ;
217
228
}
0 commit comments