Coder Social home page Coder Social logo

bevy_ort's People

Contributors

mosure avatar

Stargazers

 avatar  avatar

Watchers

 avatar

bevy_ort's Issues

provide an example /w AsyncComputeTask inference

// TODO: provide an example /w AsyncComputeTask inference

use bevy::{
    prelude::*,
    app::AppExit,
};
use ndarray::Array;

use bevy_ort::{
    BevyOrtPlugin,
    inputs,
    Onnx,
};


// TODO: provide an example /w AsyncComputeTask inference
fn main() {
    App::new()
        .add_plugins((
            DefaultPlugins,
            BevyOrtPlugin,
        ))
        .init_resource::<Modnet>()
        .add_systems(Startup, load_modnet)
        .add_systems(Update, inference)
        .run();
}

#[derive(Resource, Default)]
pub struct Modnet {
    pub onnx: Handle<Onnx>,
}

fn load_modnet(
    asset_server: Res<AssetServer>,
    mut modnet: ResMut<Modnet>,
) {
    let modnet_handle: Handle<Onnx> = asset_server.load("modnet_photographic_portrait_matting.onnx");
    modnet.onnx = modnet_handle;
}


fn inference(
    mut exit: EventWriter<AppExit>,
    modnet: Res<Modnet>,
    onnx_assets: Res<Assets<Onnx>>,
) {
    let input = Array::<f32, _>::zeros((1, 3, 640, 640));

    let output = (|| -> Option<_> {
        let onnx = onnx_assets.get(&modnet.onnx)?;
        let session = onnx.session.as_ref()?;

        println!("inputs: {:?}", session.inputs);

        let input_values = inputs!["image" => input.view()].unwrap();
        session.run(input_values).ok()
    })();

    if let Some(output) = output {
        println!("outputs: {:?}", output.keys());

        exit.send(AppExit);
    } else {
        // TODO: use Result instead of Option for error reporting
        println!("inference failed");
    }
}

read input shape from session

// TODO: read input shape from session

use criterion::{
    BenchmarkId,
    criterion_group,
    criterion_main,
    Criterion,
    Throughput,
};

use bevy::{
    prelude::*,
    render::{
        render_asset::RenderAssetUsages,
        render_resource::{
            Extent3d,
            TextureDimension,
        },
    },
};
use bevy_ort::{
    inputs,
    models::yolo_v8::{
        prepare_input,
        process_output,
    },
    Session,
};
use ort::GraphOptimizationLevel;


criterion_group!{
    name = yolo_v8_benches;
    config = Criterion::default().sample_size(10);
    targets = prepare_input_benchmark,
        process_output_benchmark,
        inference_benchmark,
}
criterion_main!(yolo_v8_benches);


const RESOLUTIONS: [(u32, u32); 3] = [
    (640, 360),
    (1280, 720),
    (1920, 1080),
];

// TODO: read input shape from session
const MODEL_WIDTH: u32 = 640;
const MODEL_HEIGHT: u32 = 640;


fn prepare_input_benchmark(c: &mut Criterion) {
    let mut group = c.benchmark_group("yolo_v8_prepare_input");

    RESOLUTIONS.iter()
        .for_each(|(width, height)| {
            let data = vec![0u8; (width * height * 4) as usize];
            let image = Image::new(
                Extent3d {
                    width: *width,
                    height: *height,
                    depth_or_array_layers: 1,
                },
                TextureDimension::D2,
                data.clone(),
                bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
                RenderAssetUsages::all(),
            );

            group.throughput(Throughput::Elements(1));
            group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &image, |b, images| {
                b.iter(|| prepare_input(&image, MODEL_WIDTH, MODEL_HEIGHT));
            });
        });
}


