Coder Social home page Coder Social logo

fkd's Introduction

๐Ÿš€ FKD: A Fast Knowledge Distillation Framework for Visual Recognition

Official PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition (ECCV 2022, ECCV paper, arXiv), Zhiqiang Shen and Eric Xing.

Abstract

Knowledge Distillation (KD) has been recognized as a useful tool in many visual tasks, such as the supervised classification and self-supervised representation learning. While the main drawback of a vanilla KD framework lies in its mechanism that most of the computational overhead is consumed on forwarding through the giant teacher networks, which makes the whole learning procedure a low-efficient and costly manner.

๐Ÿš€ Fast Knowledge Distillation (FKD) is a novel framework that addresses the low-efficiency drawback, simulates the distillation training phase, and generates soft labels following the multi-crop KD procedure, meanwhile enjoying a faster training speed than other methods. FKD is even more efficient than the conventional classification framework when employing multi-crop in the same image for data loading. It achieves 80.1% (SGD) and 80.5% (AdamW) using ResNet-50 on ImageNet-1K with plain training settings. This work also demonstrates the efficiency advantage of FKD on the self-supervised learning task.

Citation

@article{shen2021afast,
      title={A Fast Knowledge Distillation Framework for Visual Recognition}, 
      author={Zhiqiang Shen and Eric Xing},
      year={2021},
      journal={arXiv preprint arXiv:2112.01528}
}

What's New

Supervised Training

Preparation

FKD Training on CNNs

To train a model, run train_FKD.py with the desired model architecture and the path to the soft label and ImageNet dataset:

python train_FKD.py -a resnet50 --lr 0.1 --num_crops 4 -b 1024 --cos --temp 1.0 --softlabel_path [soft label path] [imagenet-folder with train and val folders]

For --softlabel_path, use format as ./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet.

Multi-processing distributed training on a single node with multiple GPUs:

python train_FKD.py \
--dist-url 'tcp://127.0.0.1:10001' \
--dist-backend 'nccl' \
--multiprocessing-distributed --world-size 1 --rank 0 \
-a resnet50 --lr 0.1 --num_crops 4 -b 1024 \
--temp 1.0 --cos -j 32 \
--save_checkpoint_path ./FKD_nc_4_res50_plain \
--softlabel_path [soft label path, e.g., ./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet] \
[imagenet-folder with train and val folders]

For multiple nodes multi-processing distributed training, please refer to official PyTorch ImageNet training code for details.

Evaluation

python train_FKD.py -a resnet50 -e --resume [model path] [imagenet-folder with train and val folders]

Training Speed Comparison

The training speed of each epoch is tested on HPC/CIAI cluster at MBZUAI with 8 NVIDIA V100 GPUs. The batch size is 1024 for all three methods: (i) regular/vanilla classification framework, (ii) Relabel and (iii) FKD. For Vanilla and ReLabel, we use the average of 10 epochs after the speed is stable. For FKD, we perform num_crops = 4 to calculate the average of (4 $\times$ 10) epochs, note that using 8 will give faster training speed. All other settings are the same for the comparison.

Method Network Training time per-epoch
Vanilla ResNet-50 579.36 sec/epoch
ReLabel ResNet-50 762.11 sec/epoch
FKD (Ours) ResNet-50 486.77 sec/epoch

Trained Models

Method Network accuracy (Top-1) weights configurations
ReLabel ResNet-50 78.9 -- --
FKD ResNet-50 80.1+1.2% link same as ReLabel while initial lr = 0.1 $\times$ $batch size \over 512$
FKD(Plain) ResNet-50 79.8 link Table 12 in paper
(w/o warmup&colorJ )
FKD(AdamW) ResNet-50 80.5 link Table 13 in paper
(same as our settings on ViT and SReT)
ReLabel ResNet-101 80.7 -- --
FKD ResNet-101 81.9+1.2% link Table 12 in paper
FKD(Plain) ResNet-101 81.7 link Table 12 in paper
(w/o warmup&colorJ )

Mobile-level Efficient Networks

Method Network FLOPs accuracy (Top-1) weights
FBNet FBNet-c100 375M 75.12% --
FKD FBNet-c100 375M 77.13%+2.01% link
EfficientNetv2 EfficientNetv2-B0 700M 78.35% --
FKD EfficientNetv2-B0 700M 79.94%+1.59% link

The training protocol is the same as we used for ViT/SReT:

# Use the same settings as on ViT and SReT
cd train_ViT
# Train the model
python -u train_ViT_FKD.py \
--dist-url 'tcp://127.0.0.1:10001' \
--dist-backend 'nccl' \
--multiprocessing-distributed --world-size 1 --rank 0 \
-a tf_efficientnetv2_b0 \
--lr 0.002 --wd 0.05 \
--epochs 300 --cos -j 32 \
--num_classes 1000 --temp 1.0 \
-b 1024 --num_crops 4 \
--save_checkpoint_path ./FKD_nc_4_224_efficientnetv2_b0 \
--soft_label_type marginal_smoothing_k5  \
--softlabel_path [soft label path] \
[imagenet-folder with train and val folders]

FKD Training on ViT/DeiT and SReT

To train a ViT model, run train_ViT_FKD.py with the desired model architecture and the path to the soft label and ImageNet dataset:

cd train_ViT
python train_ViT_FKD.py \
--dist-url 'tcp://127.0.0.1:10001' \
--dist-backend 'nccl' \
--multiprocessing-distributed --world-size 1 --rank 0 \
-a SReT_LT --lr 0.002 --wd 0.05 --num_crops 4 \
--temp 1.0 -b 1024 --cos \
--softlabel_path [soft label path] \
[imagenet-folder with train and val folders]

For the instructions of SReT_LT model, please refer to SReT for details.

Evaluation

python train_ViT_FKD.py -a SReT_LT -e --resume [model path] [imagenet-folder with train and val folders]

Trained Models

Model FLOPs #params accuracy (Top-1) weights configurations
DeiT-T-distill 1.3B 5.7M 74.5 -- --
FKD ViT/DeiT-T 1.3B 5.7M 75.2 link Table 13 in paper
SReT-LT-distill 1.2B 5.0M 77.7 -- --
FKD SReT-LT 1.2B 5.0M 78.7 link Table 13 in paper

Fast MEAL V2

Please see MEAL V2 for the instructions to run FKD with MEAL V2.

Self-supervised Representation Learning Using FKD

Please see FKD-SSL for the instructions to run FKD for SSL task.

Contact

Zhiqiang Shen (zhiqiangshen0214 at gmail.com or zhiqians at andrew.cmu.edu)

fkd's People

Contributors

szq0214 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.