Coder Social home page Coder Social logo

maxschambach / csgan Goto Github PK

View Code? Open in Web Editor NEW

This project forked from po0ya/csgan

0.0 1.0 0.0 3.29 MB

Task-Aware Compressed Sensing Using Generative Adversarial Networks (published in AAAI18)

License: Apache License 2.0

Python 96.42% Shell 3.58%

csgan's Introduction

Task-aware Compressed Sensing with Generative Adversarial Networks

This repository contains the implementation of our AAAI-18 paper: Task-aware Compressed Sensing with Generative Adversarial Networks

If you find this code or the paper useful, please consider citing:

@inproceedings{kabkab2018task,
title={Task-aware Compressed Sensing with Generative Adversarial Networks},
author={Kabkab, Maya and Samangouei, Pouya and Chellappa, Rama},
booktitle={AAAI Conference on Artificial Intelligence},
year={2018}
}

Contents

  1. Requirements
  2. Installation
  3. Usage
  4. Reproducing paper results
  5. Diagrams of training algorithms
  6. Some results

Requirements

  • Python 2.7
  • TensorFlow 1.5+ pip install tensorflow-gpu
  • Scikit-learn pip install scikit-learn
  • metric_learn pip install metric_learn
  • pillow pip install pillow
  • tqdm pip install tqdm
  • requests pip install requests

Installation

First clone this repository:

$ git clone http://github.com/po0ya/csgan

Then download the datasets with:

$ python download.py [mnist|fmnist|celeba]

Finally create this directory structure:

  • output for saving the models ln -s <path-to-proj-data>/output
  • data for datasets and intermediate files such as image means etc. ln -s <path-to-proj-data>/data
  • debug for saving debug visualizations and test outputs, anything that can be removed after a day ln -s <path-to-proj-data>/debug
  • experiments contains the configurations for different experiments, their logs, and their run scripts.

Usage

For each experiment a set of hyper-parameters are set either via the tf.flags mechanism or read from a config file. The flags take precedence over the config file values.

Training a model

We first need to generate an appropriate sampling matrix for the config file.

$ python main.py --cfg <path-to-cfg> --generate_A

Now, we are ready to train a model.

$ python main.py --cfg <path-to-cfg> --is_train <flags>

This will create an output directory based on the provided config files and flags. This directory will contain the TensorFlow checkpoints as well as the extracted features and the results of the evaluations for that experiment.

To save the reconstructions for a dataset, run:

$ python main.py --cfg experiments/cfgs/csgan/<cfg_file>.yml \
                 --reconstruction_res \
                 --cs_learning_rate <lr> \
                 --cs_max_update_iter <max_update_iter> \
                 --cs_num_random_restarts <rr>

Refer to main.py and configs for more information about flags.

Classification

After a model is trained and the reconstructions are extracted, the features will be saved under the project output directory output/default/<exp_name>/. The following python script trains a classifier with parameters specified in the classification config file. It also calculates the accuracy on the train and test splits.

$ python classify.py --cfg experiments/cfgs/cls/<cfg_file>.yml \
                     --feature_file <experiment_output_dir>/cache_iter_<num_of_cs_iter>/<feature_filename> \
                     --test_split test
                     --validate
  • The --cfg flag points to the classification config file. Note that this is different from the CSGAN config file. Sample configs can be found in experiments/cfgs/cls/
  • The --feature_file specifies the path to the training features. <feature_filename> is typically given by <split>_<feature>_lr<lr>_rr<rr>_m<measurements>_a<a_index>_c<counter>.pkl, where split=[train|val|test], and feature=[x_hats|z_hats|measurements] (specifies which features to train on: reconstructed images, latent variables, or compressed measurements).
  • The --test_split flag is optional (default=test). It sets which split (train|val|test) to test the classifier on.
  • The --validate flag is optional (default=False). When set, a validation is done (using the validation set) and the best performing checkpoint is chosen. This is only needed for neural network classifiers.
  • The --retrain flag is optional (default=False). When set, any existing checkpoints will be ignored, and training will start from scratch. Otherwise, training will resume from the most recent checkpoint. This is only needed for neural network classifiers.

Reproducing paper results

To reproduce the results reported in the paper refer to experiments/scripts which contains one script per results Table:

Notes:

  • These experiments should be run from <proj_root>.
  • Before running these scripts, first run generate_As.sh in order to generate the needed compressed sensing matrices.
  • Each script runs its experiments in the background. stdout and stderr outputs of each run are redirected to a file debug/<cfg>/<exp_file>.

Scripts

  • generate_As.sh generates fixed random measurements matrices and save them into <proj_root>/output/sampling_mats/. Usage:

      ./experiments/scripts/generate_As.sh
    
  • diff_measurements_base.sh Trains and tests models with various numbers of measurements.

      ./experiments/scripts/diff_measurements_base.sh <path-to-cfg> <comma-separated list of numbers of measurements (10,20,100)> <extra configs (optional)>
    
  • figure_1_{mnist|fmnist|celeba}_reconstruction.sh runs the experiments for Figure 1 using diff_measurements_base.sh.

      ./experiments/scripts/figure_1_{mnist|fmnist|celeba}_reconstruction.sh
    
  • table_1_base.sh trains and tests models with different numbers of uncompressed training data.

      ./experiments/scripts/figure_1_base.sh <path-to-cfg> <comma-separated list of numbers of samples (100,1000,8000)> <extra configs (optional)>
    
  • table_1_{csgan|dcgan}.sh runs the experiments of Table 1 using table_1_base.sh.

      ./experiments/scripts/table_1_{csgan|dcgan}.sh
    
  • table_2_{random|superres}.sh runs the experiments of Table 2 using table_1_base.sh.

      ./experiments/scripts/table_2_{random|superres}.sh
    
  • table_3.sh runs the experiments of Table 3.

      ./experiments/scripts/table_3.sh
    
  • diff_measurements_base.sh trains and tests models with different numbers of measurements, and extracts all features of {train|val|test} sets.

      ./experiments/scripts/diff_measurements_base_all.sh <path-to-cfg> <comma-separated number of measurements (10,20,100)> <extra configs (optional)>
    
  • table_4_5_{mnist|fmnist}.sh runs the experiments for training models with a discriminative latent space of Table 4 and 5.

      ./experiments/scripts/table_4_5_{mnist|fmnist}.sh
    
  • cl_base.sh The base script for classifying the saved features of each experiment. The path of the saved training features from table_4_5_{mnist|fmnist}.sh should be provided.

      ./experiments/scripts/cl_base.sh <classifier config> <path-to-train-features>
    

The corresponding hyper-parameters for all scripts can be found/set in experiments/cfgs

Training algorithms

Left: One iteration of the task-aware GAN training algorithm when only non-compressed (original) training samples are used. Right: One iteration of the task-aware GAN training algorithm when a combination of non-compressed (original) and compressed training samples are used. alt text

Some results

CelebA super-resolution results. Top row: original image; middle row: blurred image; bottom row: reconstructed image. alt text

MNIST reconstruction results with m = 200. Top to bottom rows: original images, reconstructions with NC = 0, reconstructions with NC = 100, reconstructions with NC = 1,000, and reconstructions with NC = 8,000. alt text

Fashion-MNIST reconstruction results when only compressed training data is available. Top row: original image; middle row: reconstructed image from m = 200 measurements; bottom row: reconstructed image from m = 400 measurements. alt text

csgan's People

Contributors

kabkabm avatar po0ya avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.