Coder Social home page Coder Social logo

congruency's Introduction

Direction Concentration Learning: Enhancing Congruency in Machine Learning

Luo, Yan and Wong, Yongkang and Kankanhalli, Mohan S and Zhao, Qi

This repository contains the code for the classification task. For the continual learning task, please refer to repository congruency_continual.
DCL (arXiv, IEEE) is a work that studies the agreement between the learned knowledge and the new information in a learning process.
The code is built on PyTorch, and partly based on GEM. It is tested under Ubuntu 1604 LTS with Python 3.6. State-of-the-art pretrained EfficientNets on CIFAR and Tiny ImageNet are included.

GD RMSProp Adam

If you find this work or the code useful in your research, please consider citing:
@article{Luo_DCL_2019,
    author={Y. {Luo} and Y. {Wong} and M. {Kankanhalli} and Q. {Zhao}},
    journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
    title={Direction Concentration Learning: Enhancing Congruency in Machine Learning},
    year={2019},
    pages={1-1},
    doi={10.1109/TPAMI.2019.2963387},
    ISSN={1939-3539}
}

TOC

  1. Prerequisites
  2. Illustration
  3. Training on ImageNet
  4. Training on Tiny ImageNet
  5. Training on CIFAR
  6. Pretrained Model

Prerequisites

  1. PyTorch 0.4.1, e.g.,
conda install pytorch=0.4.1 cuda80 -c pytorch # for CUDA 8.0
conda install pytorch=0.4.1 cuda90 -c pytorch # for CUDA 9.0

To use EfficientNet as the baseline model, it requires PyTorch 1.1.0+, e.g.,

conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=10.0 -c pytorch
  1. torchvision 0.2.1+, e.g.,
pip install torchvision==0.2.1
  1. quadprog, i.e.,
pip install msgpack
pip install Cython
pip install quadprog
  1. EfficientNet (optinal), i.e.,
pip install efficientnet_pytorch
  1. TensorboardX (optinal)
pip install tensorboardX==1.2

Illustration

To plot the convergence paths and the corresponding z-vs-iteration curves, execute the following commands
# go to the illustrator folder
cd illustrator 
# generate convergence path with one of optimizers (gd, rmsprop, or adam)
python convergence_visualization.py --opt gd|rmsprop|adam
# generate the corresponding z-vs-iteration curves
plot plot_z.py

Training on ImageNet

Step 1: Download ImageNet
The directory structure of the validation set should be re-organized as follows, i.e.,

imagenet/images
├── train
│   ├── n01440764
│   ├── n01443537
│   ├── n01484850
│   ├── n01491361
│   ├── n01494475
│   ...
└── val
    ├── n01440764
    ├── n01443537
    ├── n01484850
    ├── n01491361
    ├── n01494475
    ...

Either you can write your own script to achive that, or use the script provided in TensorFlow repo to do it.
Step 2: We use 8 GPUs to train the models on ImageNet. To train ResNet-50 with DCL, run

python -W ignore train_imgnet.py \
    -a resnet50 \
    --lr 0.1 \
    --lr-decay-epoch 30 \
    -b 512 \
    --epochs 90 \
    --n-classes 1000 \
    --workers 16 \
    --checkpoint-path 'checkpoints/imagenet/resnet50_dcl_1_1' \
    --mtype 'dcl' \
    --dcl-refsize 1 \
    --dcl-window 1 \
    '/path-to-folder/imagenet/images/'

To train ResNet-50 (baseline), run

python -W ignore train_imgnet.py \
    -a resnet50 \
    --lr 0.1 \
    --lr-decay-epoch 30 \
    -b 512 \
    --epochs 90 \
    --n-classes 1000 \
    --workers 16 \
    --checkpoint-path 'checkpoints/imagenet/resnet50_baseline' \
    --mtype 'baseline' \
    '/path-to-folder/imagenet/images/'

To train ResNet-50 with GEM, run

python -W ignore train_imgnet.py \
    -a resnet50 \
    --lr 0.1 \
    --lr-decay-epoch 30 \
    -b 512 \
    --epochs 90 \
    --n-classes 1000 \
    --workers 16 \
    --checkpoint-path 'checkpoints/imagenet/resnet50_gem_1' \
    --mtype 'gem' \
    --gem-memsize 1 \
    '/path-to-folder/imagenet/images/'

The above commands are written in train_imgnet.sh as well.

Training on Tiny ImageNet

Step 1: Download Tiny ImageNet
The directory structure of the dataset should be re-organized as

/home/yluo/project/dataset/tinyimagenet
├── images
│   ├── test
│   ├── train
│   └── val
├── wnids.txt
└── words.txt

Step 2: For training, please refer to train_timgnet.sh.

Training on CIFAR

The code for CIFAR experiment is built on pytorch-classification
For training, please refer to train_cifar.sh.

Pretrained Model

ImageNet

The pre-trained model with DCL on ImageNet is available at a shared Google drive. In this folder, the statistics including 1-crop top 1 validation accuracies along epochs are recorded in the file named stat.csv. Specifically, the highest val accuracy is 75.93% at epoch 86, while the mean accuracy of the baseline over 3 runs is 75.66%.
To load this pre-trained model, please first download the pre-trained model model_best.pth.tar to the created folder data and then run

python load_pretrained.py

Tiny ImageNet

The pre-trained models, i.e., ResNet-101 and EfficientNet-B1, with DCL on Tiny ImageNet are available at a shared drive ResNet_TImgNet and EfficientNet_TImgNet. The lowest val top-1 error is 15.61% while the one of the baseline is 15.73%.

CIFAR

The pre-trained models, i.e., ResNeXt-29 and EfficientNet-B1, with DCL on CIFAR-10 and CIFAR-100 are available at the shared drive ResNeXt_CIFAR10, EfficientNet_CIFAR10, ResNeXt_CIFAR100, and EfficientNet_CIFAR100, respectively. The lowest val error on CIFAR-10 is 1.79% while the one of the baseline is 1.91%. The lowest val error on CIFAR-100 is 11.65% while the one of the baseline is 11.81%. We used dcl_margin=0.1 on CIFAR-10 and 0.3 on CIFAR-100.

Contact

luoxx648 at umn.edu
Any discussions, suggestions, and questions are welcome!

congruency's People

Contributors

luoyan407 avatar oibook13 avatar

Stargazers

Julan avatar Yewon Han avatar sobieskibj avatar  avatar  avatar KANG IL LEE avatar Seok-Ju Hahn (Adam) avatar chenyiming avatar Kokyou avatar 白皓天 avatar litho avatar Yannis Kalfas avatar linyu avatar Pirazh Khorramshahi avatar  avatar Sungjae Lee (James) avatar FeiiYin avatar  avatar Guangzhi Wang avatar Chen Ma avatar Zhengyu Zhao avatar Miles Gray avatar Jinlai Zhang avatar

Watchers

paper2code - bot avatar

Forkers

oibook13 wellido

congruency's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.