Coder Social home page Coder Social logo

rea1axb / promptcal Goto Github PK

View Code? Open in Web Editor NEW

This project forked from sheng-eatamath/promptcal

0.0 0.0 0.0 882 KB

Official Implementation of paper: PromptCAL: Contrastive Affinity Learning via Auxiliary Prompts for Generalized Novel Category Discovery (CVPR'23)

License: MIT License

Shell 6.21% Python 91.17% Jupyter Notebook 2.62%

promptcal's Introduction

PromptCAL for Generalized Novel Category Discovery (CVPR'23)

This repo contains codes for our paper: PromptCAL: Contrastive Affinity Learning via Auxiliary Prompts for Generalized Novel Category Discovery

Abstract:

Although existing semi-supervised learning models achieve remarkable success in learning with unannotated in-distribution data, they mostly fail to learn on unlabeled data sampled from novel semantic classes due to their closed-set assumption. In this work, we target a pragmatic but under-explored Generalized Novel Category Discovery (GNCD) setting. The GNCD setting aims to categorize unlabeled training data coming from known and novel classes by leveraging the information of partially labeled known classes. We propose a two-stage Contrastive Affinity Learning method with auxiliary visual Prompts, dubbed PromptCAL, to address this challenging problem. Our approach discovers reliable pairwise sample affinities to learn better semantic clustering of both known and novel classes for the class token and visual prompts. First, we propose a discriminative prompt regularization loss to reinforce semantic discriminativeness of prompt-adapted pre-trained vision transformer for refined affinity relationships. Besides, we propose a contrastive affinity learning stage to calibrate semantic representations based on our iterative semi-supervised affinity graph generation method for semantically-enhanced prompt supervision. Extensive experimental evaluation demonstrates that our PromptCAL method is more effective in discovering novel classes even with limited annotations and surpasses the current state-of-the-art on generic and fine-grained benchmarks (with nearly 11% gain on CUB-200, and 9% on ImageNet-100) on overall accuracy.

main-5

🔖Contents

  1. Description
  2. Dependencies and datasets
  3. Running
    1. GNCD training
    2. Inductive GNCD training
    3. evaluation
    4. checkpoints
  4. Acknowledgement
  5. Citation

Descriptions

This is the official PyTorch implementation of "PromptCAL: Contrastive Affinity Learning via Auxiliary Prompts for Generalized Novel Category Discovery".

Dependencies and datasets

Our experiments are based on benchmarks including: fine-grained datasets (CUB-200, StanfordCars, Aircraft, Herbarium2019), generic datasets (CIFAR, ImageNet-100).

Note that we use the same dataset split as GCD. For ImageNet-100 dataset, we sample 100 classes adopted in this issue from the original ImageNet-1k dataset into another folder and rename this dataset as imagenet_100_gcd.

The dependency file of our model is requirement.txt

We can run pip install -r requirement.txt

Download datasets in your folder and change their corresponding datapaths in /config.py, /data/imagenet.py, /data/stanford_cars.py, /methods/contrastive_training/common.py.

Running

Preparation

Before the training, create a /cache folder for saving checkpoints and results.

Before the evaluation, create a /tmp folder for saving intermediate and evaluation results.

Training under GNCD setting

Our model is run on a single A6000 GPU (48GB); you can also use --devices option since we also support training on multiple GPUs with nn.DataParallel. For example, we can --device 0 --devices '0,1,2,3' for training on four GPUs.

Training script: bash_scripts/train_stage_1.sh for 1st stage and bash_scripts/train_stage_2.sh for 2nd stage. Configurations for each datasets are separately listed in these two files.

Example annotated bash commands for training are listed as follows:

### stage 1 training for fine-grained dataset
...
python -m methods.contrastive_training.contrastive_training_1 \
            --device 'cuda:3' \
          # --devices '0,1' \ # if use cuda:0 and cuda:1 for data parallelism, --device must be 'cuda:0'
            --exp_id $EXP_ID \ # experiment id
            --lr 0.1 \
            --epochs 200 \
            --dataset_name ${DATASET} \ # dataset
            --prop_train_labels 0.5 \ # labeling ratio for known classes
            --use_val 'False' \ # false for GNCD setting
            --kmeans_interval -1 \ # fast kmeans evaluation interval, -1 means every epoch
            --eval_interval 1 \ # kmeans evaluation interval
            --use_fast_kmeans 'True' \ # fast kmeans is GPU-side KMeans
            --inkd_T 5 \ # ImageNet Knowledge Distillation (INKD) loss epoch
            --w_inkd_loss 0.5 \ # INKD loss initial weight
            --w_inkd_loss_min 0.01 \ # INKD loss minimum decayed weight
            --num_dpr 2 \ # number of supervised prompts
            --w_prompt_clu 0.35 \ # DPR loss weight
            --num_prompts 5 \ # total prompt number
            --early_stop -2
...

### stage 2 training for fine-grained dataset
...
python -m methods.contrastive_training.contrastive_training_2 \
            --device 'cuda:2' \
            --epochs 100 \
            --lr 0.1 \
            --runner_name $EXP_NAME \
            --exp_id $EXP_ID \
            --eval_interval 1 \
            --eval_interval_t 10 \ # evaluation interval for teacher model
            --kmeans_interval -1 \
            --use_fast_kmeans 'True' \
            --use_val 'False' \
            --val_split 0.1 \
            --dataset_name ${dataset_name} \ # dataset
            --prop_train_labels 0.5 \
            --num_prompts 5 \
            --num_dpr 2 \
            --w_prompt_clu 0.35 \
            --knn 5 \ # neighborhood size
            --w_knn_loss 0.6 \ # CAL loss weight
            --diffusion_q 0.8 \ # affinity threshold
            --load_from_model ${HOME}/cache/log-${save_dir}-surveillance/log/${checkpoint_exp}/checkpoints/${checkpoint_model}.pt \ ### warmup backbone checkpoint
            --load_from_head ${HOME}/cache/log-${save_dir}-surveillance/log/${checkpoint_exp}/checkpoints/${checkpoint_head}.pt ### warmup CLS head checkpoint
...

Training under Inductive GNCD training

Switching between GNCD and Inductive GNCD training is quite easy; we can simply set --use_val 'True' if we want to train in Inductive GNCD setting (Note that for both stage 1 and stage 2, the --use_val 'True' are set; otherwise, the class splits will be inconsistent between two stages). Another option --val_split 0.1 means holding out 10% labeled and unlabeled training data as a small validation set.

Example annotated bash commands for 2nd stage training are listed as follows:

### stage 2 training for fine-grained dataset
...
python -m methods.contrastive_training.contrastive_training_2 \
            --device 'cuda:2' \
            --epochs 100 \
            --lr 0.1 \
            --runner_name $EXP_NAME \
            --exp_id $EXP_ID \
            --eval_interval 1 \
            --eval_interval_t 10 \
            --kmeans_interval -1 \
            --use_fast_kmeans 'True' \
            --use_val 'True' \
            --val_split 0.1 \
            --dataset_name ${dataset_name} \
            --prop_train_labels 0.5 \
            --use_vpt 'True' \
            --num_prompts 5 \
            --predict_token 'cop' \
            --num_dpr 2 \
            --w_prompt_clu 0.35 \
            --knn 5 \
            --w_knn_loss 0.6 \
            --diffusion_q 0.8 \
            --load_from_model ${HOME}/cache/log-${save_dir}-surveillance/log/${checkpoint_exp}/checkpoints/${checkpoint_model}.pt \
            --load_from_head ${HOME}/cache/log-${save_dir}-surveillance/log/${checkpoint_exp}/checkpoints/${checkpoint_head}.pt
...

Evaluation

We evaluate the best checkpoints in the 2nd stage using semi-supervised Kmeans algorithm. The script is given /bash_scripts/k_means.sh.

HOME=<YOUR HOME PATH>
DATASET=<DATASET>
...
PARALLEL='False' # set 'True' if use data parallelism
EVAL_TEST='False' # if 'True', evaluate on test set under Inductive GNCD setting; else 'False'
...

EXP_NAME=<EXP_NAME> # experiment folder name
EXP_ID=<EXP_ID> # exp_id of the 2nd stage
DEVICE='cuda:0'

...

### ==============================================================================================


python -m methods.clustering.extract_features --dataset ${DATASET} --use_best_model ${postfix_2} \
 --warmup_model_dir ${MODEL_PATH} --transform ${TRANSFORM} \
 --num_prompts ${NUM_PROMPTS} --device ${DEVICE} \
 --with_parallel ${PARALLEL}

python -m methods.clustering.k_means --dataset ${DATASET} --semi_sup 'True' --use_ssb_splits 'True' \
 --use_best_model ${postfix_2} --max_kmeans_iter 200 --k_means_init 100 --warmup_model_exp_id ${EXP_ID} \
 --device ${DEVICE} --eval_test ${EVAL_TEST} \
 >> ${SAVE_DIR}kmeans/${EXP_ID}_logfile_.out

Checkpoints

We release trained models at checkpoints.

Setting Dataset Stage 1 Stage 2
GNCD CIFAR10 model model
GNCD CIFAR100 model model
GNCD ImageNet-100 model model
GNCD CUB-200 model model
GNCD StanfordCars model model
GNCD Aircraft model model
GNCD Herbarium19 model model
Inductive GNCD CIFAR100 model model
Inductive GNCD CUB-200 model model
Inductive GNCD ImageNet-100 model model

Acknowledgement

Some of our codes are borrowed from Generalized Category Discovery.

Citation

@article{zhang2022promptcal,
  title={PromptCAL: Contrastive Affinity Learning via Auxiliary Prompts for Generalized Novel Category Discovery},
  author={Zhang, Sheng and Khan, Salman and Shen, Zhiqiang and Naseer, Muzammal and Chen, Guangyi and Khan, Fahad},
  journal={arXiv preprint arXiv:2212.05590},
  year={2022}
}

promptcal's People

Contributors

sheng-eatamath avatar muzammal-naseer avatar rea1axb avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.