Coder Social home page Coder Social logo

costsensitiveselftraining's Introduction

Code Repositiory for CSST

Cost Sensitive Self-Training for Optimising Non-decomposable Measures

Authors: Harsh Rangwani*, Shrinivas Ramasubramanian*, Sho Takemori*, Kato Takashi, Yuhei Umeda, Venkatesh Babu Radhakrishnan

NeurIPS 2022

Paper

Introduction

Self-training with semi-supervised learning algorithms allows highly accurate deep neural networks to be learned using only a fraction of labeled data. However, most self-training work focuses on improving accuracy, while practical machine learning systems have non-decomposable goals, such as maximizing recall across classes. We introduce the Cost-Sensitive Self-Training (CSST) framework, which generalizes self-training methods for optimizing non-decomposable metrics. Our framework can better optimize desired metrics using unlabeled data, under similar data distribution assumptions made for the analysis of self-training. Using CSST, we obtain practical self-training methods for optimizing different non-decomposable metrics in both vision and NLP tasks. Our results show that CSST outperforms the state-of-the-art in most cases across datasets and objectives.

Usage

Installation

  1. Create and activate a conda environment
conda create -n CSST
conda activate CSST
  1. Clone and install the requisite libraries
git clone https://github.com/val-iisc/CostSensitiveSelfTraining
cd CostSensitiveSelfTraining
pip install -r requirements.txt
  1. We recommend installation of W&B (weights and biases for detailed logging of performance metrics

Training

We present a sample training command for CIFAR-10 under imbalance factor 100 and labeled and unlabeled data split ratio of 1/4. We can change the objective as per requirement (--M argument, see docs)

python trainMetricOpt.py --M mean_recall_coverage --world-size 1 --rank 0 --multiprocessing-distributed --uratio 4 --num_labels 12500 --save_name <local logging name> --dataset cifar10 --imbalance 100 --num_classes 10 --amp --net WideResNet --overwrite  --widen_factor 2 --wandb-project <Project name> --wandb-runid <your-runid> --vanilla_opt True --ult True  --num_workers 4 --seed 0

Evaluation

We load the saved checkpoint and evaluate the model on the same seed split of the dataset

python eval.py --load <PATH> --dataset cifar10 --uratio 4 --net WideResNet --widen_factor 2 --imbalance 100 --num_classes 10 --seed 0

Results

We provide a summary of results for CIFAR-10 LT for the two objectives below, in comparison to the state-of-the-art:

Result Image

Citation

In case you find our work useful, please consider citing us as:

@inproceedings{
rangwani2022costsensitive,
title={Cost-Sensitive Self-Training for Optimizing Non-Decomposable Metrics},
author={Harsh Rangwani and Shrinivas Ramasubramanian and Sho Takemori and Kato Takashi and Yuhei Umeda and Venkatesh Babu Radhakrishnan},
booktitle={Advances in Neural Information Processing Systems},
editor={Alice H. Oh and Alekh Agarwal and Danielle Belgrave and Kyunghyun Cho},
year={2022},
url={https://openreview.net/forum?id=bGo0A4bJBc}
}

Contact

Please feel free to file an isssue or send us an email, in case you have any comments or suggestions.

costsensitiveselftraining's People

Contributors

rangwani-harsh avatar stablegradients avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.