Coder Social home page Coder Social logo

mgca's Introduction

MGCA

Multi-Granularity Cross-modal Alignment for Generalized Medical Visual Representation Learning, NeurIPS 2022.

framework

Installation

To clone this repository:

git clone https://github.com/fuying-wang/MGCA.git

To install Python dependencies:

pip install -r requirements.txt

To install package mgca:

pip install -e .

Dataset downloading

Datasets we used are as follows:

  • MIMIC-CXR: We downloaded the MIMIC-CXR-JPG dataset as the radiographs. Paired medical reports can be downloaded in MIMIC-CXR.

  • CheXpert: We downloaded the CheXpert dataset which consisting of 224,316 chest radiographs of 65,240 patients.

  • RSNA: We used the stage 2 of RSNA dataset in Kaggle.

  • COVIDx: We used the version 6 of COVIDx dataset in Kaggle.

  • SIIM: We downloaded the stage 1 of SIIM dataset in Kaggle.

  • Object-CXR: We downloaded the object-CXR dataset in its official website.

After downloading datasets, please check if the path in mgca/constants.py is correct.

Data Preprocessing

We preprocessed these datasets and split the dataset into train/val/test set using the code in mgca/preprocess.

Pre-training

We pre-trained MGCA on MIMIC-CXR using this command:

cd mgca/models/mgca
CUDA_VISIBLE_DEVICES=0,1 python mgca_module.py --gpus 2 --strategy ddp

We train our framework 50 epochs on 2 pieces of RTX 3090 GPUs with batch size of 144. It takes about 1 day to pre-train this model.

Note that it is flexible to develop other pre-training models under this framework. You may create a folder in mgca/models and complete the {MODEL_NAME}_module.py file.

Pre-trained models can be found here.

Finetune on downstream tasks

We evlauate the performance of MGCA framework on three downstream tasks: classification, object detection and semantic segmentation. Before finetuning, we need set the path (or ckpt_path) argument to the path of pre-trained MGCA model.

Linear classification

We evaluate linear classification performance of our model using this command:

cd mgca/models/mgca
CUDA_VISIBLE_DEVICES=1 python mgca_finetuner.py --gpus 1 --dataset chexpert --data_pct 0.01

We can use --dataset to set specific dataset for finetuning. Here, 3 datsets are available: chexpert, rsna and covidx. We can use --data_pct to set the fraction of training data for finetuning.

Object detection

We evaluate object detection performance of our model using this command:

cd mgca/models/mgca
CUDA_VISIBLE_DEVICES=0 python mgca_detector.py --devices 1 --dataset rsna --data_pct 1 --learning_rate 5e-4

Here, 2 datsets are available: rsna and object_cxr.

To run all experiments for this detection task:

sh run_det_funetune.sh

Semantic segmentation

We evaluate semantic segmentation performance of our model using this command:

cd mgca/models/mgca
CUDA_VISIBLE_DEVICES=0 python mgca_segmenter.py --gpus 1 --data_pct 1 --dataset rsna --batch_size 16 --learning_rate 5e-4

Here, 2 datsets are available: rsna and siim.

To run all experiments for this detection task:

sh run_seg_funetune.sh

TODO List

  • Refactor the train/valid/test lists.

Reference

If you found our work useful in your research, please consider citing our works(s) at:

@article{wang2022multi,
  title={Multi-Granularity Cross-modal Alignment for Generalized Medical Visual Representation Learning},
  author={Wang, Fuying and Zhou, Yuyin and Wang, Shujun and Vardhanabhuti, Varut and Yu, Lequan},
  journal={arXiv preprint arXiv:2210.06044},
  year={2022}
}

mgca's People

Contributors

xypb avatar fuying-wang avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.