Coder Social home page Coder Social logo

gloryyrolg / normalized-autoencoders Goto Github PK

View Code? Open in Web Editor NEW

This project forked from swyoon/normalized-autoencoders

0.0 1.0 0.0 132.68 MB

The official repository for <Autoencoding Under Normalization Constraints> (Yoon, Noh and Park, ICML 2021).

License: MIT License

Python 15.77% Jupyter Notebook 84.23%

normalized-autoencoders's Introduction

Autoencoding Under Normalization Constraints

The official repository for <Autoencoding Under Normalization Constraints> (Yoon, Noh and Park, ICML 2021).

The paper proposes Normalized Autoencoder (NAE), which is a novel energy-based model where the energy function is the reconstruction error. NAE effectively remedies outlier reconstruction, a pathological phenomenon limiting the performance of an autoencoder as an outlier detector.

Arxiv: https://arxiv.org/abs/2105.05735
5-min video: https://www.youtube.com/watch?v=ra6usGKnPGk

MNIST-figure

Progress

  • Unit tests (tests/)
  • Training script (train.py)
  • OOD detection performance script (evaluate_ood.py)
  • Sampling script (sample.py)
  • Pretrained models for MNIST, CIFAR-10, CelebA64
  • 2D Experiments

Updates on the repository and the releases of other materials will be broadcasted through the mailing list. If you want to be kept in touch, please sign up for the mailing list.

Requirements

Environment

The project is developed under a standard PyTorch environment.

  • python 3.7.2
  • numpy
  • pillow
  • pytorch 1.7.1
  • CUDA 10.1
  • scikit-learn 0.24.2
  • tensorboard 2.5.0
  • pytest 6.2.3

Datasets

All datasets are stored in datasets/ directory.

  • MNIST, CIFAR-10, SVHN, Omniglot : Retrieved using torchvision.dataset.
  • Noise, Constant, ConstantGray : Dropbox link
  • CelebA, ImageNet 32x32: Retrieved from their official site. I am afraid that website for ImageNet 32x32 is not available as of June 24, 2021. I will temporarily upload the data to the above Dropbox link.

When set up, the dataset directory should look like as follows.

datasets
├── CelebA
│   ├── Anno
│   ├── Eval
│   └── Img
├── cifar-10-batches-py
├── const_img_gray.npy
├── const_img.npy
├── FashionMNIST
├── ImageNet32
│   ├── train_32x32
│   └── valid_32x32
├── MNIST
├── noise_img.npy
├── omniglot-py
│   ├── images_background
│   └── images_evaluation
├── test_32x32.mat
└── train_32x32.mat

Pre-trained Models

Pre-trained models are stored under pretrained/. The pre-trained models are provided through the Dropbox link.

If the pretrained models are prepared successfully, the directory structure should look like the following.

pretrained
├── celeba64_ood_nae
│   └── z64gr_h32g8
├── cifar_ood_nae
│   └── z32gn
└── mnist_ood_nae
    └── z32

Unittesting

PyTest is used for unittesting.

pytest tests

The code should pass all tests after the preparation of pre-trained models and datasets.

Execution

OOD Detection Evaluation

python evaluate_ood.py --ood ConstantGray_OOD,FashionMNIST_OOD,SVHN_OOD,CelebA_OOD,Noise_OOD --resultdir pretrained/cifar_ood_nae/z32gn/ --ckpt nae_9.pkl --config z32gn.yml --device 0 --dataset CIFAR10_OOD
Expected Results
OOD Detection Results in AUC
ConstantGray_OOD:0.9632
FashionMNIST_OOD:0.8193
SVHN_OOD:0.9196
CelebA_OOD:0.8873
Noise_OOD:1.0000

Training

Use train.py to train NAE.

  • --config option specifies a path to a configuration yaml file.
  • --logdir specifies a directory where results files will be written.
  • --run specifies an id for each run, i.e., an experiment.

Training on MNIST

python train.py --config configs/mnist_ood_nae/z32.yml --logdir results/mnist_ood_nae/ --run run --device 0

Training on CIFAR-10

python train.py --config configs/cifar_ood_nae/z32gn.yml --logdir results/cifar_ood_nae/ --run run --device 0

Training on CelebA 64x64

python train.py --config configs/celeba64_ood_nae/z64gr_h32g8.yml --logdir results/celeba64_ood_nae/z64gr_h32g8.yml --run run --device 0

Sampling

Use sample.py to generate sample images form NAE. Samples are saved as .npy file containing an (n_sample, img_h, img_w, channels) array. Note that the quality of generated images is not supposed to match that of state-of-the-art generative models. Improving the sample quality is one of the important future research direction.

Sampling for CIFAR-10

python sample.py pretrained/cifar_ood_nae/z32gn/ z32gn.yml nae_8.pkl --zstep 180 --xstep 40 --batch_size 64 --n_sample 64 --name run --device 0

Sampling for CelebA 64x64

python sample.py pretrained/celeba64_ood_nae/z64gr_h32g8/ z64gr_h32g8.yml nae_3.pkl --zstep 180 --xstep 40 --batch_size 64 --n_sample 64 --name run --device 0 --x_shape 64

Sample images for CIFAR-10 and CelebA 64x64

cifar10samples

celeba64samples

Citation

@InProceedings{pmlr-v139-yoon21c,
  title = 	 {Autoencoding Under Normalization Constraints},
  author =       {Yoon, Sangwoong and Noh, Yung-Kyun and Park, Frank},
  booktitle = 	 {Proceedings of the 38th International Conference on Machine Learning},
  pages = 	 {12087--12097},
  year = 	 {2021},
  editor = 	 {Meila, Marina and Zhang, Tong},
  volume = 	 {139},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {18--24 Jul},
  publisher =    {PMLR},
  pdf = 	 {http://proceedings.mlr.press/v139/yoon21c/yoon21c.pdf},
  url = 	 {https://proceedings.mlr.press/v139/yoon21c.html}
}
 

normalized-autoencoders's People

Contributors

swyoon avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.