Skip to content

Commit

Permalink
feat: update environment after initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Jan 12, 2024
1 parent a683702 commit ea7d059
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 62 deletions.
126 changes: 70 additions & 56 deletions src/environment.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
#[cfg(feature = "load-dynamic")]
use std::sync::Arc;
use std::{
ffi::CString,
sync::{atomic::AtomicPtr, OnceLock}
};
use std::{cell::UnsafeCell, ffi::CString, sync::atomic::AtomicPtr, sync::Arc};

use tracing::debug;

Expand All @@ -15,20 +10,40 @@ use super::{
#[cfg(feature = "load-dynamic")]
use crate::G_ORT_DYLIB_PATH;

static G_ENV: OnceLock<EnvironmentSingleton> = OnceLock::new();
struct EnvironmentSingleton {
cell: UnsafeCell<Option<Arc<Environment>>>
}

unsafe impl Sync for EnvironmentSingleton {}

static G_ENV: EnvironmentSingleton = EnvironmentSingleton { cell: UnsafeCell::new(None) };

#[derive(Debug)]
pub(crate) struct EnvironmentSingleton {
pub(crate) struct Environment {
pub(crate) execution_providers: Vec<ExecutionProviderDispatch>,
pub(crate) env_ptr: AtomicPtr<ort_sys::OrtEnv>
}

pub(crate) fn get_environment() -> Result<&'static EnvironmentSingleton> {
if G_ENV.get().is_none() {
EnvironmentBuilder::default().commit()?;
Ok(G_ENV.get().unwrap())
impl Drop for Environment {
#[tracing::instrument]
fn drop(&mut self) {
let env_ptr: *mut ort_sys::OrtEnv = *self.env_ptr.get_mut();

debug!("Releasing environment");

assert_ne!(env_ptr, std::ptr::null_mut());
ortsys![unsafe ReleaseEnv(env_ptr)];
}
}

pub(crate) fn get_environment() -> Result<&'static Arc<Environment>> {
if let Some(c) = unsafe { &*G_ENV.cell.get() } {
Ok(c)
} else {
Ok(unsafe { G_ENV.get().unwrap_unchecked() })
debug!("Environment not yet initialized, creating a new one");
EnvironmentBuilder::default().commit()?;

Ok(unsafe { (*G_ENV.cell.get()).as_ref().unwrap_unchecked() })
}
}

Expand Down Expand Up @@ -131,63 +146,62 @@ impl EnvironmentBuilder {

/// Commit the configuration to a new [`Environment`].
pub fn commit(self) -> Result<()> {
if G_ENV.get().is_none() {
debug!("Environment not yet initialized, creating a new one");

let env_ptr = if let Some(global_thread_pool) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap();

let mut thread_options: *mut ort_sys::OrtThreadingOptions = std::ptr::null_mut();
ortsys![unsafe CreateThreadingOptions(&mut thread_options) -> Error::CreateEnvironment; nonNull(thread_options)];
if let Some(inter_op_parallelism) = global_thread_pool.inter_op_parallelism {
ortsys![unsafe SetGlobalInterOpNumThreads(thread_options, inter_op_parallelism) -> Error::CreateEnvironment];
}
if let Some(intra_op_parallelism) = global_thread_pool.intra_op_parallelism {
ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism) -> Error::CreateEnvironment];
}
if let Some(spin_control) = global_thread_pool.spin_control {
ortsys![unsafe SetGlobalSpinControl(thread_options, if spin_control { 1 } else { 0 }) -> Error::CreateEnvironment];
}
if let Some(intra_op_thread_affinity) = global_thread_pool.intra_op_thread_affinity {
let cstr = CString::new(intra_op_thread_affinity).unwrap();
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment];
}

ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
let env_ptr = if let Some(global_thread_pool) = self.global_thread_pool_options {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap();

let mut thread_options: *mut ort_sys::OrtThreadingOptions = std::ptr::null_mut();
ortsys![unsafe CreateThreadingOptions(&mut thread_options) -> Error::CreateEnvironment; nonNull(thread_options)];
if let Some(inter_op_parallelism) = global_thread_pool.inter_op_parallelism {
ortsys![unsafe SetGlobalInterOpNumThreads(thread_options, inter_op_parallelism) -> Error::CreateEnvironment];
}
if let Some(intra_op_parallelism) = global_thread_pool.intra_op_parallelism {
ortsys![unsafe SetGlobalIntraOpNumThreads(thread_options, intra_op_parallelism) -> Error::CreateEnvironment];
}
if let Some(spin_control) = global_thread_pool.spin_control {
ortsys![unsafe SetGlobalSpinControl(thread_options, if spin_control { 1 } else { 0 }) -> Error::CreateEnvironment];
}
if let Some(intra_op_thread_affinity) = global_thread_pool.intra_op_thread_affinity {
let cstr = CString::new(intra_op_thread_affinity).unwrap();
ortsys![unsafe SetGlobalIntraOpThreadAffinity(thread_options, cstr.as_ptr()) -> Error::CreateEnvironment];
}

ortsys![unsafe CreateEnvWithCustomLoggerAndGlobalThreadPools(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
thread_options,
&mut env_ptr
) -> Error::CreateEnvironment; nonNull(env_ptr)];
ortsys![unsafe ReleaseThreadingOptions(thread_options)];
env_ptr
} else {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap();
ortsys![unsafe CreateEnvWithCustomLogger(
ortsys![unsafe ReleaseThreadingOptions(thread_options)];
env_ptr
} else {
let mut env_ptr: *mut ort_sys::OrtEnv = std::ptr::null_mut();
let logging_function: ort_sys::OrtLoggingFunction = Some(custom_logger);
// FIXME: What should go here?
let logger_param: *mut std::ffi::c_void = std::ptr::null_mut();
let cname = CString::new(self.name.clone()).unwrap();
ortsys![unsafe CreateEnvWithCustomLogger(
logging_function,
logger_param,
ort_sys::OrtLoggingLevel::ORT_LOGGING_LEVEL_VERBOSE,
cname.as_ptr(),
&mut env_ptr
) -> Error::CreateEnvironment; nonNull(env_ptr)];
env_ptr
};
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");
env_ptr
};
debug!(env_ptr = format!("{:?}", env_ptr).as_str(), "Environment created");

let _ = G_ENV.set(EnvironmentSingleton {
unsafe {
*G_ENV.cell.get() = Some(Arc::new(Environment {
execution_providers: self.execution_providers,
env_ptr: AtomicPtr::new(env_ptr)
});
}
}));
};

Ok(())
}
}
Expand Down Expand Up @@ -223,11 +237,11 @@ mod tests {
use super::*;

fn is_env_initialized() -> bool {
G_ENV.get().is_some() && !G_ENV.get().unwrap().env_ptr.load(Ordering::Relaxed).is_null()
unsafe { (*G_ENV.cell.get()).as_ref() }.is_some() && !unsafe { (*G_ENV.cell.get()).as_ref() }.unwrap().env_ptr.load(Ordering::Relaxed).is_null()
}

fn env_ptr() -> Option<*mut ort_sys::OrtEnv> {
G_ENV.get().map(|f| f.env_ptr.load(Ordering::Relaxed))
unsafe { (*G_ENV.cell.get()).as_ref() }.map(|f| f.env_ptr.load(Ordering::Relaxed))
}

struct ConcurrentTestRun {
Expand Down
18 changes: 14 additions & 4 deletions src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ use super::{
value::{Value, ValueType},
AllocatorType, GraphOptimizationLevel, MemType
};
use crate::environment::Environment;

pub(crate) mod input;
pub(crate) mod output;
Expand Down Expand Up @@ -432,7 +433,11 @@ impl SessionBuilder {
.collect::<Result<Vec<Output>>>()?;

Ok(Session {
inner: Arc::new(SharedSessionInner { session_ptr, allocator }),
inner: Arc::new(SharedSessionInner {
session_ptr,
allocator,
_environment: Arc::clone(env)
}),
inputs,
outputs
})
Expand Down Expand Up @@ -490,7 +495,11 @@ impl SessionBuilder {
.collect::<Result<Vec<Output>>>()?;

let session = Session {
inner: Arc::new(SharedSessionInner { session_ptr, allocator }),
inner: Arc::new(SharedSessionInner {
session_ptr,
allocator,
_environment: Arc::clone(env)
}),
inputs,
outputs
};
Expand All @@ -503,7 +512,8 @@ impl SessionBuilder {
#[derive(Debug)]
pub struct SharedSessionInner {
pub(crate) session_ptr: *mut ort_sys::OrtSession,
allocator: Allocator
allocator: Allocator,
_environment: Arc<Environment>
}

unsafe impl Send for SharedSessionInner {}
Expand Down Expand Up @@ -680,7 +690,7 @@ impl Session {
// The C API expects pointers for the arrays (pointers to C-arrays)
let input_ort_values: Vec<*const ort_sys::OrtValue> = input_values.iter().map(|input_array_ort| input_array_ort.ptr() as *const _).collect();

let run_options_ptr = if let Some(run_options) = run_options {
let run_options_ptr = if let Some(run_options) = &run_options {
run_options.run_options_ptr
} else {
std::ptr::null_mut()
Expand Down
38 changes: 36 additions & 2 deletions tests/upsample.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::path::Path;
use std::{
path::Path,
sync::{mpsc, Arc}
};

use image::RgbImage;
use ndarray::{Array, CowArray, Ix4};
use ort::{inputs, GraphOptimizationLevel, Session, Tensor};
use ort::{inputs, GraphOptimizationLevel, RunOptions, Session, Tensor};
use test_log::test;

fn load_input_image<P: AsRef<Path>>(name: P) -> RgbImage {
Expand Down Expand Up @@ -113,3 +116,34 @@ fn upsample_with_ort_model() -> ort::Result<()> {

Ok(())
}

#[test]
fn upsample_termination() -> ort::Result<()> {
const IMAGE_TO_LOAD: &str = "mushroom.png";

ort::init().with_name("integration_test").commit()?;

let session_data =
std::fs::read(Path::new(env!("CARGO_MANIFEST_DIR")).join("tests").join("data").join("upsample.onnx")).expect("Could not open model from file");
let session = Session::builder()?
.with_model_from_memory(&session_data)
.expect("Could not read model from memory");

// Load image, converting to RGB format
let image_buffer = load_input_image(IMAGE_TO_LOAD);
let array = convert_image_to_cow_array(&image_buffer);

let run_options = Arc::new(RunOptions::new()?);
let (sender, receiver) = mpsc::channel::<()>();

let run_options_ = Arc::clone(&run_options);
std::thread::spawn(move || {
receiver.recv().unwrap();
run_options_.set_terminate().unwrap();
});

sender.send(()).unwrap();
panic!("{:?}", session.run_with_options(inputs![&array]?, run_options).map(|_| ()).unwrap_err());

Ok(())
}

0 comments on commit ea7d059

Please sign in to comment.