Coder Social home page Coder Social logo

necst's Introduction

Neural Joint Source-Channel Coding

This repo contains a reference implementation for NECST as described in the paper:

Neural Joint-Source Channel Coding
Kristy Choi, Kedar Tatwawadi, Aditya Grover, Tsachy Weissman, Stefano Ermon
International Conference on Machine Learning (ICML), 2019.
Paper: https://arxiv.org/abs/1811.07557

Requirements

The codebase is implemented in Python 3.6 and Tensorflow. To install the necessary dependencies, run:

pip3 install -r requirements.txt

Datasets

A set of scripts for data pre-processing are included in the directory ./data_setup. Relevant files for The NECST model operates over Tensorflow TFRecords. A few points to note:

  1. Raw data files for MNIST and BinaryMNIST can be downloaded using data_setup/download.py. CelebA files can be downloaded using data_setup/celebA_download.py. CIFAR10 can be downloaded (with tfrecords automatically generated) using data_setup/generate_cifar10_tfrecords.py. All other data files (Omniglot, SVHN) must be downloaded separately.
  2. Omniglot and CelebA should be converted into .hdf5 format using data_setup/convert_celebA_h5.py and data_setup/convert_omniglot_h5.py respectively.
  3. Random {0,1} bits can be generated using data_setup/gen_random_bits.py.
  4. After this step, tfrecords must be generated using: data_setup/convert_to_records.py before running the model.

Options

Training the NECST model takes a set of command line arguments in the main.py script. The most relevant ones are listed below:

--datasource (STRING):    one of [mnist, BinaryMNIST, random, omniglot, celebA, svhn, cifar10]
--is_binary (BOOL):       whether or not the data is binary {0,1}, e.g. BinaryMNIST
--vimco_samples (INT):    number of samples to use for VIMCO
--channel_model (STRING): BSC/BEC
--noise (FLOAT):          channel noise level during training
--test_noise (FLOAT):     channel noise level at TEST time
--n_epochs (INT):         number of training epochs
--batch_size (INT):       size of minibatch
--lr (FLOAT):             learning rate of optimizer
--optimizer (STRING):     one of [adam, sgd]
--dech_arch (STRING):     comma-separated decoder architecture
--enc_arch (STRING):      comma-separated encoder architecture
--reg_param (FLOAT):      regularization for encoder architecture

Examples

Download and Train a 100-bit NECST model with BSC noise = 0.1 on BinaryMNIST:

# Download the BinaryMNIST dataset
python3 data_setup/download.py BinaryMNIST

# Generate a tfrecords file corresponding to the dataset
python3 data_setup/convert_to_records.py --dataset=BinaryMNIST

# Train the model
python3 main.py --datadir=./data --datasource=BinaryMNIST --channel_model=bsc --noise=0.1 --test_noise=0.1 --n_bits=100 --is_binary=True

Training a 1000-bit NECST model with BSC noise = 0.2 on CelebA:

python3 main.py --datadir=./data --datasource=celebA --channel_model=bsc --noise=0.2 --test_noise=0.2 --n_bits=1000

Citing

If you find NECST useful in your research, please consider citing the following paper:

@article{choi2018necst,
  title={Neural Joint Source-Channel Coding},
  author={Choi, Kristy and Tatwawadi, Kedar and Grover, Aditya and Weissman, Tsachy and Ermon, Stefano},
  journal={arXiv preprint arXiv:1811.07557},
  year={2018}
}

necst's People

Contributors

kristychoi avatar kedartatwawadi avatar

Watchers

James Cloos avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.