Coder Social home page Coder Social logo

hidden-networks's Introduction

What's hidden in a randomly weighted neural network?

by Vivek Ramanujan*, Mitchell Wortsman*, Aniruddha Kembhavi, Ali Farhadi, Mohammad Rastegari

arxiv link: https://arxiv.org/abs/1911.13299

News & Updates

  • Simple one file example! Check out simple_mnist_example.py.
  • Faster version of GetSubNet written by Suchin Gururangan! Feel free to replace the old version with this:
def percentile(t, q):
    k = 1 + round(.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()
    
class GetSubnetFaster(torch.autograd.Function):
    @staticmethod
    def forward(ctx, scores, zeros, ones, sparsity):
        k_val = percentile(scores, sparsity*100)
        return torch.where(scores < k_val, zeros.to(scores.device), ones.to(scores.device))

    @staticmethod
    def backward(ctx, g):
        return g, None, None, None

Setup

  1. Set up a virtualenv with python 3.7.4. You can use pyvenv or conda for this.
  2. Run pip install -r requirements.txt to get requirements
  3. Create a data directory as a base for all datasets. For example, if your base directory is /mnt/datasets then imagenet would be located at /mnt/datasets/imagenet and CIFAR-10 would be located at /mnt/datasets/cifar10

Starting an Experiment

We use config files located in the configs/ folder to organize our experiments. The basic setup for any experiment is:

python main.py --config <path/to/config> <override-args>

Common example override-args include --multigpu=<gpu-ids seperated by commas, no spaces> to run on GPUs, and --prune-rate to set the prune rate, weights_remaining in our paper, for an experiment. Run python main --help for more details.

YAML Name Key

(u)uc -> (unscaled) unsigned constant
(u)sc -> (unscaled) signed constant
(u)pt -> (unscaled) pretrained init
(u)kn -> (unscaled) kaiming normal

Example Run

python main.py --config configs/smallscale/conv4/conv4_usc_unsigned.yml \
               --multigpu 0 \
               --name example \
               --data <path/to/data-dir> \
               --prune-rate 0.5

Expected Results and Pretrained Models

Model Params % Weights Remaining Initialization Accuracy (ImageNet)
ResNet-50 7.7M 30% Kaiming Normal 61.7
ResNet-50 7.7M 30% Signed Kaiming Constant 68.6
ResNet-101 13.3M 30% Kaiming Normal 66.15
ResNet-101 13.3M 30% Signed Kaiming Constant 72.3
Wide ResNet-50 20.6M 30% Kaiming Normal 67.9
Wide ResNet-50 20.6M 30% Signed Kaiming Constant 73.3

To use a pretrained model use the --pretrained=<path/to/pretrained-checkpoint> flag.

Tracking

tensorboard --logdir runs/ --bind_all

When your experiment is done, a CSV entry will be written (or appended) to runs/results.csv. Your experiment base directory will automatically be written to runs/<config-name>/prune-rate=<prune-rate>/<experiment-name> with checkpoints/ and logs/ subdirectories. If your experiment happens to match a previously created experiment base directory then an integer increment will be added to the filepath (eg. /0, /1, etc.). Checkpoints by default will have the first, best, and last models. To change this behavior, use the --save-every flag.

Requirements

Python 3.7.4, CUDA Version 10.1 (also works with 9.2 and 10.0):

absl-py==0.8.1
grpcio==1.24.3
Markdown==3.1.1
numpy==1.17.3
Pillow==6.2.1
protobuf==3.10.0
PyYAML==5.1.2
six==1.12.0
tensorboard==2.0.0
torch==1.3.0
torchvision==0.4.1
tqdm==4.36.1
Werkzeug==0.16.0

hidden-networks's People

Contributors

vkramanuj avatar mitchellnw avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.