Coder Social home page Coder Social logo

attentiondeepmil's Introduction

Attention-based Deep Multiple Instance Learning

by Maximilian Ilse ([email protected]), Jakub M. Tomczak ([email protected]) and Max Welling

Overview

PyTorch implementation of our paper "Attention-based Deep Multiple Instance Learning":

  • Ilse, M., Tomczak, J. M., & Welling, M. (2018). Attention-based Deep Multiple Instance Learning. arXiv preprint arXiv:1802.04712. link.

Installation

Installing Pytorch 0.3.1, using pip or conda, should resolve all dependencies. Tested with Python 2.7, but should work with 3.x as well. Tested on both CPU and GPU.

Content

The code can be used to run the MNIST-BAGS experiment, see Section 4.2 and Figure 1 in our paper. In order to have a small and concise experimental setup, the code has the following limitation:

  • Mean bag length parameter shouldn't be much larger than 10, for larger numbers the training dataset will become unbalanced very quickly. You can run the data loader on its own to check, see main part of dataloader.py
  • No validation set is used during training, no early stopping

NOTE: In order to run experiments on the histopathology datasets, please download datasets Breast Cancer and Colon Cancer. In the histopathology experiments we used a similar model to the model in model.py, please see the paper for details.

How to Use

dataloader.py: Generates training and test set by combining multiple MNIST images to bags. A bag is given a positive label if it contains one or more images with the label specified by the variable target_number. If run as main, it computes the ratio of positive bags as well as the mean, max and min value for the number per instances in a bag.

mnist_bags_loader.py: Added the original data loader we used in the experiments. It can handle any bag length without the dataset becoming unbalanced. It is most probably not the most efficient way to create the bags. Furthermore it is only test for the case that the target number is ‘9’.

main.py: Trains a small CNN with the Adam optimization algorithm. The training takes 20 epochs. Last, the accuracy and loss of the model on the test set is computed. In addition, a subset of the bags labels and instance labels are printed.

model.py: The model is a modified LeNet-5, see http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf. The Attention-based MIL pooling is located before the last layer of the model. The objective function is the negative log-likelihood of the Bernoulli distribution.

Questions and Issues

If you find any bugs or have any questions about this code please contact Maximilian or Jakub. We cannot guarantee any support for this software.

Citation

Please cite our paper if you use this code in your research:

@article{ITW:2018,
  title={Attention-based Deep Multiple Instance Learning},
  author={Ilse, Maximilian and Tomczak, Jakub M and Welling, Max},
  journal={arXiv preprint arXiv:1802.04712},
  year={2018}
}

Acknowledgments

The work conducted by Maximilian Ilse was funded by the Nederlandse Organisatie voor Wetenschappelijk Onderzoek (Grant DLMedIa: Deep Learning for Medical Image Analysis).

The work conducted by Jakub Tomczak was funded by the European Commission within the Marie Skodowska-Curie Individual Fellowship (Grant No. 702666, ”Deep learning and Bayesian inference for medical imaging”).

attentiondeepmil's People

Contributors

georgebatch avatar jmtomczak avatar kaminyou avatar max-ilse avatar mdraw avatar nzw0301 avatar piyush01123 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

attentiondeepmil's Issues

dataloader.py line 67

Hi,
labels_in_bag = labels_in_bag >= self.target_number
i think the purpose of the code is if the target exists in the labels_in_bag, so '>=' should be '=='
If my understanding is wrong, please correct me.
Thanks

can you provide command for reproducing the results in the paper

Hi

Thanks for the great model and codebase, I am trying to apply the model to my medical imaging dataset.
As a first step, I want to make sure I can run the code correctly, I'm wondering would you able to provide sample commands that reproduce the results in your paper?

Thanks!

High CUDA memory usage

Hi,
I have things done and results are pretty nice! Thanks a lot for this open-source code!

However, I have been running with a mean_bag_length of 10 and 200 bags for testing. Is it usually to get a mean of 25~30 GiB of GPU memory usage? How can I run it more efficiently?
Many thanks!

Question about batch size (not bag size)

Hi, I am just wondering whether the model supports batch size > 1?

I tried batch size > 1 on mnist dataloader and my own data loader, it always said " Expected 4-dimensional input for 4-dimensional weight [20, 1, 5, 5], but got 5-dimensional input of size [2, 10, 1, 28, 28] instead "

thanks!

FYI, I used fixed bag size, so I think it should be supported for batch size > 1

Code for generating Fig 4 & 5

Hi,
Thanks for the paper and the releasing the code.

I was able to run the code for MNIST.

I was curious to know if you had any plans of releasing code to show the visualization of attention as shown in Figure 4 & 5 of the paper.

Thanks

Steps to modify the code for the Breast and colon dataset

I wanted to replicate the experiment in the paper for the Breast and the Colon dataset. I wanted to know if this rough roadmap of code modification is correct or not:

  1. Create a custom data loader by modifying the existing MnistBags. Also, create a custom loader for actually reading in the Histopath images.
  2. Change the number of patches per bag as described in the paper.
  3. The network remains the same LeNet (as mentioned in the ReadME)
  4. The Hyperparameters remain the same I presume.

Any guidance on creating the modifications especially the data loader will be appreciated.
Thanks

Unable to access Breast Cancer and Colon Cancer Datasets

