Coder Social home page Coder Social logo

aurooj / vlm_ss Goto Github PK

View Code? Open in Web Editor NEW
0.0 0.0 0.0 30.73 MB

Mini-batch selective sampling for knowledge adaption of VLMs for mammography.

License: Apache License 2.0

Jupyter Notebook 70.69% Python 29.31% Shell 0.01%
mammogram medical-imaging medical-vision-language-model miccai miccai2024 multimodal-learning multimodal-representation-learning multimodal-retrieval vision-and-language vision-language-transformer minibatch-selective-sampling

vlm_ss's Introduction

VLM_SS

Mini-batch selective sampling for knowledge adaption of VLMs

Knowledge-grounded Adaptation Strategy for Vision-language Models: Building a Unique Case-set for Screening Mammograms for Residents Training [MICCAI 2024] Aisha Urooj Khan, John Garrett, Tyler Bradshaw, Lonie Salkowski, Jiwoong Jeong, Amara Tariq, Imon Banerjee

Paper | BibTeX

Official Pytorch implementation of Mini-batch selective sampling for knowledge adaption of VLMs and pre-trained models.

Requirements

The code for selective sampling works for python 3.x versions and requires the following packages: random, math, and tqdm

For selective sampling in ALBEF, we use the environment provided from ALBEF original repo, and update/install a few packages using ALBEF/requirements.txt.

For selective sampling in MedCLIP, we use the environment provided in MedCLIP/requirements.txt.

Usage

Mini-batch selective sampling requires the mammography image-report pairs to have the group information. group information can be extracted from radiology reports using the notebook extract_groups.ipynb provided in this code repo.

Assuming that your data is in the format as provided in sample_data.json, you can enable the selective sampling by creating a SelectiveSampling object in your custom Pytorch Dataset class as follows:

  1. Import selective sampling in your custom dataset code:
from selective_sampling import SelectiveSampling
  1. Add selective sampling to your custom dataset class in the __init__() function:
self.selective_sampling = SelectiveSampling(data)

where data is the list of data instances' dictionaries

Each dictionary in this list belongs to a data instance, and can have an arbitrary number of columns but expects the key 'group' in each item's dictionary. Example: Format: [ { 'colA':<value>, 'colB':<value>, 'group':['a', 'b'] }, { 'colA':<value>, 'colB':<value>, 'group':['a', 'b'] }, ... ]

  1. Add shuffle() definition to your dataset class:
def shuffle(self, bs=8, rare_grp_ratio=0.375, batch_shuffle=False):
        #function to prepare minibatches based on selective sampling strategy
        # calls shuffle function from the base class (SelectiveSampling) 
        self.ann = self.selective_sampling.shuffle(bs=bs, rare_grp_ratio=rare_grp_ratio, batch_shuffle=batch_shuffle)

where bs=batch_size, rare_grp_ratio=ratio of samples from rare groups in a mini-batch, batch_shuffle= a boolean flag, if set to True, mini-batch is shuffled further; default value is set to False based on the ablations reported in table 3, rows (5) and (6) in the paper.

  1. Now in your main training loop, you can call the selective sampling based shuffling by calling the shuffle method in your custom dataset class, i.e., data_loader.dataset.shuffle(bs=<batch_size>).

    for epoch in range(epochs):
         
         if config['selective_sampling']:
             data_loader.dataset.shuffle(bs=config['batch_size'])
    

    See example snippet from ALBEF/Pretrain.py below:

        for epoch in range(start_epoch, max_epoch):

                if config['selective_sampling']:
                    #at every epoch, shuffle data with custom sampling function for medical data
                    print(f"Shuffling training data for epoch {epoch}")
                    data_loader.dataset.shuffle(bs=config['batch_size'],
                                                rare_grp_ratio=config['rare_grp_ratio'],
                                                batch_shuffle=config['batch_shuffle']
                                                )
                        ...

NOTE:

The default sampling from PyTorch, i.e., data_loader.shuffle should be set to False in the training dataloader when using selective sampling based shuffling.

We integrated selective sampling as part of this work into ALBEF and MedCLIP code repos. The updated codes are provided as part of this code repo. We thank the authors of ALBEF and MedCLIP for providing their amazing code repos.

ALBEF with Selective Sampling

Read instructions from albef.md for ALBEF trained with selective sampling.

MedCLIP with Selective Sampling

Read instructions from medclip.md for MedCLIP trained with selective sampling.

Todo:
  • Add MedCLIP's modified version with selective sampling here.
  • Add pretrained model weights

Citation

If this work and/or its findings are useful for your research, please cite our paper. Todo:

  • Replace with bibtex from MICCAI
@misc{khan2024knowledgegrounded,
    title={Knowledge-grounded Adaptation Strategy for Vision-language Models: Building Unique Case-set for Screening Mammograms for Residents Training},
    author={Aisha Urooj Khan and John Garrett and Tyler Bradshaw and Lonie Salkowski and Jiwoong Jason Jeong and Amara Tariq and Imon Banerjee},
    year={2024},
    eprint={2405.19675},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}

Questions?

Please contact '[email protected]'

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.