Coder Social home page Coder Social logo

sheetalgiri / nerve-segmentation-classification Goto Github PK

View Code? Open in Web Editor NEW

This project forked from johnpeterflynn/nerve-segmentation-classification

0.0 1.0 0.0 1.16 MB

Student project with teammate John P. Flynn : Collection of experimental U-net architectures for multitask learning on nerve ultrasound images.

Shell 0.11% Python 99.89%

nerve-segmentation-classification's Introduction

Multitask Learning on Nerve Ultrasound Images

This repository contains several multi-task extensions of a U-Net model[1] to improve segmentation results on a small ultrasound nerve dataset. Our approaches were guided by TUM's chair for Computer Aided Medical Procedures. We applied multi-task learning[2] with a U-Net model to improve segmentation results on a very limited dataset. We implemented multiple architectures including hard parameter sharing using an FCN classifier at the U-net bottleneck, soft parameter sharing using cross-stitch networks[3] as well as a ResNet-18 benchmark classifier. Our classifiers used cross entropy loss and segmenters used dice loss. We experimented with several multitask loss approaches including linear weighting of classification and segmentation loss, uncertainty weighting[4] and loss scheduling[5].

Networks

model/quicknat.py - Vanilla QuickNAT architecture very similar to U-Net but adapted for fast brain image segmentation. This network serves as a benchmark for segmentation results.

model/resultnet.py - PyTorch ResNet-18 adapted for nerve image classification. Similarly serves as a gold standard for classification results. Note that our multitask networks use an encoder + FCN classifier so we won't expect their accuracy to be as high as ResNet.

model/quickfcn.py - Hard parameter sharing model [2]: QuickNAT nerve segmentation network with a fully connected layer for nerve classification attached to bottleneck.

model/softquickfcn.py - Soft parameter sharing model [2]: QuickNAT nerve segmention network and a separate nerve classifier network with an identical encoder. Both networks are independently pretrained on their specific tasks. Encoders of both networks are then joined using Cross-stitch networks[3] for a second round of training.

Results

We compare two different multitask learning extensions of QuickNAT. Hard parameter sharing uses the same encoder and bottleneck weights for classification and segmentation while soft parameter sharing optionally shares cross-stitch wrights between independent encoders.

Below we see that both networks segment nerves well when the nerve class is predicted correctly (ground truth in red and prediction in blue). However when a nerve is misclassified, hard parameter sharing fails to predict an accurate segmentation. We hypothesize that this is because the tasks of segmentation and classification are quite distinct; the more the network learns about classification, the less it knows about segmentation. In contrast, soft parameter sharing segments nerves correctly even when they are misclassified. This is expected, because the classification and segmentation networks are still free to learn separately.

Note: Replace image

Dependencies

  • PyTorch - Python deep learning library
  • Tensorboard - Visualization of losses, metrics, segmentation results and confusion matrices.
  • polyaxon - GPU cluster scheduling

References

  1. Abhijit Guha Roy, Sailesh Conjeti, Nassir Navab, Christian Wachinger (2018). QuickNAT: Segmenting {MRI} Neuroanatomy in 20 seconds. CoRR
  2. Sebastian Ruder (2017). An Overview of Multi-Task Learning in Deep Neural Networks, CVPR
  3. Ishan Misra, Abhinav Shrivastava, Abhinav Gupta, Martial Hebert (2017). Cross-stitch Networks for Multi-task Learning, CVPR
  4. Alex Kendall, Yarin Gal and Roberto Cipolla, Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry and Semantics, CoRR 2017
  5. Sailesh Conjeti, Magdalini Paschali, Amin Katouzian, Nassir Navab (2017), Learning Robust Hash Codes for Multiple Instance Image Retrieval, MICCAI 2017

nerve-segmentation-classification's People

Contributors

faridyagubbayli avatar florianh3000 avatar johnpeterflynn avatar malamleh93 avatar sheetalgiri20 avatar waltsims avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.