Skip to content

Commit b02744a

Browse files
committed
feat: ExecutionProvider::rocm, ref #16
1 parent 17eeed3 commit b02744a

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

src/execution_providers.rs

+12-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ extern "C" {
1414
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_CoreML(options: *mut sys::OrtSessionOptions, flags: u32) -> sys::OrtStatusPtr;
1515
#[cfg(feature = "directml")]
1616
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_DML(options: *mut sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> sys::OrtStatusPtr;
17+
#[cfg(feature = "rocm")]
18+
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_ROCm(options: *mut sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> sys::OrtStatusPtr;
1719
}
1820

1921
/// Execution provider container. See [the ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/) for more
@@ -73,7 +75,8 @@ impl ExecutionProvider {
7375
dnnl = "DnnlExecutionProvider",
7476
onednn = "DnnlExecutionProvider",
7577
coreml = "CoreMLExecutionProvider",
76-
directml = "DmlExecutionProvider"
78+
directml = "DmlExecutionProvider",
79+
rocm = "ROCmExecutionProvider"
7780
}
7881

7982
/// Returns `true` if this execution provider is available, `false` otherwise.
@@ -212,6 +215,14 @@ pub(crate) fn apply_execution_providers(options: *mut sys::OrtSessionOptions, ex
212215
return; // EP found
213216
}
214217
}
218+
#[cfg(feature = "rocm")]
219+
"ROCmExecutionProvider" => {
220+
let device_id = init_args.get("device_id").map_or(0, |s| s.parse::<i32>().unwrap_or(0));
221+
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_ROCm(options, device_id) };
222+
if status_to_result_and_log("DirectML", status).is_ok() {
223+
return; // EP found
224+
}
225+
}
215226
_ => {}
216227
};
217228
}

0 commit comments

Comments
 (0)