Skip to content

Commit 17eeed3

Browse files
committed
docs: more rust documentation
1 parent bb2f924 commit 17eeed3

14 files changed

+85
-17
lines changed

build.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ fn generate_bindings(include_dir: &Path) {
530530
}
531531

532532
fn main() {
533-
if !std::env::var("DOCS_RS").is_ok() {
533+
if std::env::var("DOCS_RS").is_err() {
534534
let (install_dir, needs_link) = prepare_libort_dir();
535535

536536
let include_dir = install_dir.join("include");

src/download.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
pub mod language;
22
pub mod vision;
33

4+
/// Represents a type that returns an ONNX model URL.
45
pub trait ModelUrl {
6+
/// Returns the model URL associated with this model.
57
fn fetch_url(&self) -> &'static str;
68
}

src/download/language.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! Models for language understanding.
2+
13
pub mod machine_comprehension;
24

3-
pub use machine_comprehension::MachineComprehension;
5+
pub use machine_comprehension::{MachineComprehension, RoBERTa, GPT2};

src/download/language/machine_comprehension.rs

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#![allow(clippy::upper_case_acronyms)]
22

3+
//! Models for machine language comprehension.
4+
35
use crate::download::ModelUrl;
46

57
/// Machine comprehension models.
@@ -20,14 +22,18 @@ pub enum MachineComprehension {
2022
/// Large transformer-based model that predicts sentiment based on given input text.
2123
#[derive(Debug, Clone)]
2224
pub enum RoBERTa {
25+
/// Base RoBERTa model.
2326
RoBERTaBase,
27+
/// RoBERTa model for sequence classification.
2428
RoBERTaSequenceClassification
2529
}
2630

2731
/// Generates synthetic text samples in response to the model being primed with an arbitrary input.
2832
#[derive(Debug, Clone)]
2933
pub enum GPT2 {
34+
/// Base GPT-2 model.
3035
GPT2,
36+
/// GPT-2 model with a causal LM head.
3137
GPT2LmHead
3238
}
3339

src/download/vision.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! Models for computer vision.
2+
13
pub mod body_face_gesture_analysis;
24
pub mod domain_based_image_classification;
35
pub mod image_classification;
@@ -6,6 +8,6 @@ pub mod object_detection_image_segmentation;
68

79
pub use body_face_gesture_analysis::BodyFaceGestureAnalysis;
810
pub use domain_based_image_classification::DomainBasedImageClassification;
9-
pub use image_classification::ImageClassification;
10-
pub use image_manipulation::ImageManipulation;
11+
pub use image_classification::{ImageClassification, InceptionVersion, ResNet, ResNetV1, ResNetV2, ShuffleNetVersion, Vgg};
12+
pub use image_manipulation::{FastNeuralStyleTransferStyle, ImageManipulation};
1113
pub use object_detection_image_segmentation::ObjectDetectionImageSegmentation;

src/download/vision/body_face_gesture_analysis.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
//! Models for body, face, & gesture analysis.
2+
13
use crate::download::ModelUrl;
24

5+
/// Models for body, face, & gesture analysis.
36
#[derive(Debug, Clone)]
47
pub enum BodyFaceGestureAnalysis {
58
/// A CNN based model for face recognition which learns discriminative features of faces and produces embeddings for

src/download/vision/domain_based_image_classification.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
//! Models for domain-based image classification.
2+
13
use crate::download::ModelUrl;
24

5+
/// Models for domain-based image classification.
36
#[derive(Debug, Clone)]
47
pub enum DomainBasedImageClassification {
58
/// Handwritten digit prediction using CNN.

src/download/vision/image_classification.rs

+9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
//! Models for image classification.
2+
13
#![allow(clippy::upper_case_acronyms)]
24

35
use crate::download::ModelUrl;
46

7+
/// Models for image classification.
58
#[derive(Debug, Clone)]
69
pub enum ImageClassification {
710
/// Image classification aimed for mobile targets.
@@ -73,10 +76,15 @@ pub enum ResNet {
7376

7477
#[derive(Debug, Clone)]
7578
pub enum ResNetV1 {
79+
/// ResNet v1 with 18 layers.
7680
ResNet18,
81+
/// ResNet v1 with 34 layers.
7782
ResNet34,
83+
/// ResNet v1 with 50 layers.
7884
ResNet50,
85+
/// ResNet v1 with 101 layers.
7986
ResNet101,
87+
/// ResNet v1 with 152 layers.
8088
ResNet152
8189
}
8290

@@ -109,6 +117,7 @@ pub enum Vgg {
109117
/// power.
110118
#[derive(Debug, Clone)]
111119
pub enum ShuffleNetVersion {
120+
/// The original ShuffleNet.
112121
V1,
113122
/// ShuffleNetV2 is an improved architecture that is the state-of-the-art in terms of speed and accuracy tradeoff
114123
/// used for image classification.

src/error.rs

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! Types and helpers for handling ORT errors.
2+
13
use std::{io, path::PathBuf, string};
24

35
use thiserror::Error;
@@ -7,9 +9,11 @@ use super::{char_p_to_string, ort, sys, tensor::TensorElementDataType};
79
/// Type alias for the Result type returned by ORT functions.
810
pub type OrtResult<T> = std::result::Result<T, OrtError>;
911

12+
/// An enum of all errors returned by ORT functions.
1013
#[non_exhaustive]
1114
#[derive(Error, Debug)]
1215
pub enum OrtError {
16+
/// An error occurred when converting an FFI C string to a Rust `String`.
1317
#[error("Failed to construct Rust String")]
1418
FfiStringConversion(OrtApiError),
1519
/// An error occurred while creating an ONNX environment.

src/execution_providers.rs

+13-4
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ extern "C" {
1616
pub(crate) fn OrtSessionOptionsAppendExecutionProvider_DML(options: *mut sys::OrtSessionOptions, device_id: std::os::raw::c_int) -> sys::OrtStatusPtr;
1717
}
1818

19+
/// Execution provider container. See [the ONNX Runtime docs](https://onnxruntime.ai/docs/execution-providers/) for more
20+
/// info on execution providers. Execution providers are actually registered via the `with_execution_providers()`
21+
/// functions [`crate::SessionBuilder`] (per-session) or [`crate::EnvBuilder`] (default for all sessions in an
22+
/// environment).
1923
#[derive(Debug, Clone)]
2024
pub struct ExecutionProvider {
2125
provider: String,
@@ -51,6 +55,9 @@ macro_rules! ep_options {
5155
}
5256

5357
impl ExecutionProvider {
58+
/// Creates an `ExecutionProvider` for the given execution provider name.
59+
///
60+
/// You probably want the dedicated methods instead, e.g. [`ExecutionProvider::cuda`].
5461
pub fn new(provider: impl Into<String>) -> Self {
5562
Self {
5663
provider: provider.into(),
@@ -69,6 +76,8 @@ impl ExecutionProvider {
6976
directml = "DmlExecutionProvider"
7077
}
7178

79+
/// Returns `true` if this execution provider is available, `false` otherwise.
80+
/// The CPU execution provider will always be available.
7281
pub fn is_available(&self) -> bool {
7382
let mut providers: *mut *mut c_char = std::ptr::null_mut();
7483
let mut num_providers = 0;
@@ -90,9 +99,9 @@ impl ExecutionProvider {
9099
false
91100
}
92101

93-
/// Configure this execution provider with the given option.
94-
pub fn with(mut self, k: impl Into<String>, v: impl Into<String>) -> Self {
95-
self.options.insert(k.into(), v.into());
102+
/// Configure this execution provider with the given option name and value
103+
pub fn with(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
104+
self.options.insert(name.into(), value.into());
96105
self
97106
}
98107

@@ -198,7 +207,7 @@ pub(crate) fn apply_execution_providers(options: *mut sys::OrtSessionOptions, ex
198207
"DmlExecutionProvider" => {
199208
let device_id = init_args.get("device_id").map_or(0, |s| s.parse::<i32>().unwrap_or(0));
200209
// TODO: extended options with OrtSessionOptionsAppendExecutionProviderEx_DML
201-
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_DML(options, device_id.into()) };
210+
let status = unsafe { OrtSessionOptionsAppendExecutionProvider_DML(options, device_id) };
202211
if status_to_result_and_log("DirectML", status).is_ok() {
203212
return; // EP found
204213
}

src/lib.rs

+23-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#![doc = include_str!("../README.md")]
2+
13
pub mod download;
24
pub mod environment;
35
pub mod error;
@@ -42,6 +44,11 @@ lazy_static! {
4244
};
4345
}
4446

47+
/// Attempts to acquire the global OrtApi object.
48+
///
49+
/// # Panics
50+
///
51+
/// Panics if another thread panicked while holding the API lock, or if the ONNX Runtime API could not be initialized.
4552
pub fn ort() -> sys::OrtApi {
4653
let mut api_ref = G_ORT_API.lock().expect("failed to acquire OrtApi lock; another thread panicked?");
4754
let api_ref_mut: &mut *mut sys::OrtApi = api_ref.get_mut();
@@ -176,15 +183,20 @@ extern_system_fn! {
176183
}
177184
}
178185

179-
/// ONNX Runtime logging level.
186+
/// The minimum logging level. Logs will be handled by the `tracing` crate.
180187
#[derive(Debug)]
181188
#[cfg_attr(not(windows), repr(u32))]
182189
#[cfg_attr(windows, repr(i32))]
183190
pub enum LoggingLevel {
191+
/// Verbose logging level. This will log *a lot* of messages!
184192
Verbose = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_VERBOSE as OnnxEnumInt,
193+
/// Info logging level.
185194
Info = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_INFO as OnnxEnumInt,
195+
/// Warning logging level. Recommended to receive potentially important warnings.
186196
Warning = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_WARNING as OnnxEnumInt,
197+
/// Error logging level.
187198
Error = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_ERROR as OnnxEnumInt,
199+
/// Fatal logging level.
188200
Fatal = sys::OrtLoggingLevel_ORT_LOGGING_LEVEL_FATAL as OnnxEnumInt
189201
}
190202

@@ -212,7 +224,7 @@ impl From<LoggingLevel> for sys::OrtLoggingLevel {
212224
/// The optimizations belonging to one level are performed after the optimizations of the previous level have been
213225
/// applied (e.g., extended optimizations are applied after basic optimizations have been applied).
214226
///
215-
/// **All optimizations are enabled by default.**
227+
/// **All optimizations (i.e. [`GraphOptimizationLevel::Level3`]) are enabled by default.**
216228
///
217229
/// # Online/offline mode
218230
/// All optimizations can be performed either online or offline. In online mode, when initializing an inference session,
@@ -233,6 +245,7 @@ impl From<LoggingLevel> for sys::OrtLoggingLevel {
233245
#[cfg_attr(not(windows), repr(u32))]
234246
#[cfg_attr(windows, repr(i32))]
235247
pub enum GraphOptimizationLevel {
248+
/// Disables all graph optimizations.
236249
Disable = sys::GraphOptimizationLevel_ORT_DISABLE_ALL as OnnxEnumInt,
237250
/// Level 1 includes semantics-preserving graph rewrites which remove redundant nodes and redundant computation.
238251
/// They run before graph partitioning and thus apply to all the execution providers. Available basic/level 1 graph
@@ -292,13 +305,13 @@ impl From<GraphOptimizationLevel> for sys::GraphOptimizationLevel {
292305
}
293306
}
294307

295-
/// Allocator type
308+
/// Execution provider allocator type.
296309
#[derive(Debug, Clone)]
297310
#[repr(i32)]
298311
pub enum AllocatorType {
299-
/// Device allocator
312+
/// Default device-specific allocator.
300313
Device = sys::OrtAllocatorType_OrtDeviceAllocator,
301-
/// Arena allocator
314+
/// Arena allocator.
302315
Arena = sys::OrtAllocatorType_OrtArenaAllocator
303316
}
304317

@@ -311,17 +324,20 @@ impl From<AllocatorType> for sys::OrtAllocatorType {
311324
}
312325
}
313326

314-
/// Memory type
327+
/// Memory types for allocated memory.
315328
#[derive(Debug, Clone)]
316329
#[repr(i32)]
317330
pub enum MemType {
331+
/// Any CPU memory used by non-CPU execution provider.
318332
CPUInput = sys::OrtMemType_OrtMemTypeCPUInput,
333+
/// CPU accessible memory outputted by non-CPU execution provider, i.e. CUDA_PINNED.
319334
CPUOutput = sys::OrtMemType_OrtMemTypeCPUOutput,
320-
/// Default memory type
335+
/// The default allocator for an execution provider.
321336
Default = sys::OrtMemType_OrtMemTypeDefault
322337
}
323338

324339
impl MemType {
340+
/// Temporary CPU accessible memory allocated by non-CPU execution provider, i.e. CUDA_PINNED.
325341
pub const CPU: MemType = MemType::CPUOutput;
326342
}
327343

src/metadata.rs

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::{ffi::CString, os::raw::c_char};
44

55
use super::{char_p_to_string, error::OrtResult, ortfree, ortsys, sys, OrtError};
66

7+
/// Container for model metadata, including name & producer information.
78
pub struct Metadata {
89
metadata_ptr: *mut sys::OrtModelMetadata,
910
allocator_ptr: *mut sys::OrtAllocator
@@ -14,6 +15,7 @@ impl Metadata {
1415
Metadata { metadata_ptr, allocator_ptr }
1516
}
1617

18+
/// Gets the model description, returning an error if no description is present.
1719
pub fn description(&self) -> OrtResult<String> {
1820
let mut str_bytes: *mut c_char = std::ptr::null_mut();
1921
ortsys![unsafe ModelMetadataGetDescription(self.metadata_ptr, self.allocator_ptr, &mut str_bytes) -> OrtError::GetModelMetadata; nonNull(str_bytes)];
@@ -23,6 +25,7 @@ impl Metadata {
2325
Ok(value)
2426
}
2527

28+
/// Gets the model producer name, returning an error if no producer name is present.
2629
pub fn producer(&self) -> OrtResult<String> {
2730
let mut str_bytes: *mut c_char = std::ptr::null_mut();
2831
ortsys![unsafe ModelMetadataGetProducerName(self.metadata_ptr, self.allocator_ptr, &mut str_bytes) -> OrtError::GetModelMetadata; nonNull(str_bytes)];
@@ -32,6 +35,7 @@ impl Metadata {
3235
Ok(value)
3336
}
3437

38+
/// Gets the model name, returning an error if no name is present.
3539
pub fn name(&self) -> OrtResult<String> {
3640
let mut str_bytes: *mut c_char = std::ptr::null_mut();
3741
ortsys![unsafe ModelMetadataGetGraphName(self.metadata_ptr, self.allocator_ptr, &mut str_bytes) -> OrtError::GetModelMetadata; nonNull(str_bytes)];
@@ -41,12 +45,14 @@ impl Metadata {
4145
Ok(value)
4246
}
4347

48+
/// Gets the model version, returning an error if no version is present.
4449
pub fn version(&self) -> OrtResult<i64> {
4550
let mut ver = 0i64;
4651
ortsys![unsafe ModelMetadataGetVersion(self.metadata_ptr, &mut ver) -> OrtError::GetModelMetadata];
4752
Ok(ver)
4853
}
4954

55+
/// Fetch the value of a custom metadata key. Returns `Ok(None)` if the key is not found.
5056
pub fn custom(&self, key: &str) -> OrtResult<Option<String>> {
5157
let mut str_bytes: *mut c_char = std::ptr::null_mut();
5258
let key_str = CString::new(key)?;

src/session.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
//! Contains the [`Session`] and [`SessionBuilder`] types for managing ONNX Runtime sessions and performing inference.
2+
13
#![allow(clippy::tabs_in_doc_comments)]
24

35
#[cfg(not(target_family = "windows"))]
@@ -95,6 +97,7 @@ impl Drop for SessionBuilder {
9597
}
9698

9799
impl SessionBuilder {
100+
/// Creates a new session builder in the given environment.
98101
pub fn new(env: &Arc<Environment>) -> OrtResult<Self> {
99102
let mut session_options_ptr: *mut sys::OrtSessionOptions = std::ptr::null_mut();
100103
ortsys![unsafe CreateSessionOptions(&mut session_options_ptr) -> OrtError::CreateSessionOptions; nonNull(session_options_ptr)];
@@ -673,14 +676,16 @@ impl Session {
673676
Ok(())
674677
}
675678

679+
/// Gets the session model metadata. See [`Metadata`] for more info.
676680
pub fn metadata(&self) -> OrtResult<Metadata> {
677681
let mut metadata_ptr: *mut sys::OrtModelMetadata = std::ptr::null_mut();
678682
ortsys![unsafe SessionGetModelMetadata(self.session_ptr, &mut metadata_ptr) -> OrtError::GetModelMetadata; nonNull(metadata_ptr)];
679683
Ok(Metadata::new(metadata_ptr, self.allocator_ptr))
680684
}
681685

682-
/// Ends profiling for this session. Note that this must be explicitly called at the end of profiling, otherwise
683-
/// the profiing file will be empty.
686+
/// Ends profiling for this session.
687+
///
688+
/// Note that this must be explicitly called at the end of profiling, otherwise the profiing file will be empty.
684689
#[cfg(feature = "profiling")]
685690
pub fn end_profiling(&self) -> OrtResult<String> {
686691
let mut profiling_name: *mut c_char = std::ptr::null_mut();

src/tensor.rs

+1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ impl_type_trait!(half::bf16, Bfloat16);
148148
/// would conflict with the implementations of [IntoTensorElementDataType] for primitive numeric
149149
/// types (which might implement [`AsRef<str>`] at some point in the future).
150150
pub trait Utf8Data {
151+
/// Returns the contents of this value as a slice of UTF-8 bytes.
151152
fn utf8_bytes(&self) -> &[u8];
152153
}
153154

0 commit comments

Comments
 (0)