Coder Social home page Coder Social logo

mahmoodlab / mcat Goto Github PK

View Code? Open in Web Editor NEW
150.0 4.0 34.0 552.58 MB

Multimodal Co-Attention Transformer for Survival Prediction in Gigapixel Whole Slide Images - ICCV 2021

License: GNU General Public License v3.0

Jupyter Notebook 78.08% Python 21.92%
early-fusion genomics mahmoodlab mcat multimodal multimodal-deep-learning multimodal-fusion pathology

mcat's Introduction

Multimodal Co-Attention Transformer (MCAT) for Survival Prediction in Gigapixel Whole Slide Images

Multimodal Co-Attention Transformer for Survival Prediction in Gigapixel Whole Slide Images, ICCV 2021. [HTML]
Richard J Chen, Ming Y Lu, Wei-Hung Weng, Tiffany Y Chen, Drew FK Williamson, Trevor Manz, Maha Shady, Faisal Mahmood
@inproceedings{chen2021multimodal,
  title={Multimodal Co-Attention Transformer for Survival Prediction in Gigapixel Whole Slide Images},
  author={Chen, Richard J and Lu, Ming Y and Weng, Wei-Hung and Chen, Tiffany Y and Williamson, Drew FK and Manz, Trevor and Shady, Maha and Mahmood, Faisal},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={4015--4025},
  year={2021}
}

Summary: We develop a method for performing early fusion between histology and genomics via: 1) formulating both WSIs and genomic inputs as embedding-like structures, 2) using co-attention mechanism that learns pairwise interactions between instance-level histology patches and genomic embeddings. In addition, we make connections between MIL and Set Transformers, and adapt Transformer Attention to WSIs for learning long-range dependencies for survival outcome prediction.

Updates:

  • 11/12/2021: Several users have raised concerns about the low c-Index for GBMLGG in SNN (Genomic Only). In using the gene families from MSigDB as gene signatures, IDH1 mutation was not included (key biomarker in distinguishing GBM and LGG).
  • 06/18/2021: Updated data preprocessing section for reproducibility.
  • 06/17/2021: Uploaded predicted risk scores on the validation folds for each models, and the evaluation script to compute the c-Index and Integrated AUC (I-AUC) validation metrics, found using the following Jupyter Notebook. Model checkpoints for MCAT are uploaded in the results directory.
  • 06/17/2021: Uploaded notebook detailing the MCAT network architecture, with sample input in the following following Jupyter Notebook, in which we print the shape of the tensors at each stage of MCAT.

Installation Guide for Linux (using anaconda)

Pre-requisites:

  • Linux (Tested on Ubuntu 18.04)
  • NVIDIA GPU (Tested on Nvidia GeForce RTX 2080 Ti x 16) with CUDA 11.0 and cuDNN 7.5
  • Python (3.7.7), h5py (2.10.0), matplotlib (3.1.1), numpy (1.18.1), opencv-python (4.1.1), openslide-python (1.1.1), openslide (3.4.1), pandas (1.1.3), pillow (7.0.0), PyTorch (1.6.0), scikit-learn (0.22.1), scipy (1.4.1), tensorflow (1.13.1), tensorboardx (1.9), torchvision (0.7.0), captum (0.2.0), shap (0.35.0)

Downloading TCGA Data

To download diagnostic WSIs (formatted as .svs files), molecular feature data and other clinical metadata, please refer to the NIH Genomic Data Commons Data Portal and the cBioPortal. WSIs for each cancer type can be downloaded using the GDC Data Transfer Tool.

Processing Whole Slide Images

To process WSIs, first, the tissue regions in each biopsy slide are segmented using Otsu's Segmentation on a downsampled WSI using OpenSlide. The 256 x 256 patches without spatial overlapping are extracted from the segmented tissue regions at the desired magnification. Consequently, a pretrained truncated ResNet50 is used to encode raw image patches into 1024-dim feature vectors, which we then save as .pt files for each WSI. The extracted features then serve as input (in a .pt file) to the network. The following folder structure is assumed for the extracted features vectors:

DATA_ROOT_DIR/
    └──TCGA_BLCA/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_BRCA/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_GBMLGG/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    └──TCGA_LUAD/
        ├── slide_1.ptd
        ├── slide_2.pt
        └── ...
    └──TCGA_UCEC/
        ├── slide_1.pt
        ├── slide_2.pt
        └── ...
    ...

