The toolbox has migrated to the OTT-JAX ORG.
Development on this branch has stopped on Jan. 28 2022.
License: Apache License 2.0
The toolbox has migrated to the OTT-JAX ORG.
Development on this branch has stopped on Jan. 28 2022.
Not specifying sinkhorn_kwargs
in gromov_wasserstein
raises an AttributeError
; to reproduce:
from ott.core.gromov_wasserstein import gromov_wasserstein
from ott.geometry.geometry import Geometry
import jax.numpy as jnp
x = Geometry(cost_matrix=jnp.ones((10, 10)))
y = Geometry(cost_matrix=jnp.ones((5, 5)))
gromov_wasserstein(x, y, sinkhorn_kwargs={}) # works as expected
gromov_wasserstein(x, y) # raises the error below
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
/tmp/ipykernel_246820/4027485735.py in <module>
6 y = Geometry(cost_matrix=jnp.ones((5, 5)))
7 gromov_wasserstein(x, y, sinkhorn_kwargs={}) # works as expected
----> 8 gromov_wasserstein(x, y) # raises the error below
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/core/gromov_wasserstein.py in gromov_wasserstein(geom_x, geom_y, a, b, epsilon, loss, max_iterations, jit, warm_start, sinkhorn_kwargs, **kwargs)
151 raise ValueError('Unknown loss. Either pass an instance of GWLoss or '
152 f'a string among: [{",".join(GW_LOSSES.keys())}]')
--> 153 tau_a = sinkhorn_kwargs.get('tau_a', 1.0)
154 tau_b = sinkhorn_kwargs.get('tau_b', 1.0)
155 if tau_a != 1.0 or tau_b != 1.0:
AttributeError: 'NoneType' object has no attribute 'get'
Version 0.1.17
.
Hello and many thanks for this package!
I noticed that in the point_clouds jupyter notebook, that you provide in the documentation, after executing:
out = sinkhorn.sinkhorn(geom, a, b)
one gets:
out.reg_ot_cost --> DeviceArray(nan, dtype=float32)
although out.converged is True and out.f, out.g have no nan values.
Setting epsilon to larger values, overcomes this issue. However, I was wondering if you clarify why the OT objective returns nan, although Sinkhorn seems to have run with no numerical issues.
Thanks in advance.
If I do the following:
x1 = np.zeros((2048, 64))
x2 = np.zeros((120, 64))
geom = pointcloud.PointCloud(x1, x2, epsilon=1.e-3)
output = sinkhorn.sinkhorn(geom)
print(output.reg_ot_cost)
I get the exception below:
...
File "ott/geometry/geometry.py", line 311, in _center
return f[:, jnp.newaxis] + g[jnp.newaxis, :] - self.cost_matrix
File "jax/_src/numpy/lax_numpy.py", line 6589, in deferring_binary_op
return binary_op(self, other)
File "bug_demo/ott_bug.runfiles/google3/third_party/py/jax/_src/numpy/lax_numpy.py", line 679, in
fn = lambda x1, x2: lax_fn(*_promote_args(numpy_fn.name, x1, x2))
TypeError: sub got incompatible shapes for broadcasting: (0, 0), (2048, 120).
If, instead, I use jnp.zeros above, everything works as expected.
Hi! Many thanks for this package and its updates.
In the case where the cost for the transportation of mass is the squared Euclidian distance, one can compute the gradient of the regularised OT cost w.r.t. the positions of mass using the expression (9.7) in https://arxiv.org/pdf/1803.00567.pdf
I noticed that the same gradient when computed with jax.grad (with implicit differentiation or backpropagation and log stabilized or not) returned very different values.
I share a collab demonstrating this. The code for the closed form computation of the gradients is taken from here
and the code for the computation with jax.grad on the sinkhorn output from this test. It seems that the gradients returned by jax.grad have significantly smaller values.
Were you expecting this difference?
Thank you in advance for the response.
Code to reproduce:
from ott.core.sinkhorn import sinkhorn
import jax.numpy as jnp
geom_e = Geometry(kernel_matrix=jnp.ones((10, 10)), epsilon=1e-2)
print(geom_e.cost_matrix) # ok
geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
print(geom.cost_matrix) # raises the error below
---------------------------------------------------------------------------
RecursionError Traceback (most recent call last)
/tmp/ipykernel_246820/705571593.py in <module>
2 print(geom_e.cost_matrix)
3 geom = Geometry(kernel_matrix=jnp.ones((10, 10)))
----> 4 print(geom.cost_matrix)
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
107 def cost_matrix(self):
108 if self._cost_matrix is None:
--> 109 return -self.epsilon * jnp.log(self._kernel_matrix)
110 return self._cost_matrix
111
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in epsilon(self)
130 @property
131 def epsilon(self):
--> 132 return self._epsilon.target
133
134 @property
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in _epsilon(self)
102 return self._epsilon_init
103 eps = 5e-2 if self._epsilon_init is None else self._epsilon_init
--> 104 return epsilon_scheduler.Epsilon.make(eps, scale=self.scale, **self._kwargs)
105
106 @property
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in scale(self)
92 if (self._scale is None) and (trigger is not None): # for dry run
93 return jnp.where(
---> 94 trigger, jax.lax.stop_gradient(self.mean_cost_matrix), 1.0)
95 else:
96 return self._scale
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in mean_cost_matrix(self)
116 @property
117 def mean_cost_matrix(self):
--> 118 if isinstance(self.shape[0], int) and (self.shape[0] > 0):
119 return jnp.sum(self.apply_cost(jnp.ones((self.shape[0],)))) / (
120 self.shape[0] * self.shape[1])
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in shape(self)
134 @property
135 def shape(self):
--> 136 mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
137 if mat is not None:
138 return mat.shape
... last 6 frames repeated, from the frame below ...
~/.miniconda3/envs/cellrank/lib/python3.8/site-packages/ott/geometry/geometry.py in cost_matrix(self)
107 def cost_matrix(self):
108 if self._cost_matrix is None:
--> 109 return -self.epsilon * jnp.log(self._kernel_matrix)
110 return self._cost_matrix
111
RecursionError: maximum recursion depth exceeded while calling a Python object
Version: 0.1.17
Hi,
I'm calculating Wasserstein distance using Sinkhorn approximation between two observations (no matching, just pure L2 distance). I see that the distance is way higher compared to the distance calculated by other methods (OT library, L2 norm). In particular, I have embeddings of two 18-bit observations that vary by one bit. Thus, the distance between them should be low.
The results are as follows:
L2 distance: 3.464944 (calculation: np.sqrt(np.sum(np.square(obs_standard-obs_terminal))))
OTT Sinkhorn distance: 12.005859
I would like to know where the difference is coming from.
Thanks!
Hi,
we noticed some inconsistencies in the API for solvers (Regular and Low Rank) between Linear and Quadratic problems and thought it may be useful to report them. Surely these may be a result of our misunderstanding and in that case we would be grateful for your clarifications.
Issues:
(1) epsilon
: for a Linear problem one must provide it through the Geometry
object whereas in Quadratic passed to the solver.
(2) Low Rank solver: two separate solvers in the Linear case, single solver for the Quadratic case.
We believe the code snippet below depicts the above in a straightforward manner.
Thank you in advance! :)
if alpha == 0: # Linear problem
ot_prob = ott.core.problems.LinearProblem(M)
if epsilon: # regular (epsilon passed through Geometry M )
solver = ott.core.sinkhorn.Sinkhorn()
else: # Low rank (rank passed)
solver = ott.core.sinkhorn_lr.LRSinkhorn(rank=rank)
else: # Quadratic problem
if alpha == 1: # GW
ot_prob = ott.core.quad_problems.QuadraticProblem(geom_xx=C1,
geom_yy=C2,
fused_penalty=0)
else: # FGW
ot_prob = ott.core.quad_problems.QuadraticProblem(geom_xx=C1,
geom_yy=C2,
geom_xy=M,
fused_penalty=(1 - alpha) / alpha)
[a one liner passing both epsilon and rank (whichever is not None implies usage)]
solver = ott.core.gromov_wasserstein.GromovWasserstein(epsilon=epsilon, rank=rank)
ot_gw = solver(ot_prob)```
This issue describes a problem I encounter when trying to compute the regularized Wasserstein distance between a measurement with the parallel_dual_update=True
feature in the sinkhorn.sinkhorn
function of ott.core
.
The code to reproduce my experiment is:
rng = jax.random.PRNGKey(0)
rngs = jax.random.split(rng, 3)
n = 100 # number of points in the support of the input and output measure
d = 2 # dimension of the points
x = jax.random.normal(rngs[0], (n,d)) # supports of the measures
unif = jnp.ones(len(x)) / len(x) # unfiorm weights
# define the geometry object
geom = pointcloud.PointCloud(x, x, epsilon=1e-2)
# run the Sinkhorn algorithm for various maximal number of iterations and parallel_dual_updates=True
max_iters = [100, 1000, 10000, 50000, 100000, 250000, 500000]
for max_iter in max_iters:
out = sinkhorn.sinkhorn(geom=geom, a=unif, b=unif, threshold=1e-1, max_iterations=np.int(max_iter), parallel_dual_updates=True)
print(f"max_iter={int(max_iter)}\nW_eps(alpha, alpha) = {out.reg_ot_cost}")
print(f"Has Sinkhorn converged? {out.converged}\n")
In this experiment, we start by generating 100 points in R^2 and we consider \alpha the uniform measure on these 100 points. We then want to calculate the regularized 2-Wasserstein distance W_ε between α and α, i.e. W_ε(α, α).
By default, OTT
considers the KL-divergence between the transport plane and the product of input and output measurements to regularize the Kantorovitch problem. In our case, the regularization is of the form KL(P | α x α). As this quantity is always positive and the scalar product between the cost and the transport plan is always positive because we use the Euclidean cost, the solution of this regularized Kantorovitch problem is positive as the min (reached) of a positive function. Thus, we have W_ε(α, α) > 0. Thus, when we compute W_ε(α, α) from out = sinkhorn.sinkhorn(...)
of ott.core
, we should have out.reg_ot_cost > 0
.
parrallel_dual_updates=False
When we perform this calculation with parrallel_dual_updates=False
in out = sinkhorn.sinkhorn(...)
, we get a positive result and we converge (i.e. out.converged=True
) for a reasonable threshold
and max_iterations
.
parrallel_dual_updates=True
On the other hand, when we perform this calculation with parrallel_dual_updates=True
in out = sinkhorn.sinkhorn(...)
, we get a negative result and we do not converge (i.e. out.converged=False
), even for a small threshold
and a large max_iterations
. The fact that it does not converge is reasonable, since the returned value is false. Precisely, by varying max_iterations
, we oscillate slightly around the same negative value. This behavior can be found by executing the above code.
As described by [Fedy et. al, 2019], when we compute the regularized Wasserstein distance between a measure and itself, the regularized Kantorovitch dual problem becomes a concave maximization problem that is symmetric with respect to its two variables f and g. Hence, there exists a (unique) optimal dual pair (f,g = f) on the diagonal whose optimality (in the discrete setting) is expressed by a well-condition fixed point equation. In the same article, it is said that one can typically achieve convergence in about 3 iterations. The parallel_dual_updates=True
feature seems to solve the problem by solving this fixed point equation only on f, rather than the classical Sinkhorn fixed point equation in f and g. The bug seems to come from solving this equation.
I hope to have helped you with this remark, and I thank you very much for the development of OTT
which really facilitates the use of numerical optimal transport in ML pipeline.
hi,
Please excuse me in advance if there is a basic mis-usage in the attached code/setting.
In this drive you can find the following folders:
[1] code
: two code files (differing in the cost object type, in 'gwlr.py' costs are pointclouds
and in gwlr_geom,py
we use general Geometry
) which can reproduce our oom
errors we encounter running on GPUs.
** we noticed that changing alpha
to 0.75
we are able refrain from oom
using low rank
however while geometry run returned quickly the pointcloud calculation took >2 hrs.
[2] logs
: 4 log-files documenting the errors in regular (quad
) and low rank (lr
) runs using pointcloud
costs and Geometry
(geom
) objects.
[3] data
: The data to run the code.
We realize it may be that the data is too large but will be grateful for your input.
(I am still trying to set-up different mem. allocation on our GPUs that may allow this)
Thank you in advance :)
rng = jax.random.PRNGKey(1)
keys = jax.random.split(rng, 2)
# parameters of the measures
dim = 5
n = 100
m = 150
# define the size of the grid
x = jax.random.uniform(keys[0], (n, dim))
y = jax.random.uniform(keys[1], (m, dim))
# # weights of the measures
a = jax.random.uniform(keys[0], (len(x),))
b = jax.random.uniform(keys[1], (len(y),))
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)
# regularization strength candidates
eps_cand = [10**(-i) for i in range(10)]
for eps in eps_cand:
# define the geometry
geom = pointcloud.PointCloud(x, y, epsilon=eps)
# run the Sinkhorn algorithm
out = sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=True) # just set to True to emphasize it, by default it is set to True
print(f'epsilon = {eps}: regularised optimal transport cost = {out.reg_ot_cost}')
if jnp.isnan(out.reg_ot_cost):
break
# regularization strength candidates
eps_cand = [10**(-i) for i in range(10)]
for eps in eps_cand:
# define the geometry
geom = pointcloud.PointCloud(x, y, epsilon=eps)
# run the Sinkhorn algorithm
out = sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=False)
print(f'epsilon = {eps}: regularised optimal transport cost = {out.reg_ot_cost}')
if jnp.isnan(out.reg_ot_cost):
break
Using the logsumexp mode (les_mode=True), I get overflow (i.e. nan) from epsilon = 10e-3. This is quite strange because as the support points are randomly drawn in U([0,1]^5), the maximum distance between two support points of the first and second measure is 25. So we have ||C||_inf / eps ~ 2.5 * 10e4 which is reasonable, so the logsumexp should not generate any overflow.
Moreover, by not using the logsumexp mode (les_mode=False), I don't get any overflow (until epsilon = 1e-9). This is strange since logsumexp is supposed to be more stable than the version using matrix products of vectors against Gibbs Kernel. On the other hand, as epsilon decreases, the regularized wassertein distance tends to 0 (9.99 * 10e-10 for epsilon = 1e-9). But when epislon becomes very small, the regularized Wassersetin distance tends towards the Wasserstein distance. This is therefore strange because there is no reason a priori for the Wasserstein distance between the two measures to be zero. Indeed, even if the points of the support of each measure are drawn according to the same law (U([0,1]^5), the weights are not uniform and are also drawn randomly then normalized. The regularized Wasserstein distance that we compute is therefore not a priori an estimator of the regularized Wasserstein distance between two measures following the same law (U([0,1]^5), which in this case would make sense to tend towards 0.
I hope to have helped you with this remark, and I thank you very much for the development of OTT which really facilitates the use of numerical optimal transport, especially for the differentiation of OT metrics.
Hello,
Given JAX's concept of composable function transformations, it would be great if the grad function of Sinkhorn could also be differentiated with reverse mode.
However, given the fact that the current implementation of the Sinkhorn loop uses jax.lax.while_loop (which is only fwd-mode-differentiable), this is not possible.
It would really be helpful if a solution could be offered for that. Perhaps with jax.lax.scan and in the lines of this ?
Any solution to that would be strongly appreciated, as currently it is not possible to jit the grad of Sinkhorn and do automatic differentation. The only solution to my understanding is to let it unjitted and run in op-by-op mode.
Thanks in advance.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.