Coder Social home page Coder Social logo

bruinxiong / unidg Goto Github PK

View Code? Open in Web Editor NEW

This project forked from invictus717/unidg

0.0 0.0 0.0 759 KB

Towards Unified and Effective Domain Generalization

Home Page: https://invictus717.github.io/Generalization/

License: Apache License 2.0

Shell 1.33% Python 98.24% Dockerfile 0.43%

unidg's Introduction

Towards Unified and Effective Domain Generalization


arXiv website LICENSE

Highlights ⭐⭐⭐

  • 🚀 UniDG is an effective Test-Time Adaptation scheme. It brings out an average improvement to existing DG methods by +5.0% Accuracy on DomainBed benchmarks including PACS, VLCS, OfficeHome, and TerraInc datasets.

  • 🚀 UniDG is architecture-agnostic. Unified with 10+ visual backbones including CNN, MLP, and transformer-based models, UniDG can bring out consistent performance enhancement of +5.4% average on domain generalization.

  • 🏆 Achieved 79.6 mAcc on PACS , VLCS , OfficeHome, TerraIncognita , and DomainNet datasets.

Abstract

We propose UniDG, a novel and Unified framework for Domain Generalization that is capable of significantly enhancing the out-of-distribution generalization performance of foundation models regardless of their architectures. The core idea of UniDG is to finetune models during the inference stage, which saves the cost of iterative training. Specifically, we encourage models to learn the distribution of test data in an unsupervised manner and impose a penalty regarding the updating step of model parameters. The penalty term can effectively reduce the catastrophic forgetting issue as we would like to maximally preserve the valuable knowledge in the original model. Empirically, across 12 visual backbones, including CNN-, MLP-, and Transformer-based models, ranging from 1.89M to 303M parameters, UniDG shows an average accuracy improvement of +5.4% on DomainBed.

Features 🎉🎉🎉

  1. 🌟 Extensibility: we intergrate UniDG with Domainbed. More networks and algorithms can be built easily with our framework, and UniDG brings out an average improvement of +5.0% to existing methods including ERM, CORAL, and MIRO.

  2. 🌟 Reproducibility: all implemented models are trained on various tasks at least three times. Mean±std is provided in the UniDG paper. Pretrained models and logs are available.

  3. 🌟 Ease of Use: we develop tools to charge experimental logs with json files, which can transform results directly into latex:

    logs

  4. 🌟 Visualization Tools: we provides scripts to easily visualize results by T-SNE and performance curves:

    • Convergence Curves:

    curves

    • T-SNE Visualization Results:

    T-SNE

Model Zoo for UniDG

We provide pretrained checkpoints with base ERM algorithm to reproduce our experimental results conveniently.

Note that IID_best.pkl is the pretrained source model.

  • CORAL Source Models
Backbone Dataset Algorithm Base Model Adaptation Google Drive
ResNet-50 VLCS | PACS | OfficeHome | TerraInc| DomainNet CORAL 64.1 ± 0.1 69.3 ± 0.2 ckpt
Swin Transformer VLCS | PACS | OfficeHome | TerraInc CORAL 77.2 ± 0.1 82.5 ± 0.2 ckpt
ConvNeXt VLCS | PACS | OfficeHome | TerraInc| DomainNet CORAL 75.1 ± 0.1 79.6 ± 0.3 ckpt
  • ERM Source Models
Backbone Dataset Algorithm Base Model Adaptation Google Drive
ResNet-18 VLCS | PACS | OfficeHome | TerraInc ERM 63.0 ± 0.0 67.2 ± 0.2 ckpt
ResNet-50 VLCS | PACS | OfficeHome | TerraInc ERM 67.6 ± 0.0 73.1 ± 0.2 ckpt
ResNet-101 VLCS | PACS | OfficeHome | TerraInc ERM 68.1 ± 0.1 72.3 ± 0.3 ckpt
Mobilenet V3 VLCS | PACS | OfficeHome | TerraInc ERM 58.9 ± 0.0 65.3 ± 0.2 ckpt
EfficientNet V2 VLCS | PACS | OfficeHome | TerraInc ERM 67.2 ± 0.0 72.1 ± 0.3 ckpt
ConvNeXt-B VLCS | PACS | OfficeHome | TerraInc ERM 79.7 ± 0.0 83.7 ± 0.1 ckpt
ViT-B16 VLCS | PACS | OfficeHome | TerraInc ERM 69.5 ± 0.0 75.4 ± 0.2 ckpt
ViT-L16 VLCS | PACS | OfficeHome | TerraInc ERM 74.1 ± 0.0 79.9 ± 0.3 ckpt
DeiT VLCS | PACS | OfficeHome | TerraInc ERM 73.5 ± 0.0 77.8 ± 0.2 ckpt
Swin Transformer VLCS | PACS | OfficeHome | TerraInc ERM 77.2 ± 0.0 81.5 ± 0.3 ckpt
Mixer-B16 VLCS | PACS | OfficeHome | TerraInc ERM 57.2 ± 0.1 65.6 ± 0.3 ckpt
Mixer-L16 VLCS | PACS | OfficeHome | TerraInc ERM 67.4 ± 0.0 73.0 ± 0.2 ckpt

🔧 Get Started

Environments Set up

git clone https://github.com/invictus717/UniDG.git && cd UniDG
conda env create -f UniDG.yaml &&  conda activate UniDG

Datasets Preparation

python -m domainbed.scripts.download \
       --data_dir=./data

⏳ Training & Test-time adaptation

Train a model:

python -m domainbed.scripts.train\
       --data_dir=./data \
       --algorithm ERM \
       --dataset OfficeHome \
       --test_env 2 \
       --hparams "{\"backbone\": \"resnet50\"}" \
       --output_dir my/pretrain/ERM/resnet50

Note that you can download our pretrained checkpoints in the Model Zoo.

Then you can perform self-supervised adaptation:

python -m domainbed.scripts.unsupervised_adaptation \
       --input_dir my/pretrain/ERM/resnet50 \
       --adapt_algorithm=UniDG

📆 Collect Experimental Results

Then you can perform self-supervised adaptation:

python -m domainbed.scripts.collect_all_results\
       --input_dir=my/pretrain/ERM \
       --adapt_dir=results/ERM/resnet50 \
       --output_dir=log/UniDG/ \
       --adapt_algorithm=UniDG \
       --latex

📈 Visualization results

For T-SNE visualization:

python -m domainbed.scripts.visualize_tsne\
       --input_dir=my/pretrain/ERM \
       --adapt_dir=UniDG/results/ERM/resnet50 \
       --output_dir=log/UniDG/ \
       --adapt_algorithm=UniDG \
       --latex

For performance curves visualization:

python -m domainbed.scripts.visualize_curves\
       --input_dir=my/pretrain/ERM \
       --adapt_dir=UniDG/results/ERM/resnet50 \
       --output_dir=log/UniDG/ \
       --adapt_algorithm=UniDG \
       --latex

Citation

If this work is helpful for your research, please consider citing the following BibTeX entry.

@article{zhang2023unified,
      title={Towards Unified and Effective Domain Generalization}, 
      author={Yiyuan Zhang and Kaixiong Gong and Xiaohan Ding and Kaipeng Zhang and Fangrui Lv and Kurt Keutzer and Xiangyu Yue},
      year={2023},
      eprint={2310.10008},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledge

This repository is based on DomainBed, T3A, timm. Thanks a lot for their great works!

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

unidg's People

Contributors

invictus717 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.