I would like to reproduce your results and run some experiments on the Breast Cancer and Colon Cancer datasets.
Unfortunately, both of the links provided have led me to dead ends.

Breast Cancer
The link takes me to the site, however when I click Download dataset I get a 404 from nginx.

Colon Cancer
The link takes me to a Warwick Uni page, asking me to sign in. I don't have a Warwick Uni login so cannot access the data.

If anyone could provide alternative links, it would be much appreciated.

What is the number of instances per bag in Breast dataset?

Hi,
In the paper, you said that the breast images are divided into 32x32 patches results in 672 patches per bag. But you also mention that a patch is discarded if it contains 75% or more white pixels. So I wondering do you use a fixed number of patches per bag or it depends?
Many thanks.

invalid index of a 0-dim tensor

When I am running the code, here is an error for the program.

exec(code_obj, self.user_global_ns, self.user_ns)
File "", line 141, in
train(epoch)
File "", line 95, in train
error, _ = model.calculate_classification_error(data, bag_label)
File "/home/huafeng/AttentionDeepMIL-master/model.py", line 60, in calculate_classification_error
error = 1. - Y_hat.eq(Y).cpu().float().mean().data[0]
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

And I have tried to change .data[0] into .data or .item(), but neither worked, could anyone give me some suggestions?

Pytorch version:1.1.0

How to process a Histopathology datasets and to print a heatmap?

Hi,

Thanks for your research and the releasing the code.

I am trying to replicate the experiment in the paper for Histopathology datasets(Colon dataset and Breast dataset).
But since I do not know how to process Histopathology datasets, I can not proceed with this experiment.

Specifically, I do not know following part in the paper.
"We randomly adjust the amount of H&E by decomposing the RGB color of the tissue into the H&E color space, followed by multiplying the magnitude of H&E for a pixel by two i.i.d. Gaussian random variables with expectation equal to one."

I am wondering if you could describe this in detail or release the code.

And I did not know how to get the heatmap from the paper.
(I was able to get the rescaled attention weight a', but I do not know how to multiply this by each patch to get the heatmap described on paper.)

I am wondering if you could describe this in detail or release the code.

the problem of mean bag length

hi,
you said 'Mean bag length parameter shouldn't be much larger than 10', however the paper Attention-based Deep Multiple Instance Learning says that 'In the experiments we use different numbers of the mean
bag size, namely, 10 , 50 and 100'. How to explain this contradiction, in my experiments no matter what the variance is, the proportion will be unbalanced as long as the mean bag length more than 10.

What's the bag length in Colon data?

Hi,

I just tried to implement the model on Colon data, the average bag length is just about 100, so the result is pretty bad. Can I know roughly how many patches did you generate for each image in your experiment?

Regards,
Mark

Explanation of variable comments

I am having issues aligning the variables in the code to their respective meanings, mostly N in models.py.

Can you provide more information / references to the paper? Here is my interpretation:

  • N: number of bags (i.e., batch size)
  • K: number of instances in a bag (i.e., embeddings)

However, that doesn't make any sense with these lines from the Attention class:

        H = self.feature_extractor_part2(H)  # NxL

        A = self.attention(H)  # NxK
        A = torch.transpose(A, 1, 0)  # KxN
        A = F.softmax(A, dim=1)  # softmax over N

        M = torch.mm(A, H)  # KxL

because H = self.feature_extractor_part2(H) returns a KxL matrix ([13, 500]), no N involved.

Am I misinterpreting the meaning of N?

Does mnist bag loader make sure it sees all data?

Hello,

I was wondering if your mnist bag loader makes sure that your network sees all samples of digit 9 against negative samples. In particular, does this line generates different random indices that are not seen by the network each time? To be more clear, my understanding is that your data loader generates random indices based on bag length, looks at the data at those indices and if there is a target number (here 9) and makes a bag out of them. In the next iteration the random index generator doesn't consider the already generated random numbers in the last iteration and makes random numbers from scratch. So in this way, in one epoch of learning, your network does not see all the 9s and in the wors case it sees only one 9 in the dataset (say the random number generator "accidentaly" picks only one index of 9 every time)

Question regarding mi-svm

Hi, I know this is not an issue with the repository, but i'm curious about the framework you used for mi-svm and how MNIST data was feeded into mi-svm? did you flatten the images?

Thanks

Data Loader for Histopathology images

Hi, Could you please share the data loader script used for running the experiments on the Breast Cancer dataset.
P.s I'm trying to adopt this model for a bigger dataset ( TCGA where images are huge ) I don't have any information about the instance patches.
And giving all the instances of a bag in a single batch exceeds GPU memory limit.

Reproducing Test AUCs

Hi,

Thank you all for providing this repository for public use. I am trying to reproduce the results from the paper, namely the test AUC given for the 50-bag, 10-instance experiment.

I've run the implementation in this repository with the following command:

python main.py --num_bags_train 50 --num_bags_test 1000

Doing so actually overshoots the result given in the paper, by a substantial margin (0.768 (paper) vs. 0.898 (repo)). I understand that there are differences in the repository vs. the implementation in the paper (i.e. no validation set, no early stopping). However, given that the sample count in the training bags is so small, I am not convinced that such a large difference is due to the data split, and I am printing the AUC at each step for both the train and test set, which should allow me to reason about the early stopping difference. Is there something else that I am missing? Let me know, thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.