Coder Social home page Coder Social logo

theislab / moscot Goto Github PK

View Code? Open in Web Editor NEW
102.0 3.0 9.0 12.17 MB

Multi-omic single-cell optimal transport tools

Home Page: https://moscot-tools.org

License: BSD 3-Clause "New" or "Revised" License

Python 99.80% Shell 0.20%
optimal-transport single-cell

moscot's Introduction

PyPI Downloads CI pre-commit.ci status Coverage Documentation

moscot - multi-omic single-cell optimal transport tools

moscot is a scalable framework for Optimal Transport (OT) applications in single-cell genomics. It can be used for

  • trajectory inference (incorporating spatial and lineage information)
  • mapping cells to their spatial organisation
  • aligning spatial transcriptomics slides
  • translating modalities
  • prototyping of new OT models in single-cell genomics

moscot is powered by OTT which is a JAX-based Optimal Transport toolkit that supports just-in-time compilation, GPU acceleration, automatic differentiation and linear memory complexity for OT problems.

Installation

You can install moscot via:

pip install moscot

In order to install moscot from in editable mode, run:

git clone https://github.com/theislab/moscot
cd moscot
pip install -e .

For further instructions how to install jax, please refer to https://github.com/google/jax.

Resources

Please have a look at our documentation

Reference

Our preprint "Mapping cells through time and space with moscot" can be found here.

moscot's People

Contributors

arinadanilina avatar giovp avatar marius1311 avatar michalk8 avatar mucdk avatar nosander avatar pre-commit-ci[bot] avatar selmanozleyen avatar weilerp avatar zoepiran avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

moscot's Issues

Print information whether internal sinkhorn iterations converged.

So far, our logging output looks like this:
Screenshot 2021-09-15 at 11 15 31

This is nice, but there's information missing on whether the internal sinkhorn iterations actually converged. Would be nice to add this as an extra column.

BTW, OTT has this implemented for their entropic GW solver, i.e. running GW_object_reg = ott.core.gromov_wasserstein.gromov_wasserstein(C1_g, C2_g, epsilon=2, sinkhorn_kwargs={})

gives me an object where I can check assert(GW_object.converged_sinkhorn.all())

Improve tests

TODOs:

  • setup and increase coverage
  • generate test data and add regression tests

Polishing the repository

TODOs:

  • fix linting (esp. type-checking)
  • improve docstrings
  • improve documentation (installation/tutorials/references/etc. + styling)
  • improve README.rst (links to docs/tutorials/pre-print + logo + badges)
  • improve setup.py
  • update/edit existing config files
  • (optional) add CONTRIBUTING.rst

Installation & importing

Both installation & importing work fine on my machine. The only thing I get is WARNING: scott 0.1.dev3+gff251e7 does not provide the extra 'dev'

Support for various cost functions

We should check whether the definition of the cost function has a large impact on OT performance in our applications - @ManuelGander did some initial experiments (using DPT distance vs. euclidean, both in latent spaces) and saw that it can make a difference in the temporal problem (@ManuelGander, once reproducibility issues have been resolved, please post your results in a comment and explain).

Going from there, we have to decide how to implement this - would appreciate your input @michalk8. Do we have a flexible way yet to define costs? In any case, we should make sure the way the cost is defined supports online evaluation and never materializes the cost matrix. Luckily, for DPT distance, this is possible as follows. Let N=M be the cell number.

  • Compute your latent representation Z
  • Construct KNN graph for K neighbors in Z -> Sparse connectivity matrix C (it's N by N, but it's fine, it's very sparse, approx K entries per row)
  • Transform this into a density-normalized transition matrix T of the same sparsity (all defined in the DPT paper)
  • Compute the first G eigenvalues and eigenvectors of T (these will be real as T is generalized symmetric). For this step, we'll have to use iterative solvers which work only through matrix-vector products. Scanpy uses scipy's implementation, which is fine for now. Long term, this should be done in SLEPc which is numerically much more stable
  • Using the G eigenvalues and eigenvectors, define a new representation for each cell in G-dimensional space (see DPT paper)
  • Euclidean distance in that space corresponds to DPT distance, i.e. this can be passed to an OTT geometry with online=TRUE :)

How to set `epsilon`

OTT has a heuristic, we say in an example that this can make a hughe difference and we should consider using a similar heuristic.

Error w/ GW solver

running the GW solver on a GPU I get the error:

E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2086] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 25158058240 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:   23.45GiB
              constant allocation:         0B
        maybe_live_out allocation:   23.43GiB
     preallocated temp allocation:         0B
                 total allocation:   46.88GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 23.43GiB
		Entry Parameter Subshape: f64[1036,1036,2930]
		==========================

	Buffer 2:
		Size: 23.43GiB
		Operator: op_type="sub" op_name="jit(sub)/sub" source_file="/cs/labs/mornitzan/zoe.piran/venvSiFT/lib/python3.7/site-packages/ott/geometry/ops.py" source_line=28
		XLA Label: fusion
		Shape: f64[1036,1036,2930]
		==========================

	Buffer 3:
		Size: 23.16MiB
		Entry Parameter Subshape: f64[1036,1,2930]
		==========================

@michalk8 any thoughts?

OT for aligning spatial omics data

Release preparation

TODOs:

  • setup release CI
  • setup cron job for CI (once a week)
  • add repository token secrets
  • add .readthedocs.yml
  • add .bumpversion.cfg

Improve efficiency of FGW

TODOs:

  • jaxify the outer loop (by using jax.lax.scan or the while loop)
  • make sure we don't materialize the cost matrix when combining them using alpha

Online method throws error

After installing the requirements as instructed I get the following problem:

Regularized.fit() method does not work if "online" in geometry object is set to True

import scanpy as sc
from ott.geometry.pointcloud import PointCloud
import jax.numpy as jnp
from moscot._solver import Regularized

adata = anndata.read("/home/icb/dominik.klein/git_repos/data/adatas/adata_tedsim_8192.h5ad")
obs_var_time = "depth"
adata_source = adata[adata.obs[obs_var_time] == 11]
adata_target = adata[adata.obs[obs_var_time] == 12]

sc.pp.pca(adata_source)
sc.pp.pca(adata_target)

pointcloud_offline = PointCloud(x=jnp.asarray(adata_source.X), y=jnp.asarray(adata_target.X), online=False)
pointcloud_online = PointCloud(x=jnp.asarray(adata_source.X), y=jnp.asarray(adata_target.X), online=True)

moscot_solver = Regularized(epsilon=0.2)

moscot_solver.fit(pointcloud_offline) # works
moscot_solver.fit(pointcloud_online) # does not work
Error message:
`TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_27778/1474808108.py in <module>
      1 moscot_solver = Regularized(epsilon=0.2)
      2 
----> 3 moscot_solver.fit(pointcloud_online)

/mnt/home/icb/dominik.klein/git_repos/moscot/moscot/_solver.py in fit(self, geom, a, b, **kwargs)
     96         """
     97         geom = self._prepare_geom(geom, **kwargs)
---> 98         self._transport = Transport(geom, a=a, b=b, **self._kwargs)
     99         self._check_marginals(a, b)
    100 

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/tools/transport.py in __init__(self, a, b, *args, **kwargs)
     66       self.geom = pointcloud.PointCloud(*args, **pc_kw)
     67 
---> 68     num_a, num_b = self.geom.shape
     69     self.a = jnp.ones((num_a,)) / num_a if a is None else a
     70     self.b = jnp.ones((num_b,)) / num_b if b is None else b

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/geometry/geometry.py in shape(self)
    138   @property
    139   def shape(self):
--> 140     mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
    141     if mat is not None:
    142       return mat.shape

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    110       # If no epsilon was passed on to the geometry, then assume it is one by
    111       # default.
--> 112       cost = -jnp.log(self._kernel_matrix)
    113       return cost if self._epsilon_init is None else self.epsilon * cost
    114     return self._cost_matrix

    [... skipping hidden 15 frame]

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in <lambda>(x)
    690 def _one_to_one_unop(numpy_fn, lax_fn, promote_to_inexact=False, lax_doc=False):
    691   if promote_to_inexact:
--> 692     fn = lambda x: lax_fn(*_promote_args_inexact(numpy_fn.__name__, x))
    693   else:
    694     fn = lambda x: lax_fn(*_promote_args(numpy_fn.__name__, x))

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _promote_args_inexact(fun_name, *args)
    600 
    601   Promotes non-inexact types to an inexact type."""
--> 602   _check_arraylike(fun_name, *args)
    603   _check_no_float0s(fun_name, *args)
    604   return _promote_shapes(fun_name, *_promote_dtypes_inexact(*args))

~/miniconda3/envs/moscot_1712/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py in _check_arraylike(fun_name, *args)
    576                     if not _arraylike(arg))
    577     msg = "{} requires ndarray or scalar arguments, got {} at position {}."
--> 578     raise TypeError(msg.format(fun_name, type(arg), pos))
    579 
    580 def _check_no_float0s(fun_name, *args):

TypeError: log requires ndarray or scalar arguments, got <class 'NoneType'> at position 0.`

