mosure / bevy_ort Goto Github PK
View Code? Open in Web Editor NEWort (onnxruntime) plugin for bevy
License: GNU Affero General Public License v3.0
ort (onnxruntime) plugin for bevy
License: GNU Affero General Public License v3.0
Line 14 in 0c25186
use bevy::{
prelude::*,
app::AppExit,
};
use ndarray::Array;
use bevy_ort::{
BevyOrtPlugin,
inputs,
Onnx,
};
// TODO: provide an example /w AsyncComputeTask inference
fn main() {
App::new()
.add_plugins((
DefaultPlugins,
BevyOrtPlugin,
))
.init_resource::<Modnet>()
.add_systems(Startup, load_modnet)
.add_systems(Update, inference)
.run();
}
#[derive(Resource, Default)]
pub struct Modnet {
pub onnx: Handle<Onnx>,
}
fn load_modnet(
asset_server: Res<AssetServer>,
mut modnet: ResMut<Modnet>,
) {
let modnet_handle: Handle<Onnx> = asset_server.load("modnet_photographic_portrait_matting.onnx");
modnet.onnx = modnet_handle;
}
fn inference(
mut exit: EventWriter<AppExit>,
modnet: Res<Modnet>,
onnx_assets: Res<Assets<Onnx>>,
) {
let input = Array::<f32, _>::zeros((1, 3, 640, 640));
let output = (|| -> Option<_> {
let onnx = onnx_assets.get(&modnet.onnx)?;
let session = onnx.session.as_ref()?;
println!("inputs: {:?}", session.inputs);
let input_values = inputs!["image" => input.view()].unwrap();
session.run(input_values).ok()
})();
if let Some(output) = output {
println!("outputs: {:?}", output.keys());
exit.send(AppExit);
} else {
// TODO: use Result instead of Option for error reporting
println!("inference failed");
}
}
Line 46 in 1a18b54
use criterion::{
BenchmarkId,
criterion_group,
criterion_main,
Criterion,
Throughput,
};
use bevy::{
prelude::*,
render::{
render_asset::RenderAssetUsages,
render_resource::{
Extent3d,
TextureDimension,
},
},
};
use bevy_ort::{
inputs,
models::yolo_v8::{
prepare_input,
process_output,
},
Session,
};
use ort::GraphOptimizationLevel;
criterion_group!{
name = yolo_v8_benches;
config = Criterion::default().sample_size(10);
targets = prepare_input_benchmark,
process_output_benchmark,
inference_benchmark,
}
criterion_main!(yolo_v8_benches);
const RESOLUTIONS: [(u32, u32); 3] = [
(640, 360),
(1280, 720),
(1920, 1080),
];
// TODO: read input shape from session
const MODEL_WIDTH: u32 = 640;
const MODEL_HEIGHT: u32 = 640;
fn prepare_input_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("yolo_v8_prepare_input");
RESOLUTIONS.iter()
.for_each(|(width, height)| {
let data = vec![0u8; (width * height * 4) as usize];
let image = Image::new(
Extent3d {
width: *width,
height: *height,
depth_or_array_layers: 1,
},
TextureDimension::D2,
data.clone(),
bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
RenderAssetUsages::all(),
);
group.throughput(Throughput::Elements(1));
group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &image, |b, images| {
b.iter(|| prepare_input(&image, MODEL_WIDTH, MODEL_HEIGHT));
});
});
}
fn process_output_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("yolo_v8_process_output");
let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/yolov8n.onnx").unwrap();
RESOLUTIONS.iter()
.for_each(|(width, height)| {
let data = vec![0u8; (width * height * 4) as usize];
let image: Image = Image::new(
Extent3d {
width: *width,
height: *height,
depth_or_array_layers: 1,
},
TextureDimension::D2,
data.clone(),
bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
RenderAssetUsages::all(),
);
let input = prepare_input(&image, MODEL_WIDTH, MODEL_HEIGHT);
let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap();
let outputs = session.run(input_values).map_err(|e| e.to_string());
let binding = outputs.ok().unwrap();
let output_value: &ort::Value = binding.get("output0").unwrap();
group.throughput(Throughput::Elements(1));
group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &output_value, |b, output_value| {
b.iter(|| process_output(output_value, *width, *height, MODEL_WIDTH, MODEL_HEIGHT));
});
});
}
fn inference_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("yolo_v8_inference");
let session = Session::builder().unwrap()
.with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
.with_model_from_file("assets/yolov8n.onnx").unwrap();
RESOLUTIONS.iter().for_each(|(width, height)| {
let data = vec![0u8; *width as usize * *height as usize * 4];
let image = Image::new(
Extent3d {
width: *width,
height: *height,
depth_or_array_layers: 1,
},
TextureDimension::D2,
data.clone(),
bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
RenderAssetUsages::all(),
);
let input = prepare_input(&image, MODEL_WIDTH, MODEL_HEIGHT);
group.throughput(Throughput::Elements(1));
group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &(width, height), |b, _| {
b.iter(|| {
let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap();
let outputs = session.run(input_values).map_err(|e| e.to_string());
let binding = outputs.ok().unwrap();
let output_value: &ort::Value = binding.get("output0").unwrap();
process_output(output_value, *width, *height, MODEL_WIDTH, MODEL_HEIGHT);
});
});
});
}
Line 81 in 0c25186
use std::io::ErrorKind;
use bevy::{
prelude::*,
asset::{
AssetLoader,
AsyncReadExt,
LoadContext,
io::Reader,
},
utils::BoxedFuture,
};
use ort::{
CoreMLExecutionProvider,
CPUExecutionProvider,
CUDAExecutionProvider,
GraphOptimizationLevel,
OpenVINOExecutionProvider,
};
use thiserror::Error;
pub use ort::{
inputs,
Session,
};
pub struct BevyOrtPlugin;
impl Plugin for BevyOrtPlugin {
fn build(&self, app: &mut App) {
// TODO: configurable execution providers via plugin settings
ort::init()
.with_execution_providers([
CoreMLExecutionProvider::default().build(),
CUDAExecutionProvider::default().build(),
OpenVINOExecutionProvider::default().build(),
CPUExecutionProvider::default().build(),
])
.commit().ok();
app.init_asset::<Onnx>();
app.init_asset_loader::<OnnxLoader>();
}
}
#[derive(Asset, Debug, Default, TypePath)]
pub struct Onnx {
pub session: Option<Session>,
}
#[derive(Debug, Error)]
pub enum BevyOrtError {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("ort error: {0}")]
Ort(#[from] ort::Error),
}
#[derive(Default)]
pub struct OnnxLoader;
impl AssetLoader for OnnxLoader {
type Asset = Onnx;
type Settings = ();
type Error = BevyOrtError;
fn load<'a>(
&'a self,
reader: &'a mut Reader,
_settings: &'a Self::Settings,
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<Self::Asset, Self::Error>> {
Box::pin(async move {
let mut bytes = Vec::new();
reader.read_to_end(&mut bytes).await.map_err(BevyOrtError::from)?;
match load_context.path().extension() {
Some(ext) if ext == "onnx" => {
// TODO: add session configuration
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.with_model_from_memory(&bytes)?;
Ok(Onnx {
session: Some(session),
})
},
_ => Err(BevyOrtError::Io(std::io::Error::new(ErrorKind::Other, "only .onnx supported"))),
}
})
}
fn extensions(&self) -> &[&str] {
&["onnx"]
}
}
Line 50 in 6b7b48d
use bevy::{prelude::*, render::render_asset::RenderAssetUsages};
use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage};
use ndarray::{Array, Array4, ArrayView4, Axis, s};
pub fn modnet_output_to_luma_images(
output_value: &ort::Value,
) -> Vec<Image> {
let tensor: ort::Tensor<f32> = output_value.extract_tensor::<f32>().unwrap();
let data = tensor.view();
let shape = data.shape();
let batch_size = shape[0];
let width = shape[3];
let height = shape[2];
let tensor_data = ArrayView4::from_shape((batch_size, 1, height, width), data.as_slice().unwrap())
.expect("failed to create ArrayView4 from shape and data");
let mut images = Vec::new();
for i in 0..batch_size {
let mut imgbuf = ImageBuffer::<Luma<u8>, Vec<u8>>::new(width as u32, height as u32);
for y in 0..height {
for x in 0..width {
let pixel_value = tensor_data[(i, 0, y, x)];
let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8;
imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value]));
}
}
let dyn_img = DynamicImage::ImageLuma8(imgbuf);
images.push(Image::from_dynamic(dyn_img, false, RenderAssetUsages::all()));
}
images
}
pub fn images_to_modnet_input(
images: Vec<&Image>,
) -> Array4<f32> {
// TODO: better error handling
if images.is_empty() {
panic!("no images provided");
}
let ref_size = 512;
let &first_image = images.first().unwrap();
assert_eq!(first_image.texture_descriptor.format, bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb);
let dynamic_image = first_image.clone().try_into_dynamic().unwrap();
let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size);
let resized_image = resize_image(&dynamic_image, x_scale, y_scale);
let first_image_ndarray = image_to_ndarray(&resized_image);
let single_image_shape = first_image_ndarray.dim();
let n_images = images.len();
let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3);
let mut aggregate = Array4::<f32>::zeros(batch_shape);
for (i, &image) in images.iter().enumerate() {
let dynamic_image = image.clone().try_into_dynamic().unwrap();
let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size);
let resized_image = resize_image(&dynamic_image, x_scale, y_scale);
let image_ndarray = image_to_ndarray(&resized_image);
let slice = s![i, .., .., ..];
aggregate.slice_mut(slice).assign(&image_ndarray.index_axis_move(Axis(0), 0));
}
aggregate
}
Line 189 in 6231f04
use bevy::{
prelude::*,
window::PrimaryWindow,
};
use bevy_ort::{
BevyOrtPlugin,
models::lightglue::{
GluedPair,
lightglue_inference,
Lightglue,
LightgluePlugin,
},
Onnx,
};
fn main() {
App::new()
.add_plugins((
DefaultPlugins,
BevyOrtPlugin,
LightgluePlugin,
))
.init_resource::<LightglueInput>()
.add_systems(Startup, load_lightglue)
.add_systems(Update, inference)
.run();
}
#[derive(Resource, Default)]
pub struct LightglueInput {
pub a: Handle<Image>,
pub b: Handle<Image>,
}
fn load_lightglue(
asset_server: Res<AssetServer>,
mut lightglue: ResMut<Lightglue>,
mut input: ResMut<LightglueInput>,
) {
let lightglue_handle: Handle<Onnx> = asset_server.load("models/disk_lightglue_end2end_fused_cpu.onnx");
lightglue.onnx = lightglue_handle;
input.a = asset_server.load("images/sacre_coeur1.png");
input.b = asset_server.load("images/sacre_coeur2.png");
}
fn inference(
mut commands: Commands,
lightglue: Res<Lightglue>,
input: Res<LightglueInput>,
onnx_assets: Res<Assets<Onnx>>,
images: Res<Assets<Image>>,
primary_window: Query<&Window, With<PrimaryWindow>>,
mut complete: Local<bool>,
) {
if *complete {
return;
}
let window = primary_window.single();
let images = vec![
images.get(&input.a).expect("failed to get image asset"),
images.get(&input.b).expect("failed to get image asset"),
];
let images = images.iter().map(|image| *image).collect::<Vec<_>>();
let glued_pairs: Result<Vec<(usize, usize, Vec<GluedPair>)>, String> = (|| {
let onnx = onnx_assets.get(&lightglue.onnx).ok_or("failed to get ONNX asset")?;
let session_lock = onnx.session.lock().map_err(|e| e.to_string())?;
let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?;
Ok(lightglue_inference(
session,
images.as_slice(),
))
})();
match glued_pairs {
Ok(glued_pairs) => {
println!("glued_pairs: {:?}", glued_pairs[0].2.len());
commands.spawn(NodeBundle {
style: Style {
display: Display::Grid,
width: Val::Percent(100.0),
height: Val::Percent(100.0),
grid_template_columns: RepeatedGridTrack::flex(2, 1.0),
grid_template_rows: RepeatedGridTrack::flex(2, 1.0),
..default()
},
background_color: BackgroundColor(Color::DARK_GRAY),
..default()
})
.with_children(|builder| {
builder.spawn(ImageBundle {
style: Style {
..default()
},
image: UiImage::new(input.a.clone()),
..default()
})
.with_children(|builder| {
let image_width = images[0].width() as f32;
let image_height = images[0].height() as f32;
let display_width = window.width() as f32 / 2.0;
let display_height = window.height() as f32 / 2.0;
let scale_x = display_width / image_width;
let scale_y = display_height / image_height;
glued_pairs[0].2.iter().for_each(|pair| {
let scaled_x = pair.from_x as f32 * scale_x;
let scaled_y = pair.from_y as f32 * scale_y;
builder.spawn(NodeBundle {
style: Style {
position_type: PositionType::Absolute,
left: Val::Px(scaled_x),
top: Val::Px(scaled_y),
width: Val::Px(2.0),
height: Val::Px(2.0),
..default()
},
background_color: Color::rgb(1.0, 0.0, 0.0).into(),
..default()
});
});
});
builder.spawn(ImageBundle {
style: Style {
..default()
},
image: UiImage::new(input.b.clone()),
..default()
})
.with_children(|builder| {
let image_width = images[1].width() as f32;
let image_height = images[1].height() as f32;
let display_width = window.width() as f32 / 2.0;
let display_height = window.height() as f32 / 2.0;
let scale_x = display_width / image_width;
let scale_y = display_height / image_height;
glued_pairs[0].2.iter().for_each(|pair| {
let scaled_x = pair.to_x as f32 * scale_x;
let scaled_y = pair.to_y as f32 * scale_y;
builder.spawn(NodeBundle {
style: Style {
position_type: PositionType::Absolute,
left: Val::Px(scaled_x),
top: Val::Px(scaled_y),
width: Val::Px(2.0),
height: Val::Px(2.0),
..default()
},
background_color: Color::rgb(0.0, 1.0, 0.0).into(),
..default()
});
});
});
builder.spawn(ImageBundle {
style: Style {
..default()
},
image: UiImage::new(input.a.clone()),
..default()
});
builder.spawn(ImageBundle {
style: Style {
..default()
},
image: UiImage::new(input.b.clone()),
..default()
});
// TODO: draw lines between keypoints
});
commands.spawn(Camera2dBundle::default());
*complete = true;
}
Err(e) => {
eprintln!("inference failed: {}", e);
}
}
}
Line 63 in 0c25186
use bevy::{
prelude::*,
app::AppExit,
};
use ndarray::Array;
use bevy_ort::{
BevyOrtPlugin,
inputs,
Onnx,
};
// TODO: provide an example /w AsyncComputeTask inference
fn main() {
App::new()
.add_plugins((
DefaultPlugins,
BevyOrtPlugin,
))
.init_resource::<Modnet>()
.add_systems(Startup, load_modnet)
.add_systems(Update, inference)
.run();
}
#[derive(Resource, Default)]
pub struct Modnet {
pub onnx: Handle<Onnx>,
}
fn load_modnet(
asset_server: Res<AssetServer>,
mut modnet: ResMut<Modnet>,
) {
let modnet_handle: Handle<Onnx> = asset_server.load("modnet_photographic_portrait_matting.onnx");
modnet.onnx = modnet_handle;
}
fn inference(
mut exit: EventWriter<AppExit>,
modnet: Res<Modnet>,
onnx_assets: Res<Assets<Onnx>>,
) {
let input = Array::<f32, _>::zeros((1, 3, 640, 640));
let output = (|| -> Option<_> {
let onnx = onnx_assets.get(&modnet.onnx)?;
let session = onnx.session.as_ref()?;
println!("inputs: {:?}", session.inputs);
let input_values = inputs!["image" => input.view()].unwrap();
session.run(input_values).ok()
})();
if let Some(output) = output {
println!("outputs: {:?}", output.keys());
exit.send(AppExit);
} else {
// TODO: use Result instead of Option for error reporting
println!("inference failed");
}
}
Line 31 in 0c25186
use std::io::ErrorKind;
use bevy::{
prelude::*,
asset::{
AssetLoader,
AsyncReadExt,
LoadContext,
io::Reader,
},
utils::BoxedFuture,
};
use ort::{
CoreMLExecutionProvider,
CPUExecutionProvider,
CUDAExecutionProvider,
GraphOptimizationLevel,
OpenVINOExecutionProvider,
};
use thiserror::Error;
pub use ort::{
inputs,
Session,
};
pub struct BevyOrtPlugin;
impl Plugin for BevyOrtPlugin {
fn build(&self, app: &mut App) {
// TODO: configurable execution providers via plugin settings
ort::init()
.with_execution_providers([
CoreMLExecutionProvider::default().build(),
CUDAExecutionProvider::default().build(),
OpenVINOExecutionProvider::default().build(),
CPUExecutionProvider::default().build(),
])
.commit().ok();
app.init_asset::<Onnx>();
app.init_asset_loader::<OnnxLoader>();
}
}
#[derive(Asset, Debug, Default, TypePath)]
pub struct Onnx {
pub session: Option<Session>,
}
#[derive(Debug, Error)]
pub enum BevyOrtError {
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("ort error: {0}")]
Ort(#[from] ort::Error),
}
#[derive(Default)]
pub struct OnnxLoader;
impl AssetLoader for OnnxLoader {
type Asset = Onnx;
type Settings = ();
type Error = BevyOrtError;
fn load<'a>(
&'a self,
reader: &'a mut Reader,
_settings: &'a Self::Settings,
load_context: &'a mut LoadContext,
) -> BoxedFuture<'a, Result<Self::Asset, Self::Error>> {
Box::pin(async move {
let mut bytes = Vec::new();
reader.read_to_end(&mut bytes).await.map_err(BevyOrtError::from)?;
match load_context.path().extension() {
Some(ext) if ext == "onnx" => {
// TODO: add session configuration
let session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(4)?
.with_model_from_memory(&bytes)?;
Ok(Onnx {
session: Some(session),
})
},
_ => Err(BevyOrtError::Io(std::io::Error::new(ErrorKind::Other, "only .onnx supported"))),
}
})
}
fn extensions(&self) -> &[&str] {
&["onnx"]
}
}
bevy_ort/src/models/yolo_v8.rs
Line 36 in 904f83c
}
pub struct YoloPlugin;
impl Plugin for YoloPlugin {
fn build(&self, app: &mut App) {
app.init_resource::<Yolo>();
}
}
#[derive(Resource, Default)]
pub struct Yolo {
pub onnx: Handle<Onnx>,
}
// TODO: support yolo input batching
pub fn yolo_inference(
session: &ort::Session,
image: &Image,
iou_threshold: f32,
) -> Vec<BoundingBox> {
let width = image.width();
let height = image.height();
let model_width = session.inputs[0].input_type.tensor_dimensions().unwrap()[2] as u32;
let model_height = session.inputs[0].input_type.tensor_dimensions().unwrap()[3] as u32;
let input = prepare_input(image, model_width, model_height);
let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap();
let outputs = session.run(input_values).map_err(|e| e.to_string());
let binding = outputs.ok().unwrap();
let output_value: &ort::Value = binding.get("output0").unwrap();
let detections = process_output(output_value, width, height, model_width, model_height);
nms(&detections, iou_threshold)
}
pub fn prepare_input(
image: &Image,
model_width: u32,
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.