fn process_output_benchmark(c: &mut Criterion) {
    let mut group = c.benchmark_group("yolo_v8_process_output");

    let session = Session::builder().unwrap()
        .with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
        .with_model_from_file("assets/yolov8n.onnx").unwrap();

    RESOLUTIONS.iter()
        .for_each(|(width, height)| {
            let data = vec![0u8; (width * height * 4) as usize];
            let image: Image = Image::new(
                Extent3d {
                    width: *width,
                    height: *height,
                    depth_or_array_layers: 1,
                },
                TextureDimension::D2,
                data.clone(),
                bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
                RenderAssetUsages::all(),
            );

            let input = prepare_input(&image, MODEL_WIDTH, MODEL_HEIGHT);
            let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap();

            let outputs = session.run(input_values).map_err(|e| e.to_string());
            let binding = outputs.ok().unwrap();
            let output_value: &ort::Value = binding.get("output0").unwrap();

            group.throughput(Throughput::Elements(1));
            group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &output_value, |b, output_value| {
                b.iter(|| process_output(output_value, *width, *height, MODEL_WIDTH, MODEL_HEIGHT));
            });
        });
}


fn inference_benchmark(c: &mut Criterion) {
    let mut group = c.benchmark_group("yolo_v8_inference");

    let session = Session::builder().unwrap()
        .with_optimization_level(GraphOptimizationLevel::Level3).unwrap()
        .with_model_from_file("assets/yolov8n.onnx").unwrap();

    RESOLUTIONS.iter().for_each(|(width, height)| {
        let data = vec![0u8; *width as usize * *height as usize * 4];
        let image = Image::new(
            Extent3d {
                width: *width,
                height: *height,
                depth_or_array_layers: 1,
            },
            TextureDimension::D2,
            data.clone(),
            bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb,
            RenderAssetUsages::all(),
        );

        let input = prepare_input(&image, MODEL_WIDTH, MODEL_HEIGHT);

        group.throughput(Throughput::Elements(1));
        group.bench_with_input(BenchmarkId::from_parameter(format!("{}x{}", width, height)), &(width, height), |b, _| {
            b.iter(|| {
                let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap();
                let outputs = session.run(input_values).map_err(|e| e.to_string());
                let binding = outputs.ok().unwrap();
                let output_value: &ort::Value = binding.get("output0").unwrap();
                process_output(output_value, *width, *height, MODEL_WIDTH, MODEL_HEIGHT);
            });
        });
    });
}

add session configuration

// TODO: add session configuration

use std::io::ErrorKind;

use bevy::{
    prelude::*,
    asset::{
        AssetLoader,
        AsyncReadExt,
        LoadContext,
        io::Reader,
    },
    utils::BoxedFuture,
};
use ort::{
    CoreMLExecutionProvider,
    CPUExecutionProvider,
    CUDAExecutionProvider,
    GraphOptimizationLevel,
    OpenVINOExecutionProvider,
};
use thiserror::Error;

pub use ort::{
    inputs,
    Session,
};


pub struct BevyOrtPlugin;
impl Plugin for BevyOrtPlugin {
    fn build(&self, app: &mut App) {
        // TODO: configurable execution providers via plugin settings
        ort::init()
            .with_execution_providers([
                CoreMLExecutionProvider::default().build(),
                CUDAExecutionProvider::default().build(),
                OpenVINOExecutionProvider::default().build(),
                CPUExecutionProvider::default().build(),
            ])
            .commit().ok();

        app.init_asset::<Onnx>();
        app.init_asset_loader::<OnnxLoader>();
    }
}


#[derive(Asset, Debug, Default, TypePath)]
pub struct Onnx {
    pub session: Option<Session>,
}


