Coder Social home page Coder Social logo

mehta-lab / viscy Goto Github PK

View Code? Open in Web Editor NEW
26.0 6.0 3.0 7.33 MB

computer vision models for single-cell phenotyping

Home Page: https://pypi.org/project/viscy/

License: BSD 3-Clause "New" or "Revised" License

Python 99.52% Shell 0.48%
bioimage-analysis computer-vision machine-learning

viscy's Introduction

VisCy

VisCy is a deep learning pipeline for training and deploying computer vision models for image-based phenotyping at single-cell resolution.

The following methods are being developed:

  • Image translation
    • Robust virtual staining of landmark organelles
  • Image classification
    • Supervised learning of of cell state (e.g. state of infection)
  • Image representation learning
    • Self-supervised learning of the cell state and organelle phenotypes
Note:
VisCy is currently considered alpha software and is under active development. Frequent breaking changes are expected.

Virtual staining

Pipeline

A full illustration of the virtual staining pipeline can be found here.

Library of virtual staining (VS) models

The robust virtual staining models (i.e VSCyto2D, VSCyto3D, VSNeuromast), and fine-tuned models can be found here

Demos

Image-to-Image translation using VisCy

  • Guide for Virtual Staining Models: Instructions for how to train and run inference on ViSCy's virtual staining models (VSCyto3D, VSCyto2D and VSNeuromast)

  • Image translation Exercise: Example showing how to use VisCy to train, predict and evaluate the VSCyto2D model. This notebook was developed for the DL@MBL2024 course.

  • Virtual staining exercise: exploring the label-free to fluorescence virtual staining and florescence to label-free image translation task using VisCy UneXt2. More usage examples and demos can be found here

Gallery

Below are some examples of virtually stained images (click to play videos). See the full gallery here.

VSCyto3D VSNeuromast VSCyto2D
HEK293T Neuromast A549

Reference

The virtual staining models and training protocols are reported in our recent preprint on robust virtual staining:

