Skip to content

Commit 838ef22

Browse files
Circle CICircle CI
authored andcommitted
CircleCI update of dev docs (3888).
1 parent b775c44 commit 838ef22

309 files changed

Lines changed: 92239 additions & 90666 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
Binary file not shown.
Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
# -*- coding: utf-8 -*-
2+
r"""
3+
===============================================================================
4+
Solve Fused Unbalanced Gromov Wasserstein with Adam
5+
===============================================================================
6+
7+
Since the FUGW loss is differentiable, it can be minimized with first-order optimization.
8+
We show how to do this with the `loss_fugw_batch` function and compare the results with
9+
the dedicated FUGW solver `fused_unbalanced_gromov_wasserstein`.
10+
"""
11+
12+
# Author: Rémi Flamary <remi.flamary@polytechnique.edu>
13+
# Sonia Mazelet <sonia.mazelet@polytechnique.edu>
14+
#
15+
# License: MIT License
16+
17+
# sphinx_gallery_thumbnail_number = 3
18+
19+
import numpy as np
20+
import matplotlib.pylab as pl
21+
import torch
22+
from time import perf_counter
23+
import ot
24+
from ot.batch._quadratic import loss_quadratic_batch, tensor_batch
25+
from ot.gromov import fused_unbalanced_gromov_wasserstein
26+
from sklearn.manifold import MDS
27+
28+
29+
# %%
30+
# Generation of source and target graphs
31+
# ----------------
32+
33+
rng = np.random.RandomState(42)
34+
35+
36+
def get_sbm(n, nc, ratio, P):
37+
nbpc = np.round(n * ratio).astype(int)
38+
n = np.sum(nbpc)
39+
C = np.zeros((n, n))
40+
for c1 in range(nc):
41+
for c2 in range(c1 + 1):
42+
if c1 == c2:
43+
for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])):
44+
for j in range(np.sum(nbpc[:c2]), i):
45+
if rng.rand() <= P[c1, c2]:
46+
C[i, j] = 1
47+
else:
48+
for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])):
49+
for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[: c2 + 1])):
50+
if rng.rand() <= P[c1, c2]:
51+
C[i, j] = 1
52+
53+
return C + C.T
54+
55+
56+
def plot_graph(x, C, color="C0", s=100):
57+
for j in range(C.shape[0]):
58+
for i in range(j):
59+
if C[i, j] > 0:
60+
pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color="k")
61+
pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors="k")
62+
63+
64+
def get_sbm_labels(n, ratio):
65+
nbpc = np.round(n * ratio).astype(int)
66+
return np.concatenate(
67+
[np.full(count, label, dtype=int) for label, count in enumerate(nbpc)]
68+
)
69+
70+
71+
def get_noisy_one_hot(labels, n_classes, noise_level=0.1):
72+
x = np.eye(n_classes)[labels]
73+
x += noise_level * rng.randn(*x.shape)
74+
return x
75+
76+
77+
n1 = 15
78+
n2 = 10
79+
nc1 = 3
80+
nc2 = 2
81+
ratio1 = np.array([0.33, 0.33, 0.33])
82+
ratio2 = np.array([0.5, 0.5])
83+
84+
P1 = np.array([[0.8, 0.03, 0.0], [0.08, 0.8, 0.03], [0.0, 0.08, 0.8]])
85+
P2 = np.array(0.8 * np.eye(2) + 0.01 * np.ones((2, 2)))
86+
C1 = get_sbm(n1, nc1, ratio1, P1)
87+
C2 = get_sbm(n2, nc2, ratio2, P2)
88+
labels1 = get_sbm_labels(n1, ratio1)
89+
labels2 = get_sbm_labels(n2, ratio2)
90+
91+
# Use noisy one-hot encodings of the SBM classes as node features.
92+
feature_dim = max(nc1, nc2)
93+
x1 = get_noisy_one_hot(labels1, feature_dim)
94+
x2 = get_noisy_one_hot(labels2, feature_dim)
95+
all_features = np.vstack([x1, x2])
96+
feature_min = all_features[:, :3].min(axis=0, keepdims=True)
97+
feature_max = all_features[:, :3].max(axis=0, keepdims=True)
98+
99+
# get 2d positions for visualization
100+
pos1 = MDS(dissimilarity="precomputed", random_state=0, n_init=1).fit_transform(1 - C1)
101+
pos2 = MDS(dissimilarity="precomputed", random_state=0, n_init=1).fit_transform(1 - C2)
102+
103+
colors1 = np.clip(
104+
(x1 - feature_min) / np.maximum(feature_max - feature_min, 1e-15), 0.0, 1.0
105+
)
106+
colors2 = np.clip(
107+
(x2 - feature_min) / np.maximum(feature_max - feature_min, 1e-15), 0.0, 1.0
108+
)
109+
110+
111+
pl.figure(1, (10, 5))
112+
pl.clf()
113+
pl.subplot(1, 2, 1)
114+
plot_graph(pos1, C1, color=colors1)
115+
pl.title("SBM source graph")
116+
pl.axis("off")
117+
pl.subplot(1, 2, 2)
118+
plot_graph(pos2, C2, color=colors2)
119+
pl.title("SBM target graph")
120+
_ = pl.axis("off")
121+
122+
123+
# %%
124+
# Solve FUGW with Adam
125+
# ----------------
126+
127+
# Even though `loss_fugw_batch` supports batches of problems, we use a
128+
# batch of size 1 here for clarity.
129+
130+
a = ot.unif(C1.shape[0])
131+
b = ot.unif(C2.shape[0])
132+
M = ot.dist(x1, x2)
133+
M /= M.max()
134+
135+
a_torch = torch.tensor(a[None, :])
136+
b_torch = torch.tensor(b[None, :])
137+
C1_torch = torch.tensor(C1[None, :, :])
138+
C2_torch = torch.tensor(C2[None, :, :])
139+
M_torch = torch.tensor(M[None, :, :])
140+
L = tensor_batch(a_torch, b_torch, C1_torch, C2_torch, loss="sqeuclidean")
141+
142+
alpha = 0.5
143+
reg_marginals = 0.5
144+
lr = 5e-2
145+
nb_iter_max = 1500
146+
tol = 1e-7
147+
148+
T0_torch = a_torch[:, :, None] * b_torch[:, None, :]
149+
T_torch = torch.log(torch.expm1(T0_torch)).clone().requires_grad_(True)
150+
optimizer = torch.optim.Adam([T_torch], lr=lr)
151+
loss_iter = []
152+
mass_iter = []
153+
previous_plan_torch = None
154+
155+
tic = perf_counter()
156+
for i in range(nb_iter_max):
157+
optimizer.zero_grad()
158+
# Positive transport plan parameterized as log(1 + exp(T)).
159+
plan_torch = torch.nn.functional.softplus(T_torch)
160+
loss = loss_quadratic_batch(
161+
a_torch,
162+
b_torch,
163+
C1_torch,
164+
C2_torch,
165+
plan_torch,
166+
M_torch,
167+
alpha=alpha,
168+
unbalanced=reg_marginals,
169+
unbalanced_type="kl",
170+
recompute_const=True,
171+
)[0]
172+
173+
loss_iter.append(float(loss.detach()))
174+
mass_iter.append(float(plan_torch.detach().sum()))
175+
if previous_plan_torch is not None:
176+
err = float(torch.sum(torch.abs(plan_torch.detach() - previous_plan_torch)))
177+
if err < tol:
178+
break
179+
previous_plan_torch = plan_torch.detach().clone()
180+
loss.backward()
181+
optimizer.step()
182+
time_adam = perf_counter() - tic
183+
184+
T_adam = torch.nn.functional.softplus(T_torch).detach().cpu().numpy()[0]
185+
186+
187+
# %%
188+
# Compare with the dedicated FUGW solver
189+
# -------------------------------------
190+
#
191+
# The dedicated solver uses a block coordinate descent (BCD) scheme. We compare
192+
# the coupling it returns with the one obtained by direct Adam minimization of
193+
# `loss_fugw_batch`.
194+
195+
196+
def evaluate_batch_fugw_loss(plan):
197+
plan_torch = torch.tensor(plan[None, :, :], dtype=M_torch.dtype)
198+
loss = loss_quadratic_batch(
199+
a_torch,
200+
b_torch,
201+
C1_torch,
202+
C2_torch,
203+
plan_torch,
204+
M_torch,
205+
alpha=alpha,
206+
unbalanced=reg_marginals,
207+
unbalanced_type="kl",
208+
recompute_const=True,
209+
)[0]
210+
return float(loss.detach())
211+
212+
213+
tic = perf_counter()
214+
result = ot.solve_gromov(
215+
C1, C2, M, a, b, alpha=alpha, reg=0, unbalanced_type="kl", unbalanced=reg_marginals
216+
)
217+
time_bcd = perf_counter() - tic
218+
219+
loss_adam_final = evaluate_batch_fugw_loss(T_adam)
220+
T_bcd = result.plan
221+
loss_bcd_final = evaluate_batch_fugw_loss(T_bcd)
222+
mass_bcd = T_bcd.sum()
223+
224+
pl.figure(2, (10, 4))
225+
pl.clf()
226+
pl.subplot(1, 2, 1)
227+
pl.plot(loss_iter, label="Adam")
228+
pl.axhline(loss_bcd_final, color="C1", linestyle="--", label="BCD solver")
229+
pl.grid()
230+
pl.title("FUGW loss along iterations")
231+
pl.xlabel("Iterations")
232+
pl.legend()
233+
pl.subplot(1, 2, 2)
234+
pl.plot(mass_iter, label="Adam")
235+
pl.axhline(mass_bcd, color="C1", linestyle="--", label="BCD solver")
236+
pl.grid()
237+
pl.title("Transport mass")
238+
pl.xlabel("Iterations")
239+
_ = pl.legend()
240+
241+
242+
# %%
243+
# Visualize the learned couplings
244+
# -------------------------------
245+
# We visualize the couplings obtained by both methods to compare them. On this example, both methods recover similar couplings,
246+
# but direct minimization reaches a lower `loss_fugw_batch` value at the cost
247+
# of a longer runtime.
248+
249+
vmin = min(T_adam.min(), T_bcd.min())
250+
vmax = max(T_adam.max(), T_bcd.max())
251+
pl.figure(3, (10, 4))
252+
pl.clf()
253+
pl.subplot(1, 2, 1)
254+
pl.imshow(T_adam, interpolation="nearest", cmap="Blues", vmin=vmin, vmax=vmax)
255+
pl.title(
256+
f"Coupling from direct minimization\nloss={loss_adam_final:.3f}, time={time_adam:.2f}s"
257+
)
258+
pl.xlabel("Target nodes")
259+
pl.ylabel("Source nodes")
260+
pl.colorbar()
261+
pl.subplot(1, 2, 2)
262+
pl.imshow(T_bcd, interpolation="nearest", cmap="Blues", vmin=vmin, vmax=vmax)
263+
pl.title(f"Coupling from BCD solver\nloss={loss_bcd_final:.3f}, time={time_bcd:.2f}s")
264+
pl.xlabel("Target nodes")
265+
pl.ylabel("Source nodes")
266+
_ = pl.colorbar()
Binary file not shown.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"\n# Solve Fused Unbalanced Gromov Wasserstein with Adam\n\nSince the FUGW loss is differentiable, it can be minimized with first-order optimization.\nWe show how to do this with the `loss_fugw_batch` function and compare the results with\nthe dedicated FUGW solver `fused_unbalanced_gromov_wasserstein`.\n"
8+
]
9+
},
10+
{
11+
"cell_type": "code",
12+
"execution_count": null,
13+
"metadata": {
14+
"collapsed": false
15+
},
16+
"outputs": [],
17+
"source": [
18+
"# Author: R\u00e9mi Flamary <remi.flamary@polytechnique.edu>\n# Sonia Mazelet <sonia.mazelet@polytechnique.edu>\n#\n# License: MIT License\n\n# sphinx_gallery_thumbnail_number = 3\n\nimport numpy as np\nimport matplotlib.pylab as pl\nimport torch\nfrom time import perf_counter\nimport ot\nfrom ot.batch._quadratic import loss_quadratic_batch, tensor_batch\nfrom ot.gromov import fused_unbalanced_gromov_wasserstein\nfrom sklearn.manifold import MDS"
19+
]
20+
},
21+
{
22+
"cell_type": "markdown",
23+
"metadata": {},
24+
"source": [
25+
"## Generation of source and target graphs\n\n"
26+
]
27+
},
28+
{
29+
"cell_type": "code",
30+
"execution_count": null,
31+
"metadata": {
32+
"collapsed": false
33+
},
34+
"outputs": [],
35+
"source": [
36+
"rng = np.random.RandomState(42)\n\n\ndef get_sbm(n, nc, ratio, P):\n nbpc = np.round(n * ratio).astype(int)\n n = np.sum(nbpc)\n C = np.zeros((n, n))\n for c1 in range(nc):\n for c2 in range(c1 + 1):\n if c1 == c2:\n for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])):\n for j in range(np.sum(nbpc[:c2]), i):\n if rng.rand() <= P[c1, c2]:\n C[i, j] = 1\n else:\n for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[: c1 + 1])):\n for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[: c2 + 1])):\n if rng.rand() <= P[c1, c2]:\n C[i, j] = 1\n\n return C + C.T\n\n\ndef plot_graph(x, C, color=\"C0\", s=100):\n for j in range(C.shape[0]):\n for i in range(j):\n if C[i, j] > 0:\n pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color=\"k\")\n pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors=\"k\")\n\n\ndef get_sbm_labels(n, ratio):\n nbpc = np.round(n * ratio).astype(int)\n return np.concatenate(\n [np.full(count, label, dtype=int) for label, count in enumerate(nbpc)]\n )\n\n\ndef get_noisy_one_hot(labels, n_classes, noise_level=0.1):\n x = np.eye(n_classes)[labels]\n x += noise_level * rng.randn(*x.shape)\n return x\n\n\nn1 = 15\nn2 = 10\nnc1 = 3\nnc2 = 2\nratio1 = np.array([0.33, 0.33, 0.33])\nratio2 = np.array([0.5, 0.5])\n\nP1 = np.array([[0.8, 0.03, 0.0], [0.08, 0.8, 0.03], [0.0, 0.08, 0.8]])\nP2 = np.array(0.8 * np.eye(2) + 0.01 * np.ones((2, 2)))\nC1 = get_sbm(n1, nc1, ratio1, P1)\nC2 = get_sbm(n2, nc2, ratio2, P2)\nlabels1 = get_sbm_labels(n1, ratio1)\nlabels2 = get_sbm_labels(n2, ratio2)\n\n# Use noisy one-hot encodings of the SBM classes as node features.\nfeature_dim = max(nc1, nc2)\nx1 = get_noisy_one_hot(labels1, feature_dim)\nx2 = get_noisy_one_hot(labels2, feature_dim)\nall_features = np.vstack([x1, x2])\nfeature_min = all_features[:, :3].min(axis=0, keepdims=True)\nfeature_max = all_features[:, :3].max(axis=0, keepdims=True)\n\n# get 2d positions for visualization\npos1 = MDS(dissimilarity=\"precomputed\", random_state=0, n_init=1).fit_transform(1 - C1)\npos2 = MDS(dissimilarity=\"precomputed\", random_state=0, n_init=1).fit_transform(1 - C2)\n\ncolors1 = np.clip(\n (x1 - feature_min) / np.maximum(feature_max - feature_min, 1e-15), 0.0, 1.0\n)\ncolors2 = np.clip(\n (x2 - feature_min) / np.maximum(feature_max - feature_min, 1e-15), 0.0, 1.0\n)\n\n\npl.figure(1, (10, 5))\npl.clf()\npl.subplot(1, 2, 1)\nplot_graph(pos1, C1, color=colors1)\npl.title(\"SBM source graph\")\npl.axis(\"off\")\npl.subplot(1, 2, 2)\nplot_graph(pos2, C2, color=colors2)\npl.title(\"SBM target graph\")\n_ = pl.axis(\"off\")"
37+
]
38+
},
39+
{
40+
"cell_type": "markdown",
41+
"metadata": {},
42+
"source": [
43+
"## Solve FUGW with Adam\n\n"
44+
]
45+
},
46+
{
47+
"cell_type": "code",
48+
"execution_count": null,
49+
"metadata": {
50+
"collapsed": false
51+
},
52+
"outputs": [],
53+
"source": [
54+
"# Even though `loss_fugw_batch` supports batches of problems, we use a\n# batch of size 1 here for clarity.\n\na = ot.unif(C1.shape[0])\nb = ot.unif(C2.shape[0])\nM = ot.dist(x1, x2)\nM /= M.max()\n\na_torch = torch.tensor(a[None, :])\nb_torch = torch.tensor(b[None, :])\nC1_torch = torch.tensor(C1[None, :, :])\nC2_torch = torch.tensor(C2[None, :, :])\nM_torch = torch.tensor(M[None, :, :])\nL = tensor_batch(a_torch, b_torch, C1_torch, C2_torch, loss=\"sqeuclidean\")\n\nalpha = 0.5\nreg_marginals = 0.5\nlr = 5e-2\nnb_iter_max = 1500\ntol = 1e-7\n\nT0_torch = a_torch[:, :, None] * b_torch[:, None, :]\nT_torch = torch.log(torch.expm1(T0_torch)).clone().requires_grad_(True)\noptimizer = torch.optim.Adam([T_torch], lr=lr)\nloss_iter = []\nmass_iter = []\nprevious_plan_torch = None\n\ntic = perf_counter()\nfor i in range(nb_iter_max):\n optimizer.zero_grad()\n # Positive transport plan parameterized as log(1 + exp(T)).\n plan_torch = torch.nn.functional.softplus(T_torch)\n loss = loss_quadratic_batch(\n a_torch,\n b_torch,\n C1_torch,\n C2_torch,\n plan_torch,\n M_torch,\n alpha=alpha,\n unbalanced=reg_marginals,\n unbalanced_type=\"kl\",\n recompute_const=True,\n )[0]\n\n loss_iter.append(float(loss.detach()))\n mass_iter.append(float(plan_torch.detach().sum()))\n if previous_plan_torch is not None:\n err = float(torch.sum(torch.abs(plan_torch.detach() - previous_plan_torch)))\n if err < tol:\n break\n previous_plan_torch = plan_torch.detach().clone()\n loss.backward()\n optimizer.step()\ntime_adam = perf_counter() - tic\n\nT_adam = torch.nn.functional.softplus(T_torch).detach().cpu().numpy()[0]"
55+
]
56+
},
57+
{
58+
"cell_type": "markdown",
59+
"metadata": {},
60+
"source": [
61+
"## Compare with the dedicated FUGW solver\n\nThe dedicated solver uses a block coordinate descent (BCD) scheme. We compare\nthe coupling it returns with the one obtained by direct Adam minimization of\n`loss_fugw_batch`.\n\n"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": null,
67+
"metadata": {
68+
"collapsed": false
69+
},
70+
"outputs": [],
71+
"source": [
72+
"def evaluate_batch_fugw_loss(plan):\n plan_torch = torch.tensor(plan[None, :, :], dtype=M_torch.dtype)\n loss = loss_quadratic_batch(\n a_torch,\n b_torch,\n C1_torch,\n C2_torch,\n plan_torch,\n M_torch,\n alpha=alpha,\n unbalanced=reg_marginals,\n unbalanced_type=\"kl\",\n recompute_const=True,\n )[0]\n return float(loss.detach())\n\n\ntic = perf_counter()\nresult = ot.solve_gromov(\n C1, C2, M, a, b, alpha=alpha, reg=0, unbalanced_type=\"kl\", unbalanced=reg_marginals\n)\ntime_bcd = perf_counter() - tic\n\nloss_adam_final = evaluate_batch_fugw_loss(T_adam)\nT_bcd = result.plan\nloss_bcd_final = evaluate_batch_fugw_loss(T_bcd)\nmass_bcd = T_bcd.sum()\n\npl.figure(2, (10, 4))\npl.clf()\npl.subplot(1, 2, 1)\npl.plot(loss_iter, label=\"Adam\")\npl.axhline(loss_bcd_final, color=\"C1\", linestyle=\"--\", label=\"BCD solver\")\npl.grid()\npl.title(\"FUGW loss along iterations\")\npl.xlabel(\"Iterations\")\npl.legend()\npl.subplot(1, 2, 2)\npl.plot(mass_iter, label=\"Adam\")\npl.axhline(mass_bcd, color=\"C1\", linestyle=\"--\", label=\"BCD solver\")\npl.grid()\npl.title(\"Transport mass\")\npl.xlabel(\"Iterations\")\n_ = pl.legend()"
73+
]
74+
},
75+
{
76+
"cell_type": "markdown",
77+
"metadata": {},
78+
"source": [
79+
"## Visualize the learned couplings\nWe visualize the couplings obtained by both methods to compare them. On this example, both methods recover similar couplings,\nbut direct minimization reaches a lower `loss_fugw_batch` value at the cost\nof a longer runtime.\n\n"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {
86+
"collapsed": false
87+
},
88+
"outputs": [],
89+
"source": [
90+
"vmin = min(T_adam.min(), T_bcd.min())\nvmax = max(T_adam.max(), T_bcd.max())\npl.figure(3, (10, 4))\npl.clf()\npl.subplot(1, 2, 1)\npl.imshow(T_adam, interpolation=\"nearest\", cmap=\"Blues\", vmin=vmin, vmax=vmax)\npl.title(\n f\"Coupling from direct minimization\\nloss={loss_adam_final:.3f}, time={time_adam:.2f}s\"\n)\npl.xlabel(\"Target nodes\")\npl.ylabel(\"Source nodes\")\npl.colorbar()\npl.subplot(1, 2, 2)\npl.imshow(T_bcd, interpolation=\"nearest\", cmap=\"Blues\", vmin=vmin, vmax=vmax)\npl.title(f\"Coupling from BCD solver\\nloss={loss_bcd_final:.3f}, time={time_bcd:.2f}s\")\npl.xlabel(\"Target nodes\")\npl.ylabel(\"Source nodes\")\n_ = pl.colorbar()"
91+
]
92+
}
93+
],
94+
"metadata": {
95+
"kernelspec": {
96+
"display_name": "Python 3",
97+
"language": "python",
98+
"name": "python3"
99+
},
100+
"language_info": {
101+
"codemirror_mode": {
102+
"name": "ipython",
103+
"version": 3
104+
},
105+
"file_extension": ".py",
106+
"mimetype": "text/x-python",
107+
"name": "python",
108+
"nbconvert_exporter": "python",
109+
"pygments_lexer": "ipython3",
110+
"version": "3.12.13"
111+
}
112+
},
113+
"nbformat": 4,
114+
"nbformat_minor": 0
115+
}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)