#[derive(Debug, Error)]
pub enum BevyOrtError {
    #[error("io error: {0}")]
    Io(#[from] std::io::Error),
    #[error("ort error: {0}")]
    Ort(#[from] ort::Error),
}


#[derive(Default)]
pub struct OnnxLoader;
impl AssetLoader for OnnxLoader {
    type Asset = Onnx;
    type Settings = ();
    type Error = BevyOrtError;

    fn load<'a>(
        &'a self,
        reader: &'a mut Reader,
        _settings: &'a Self::Settings,
        load_context: &'a mut LoadContext,
    ) -> BoxedFuture<'a, Result<Self::Asset, Self::Error>> {
        Box::pin(async move {
            let mut bytes = Vec::new();
            reader.read_to_end(&mut bytes).await.map_err(BevyOrtError::from)?;

            match load_context.path().extension() {
                Some(ext) if ext == "onnx" => {
                    // TODO: add session configuration
                    let session = Session::builder()?
                        .with_optimization_level(GraphOptimizationLevel::Level3)?
                        .with_intra_threads(4)?
                        .with_model_from_memory(&bytes)?;

                    Ok(Onnx {
                        session: Some(session),
                    })
                },
                _ => Err(BevyOrtError::Io(std::io::Error::new(ErrorKind::Other, "only .onnx supported"))),
            }
        })
    }

    fn extensions(&self) -> &[&str] {
        &["onnx"]
    }
}

better error handling

// TODO: better error handling

use bevy::{prelude::*, render::render_asset::RenderAssetUsages};
use image::{DynamicImage, GenericImageView, imageops::FilterType, ImageBuffer, Luma, RgbImage};
use ndarray::{Array, Array4, ArrayView4, Axis, s};


pub fn modnet_output_to_luma_images(
    output_value: &ort::Value,
) -> Vec<Image> {
    let tensor: ort::Tensor<f32> = output_value.extract_tensor::<f32>().unwrap();

    let data = tensor.view();

    let shape = data.shape();
    let batch_size = shape[0];
    let width = shape[3];
    let height = shape[2];

    let tensor_data = ArrayView4::from_shape((batch_size, 1, height, width), data.as_slice().unwrap())
        .expect("failed to create ArrayView4 from shape and data");

    let mut images = Vec::new();

    for i in 0..batch_size {
        let mut imgbuf = ImageBuffer::<Luma<u8>, Vec<u8>>::new(width as u32, height as u32);

        for y in 0..height {
            for x in 0..width {
                let pixel_value = tensor_data[(i, 0, y, x)];
                let pixel_value = (pixel_value.clamp(0.0, 1.0) * 255.0) as u8;
                imgbuf.put_pixel(x as u32, y as u32, Luma([pixel_value]));
            }
        }

        let dyn_img = DynamicImage::ImageLuma8(imgbuf);

        images.push(Image::from_dynamic(dyn_img, false, RenderAssetUsages::all()));
    }

    images
}

pub fn images_to_modnet_input(
    images: Vec<&Image>,
) -> Array4<f32> {
    // TODO: better error handling
    if images.is_empty() {
        panic!("no images provided");
    }

    let ref_size = 512;

    let &first_image = images.first().unwrap();
    assert_eq!(first_image.texture_descriptor.format, bevy::render::render_resource::TextureFormat::Rgba8UnormSrgb);

    let dynamic_image = first_image.clone().try_into_dynamic().unwrap();
    let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size);
    let resized_image = resize_image(&dynamic_image, x_scale, y_scale);
    let first_image_ndarray = image_to_ndarray(&resized_image);
    let single_image_shape = first_image_ndarray.dim();
    let n_images = images.len();
    let batch_shape = (n_images, single_image_shape.1, single_image_shape.2, single_image_shape.3);

    let mut aggregate = Array4::<f32>::zeros(batch_shape);

    for (i, &image) in images.iter().enumerate() {
        let dynamic_image = image.clone().try_into_dynamic().unwrap();
        let (x_scale, y_scale) = get_scale_factor(dynamic_image.height(), dynamic_image.width(), ref_size);
        let resized_image = resize_image(&dynamic_image, x_scale, y_scale);
        let image_ndarray = image_to_ndarray(&resized_image);

        let slice = s![i, .., .., ..];
        aggregate.slice_mut(slice).assign(&image_ndarray.index_axis_move(Axis(0), 0));
    }

    aggregate
}

