Coder Social home page Coder Social logo

ott's Introduction

The toolbox has migrated to the OTT-JAX ORG.

Development on this branch has stopped on Jan. 28 2022.

ott's People

Contributors

alantian avatar bunnech avatar geoff-davis avatar laetitiapapaxanthos avatar marcocuturi avatar mucdk avatar olivierteboul avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ott's Issues

GW default sinkhorn kwargs raises AttributeError

Not specifying sinkhorn_kwargs in gromov_wasserstein raises an AttributeError; to reproduce:

from ott.core.gromov_wasserstein import gromov_wasserstein
from ott.geometry.geometry import Geometry
import jax.numpy as jnp

x = Geometry(cost_matrix=jnp.ones((10, 10)))
y = Geometry(cost_matrix=jnp.ones((5, 5)))
gromov_wasserstein(x, y, sinkhorn_kwargs={})  # works as expected
gromov_wasserstein(x, y)  # raises the error below
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
/tmp/ipykernel_246820/4027485735.py in <module>
      6 y = Geometry(cost_matrix=jnp.ones((5, 5)))
      7 gromov_wasserstein(x, y, sinkhorn_kwargs={})  # works as expected
----> 8 gromov_wasserstein(x, y)  # raises the error below

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/core/gromov_wasserstein.py in gromov_wasserstein(geom_x, geom_y, a, b, epsilon, loss, max_iterations, jit, warm_start, sinkhorn_kwargs, **kwargs)
    151     raise ValueError('Unknown loss. Either pass an instance of GWLoss or '
    152                      f'a string among: [{",".join(GW_LOSSES.keys())}]')
--> 153   tau_a = sinkhorn_kwargs.get('tau_a', 1.0)
    154   tau_b = sinkhorn_kwargs.get('tau_b', 1.0)
    155   if tau_a != 1.0 or tau_b != 1.0:

AttributeError: 'NoneType' object has no attribute 'get'

Version 0.1.17.

Sinkhorn's reg_ot_cost can be nan, when converged

Hello and many thanks for this package!

I noticed that in the point_clouds jupyter notebook, that you provide in the documentation, after executing:
out = sinkhorn.sinkhorn(geom, a, b)

one gets:
out.reg_ot_cost --> DeviceArray(nan, dtype=float32)

although out.converged is True and out.f, out.g have no nan values.

Setting epsilon to larger values, overcomes this issue. However, I was wondering if you clarify why the OT objective returns nan, although Sinkhorn seems to have run with no numerical issues.

Thanks in advance.

Creating a PointCloud using np.ndarrays leads to exceptions in the sinkhorn algorithm

If I do the following:

x1 = np.zeros((2048, 64))
x2 = np.zeros((120, 64))
geom = pointcloud.PointCloud(x1, x2, epsilon=1.e-3)
output = sinkhorn.sinkhorn(geom)
print(output.reg_ot_cost)

I get the exception below:

...

File "ott/geometry/geometry.py", line 311, in _center
return f[:, jnp.newaxis] + g[jnp.newaxis, :] - self.cost_matrix
File "jax/_src/numpy/lax_numpy.py", line 6589, in deferring_binary_op
return binary_op(self, other)
File "bug_demo/ott_bug.runfiles/google3/third_party/py/jax/_src/numpy/lax_numpy.py", line 679, in
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.name, x1, x2))
TypeError: sub got incompatible shapes for broadcasting: (0, 0), (2048, 120).

If, instead, I use jnp.zeros above, everything works as expected.

Gradient of regularised OT cost w.r.t. positions of mass is very different to that computed analytically for squared Euclidian distance cost.

Hi! Many thanks for this package and its updates.

In the case where the cost for the transportation of mass is the squared Euclidian distance, one can compute the gradient of the regularised OT cost w.r.t. the positions of mass using the expression (9.7) in https://arxiv.org/pdf/1803.00567.pdf

I noticed that the same gradient when computed with jax.grad (with implicit differentiation or backpropagation and log stabilized or not) returned very different values.

I share a collab demonstrating this. The code for the closed form computation of the gradients is taken from here
and the code for the computation with jax.grad on the sinkhorn output from this test. It seems that the gradients returned by jax.grad have significantly smaller values.

Were you expecting this difference?

Thank you in advance for the response.

