Coder Social home page Coder Social logo

congruency's Introduction

Direction Concentration Learning: Enhancing Congruency in Machine Learning

Luo, Yan and Wong, Yongkang and Kankanhalli, Mohan S and Zhao, Qi

This repository contains the code for the classification task. For the continual learning task, please refer to repository congruency_continual.
DCL (arXiv, IEEE) is a work that studies the agreement between the learned knowledge and the new information in a learning process.
The code is built on PyTorch, and partly based on GEM. It is tested under Ubuntu 1604 LTS with Python 3.6. State-of-the-art pretrained EfficientNets on CIFAR and Tiny ImageNet are included.

GD RMSProp Adam

If you find this work or the code useful in your research, please consider citing:
@article{Luo_DCL_2019,
    author={Y. {Luo} and Y. {Wong} and M. {Kankanhalli} and Q. {Zhao}},
    journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
    title={Direction Concentration Learning: Enhancing Congruency in Machine Learning},
    year={2019},
    pages={1-1},
    doi={10.1109/TPAMI.2019.2963387},
    ISSN={1939-3539}
}

TOC

  1. Prerequisites
  2. Illustration
  3. Training on ImageNet
  4. Training on Tiny ImageNet
  5. Training on CIFAR
  6. Pretrained Model

Prerequisites

  1. PyTorch 0.4.1, e.g.,
conda install pytorch=0.4.1 cuda80 -c pytorch # for CUDA 8.0
conda install pytorch=0.4.1 cuda90 -c pytorch # for CUDA 9.0

To use EfficientNet as the baseline model, it requires PyTorch 1.1.0+, e.g.,

conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch
  1. torchvision 0.2.1+, e.g.,
pip install torchvision==0.2.1
  1. quadprog, i.e.,
pip install msgpack
pip install Cython
pip install quadprog
  1. EfficientNet (optinal), i.e.,
pip install efficientnet_pytorch
  1. TensorboardX (optinal)
pip install tensorboardX==1.2

Illustration

To plot the convergence paths and the corresponding z-vs-iteration curves, execute the following commands
# go to the illustrator folder
cd illustrator 
# generate convergence path with one of optimizers (gd, rmsprop, or adam)
python convergence_visualization.py --opt gd|rmsprop|adam
# generate the corresponding z-vs-iteration curves
plot plot_z.py

Training on ImageNet

Step 1: Download ImageNet
The directory structure of the validation set should be re-organized as follows, i.e.,

imagenet/images
├── train
│   ├── n01440764
│   ├── n01443537
│   ├── n01484850
│   ├── n01491361
│   ├── n01494475
│   ...
└── val
    ├── n01440764
    ├── n01443537
    ├── n01484850
    ├── n01491361
    ├── n01494475
    ...

Either you can write your own script to achive that, or use the script provided in TensorFlow repo to do it.
Step 2: We use 8 GPUs to train the models on ImageNet. To train ResNet-50 with DCL, run

python -W ignore train_imgnet.py \
    -a resnet50 \
    --lr 0.1 \
    --lr-decay-epoch 30 \
    -b 512 \
    --epochs 90 \
    --n-classes 1000 \
    --workers 16 \
    --checkpoint-path 'checkpoints/imagenet/resnet50_dcl_1_1' \
    --mtype 'dcl' \
    --dcl-refsize 1 \
    --dcl-window 1 \
    '/path-to-folder/imagenet/images/'

To train ResNet-50 (baseline), run

python -W ignore train_imgnet.py \
    -a resnet50 \
    --lr 0.1 \
    --lr-decay-epoch 30 \
    -b 512 \
    --epochs 90 \
    --n-classes 1000 \
    --workers 16 \
    --checkpoint-path 'checkpoints/imagenet/resnet50_baseline' \
    --mtype 'baseline' \
    '/path-to-folder/imagenet/images/'

To train ResNet-50 with GEM, run

python -W ignore train_imgnet.py \
    -a resnet50 \
    --lr 0.1 \
    --lr-decay-epoch 30 \
    -b 512 \
    --epochs 90 \
    --n-classes 1000 \
    --workers 16 \
    --checkpoint-path 'checkpoints/imagenet/resnet50_gem_1' \
    --mtype 'gem' \
    --gem-memsize 1 \
    '/path-to-folder/imagenet/images/'

The above commands are written in train_imgnet.sh as well.

Training on Tiny ImageNet

Step 1: Download Tiny ImageNet
The directory structure of the dataset should be re-organized as

/home/yluo/project/dataset/tinyimagenet
├── images
│   ├── test
│   ├── train
│   └── val
├── wnids.txt
└── words.txt

Step 2: For training, please refer to train_timgnet.sh.

Training on CIFAR

The code for CIFAR experiment is built on pytorch-classification
For training, please refer to train_cifar.sh.

Pretrained Model

ImageNet

The pre-trained model with DCL on ImageNet is available at a shared Google drive. In this folder, the statistics including 1-crop top 1 validation accuracies along epochs are recorded in the file named stat.csv. Specifically, the highest val accuracy is 75.93% at epoch 86, while the mean accuracy of the baseline over 3 runs is 75.66%.
To load this pre-trained model, please first download the pre-trained model model_best.pth.tar to the created folder data and then run

python load_pretrained.py

Tiny ImageNet

The pre-trained models, i.e., ResNet-101 and EfficientNet-B1, with DCL on Tiny ImageNet are available at a shared drive ResNet_TImgNet and EfficientNet_TImgNet. The lowest val top-1 error is 15.61% while the one of the baseline is 15.73%.

CIFAR

The pre-trained models, i.e., ResNeXt-29 and EfficientNet-B1, with DCL on CIFAR-10 and CIFAR-100 are available at the shared drive ResNeXt_CIFAR10, EfficientNet_CIFAR10, ResNeXt_CIFAR100, and EfficientNet_CIFAR100, respectively. The lowest val error on CIFAR-10 is 1.79% while the one of the baseline is 1.91%. The lowest val error on CIFAR-100 is 11.65% while the one of the baseline is 11.81%. We used dcl_margin=0.1 on CIFAR-10 and 0.3 on CIFAR-100.

Contact

luoxx648 at umn.edu
Any discussions, suggestions, and questions are welcome!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.