Which loss is moscot printing?

moscot and POT have different printout, see Fig. 1 and 2, resp. Which loss does moscot print?

Fig. 1 (moscot)

Screenshot 2021-09-13 at 11 21 47

Fig. 2 (POT)

Screenshot 2021-09-13 at 11 22 13

Explore alternative backends

We've discussed

The basic thought behind this is improving both our memory & time complexity from currently O(n2) and O(n3 log(n)) to O(n) and O(n), respectively. The theory behind this, at least for the time complexity, was discussed in #26.

Move notebooks out of this repo

The notebooks folder currently contains our analysis for moscot-lineage that we did for the extended abstract. This analysis should be moved to https://github.com/theislab/moscot-lineage_reproducibility

Ongoing, further analysis for the moscot-lineage paper should be done in https://github.com/theislab/moscot-lineage_notebooks

Basically, moscot-lineage_reproducibility should be a snapshot of our anlaysis at extended abstract submission time, and futher analysis should be done in moscot-lineage_notebooks.

Convergence issue

For \alpha = 1e-2 (it used to be 1e-3 originally, i.e. I'm giving more weight to GW now), moscot's FGW doesn't converge, but POT's does. I'm using \epsilon = 1e-1 for both methods, and I'm running the example from MK_2021-09-07_fgw_comparison_gt. See Fig. 1 and 2 for the losses, and Fig. 3 and 4 for the results.

Fig. 1 (moscot loss does not converge)

Screenshot 2021-09-13 at 11 27 53
Notice how the loss starts to increase again.

Fig. 2 (POT loss does converge)

Screenshot 2021-09-13 at 11 28 25

Fig. 3 (moscot coupling matrix)

Screenshot 2021-09-13 at 11 28 49

Fig. 4 (POT coupling matrix)

Screenshot 2021-09-13 at 11 29 04

How to choose \tau

While in the FGV paper (Vayer et al., 2019), the authors use a linesearch to set \tau, the original GW paper for machine learning (Peyre et al., 2016) suggested to set \tau = 1\/epsilon, where \epsilon is the regularization weight. Note that (Vayer et al., 2019) do not consider entropic regularization, hence that choice wouldn't make sense in their context. I believe POTs implementation of FGV is directly taken from (Vayer et al., 2019), hence no entropic regularization and a line-search for \tau. Is that correct @michalk8?

Convergence

One important question here is whether either scheme is guaranteed to converge. Let's just consider the standard entr. reg. GW case here, not FGV. In that case, (Peyre et al., 2016) provide the answer in their Remark 3:
Screenshot 2021-09-08 at 14 03 21

So, with a linesearch, we are probably guaranteed to converge (at the cost of extra effort, internal iterations) while with \tau = 1/\epsilon we are in general not, but it may work well in practice and it will be a bit faster (if it works).

Complexity

The entr. reg. GW problem defines a non-convex optimization problem, so finding the global optimum will be NP-hard. All algorithms we consider here only look for local optima and I don't know whether there are complexity bounds.

Improve the current API

This is a discussion about (but not limited to) the following issues:

  • should we use sklearn's estimators as a base for our estimators? i.e. params in init, data in fit
    • if yes, what are metaparams (n_iters/{r,a}tol/...) and should they be incuded in fit
    • use fit_transform + return the array rather than fit + access .matrix? should we even save it in the estimator?
  • naming conventions (currently have attribute .matrix for the transport map) + argument names (currently, we have geom_a, geom_b, geom_ab and a/b for marginals
  • conversions when passing raw numpy arrays: atm, they are assumed to be point clouds, but they can also be cost matrices, should we add flag? might complicate things

Jax version does not find GPU

Hey,

it seems like the current JAX version does not find the GPU. This depends on the JAX version with a high probability (encountered the issue before, e.g. see google/jax#5231) but I want to make sure our setup does work.

I.e. torch.cuda.is_available() returns True but initializing a PointCloud throws the warning:

PointCloud(x=jnp.asarray(adata_source.X), y=jnp.asarray(adata_target.X)

WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)

Wishlist

What should this package be able to do?

  • Implement scalable optimal transport by circumventing the O(n**2) memory cost. This can be done either through OTT or GeomLoss
  • Allow an easy interface to AnnData objects
  • Include both simple OT (regularized, unbalanced -> WOT) as well as Gromov-Wasserstein OT (=NovoSpark if combined with simple OT in the loss)
  • Add downstream functions from WOT, like computing ancestor or descendant distributions
  • Flexible computation of across time-point and within time-point distance metrics/matrices.

Re sparse matrices: we may need them up to the point where we compute the distance matrices/metrics - these will always be dense, I suppose.

Write tutorial

I think it would be nice to have one tutorial that gives intuition on moscot-lineage, using both the toy data example of Fig. 10 in Vayer et al., 2020 (see Fig. 1 below, this is also used in POT here) and using the 2-gene simulations from LineageOT.

Fig. 1

Screenshot 2021-10-20 at 14 29 15

Implementation of FGW

Two questions:

  • Why do you add jnp.min(tmp) in return Geometry(cost_matrix=tmp + jnp.min(tmp), epsilon=self.epsilon) in _solver.py?
  • Why do you divide by 2 in tens = -jnp.dot(h1(C_a), T).dot(h2(C_b).T / 2.0) in the same file?

How do we measure accuracy? What to implement?

Possible methods for sequential data:

  • similar to WOT: interpolate P_t by P_{t-1} and p_{t+1}. Here, we could also take the gradient flows provided by OTT
  • similar to moscot-lineage: compare calculated transport matrix row- or column-wise to ground truth transport matrix (only works if we know ground truth, i.e. with simulated data), do I understand this correctly?
  • new approach: leave cells out for training and project cells to subspace spanned by training data points. Apply optimal transport to linear combination

Robustness with respect to \alpha / GW vs W

The parameter \alpha controls the weight given to GW (as opposed to W). Originally, in MK_2021-09-07_fgw_comparison_gt, it was set to 1e-3. I'm playing with this parameter to see how the corresponding coupling matrix changes in moscot, POT and novospark. Note that, due to convergence issues, see #13, i'm using novospark=True for moscot, and I'm keeping the regularization parameter fixed at \epsilon=1e-1.

Also note that POT uses unregularized FGW while moscot uses entropically regularized FGW, so we don't expect results to be identical.

Summary

We're basically looking at Fig. 10 from Vayer et al., Algorithms 2020:
Screenshot 2021-09-13 at 14 13 47

\alpha = 0 (regularized OT)

This is a reference to compare to - as alpha approaches zero, we should converge to the pure, entropically regularized, optimal transport solution. Computed using POT, same epsilon.

image
This looks weird because epsilon=0.1 is actually quite large regularization.

\alpha = 1e-3

moscot:
image

POT:
image

novosparc:
image

\alpha = 1e-2

moscot:
image

POT:
image

novosparc:
Gives a lot of numerical warnings (Warning: numerical errors at iteration 0)
image

\alpha = 1e-1 (POT flipped)

moscot:
image

POT:
image

novosparc:
Numerical warnings
image

\alpha = 0.2

moscot:
image

POT:
image

novosparc:
Numerical warnings
image

\alpha = 0.5 (moscot and novosparc flipped)

moscot:
image

POT:
image

novosparc
Numerical warnings
image

\alpha = 0.9

moscot:
image

POT:
image

novosparc
Numerical warnings
image

\alpha = 1.0 (regularized GW)

POT doesn't work here, I'm calling entropic_gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', verbose=True, log=True, epsilon=epsilon) and i'm getting Warning: numerical errors at iteration 0 and in the end:
image

I checked the underlying transport matrix, it actually returns just zeros.

Using instead unregularized GW via POT, it works and I get
image

However, regularized GW via OTT works, if I set epsilon high enough, e.g. 2:
image

Note that we also had to increase numerical precision from float32 to float64 to get this in Jax.

Conclusion

\alpha controls our interpolation between W (\alpha = 0 ) and GW (\alpha = 1) losses. For low \alpha, we expect to see a coupling that focuses on feature similarity whereas for high alpha, we expect to see a coupling that focuses on structure similarity. This seems to work differently in POT and moscot. POT flipps from W to GW behaviour at around \alpha=0.1, whereas moscot flips at around \alpha=0.5. Also, the GW behaviour looks quite different in both methods.

Novosparc seems to not converge sometimes and likes to give lots of warnings.

Scaling behaviour

This issue is about recent developments in the OT literature that have been concerned with scaling for

  • entr. reg. OT
  • GW
  • FGW

We can collect some results here and note down methods that we may want to explore later on. This is more of a discussion than an issue, however, github discussions are (in our case) only enabled for public repos, so it has to stay here.

Do we depend on POT?

We obviously don't have it in our requirements, however, do we need it to compute distance matrices? In your notebook, you use

C1 = ot.dist(xs)
C2 = ot.dist(xt)

I'm just wondering. I know this is just a wrapper around scipy.spatial.distance.cdist.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.