draw lines between keypoints

// TODO: draw lines between keypoints

use bevy::{
    prelude::*,
    window::PrimaryWindow,
};

use bevy_ort::{
    BevyOrtPlugin,
    models::lightglue::{
        GluedPair,
        lightglue_inference,
        Lightglue,
        LightgluePlugin,
    },
    Onnx,
};


fn main() {
    App::new()
        .add_plugins((
            DefaultPlugins,
            BevyOrtPlugin,
            LightgluePlugin,
        ))
        .init_resource::<LightglueInput>()
        .add_systems(Startup, load_lightglue)
        .add_systems(Update, inference)
        .run();
}


#[derive(Resource, Default)]
pub struct LightglueInput {
    pub a: Handle<Image>,
    pub b: Handle<Image>,
}


fn load_lightglue(
    asset_server: Res<AssetServer>,
    mut lightglue: ResMut<Lightglue>,
    mut input: ResMut<LightglueInput>,
) {
    let lightglue_handle: Handle<Onnx> = asset_server.load("models/disk_lightglue_end2end_fused_cpu.onnx");
    lightglue.onnx = lightglue_handle;

    input.a = asset_server.load("images/sacre_coeur1.png");
    input.b = asset_server.load("images/sacre_coeur2.png");
}


fn inference(
    mut commands: Commands,
    lightglue: Res<Lightglue>,
    input: Res<LightglueInput>,
    onnx_assets: Res<Assets<Onnx>>,
    images: Res<Assets<Image>>,
    primary_window: Query<&Window, With<PrimaryWindow>>,
    mut complete: Local<bool>,
) {
    if *complete {
        return;
    }

    let window = primary_window.single();

    let images = vec![
        images.get(&input.a).expect("failed to get image asset"),
        images.get(&input.b).expect("failed to get image asset"),
    ];
    let images = images.iter().map(|image| *image).collect::<Vec<_>>();

    let glued_pairs: Result<Vec<(usize, usize, Vec<GluedPair>)>, String> = (|| {
        let onnx = onnx_assets.get(&lightglue.onnx).ok_or("failed to get ONNX asset")?;
        let session_lock = onnx.session.lock().map_err(|e| e.to_string())?;
        let session = session_lock.as_ref().ok_or("failed to get session from ONNX asset")?;

        Ok(lightglue_inference(
            session,
            images.as_slice(),
        ))
    })();

    match glued_pairs {
        Ok(glued_pairs) => {
            println!("glued_pairs: {:?}", glued_pairs[0].2.len());

            commands.spawn(NodeBundle {
                style: Style {
                    display: Display::Grid,
                    width: Val::Percent(100.0),
                    height: Val::Percent(100.0),
                    grid_template_columns: RepeatedGridTrack::flex(2, 1.0),
                    grid_template_rows: RepeatedGridTrack::flex(2, 1.0),
                    ..default()
                },
                background_color: BackgroundColor(Color::DARK_GRAY),
                ..default()
            })
            .with_children(|builder| {
                builder.spawn(ImageBundle {
                    style: Style {
                        ..default()
                    },
                    image: UiImage::new(input.a.clone()),
                    ..default()
                })
                .with_children(|builder| {
                    let image_width = images[0].width() as f32;
                    let image_height = images[0].height() as f32;

                    let display_width = window.width() as f32 / 2.0;
                    let display_height = window.height() as f32 / 2.0;

                    let scale_x = display_width / image_width;
                    let scale_y = display_height / image_height;

                    glued_pairs[0].2.iter().for_each(|pair| {
                        let scaled_x = pair.from_x as f32 * scale_x;
                        let scaled_y = pair.from_y as f32 * scale_y;

                        builder.spawn(NodeBundle {
                            style: Style {
                                position_type: PositionType::Absolute,
                                left: Val::Px(scaled_x),
                                top: Val::Px(scaled_y),
                                width: Val::Px(2.0),
                                height: Val::Px(2.0),
                                ..default()
                            },
                            background_color: Color::rgb(1.0, 0.0, 0.0).into(),
                            ..default()
                        });
                    });
                });

                builder.spawn(ImageBundle {
                    style: Style {
                        ..default()
                    },
                    image: UiImage::new(input.b.clone()),
                    ..default()
                })
                .with_children(|builder| {
                    let image_width = images[1].width() as f32;
                    let image_height = images[1].height() as f32;

                    let display_width = window.width() as f32 / 2.0;
                    let display_height = window.height() as f32 / 2.0;

                    let scale_x = display_width / image_width;
                    let scale_y = display_height / image_height;

                    glued_pairs[0].2.iter().for_each(|pair| {
                        let scaled_x = pair.to_x as f32 * scale_x;
                        let scaled_y = pair.to_y as f32 * scale_y;

                        builder.spawn(NodeBundle {
                            style: Style {
                                position_type: PositionType::Absolute,
                                left: Val::Px(scaled_x),
                                top: Val::Px(scaled_y),
                                width: Val::Px(2.0),
                                height: Val::Px(2.0),
                                ..default()
                            },
                            background_color: Color::rgb(0.0, 1.0, 0.0).into(),
                            ..default()
                        });
                    });
                });

                builder.spawn(ImageBundle {
                    style: Style {
                        ..default()
                    },
                    image: UiImage::new(input.a.clone()),
                    ..default()
                });

                builder.spawn(ImageBundle {
                    style: Style {
                        ..default()
                    },
                    image: UiImage::new(input.b.clone()),
                    ..default()
                });

                // TODO: draw lines between keypoints
            });

            commands.spawn(Camera2dBundle::default());

            *complete = true;
        }
        Err(e) => {
            eprintln!("inference failed: {}", e);
        }
    }
}

