Skip to content

Commit

Permalink
refactor!: undo The Flattening
Browse files Browse the repository at this point in the history
  • Loading branch information
decahedron1 committed Nov 13, 2024
1 parent 17fe990 commit d4f82fc
Show file tree
Hide file tree
Showing 67 changed files with 521 additions and 343 deletions.
2 changes: 1 addition & 1 deletion benches/squeezenet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{path::Path, sync::Arc};
use glassbench::{Bench, pretend_used};
use image::{ImageBuffer, Pixel, Rgb, imageops::FilterType};
use ndarray::{Array4, s};
use ort::{GraphOptimizationLevel, Session};
use ort::session::{Session, builder::GraphOptimizationLevel};

fn load_squeezenet_data() -> ort::Result<(Session, Array4<f32>)> {
const IMAGE_TO_LOAD: &str = "mushroom.png";
Expand Down
18 changes: 9 additions & 9 deletions docs/pages/fundamentals/value.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ Sessions in `ort` return a map of `DynValue`s. You can determine a value's type
## Downcasting
**Downcasting** means to convert a `Dyn` type like `DynValue` to stronger type like `DynTensor`. Downcasting can be performed using the `.downcast()` function on `DynValue`:
```rs
let value: ort::DynValue = outputs.remove("output0").unwrap();
let value: ort::value::DynValue = outputs.remove("output0").unwrap();

let dyn_tensor: ort::DynTensor = value.downcast()?;
let dyn_tensor: ort::value::DynTensor = value.downcast()?;
```

If `value` is not actually a tensor, the `downcast()` call will fail.
Expand All @@ -30,9 +30,9 @@ If `value` is not actually a tensor, the `downcast()` call will fail.

The strongly typed variants of these types - `Tensor<T>`, `Sequence<T>`, and `Map<K, V>`, can be directly downcasted to, too:
```rs
let dyn_value: ort::DynValue = outputs.remove("output0").unwrap();
let dyn_value: ort::value::DynValue = outputs.remove("output0").unwrap();

let f32_tensor: ort::Tensor<f32> = dyn_value.downcast()?;
let f32_tensor: ort::value::Tensor<f32> = dyn_value.downcast()?;
```

If `value` is not a tensor, **or** if the element type of the value does not match what was requested (`f32`), the `downcast()` call will fail.
Expand All @@ -43,7 +43,7 @@ Stronger typed values have infallible variants of the `.try_extract_*` methods:
let f32_array: ArrayViewD<f32> = dyn_value.try_extract_tensor()?;

// Or, we can first onvert it to a tensor, and then extract afterwards:
let tensor: ort::Tensor<f32> = dyn_value.downcast()?;
let tensor: ort::value::Tensor<f32> = dyn_value.downcast()?;
let f32_array = tensor.extract_tensor(); // no `?` required, this will never fail!
```

Expand Down Expand Up @@ -89,16 +89,16 @@ View types are suffixed with `Ref` or `RefMut` for shared/mutable variants respe

These views can be acquired with `.view()` or `.view_mut()` on a value type:
```rs
let my_tensor: ort::Tensor<f32> = Tensor::new(...)?;
let my_tensor: ort::value::Tensor<f32> = Tensor::new(...)?;

let tensor_view: ort::TensorRef<'_, f32> = my_tensor.view();
let tensor_view: ort::value::TensorRef<'_, f32> = my_tensor.view();
```

Views act identically to a borrow of their type - `TensorRef` supports `extract_tensor`, `TensorRefMut` supports `extract_tensor_mut`. The same is true for sequences & maps. Views also support down/upcasting via `.downcast()` & `.into_dyn()` (but not `.upcast()` at the moment).

You can also directly downcast a value to a stronger-typed view using `.downcast_ref()` and `.downcast_mut()`:
```rs
let tensor_view: ort::TensorRef<'_, f32> = dyn_value.downcast_ref()?;
let tensor_view: ort::value::TensorRef<'_, f32> = dyn_value.downcast_ref()?;
// is equivalent to
let tensor_view: ort::TensorRef<'_, f32> = dyn_value.view().downcast()?;
let tensor_view: ort::value::TensorRef<'_, f32> = dyn_value.view().downcast()?;
```
2 changes: 1 addition & 1 deletion docs/pages/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Your model will need to be converted to an ONNX graph before you can use it.
Once you've got a model, load it via `ort` by creating a [`Session`](/fundamentals/session):

```rust
use ort::{GraphOptimizationLevel, Session};
use ort::session::{builder::GraphOptimizationLevel, Session};

let model = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
Expand Down
5 changes: 1 addition & 4 deletions docs/pages/migrating/v2.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ let dyn_value: DynValue = tensor.into_dyn();
```

### Tensor extraction directly returns an `ArrayView`
The new `extract_tensor` and `try_extract_tensor` functions return an `ndarray::ArrayView` directly, instead of putting it behind the old `ort::Tensor<T>` type (not to be confused with the new specialized value type). This means you don't have to `.view()` on the result:
The new `extract_tensor` and `try_extract_tensor` functions return an `ndarray::ArrayView` directly, instead of putting it behind the old `ort::value::Tensor<T>` type (not to be confused with the new specialized value type). This means you don't have to `.view()` on the result:
```diff
-let generated_tokens: Tensor<f32> = outputs["output1"].try_extract()?;
-let generated_tokens = generated_tokens.view();
Expand Down Expand Up @@ -201,9 +201,6 @@ You can still use `Session::commit_from_url`, it just now takes a URL string ins
## Changes to logging
Environment-level logging configuration (i.e. `EnvironmentBuilder::with_log_level`) has been removed because it could cause unnecessary confusion with our `tracing` integration.

## The Flattening
All modules except `download` and `sys` are now private. Exports have been flattened to the crate root, so i.e. `ort::session::Session` becomes `ort::Session`.

## Renamed types
The following types have been renamed with no other changes.
- `NdArrayExtensions` -> `ArrayExtensions`
Expand Down
23 changes: 16 additions & 7 deletions docs/pages/perf/execution-providers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ ONNX Runtime must be compiled with support for each execution provider. pyke pro
In order to configure sessions to use certain execution providers, you must **register** them when creating an environment or session. You can do this via the `SessionBuilder::with_execution_providers` method. For example, to register the CUDA execution provider for a session:

```rust
use ort::{CUDAExecutionProvider, Session};
use ort::{execution_providers::CUDAExecutionProvider, session::Session};

fn main() -> anyhow::Result<()> {
let session = Session::builder()?
Expand All @@ -66,7 +66,10 @@ fn main() -> anyhow::Result<()> {
You can, of course, specify multiple execution providers. `ort` will register all EPs specified, in order. If an EP does not support a certain operator in a graph, it will fall back to the next successfully registered EP, or to the CPU if all else fails.

```rust
use ort::{CoreMLExecutionProvider, CUDAExecutionProvider, DirectMLExecutionProvider, TensorRTExecutionProvider, Session};
use ort::{
execution_providers::{CoreMLExecutionProvider, CUDAExecutionProvider, DirectMLExecutionProvider, TensorRTExecutionProvider},
session::Session
};

fn main() -> anyhow::Result<()> {
let session = Session::builder()?
Expand All @@ -89,7 +92,7 @@ fn main() -> anyhow::Result<()> {
EPs have configuration options to control behavior or increase performance. Each `XXXExecutionProvider` struct returns a builder with configuration methods. See the [API reference](https://docs.rs/ort/2.0.0-rc.8/ort/index.html?search=ExecutionProvider) for the EP structs for more information on which options are supported and what they do.

```rust
use ort::{CoreMLExecutionProvider, Session};
use ort::{execution_providers::CoreMLExecutionProvider, session::Session};

fn main() -> anyhow::Result<()> {
let session = Session::builder()?
Expand All @@ -112,7 +115,7 @@ fn main() -> anyhow::Result<()> {

You can configure an EP to return an error on failure by adding `.error_on_failure()` after you `.build()` it. In this example, if CUDA doesn't register successfully, the program will exit with an error at `with_execution_providers`:
```rust
use ort::{CoreMLExecutionProvider, Session};
use ort::{execution_providers::CoreMLExecutionProvider, session::Session};

fn main() -> anyhow::Result<()> {
let session = Session::builder()?
Expand All @@ -128,7 +131,10 @@ fn main() -> anyhow::Result<()> {
If you require more complex error handling, you can also manually register execution providers via the `ExecutionProvider::register` method:

```rust
use ort::{CUDAExecutionProvider, ExecutionProvider, Session};
use ort::{
execution_providers::{CUDAExecutionProvider, ExecutionProvider},
session::Session
};

fn main() -> anyhow::Result<()> {
let builder = Session::builder()?;
Expand All @@ -148,7 +154,10 @@ fn main() -> anyhow::Result<()> {
You can also check whether ONNX Runtime is even compiled with support for the execution provider with the `is_available` method.

```rust
use ort::{CoreMLExecutionProvider, ExecutionProvider, Session};
use ort::{
execution_providers::{CoreMLExecutionProvider, ExecutionProvider},
session::Session
};

fn main() -> anyhow::Result<()> {
let builder = Session::builder()?;
Expand All @@ -172,7 +181,7 @@ fn main() -> anyhow::Result<()> {
You can configure `ort` to attempt to register a list of execution providers for all sessions created in an environment.

```rust
use ort::{CUDAExecutionProvider, Session};
use ort::{execution_providers::CUDAExecutionProvider, session::Session};

fn main() -> anyhow::Result<()> {
ort::init()
Expand Down
14 changes: 7 additions & 7 deletions docs/pages/troubleshooting/compiling.mdx
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
## The trait bound `ort::Value: From<...>` is not satisfied
## The trait bound `ort::value::Value: From<...>` is not satisfied
An error like this might come up when attempting to upgrade from an earlier (1.x) version of `ort` to a more recent version:
```
error[E0277]: the trait bound `ort::Value: From<ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>>` is not satisfied
error[E0277]: the trait bound `ort::value::Value: From<ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>>` is not satisfied
--> src/main.rs:72:16
|
72 | let inputs = ort::inputs![
| ______________________^
73 | | input1,
74 | | ]?;
| |_________^ the trait `From<ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>>` is not implemented for `ort::Value`, which is required by `ort::Value: TryFrom<ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>>`
| |_________^ the trait `From<ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>>` is not implemented for `ort::value::Value`, which is required by `ort::value::Value: TryFrom<ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>>`
|
= help: the following other types implement trait `From<T>`:
`ort::Value` implements `From<ort::Value<DynTensorValueType>>`
`ort::Value` implements `From<ort::Value<TensorValueType<T>>>`
= note: required for `ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>` to implement `Into<ort::Value>`
= note: required for `ort::Value` to implement `TryFrom<ArrayBase<OwnedRepr<i64>, Dim<[usize; 2]>>>`
`ort::value::Value` implements `From<ort::value::Value<DynTensorValueType>>`
`ort::value::Value` implements `From<ort::value::Value<TensorValueType<T>>>`
= note: required for `ArrayBase<OwnedRepr<f32>, Dim<[usize; 3]>>` to implement `Into<ort::value::Value>`
= note: required for `ort::value::Value` to implement `TryFrom<ArrayBase<OwnedRepr<i64>, Dim<[usize; 2]>>>`
= note: this error originates in the macro `ort::inputs` (in Nightly builds, run with -Z macro-backtrace for more info)
```

Expand Down
6 changes: 5 additions & 1 deletion examples/async-gpt2-api/examples/async-gpt2-api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ use axum::{
};
use futures::Stream;
use ndarray::{Array1, ArrayViewD, Axis, array, concatenate, s};
use ort::{CUDAExecutionProvider, GraphOptimizationLevel, Session, inputs};
use ort::{
execution_providers::CUDAExecutionProvider,
inputs,
session::{Session, builder::GraphOptimizationLevel}
};
use rand::Rng;
use tokenizers::Tokenizer;
use tokio::net::TcpListener;
Expand Down
24 changes: 13 additions & 11 deletions examples/cudarc/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
use std::{ops::Mul, path::Path};

use cudarc::driver::{sys::CUdeviceptr, CudaDevice, DevicePtr, DevicePtrMut};
use image::{imageops::FilterType, GenericImageView, ImageBuffer, Rgba};
use cudarc::driver::{CudaDevice, DevicePtr, DevicePtrMut, sys::CUdeviceptr};
use image::{GenericImageView, ImageBuffer, Rgba, imageops::FilterType};
use ndarray::Array;
use ort::{AllocationDevice, AllocatorType, CUDAExecutionProvider, ExecutionProvider, MemoryInfo, MemoryType, Session, TensorRefMut};
use show_image::{event, AsImageView, WindowOptions};
use ort::{
execution_providers::{CUDAExecutionProvider, ExecutionProvider},
memory::{AllocationDevice, AllocatorType, MemoryInfo, MemoryType},
session::Session,
value::TensorRefMut
};
use show_image::{AsImageView, WindowOptions, event};

#[show_image::main]
fn main() -> anyhow::Result<()> {
Expand Down Expand Up @@ -66,13 +71,10 @@ fn main() -> anyhow::Result<()> {
let window = show_image::context()
.run_function_wait(move |context| -> Result<_, String> {
let mut window = context
.create_window(
"ort + modnet",
WindowOptions {
size: Some([img_width, img_height]),
..WindowOptions::default()
}
)
.create_window("ort + modnet", WindowOptions {
size: Some([img_width, img_height]),
..WindowOptions::default()
})
.map_err(|e| e.to_string())?;
window.set_image("photo", &output.as_image_view().map_err(|e| e.to_string())?);
Ok(window.proxy())
Expand Down
10 changes: 9 additions & 1 deletion examples/custom-ops/examples/custom-ops.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
use ndarray::Array2;
use ort::{Kernel, KernelAttributes, KernelContext, Operator, OperatorDomain, OperatorInput, OperatorOutput, Session, TensorElementType};
use ort::{
operator::{
Operator, OperatorDomain,
io::{OperatorInput, OperatorOutput},
kernel::{Kernel, KernelAttributes, KernelContext}
},
session::Session,
tensor::TensorElementType
};

struct CustomOpOne;
struct CustomOpOneKernel;
Expand Down
6 changes: 5 additions & 1 deletion examples/gpt2/examples/gpt2-no-ndarray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use std::{
sync::Arc
};

use ort::{CUDAExecutionProvider, GraphOptimizationLevel, Session, inputs};
use ort::{
execution_providers::CUDAExecutionProvider,
inputs,
session::{Session, builder::GraphOptimizationLevel}
};
use rand::Rng;
use tokenizers::Tokenizer;

Expand Down
6 changes: 5 additions & 1 deletion examples/gpt2/examples/gpt2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@ use std::{
};

use ndarray::{Array1, ArrayViewD, Axis, array, concatenate, s};
use ort::{CUDAExecutionProvider, GraphOptimizationLevel, Session, inputs};
use ort::{
execution_providers::CUDAExecutionProvider,
inputs,
session::{Session, builder::GraphOptimizationLevel}
};
use rand::Rng;
use tokenizers::Tokenizer;

Expand Down
2 changes: 1 addition & 1 deletion examples/model-info/examples/model-info.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::{env, process};

use ort::{Session, TensorElementType, ValueType};
use ort::{session::Session, tensor::TensorElementType, value::ValueType};

fn display_element_type(t: TensorElementType) -> &'static str {
match t {
Expand Down
2 changes: 1 addition & 1 deletion examples/modnet/examples/modnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{ops::Mul, path::Path};

use image::{GenericImageView, ImageBuffer, Rgba, imageops::FilterType};
use ndarray::Array;
use ort::{CUDAExecutionProvider, Session, inputs};
use ort::{execution_providers::CUDAExecutionProvider, inputs, session::Session};
use show_image::{AsImageView, WindowOptions, event};

#[show_image::main]
Expand Down
2 changes: 1 addition & 1 deletion examples/phi-3-vision/src/image_process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub const NUM_CROPS: usize = 1;
pub const _NUM_IMG_TOKENS: usize = 144;

const OPENAI_CLIP_MEAN: [f32; 3] = [0.48145466, 0.4578275, 0.40821073];
const OPENAI_CLIP_STD: [f32; 3] = [0.26862954, 0.26130258, 0.27577711];
const OPENAI_CLIP_STD: [f32; 3] = [0.26862954, 0.2613026, 0.2757771];

pub struct Phi3VImageProcessor {
num_crops: usize,
Expand Down
33 changes: 13 additions & 20 deletions examples/phi-3-vision/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@ use std::{path::Path, time::Instant};
use anyhow::Result;
use image::DynamicImage;
use ndarray::{Array, Array2, Array3, Array4, ArrayView, Ix3, Ix4, s};
use ort::{Session, Tensor};
use ort::{session::Session, value::Tensor};
use tokenizers::Tokenizer;

const VISION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-vision.onnx";
const TEXT_EMBEDDING_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text-embedding.onnx";
const GENERATION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text.onnx";
const VISION_MODEL_NAME: &str = "phi-3-v-128k-instruct-vision.onnx";
const TEXT_EMBEDDING_MODEL_NAME: &str = "phi-3-v-128k-instruct-text-embedding.onnx";
const GENERATION_MODEL_NAME: &str = "phi-3-v-128k-instruct-text.onnx";

const MAX_LENGTH: usize = 1000; // max length of the generated text
const EOS_TOKEN_ID: i64 = 32007; // <|end|>
Expand Down Expand Up @@ -37,8 +37,7 @@ fn get_image_embedding(vision_model: &Session, img: &Option<DynamicImage>) -> Re
]?;
let outputs = vision_model.run(model_inputs)?;
let predictions_view: ArrayView<f32, _> = outputs["visual_features"].try_extract_tensor::<f32>()?;
let predictions = predictions_view.into_dimensionality::<Ix3>()?.to_owned();
predictions
predictions_view.into_dimensionality::<Ix3>()?.to_owned()
} else {
Array::zeros((1, 0, 0))
};
Expand Down Expand Up @@ -71,7 +70,7 @@ fn merge_text_and_image_embeddings(
// Insert visual features
combined_embeds
.slice_mut(s![.., image_token_position..(image_token_position + visual_features.shape()[1]), ..])
.assign(&visual_features);
.assign(visual_features);

// Copy the remaining text embeddings
combined_embeds
Expand Down Expand Up @@ -109,13 +108,13 @@ pub async fn generate_text(
text: &str
) -> Result<()> {
let (inputs_embeds, mut attention_mask) = {
let visual_features = get_image_embedding(&vision_model, &image)?;
let prompt = format_chat_template(&image, text);
let visual_features = get_image_embedding(vision_model, image)?;
let prompt = format_chat_template(image, text);
let encoding = tokenizer.encode(prompt, true).map_err(|e| anyhow::anyhow!("Error encoding: {:?}", e))?;

let input_ids: Vec<i64> = encoding.get_ids().iter().map(|&id| id as i64).collect();
let input_ids: Array2<i64> = Array2::from_shape_vec((1, input_ids.len()), input_ids)?;
let mut inputs_embeds: Array3<f32> = get_text_embedding(&text_embedding_model, &input_ids)?;
let mut inputs_embeds: Array3<f32> = get_text_embedding(text_embedding_model, &input_ids)?;

let attention_mask: Vec<i64> = encoding.get_attention_mask().iter().map(|&mask| mask as i64).collect();
let mut attention_mask: Array2<i64> = Array2::from_shape_vec((1, attention_mask.len()), attention_mask)?;
Expand Down Expand Up @@ -190,7 +189,7 @@ pub async fn generate_text(

// Update current_embeds, attention_mask, and past_key_values for the next iteration
let new_token_id = Array2::from_elem((1, 1), next_token_id);
next_inputs_embeds = get_text_embedding(&text_embedding_model, &new_token_id)?;
next_inputs_embeds = get_text_embedding(text_embedding_model, &new_token_id)?;
attention_mask = Array2::ones((1, attention_mask.shape()[1] + 1));
for i in 0..32 {
past_key_values[i * 2] = model_outputs[format!("present.{}.key", i)]
Expand All @@ -213,15 +212,9 @@ async fn main() -> Result<()> {

let data_dir = Path::new(env!("CARGO_MANIFEST_DIR")).join("data");
let tokenizer = Tokenizer::from_file(data_dir.join("tokenizer.json")).map_err(|e| anyhow::anyhow!("Error loading tokenizer: {:?}", e))?;
let vision_model = Session::builder()?
.with_execution_providers([ort::CPUExecutionProvider::default().build()])?
.commit_from_file(data_dir.join(VISION_MODEL_NAME))?;
let text_embedding_model = Session::builder()?
.with_execution_providers([ort::CPUExecutionProvider::default().build()])?
.commit_from_file(data_dir.join(TEXT_EMBEDDING_MODEL_NAME))?;
let generation_model = Session::builder()?
.with_execution_providers([ort::CPUExecutionProvider::default().build()])?
.commit_from_file(data_dir.join(GENERATION_MODEL_NAME))?;
let vision_model = Session::builder()?.commit_from_file(data_dir.join(VISION_MODEL_NAME))?;
let text_embedding_model = Session::builder()?.commit_from_file(data_dir.join(TEXT_EMBEDDING_MODEL_NAME))?;
let generation_model = Session::builder()?.commit_from_file(data_dir.join(GENERATION_MODEL_NAME))?;

// Generate text from text
let image: Option<DynamicImage> = None;
Expand Down
Loading

0 comments on commit d4f82fc

Please sign in to comment.