Specifying Geometry with kernel matrix and without epsilon causes RecursionError

Code to reproduce:

from ott.core.sinkhorn import sinkhorn
import jax.numpy as jnp

geom_e = Geometry(kernel_matrix=jnp.ones((10, 10)), epsilon=1e-2)
print(geom_e.cost_matrix)  # ok
geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
print(geom.cost_matrix)  # raises the error below
---------------------------------------------------------------------------
RecursionError                            Traceback (most recent call last)
/tmp/ipykernel_246820/705571593.py in <module>
      2 print(geom_e.cost_matrix)
      3 geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
----> 4 print(geom.cost_matrix)

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    107   def cost_matrix(self):
    108     if self._cost_matrix is None:
--> 109       return -self.epsilon * jnp.log(self._kernel_matrix)
    110     return self._cost_matrix
    111 

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in epsilon(self)
    130   @property
    131   def epsilon(self):
--> 132     return self._epsilon.target
    133 
    134   @property

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in _epsilon(self)
    102       return self._epsilon_init
    103     eps = 5e-2 if self._epsilon_init is None else self._epsilon_init
--> 104     return epsilon_scheduler.Epsilon.make(eps, scale=self.scale, **self._kwargs)
    105 
    106   @property

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in scale(self)
     92     if (self._scale is None) and (trigger is not None):  # for dry run
     93       return jnp.where(
---> 94           trigger, jax.lax.stop_gradient(self.mean_cost_matrix), 1.0)
     95     else:
     96       return self._scale

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in mean_cost_matrix(self)
    116   @property
    117   def mean_cost_matrix(self):
--> 118     if isinstance(self.shape[0], int) and (self.shape[0] > 0):
    119       return jnp.sum(self.apply_cost(jnp.ones((self.shape[0],)))) / (
    120           self.shape[0] * self.shape[1])

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in shape(self)
    134   @property
    135   def shape(self):
--> 136     mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
    137     if mat is not None:
    138       return mat.shape

... last 6 frames repeated, from the frame below ...

~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
    107   def cost_matrix(self):
    108     if self._cost_matrix is None:
--> 109       return -self.epsilon * jnp.log(self._kernel_matrix)
    110     return self._cost_matrix
    111 

RecursionError: maximum recursion depth exceeded while calling a Python object

Version: 0.1.17

Sinkhorn values

Hi,

I'm calculating Wasserstein distance using Sinkhorn approximation between two observations (no matching, just pure L2 distance). I see that the distance is way higher compared to the distance calculated by other methods (OT library, L2 norm). In particular, I have embeddings of two 18-bit observations that vary by one bit. Thus, the distance between them should be low.

The results are as follows:
L2 distance: 3.464944 (calculation: np.sqrt(np.sum(np.square(obs_standard-obs_terminal))))
OTT Sinkhorn distance: 12.005859

I would like to know where the difference is coming from.

Thanks!

Low Rank - Linear, Quadratic problems API

Hi,
we noticed some inconsistencies in the API for solvers (Regular and Low Rank) between Linear and Quadratic problems and thought it may be useful to report them. Surely these may be a result of our misunderstanding and in that case we would be grateful for your clarifications.
Issues:
(1) epsilon: for a Linear problem one must provide it through the Geometry object whereas in Quadratic passed to the solver.
(2) Low Rank solver: two separate solvers in the Linear case, single solver for the Quadratic case.