use Result instead of Option for error reporting

// TODO: use Result instead of Option for error reporting

use bevy::{
    prelude::*,
    app::AppExit,
};
use ndarray::Array;

use bevy_ort::{
    BevyOrtPlugin,
    inputs,
    Onnx,
};


// TODO: provide an example /w AsyncComputeTask inference
fn main() {
    App::new()
        .add_plugins((
            DefaultPlugins,
            BevyOrtPlugin,
        ))
        .init_resource::<Modnet>()
        .add_systems(Startup, load_modnet)
        .add_systems(Update, inference)
        .run();
}

#[derive(Resource, Default)]
pub struct Modnet {
    pub onnx: Handle<Onnx>,
}

fn load_modnet(
    asset_server: Res<AssetServer>,
    mut modnet: ResMut<Modnet>,
) {
    let modnet_handle: Handle<Onnx> = asset_server.load("modnet_photographic_portrait_matting.onnx");
    modnet.onnx = modnet_handle;
}


fn inference(
    mut exit: EventWriter<AppExit>,
    modnet: Res<Modnet>,
    onnx_assets: Res<Assets<Onnx>>,
) {
    let input = Array::<f32, _>::zeros((1, 3, 640, 640));

    let output = (|| -> Option<_> {
        let onnx = onnx_assets.get(&modnet.onnx)?;
        let session = onnx.session.as_ref()?;

        println!("inputs: {:?}", session.inputs);

        let input_values = inputs!["image" => input.view()].unwrap();
        session.run(input_values).ok()
    })();

    if let Some(output) = output {
        println!("outputs: {:?}", output.keys());

        exit.send(AppExit);
    } else {
        // TODO: use Result instead of Option for error reporting
        println!("inference failed");
    }
}

configurable execution providers via plugin settings

// TODO: configurable execution providers via plugin settings

use std::io::ErrorKind;

