Coder Social home page Coder Social logo

careamics / careamics Goto Github PK

View Code? Open in Web Editor NEW
28.0 4.0 5.0 8.28 MB

A deep-learning library for N2V and friends

Home Page: https://careamics.github.io/

License: BSD 3-Clause "New" or "Revised" License

Python 100.00%
deep-learning denoising python restoration n2v

careamics's Introduction

CAREamics

License PyPI Python Version CI codecov

CAREamics is a PyTorch library aimed at simplifying the use of Noise2Void and its many variants and cousins (CARE, Noise2Noise, N2V2, P(P)N2V, HDN, muSplit etc.).

Why CAREamics?

Noise2Void is a widely used denoising algorithm, and is readily available from the n2v python package. However, n2v is based on TensorFlow, while more recent methods denoising methods (PPN2V, DivNoising, HDN) are all implemented in PyTorch, but are lacking the extra features that would make them usable by the community.

The aim of CAREamics is to provide a PyTorch library reuniting all the latest methods in one package, while providing a simple and consistent API. The library relies on PyTorch Lightning as a back-end. In addition, we will provide extensive documentation and tutorials on how to best apply these methods in a scientific context.

Installation and use

Check out the documentation for installation instructions and guides!

careamics's People

Contributors

cateek avatar dependabot[bot] avatar jdeschamps avatar melisande-c avatar mese79 avatar pre-commit-ci[bot] avatar veegalinova avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

Forkers

koumal8 ashesh-0

careamics's Issues

[BUG] config/config_path params

If config_path is provided instead of config, there's no validation or assertion happening

engine = Engine(config_path="n2v.yml")

This will lead to

