Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ CI:

ot.bregman:
- changed-files:
- any-glob-to-any-file: ot/bergman/**
- any-glob-to-any-file: ot/bregman/**

ot.gromov:
- changed-files:
Expand All @@ -32,6 +32,10 @@ ot.lp:
- changed-files:
- any-glob-to-any-file: ot/lp/**

ot.bsp:
- changed-files:
- any-glob-to-any-file: ot/bsp/**

ot.utils:
- changed-files:
- any-glob-to-any-file: ot/utils.py
Expand Down Expand Up @@ -68,17 +72,21 @@ ot.lowrank:
- changed-files:
- any-glob-to-any-file: ot/lowrank.py

ot.sliced:
- changed-files:
- any-glob-to-any-file: ot/sliced/**

ot.solvers:
- changed-files:
- any-glob-to-any-file: ot/solvers.py
- any-glob-to-any-file: ot/solvers/**

ot.partial:
ot.batch:
- changed-files:
- any-glob-to-any-file: ot/partial/**
- any-glob-to-any-file: ot/batch/**

ot.sliced:
ot.partial:
- changed-files:
- any-glob-to-any-file: ot/sliced.py
- any-glob-to-any-file: ot/partial/**

ot.smooth:
- changed-files:
Expand All @@ -95,3 +103,7 @@ ot.dr:
ot.gnn:
- changed-files:
- any-glob-to-any-file: ot/gnn/**

ot.sgot:
- changed-files:
- any-glob-to-any-file: ot/sgot.py
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Wrapper for barycenter solvers with free support `ot.solvers.bary_free_support` (PR #730)
- Build wheels on ubuntu ARM to avoid QEMU emulation (PR #818)
- Add new methods to compute the linear transport map and the related 2-Wasserstein distance betweeen high-dimensional (HD) Gaussian distributions as described in [88], implemented in `ot.gaussian.bures_wasserstein_mapping_hd` and `ot.gaussian.bures_wasserstein_distance_hd`, respectively. Two additional methods estimate the same quantities from the source and destination observed data and are implemented in `ot.gaussian.empirical_bures_wasserstein_mapping_hd` and `ot.gaussian.empirical_bures_wasserstein_distance_hd`, respectively (PR #814)
- Fix docstrings for `lowrank_gromov_wasserstein_samples` and `lowrank_sinkhorn` (PR #823)
- Update the geomloss wrapper to the new version and API (PR #826)


#### Closed issues
Expand All @@ -56,6 +56,7 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver
- Fix entropic regularization in `gcg`(PR #817, Issue #758)
- Fix documentation build on master with submodules (PR #818)
- Fix failing test for unbalanced solver with generic regularization (PR #824)
- Fix docstrings for `lowrank_gromov_wasserstein_samples` and `lowrank_sinkhorn` (PR #823)


## 0.9.6.post1
Expand Down
3 changes: 2 additions & 1 deletion ot/bregman/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

from ._dictionary import unmix

from ._geomloss import empirical_sinkhorn2_geomloss, geomloss
from ._geomloss import empirical_sinkhorn2_geomloss, geomloss, old_geomloss


__all__ = [
Expand Down Expand Up @@ -76,6 +76,7 @@
"empirical_sinkhorn_nystroem",
"empirical_sinkhorn_nystroem2",
"geomloss",
"old_geomloss",
"screenkhorn",
"unmix",
]
124 changes: 69 additions & 55 deletions ot/bregman/_geomloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,15 @@
import torch
from torch.autograd import grad
from ..utils import get_backend, LazyTensor, dist

if geomloss.__version__ < "0.3.1":
old_geomloss = True
else:
old_geomloss = False

except ImportError:
geomloss = False
old_geomloss = False


def get_sinkhorn_geomloss_lazytensor(
Expand Down Expand Up @@ -127,6 +134,12 @@ def empirical_sinkhorn2_geomloss(
better stability and epsilon-scaling. The solution is computed in a lazy way
using the Geomloss [60]_ and the KeOps library [61]_.

.. warning::
The Geomloss library is required for this function to work. Also
when setting `log=True`, the dual potentials are computed using autograd and
may be slow for large problems and prevent computing backward gradients. Use
the fynction :func:`ot.solve_sample` with `method='geomloss'` for better performance and gradient computation.

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
Expand Down Expand Up @@ -178,63 +191,66 @@ def empirical_sinkhorn2_geomloss(

"""

if geomloss:
nx = get_backend(X_s, X_t, a, b)
if not geomloss:
raise ImportError("geomloss not installed")

if nx.__name__ not in ["torch", "numpy"]:
raise ValueError("geomloss only support torch or numpy backend")
nx = get_backend(X_s, X_t, a, b)

if a is None:
a = nx.ones(X_s.shape[0], type_as=X_s) / X_s.shape[0]
if b is None:
b = nx.ones(X_t.shape[0], type_as=X_t) / X_t.shape[0]
if nx.__name__ not in ["torch", "numpy"]:
raise ValueError("geomloss only support torch or numpy backend")

if nx.__name__ == "numpy":
X_s_torch = torch.tensor(X_s)
X_t_torch = torch.tensor(X_t)
if a is None:
a = nx.ones(X_s.shape[0], type_as=X_s) / X_s.shape[0]
if b is None:
b = nx.ones(X_t.shape[0], type_as=X_t) / X_t.shape[0]

a_torch = torch.tensor(a)
b_torch = torch.tensor(b)
if nx.__name__ == "numpy":
X_s_torch = torch.tensor(X_s)
X_t_torch = torch.tensor(X_t)

else:
X_s_torch = X_s
X_t_torch = X_t
a_torch = torch.tensor(a)
b_torch = torch.tensor(b)

a_torch = a
b_torch = b
else:
X_s_torch = X_s
X_t_torch = X_t

# after that we are all in torch
a_torch = a
b_torch = b

# set blur value and p
if metric == "sqeuclidean":
p = 2
blur = np.sqrt(reg / 2) # because geomloss divides cost by two
elif metric == "euclidean":
p = 1
blur = np.sqrt(reg)
else:
raise ValueError("geomloss only supports sqeuclidean and euclidean metrics")
# after that we are all in torch

# set blur value and p
if metric == "sqeuclidean":
p = 2
blur = np.sqrt(reg / 2) # because geomloss divides cost by two
elif metric == "euclidean":
p = 1
blur = np.sqrt(reg)
else:
raise ValueError("geomloss only supports sqeuclidean and euclidean metrics")

if log:
# force gradients for computing dual
a_torch.requires_grad = True
b_torch.requires_grad = True

loss = SamplesLoss(
loss="sinkhorn",
p=p,
blur=blur,
backend=backend,
debias=debias,
scaling=scaling,
verbose=verbose,
)
loss = SamplesLoss(
loss="sinkhorn",
p=p,
blur=blur,
backend=backend,
debias=debias,
scaling=scaling,
verbose=verbose,
)

# compute value
value = loss(
a_torch, X_s_torch, b_torch, X_t_torch
) # linear + entropic/KL reg?
# compute value
value = loss(a_torch, X_s_torch, b_torch, X_t_torch) # linear + entropic/KL reg?

# get dual potentials
# get dual potentials

if log: # recover dual potentials
f, g = grad(value, [a_torch, b_torch])

if metric == "sqeuclidean":
Expand All @@ -245,20 +261,18 @@ def empirical_sinkhorn2_geomloss(
g = g.cpu().detach().numpy()
value = value.cpu().detach().numpy()

if log:
log = {}
log["f"] = f
log["g"] = g
log["value"] = value

log["lazy_plan"] = get_sinkhorn_geomloss_lazytensor(
X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx
)
log = {}
log["f"] = f
log["g"] = g
log["value"] = value

return value, log
log["lazy_plan"] = get_sinkhorn_geomloss_lazytensor(
X_s, X_t, f, g, a, b, metric=metric, blur=blur, nx=nx
)

else:
return value
return value, log

else:
raise ImportError("geomloss not installed")
if nx.__name__ == "numpy":
value = value.cpu().detach().numpy()
return value
52 changes: 30 additions & 22 deletions ot/solvers/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from ..bregman import (
sinkhorn_log,
empirical_sinkhorn2,
empirical_sinkhorn2_geomloss,
empirical_sinkhorn_nystroem2,
old_geomloss,
)
from ..smooth import smooth_ot_dual
from ..gaussian import empirical_bures_wasserstein_distance
from ..factored import factored_optimal_transport
from ..lowrank import lowrank_sinkhorn
from ..optim import cg
from warnings import warn


lst_method_lazy = [
Expand Down Expand Up @@ -1175,29 +1176,36 @@ def solve_sample(
backend = "online"
else:
backend = "tensorized"
if lazy0 is None:
warn(
f"geomloss backend is set to '{backend}' but is not yet supported by unified geomloss API yet."
)

value, log = empirical_sinkhorn2_geomloss(
X_a,
X_b,
reg=reg,
a=a,
b=b,
metric=metric,
log=True,
verbose=verbose,
scaling=scaling,
backend=backend,
)
if old_geomloss: # old wrapper for old geomloss versions
raise NotImplementedError(
"geomloss version >= 0.3.1 required for ot.solve_sample() geomloss backend."
)

lazy_plan = log["lazy_plan"]
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]
else: # new geomloss wrapper
import geomloss.ot as got

# return scaled potentials (to be consistent with other solvers)
potentials = (
log["f"] / (lazy_plan.blur**2),
log["g"] / (lazy_plan.blur**2),
)
if max_iter is None:
max_iter = 1000

res = got.solve_sample(
X_a,
X_b,
a=a,
b=b,
reg=reg,
cost=metric,
unbalanced=unbalanced,
unbalanced_type=unbalanced_type,
max_iter=max_iter,
tol=tol,
)
res.value_linear = None
return res

elif reg is None or reg == 0: # exact OT
if unbalanced is None: # balanced EMD solver not available for lazy
Expand All @@ -1210,7 +1218,7 @@ def solve_sample(
else:
raise (
NotImplementedError(
'Non regularized solver with unbalanced_type="{}" not implemented'.format(
'Non regularized lazy solver with unbalanced_type="{}" not implemented'.format(
unbalanced_type
)
)
Expand Down
2 changes: 1 addition & 1 deletion requirements_all.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ tensorflow; python_version < '3.14'
pytest
torch_geometric
cvxpy
geomloss
geomloss>=0.3.1
pykeops
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
"dr": ["scikit-learn", "pymanopt", "autograd"],
"gnn": ["torch", "torch_geometric"],
"plot": ["matplotlib"],
"geomloss": ["geomloss"],
"doc": [
"sphinx",
"sphinx-rtd-theme",
Expand All @@ -151,6 +152,7 @@
"autograd",
"torch_geometric",
"matplotlib",
"geomloss",
],
},
python_requires=">=3.7",
Expand Down
Loading
Loading