DATA_ROOT_DIR is the base directory of all datasets / cancer type(e.g. the directory to your SSD). Within DATA_ROOT_DIR, each folder contains a list of .pt files for that dataset / cancer type.

Molecular Features and Genomic Signatures

Processed molecular profile features containing mutation status, copy number variation, and RNA-Seq abundance can be downloaded from the cBioPortal, which we include as CSV files in the following directory. For ordering gene features into gene embeddings, we used the following categorization of gene families (categorized via common features such as homology or biochemical activity) from MSigDB. Gene sets for homeodomain proteins and translocated cancer genes were not used due to overlap with transcription factors and oncogenes respectively. The curation of "genomic signatures" can be modified to curate genomic embedding that reflect unique biological functions.

Training-Validation Splits

For evaluating the algorithm's performance, we randomly partitioned each dataset using 5-fold cross-validation. Splits for each cancer type are found in the splits/5foldcv folder, which each contain splits_{k}.csv for k = 1 to 5. In each splits_{k}.csv, the first column corresponds to the TCGA Case IDs used for training, and the second column corresponds to the TCGA Case IDs used for validation. Alternatively, one could define their own splits, however, the files would need to be defined in this format. The dataset loader for using these train-val splits are defined in the get_split_from_df function in the Generic_WSI_Survival_Dataset class (inherited from the PyTorch Dataset class).

Running Experiments

To run experiments using the SNN, AMIL, and MMF networks defined in this repository, experiments can be run using the following generic command-line:

CUDA_VISIBLE_DEVICES=<DEVICE ID> python main.py --which_splits <SPLIT FOLDER PATH> --split_dir <SPLITS FOR CANCER TYPE> --mode <WHICH MODALITY> --model_type <WHICH MODEL>

Commands for all experiments / models can be found in the Commands.md file.

Issues

  • Please open new threads or report issues directly (for urgent blockers) to [email protected].
  • Immediate response to minor issues may not be available.

License & Usage

If you find our work useful in your research, please consider citing our paper at:

@inproceedings{chen2021multimodal,
  title={Multimodal Co-Attention Transformer for Survival Prediction in Gigapixel Whole Slide Images},
  author={Chen, Richard J and Lu, Ming Y and Weng, Wei-Hung and Chen, Tiffany Y and Williamson, Drew FK and Manz, Trevor and Shady, Maha and Mahmood, Faisal},
  booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
  pages={4015--4025},
  year={2021}
}

© Mahmood Lab - This code is made available under the GPLv3 License and is available for non-commercial academic purposes.

mcat's People

Contributors

faisalml avatar iccv2021anon avatar richarizardd avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

mcat's Issues

Mi-FC reproducibility

Hi, I am trying to reproduce your results but am having trouble with MI-FC as i do not have the required fast_cluster_ids.pkl file. I saw quite a bit of discussion in various issues, the closest answer to what I was looking for was this:

Hi @genhao3 - the fast_cluster_ids.pkl is a dictionary that I load for each cancer type, which maps case_id to [M x 1]-dim array of cluster assignments 1⋯C, where M is the number of patches and the indices correspond to the cluster assignment of a given patch embedding. You can use packages like faiss to generate these cluster assignments for running MI-FCN / DeepAttnMISL comparisons.

but it is a bit unspecific and the library mentioned is not straightforward for me. Could you provide either code used to create the file or the .pkl file itself? Any help is much appreciated :)
I know it is a baseline and not your proposed model, but would be interested to run this also.

Best
Valentin

Co-attention visualizations?

Hi,

Thanks for the publishing the nice code. I can't find code in this repo that makes the slide attention visualizations similar to that seen in Figures 2 and 3 of the paper. Is this available somewhere?

Thanks!
Ben

Questions about the computation of the survival layer in `MCAT_Surv`

It seems that the computation of the survival layer in MCAT_Surv(link) is wrong, and logits = self.classifier(h).unsqueeze(0) should be logits = self.classifier(h). With the old version, supposing that the batch_size=6 and n_classes=4, the logits will be of size of (1,6,4), the hazards will be of size of (1,6,4), the Y_hat will be of size of (1,1,4), which certainly does not contain the Y_hat for the 6 samples of the batch. Besides, the S will means the cumulative production of the survival(i.e. 1-hazards) along the batch dimension, what does this mean? This S is of size of (1,6,4), then the len(S) in CoxSurvLoss(link) will be 1, which certainly is not the batch size as expected.