@article {Liu2024.05.31.596901,
    author = {Liu, Ziwen and Hirata-Miyasaki, Eduardo and Pradeep, Soorya and Rahm, Johanna and Foley, Christian and Chandler, Talon and Ivanov, Ivan and Woosley, Hunter and Lao, Tiger and Balasubramanian, Akilandeswari and Liu, Chad and Leonetti, Manu and Arias, Carolina and Jacobo, Adrian and Mehta, Shalin B.},
    title = {Robust virtual staining of landmark organelles},
    elocation-id = {2024.05.31.596901},
    year = {2024},
    doi = {10.1101/2024.05.31.596901},
    publisher = {Cold Spring Harbor Laboratory},
    URL = {https://www.biorxiv.org/content/early/2024/06/03/2024.05.31.596901},
    eprint = {https://www.biorxiv.org/content/early/2024/06/03/2024.05.31.596901.full.pdf},
    journal = {bioRxiv}
}

This package evolved from the TensorFlow version of virtual staining pipeline, which we reported in this paper in 2020:

@article {10.7554/eLife.55502,
article_type = {journal},
title = {Revealing architectural order with quantitative label-free imaging and deep learning},
author = {Guo, Syuan-Ming and Yeh, Li-Hao and Folkesson, Jenny and Ivanov, Ivan E and Krishnan, Anitha P and Keefe, Matthew G and Hashemi, Ezzat and Shin, David and Chhun, Bryant B and Cho, Nathan H and Leonetti, Manuel D and Han, May H and Nowakowski, Tomasz J and Mehta, Shalin B},
editor = {Forstmann, Birte and Malhotra, Vivek and Van Valen, David},
volume = 9,
year = 2020,
month = {jul},
pub_date = {2020-07-27},
pages = {e55502},
citation = {eLife 2020;9:e55502},
doi = {10.7554/eLife.55502},
url = {https://doi.org/10.7554/eLife.55502},
keywords = {label-free imaging, inverse algorithms, deep learning, human tissue, polarization, phase},
journal = {eLife},
issn = {2050-084X},
publisher = {eLife Sciences Publications, Ltd},
}

Installation

  1. We recommend using a new Conda/virtual environment.

    conda create --name viscy python=3.10
    # OR specify a custom path since the dependencies are large:
    # conda create --prefix /path/to/conda/envs/viscy python=3.10
  2. Install a released version of VisCy from PyPI:

    pip install viscy

    If evaluating virtually stained images for segmentation tasks, install additional dependencies:

    pip install "viscy[metrics]"

    Visualizing the model architecture requires visual dependencies:

    pip install "viscy[visual]"
  3. Verify installation by accessing the CLI help message:

    viscy --help

Contributing

For development installation, see the contributing guide.

Additional Notes

The pipeline is built using the PyTorch Lightning framework. The iohub library is used for reading and writing data in OME-Zarr format.

The full functionality is tested on Linux x86_64 with NVIDIA Ampere GPUs (CUDA 12.4). Some features (e.g. mixed precision and distributed training) may not be available with other setups, see PyTorch documentation for details.

viscy's People

Contributors

alishbaimran avatar edyoshikun avatar mattersoflight avatar soorya19pradeep avatar ziw-liu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

viscy's Issues

Contrastive Learning Implementation

My understanding of the current classifier and task with 60X res data can be found here: https://docs.google.com/document/d/1j3UePmDJL_1V_9j7v3I4nLgAgKuFXXuqmlW8ZFOyTk0/edit?usp=sharing.

Given this approach, we'd need to modify HCSDataModule to support triplet sampling. Specifically:

  • We can apply transformations like rotation, cropping, and color jitter to create different views of the same cell.
  • Generate an anchor and a positive sample using different augmentations of the same cell image.
  • Select a different cell with a different label (infected vs. uninfected) as the negative sample.

The goal of triplet sampling is to minimize the distance between the anchor and the positive while maximizing the distance between the anchor and the negative in the learned embedding space.

# takes a base_transform and applies it to a sample to generate anchor and positive samples.
# When the __call__ method is invoked with a sample, it applies the base_transform to the sample twice: first to create the anchor and second to create the positive.

class TripletTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, sample):
        anchor = self.transform(sample)
        positive = self.transform(sample)
        return anchor, positive


# The TripletDataset class is initialized with the dataset and a transform function. When the __getitem__ method is called with an index (idx):
# Anchor and Positive: The same data sample is retrieved for both the anchor and positive.
# Negative Sampling: A different sample is randomly selected as the negative.
#  If a transform is provided:
# The TripletTransform is used to apply the base_transform to both the anchor and positive samples, creating augmented versions.
# The base_transform is applied directly to the negative sample to create its augmented version (if wanted).

class TripletDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        anchor = self.data[idx]
        positive = self.data[idx]
        # simple negative sampling
        negative_idx = ...
        negative = self.data[negative_idx]
        if self.transform:
            anchor = self.transform(anchor)
            positive = self.transform(positive)
            negative = self.transform(negative)
        return (anchor, positive, negative)


Here the TripletTransform class takes a base transformation (defined in base_transform) and applies it to create the anchor and positive samples.

Modify HCSDataModule:

class TripletHCSDataModule(HCSDataModule):
    def __init__(
        self,
        data_path: str,
        source_channel: Union[str, Sequence[str]],
        target_channel: Union[str, Sequence[str]],
        z_window_size: int,
        split_ratio: float = 0.8,
        batch_size: int = 16,
        num_workers: int = 8,
        architecture: Literal["2D", "2.1D", "2.2D", "2.5D", "3D", "fcmae"] = "2.5D",
        yx_patch_size: tuple[int, int] = (256, 256),
        normalizations: list[MapTransform] = [],
        augmentations: list[MapTransform] = [],
        caching: bool = False,
        ground_truth_masks: Optional[Path] = None,
    ):
        super().__init__(
            data_path,
            source_channel,
            target_channel,
            z_window_size,
            split_ratio,
            batch_size,
            num_workers,
            architecture,
            yx_patch_size,
            normalizations,
            augmentations,
            caching,
            ground_truth_masks
        )
        self.triplet_transform = TripletTransform(transforms.Compose(normalizations + augmentations))

#update to use TripletDataset
    def setup(self, stage: Optional[str] = None):
        super().setup(stage)
        if stage in ("fit", "validate"):
            self.train_dataset = TripletDataset(self.train_dataset.data, transform=self.triplet_transform)
            self.val_dataset = TripletDataset(self.val_dataset.data, transform=self.triplet_transform)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size // 3,  # adjust batch size for triplets
            num_workers=self.num_workers,
            shuffle=True,
            persistent_workers=bool(self.num_workers),
            prefetch_factor=4 if self.num_workers else None,
            drop_last=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size // 3,  # adjust batch size for triplets
            num_workers=self.num_workers,
            shuffle=False,
            prefetch_factor=4 if self.num_workers else None,
            persistent_workers=bool(self.num_workers),
        )

# example of what could be included in the augmentations list
base_transform = transforms.Compose([
    transforms.RandomApply([transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)], p=0.5),
    transforms.RandomRotation(10),
    transforms.RandomResizedCrop(128, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor()
])

Using this updated dataloader:

data_module = TripletHCSDataModule(
    dataset_path="...",
    source_channel=["Phase", "Sensor"],
    target_channel=["Inf_mask"],
    yx_patch_size=[128, 128],
    split_ratio=0.8,
    z_window_size=1,
    architecture="2D",
    num_workers=4,
    batch_size=64,
    normalizations=[
        NormalizeSampled(
            keys=["Sensor", "Phase"],
            level="fov_statistics",
            subtrahend="median",
            divisor="iqr",
        )
    ],
    augmentations=[
        RandWeightedCropd(
            num_samples=8,
            spatial_size=[-1, 128, 128],
            keys=["Sensor", "Phase", "Inf_mask"],
            w_key="Inf_mask",
        )
    ]
)

Model details:

  • Use an encoder for embeddings. The contrastive learning model uses this encoder to generate embeddings and compute the triplet loss. Different losses: Triplet Margin Loss, AllTripletMiner, NTXent.
  • Input: The input to the model is the same (e.g., phase and sensor data).
  • Output: The model outputs embeddings.
  • Loss Function: Triplet loss is used to train the model to minimize the distance between embeddings of similar samples and maximize the distance between embeddings of dissimilar samples.
  • Validation: The validation process compares the embeddings using the triplet loss, ensuring that the model learns useful representations of the cells.

Other ideas: try simclr vs triplet sampling

  • SimCLR: generates positive pairs by applying different augmentations to the same sample. Negative samples are implicitly created from other samples in the same batch.
  • Triplet Sampling: explicitly forms triplets consisting of an anchor, a positive, and a negative.

Good resource: https://lilianweng.github.io/posts/2021-05-31-contrastive/

organize application scripts on GitHub rather than disk

@ziw-liu @edyoshikun I propose we organize the scripts that use VisCy (with other libraries) for specific applications in applications/infection_screens, applications/organelle_phenotyping, and applications/ultrack. I suggest this approach, because the scope of the repo is single-cell phenotyping. It is much easier to find code on git rather than disk, and encourages clean code organization.

Test framework

The test infrastracture of microDL 1.0 was a mixture of unittest and nose. However since we are rewriting it in a different framework, very few (if any) test can be reused. This brings up an opportunity to switch to pytest, which most contributors of this project will be more familiar with since all of our other projects use it.

sample['target'] from a dataloader batch returns the wrong z-shape

We were expecting that after setting an HCSDataModule().test_dataloader()

for i, sample in enumerate(HCSDataModule().test_dataloader()):
     break
sample['target'].shape # returns( torch.Size([64, 2, 5, 512, 512]))

The expected behavior should be that sample['target'] is 1.

Note:
When predicting, I see that we select the middle of the stack, but the default setting for the target should be always be one

def on_before_batch_transfer(self, batch: Sample, dataloader_idx: int) -> Sample:
        predicting = False
        if self.trainer:
            if self.trainer.predicting:
                predicting = True
        if predicting or isinstance(batch, torch.Tensor):
            # skipping example input array
            return batch
        if self.target_2d:
            # slice the center during training or testing
            z_index = self.z_window_size // 2
            batch["target"] = batch["target"][:, :, slice(z_index, z_index + 1)]
        return batch

Normalization for patches ome-zarr

We want to compute statistics from the FOV-scale zarr store and store it with patch-scale zarr store, which dataloader will parse.

The process for this is:

  • Compute statistics per FOV (not patch) and store these with metadata.
  • Normalize using FOV statistics at training and test time - use existing CLI or preprocessing script.

This is current structure of the patches ome_zarr:

=== Summary ===
Format: omezarr v0.4
Axes: T (time); C (channel); Z (space); Y (space); X (space);
Channel names: ['RFP', 'Phase3D']
Row names: ['A', 'B']
Column names: ['3', '4']
Wells: 4
Positions: 2629

image
image

This is structure of track_labels zarr where we can store the meta data: (fov information can be stored here)
=== Summary ===
Format: omezarr v0.4
Axes: T (time); C (channel); Z (space); Y (space); X (space);
Channel names: ['tracking']
Row names: ['A', 'B']
Column names: ['3', '4']
Wells: 4
Positions: 61

Screenshot 2024-06-28 at 10 00 40 AM

Need help on how to integrate metadata into the track ome_zarr, how this will be used to generate the patch ome_zarr and then eventually the dataloader for normalization.

Illustrative test datasets

Hi. Really want to try virtual staining, however I can't find Illustrative test datasets mentioned in the article in this repository. Am I missing something?

Augmentation strategy for generalization across magnification

Different magnification of the microscope alters the sampling of all 3 spatial dimensions. And the changes in Z is different from that in XY. If we want to train a model on a single dataset that generalizes across magnifications, we need to employ augmentation strategies that simulate the changes in spatial sampling.

@mattersoflight pointed out that scaling down is not a good approximate of reducing magnification. Blurring (e.g. Gaussian filtering) before rescaling can simulate the integration of information along the light path and reduce artifacts.

Another question is that how do we determine the training time Z sampling for better utilization of defocus information. This can potentially be estimated from magnification, Z step size, and the NA of illumination and detection.

Predict module not giving expected output

The output images generated by the prediction module are not as expected. Here is the snapshot of the prediction from the 2D infection classifier trained for BJ5 cells generated using the test module:
image
This is the ground truth image of the infection score:
image
Here is the output from the prediction module:
image
@ziw-liu, where do you think this can originate, and can you please help me debug this? Thanks!

document published models on Wiki

We should document the models and releases that match each preprint or paper we post.
The page should describe how to run inference, and where appropriate, how to train the model.
Creating a "Model library" page in Wiki makes sense, with the following three sub-pages:

Zenodo record you put together for mantis: https://zenodo.org/records/10403605 is a good template for what to share via these pages. And, the Wiki pages can go beyond and point to specific scripts in this or other repos.

strategies to improve invariance with data augmentation

We use intensity scaling and noise augmentations to make the virtual staining model invariant. We should leverage augmentations while keeping the training process stable and efficient.

This paper suggests a simple strategy and reports that it is effective: include many augmentations of the same sample to construct the batch, and average the losses (which happens naturally). @ziw-liu what is the current strategy in HCSDataModule? Can you test the strategy reported in Fig. 1B (top) of this paper?

PS: The paper also reports the regularization of a classification model with KL divergence over the augmentations. This doesn't translate naturally to virtual staining.

Clarify how to use different stages of VisCy

I was expecting that if you call HCSDataModule().setup('fit') the DataModule should fit the data and re-write the normalization dictionary. However, when this is called twice in a row, we get:

KeyError                                  Traceback (most recent call last)
[/home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py) in line 14
      [34](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=33) # %%
      [35](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=34) data_module = HCSDataModule(
      [36](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=35)     input_data_path,
      [37](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=36)     source_channel="Phase3D",
   (...)
     [45](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=44)     augment=False,  # Turn off augmentation for now.
     [46](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=45) )
---> [47](file:///home/eduardoh/vs_data/1-test_dataloader/edhirata_dataloader.py?line=46) data_module.setup("fit")

File [~/VisCy/viscy/light/data.py:404](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:404), in HCSDataModule.setup(self, stage)
    [402](file:///home/eduardoh/VisCy/viscy/light/data.py?line=401) dataset_settings = dict(channels=channels, z_window_size=self.z_window_size)
    [403](file:///home/eduardoh/VisCy/viscy/light/data.py?line=402) if stage in ("fit", "validate"):
--> [404](file:///home/eduardoh/VisCy/viscy/light/data.py?line=403)     self._setup_fit(dataset_settings)
    [405](file:///home/eduardoh/VisCy/viscy/light/data.py?line=404) elif stage == "test":
    [406](file:///home/eduardoh/VisCy/viscy/light/data.py?line=405)     self._setup_test(dataset_settings)

File [~/VisCy/viscy/light/data.py:429](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:429), in HCSDataModule._setup_fit(self, dataset_settings)
    [428](file:///home/eduardoh/VisCy/viscy/light/data.py?line=427) def _setup_fit(self, dataset_settings: dict):
--> [429](file:///home/eduardoh/VisCy/viscy/light/data.py?line=428)     plate, normalize_transform = self._setup_eval(dataset_settings)
    [430](file:///home/eduardoh/VisCy/viscy/light/data.py?line=429)     fit_transform = self._fit_transform()
    [431](file:///home/eduardoh/VisCy/viscy/light/data.py?line=430)     train_transform = Compose(
    [432](file:///home/eduardoh/VisCy/viscy/light/data.py?line=431)         [normalize_transform] + self._train_transform() + fit_transform
    [433](file:///home/eduardoh/VisCy/viscy/light/data.py?line=432)     )

File [~/VisCy/viscy/light/data.py:424](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/VisCy/viscy/light/data.py:424), in HCSDataModule._setup_eval(self, dataset_settings)
    [420](file:///home/eduardoh/VisCy/viscy/light/data.py?line=419) if self.normalize_source:
    [421](file:///home/eduardoh/VisCy/viscy/light/data.py?line=420)     norm_keys += self.source_channel
    [422](file:///home/eduardoh/VisCy/viscy/light/data.py?line=421) normalize_transform = NormalizeSampled(
    [423](file:///home/eduardoh/VisCy/viscy/light/data.py?line=422)     norm_keys,
--> [424](file:///home/eduardoh/VisCy/viscy/light/data.py?line=423)     plate.zattrs["normalization"],
    [425](file:///home/eduardoh/VisCy/viscy/light/data.py?line=424) )
    [426](file:///home/eduardoh/VisCy/viscy/light/data.py?line=425) return plate, normalize_transform

File [~/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py:73](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/VisCy/~/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py:73), in Attributes.__getitem__(self, item)
     [72](file:///home/eduardoh/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py?line=71) def __getitem__(self, item):
---> [73](file:///home/eduardoh/conda/envs/viscy/lib/python3.10/site-packages/zarr/attrs.py?line=72)     return self.asdict()[item]

KeyError: 'normalization'

Unable to load two channesl as inputs to do fluoresence to phase image translation

I was running the MBL DL2023 example notebook and ran into this issue at the end trying to predict the phase using 2 fluorescence channels.

tune_data = HCSDataModule(
    data_path,
    source_channel= ["Nuclei","Membrane"],
    target_channel="Phase",
    z_window_size=1,
    split_ratio=0.8,
    batch_size=BATCH_SIZE,
    num_workers=10,
    architecture="2D",
    yx_patch_size=YX_PATCH_SIZE,
    augment=True,
)
tune_data.setup("fit")

tune_config = {
    "architecture": "2D",
    "num_filters": [24, 48, 96, 192, 384],
    "in_channels": 2,
    "out_channels":1,
    "residual": True,
    "dropout": 0.1,  # dropout randomly turns off weights to avoid overfitting of the model to data.
    "task": "reg",  # reg = regression task.
}
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[30], line 58
     42 n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.
     44 trainer = VSTrainer(
     45     accelerator="gpu",
     46     devices=[GPU_ID],
   (...)
     55         ),
     56 )  
---> 58 trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
     61 # Visualize the graph of fluor2phase model as image.
     62 model_graph_fluor2phase = torchview.draw_graph(
     63     fluor2phase_model,
     64     fluor2phase_data.train_dataset[0]["source"],
     65     depth=2,  # adjust depth to zoom in.
     66     device="cpu",
     67 )

File [~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/DL-MBL-2023/04_image_translation/~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )
...
    458                     _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
    460                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [24, 2, 3, 3], expected input[1, 1, 512, 512] to have 2 channels, but got 1 channels instead---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[30], line 58
     42 n_epochs = 50 # Set this to 50 or the number of epochs you want to train for.
     44 trainer = VSTrainer(
     45     accelerator="gpu",
     46     devices=[GPU_ID],
   (...)
     55         ),
     56 )  
---> 58 trainer.fit(fluor2phase_model, datamodule=fluor2phase_data)
     61 # Visualize the graph of fluor2phase model as image.
     62 model_graph_fluor2phase = torchview.draw_graph(
     63     fluor2phase_model,
     64     fluor2phase_data.train_dataset[0]["source"],
     65     depth=2,  # adjust depth to zoom in.
     66     device="cpu",
     67 )

File [~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532](https://vscode-remote+ssh-002dremote-002bec2-002d3-002d144-002d39-002d9-002eus-002deast-002d2-002ecompute-002eamazonaws-002ecom.vscode-resource.vscode-cdn.net/home/eduardoh/DL-MBL-2023/04_image_translation/~/conda/envs/04_image_translation/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py:532), in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    530 self.strategy._lightning_module = model
    531 _verify_strategy_supports_compile(model, self.strategy)
--> 532 call._call_and_handle_interrupt(
    533     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    534 )
...
    458                     _pair(0), self.dilation, self.groups)
--> 459 return F.conv2d(input, weight, bias, self.stride,
    460                 self.padding, self.dilation, self.groups)

RuntimeError: Given groups=1, weight of size [24, 2, 3, 3], expected input[1, 1, 512, 512] to have 2 channels, but got 1 channels instead

Discussion of what metrics to prioritize in inference tensorboard

Each time inference is run it generates a tensorboard output displaying the outputs of inference and metrics calculated along multiple orientations. Each scalar added (each sub-tag corresponding to a metric calculated) generates a new folder in a recursive structure, resulting in potentially prohibitively long session load-times depending on the number and type of metrics you calculate.

Let's have a discussion here about:

  1. What metrics have we implemented (@Soorya19Pradeep) and whether they are integrated into the inference module.
  2. Which of these metrics should we prioritize supporting on inference calls.

Implement test time augmentations to avoid high frequency fluctuations

With a variety of models, we see that virtually stained structures show high-frequency fluctuations over time. This may be due to the sensitivity of ConvNets to small changes in local texture and orientation of edges. One way to average over these effects is to do test time augmentation. Let's experiment with one of our trained models and compare the temporal consistency of predictions with and without test-time augmentations.

Relevant paper: https://www.nature.com/articles/s41598-020-61808-3

Does phase input need to be normalized with dataset statistics?

When deploying on a microscope system for online inference, we may not have access to dataset-level summary statistics for normalization.

During a discussion with @ieivanov and @Soorya19Pradeep, the necessity of normalizing phase input came under question. If deconvolution already guarantees zero mean for background, and the exact values of foreground pixels carry physical meaning (relative phase delay), further subtraction and scaling seems redundant for thin samples such as monolayer cell cultures. The remaining small variations can be adjusted by augmentation during model training.

@mattersoflight may have more input on whether there is empirical evidence that normalization for phase is still necessary at inference time.

Save augmented tiles

Feature request: Saving a faction of augmented tiles will help to get a better feeling for the augmentation parameters with the possibility to visualize them. As a lot of tiles are created during training, really only a fraction should be saved. For example, the training config could define a path to save the tiles, how many to be saved, and whether or not to save them.

architectures and nomenculature

Now that 2D and 2.5D UNets from our 2020 paper are implemented in pytorch, we are exploring the space of architectures in two ways:
a) input-output tensor dimensions.
b) Using SOTA convolutional layers, particularly inspired by ConvNeXt.

At this point, the 2.1D network combines both. It is useful to have distinct nomenclature and models to compare these two innovations.

I suggest:

  • 2D, 2.5D, 2.1D, 3D architectures use classical convolutional layers and activations.
  • Architectures that use ConvNeXt design principles can use 2NeX, 2.5NeX, 2.1Nex, ... nomenclature.

Variable input size training and data pooling

To train a model with datasets of different dimensionalities (2D/3D) and imaging modalities (bright field, quantitative phase, Zernike phase contrast etc.), the training pipeline needs to have these new features:

  • Data loading that draws samples from different data stores
  • Batch collation and training loop that works with 2D and 3D data at the same time
    • Use loss aggregation or alternation?
  • A dynamic model stem with projection layers for different input dimensions

Move to NumPy style docstrings?

This package inherited single-line sphinx style docstrings from microDL. However this is different from the NumPy style that most of our contributors are used to and sometimes IDE settings get mixed up. Maybe moving to NumPy style will reduce some friction?

sampling strategy to improve class balance across cell cycle

At a recent meeting, we discussed strategies to achieve class balance across the cell cycle. @ziw-liu proposed a selection of FOVs based on the rough measure of the shape of the cells, which I think is a good way to digitally sort the FOVs while constructing a batch.

Let's continue to think about this. We need:
a) rough measures of the cell cycle stage:

  • cytoplasm/nucleus ratio as measured from 'dirty' segmentations of target channels comes to my mind as the first measure to try. It should work for multiple cell types and microscopes.

b) strategies to assign a probability of sampling to a FOV or a patch:

  • @ziw-liu I recall you used the fluorescence channel itself as a weight mask. Can you point to that call?
  • We can preprocess or annotate each FOV to assign it a score and use the score to achieve class balance.

image translation exercise for DL@MBL

The students will program a UNet near the start of the course and we have scheduled image translation after segmentation and denoising exercises. We have a 6-hour window for this exercise, which will be divided between deterministic image translation and generative image translation. For deterministic image translation, we'd use UNeXt2 architecture, and for generative translation we'll probably use conditional GAN used in this paper .

Considering the above plan and the need for a demo notebook for release 0.2.0 of VisCy, I suggest that we develop a demo notebook that illustrate the training of the VSCyto3D and VSNeuromast models.

Here are Alishba's fixes to last year's exercise:
https://github.com/alishbaimran/image_translation/blob/solution/solution.ipynb
https://docs.google.com/document/d/1h3u42hodHN7nQz9qND-NQc7uOm72fBi7DxURNYuuYPM/edit

configurable augmentations

A key hyperparameter of our training process is the choice of augmentations. At this point, the augmentations are hard coded. @ziw-liu how can we make the augmentations configurable via pytorch lightning config? A useful side-effect is that augmentations that go into each computational experiment will be automatically documented.

API for 2D UNeXt2 model construction needs clarity

The DL@MBL exercise (#100) uses the following call to create a 2D UNeXt2 model.

phase2fluor_model = VSUNet(
    architecture='fcmae', #2D UNeXt2 architecture
...

The above can mislead users into thinking they are working with a masked autoencoder.

@ziw-liu does this also affect #70 ? Can you please fix this while merging #100?

CLI for multiple training tasks

As part of #126, @ziw-liu suggested:

For the future iteration (v0.3?) it might make sense to replace the manual CLI entry point (viscy [command]) with python's built-in (python -m viscy.cli.module [command]) to make supporting multiple learning tasks easier.

This makes a lot of sense and I'd vote to make this change sooner to clarify that the repo does implement multiple learning tasks. It is easy to change our virtual training scripts that launch training from viscy [command] -> python -m viscy.cli.module [command]. Are there other factors for which we should wait to refactor CLI?

@ziw-liu @edyoshikun

Augmentation strategy for mantis data

Label-free datasets acquired on the mantis have similar in-plane pixel size as hummingbird@63x, where a lot of the training data was acquired. To apply 2.5D virtual staining across microscopes, however, axial spatial augmentation is needed to match training data distribution (250 nm) to the wide range of existing (570 nm, czbiohub-sf/shrimPy#69) and future (205 nm) Z-sampling on mantis.

@talonchandler suggests that trilinear is a good interpolation to start with (0.4x to 1.2x scaling in Z). We will also investigate the unit of voxel intensity values in the reconstructed phase images, since mantis has a different illumination wavelength (450 nm) than hummingbird (532 nm).

Pinging @Soorya19Pradeep and @edyoshikun in case these numbers are not accurate.

upgrading lightning to >2.0.8 results to trainer issues

This error does not happen in lightning==2.0.1, which is what I had installed by default from viscy. However, I tried upgrading to lighting 2.3.0.dev0 to circumvent the caching timeout issue here, but I got the following error for which we will have to make sure our tensors are on the right device according to this. Flagging it just in case you also encounter it @ziw-liu .

Traceback (most recent call last):
  File "/hpc/projects/comp.micro/virtual_staining/models/fcmae-3d/fit/pretrain_scratch_path.py", line 141, in <module>
    trainer.fit(model, data)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 987, in _run
    results = self._run_stage()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1031, in _run_stage
    self._run_sanity_check()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 1060, in _run_sanity_check
    val_loop.run()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 142, in run
    return self.on_run_end()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 254, in on_run_end
    self._on_evaluation_epoch_end()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 336, in _on_evaluation_epoch_end
    trainer._logger_connector.on_epoch_end()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 195, in on_epoch_end
    metrics = self.metrics
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py", line 234, in metrics
    return self.trainer._results.metrics(on_step)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 483, in metrics
    value = self._get_cache(result_metric, on_step)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 447, in _get_cache
    result_metric.compute()
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 289, in wrapped_func
    self._computed = compute(*args, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/logger_connector/result.py", line 249, in compute
    value = self.meta.sync(self.value.clone())  # `clone` because `sync` is in-place
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/ddp.py", line 342, in reduce
    return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 173, in _sync_ddp_if_available
    return _sync_ddp(result, group=group, reduce_op=reduce_op)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/lightning/fabric/utilities/distributed.py", line 223, in _sync_ddp
    torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 72, in wrapper
    return func(*args, **kwargs)
  File "/hpc/mydata/eduardo.hirata/.conda/envs/viscy/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 1992, in all_reduce
    work = group.allreduce([tensor], opts)
RuntimeError: No backend type associated with device type cpu

clean up viscy cli display

viscy --help prints a useful and succinct help message.

But, viscy subcommand --help prints a lot of lightning CLI info that is not relevant, e.g.,

viscy preprocess --help prints:

  --lr_scheduler CONFIG | CLASS_PATH_OR_NAME | .INIT_ARG_NAME VALUE
                        One or more arguments specifying "class_path" and
                        "init_args" for any subclass of {torch.optim.lr_schedu
                        ler.LRScheduler,lightning.pytorch.cli.ReduceLROnPlatea
                        u}. (type: Union[LRScheduler, ReduceLROnPlateau],
                        known subclasses:
                        torch.optim.lr_scheduler.LRScheduler,
                        monai.optimizers.LinearLR,
                        monai.optimizers.ExponentialLR,
                        torch.optim.lr_scheduler.LambdaLR,
                        monai.optimizers.WarmupCosineSchedule,
                        torch.optim.lr_scheduler.MultiplicativeLR,
                        torch.optim.lr_scheduler.StepLR,
                        torch.optim.lr_scheduler.MultiStepLR,
                        torch.optim.lr_scheduler.ConstantLR,
                        torch.optim.lr_scheduler.LinearLR,
                        torch.optim.lr_scheduler.ExponentialLR,
                        torch.optim.lr_scheduler.SequentialLR,
                        torch.optim.lr_scheduler.PolynomialLR,
                        torch.optim.lr_scheduler.CosineAnnealingLR,
                        torch.optim.lr_scheduler.ChainedScheduler,
                        torch.optim.lr_scheduler.ReduceLROnPlateau,
                        lightning.pytorch.cli.ReduceLROnPlateau,
                        torch.optim.lr_scheduler.CyclicLR,
                        torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
                        torch.optim.lr_scheduler.OneCycleLR,
                        torch.optim.swa_utils.SWALR,
                        lightning.pytorch.cli.ReduceLROnPlateau)

Compute dataset statistics before training or testing for normalization:
  --data_path DATA_PATH
                        (required, type: <class 'Path'>)
  --channel_names CHANNEL_NAMES, --channel_names+ CHANNEL_NAMES
                        (type: Union[list[str], Literal[-1]], default: -1)
  --num_workers NUM_WORKERS
                        (type: int, default: 1)
  --block_size BLOCK_SIZE
                        (type: int, default: 32)

@ziw-liu please fix this.

Three channel output using covnext model

I want to train the covnext model for the infection classifier problem. The model will use two or more input channels (phase + HSP90 channels + infection score channels, etc.) and should output three channels (background + uninfected + infected channels). Currently, I can perform this training using 2D and 2.5D unets, but not using covnext. It doesn't allow three-channel output due to a hardcoded scaling parameter value. @ziw-liu, can you help with this? Thanks!

documentation of data.py

Since we are training a number of models with HCSDataModule, it is timely to document data.py to clarify the flow of data through different methods.

I suggest the following:

  • Add a note to HCSDataModule docstring that explains how the data on disk is turned into a batch (what public and private methods are called in the process).
  • Add docstrings to all of the above methods.

@ziw-liu additional improvements are welcome, but the above should be sufficient to understand the design.

Testing the training pipeline

The training loop lacks automated tests. Although unit-testing DL code is not as straightforward as other software, there are some strategies to improve the coverage. See these blog posts for some ideas: 1, 2.

draft demo dataset and notebook

Dataset
HEK293T cells with phase, membrane, and nuclei channels. Let's start with 50 FOVs.

checkpoint 1
Load zarr store, view label-free and fluorescence channels, configure model, browse the 2D UNet with tensorboard, start a training a phase->nuclei model.

checkpoint 2
Examine loss after lunch, see the regression metrics for the phase->nuclei model, train nuclei->phase model, and see the regression metrics for the nuclei->phase model.

checkpoint 3
Adjust the network capacity by different amounts and each student trains one model (phase -> nuclei, phase-> membrane, phase -> nuclei, membrane). Record the metrics on a Google doc.

Unify stem implementation

          @ziw-liu I wrote the `StemDepthtoChannels` to work for resnet and convnext models. Can you test the construction of UNeXt2 model with this stem and report back if the construction works? If it does, we should deprecate other classes for building the stem to avoid confusion in the future.

Originally posted by @mattersoflight in #113 (comment)

Compatibility issues

The models we report in the preprints were trained with different versions of the codebase:

  • VSCyto3D and the model used in the Mantis preprint was trained with v0.1.0a0.
    • Old config files are not compatible with the current HEAD of main (d7d1200)
    • Weights can be loaded through Python API.
  • VSNeuromast was trained with the current HEAD of main (after v0.1.0a1?)
  • VSCyto2D was trained with #71.
    • Weights are not compatible with main due to architectural changes.

The VSCyto3D weights can be provided with an updated config file, while #71 will need to add API changes to separate 2D model building and also change config files accordingly before merging.

Overflow error seen in training and prediction steps

I receive an overflow error when I try to run training or prediction in VisCy. I tried creating a new environment to see if that fixes it, but that did not work.

> [rank1]: Traceback (most recent call last):
> [rank1]:   File "/hpc/mydata/soorya.pradeep/scratch/VisCy/viscy/scripts/DragonFly_vsModel/pretrain.py", line 70, in <module>
> [rank1]:     trainer.fit(model(), data(batch_size=32, caching=True, num_workers=16))
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
> [rank1]:     call._call_and_handle_interrupt(
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
> [rank1]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
> [rank1]:     return function(*args, **kwargs)
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
> [rank1]:     self._run(model, ckpt_path=ckpt_path)
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 948, in _run
> [rank1]:     call._call_setup_hook(self)  # allow user to set up LightningModule in accelerator environment
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 94, in _call_setup_hook
> [rank1]:     _call_lightning_datamodule_hook(trainer, "setup", stage=fn)
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 181, in _call_lightning_datamodule_hook
> [rank1]:     return fn(*args, **kwargs)
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/combined.py", line 52, in setup
> [rank1]:     dm.setup(stage)
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/hcs.py", line 383, in setup
> [rank1]:     self._setup_fit(dataset_settings)
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/hcs.py", line 396, in _setup_fit
> [rank1]:     train_transform = Compose(
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/compose.py", line 251, in __init__
> [rank1]:     self.set_random_state(seed=get_seed())
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/compose.py", line 263, in set_random_state
> [rank1]:     _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32"))
> [rank1]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/transform.py", line 207, in set_random_state
> [rank1]:     _seed = _seed % MAX_SEED
> [rank1]: OverflowError: Python integer 4294967296 out of bounds for uint32
> [rank0]: Traceback (most recent call last):
> [rank0]:   File "/hpc/mydata/soorya.pradeep/scratch/VisCy/viscy/scripts/DragonFly_vsModel/pretrain.py", line 70, in <module>
> [rank0]:     trainer.fit(model(), data(batch_size=32, caching=True, num_workers=16))
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
> [rank0]:     call._call_and_handle_interrupt(
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
> [rank0]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
> [rank0]:     return function(*args, **kwargs)
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
> [rank0]:     self._run(model, ckpt_path=ckpt_path)
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 948, in _run
> [rank0]:     call._call_setup_hook(self)  # allow user to set up LightningModule in accelerator environment
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 94, in _call_setup_hook
> [rank0]:     _call_lightning_datamodule_hook(trainer, "setup", stage=fn)
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 181, in _call_lightning_datamodule_hook
> [rank0]:     return fn(*args, **kwargs)
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/combined.py", line 52, in setup
> [rank0]:     dm.setup(stage)
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/hcs.py", line 383, in setup
> [rank0]:     self._setup_fit(dataset_settings)
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/hcs.py", line 396, in _setup_fit
> [rank0]:     train_transform = Compose(
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/compose.py", line 251, in __init__
> [rank0]:     self.set_random_state(seed=get_seed())
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/compose.py", line 263, in set_random_state
> [rank0]:     _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32"))
> [rank0]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/transform.py", line 207, in set_random_state
> [rank0]:     _seed = _seed % MAX_SEED
> [rank0]: OverflowError: Python integer 4294967296 out of bounds for uint32
> [rank2]: Traceback (most recent call last):
> [rank2]:   File "/hpc/mydata/soorya.pradeep/scratch/VisCy/viscy/scripts/DragonFly_vsModel/pretrain.py", line 70, in <module>
> [rank2]:     trainer.fit(model(), data(batch_size=32, caching=True, num_workers=16))
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
> [rank2]:     call._call_and_handle_interrupt(
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
> [rank2]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
> [rank2]:     return function(*args, **kwargs)
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
> [rank2]:     self._run(model, ckpt_path=ckpt_path)
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 948, in _run
> [rank2]:     call._call_setup_hook(self)  # allow user to set up LightningModule in accelerator environment
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 94, in _call_setup_hook
> [rank2]:     _call_lightning_datamodule_hook(trainer, "setup", stage=fn)
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 181, in _call_lightning_datamodule_hook
> [rank2]:     return fn(*args, **kwargs)
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/combined.py", line 52, in setup
> [rank2]:     dm.setup(stage)
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/hcs.py", line 383, in setup
> [rank2]:     self._setup_fit(dataset_settings)
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/hcs.py", line 396, in _setup_fit
> [rank2]:     train_transform = Compose(
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/compose.py", line 251, in __init__
> [rank2]:     self.set_random_state(seed=get_seed())
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/compose.py", line 263, in set_random_state
> [rank2]:     _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32"))
> [rank2]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/transform.py", line 207, in set_random_state
> [rank2]:     _seed = _seed % MAX_SEED
> [rank2]: OverflowError: Python integer 4294967296 out of bounds for uint32
> [rank3]: Traceback (most recent call last):
> [rank3]:   File "/hpc/mydata/soorya.pradeep/scratch/VisCy/viscy/scripts/DragonFly_vsModel/pretrain.py", line 70, in <module>
> [rank3]:     trainer.fit(model(), data(batch_size=32, caching=True, num_workers=16))
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 543, in fit
> [rank3]:     call._call_and_handle_interrupt(
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 43, in _call_and_handle_interrupt
> [rank3]:     return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/strategies/launchers/subprocess_script.py", line 105, in launch
> [rank3]:     return function(*args, **kwargs)
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 579, in _fit_impl
> [rank3]:     self._run(model, ckpt_path=ckpt_path)
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/trainer.py", line 948, in _run
> [rank3]:     call._call_setup_hook(self)  # allow user to set up LightningModule in accelerator environment
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 94, in _call_setup_hook
> [rank3]:     _call_lightning_datamodule_hook(trainer, "setup", stage=fn)
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/lightning/pytorch/trainer/call.py", line 181, in _call_lightning_datamodule_hook
> [rank3]:     return fn(*args, **kwargs)
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/combined.py", line 52, in setup
> [rank3]:     dm.setup(stage)
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/hcs.py", line 383, in setup
> [rank3]:     self._setup_fit(dataset_settings)
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy_main/VisCy/viscy/data/hcs.py", line 396, in _setup_fit
> [rank3]:     train_transform = Compose(
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/compose.py", line 251, in __init__
> [rank3]:     self.set_random_state(seed=get_seed())
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/compose.py", line 263, in set_random_state
> [rank3]:     _transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32"))
> [rank3]:   File "/hpc/mydata/soorya.pradeep/viscy/lib/python3.10/site-packages/monai/transforms/transform.py", line 207, in set_random_state
> [rank3]:     _seed = _seed % MAX_SEED
> [rank3]: OverflowError: Python integer 4294967296 out of bounds for uint32

I see that the same error in mentioned here: Project-MONAI/MONAI#7856

Keep segmentation task for U-Nets?

In microDL, using U-Nets for segmentation was explored, leaving branching code paths for both regression (virtual staining) and segmentation tasks (e.g. in preprocessing and model architecture).

@mattersoflight Given the strategy to use virtual staining models with off-the-shelf segmentation tools such as CellPose, is it still useful to keep these under-tested code?

Improving visualization of predictions on tensorboard with appropriate intensity scale

Improving the visualization of predictions logged on Tensorboard by setting appropriate intensity scale. Currently the intensity scaling of the predicted images shows artifacts from other structures of cells not involved in the fluorescence expression. The scaling enhances them and produces false signals for visualization. Setting similar intensity scale for all windows of prediction can reduce the effect due to artifact. @ziw-liu , can you help implement this? Thank you!

switching to tensorboard logger for organelle phenotyping

Hi @alishbaimran @ziw-liu I looked carefully into model logging and profiling yesterday and figured that W&B is not the optimal backend for our projects. Please start using tensorboard logger in your training script on infection_state branch.

If this turns out to be tricky, I will start a PR after Tuesday to help with the transition.

Here are the features we need for the project, and options.

Model visualization: There doesn’t seem to be a way to extract images from W&B logs to illustrate training runs. See this movie for the visualizations we would need to interpret models. We have code to make movies like this tensorboard and W&B doesn’t yet provide an API to do this.

Profiling resources: W&B’s visualization of resources is useful, but resource usage can be profiled with lightning's SimpleProfiler and PytorchProfiler. The PytorchProfiler is more feature complete than W&B.

Profiling gradients: W&B makes it particularly easy to visualize vanishing gradients in a layer, but this can also be done by logging .grad property of tensors to tensorboard. We should look for a convenient API that plays well with lightning. @ziw-liu have you come across any good solutions?

Sharing logs: W&B makes it easy to share logs via URL, but it is only marginally harder to use tensorboard to share logs by pointing to a directory.

Lastly, monetary and time cost: For the active team of 4, W&B pricing structure will cost us the cost of attendance at a conference for one team member over 6-9 months. In addition to $50/user/month, we’d pay $3/100GB/month. If we use W&B for this project, we will fragment virtual staining and organelle phenotyping logs in different systems. If we use W&B for all of our projects, it is expensive.

Support loading data from a specific resolution of a .zarr pyramid

@dsundarraman and I are planning to bring multiscale pyramids to the course so that we can iterate quickly on downsampled versions of our data.

@mattersoflight and I think it would be useful to add a resolution parameter to HCSDataModule. Default resolution=0 would load from the highest resolution, and larger integers would pull from lower resolutions of the pyramid.

What do you think @ziw-liu? I will have an example store to test with soon.

Configure sample image logging during training

For 2D models, logging one sample from each batch resulted in too few samples being logged for the validation set, since the number of FOVs for validation is relatively small with regard to typical batch sizes. However, changing the logging scheme to use the first N samples from only the first batch will log samples from only the first FOV in validation due to sliding window sampling. To support both use cases, I propose the following change to the logging interface:

Current:

model:
  log_num_samples: 12

Change to:

model:
  log_batches_per_epoch: 4
  log_samples_per_batch: 3

So for 2D training, these numbers can be (1, 12) so that even the validation epoch only has 1 batch, this will still log the same number of images.

Time sampling for positive pair

Goal: be able to sample the positive pair given a "timestep". For example, sample at timestep 1 would mean sampling the nearest timepoint.

Key features:

  • Specify range of timestep to sample. Timestep 1 means nearest timepoint, timestep 2 means second nearest timepoint, etc.
  • Use sampling rate (should be specific depending on dataset) to determine the experimental timepoint to draw anchor pair from.
  • Sample either forward or backward. We can pick a direction for now and always sample that as this theoretically shouldn't impact the prediction. If we are at the edge and don't have timepoint before or after, then sample the other way. If timepoints available both forward and backward, then always sample x (forward or backward).

WIP: refactor contrastive learning code with virtual staining code

This issue tracks our progress toward integration of contrastive learning code with virtual staining code.

As these steps are being developed, let's version the CLI and configs in compmicro-hpc using the toy dataset of ~10 cells. I am also noting a current/natural home for each step. ? indicates where I am not sure about the natural home. Please comment on the order and missing steps.

The preprocessing steps: repo

  • fluorescence deskewing/deconvolution: shrimPy
  • phase deconvolution: recOrder
  • Registration with fluorescence: shrimPy
  • virtual staining of nuclei and membrane: VisCy
  • Segmentation: VisCy
  • Tracking: ultrack
  • Patchification: VisCy?
  • Normalization: VisCy

Training steps: repo

  • Model training and logging: VisCy

Evaluation steps: repo

Features we need for flexible model training with diverse organelle phenotyping datasets, majority of them focused on dataloader:

  • specify any z slices and any channels.
  • data loader that pools multiple datasets.
  • (next iteration) enable flexible definitions of positive pairs and negative pairs.

Regarding the DataModule, it will be useful (but not a must) if @ziw-liu can extend the existing HCSDataModule to return a doublet or a triplet and we can deprecate the special data module for contrastive learning. The special thing in ContrastiveDataModule is the parsing of tracks output by ultrack.

Inference output image format

@Christianfoley and I are trying to decide which is the best image format to save the inference predicted images, whether to use zarr or tiff. Zarr is better for storing the data, but there are some softwares used for processing the predicted image which works with single page tiffs. @mattersoflight has commented that we should aim to store the predictions as zarr. We can read zarr to numpy array and then perform downstream analysis (i.e., metrics evaluation, this links to issue #202). Anything to add @Christianfoley , @mattersoflight , @ziw-liu ?

Viscy workflow questions and experience from DL@MBL

I am documenting here some of the questions that arose during the use of Viscy and the experience of running this during the DL@MBL course:

I list them as tasks (please cross them out if you think these have been solved) and/or open new issues if these are high priority.

  • Capability to select normalization strategy between std or irq.
  • Tensorboard validation shows a (Z-stack) rather than the multiple batches
  • Difficulty determining what config file to use and what parameters to change.
  • Ability to do tiled predictions.
  • HCS DataLoader outputting the wrong shape #47
  • Tool to easily crop and/or rechunk ome-zarr stores. Current mantis chunking of (ZYX)<500MBs make IO bottlenecks.
  • Unclear preprocessing steps, config.yml, and order of CLI calls. Mostly solved by #43 #45.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.