We believe the code snippet below depicts the above in a straightforward manner.
Thank you in advance! :)

    if alpha == 0: # Linear problem
        ot_prob = ott.core.problems.LinearProblem(M)
        if epsilon: # regular (epsilon passed through Geometry M )
            solver = ott.core.sinkhorn.Sinkhorn()
        else: # Low rank (rank passed)
            solver = ott.core.sinkhorn_lr.LRSinkhorn(rank=rank)
    else: # Quadratic problem
        if alpha == 1: # GW
            ot_prob = ott.core.quad_problems.QuadraticProblem(geom_xx=C1,
                                                              geom_yy=C2,
                                                              fused_penalty=0)
        else: # FGW
            ot_prob = ott.core.quad_problems.QuadraticProblem(geom_xx=C1,
                                                              geom_yy=C2,
                                                              geom_xy=M,
                                                              fused_penalty=(1 - alpha) / alpha)
        [a one liner passing both epsilon and rank (whichever is not None implies usage)]
        solver = ott.core.gromov_wasserstein.GromovWasserstein(epsilon=epsilon, rank=rank)
    ot_gw = solver(ot_prob)```

Negative value and no convergence for regularized Wasserstein distance between a measure and itself with parallel_dual__update=True

Issue overview

This issue describes a problem I encounter when trying to compute the regularized Wasserstein distance between a measurement with the parallel_dual_update=True feature in the sinkhorn.sinkhorn function of ott.core.

Code to reproduce experiment

The code to reproduce my experiment is:

rng = jax.random.PRNGKey(0) 
rngs = jax.random.split(rng, 3)

n = 100 # number of points in the support of the input and output measure
d = 2 # dimension of the points

x = jax.random.normal(rngs[0], (n,d)) # supports of the measures
unif = jnp.ones(len(x)) / len(x) # unfiorm weights

# define the geometry object
geom = pointcloud.PointCloud(x, x, epsilon=1e-2) 

# run the Sinkhorn algorithm for various maximal number of iterations and parallel_dual_updates=True 
max_iters = [100, 1000, 10000, 50000, 100000, 250000, 500000]
for max_iter in max_iters:
    out = sinkhorn.sinkhorn(geom=geom, a=unif, b=unif, threshold=1e-1, max_iterations=np.int(max_iter), parallel_dual_updates=True)
    print(f"max_iter={int(max_iter)}\nW_eps(alpha, alpha) = {out.reg_ot_cost}")
    print(f"Has Sinkhorn converged? {out.converged}\n")

Why is there a problem ?

Data generation

In this experiment, we start by generating 100 points in R^2 and we consider \alpha the uniform measure on these 100 points. We then want to calculate the regularized 2-Wasserstein distance W_ε between α and α, i.e. W_ε(α, α).

Positivity of the regularized Wasserstein distance W_ε

By default, OTT considers the KL-divergence between the transport plane and the product of input and output measurements to regularize the Kantorovitch problem. In our case, the regularization is of the form KL(P | α x α). As this quantity is always positive and the scalar product between the cost and the transport plan is always positive because we use the Euclidean cost, the solution of this regularized Kantorovitch problem is positive as the min (reached) of a positive function. Thus, we have W_ε(α, α) > 0. Thus, when we compute W_ε(α, α) from out = sinkhorn.sinkhorn(...) of ott.core, we should have out.reg_ot_cost > 0.

With parrallel_dual_updates=False

When we perform this calculation with parrallel_dual_updates=False in out = sinkhorn.sinkhorn(...), we get a positive result and we converge (i.e. out.converged=True) for a reasonable threshold and max_iterations.

With parrallel_dual_updates=True

On the other hand, when we perform this calculation with parrallel_dual_updates=True in out = sinkhorn.sinkhorn(...), we get a negative result and we do not converge (i.e. out.converged=False), even for a small threshold and a large max_iterations. The fact that it does not converge is reasonable, since the returned value is false. Precisely, by varying max_iterations, we oscillate slightly around the same negative value. This behavior can be found by executing the above code.

Where does the bug seem to come from?

As described by [Fedy et. al, 2019], when we compute the regularized Wasserstein distance between a measure and itself, the regularized Kantorovitch dual problem becomes a concave maximization problem that is symmetric with respect to its two variables f and g. Hence, there exists a (unique) optimal dual pair (f,g = f) on the diagonal whose optimality (in the discrete setting) is expressed by a well-condition fixed point equation. In the same article, it is said that one can typically achieve convergence in about 3 iterations. The parallel_dual_updates=True feature seems to solve the problem by solving this fixed point equation only on f, rather than the classical Sinkhorn fixed point equation in f and g. The bug seems to come from solving this equation.

I hope to have helped you with this remark, and I thank you very much for the development of OTT which really facilitates the use of numerical optimal transport in ML pipeline.

oom error

hi,

Please excuse me in advance if there is a basic mis-usage in the attached code/setting.
In this drive you can find the following folders:
[1] code: two code files (differing in the cost object type, in 'gwlr.py' costs are pointclouds and in gwlr_geom,py we use general Geometry) which can reproduce our oom errors we encounter running on GPUs.
** we noticed that changing alpha to 0.75 we are able refrain from oom using low rank however while geometry run returned quickly the pointcloud calculation took >2 hrs.
[2] logs: 4 log-files documenting the errors in regular (quad) and low rank (lr) runs using pointcloud costs and Geometry (geom) objects.
[3] data: The data to run the code.
We realize it may be that the data is too large but will be grateful for your input.
(I am still trying to set-up different mem. allocation on our GPUs that may allow this)

Thank you in advance :)

Numerical instability of Sinkhorn (even with lse_mode=True) & weird behavior with lse_mode=False

Generate two empirical masures:

  • points in the support generated i.i.d from U([0,1]^5)
  • weights generated i.i.d from U([0,1]) and normalized
rng = jax.random.PRNGKey(1) 
keys = jax.random.split(rng, 2)

# parameters of the measures
dim = 5
n = 100
m = 150

# define the size of the grid
x = jax.random.uniform(keys[0], (n, dim))
y = jax.random.uniform(keys[1], (m, dim))

# # weights of the measures 
a = jax.random.uniform(keys[0], (len(x),)) 
b = jax.random.uniform(keys[1], (len(y),))
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)

Compute the regularized Wasserstein distance for decreasing epsilons with lse_mode=True:

# regularization strength candidates 
eps_cand = [10**(-i) for i in range(10)]

for eps in eps_cand:
    
    # define the geometry
    geom = pointcloud.PointCloud(x, y, epsilon=eps)

    # run the Sinkhorn algorithm
    out = sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=True) # just set to True to emphasize it, by default it is set to True

    print(f'epsilon = {eps}: regularised optimal transport cost = {out.reg_ot_cost}')
    
    if jnp.isnan(out.reg_ot_cost):
        break

Compute the regularized Wasserstein distance for decreasing epsilons with lse_mode=False:

# regularization strength candidates 
eps_cand = [10**(-i) for i in range(10)]

for eps in eps_cand:
    
    # define the geometry
    geom = pointcloud.PointCloud(x, y, epsilon=eps)

    # run the Sinkhorn algorithm
    out = sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=False)

    print(f'epsilon = {eps}: regularised optimal transport cost = {out.reg_ot_cost}')
    
    if jnp.isnan(out.reg_ot_cost):
        break

Comments:

Using the logsumexp mode (les_mode=True), I get overflow (i.e. nan) from epsilon = 10e-3. This is quite strange because as the support points are randomly drawn in U([0,1]^5), the maximum distance between two support points of the first and second measure is 25. So we have ||C||_inf / eps ~ 2.5 * 10e4 which is reasonable, so the logsumexp should not generate any overflow.

Moreover, by not using the logsumexp mode (les_mode=False), I don't get any overflow (until epsilon = 1e-9). This is strange since logsumexp is supposed to be more stable than the version using matrix products of vectors against Gibbs Kernel. On the other hand, as epsilon decreases, the regularized wassertein distance tends to 0 (9.99 * 10e-10 for epsilon = 1e-9). But when epislon becomes very small, the regularized Wassersetin distance tends towards the Wasserstein distance. This is therefore strange because there is no reason a priori for the Wasserstein distance between the two measures to be zero. Indeed, even if the points of the support of each measure are drawn according to the same law (U([0,1]^5), the weights are not uniform and are also drawn randomly then normalized. The regularized Wasserstein distance that we compute is therefore not a priori an estimator of the regularized Wasserstein distance between two measures following the same law (U([0,1]^5), which in this case would make sense to tend towards 0.

I hope to have helped you with this remark, and I thank you very much for the development of OTT which really facilitates the use of numerical optimal transport, especially for the differentiation of OT metrics.

Reverse-mode differentiation for grad of Sinkhorn

Hello,

Given JAX's concept of composable function transformations, it would be great if the grad function of Sinkhorn could also be differentiated with reverse mode.
However, given the fact that the current implementation of the Sinkhorn loop uses jax.lax.while_loop (which is only fwd-mode-differentiable), this is not possible.
It would really be helpful if a solution could be offered for that. Perhaps with jax.lax.scan and in the lines of this ?

Any solution to that would be strongly appreciated, as currently it is not possible to jit the grad of Sinkhorn and do automatic differentation. The only solution to my understanding is to let it unjitted and run in op-by-op mode.

Thanks in advance.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.