Coder Social home page Coder Social logo

iamalexkorotin / neuraloptimaltransport Goto Github PK

View Code? Open in Web Editor NEW
154.0 4.0 18.0 39.19 MB

PyTorch implementation of "Neural Optimal Transport" (ICLR 2023 Spotlight)

Home Page: https://openreview.net/forum?id=d8CBRlWNkqH

License: MIT License

Jupyter Notebook 99.27% Python 0.73%
deep-learning image-to-image-translation neural-networks optimal-transport neural-optimal-transport unpaired-translation

neuraloptimaltransport's Introduction

Neural Optimal Transport (NOT)

This is the official Python implementation of the ICLR 2023 spotlight paper Neural Optimal Transport (NOT paper on openreview) by Alexander Korotin, Daniil Selikhanovych and Evgeny Burnaev.

The repository contains reproducible PyTorch source code for computing optimal transport (OT) maps and plans for strong and weak transport costs in high dimensions with neural networks. Examples are provided for toy problems (1D, 2D) and for the unpaired image-to-image translation task for various pairs of datasets.

Repository structure

The implementation is GPU-based with the multi-GPU support. Tested with torch== 1.9.0 and 1-4 Tesla V100.

All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/). For convenience, the majority of the evaluation output is preserved. Auxilary source code is moved to .py modules (src/).

  • notebooks/NOT_toy_1D.ipynb - toy experiments in 1D (weak costs);
  • notebooks/NOT_toy_2D.ipynb - toy experiments in 2D (weak costs);
  • notebooks/NOT_training_strong.ipynb - unpaired image-to-image translation (one-to-one, strong costs);
  • notebooks/NOT_training_weak.ipynb - unpaired image-to-image translation (one-to-many, weak costs);
  • notebooks/NOT_plots.ipynb - plotting the translation results (pre-trained models are needed);
  • stats/compute_stats.ipynb - pre-compute InceptionV3 statistics to speed up test FID computation;

Setup

To run the notebooks, it is recommended to create a virtual environment using either conda or venv. Once the virtual environment is set up, install the required dependencies by running the following command:

pip install -r requirements.txt

Finally, make sure to install torch and torchvision. It is advisable to install these packages based on your system and CUDA version. Please refer to the official website for detailed installation instructions.

Educational Materials

Citation

@inproceedings{
    korotin2023neural,
    title={Neural Optimal Transport},
    author={Korotin, Alexander and Selikhanovych, Daniil and Burnaev, Evgeny},
    booktitle={International Conference on Learning Representations},
    year={2023},
    url={https://openreview.net/forum?id=d8CBRlWNkqH}
}

Application to Unpaired Image-to-Image Translation Task

The unpaired domain translation task can be posed as an OT problem. Our NOT algorithm is applicable here. It searches for a transport map with the minimal transport cost (we use $\ell^{2}$), i.e., it naturally aims to preserve certain image attributes during the translation.

Compared to the popular image-to-image translation models based on GANs or diffusion models, our method provides the following key advantages

  • controlable amount of diversity in generated samples (without any duct tape or heuristics);
  • better interpretability of the learned map.

Qualitative examples are shown below for various pairs of datasets (at resolutions $128\times 128$ and $64\times 64$).

One-to-one translation, strong OT

We show unpaired translition with NOT with the strong quadratic cost on outdoor → church, celeba (female) → anime, shoes → handbags, handbags → shoes, male → female, celeba (female) → anime, anime → shoes, anime → celeba (female) dataset pairs.

One-to-many translation, weak OT

We show unpaired translition with NOT with the $\gamma$-weak quadratic cost on handbags → shoes, celeba (female) → anime, outdoor → church, anime → shoes, shoes → handbags, anime → celeba (female) dataset pairs.

Controlling the amount of diversity

Our method offers a single parameter $\gamma\in[0,+\infty)$ in the weak quadratic cost to control the amount of diversity.

Datasets

The dataloaders can be created by load_dataset function from src/tools.py. The latter four datasets get loaded directly to RAM.

Presentations

Related repositories

Credits

neuraloptimaltransport's People

Contributors

abdurrahheem avatar iamalexkorotin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

neuraloptimaltransport's Issues

AttributeError: module 'ot' has no attribute 'weak_optimal_transport'

Hi Alexander Korotin,

Thanks for your great work! I am quite interesting in your work that utilizes neural network to learn optimal maps and plans.

When I run the illustration part in demo NOT_toy_1D.ipynb, the following error emerge:
1702903046289

It seems that the function weak_optimal_transport in module OT is missing. Could you please help solve the above error? Thank you very much.

How to apply NOT to descrete data?

Hi! First of all, thank you so much for t he NOT! I am wondering how to use NOT in unpaired translation task for descrete data like text tokens. I guess the loss function should be changed. But I have no clue! Would you like to talk about it?
Thank you anyway for making the NOT available!

Best,
Zhangzhi

Missing requirement file

May be it better to add requirements.txt file? It is quite annoying to install packages following error from notebook logs?

Usage

Hello,

I'm trying to apply OT to an image, but I'm not sure how to use this repo. Can someone write a usage description?

Thanks

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.