Coder Social home page Coder Social logo

env-jempp's Introduction

JEM++: Improved Techniques for Training JEM

Official code for the paper JEM++: Improved Techniques for Training JEM

Added the file main.py and script.sh in order to run the code

Usage

Prerequisite

model and mean/cov data in https://1drv.ms/u/s!AgCFFlwzHuH8l0QWQ8E2aMtC0ApO?e=pE43fR

Pretrained model is jempp_M10.pt

  1. Install from the requirements.txt, please check the details pip install -r requirements.txt
  2. Download the mean/covariance (cifar10_mean/cov.pt) data from above link

Training

To train a model on CIFAR10 as in the paper, please refer to scripts/cifar10.sh

python train_jempp.py --dataset=cifar10 \
 --lr=.1 --optimizer=sgd \
 --p_x_weight=1.0 --p_y_given_x_weight=1.0 --p_x_y_weight=0.0 \
 --sigma=.03 --width=10 --depth=28 \
 --plot_uncond --warmup_iters=1000 \
 --log_arg=JEMPP-n_steps-in_steps-pyld_lr \
 --model=yopo \
 --norm batch \
 --print_every=100 \
 --n_epochs=150 --decay_epochs 50 100 125 \
 --n_steps=10 \
 --in_steps=5 \
 --pyld_lr=0.2 \
 --gpu-id=3

Evaluation

To evaluate the classifier (on CIFAR10):

python eval_jempp.py --load_path /PATH/TO/YOUR/MODEL.pt --eval test_clf --dataset cifar_test --model yopo --norm batch

To evaluate the FID in the replay buffer (on CIFAR10): ratio >= buffer size, use all images.

python eval_jempp.py --load_path /PATH/TO/YOUR/MODEL.pt --eval fid --model yopo --norm batch --ratio 10000  
# jempp_M10.pt's FID is 36.5 

Demo

Trained model samples from scratch

Gif:

Demo_Gif

Video:

evo.mp4

Citation

If you found this work useful and used it on your own research, please concider citing this paper.

@article{yang2021jempp,
    title={JEM++: Improved Techniques for Training JEM},
    author={Xiulong Yang and Shihao Ji},
    journal={International Conference on Computer Vision (ICCV)},
    month={Oct.},
    year={2021}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.