Coder Social home page Coder Social logo

tch-rs's Introduction

tch-rs

Rust bindings for PyTorch. The goal of the tch crate is to provide some thin wrappers around the C++ PyTorch api (a.k.a. libtorch). It aims at staying as close as possible to the original C++ api. More idiomatic rust bindings could then be developed on top of this. The documentation can be found on docs.rs.

Build Status Latest version Documentation License

The code generation part for the C api on top of libtorch comes from ocaml-torch.

Getting Started

This crate requires the C++ version of PyTorch (libtorch) to be available on your system. You can either install it manually and let the build script know about it via the LIBTORCH environment variable. If not set, the build script will try downloading and extracting a pre-built binary version of libtorch.

Libtorch Manual Install

export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=${LIBTORCH}/lib:$LD_LIBRARY_PATH
  • You should now be able to run some examples, e.g. cargo run --example basics.

Examples

Writing a Simple Neural Network

The following code defines a simple model with one hidden layer.

struct Net {
    fc1: nn::Linear,
    fc2: nn::Linear,
}

impl Net {
    fn new(vs: &mut nn::VarStore) -> Net {
        let fc1 = nn::Linear::new(vs, IMAGE_DIM, HIDDEN_NODES);
        let fc2 = nn::Linear::new(vs, HIDDEN_NODES, LABELS);
        Net { fc1, fc2 }
    }
}

impl nn::Module for Net {
    fn forward(&self, xs: &Tensor) -> Tensor {
        xs.apply(&self.fc1).relu().apply(&self.fc2)
    }
}

This model can be trained on the MNIST dataset by running the following command.

cargo run --example mnist

More details on the training loop can be found in the detailed tutorial.

Using some Pre-Trained Model

The pretrained-models example illustrates how to use some pre-trained computer vision model on an image. The weights - which have been extracted from the PyTorch implementation - can be downloaded here resnet18.ot and here resnet34.ot.

The example can then be run via the following command:

cargo run --example pretrained-models -- resnet18.ot tiger.jpg

This should print the top 5 imagenet categories for the image. The code for this example is pretty simple.

    // First the image is loaded and resized to 224x224.
    let image = imagenet::load_image_and_resize(image_file)?;

    // A variable store is created to hold the model parameters.
    let vs = tch::nn::VarStore::new(tch::Device::Cpu);

    // Then the model is built on this variable store, and the weights are loaded.
    let resnet18 = tch::vision::resnet::resnet18(vs.root(), imagenet::CLASS_COUNT);
    vs.load(weight_file)?;

    // Apply the forward pass of the model to get the logits and convert them
    // to probabilities via a softmax.
    let output = resnet18
        .forward_t(&image.unsqueeze(0), /*train=*/ false)
        .softmax(-1);

    // Finally print the top 5 categories and their associated probabilities.
    for (probability, class) in imagenet::top(&output, 5).iter() {
        println!("{:50} {:5.2}%", class, 100.0 * probability)
    }

Further examples include:

tch-rs's People

Contributors

laurentmazare avatar ehsanmok avatar mwilliammyers avatar beingflo avatar

Watchers

cksac avatar James Cloos avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.