`---> 78 log_path = self.cfg.working_directory / "log.txt"
     79 self.progress = ProgressLogger()
     80 self.logger = get_logger(__name__, log_path=log_path)

AttributeError: 'str' object has no attribute 'working_directory'

in https://github.com/CAREamics/careamics-restoration/blob/019c3322ea65c3ff250bd09ed61b95305eeb1a68/src/careamics_restoration/engine.py#L73

Batched prediction [BUG]

Describe the bug
Batched prediction does not work for both tiled prediction and non-tiled prediction, for different reasons.

To Reproduce
Non-tiled prediction
Code snippet allowing reproducing the behaviour:

import numpy as np

from careamics import CAREamist
from careamics.config import create_n2v_configuration

config = create_n2v_configuration(
    experiment_name="PredBatchingTest", 
    data_type="array",
    axes="SYX",
    patch_size=[64, 64],
    batch_size=4,
    num_epochs=1,
)
images = np.random.random((4, 512, 512))

engine = CAREamist(source=config)
engine.train(train_source=images)
pred = engine.predict(source=images, batch_size=2)

Results in error

  ...
  File "/home/melisande.croft/Documents/Repos/careamics/src/careamics/lightning_prediction_loop.py", line 108, in run
    last_tiles = [t.last_tile for t in self.tile_information]
  File "/home/melisande.croft/Documents/Repos/careamics/src/careamics/lightning_prediction_loop.py", line 108, in <listcomp>
    last_tiles = [t.last_tile for t in self.tile_information]
AttributeError: 'Tensor' object has no attribute 'last_tile'

This is happening because in CAREamanics.predict_step the batch unpacking assumes batch size is 1, see below. If batch has a shape (4, C, Y, X) then x be a tensor of shape (1, C, Y, X) and aux will be a tuple of length 3 containing tensors of shape (1, C, Y, X).

def predict_step(self, batch: Tensor, batch_idx: Any) -> Any:
"""Prediction step.
Parameters
----------
batch : Tensor
Input batch.
batch_idx : Any
Batch index.
Returns
-------
Any
Model output.
"""
x, *aux = batch

Tiled prediction
Code snippet allowing reproducing the behaviour:

import numpy as np

from careamics import CAREamist
from careamics.config import create_n2v_configuration

config = create_n2v_configuration(
    experiment_name="PredBatchingTest", 
    data_type="array",
    axes="SYX",
    patch_size=[64, 64],
    batch_size=4,
    num_epochs=1,
)
images = np.random.random((4, 512, 512))

engine = CAREamist(source=config)
engine.train(train_source=images)
pred = engine.predict(source=images, batch_size=2, tile_size=(64, 64))

Doesn't result in an error but produces outputs:
predictions, 0 & 2:
bad_tiling
predictions, 1 & 3
output1

Cause unknown.

Please share any insight into this strange behaviour!

[BUG] Rich logging issue

In the jupyter notebook, whenever the cell with a progress bar is restarted or another cell is executed, I keep getting this message

LiveError: Only one live display may be active at once

After this either the progress bar doesn't show up, or the cell stalls

Refactoring: Decouple tiling from prediction loop

Description

I think the tiling logic should probably be removed from the prediction loop.

Why

  • Currently the prediction loop's run function is copied from the original Lightning _PredictionLoop class with the tiling logic added in-between the Lightning code. This is hard to maintain if Lightning ever change their code.
  • Currently I'm pretty sure CAREamicsPredictionLoop is incompatible with trainer.predict(*args, **kwargs, return_predictions=False).
  • Once the save to disk function is added, when a full (stitched) prediction is saved we want to free up memory so that the prediction can be run on a large number of files. This is hard to add into the current implementation.
  • When zarr datasets are added, tiles can be written into the correct place in the file without waiting for the complete set of tiles. This is hard to add into the current implementation.

Two solutions

Tiling as a Callback

Lightning already has a BasePredictionWriter Callback. There is a write_on_batch_end hook that could handle writing to zarr files or caching tiles until the last tile to save to tiff. The outputs of trainer.predict can also be changed so that it is the full prediction and not the tiles. I have implemented a version of this as a demo in this branch, where I move the current tiling logic to a callback. (This would have to change a lot to accommodate writing predictions).

Tiling in CAREamicsModule

I don't like this option as much because it doesn't feel like what the LightingModules are for. LightningModules have all the same hooks as Callback, on_predict_batch_end etc. So all the tiling logic could move there.

Discusion points

Please comment likes/dislikes with either solution or a new solution and any other thoughts!

  • Does trainer.predict have to output the stitched predictions?

    • To change the outputs of trainer.predict, in on_predict_epoch_end we have to do:
      trainer.predict_loop._predictions = stitched_predictions
      (see implementation in mentioned branch above) which I don't like because it feels a bit hacky.
    • predictions keep the tiling information alongside them, why not just do:
      predictions = trainer.predict(model=model, datamodule=datamodule, ckpt_path=checkpoint)
      stitched_preds = stitched_predictions(predictions)
    • Users using the CAREamist class never need to see this (happens in CAREamist.predict). Users not using the CAREamist class have more control, which is why they might not be using it.
  • Something to consider:

    • Users not using CAREamist class, if they want tiled predictions, currently have to do:
      trainer = Trainer(*args, **kwargs)
      trainer.prediction_loop = CAREamicsPredictionLoop(trainer)
      predictions = trainer.predict(model=model, datamodule: CAREamicsDataModule=datamodule)
    • With the Callback option:
      trainer = Trainer(*args, **kwargs, callbacks=[TiledPredictionCallback, ...])
      predictions = trainer.predict(model=model, datamodule: CAREamicsDataModule=datamodule)
    • With the LightningModule option
      trainer = Trainer(*args, **kwargs)
      predictions = trainer.predict(model=model, datamodule: CAREamicsDataModule=datamodule)
    • Although I have just realised this is redundant if stitching is applied afterwards (as described in point above). However they would still have to add the PredictionWriterCallback for saving tiles.

Progress bar [BUG]

during predict

del self.progress
        ^^^^^^^^^^^^^
AttributeError: 'Engine' object has no attribute 'progress'

BioImage.io export [BUG]

Method save_as_bioimage of Engine produces this for certain output shapes

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[9], line 1
----> 1 engine.save_as_bioimage(engine.cfg.experiment_name + "bioimage.zip")

File [~/projects/caremics/src/careamics_restoration/engine.py:762](https://vscode-remote+ssh-002dremote-002bvdi.vscode-resource.vscode-cdn.net/home/igor.zubarev/projects/caremics/examples/3D/~/projects/caremics/src/careamics_restoration/engine.py:762), in Engine.save_as_bioimage(self, output_zip, model_specs)
    759     specs = self._generate_rdf(model_specs)
    761     # Build model
--> 762     raw_model = build_zip_model(
    763         path=output_zip,
    764         config=self.cfg,
    765         model_specs=specs,
    766     )
    768     return raw_model
    769 else:

File [~/projects/caremics/src/careamics_restoration/bioimage/io.py:166](https://vscode-remote+ssh-002dremote-002bvdi.vscode-resource.vscode-cdn.net/home/igor.zubarev/projects/caremics/examples/3D/~/projects/caremics/src/careamics_restoration/bioimage/io.py:166), in build_zip_model(path, config, model_specs)
    163     model_specs["tags"].append("2D")
    165 # build model zip
--> 166 raw_model = build_model(
    167     output_path=Path(path).absolute(),
    168     **model_specs,
    169 )
    171 # remove the temporary files
    172 weight_path.unlink()

File [/localfast/mambaforge/envs/cmcs_lf/lib/python3.11/site-packages/bioimageio/core/build_spec/build_model.py:802](https://vscode-remote+ssh-002dremote-002bvdi.vscode-resource.vscode-cdn.net/localfast/mambaforge/envs/cmcs_lf/lib/python3.11/site-packages/bioimageio/core/build_spec/build_model.py:802), in build_model(weight_uri, test_inputs, test_outputs, input_axes, output_axes, name, description, authors, tags, documentation, cite, output_path, architecture, model_kwargs, weight_type, sample_inputs, sample_outputs, input_names, input_step, input_min_shape, input_data_range, output_names, output_reference, output_scale, output_offset, output_data_range, halo, preprocessing, postprocessing, pixel_sizes, maintainers, license, covers, git_repo, attachments, packaged_by, run_mode, parent, config, dependencies, links, training_data, root, add_deepimagej_config, tensorflow_version, opset_version, pytorch_version, weight_attachments)
    800 documentation = _ensure_local(documentation, root)
    801 if covers is None:
--> 802     covers = _generate_covers(root [/](https://vscode-remote+ssh-002dremote-002bvdi.vscode-resource.vscode-cdn.net/) test_inputs[0], root [/](https://vscode-remote+ssh-002dremote-002bvdi.vscode-resource.vscode-cdn.net/) test_outputs[0], input_axes[0], output_axes[0], root)
    803 else:
    804     covers = _ensure_local(covers, root)

File [/localfast/mambaforge/envs/cmcs_lf/lib/python3.11/site-packages/bioimageio/core/build_spec/build_model.py:499](https://vscode-remote+ssh-002dremote-002bvdi.vscode-resource.vscode-cdn.net/localfast/mambaforge/envs/cmcs_lf/lib/python3.11/site-packages/bioimageio/core/build_spec/build_model.py:499), in _generate_covers(in_path, out_path, input_axes, output_axes, root)
    496 cover_path = os.path.join(root, "cover.png")
    497 input_, output = np.load(in_path), np.load(out_path)
--> 499 input_ = to_image(input_, input_axes)
    500 # this is not image data so we only save the input image
    501 if output.ndim < 4:

File [/localfast/mambaforge/envs/cmcs_lf/lib/python3.11/site-packages/bioimageio/core/build_spec/build_model.py:479](https://vscode-remote+ssh-002dremote-002bvdi.vscode-resource.vscode-cdn.net/localfast/mambaforge/envs/cmcs_lf/lib/python3.11/site-packages/bioimageio/core/build_spec/build_model.py:479), in _generate_covers..to_image(data, data_axes)
    477 # transpose the data to "bczyx" [/](https://vscode-remote+ssh-002dremote-002bvdi.vscode-resource.vscode-cdn.net/) "bcyx" order
    478 axes = "bczyx" if data.ndim == 5 else "bcyx"
--> 479 assert set(data_axes) == set(axes)
    480 if axes != data_axes:
    481     ax_permutation = tuple(data_axes.index(ax) for ax in axes)

AssertionError:

Need yo change dummy outputs shape in _get_sample_io_files

Combine this with

https://www.notion.so/RSE-Projects-12392129a1284f88b153dd4c60cff2da?p=fec0e23175df4ffca8653cd7e3d675a2&pm=s

Checkpoint loading

  • if incorrect path is provided, ambiguous message
  • if path is incorrect, asks for config, instead of telling that path is incorrect

Loading checkpoint with a different device [BUG]

Describe the bug

A checkpoint trained on the GPU does not seem to be loadable on a machine with CPU-only.

To Reproduce

# checkpoint trained on a linux machine with GPU
from pathlib import Path
from careamics import CAREamist

careamist = CAREamist(Path("sem_n2v_1epoch.ckpt"))

Leads to the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[8], [line 4](vscode-notebook-cell:?execution_count=8&line=4)
      [1](vscode-notebook-cell:?execution_count=8&line=1) from pathlib import Path
      [2](vscode-notebook-cell:?execution_count=8&line=2) from careamics import CAREamist
----> [4](vscode-notebook-cell:?execution_count=8&line=4) careamist = CAREamist(Path("sem_n2v_1epoch.ckpt"))

(...)

File ~/git/careamics/careamics/src/careamics/model_io/model_io_utils.py:67, in _load_checkpoint(path)
...
    [254](https://file+.vscode-resource.vscode-cdn.net/Users/joran.deschamps/git/careamics/careamics-examples/algorithms/n2v/~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/serialization.py:254)                        'to map your storages to the CPU.')
    [255](https://file+.vscode-resource.vscode-cdn.net/Users/joran.deschamps/git/careamics/careamics-examples/algorithms/n2v/~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/serialization.py:255) device_count = torch.cuda.device_count()
    [256](https://file+.vscode-resource.vscode-cdn.net/Users/joran.deschamps/git/careamics/careamics-examples/algorithms/n2v/~/miniconda3/envs/careamics/lib/python3.10/site-packages/torch/serialization.py:256) if device >= device_count:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Expected behavior
Loading should be possible regardless of the device.

Additional context
Trained on VDI Linux with GPU, attempted loading in macOS intel.

Prediction Callback

Hi! I recently came across this repo and it looks very promising for my use case!

As I was playing around I wanted to see how the predictions improved throughout the epochs. Ideally, I would like to have a small separate dataset on which the model could be run after each training epoch, or just a random draw of some images from the validation dataset.

I tried to implement this myself with Pytorch lightning Callbacks, but I don't see a clear way to get around having to call trainer.predict inside the callback, and I fear that messes up the trainer.fit loop by deleting the validation losses it keeps track of.

Given that you may have a lot more intuition of how CAREamics works, do you have an idea of something that would work? I am happy to implement it myself and create a PR, but currently I do not see how I can avoid calling training.predict since the prediction loop is modified to stitch the image together.

Thank you,

Conrad

Training progress bars (rich)

Two currently annoying (and related) problems in the progress bars:

  • unknown number of patches leads to a "?" and no apparent progress
  • the time counter continues after the end of the training

Feature: HDN and muSplit

What

HDN and muSplit integration into CAREamics will pose a certain number of challenges:

Specific datasets

muSplit requires specific datasets.

Choice of Dataset will now depend on the algorithm

e.g. muSplit requires a specific dataset spitting out two targets for each input, but also multiple scaled down images for the LC.

Different ways to deal with that:

  • Pass the whole configuration to the LightningDataModule, but that means asking for much more parameters in the TrainingDataWrapper
  • Create a SplitLigthningDataModule that specifically only instantiates splitting datasets.
  • Add a parameter to the DataConfig for the type of datasets and validate it from the Configuration. This is a bit hacky but might be the simpler solution.

Model parameters used by the Dataset

Related to the previous point, the number of LC levels is both a parameter of the model and of the dataset, as it also influences how many lower res images are generated by the dataset.

Model output

muSplit outputs multiple channel, in a single Tensor, so this should not be a problem with CAREamics (as opposed to outputting multiple tensors).

Sampling, generation and averaging

We will need to implement prediction methods for sampling, averaging and generation of images.

Calibration

Calibration is a fitting/learning process that uses many sampled images and the pixel-wise std.

Specific losses

That should not be an issue!

Noise Models

Noise models will need to be implemented in CAREamics. Currently, old code (that was never used) currently resides in https://github.com/CAREamics/careamics/tree/jd/refac/save_noisemodels.

Model Zoo export

What challenges await us? @mese79


Any other issue we foresee?

@CatEek @federico-carrara @ashesh-0

Stitched prediction not compatible with multi channel [BUG]

Description
Stitching prediction does not work with multiple channels. This is because in the function stitch_prediction (snippet linked below) tiles are assumed to have only 1 channel. When the tile is cropped on line 61 it should not be squeezed and the slices should account for the channel axis.

def stitch_prediction(
tiles: List[torch.Tensor],
stitching_data: List[List[torch.Tensor]],
) -> torch.Tensor:
"""
Stitch tiles back together to form a full image.
Parameters
----------
tiles : List[torch.Tensor]
Cropped tiles and their respective stitching coordinates.
stitching_coords : List
List of information and coordinates obtained from
`dataset.tiled_patching.extract_tiles`.
Returns
-------
np.ndarray
Full image.
"""
# retrieve whole array size, there is two cases to consider:
# 1. the tiles are stored in a list
# 2. the tiles are stored in a list with batches along the first dim
if tiles[0].shape[0] > 1:
input_shape = np.array(
[el.numpy() for el in stitching_data[0][0][0]], dtype=int
).squeeze()
else:
input_shape = np.array(
[el.numpy() for el in stitching_data[0][0]], dtype=int
).squeeze()
# TODO should use torch.zeros instead of np.zeros
predicted_image = torch.Tensor(np.zeros(input_shape, dtype=np.float32))
for tile_batch, (_, overlap_crop_coords_batch, stitch_coords_batch) in zip(
tiles, stitching_data
):
for batch_idx in range(tile_batch.shape[0]):
# Compute coordinates for cropping predicted tile
slices = tuple(
[
slice(c[0][batch_idx], c[1][batch_idx])
for c in overlap_crop_coords_batch
]
)
# Crop predited tile according to overlap coordinates
cropped_tile = tile_batch[batch_idx].squeeze()[slices]
# Insert cropped tile into predicted image using stitch coordinates
predicted_image[
(
...,
*[
slice(c[0][batch_idx], c[1][batch_idx])
for c in stitch_coords_batch
],
)
] = cropped_tile.to(torch.float32)
return predicted_image

To Reproduce
Code snippet allowing reproducing the behaviour:

import numpy as np

from careamics import CAREamist
from careamics.config import create_n2v_configuration

config = create_n2v_configuration(
    experiment_name="MultiChannel", 
    data_type="array",
    axes="CYX",
    patch_size=[64, 64],
    batch_size=1,
    num_epochs=1,
    independent_channels=True,
    n_channels=3,
)
image = np.random.random((3, 512, 512))

engine = CAREamist(source=config)
engine.train(train_source=image)
pred = engine.predict(source=image, tile_size=[64, 64])

Resulting Error

Traceback (most recent call last):
  File "<CODE SNIPPET>", line 20, in <module>
    pred = engine.predict(source=image, tile_size=[64, 64])
...
  File "<CODE DIR>/careamics/src/careamics/lightning_prediction_loop.py", line 100, in run
    predicted_batches = stitch_prediction(
  File "<CODE DIR>/careamics/src/careamics/prediction/stitch_prediction.py", line 64, in stitch_prediction
    predicted_image[
RuntimeError: The expanded size of the tensor (40) must match the existing size (64) at non-singleton dimension 2.  Target sizes: [3, 40, 40].  Tensor sizes: [3, 40, 64]

This error occurs because stitch_prediction slices the first two axes of the tensor. However, the first axis is the channel axis; therefore the last axis is left at the tile size, 64, instead of the desired crop size, 40.

Expected behavior
Images with multiple channels should be able to be stitched at prediction the same way as images with 1 channel.

Add a warning for prediction w/o loading a model

Trying to directly call
preds = engine.predict()
will raise

293     raise ValueError(
    294         "Mean or std are not specified in the configuration and in parameters"
    295     )
    297 pred_loader, stitch = self.get_predict_dataloader(
    298     external_input=external_input,
    299     mean=mean,
    300     std=std,
    301 )
    302 # TODO keep getting this ValueError: Mean or std are not specified in the
    303 # configuration and in parameters
    304 # TODO where is this error? is this linked to an issue? Mention issue here.

ValueError: Mean or std are not specified in the configuration and in parameters

in https://github.com/CAREamics/careamics-restoration/blob/019c3322ea65c3ff250bd09ed61b95305eeb1a68/src/careamics_restoration/engine.py#L293

Add a warning and/or restrict prediction without loading a saved model ?

[BUG] ProgressBar Itervable vs Sequence

Describe the bug
In the Engine, ProgressLogger is called several times. There are two things I don't really understand and mypy complains about it.

  1. The first one is that task_iterable is declared as an Iterable, but then:
    https://github.com/CAREamics/careamics-restoration/blob/235b96c17073e19f9fde5a788b7ef3ba514b3bb2/src/careamics_restoration/utils/logging.py#L148
    As far as I understand, Iterable don't have __len__.

  2. The ProgressLogger is called on different object types (Iterable, Sequence and enumerate).

Regarding the different types:

Here range is a Sequence, so it has len.
https://github.com/CAREamics/careamics-restoration/blob/235b96c17073e19f9fde5a788b7ef3ba514b3bb2/src/careamics_restoration/engine.py#L117

The dataloaders are Iterable I guess, so no len:
https://github.com/CAREamics/careamics-restoration/blob/235b96c17073e19f9fde5a788b7ef3ba514b3bb2/src/careamics_restoration/engine.py#L189
https://github.com/CAREamics/careamics-restoration/blob/235b96c17073e19f9fde5a788b7ef3ba514b3bb2/src/careamics_restoration/engine.py#L223

enumerate also doesn't have len:
https://github.com/CAREamics/careamics-restoration/blob/235b96c17073e19f9fde5a788b7ef3ba514b3bb2/src/careamics_restoration/engine.py#L292

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.