In the end, could you provide the reference of the equations for you to write this cox loss?

"The archive is either unknown or damaged" - tcga_luad_all_clean.csv.zip

Hi Chen,

I appreciate your excellent repository as usual :)

I encountered an issue while attempting to read the tcga_luad_all_clean.csv.zip file using pandas or through manual extraction. When trying to unzip the folder, I get a popup window saying "The archive is either unknown or damaged", and when attempting to read it with pandas, I encounter the this error: "BadZipFile: File is not a zip file".
I'm unsure whether the file is still valid or if the problem lies on my end. Could you please advise if this is an issue with the file on GitHub or if there might be another problem?

Thanks a lot.
Omnia

Nan values in training

Hi there,

Thank you for sharing your nice work!
I met a problem when I try to train your model, it returned the nan loss and risk like below:
batch 99, loss: nan, label: 1, event_time: 14.6800, risk: nan, bag_size:
batch 199, loss: nan, label: 1, event_time: 20.1700, risk: nan, bag_size:
batch 299, loss: nan, label: 2, event_time: 29.3000, risk: nan, bag_size:

The error info are:
File "/opt/anaconda3/envs/mcat/lib/python3.11/site-packages/sksurv/metrics.py", line 214, in concordance_index_censored
event_indicator, event_time, estimate = _check_inputs(event_indicator, event_time, estimate)

File "/opt/anaconda3/envs/mcat/lib/python3.11/site-packages/sksurv/metrics.py", line 47, in _check_inputs
estimate = _check_estimate_1d(estimate, event_time)

File "/opt/anaconda3/envs/mcat/lib/python3.11/site-packages/sksurv/metrics.py", line 36, in _check_estimate_1d
estimate = check_array(estimate, ensure_2d=False, input_name="estimate")