use bevy::{
    prelude::*,
    asset::{
        AssetLoader,
        AsyncReadExt,
        LoadContext,
        io::Reader,
    },
    utils::BoxedFuture,
};
use ort::{
    CoreMLExecutionProvider,
    CPUExecutionProvider,
    CUDAExecutionProvider,
    GraphOptimizationLevel,
    OpenVINOExecutionProvider,
};
use thiserror::Error;

pub use ort::{
    inputs,
    Session,
};


pub struct BevyOrtPlugin;
impl Plugin for BevyOrtPlugin {
    fn build(&self, app: &mut App) {
        // TODO: configurable execution providers via plugin settings
        ort::init()
            .with_execution_providers([
                CoreMLExecutionProvider::default().build(),
                CUDAExecutionProvider::default().build(),
                OpenVINOExecutionProvider::default().build(),
                CPUExecutionProvider::default().build(),
            ])
            .commit().ok();

        app.init_asset::<Onnx>();
        app.init_asset_loader::<OnnxLoader>();
    }
}


#[derive(Asset, Debug, Default, TypePath)]
pub struct Onnx {
    pub session: Option<Session>,
}


#[derive(Debug, Error)]
pub enum BevyOrtError {
    #[error("io error: {0}")]
    Io(#[from] std::io::Error),
    #[error("ort error: {0}")]
    Ort(#[from] ort::Error),
}


#[derive(Default)]
pub struct OnnxLoader;
impl AssetLoader for OnnxLoader {
    type Asset = Onnx;
    type Settings = ();
    type Error = BevyOrtError;

    fn load<'a>(
        &'a self,
        reader: &'a mut Reader,
        _settings: &'a Self::Settings,
        load_context: &'a mut LoadContext,
    ) -> BoxedFuture<'a, Result<Self::Asset, Self::Error>> {
        Box::pin(async move {
            let mut bytes = Vec::new();
            reader.read_to_end(&mut bytes).await.map_err(BevyOrtError::from)?;

            match load_context.path().extension() {
                Some(ext) if ext == "onnx" => {
                    // TODO: add session configuration
                    let session = Session::builder()?
                        .with_optimization_level(GraphOptimizationLevel::Level3)?
                        .with_intra_threads(4)?
                        .with_model_from_memory(&bytes)?;

                    Ok(Onnx {
                        session: Some(session),
                    })
                },
                _ => Err(BevyOrtError::Io(std::io::Error::new(ErrorKind::Other, "only .onnx supported"))),
            }
        })
    }

    fn extensions(&self) -> &[&str] {
        &["onnx"]
    }
}

support yolo input batching

// TODO: support yolo input batching

}


pub struct YoloPlugin;
impl Plugin for YoloPlugin {
    fn build(&self, app: &mut App) {
        app.init_resource::<Yolo>();
    }
}

#[derive(Resource, Default)]
pub struct Yolo {
    pub onnx: Handle<Onnx>,
}


// TODO: support yolo input batching
pub fn yolo_inference(
    session: &ort::Session,
    image: &Image,
    iou_threshold: f32,
) -> Vec<BoundingBox> {
    let width = image.width();
    let height = image.height();

    let model_width = session.inputs[0].input_type.tensor_dimensions().unwrap()[2] as u32;
    let model_height = session.inputs[0].input_type.tensor_dimensions().unwrap()[3] as u32;

    let input = prepare_input(image, model_width, model_height);

    let input_values = inputs!["images" => &input.as_standard_layout()].map_err(|e| e.to_string()).unwrap();
    let outputs = session.run(input_values).map_err(|e| e.to_string());
    let binding = outputs.ok().unwrap();
    let output_value: &ort::Value = binding.get("output0").unwrap();

    let detections = process_output(output_value, width, height, model_width, model_height);

    nms(&detections, iou_threshold)
}


pub fn prepare_input(
    image: &Image,
    model_width: u32,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.