Coder Social home page Coder Social logo

sam_myelin_seg_tem's Introduction

Myelin segmentation with SAM on TEM histology images

Introduction

Histology (microscopy) data is widely used by neuropathologists to study demylienation in the nervous system. This project aims to leverage a general-purpose foundation model to segment myelin on histology images. Foundation models are large DL models trained on large-scale data. They learn a general representation that can be adapted to a variety of downstream tasks. OpenAI's GPT serie, for example, are examples of foundation models for natural language processing. Facebook's Segment-Anything-Model (SAM) is one such promptable foundation model for segmentation tasks.

Data

The data used for this project is the data_axondeepseg_tem dataset privately hosted on an internal server with git-annex. It was used to train this model. It's also our biggest annotated dataset for myelin segmentation (20 subjects, 1360 MPx of manually segmented images). An older version of this dataset is publicly available on this OSF repository, under the data/raw/ directory. For more information on how to acces the data, see the How to reproduce section below.

SAM architecture

Results

[...]

How to reproduce

For a complete guide to reproduce these results, please see the README in the scripts folder.

sam_myelin_seg_tem's People

Contributors

hermancollin avatar

Watchers

Mathieu Boudreau, PhD avatar Julien Cohen-Adad avatar Simen Chen avatar Kostas Georgiou avatar

sam_myelin_seg_tem's Issues

Train object detection model to predict axon centroids

Pretty straightforward. This would make this pipeline fully automatic.

From my understanding, it would be possible to get accurate results because the SOTAs for object detection are pretty powerful. Also, the data is already there. For other datasets than TEM, some minor preprocessing would be required to get the axon centroids.

Try fully automatic axon segmentation

We would like to know how easily SAM can be fine-tuned for fully automatic segmentation (i.e. prompting with the whole image as a bbox instead of prompting with a ROI of interest).

The perfect pretext to try this is axon segmentation.

  1. If it works well, this axon segmentation could then be used to generate the bbox prompts used subsequently for the myelin segmentation
  2. There is no overlap between instances in the axon class. The myelin class would not be well suited because there are many overlaps (or rather touching myelin sheaths). A big advantage of segmenting the myelin with localized bbox prompts is that we get a clean and reliable instance segmentation. The way AxonDeepSeg currently works it that the semantic segmentation is "subdivided" (semantic to instance) with a watershed algorithm. Although it works fine in a lot of cases, sometimes this process deteriorates the segmentation. See example below, and look for small axons touching big axons. In the instance segmentation, a "leak" artifact occurs, where the myelin of the small axon is wrongly attributed to the big axon.
    image

Save best model based on validation dice

It would be important to keep the model that gives the best validation metrics. Currently, this is not the case and we use the final checkpoint for testing, but this is not optimal.

Train image encoder + mask decoder

It would be interesting to see if training the image encoder as well could help. MedSAM trains the encoder as well and they specify that all the weights in the image encoder were updated.

Randomize bounding box

def get_myelin_bbox(bbox_df, axon_id):
return np.array(bbox_df.iloc[axon_id])

Currently, bounding boxes are directly loaded for training. They are generated by extracting the thightest bounding box around the annotation. The exact coordinates of the bounding box should have a random component to avoid overfit. Similar to what was done in MedSAM:

https://github.com/bowang-lab/MedSAM/blob/8432244ac07be6baba120dcb786e8a694c188eb9/train_one_gpu.py#L104-L107

Add regularization to loss

I would like to add some regularization to the loss function for robustness to discourage the model to produce "glitchy" segmentations. For a perfect illustration, see the image below, taken from the validation set of https://github.com/brainhack-school2023/collin_project/tree/main (first iteration of this project).

Screenshot_20230722_142152

  1. Pink axon has discontinuities.
  2. Brown axon is not complete

I am not yet entirely sure how to regularize the myelin prediction, but will update this issue later.

Move on to other datasets and aggregate datasets

The TEM dataset was the biggest of our private datasets, but we will also need to move on to other modalities to compare performance with ivadomed and nnunetv2.

The ultimate goal of using SAM was to train a single "foundation" model for every type of contrasts/resolution, so we will eventually want to train SAM on an aggregation of all our private datasets. For comparison, we will need to train models on every dataset individually before aggregating. This necessary step will also allow us to fix bugs with these datasets before we use them all at once.

ViT_H was trained with ViT_B image embeddings...

I just realized that I forgot to re-compute the image embeddings for ViT_H training. Not sure I understand why the training could still be completed... This needs to be fixed for proper ViT_H results.

Train cascaded pipeline

Currently, the axon and myelin segmentation models are trained separately and independently. The 2 cascaded models should be trained at once. This would also allow parameter sharing, like having a common image encoder for both models.

This month, the myelin segmentation should get good enough to move on to this "cascaded" training. I already expect a lot of autograd problems...

However, after this, the model should be ready for a public release.

Integrate SAM into ADS

(Maybe we should keep the training scripts separate from the rest)
This would be pretty straightforward to integrate. The model checkpoints will need a release, and we can use the inference script as a reference. The only time-consuming part would be to add tests.

Add shuffling in dataloader

def bids_dataloader(data_dict, maps_path, embeddings_path, sub_list):
'''
:param data_dict: contains img, mask and px_size info per sample per subject
:param maps_path: paths to myelin maps (instance masks)
:param embeddings_path paths to pre-computed image embeddings
:param sub_list subjects included
'''
subjects = list(data_dict.keys())
# # we keep the last subject for testing
# for sub in subjects[:-1]:
for sub in subjects:
if sub in sub_list:
samples = (s for s in data_dict[sub].keys() if 'sample' in s)
for sample in samples:
emb_path = embeddings_path / sub / 'micr' / f'{sub}_{sample}_TEM_embedding.pt'
bboxes = get_sample_bboxes(sub, sample, maps_path)
myelin_map = get_myelin_map(sub, sample, maps_path)
yield (emb_path, bboxes, myelin_map)

Order in which samples are loaded should be shuffled.

Roadmap

This is a general overview of what needs to be done in this project before moving on to other datasets. Currently, both the axon and myelin seg models outperform ivadomed, but the overall pipeline is not efficient and nnUNet still beats SAM.

  • integrate "patch-based" training: similar to how we trained U-Nets, this would allow bigger batch sizes and would eventually allow for data augmentation. Ideally, implement this with the MONAI dataloader for easier dataAug integration
  • merge axon and myelin image encoders (halves overall model size, allows parameter sharing, more efficient training pipeline); see #10. Eventually, all datasets would be aggregated and the image encoder would learn to process all modalities.
  • implement multi-GPU training to further increase the batch size and be able to train longer

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.