Coder Social home page Coder Social logo

smd's Introduction

Improving Self-supervised Lightweight Model Learning via Hard-aware Metric Distillation

A PyTorch implementation of our paper:

Improving Self-supervised Lightweight Model Learning via Hard-aware Metric Distillation (Video)

  • Accepted at ECCV 2022 (Oral).

Self-supervised Distillation on ImageNet

Dependencies

If you don't have python 3 environment:

conda create -n SMD python=3.8
conda activate SMD

Then install the required packages:

pip install -r requirements.txt

Only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported.

Get teacher network

To pre-train a unsupervised ResNet-50 model on ImageNet, run:

python main_simsiam.py \
  -a resnet50 \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  --fix-pred-lr \
  [your imagenet-folder with train and val folders]

Or you can download the pre-trained teacher model from SimSiam.

Unsupervised Distillation

To distill a ResNet-18 model on ImageNet in an 4-gpu machine, run:

python main_distill.py \
  -a resnet18 \
  --dist-url 'tcp://localhost:10002' --multiprocessing-distributed --world-size 1 --rank 0 \
  --teacher_path [your pre-trained teacher model path] \
  [your imagenet-folder with train and val folders]

Linear Classification

With a pre-trained model, to train a supervised linear classifier on frozen features in an 4-gpu machine, run:

python main_lincls.py \
  -a resnet18 \
  --dist-url 'tcp://localhost:10003' --multiprocessing-distributed --world-size 1 --rank 0 \
  --pretrained [your checkpoint path]/checkpoint_0099.pth.tar \
  --lars \
  [your imagenet-folder with train and val folders]

The above command uses LARS optimizer and a default batch size of 4096.

Acknowledgement

This repository is partly built upon SimSiam, DisCo and SEED. Thanks for their great works!

Citation

If you use SMD in your research or wish to refer to the baseline results published in this paper, please use the following BibTeX entry.

@inproceedings{ECCV2022smd,
  title={Improving Self-supervised Lightweight Model Learning via Hard-Aware Metric Distillation},
  author={Liu, Hao and Ye, Mang},
  booktitle={European Conference on Computer Vision},
  pages={295--311},
  year={2022},
  organization={Springer}
}

smd's People

Contributors

liuhao-lh avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

lliai

smd's Issues

checkpoints

Thanks for the great work. Do you have any plans to release the checkpoints?

hyperparams for CIFAR100 dataset

Thanks for the great work! I tried to reproduce the results in table 1 and can't get the reported results. So I want to check whether the hyperparams is incorrectly used .

  1. I used 2 rtx titan gpus and the hyperparams in paper and I noticed that the running scripts in README.md is on an 4-gpu machine, so you run the cifar-100 dataset on a 4-gpu machine with the reported hyperparams?
  2. The common linear eval recipe is : a. lr : 30 with step decay [60, [80]] CMC or b. lr: 0.1 with cosine decay MoCo v2, but I noticed you set learning rate as 30 with cosine decay, is it a typo?
  3. The result in table 1 is confusing. The teacher for column 1 and column 3 is resnet32x4. However the knn and linear acc is different. While the teacher for column 1 and column 4 is different, but the results is the same. Is it a typo too?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.