Coder Social home page Coder Social logo

yxl-0713 / mtms-med-seg-kd Goto Github PK

View Code? Open in Web Editor NEW

This project forked from risabbiswas/mtms-med-seg-kd

0.0 0.0 0.0 3.52 MB

Advancing Medical Image Segmentation Through Multi-Task and Multi-Scale Contrastive Knowledge Distillation

Python 100.00%

mtms-med-seg-kd's Introduction

Multi Task Multi-Scale Contrastive Knowledge Distillation for Efficient Medical Image Segmentation

Thesis Title - "Advancing Medical Image Segmentation Through Multi-Task and Multi-Scale Contrastive Knowledge Distillation" (Thesis will be added here soon!)

This is my master’s thesis, where I investigate the feasibility of knowledge transfer between neural networks for medical image segmentation tasks, specifically focusing on the transfer from a larger multi-task “Teacher” network to a smaller “Student” network using a multi-scale contrastive learning approach.

Results

Below are a few quantitative and qualitative results. KD(T1, S1) and KD(T1, S2) are the results obtained from our proposed method. More detailed results and ablation studies can be found in the thesis.

Light         Light

Light        

MTMS-Contrastive Knowledge Distillation

The overall architecture of our multi-task multi-scale contrastive knowledge distillation framework for segmentation.

Light

Contrastive Learning

Representation of Contrastive Pairs. A beginner’s guide to Contrastive Learning can be found here.

Light

Knowledge Distillation

Teacher-Student Framework for Knowledge Distillation. A beginner’s guide to Knowledge Distillation can be found here.

Light

Multi-Task Teacher Network

We trained two teacher models T1 and T2, one a multi-task pre-trained U-Net and a multi-task TransUNet, respectively.

Light

Student Network

The student model, a simplified version of the teacher model, is significantly smaller in scale and is trained on only 50% of the data compared to the teacher model.

Light

Datasets Used

The CT spleen segmentation dataset from the medical image decathlon is used for all the experiments. Below are the links to the processed 2D images from the CT spleen dataset -

Additional Datasets

Additionally, other binary segmentation datasets that can be explored are -

Other multi-class segmentation datasets that can be explored are -

Steps to Use the Framework

Step 1 - Clone the repository to your desired location:

git clone https://github.com/RisabBiswas/MTMS-Med-Seg-KD
cd MTMS-Med-Seg-KD

Step 2 - Process Data

There are two options - Either download the .NIFTI file and convert them to 2D slices using the conversion script or, you can use the processed spleen dataset, which can be downloaded from the above link.

The data is already split into training and testing datasets.

> Input CT Volume of Spleen Dataset -

> Processed 2D Slices -

Step 3 - Train the Teacher Network

Training the multi-task teacher network (T1 or T2) is straightforward. Now that you have already created data folders, to train the T1 model, follow the below commands.

cd Multi-Task Teacher Network (T1)

or,

cd Multi-Task Teacher Network (T2)

Run the training script -

python train.py

You can experiment with different weight values for the reconstruction loss. Additionally, for all the experiments I have used DiceBCE loss as the choice of loss function. You can try other loss functions as well such as Dice Loss.

The pre-trained weights can also be downloaded from below -

  • T1 - Will be uploaded soon!
  • T2 - Will be uploaded soon!

Step 4 - Inference on the Teacher Network

Once the teacher network is trained, to run inference, follow the below command -

python inference.py

also, you can look at the metrics by running the following -

python metrics.py

Step 4 - Train the Student Network (S1 or S2) W/o Knowledge Distillation

Before performing knowledge distillation and analysing its effect on the student model, we would like to train the student model and see its performance w/o any knowledge transfer from the teacher network.

cd Student Network (S1)

Run the training script -

python train.py

Run the inference script -

python inference.py

Also, you can look at the metrics by running the following -

python metrics.py

The pre-trained weights can also be downloaded from below -

  • S1 - Will be uploaded soon!
  • S2 - Will be uploaded soon!

Step 5 - Train the Student Network (S1 or S2) With Knowledge Distillation

The steps to train the student model with contrastive knowledge distillation are similar and straightforward -

cd KD_Student Network (T1-S1)

Run the training script -

python train_Student.py

Run the inference script -

python inference.py

Also, you can look at the metrics by running the following -

python metrics.py

The knowledge distillation is performed at various scales, which can be customised in the training code.

Further Exploration

Currently, the architecture has only been tested on binary segmentation tasks and there is still room for further exploration such as -

  • Experiment on multi-class segmentation task.
  • Try other contrastive loss.

Acknowledgement

I extend my heartfelt gratitude to my guru 🙏🏻 Dr. Chaitanya Kaul for his visionary guidance and unwavering support throughout my project. His mentorship has significantly shaped me as a researcher and a better individual. I am profoundly grateful for his invaluable contributions to my professional and personal growth.

Authors

Read the Thesis

You can find it here if you are interested in reading the thesis. And if you like the project, we would appreciate a citation to the original work:

Citation will be added here soon!

Contact

If you have any questions, please feel free to reach out to Risab Biswas.

Conclusion

I appreciate your interest in my research. The code should not have any bugs, but if there are any, I am are sorry about that. Do let us know in the issues section, and we will fix it ASAP! Cheers!

mtms-med-seg-kd's People

Contributors

risabbiswas avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.