Coder Social home page Coder Social logo

explainableml / uncerguidedi2i Goto Github PK

View Code? Open in Web Editor NEW
50.0 5.0 11.0 45.36 MB

Uncertainty Guided Progressive GANs for Medical Image Translation

License: GNU General Public License v3.0

Python 0.08% Jupyter Notebook 99.92%
pytorch deep-learning gans uncertainty-estimation image-to-image-translation medical-imaging bayesian-deep-learning

uncerguidedi2i's Introduction

UncerGuidedI2I

PyTorch imeplementation of Uncertainty Guided Progressive GANs for Medical Image Translation

Introduction

This repository provides the code for the MICCAI-2021 paper titled "Uncertainty-guided Progressive GANs for Medical Image Translation". We take inspiration from the progressive learning scheme demonstrated at MedGAN and Progressive GANs, and augment the learning with the estimation of intermediate uncertainty maps (as presented here and here), that are used as attention map to focus the image translation in poorly generated (highly uncertain) regions, progressively improving the images over multiple phases.

The structure of the repository is as follows:

root
 |-ckpt/ (will save all the checkpoints)
 |-data/ (save your data and related script)
 |-src/ (contains all the source code)
    |-ds.py 
    |-networks.py
    |-utils.py
    |-losses.py

Getting started

Requirements

python >= 3.6.10
pytorch >= 1.6.0
jupyter lab
torchio
scikit-image
scikit-learn

Preparing Datasets

The experiments of the paper used T1 MRI scans from the IXI dataset and a proprietary PET/CT dataset.

data/IXI/ has jupyter notebooks and scripts to prepare the data for motion correction (data/IXI/prepare_motion_correction_data.py and data/IXI/viz_motion_correction_data.ipynb) as well as undersampled MRI reconstruction (data/IXI/viz_kspace_undersample_data.ipynb). For custom datasets, use the above notebooks as example to prepare the dataset and place them under data/. The dataset class in src/ds.py loads the paired set of images (corrupted and the non-corrupted version).

Learning models with uncertainty

src/networks.py provides the generator and discriminator architectures.

src/utils.py provides two training APIs train_i2i_UNet3headGAN and train_i2i_Cas_UNet3headGAN. The first API is to be used to train the primary GAN, whereas the second API is to be used to train the subsequent GANs.

An example command to use the first API is:

netG_A = CasUNet_3head(1,1)
netD_A = NLayerDiscriminator(1, n_layers=4)
netG_A, netD_A = train_i2i_UNet3headGAN(
    netG_A, netD_A,
    train_loader, test_loader,
    dtype=torch.cuda.FloatTensor,
    device='cuda',
    num_epochs=50,
    init_lr=1e-5,
    ckpt_path='../ckpt/i2i_0_UNet3headGAN',
)

This will save checkpoints in ../ckpt/ named as i2i_0_UNet3headGAN_eph*.pth

An example command to use the second API (here we assumed the primary GAN and first subsequent GAN are trained already):

# first load the prior Generators 
netG_A1 = CasUNet_3head(1,1)
netG_A1.load_state_dict(torch.load('../ckpt/i2i_0_UNet3headGAN_eph49_G_A.pth'))
netG_A2 = UNet_3head(4,1)
netG_A2.load_state_dict(torch.load('../ckpt/i2i_1_UNet3headGAN_eph49_G_A.pth'))

#initialize the current GAN
netG_A3 = UNet_3head(4,1)
netD_A = NLayerDiscriminator(1, n_layers=4)

#train the cascaded framework
list_netG_A, list_netD_A = train_uncorr2CT_Cas_UNet3headGAN(
    [netG_A1, netG_A2, netG_A3], [netD_A],
    train_loader, test_loader,
    dtype=torch.cuda.FloatTensor,
    device='cuda',
    num_epochs=50,
    init_lr=1e-5,
    ckpt_path='../ckpt/i2i_2_UNet3headGAN',
)

Bibtex

If you find the bits from this project helpful, please cite the following works:

@inproceedings{upadhyay2021uncerguidedi2i,
  title={Uncertainty Guided Progressive GANs for Medical Image Translation},
  author={Upadhyay, Uddeshya and Chen, Yanbei and Hebb, Tobias and Gatidis, Sergios and Akata, Zeynep},
  booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention (MICCAI)},
  year={2021},
  organization={Springer}
}

and

@article{upadhyay2021uncertainty,
  title={Uncertainty-aware Generalized Adaptive CycleGAN},
  author={Upadhyay, Uddeshya and Chen, Yanbei and Akata, Zeynep},
  journal={arXiv preprint arXiv:2102.11747},
  year={2021}
}

uncerguidedi2i's People

Contributors

udion avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

uncerguidedi2i's Issues

uncertainty map is not computed

Hi,

Thanks for sharing the code for your great work here!

I have a question. In your paper, for each Generator (except the last one), the output is the concatenation of uncertainty map and the rec img, which is described in Fig. 1 and equation (3),(4) in your paper.

But in your code, the output is the concatenation of rec img, alpha, beta and the original input img.

Looking forward to your reply. Thanks.

About loading datasets

Dear author
Hello! I found that the process of how to load the dataset is missing in the utils.py file during the reproduction process. I am very confused, and I don't know how to place and load the dataset. I hope you can answer it.

question about loss function

def bayeGen_loss(out_mean, out_1alpha, out_beta, target):

Amazing work! But I still have some questions about the loss calculation. It seems different from the paper's formula. Why the Beta in the exponential of residual is replaced to multiply to the residual and the clamp operation would not break the formula?

list_epochs = [50, 50, 150]

Hello
Could you please let us know what is the reason for this part and what is it doing?
While num_epochs is an input too.

list_epochs = [50, 50, 150]

Question about: ResConv module

Hello,
I felt confused about your ResConv module:

class ResConv(nn.Module):
    """
    Residual convolutional block, where
    convolutional block consists: (convolution => [BN] => ReLU) * 3
    residual connection adds the input to the output
    """
    ...

  def forward(self, x):
          x_in = self.double_conv1(x)
          x1 = self.double_conv(x)
          return self.double_conv(x) + x_in

your forward module just mean return x1 + x_in , different with you function description?

Looking forward to your reply. Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.