Coder Social home page Coder Social logo

cgns_rc's Introduction

Counterfactual Generative Networks

CGN


## Setup ##
Install anaconda (if you don't have it yet)
```Shell
wget https://repo.anaconda.com/archive/Anaconda3-2020.11-Linux-x86_64.sh
bash Anaconda3-2020.11-Linux-x86_64.sh
source ~/.profile

Clone the repo and build the environment

git clone https://github.com/ambekarsameer96/FACT_AI.git
cd code/counterfactual_generative_networks-main
conda env create -f environment.yml
conda activate cgn

Make all scripts executable: chmod +x scripts/*. Then, download the datasets (colored MNIST, Cue-Conflict, IN-9) and the pre-trained weights (CGN, U2-Net). Comment out the ones you don't need.

./scripts/download_data.sh
./scripts/download_weights.sh

MNIST Experiments

The main functions of this sub-repo are:

  • Generating the MNIST variants
  • Training a CGN
  • Generating counterfactual datasets
  • Training a shape classifier

Train the CGN

#MNIST dataset """

Ablation study code running

python mnists/train_classifier.py --dataset double_colored_MNIST_counterfactual --ablation True

"""To run with SSIM loss (Use SSIM Flag to activate it and mention the corresponding the dataset name and other fields.)

usage: python mnists/train_cgn.py --cfg Dataset_cfg_file ---ssim_flag true

"""

python mnists/train_cgn.py --cfg mnists/experiments/cgn_wildlife_MNIST/cfg.yaml ---ssim_flag true

python mnists/generate_data.py \
--dataset wildlife_MNIST --no_cfs 10 --dataset_size 100000

python mnists/train_classifier.py --dataset wildlife_MNIST_counterfactual

"""To run with color jitter augmentation. Seperate files with the extension/suffix augment hve been created which perform the same task as the original files but the only difference is these files perform 'Color Jitter' augmentation.

usage: python mnists/train_cgn.py --cfg Dataset_cfg_file ---ssim_flag true

"""

python mnists/train_cgn_augment.py --cfg mnists/experiments/cgn_wildlife_MNIST/cfg.yaml

python mnists/generate_data_augment.py \
--dataset wildlife_MNIST --no_cfs 10 --dataset_size 100000

python mnists/train_classifier_augment.py --dataset wildlife_MNIST_counterfactual
__Distributed Training__. To switch to multi-GPU training, run ```echo $CUDA_VISIBLE_DEVICES``` to see if the GPUs are visible. In the case of a
single node with several GPUs, you can run, e.g.,

Imagenet dataset

Test Inception score

"""

python imagenet/generate_data.py --n_data 32 --weights_path imagenet/weights/cgn.pth --mode random --run_name val --truncation 0.5 --batch_sz 1

"""

Training the CGN

NOTE: Training the CGN for Imagenet utilises biggan-256 and U2-net weights. It runs for 1.2m iterations(approx 0.5/s). Prefer to skip this part if adequate resource not available. """

python imagenet/train_cgn.py --model_name MODEL_NAME

"""## Generate Counterfactual Images

2 Folders of counterfactual images are needed(Train, Val). Train has 35,000 counterfactuals, Val has 5000 counterfactual images. We split it to provided 1:1 ratio that is recommended in the paper(Imagenet-1k is replaced with Imagenet-1k(mini)) """

python imagenet/generate_data.py --n_data 35000 --weights_path imagenet/weights/cgn.pth --mode random --run_name train --truncation 0.5 --batch_sz 1

python imagenet/generate_data.py --n_data 5000 --weights_path imagenet/weights/cgn.pth --mode random --run_name val --truncation 0.5 --batch_sz 1

"""Move the val, train into a cf_data folder"""

%cp -r /content/counterfactual_generative_networks/imagenet/data/train /content/counterfactual_generative_networks/imagenet/data/cf_data
Commented out IPython magic to ensure Python compatibility.
%cp -r /content/counterfactual_generative_networks/imagenet/data/val /content/counterfactual_generative_networks/imagenet/data/cf_data

"""## Training the classifier

We replaced Imagenet-1k with Imagenet-1k(mini). """

python imagenet/train_classifier.py -a resnet50 -b 32 --lr 0.001 -j 6 \
--epochs 45 --pretrained --cf_data imagenet/cf_data --name RUN_NAME

"""

Plotting Explainability heatmaps

Replace the --image_loc path with the image of your choice.

python lime_plot.py --image_loc '/content/counterfactual_generative_networks/imagenet/data/mini-imagenet/train/fg_n02002556_34176_bg_n03272562_13865.JPEG'

Visualization. To visualize the Tensorboard outputs, run tensorboard --logdir=imagenet/runs and open the local address in your browser.

Acknowledgments

We like to acknowledge several repos of which we use parts of code, data, or models in our implementation:

cgns_rc's People

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.