File "/opt/anaconda3/envs/mcat/lib/python3.11/site-packages/sklearn/utils/validation.py", line 921, in check_array
_assert_all_finite(
File "/opt/anaconda3/envs/mcat/lib/python3.11/site-packages/sklearn/utils/validation.py", line 161, in _assert_all_finite
raise ValueError(msg_err)
ValueError: Input estimate contains NaN.

I checked the input and output of the model and found there are many nan values in the feature of both WSI and omic data which lead to the nan output of the hazards and S. I strictly followed the instructions you provided and really confused why this nan value would appear. If you met this problem before, could you tell me how to solve this?

Thank you!

Best.

ValueError: 'a' cannot be empty unless no samples are taken

In the following lines:

MCAT/utils/utils.py

Lines 109 to 110 in b9cca63

ids = np.random.choice(np.arange(len(split_dataset), int(len(split_dataset)*0.1)), replace = False)
loader = DataLoader(split_dataset, batch_size=1, sampler = SubsetSequentialSampler(ids), collate_fn = collate, **kwargs )

I met the above bug.Do you mean?

np.random.choice(np.arange(0, len(split_dataset)), int(len(split_dataset)*0.1), replace = False)

Question of C-Index

Hello author, I would like to ask if you have ever encountered the problem that loss decreases in the training set and verification set, and C-Index increases in the training set but remains unchanged or decreases in the verification set. May I ask how it was resolved? Looking forward to your reply!

Method for processing genetic data

Hello, thank you very much for your paper and project.
Could you please provide the Method for processing genetic data?What should I do to get .csv file from the cBioPortal website?
Looking forward to your reply.

about args.alpha_surv

Hi, the hyperparameter alpha_surv seems important when the censoring rate is high.
I wonder if you have any recommendations on choosing an appropriate alpha_surv when censoring rate is high.

Single slide in training

Hi @Richarizardd ,

Again thanks for the amazing work! I have keen to know for cases with multiple slides how is that handled?

From code cases are dropped if duplicate and then only slides are read from the dataframe consisting single slide. Am I missing something?

patients_df = slide_data.drop_duplicates(['case_id']).copy()

patients_df = slide_data.drop_duplicates(['case_id'])

Thanks,
Shubham

ValueError: too many values to unpack (expected 10)

for batch_idx, (data_WSI, data_omic, label, event_time, c) in enumerate(loader):

When the program run to the line above,some error occurs.How could I solve the problem?
After I change the line to:

for batch_idx, (data_WSI, data_omic, label, event_time, c,a,b,d,e,f) in enumerate(loader):

I still can't run the program successfully,because of the new bug namely 'KeyError: 'x_omic1' '.
I notice that the slide_id of the bug is 'TCGA-5T-A9QA',the first one of the validation list.And it only has the parameter of 'x_omic' instead of 'x_omic1'.
How could I solve the problem?

Please help check this line

Dear authors,
Please help check the following line:

if "IDC" in slide_data['oncotree_code']: # must be BRCA (and if so, use only IDCs)

I have tested the code as follows:

import pandas as pd
import numpy as np

csv_path = 'MCAT_master/datasets_sig_csv/tcga_brca_all_clean.csv.zip'
slide_data = pd.read_csv(csv_path, low_memory=False)

if "IDC" in slide_data['oncotree_code']:  # must be BRCA (and if so, use only IDCs)
    print('Yes, IDC is in there')
else:
    print('No, IDC is not in there')

if "IDC" in slide_data['oncotree_code'].values:  # must be BRCA (and if so, use only IDCs)
    print('Yes, IDC is in there')
else:
    print('No, IDC is not in there')

And the output is as follows:

No, IDC is not in there
Yes, IDC is in there

Is this a bug? My pandas version is 1.4.1. Could you please help check it? Thanks a lot!

Question about GPU devices

May I ask which GPU everyone is using when running MCAT? How many were used in total? I am running the code on a server with 2080ti with the following command: CUDA_VISIBLE_DEVICES=0,1 python3 /usr/CI/MCAT-master/main.py. But the run always uses only one GPU and reports a CUDA out of memory error. Has anyone encountered this problem please?

Question about 'fast_cluster_ids.pkl'

Thank you for sharing your code!I‘m interested in your research, it gives me a lot of inspiration.
While trying to run the code, I'm confused about the file 'fast_cluster_ids.pkl' in dataset_survival.py. I can't find a description about it.
Could you please tell me what this file contains? Thank you very much!

extracted WSI features

Hello,

I want to ask whether you can provide these extracted WSI featues stored in pt files? Because when I use CLAM to extract featues, there are some differences. So, can you provide them so that I can keep same with you.

*** TypeError: object of type 'numpy.int64' has no len()

all_risk_scores = np.zeros((len(loader)))
all_censorships = np.zeros((len(loader)))
all_event_times = np.zeros((len(loader)))

In pdb,'loader' is the type of 'torch.utils.data.dataloader.DataLoader',but when I use iter to check the content of the 'loader',another error occures: *** TypeError: 'numpy.int64' object is not iterable.How could I solve the problem?Thanks a lot.

About test

Hello!Thanks for you contributions.Why don't you add the test mode to the model?I notice that you only use the validation set to examine the performance of the model,but you have the test mode in your another repo HistoFL.

About self.indices

In the class SubsetSequentialSampler,the indices should be iterable,but I get a variable of numpy.int64.My torch version is 2.0.1.And my python version is 3.9.How could I solve the problem?
This is my whole traceback:

Traceback (most recent call last):
  File "/mnt/data0/LI_jihao/mydata/MCAT-master/main.py", line 199, in <module>
    dataset = Generic_MIL_Survival_Dataset(csv_path = '/home/jupyter-ljh/data/mydata/MCAT-master/dataset_csv/tcga_brca_all_clean.csv',
  File "/mnt/data0/LI_jihao/mydata/MCAT-master/main.py", line 71, in main
    val_latest, cindex_latest = train(datasets, i, args)
  File "/mnt/data0/LI_jihao/mydata/MCAT-master/utils/core_utils.py", line 181, in train
    for i,data in enumerate(train_loader):
  File "/opt/tljh/user/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/opt/tljh/user/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 676, in _next_data
    index = self._next_index()  # may raise StopIteration
  File "/opt/tljh/user/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 623, in _next_index
    return next(self._sampler_iter)  # may raise StopIteration
  File "/opt/tljh/user/lib/python3.9/site-packages/torch/utils/data/sampler.py", line 254, in __iter__
    for idx in self.sampler:
  File "/mnt/data0/LI_jihao/mydata/MCAT-master/utils/utils.py", line 36, in __iter__
    return iter(self.indices)
TypeError: 'numpy.int64' object is not iterable

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.