diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index e038b49a1..4974302d9 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -42,7 +42,7 @@ The contributors to this library are: * [Cédric Vincent-Cuaz](https://github.com/cedricvincentcuaz) (Graph Dictionary Learning, FGW, semi-relaxed FGW, quantized FGW, partial FGW) * [Eloi Tanguy](https://github.com/eloitanguy) (Generalized Wasserstein - Barycenters, GMMOT) + Barycenters, GMMOT, Barycenters for General Transport Costs) * [Camille Le Coz](https://www.linkedin.com/in/camille-le-coz-8593b91a1/) (EMD2 debug) * [Eduardo Fernandes Montesuma](https://eddardd.github.io/my-personal-blog/) (Free support sinkhorn barycenter) * [Theo Gnassounou](https://github.com/tgnassou) (OT between Gaussian distributions) diff --git a/README.md b/README.md index 8b4cca7f7..f0e256eb0 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,8 @@ POT provides the following generic OT solvers (links to examples): * [Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_COOT.html) [49] and [unbalanced Co-Optimal Transport](https://pythonot.github.io/auto_examples/others/plot_learning_weights_with_COOT.html) [71]. * Fused unbalanced Gromov-Wasserstein [70]. +* [Optimal Transport Barycenters for Generic Costs](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter_generic_cost.html) [76] +* [Barycenters between Gaussian Mixture Models](https://pythonot.github.io/auto_examples/barycenters/plot_gmm_barycenter.html) [69, 76] POT provides the following Machine Learning related solvers: @@ -389,3 +391,5 @@ Artificial Intelligence. [74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR. [75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145. + +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). [Computing Barycentres of Measures for Generic Transport Costs](https://arxiv.org/abs/2501.04016). arXiv preprint 2501.04016 (2024) diff --git a/RELEASES.md b/RELEASES.md index ec7e5774c..1dbebcf2d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -7,6 +7,9 @@ - Added feature `grad=last_step` for `ot.solvers.solve` (PR #693) - Automatic PR labeling and release file update check (PR #704) - Reorganize sub-module `ot/lp/__init__.py` into separate files (PR #714) +- Implement fixed-point solver for OT barycenters with generic cost functions + (generalizes `ot.lp.free_support_barycenter`), with example. (PR #715) +- Implement fixed-point solver for barycenters between GMMs (PR #715), with example. - Fix warning raise when import the library (PR #716) - Implement projected gradient descent solvers for entropic partial FGW (PR #702) - Fix documentation in the module `ot.gaussian` (PR #718) diff --git a/examples/barycenters/plot_free_support_barycenter_generic_cost.py b/examples/barycenters/plot_free_support_barycenter_generic_cost.py new file mode 100644 index 000000000..536303a58 --- /dev/null +++ b/examples/barycenters/plot_free_support_barycenter_generic_cost.py @@ -0,0 +1,284 @@ +# -*- coding: utf-8 -*- +""" +===================================== +OT Barycenter with Generic Costs Demo +===================================== + +This example illustrates the computation of an Optimal Transport Barycenter for +a ground cost that is not a power of a norm. We take the example of ground costs +:math:`c_k(x, y) = \lambda_k\|P_k(x)-y\|_2^2`, where :math:`P_k` is the +(non-linear) projection onto a circle k, and :math:`(\lambda_k)` are weights. A +barycenter is defined ([76]) as a minimiser of the energy :math:`V(\mu) = \sum_k +\mathcal{T}_{c_k}(\mu, \nu_k)` where :math:`\mu` is a candidate barycenter +measure, the measures :math:`\nu_k` are the target measures and +:math:`\mathcal{T}_{c_k}` is the OT cost for ground cost :math:`c_k`. This is an +example of the fixed-point barycenter solver introduced in [76] which +generalises [20] and [43]. + +The ground barycenter function :math:`B(y_1, ..., y_K) = \mathrm{argmin}_{x \in +\mathbb{R}^2} \sum_k \lambda_k c_k(x, y_k)` is computed by gradient descent over +:math:`x` with Pytorch. + +We compare two algorithms from [76]: the first ([76], Algorithm 2, +'true_fixed_point' in POT) has convergence guarantees but the iterations may +increase in support size and thus require more computational resources. The +second ([76], Algorithm 3, 'L2_barycentric_proj' in POT) is a simplified +heuristic that imposes a fixed support size for the barycenter and fixed +weights. + +We initialise both algorithms with a support size of 136, computing a barycenter +between measures with uniform weights and 50 points. + +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 +(2024) + +[20] Cuturi, M. and Doucet, A. (2014) Fast Computation of Wasserstein +Barycenters. InternationalConference in Machine Learning + +[43] Álvarez-Esteban, Pedro C., et al. A fixed-point approach to barycenters in +Wasserstein space. Journal of Mathematical Analysis and Applications 441.2 +(2016): 744-762. + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% +# Generate data +import torch +import ot +from torch.optim import Adam +from ot.utils import dist +import numpy as np +from ot.lp import free_support_barycenter_generic_costs +import matplotlib.pyplot as plt +from time import time + + +torch.manual_seed(42) + +n = 136 # number of points of the of the barycentre +d = 2 # dimensions of the original measure +K = 4 # number of measures to barycentre +m = 50 # number of points of the measures +b_list = [torch.ones(m) / m] * K # weights of the 4 measures +weights = torch.ones(K) / K # weights for the barycentre +stop_threshold = 1e-20 # stop threshold for B and for fixed-point algo + + +# map R^2 -> R^2 projection onto circle +def proj_circle(X, origin, radius): + diffs = X - origin[None, :] + norms = torch.norm(diffs, dim=1) + return origin[None, :] + radius * diffs / norms[:, None] + + +# circles on which to project +origin1 = torch.tensor([-1.0, -1.0]) +origin2 = torch.tensor([-1.0, 2.0]) +origin3 = torch.tensor([2.0, 2.0]) +origin4 = torch.tensor([2.0, -1.0]) +r = np.sqrt(2) +P_list = [ + lambda X: proj_circle(X, origin1, r), + lambda X: proj_circle(X, origin2, r), + lambda X: proj_circle(X, origin3, r), + lambda X: proj_circle(X, origin4, r), +] + +# measures to barycentre are projections of different random circles +# onto the K circles +Y_list = [] +for k in range(K): + t = torch.rand(m) * 2 * np.pi + X_temp = 0.5 * torch.stack([torch.cos(t), torch.sin(t)], axis=1) + X_temp = X_temp + torch.tensor([0.5, 0.5])[None, :] + Y_list.append(P_list[k](X_temp)) + + +# %% +# Define costs and ground barycenter function +# cost_list[k] is a function taking x (n, d) and y (n_k, d_k) and returning a +# (n, n_k) matrix of costs +def c1(x, y): + return dist(P_list[0](x), y) + + +def c2(x, y): + return dist(P_list[1](x), y) + + +def c3(x, y): + return dist(P_list[2](x), y) + + +def c4(x, y): + return dist(P_list[3](x), y) + + +cost_list = [c1, c2, c3, c4] + + +# batched total ground cost function for candidate points x (n, d) +# for computation of the ground barycenter B with gradient descent +def C(x, y): + """ + Computes the barycenter cost for candidate points x (n, d) and + measure supports y: List(n, d_k). + """ + n = x.shape[0] + K = len(y) + out = torch.zeros(n) + for k in range(K): + out += (1 / K) * torch.sum((P_list[k](x) - y[k]) ** 2, axis=1) + return out + + +# ground barycenter function +def B(y, its=150, lr=1, stop_threshold=stop_threshold): + """ + Computes the ground barycenter for measure supports y: List(n, d_k). + Output: (n, d) array + """ + x = torch.randn(y[0].shape[0], d) + x.requires_grad_(True) + opt = Adam([x], lr=lr) + for _ in range(its): + x_prev = x.data.clone() + opt.zero_grad() + loss = torch.sum(C(x, y)) + loss.backward() + opt.step() + diff = torch.sum((x.data - x_prev) ** 2) + if diff < stop_threshold: + break + return x + + +# %% +# Compute the barycenter measure with the true fixed-point algorithm +fixed_point_its = 5 +torch.manual_seed(42) +X_init = torch.rand(n, d) +t0 = time() +X_bar, a_bar, log_dict = free_support_barycenter_generic_costs( + Y_list, + b_list, + X_init, + cost_list, + B, + numItermax=fixed_point_its, + stopThr=stop_threshold, + method="true_fixed_point", + log=True, + clean_measure=True, +) +dt_true_fixed_point = time() - t0 + +# %% +# Compute the barycenter measure with the barycentric (default) algorithm +fixed_point_its = 5 +torch.manual_seed(42) +X_init = torch.rand(n, d) +t0 = time() +X_bar2, log_dict2 = free_support_barycenter_generic_costs( + Y_list, + b_list, + X_init, + cost_list, + B, + numItermax=fixed_point_its, + stopThr=stop_threshold, + log=True, +) +dt_barycentric = time() - t0 + +# %% +# Plot Barycenters (Iteration 3) +alpha = 0.4 +s = 80 +labels = ["circle 1", "circle 2", "circle 3", "circle 4"] + + +# Compute barycenter energies +def V(X, a): + v = 0 + for k in range(K): + v += (1 / K) * ot.emd2(a, b_list[k], cost_list[k](X, Y_list[k])) + return v + + +fig, axes = plt.subplots(1, 2, figsize=(12, 6)) + +# Plot for the true fixed-point algorithm +for Y, label in zip(Y_list, labels): + axes[0].scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) +axes[0].scatter( + *(X_bar.detach().numpy()).T, + label="Barycenter", + c="black", + alpha=alpha * a_bar.numpy() / np.max(a_bar.numpy()), + s=s, +) +axes[0].set_title( + "True Fixed-Point Algorithm\n" + f"Support size: {a_bar.shape[0]}\n" + f"Barycenter cost: {V(X_bar, a_bar).item():.6f}\n" + f"Computation time {dt_true_fixed_point:.4f}s" +) +axes[0].axis("equal") +axes[0].axis("off") +axes[0].legend() + +# Plot for the heuristic algorithm +for Y, label in zip(Y_list, labels): + axes[1].scatter(*(Y.numpy()).T, alpha=alpha, label=label, s=s) +axes[1].scatter( + *(X_bar2.detach().numpy()).T, label="Barycenter", c="black", alpha=alpha, s=s +) +axes[1].set_title( + "Heuristic Barycentric Algorithm\n" + f"Support size: {X_bar2.shape[0]}\n" + f"Barycenter cost: {V(X_bar2, torch.ones(n) / n).item():.6f}\n" + f"Computation time {dt_barycentric:.4f}s" +) +axes[1].axis("equal") +axes[1].axis("off") +axes[1].legend() + +plt.tight_layout() + +# %% +# Plot energy convergence +fig, axes = plt.subplots(1, 2, figsize=(8, 4)) + +V_list = [V(X, a).item() for (X, a) in zip(log_dict["X_list"], log_dict["a_list"])] +V_list2 = [V(X, torch.ones(n) / n).item() for X in log_dict2["X_list"]] + +# Plot for True Fixed-Point Algorithm +axes[0].plot(V_list, lw=5, alpha=0.6) +axes[0].scatter(range(len(V_list)), V_list, color="blue", alpha=0.8, s=100) +axes[0].set_title("True Fixed-Point Algorithm") +axes[0].set_xlabel("Iteration") +axes[0].set_ylabel("Barycenter Energy") +axes[0].set_yscale("log") +axes[0].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) + +# Plot for Heuristic Barycentric Algorithm +axes[1].plot(V_list2, lw=5, alpha=0.6) +axes[1].scatter(range(len(V_list2)), V_list2, color="blue", alpha=0.8, s=100) +axes[1].set_title("Heuristic Barycentric Algorithm") +axes[1].set_xlabel("Iteration") +axes[1].set_ylabel("Barycenter Energy") +axes[1].set_yscale("log") +axes[1].xaxis.set_major_locator(plt.MaxNLocator(integer=True)) + +plt.tight_layout() +plt.show() + +# %% diff --git a/examples/barycenters/plot_generalized_free_support_barycenter.py b/examples/barycenters/plot_generalized_free_support_barycenter.py index 5b3572bd4..b21c66f13 100644 --- a/examples/barycenters/plot_generalized_free_support_barycenter.py +++ b/examples/barycenters/plot_generalized_free_support_barycenter.py @@ -14,7 +14,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # # License: MIT License diff --git a/examples/barycenters/plot_gmm_barycenter.py b/examples/barycenters/plot_gmm_barycenter.py new file mode 100644 index 000000000..6dd0ad8be --- /dev/null +++ b/examples/barycenters/plot_gmm_barycenter.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +""" +===================================== +Gaussian Mixture Model OT Barycenters +===================================== + +This example illustrates the computation of a barycenter between Gaussian +Mixtures in the sense of GMM-OT [69]. This computation is done using the +fixed-point method for OT barycenters with generic costs [76], for which POT +provides a general solver, and a specific GMM solver. Note that this is a +'free-support' method, implying that the number of components of the barycenter +GMM and their weights are fixed. + +The idea behind GMM-OT barycenters is to see the GMMs as discrete measures over +the space of Gaussian distributions :math:`\mathcal{N}` (or equivalently the +Bures-Wasserstein manifold), and to compute barycenters with respect to the +2-Wasserstein distance between measures in :math:`\mathcal{P}(\mathcal{N})`: a +gaussian mixture is a finite combination of Diracs on specific gaussians, and +two mixtures are compared with the 2-Wasserstein distance on this space, where +ground cost the squared Bures distance between gaussians. + +[69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space +of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + +[76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing +Barycentres of Measures for Generic Transport Costs. arXiv preprint 2501.04016 +(2024) + +""" + +# Author: Eloi Tanguy +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 1 + +# %% +# Generate data +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patches import Ellipse +import ot +from ot.gmm import gmm_barycenter_fixed_point + + +K = 3 # number of GMMs +d = 2 # dimension +n = 6 # number of components of the desired barycenter + + +def get_random_gmm(K, d, seed=0, min_cov_eig=1, cov_scale=1e-2): + rng = np.random.RandomState(seed=seed) + means = rng.randn(K, d) + P = rng.randn(K, d, d) * cov_scale + # C[k] = P[k] @ P[k]^T + min_cov_eig * I + covariances = np.einsum("kab,kcb->kac", P, P) + covariances += min_cov_eig * np.array([np.eye(d) for _ in range(K)]) + weights = rng.random(K) + weights /= np.sum(weights) + return means, covariances, weights + + +m_list = [5, 6, 7] # number of components in each GMM +offsets = [np.array([-3, 0]), np.array([2, 0]), np.array([0, 4])] +means_list = [] # list of means for each GMM +covs_list = [] # list of covariances for each GMM +w_list = [] # list of weights for each GMM + +# generate GMMs +for k in range(K): + means, covs, b = get_random_gmm( + m_list[k], d, seed=k, min_cov_eig=0.25, cov_scale=0.5 + ) + means = means / 2 + offsets[k][None, :] + means_list.append(means) + covs_list.append(covs) + w_list.append(b) + +# %% +# Compute the barycenter using the fixed-point method +init_means, init_covs, _ = get_random_gmm(n, d, seed=0) +weights = ot.unif(K) # barycenter coefficients +means_bar, covs_bar, log = gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + init_means, + init_covs, + weights, + iterations=3, + log=True, +) + + +# %% +# Define plotting functions + + +# draw a covariance ellipse +def draw_cov(mu, C, color=None, label=None, nstd=1, alpha=0.5, ax=None): + def eigsorted(cov): + vals, vecs = np.linalg.eigh(cov) + order = vals.argsort()[::-1].copy() + return vals[order], vecs[:, order] + + vals, vecs = eigsorted(C) + theta = np.degrees(np.arctan2(*vecs[:, 0][::-1])) + w, h = 2 * nstd * np.sqrt(vals) + ell = Ellipse( + xy=(mu[0], mu[1]), + width=w, + height=h, + alpha=alpha, + angle=theta, + facecolor=color, + edgecolor=color, + label=label, + fill=True, + ) + if ax is None: + ax = plt.gca() + ax.add_artist(ell) + + +# draw a gmm as a set of ellipses with weights shown in alpha value +def draw_gmm(ms, Cs, ws, color=None, nstd=0.5, alpha=1, label=None, ax=None): + for k in range(ms.shape[0]): + draw_cov( + ms[k], Cs[k], color, label if k == 0 else None, nstd, alpha * ws[k], ax=ax + ) + + +# %% +# Plot the results +c_list = ["#7ED321", "#4A90E2", "#9013FE", "#F5A623"] +c_bar = "#D0021B" +fig, ax = plt.subplots(figsize=(6, 6)) +axis = [-4, 4, -2, 6] +ax.set_title("Fixed Point Barycenter (3 Iterations)", fontsize=16) +for k in range(K): + draw_gmm(means_list[k], covs_list[k], w_list[k], color=c_list[k], ax=ax) +draw_gmm(means_bar, covs_bar, ot.unif(n), color=c_bar, ax=ax) +ax.axis(axis) +ax.axis("off") + +# %% diff --git a/examples/others/plot_GMMOT_plan.py b/examples/others/plot_GMMOT_plan.py index 7742d496e..4964ddd66 100644 --- a/examples/others/plot_GMMOT_plan.py +++ b/examples/others/plot_GMMOT_plan.py @@ -16,7 +16,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # diff --git a/examples/others/plot_GMM_flow.py b/examples/others/plot_GMM_flow.py index beb675755..dc26ff3ce 100644 --- a/examples/others/plot_GMM_flow.py +++ b/examples/others/plot_GMM_flow.py @@ -10,7 +10,7 @@ """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # diff --git a/examples/others/plot_SSNB.py b/examples/others/plot_SSNB.py index fbc343a8a..e167b1ee4 100644 --- a/examples/others/plot_SSNB.py +++ b/examples/others/plot_SSNB.py @@ -38,7 +38,7 @@ 2017. """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # License: MIT License # sphinx_gallery_thumbnail_number = 3 diff --git a/ot/gmm.py b/ot/gmm.py index 5c7a4c287..a065c73b0 100644 --- a/ot/gmm.py +++ b/ot/gmm.py @@ -3,8 +3,8 @@ Optimal transport for Gaussian Mixtures """ -# Author: Eloi Tanguy -# Remi Flamary +# Author: Eloi Tanguy +# Remi Flamary # Julie Delon # # License: MIT License @@ -13,7 +13,7 @@ from .lp import emd2, emd import numpy as np from .utils import dist -from .gaussian import bures_wasserstein_mapping +from .gaussian import bures_wasserstein_mapping, bures_wasserstein_barycenter def gaussian_logpdf(x, m, C): @@ -440,3 +440,148 @@ def Tk0k1(k0, k1): ] ) return nx.sum(mat, axis=(0, 1)) + + +def gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + w_bar=None, + iterations=100, + log=False, + barycentric_proj_method="euclidean", +): + r""" + Solves the Gaussian Mixture Model OT barycenter problem (defined in [69]) + using the fixed point algorithm (proposed in [76]). The + weights of the barycenter are not optimized, and stay the same as the input + `w_list` or are initialized to uniform. + + The algorithm uses barycentric projections of GMM-OT plans, and these can be + computed either through Bures Barycenters (slow but accurate, + barycentric_proj_method='bures') or by convex combination (fast, + barycentric_proj_method='euclidean', default). + + This is a special case of the generic free-support barycenter solver + `ot.lp.free_support_barycenter_generic_costs`. + + Parameters + ---------- + means_list : list of array-like + List of K (m_k, d) GMM means. + covs_list : list of array-like + List of K (m_k, d, d) GMM covariances. + w_list : list of array-like + List of K (m_k) arrays of weights. + means_init : array-like + Initial (n, d) GMM means. + covs_init : array-like + Initial (n, d, d) GMM covariances. + weights : array-like + Array (K,) of the barycentre coefficients. + w_bar : array-like, optional + Initial weights (n) of the barycentre GMM. If None, initialized to uniform. + iterations : int, optional + Number of iterations (default is 100). + log : bool, optional + Whether to return the list of iterations (default is False). + barycentric_proj_method : str, optional + Method to project the barycentre weights: 'euclidean' (default) or 'bures'. + + Returns + ------- + means : array-like + (n, d) barycentre GMM means. + covs : array-like + (n, d, d) barycentre GMM covariances. + log_dict : dict, optional + Dictionary containing the list of iterations if log is True. + + References + ---------- + .. [69] Delon, J., & Desolneux, A. (2020). A Wasserstein-type distance in the space of Gaussian mixture models. SIAM Journal on Imaging Sciences, 13(2), 936-970. + + .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing barycenters of Measures for Generic Transport Costs. arXiv preprint 2501.04016 (2024) + + See Also + -------- + ot.lp.free_support_barycenter_generic_costs : Compute barycenter of measures for generic transport costs. + """ + nx = get_backend( + means_init, covs_init, means_list[0], covs_list[0], w_list[0], weights + ) + K = len(means_list) + n = means_init.shape[0] + d = means_init.shape[1] + means_its = [nx.copy(means_init)] + covs_its = [nx.copy(covs_init)] + means, covs = means_init, covs_init + + if w_bar is None: + w_bar = nx.ones(n, type_as=means) / n + + for _ in range(iterations): + pi_list = [ + gmm_ot_plan(means, means_list[k], covs, covs_list[k], w_bar, w_list[k]) + for k in range(K) + ] + + # filled in the euclidean case + means_selection, covs_selection = None, None + + # in the euclidean case, the selection of Gaussians from each K sources + # comes from a barycentric projection: it is a convex combination of the + # selected means and covariances, which can be computed without a + # for loop on i = 0, ..., n -1 + if barycentric_proj_method == "euclidean": + means_selection = nx.zeros((n, K, d), type_as=means) + covs_selection = nx.zeros((n, K, d, d), type_as=means) + for k in range(K): + means_selection[:, k, :] = n * pi_list[k] @ means_list[k] + covs_selection[:, k, :, :] = ( + nx.einsum("ij,jab->iab", pi_list[k], covs_list[k]) * n + ) + + # each component i of the barycentre will be a Bures barycentre of the + # selected components of the K GMMs. In the 'bures' barycentric + # projection option, the selected components are also Bures barycentres. + for i in range(n): + # means_selection_i (K, d) is the selected means, each comes from a + # Gaussian barycentre along the disintegration of pi_k at i + # covs_selection_i (K, d, d) are the selected covariances + means_selection_i = None + covs_selection_i = None + + # use previous computation (convex combination) + if barycentric_proj_method == "euclidean": + means_selection_i = means_selection[i] + covs_selection_i = covs_selection[i] + + # compute Bures barycentre of certain components to get the + # selection at i + elif barycentric_proj_method == "bures": + means_selection_i = nx.zeros((K, d), type_as=means) + covs_selection_i = nx.zeros((K, d, d), type_as=means) + for k in range(K): + w = (1 / w_bar[i]) * pi_list[k][i, :] + m, C = bures_wasserstein_barycenter(means_list[k], covs_list[k], w) + means_selection_i[k] = m + covs_selection_i[k] = C + + else: + raise ValueError("Unknown barycentric_proj_method") + + means[i], covs[i] = bures_wasserstein_barycenter( + means_selection_i, covs_selection_i, weights + ) + + if log: + means_its.append(nx.copy(means)) + covs_its.append(nx.copy(covs)) + + if log: + return means, covs, {"means_its": means_its, "covs_its": covs_its} + return means, covs diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 932b261df..03aeb958a 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -14,6 +14,8 @@ barycenter, free_support_barycenter, generalized_free_support_barycenter, + free_support_barycenter_generic_costs, + NorthWestMMGluing, ) from ..utils import check_number_threads @@ -45,4 +47,6 @@ "dmmot_monge_1dgrid_loss", "dmmot_monge_1dgrid_optimize", "check_number_threads", + "free_support_barycenter_generic_costs", + "NorthWestMMGluing", ] diff --git a/ot/lp/_barycenter_solvers.py b/ot/lp/_barycenter_solvers.py index 4779662e9..b7d7e0b16 100644 --- a/ot/lp/_barycenter_solvers.py +++ b/ot/lp/_barycenter_solvers.py @@ -10,7 +10,7 @@ from ..backend import get_backend from ..utils import dist -from ._network_simplex import emd +from ._network_simplex import emd, emd2 import numpy as np import scipy as sp @@ -199,14 +199,12 @@ def free_support_barycenter( measures_weights : list of N (k_i,) array-like Numpy arrays where each numpy array has :math:`k_i` non-negatives values summing to one representing the weights of each discrete input measure - X_init : (k,d) array-like Initialization of the support locations (on `k` atoms) of the barycenter b : (k,) array-like Initialization of the weights of the barycenter (non-negatives, sum to 1) weights : (N,) array-like Initialization of the coefficients of the barycenter (non-negatives, sum to 1) - numItermax : int, optional Max number of iterations stopThr : float, optional @@ -219,13 +217,11 @@ def free_support_barycenter( If compiled with OpenMP, chooses the number of threads to parallelize. "max" selects the highest number possible. - Returns ------- X : (k,d) array-like Support locations (on k atoms) of the barycenter - .. _references-free-support-barycenter: References ---------- @@ -426,3 +422,474 @@ def generalized_free_support_barycenter( return Y, log_dict else: return Y + + +def free_support_barycenter_generic_costs( + measure_locations, + measure_weights, + X_init, + cost_list, + ground_bary=None, + a=None, + numItermax=100, + method="L2_barycentric_proj", + stopThr=1e-5, + log=False, + ground_bary_lr=1e-2, + ground_bary_numItermax=100, + ground_bary_stopThr=1e-5, + ground_bary_solver="SGD", + clean_measure=False, +): + r""" + Solves the OT barycenter problem for generic costs using the fixed point + algorithm, iterating the ground barycenter function B on transport plans + between the current barycenter and the measures. + + The problem finds an optimal barycenter support `X` of given size (n, d) + (enforced by the initialisation), minimising a sum of pairwise transport + costs for the costs :math:`c_k`: + + .. math:: + \min_{X} \sum_{k=1}^K \mathcal{T}_{c_k}(X, a, Y_k, b_k), + + where: + + - :math:`X` (n, d) is the barycenter support, + - :math:`a` (n) is the (fixed) barycenter weights, + - :math:`Y_k` (m_k, d_k) is the k-th measure support + (`measure_locations[k]`), + - :math:`b_k` (m_k) is the k-th measure weights (`measure_weights[k]`), + - :math:`c_k: \mathbb{R}^{n\times d}\times\mathbb{R}^{m_k\times d_k} + \rightarrow \mathbb{R}_+^{n\times m_k}` is the k-th cost function + (which computes the pairwise cost matrix) + - :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is the OT cost between the barycenter measure and the k-th measure with respect to the cost :math:`c_k`: + + .. math:: + \mathcal{T}_{c_k}(X, a, Y_k, b_k) = \min_\pi \quad \langle \pi, c_k(X, Y_k) \rangle_F + + s.t. \ \pi \mathbf{1} = \mathbf{a} + + \pi^T \mathbf{1} = \mathbf{b_k} + + \pi \geq 0 + + in other words, :math:`\mathcal{T}_{c_k}(X, a, Y_k, b)` is `ot.emd2(a, b_k, + c_k(X, Y_k))`. + + The algorithm requires a given ground barycenter function `B` which computes + (broadcasted of `n`) solutions of the following minimisation problem given + :math:`(Y_1, \cdots, Y_K) \in \mathbb{R}^{n\times + d_1}\times\cdots\times\mathbb{R}^{n\times d_K}`: + + .. math:: + B(y_1, \cdots, y_K) = \mathrm{argmin}_{x \in \mathbb{R}^d} \sum_{k=1}^K c_k(x, y_k), + + where :math:`c_k(x, y_k) \in \mathbb{R}_+` is the cost between the points + :math:`x` and :math:`y_k`. The function :math:`B:\mathbb{R}^{n\times + d_1}\times \cdots\times\mathbb{R}^{n\times d_K} \longrightarrow + \mathbb{R}^{n\times d}` is an input to this function, and for certain costs + it can be computed explicitly of through a numerical solver. The input + function B takes a list of K arrays of shape (n, d_k) and returns an array + of shape (n, d). + + This function implements two algorithms: + + - Algorithm 2 from [76] when `method=true_fixed_point` is used, which may + increase the support size of the barycenter at each iteration, with a + maximum final size of :math:`N_0 + T\sum_k n_k - TK` for T iterations and + an initial support size of :math:`N_0`. The computation of the iterates is + done using the North West Corner multi-marginal gluing method. This method + has convergence guarantees [76]. + + - Algorithm 3 from [76] when `method=L2_barycentric_proj` is used, which is + a heuristic simplification which fixes the weights and support size of the + barycenter by performing barycentric projections of the pair-wise OT + matrices. This method is substantially faster than the first one, but does + not have convergence guarantees. (Default) + + The implemented methods ([76] Algorithms 2 and 3), generalises [20] and [43] + to general costs and includes convergence guarantees, including for discrete + measures. + + Parameters + ---------- + measure_locations : list of array-like + List of K arrays of measure positions, each of shape (m_k, d_k). + measure_weights : list of array-like + List of K arrays of measure weights, each of shape (m_k). + X_init : array-like + Array of shape (n, d) representing initial barycenter points. + cost_list : list of callable or callable + List of K cost functions :math:`c_k: \mathbb{R}^{n\times + d}\times\mathbb{R}^{m_k\times d_k} \rightarrow \mathbb{R}_+^{n\times + m_k}`. If cost_list is a single callable, the same cost is used K times. + ground_bary : callable or None, optional + Function List(array(n, d_k)) -> array(n, d) accepting a list of K arrays + of shape (n\times d_K), computing the ground barycenters (broadcasted + over n). If not provided, done with Adam on PyTorch (requires PyTorch + backend), inefficiently using the cost functions in `cost_list`. + a : array-like, optional + Array of shape (n,) representing weights of the barycenter + measure.Defaults to uniform. + numItermax : int, optional + Maximum number of iterations (default is 100). + method : str, optional + Barycentre method: 'L2_barycentric_proj' (default) for Euclidean + barycentric projection, or 'true_fixed_point' for iterates using the + North West Corner multi-marginal gluing method. + stopThr : float, optional + If the iterations move less than this, terminate (default is 1e-5). + log : bool, optional + Whether to return the log dictionary (default is False). + ground_bary_lr : float, optional + Learning rate for the ground barycenter solver (if auto is used). + ground_bary_numItermax : int, optional + Maximum number of iterations for the ground barycenter solver (if auto + is used). + ground_bary_stopThr : float, optional + Stop threshold for the ground barycenter solver (if auto is used). + ground_bary_solver : str, optional + Solver for auto ground bary solver (torch SGD or Adam). Default is + "SGD". + clean_measure : bool, optional + For method=='true_fixed_point', whether to clean the discrete measure + (X, a) at each iteration to remove duplicate points and sum their + weights (default is False). + + Returns + ------- + X : array-like + Array of shape (n, d) representing barycenter points. + log_dict : list of array-like, optional + log containing the exit status, list of iterations and list of + displacements if log is True. + + References + ---------- + .. [76] Tanguy, Eloi and Delon, Julie and Gozlan, Nathaël (2024). Computing + barycenters of Measures for Generic Transport Costs. arXiv preprint + 2501.04016 (2024) + + .. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein + barycenters." International Conference on Machine Learning. 2014. + + .. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to + barycenters in Wasserstein space." Journal of Mathematical Analysis and + Applications 441.2 (2016): 744-762. + + See Also + -------- + ot.lp.free_support_barycenter : Free support solver for the case where + :math:`c_k(x,y) = \lambda_k\|x-y\|_2^2`. + + ot.lp.generalized_free_support_barycenter : Free support solver for the case + where :math:`c_k(x,y) = \|P_kx-y\|_2^2` with :math:`P_k` linear. + + ot.lp.NorthWestMMGluing : gluing method used in the `true_fixed_point` method. + """ + assert method in [ + "L2_barycentric_proj", + "true_fixed_point", + ], "Method must be 'L2_barycentric_proj' or 'true_fixed_point'" + nx = get_backend(X_init, measure_locations[0]) + K = len(measure_locations) + n = X_init.shape[0] + if a is None: + a = nx.ones(n, type_as=X_init) / n + if callable(cost_list): # use the given cost for all K pairs + cost_list = [cost_list] * K + auto_ground_bary = False + + if ground_bary is None: + auto_ground_bary = True + assert str(nx) == "torch", ( + f"Backend {str(nx)} is not compatible with ground_bary=None, it" + "must be provided if not using PyTorch backend" + ) + try: + import torch + from torch.optim import Adam, SGD + + def ground_bary(y, x_init): + x = x_init.clone().detach().requires_grad_(True) + solver = Adam if ground_bary_solver == "Adam" else SGD + opt = solver([x], lr=ground_bary_lr) + for _ in range(ground_bary_numItermax): + x_prev = x.data.clone() + opt.zero_grad() + # inefficient cost computation but compatible + # with the choice of cost_list[k] giving the cost matrix + loss = torch.sum( + torch.stack( + [torch.diag(cost_list[k](x, y[k])) for k in range(K)] + ) + ) + loss.backward() + opt.step() + diff = torch.sum((x.data - x_prev) ** 2) + if diff < ground_bary_stopThr: + break + return x.detach() + + except ImportError: + raise ImportError("PyTorch is required to use ground_bary=None") + + X_list = [X_init] if log else [] # store the iterations + a_list = [nx.copy(a)] if log and method == "true_fixed_point" else [] + X = X_init + diff_list = [] # store the displacement squared norms + exit_status = "Max iterations reached" + + for _ in range(numItermax): + pi_list = [ # compute the pairwise transport plans + emd(a, measure_weights[k], cost_list[k](X, measure_locations[k])) + for k in range(K) + ] + Y_perm = [] + + if method == "L2_barycentric_proj": + a_next = a # barycentre weights are fixed + for k in range(K): # L2 barycentric projection of pi_k + Y_perm.append((1 / a[:, None]) * pi_list[k] @ measure_locations[k]) + if auto_ground_bary: # use previous position as initialization + X_next = ground_bary(Y_perm, X) + else: + X_next = ground_bary(Y_perm) + + elif method == "true_fixed_point": + # North West Corner gluing of pi_k + J, a_next = NorthWestMMGluing(pi_list) + # J is a (N, K) array of indices, w is a (N,) array of weights + # Each Y_perm[k] is a (N, d_k) array of some points in Y_list[k] + Y_perm = [measure_locations[k][J[:, k]] for k in range(K)] + # warm start impossible due to possible size mismatch + X_next = ground_bary(Y_perm) + + if clean_measure and method == "true_fixed_point": + # clean the discrete measure (X, a) to remove duplicates + X_next, a_next = _clean_discrete_measure(X_next, a_next) + + if log: + X_list.append(X_next) + if method == "true_fixed_point": + a_list.append(a_next) + + # stationary criterion: move less than the threshold + diff = emd2(a, a_next, dist(X, X_next)) + + if log: + diff_list.append(diff) + + X = X_next + a = a_next + + if diff < stopThr * nx.sum(X**2) / X.shape[0]: + exit_status = "Stationary Point" + break + + if log: + log_dict = { + "X_list": X_list, + "exit_status": exit_status, + "a_list": a_list, + "diff_list": diff_list, + } + if method == "true_fixed_point": + return X, a, log_dict + else: + return X, log_dict + + if method == "true_fixed_point": + return X, a + else: + return X + + +def _to_int_array(x): + """ + Converts an array to an integer type array. + """ + nx = get_backend(x) + if str(nx) == "numpy": + return x.astype(int) + + if str(nx) == "torch": + return x.to(int) + + if str(nx) == "jax": + return x.astype(int) + + if str(nx) == "cupy": + return x.astype(int) + + if str(nx) == "tf": + import tensorflow as tf + + return tf.cast(x, tf.int32) + + +def NorthWestMMGluing(pi_list, log=False): + r""" + Glue transport plans :math:`(pi_1, ..., pi_K)` which have a common first + marginal using the (multi-marginal) North-West Corner method. Writing the + marginals of each :math:`pi_k\in \mathbb{R}^{n\times n_l}` as :math:`a \in + \mathbb{R}^n` and :math:`b_k \in \mathbb{R}^{n_k}`, the output represents a + particular K-marginal transport plan :math:`\rho \in + \mathbb{R}^{n_1\times\cdots\times n_K}` whose k-th marginal is :math:`b_k`. + This K-plan is such that there exists a K+1-marginal transport plan + :math:`\gamma \in \mathbb{R}^{n\times n_1 \times \cdots \times n_K}` such + that :math:`\sum_i\gamma_{i,j_1,\cdots,j_K} = \rho_{j_1, \cdots, j_K}` and + with Einstein summation convention, :math:`\gamma_{i, j_1, \cdots, j_K} = + [\pi_k]_{i, j_k}` for all :math:`k=1,\cdots,K`. + + Instead of outputting the full K-multi-marginal plan :math:`\rho`, this + function provides an array `J` of shape (N, K) where each `J[i]` is of the + form `(J[i, 1], ..., J[i, K])` with each `J[i, k]` between 0 and + :math:`n_k-1`, and a weight vector `w` of size N, such that the K-plan + :math:`rho` writes: + + .. math:: + \rho_{j_1, \cdots, j_K} = 1\left(\exists i \text{ s.t. } (j_1, \cdots, j_K) = (J[i, 1], \cdots, J[i, K])\right)\ w_i. + + This representation is useful for its memory efficiency, as it avoids + storing the full K-marginal plan. + + If `log=True`, the function computes the full K+1-marginal transport plan + :math:`\gamma`and stores it in log_dict['gamma']. Note that this option is + extremely costly in memory. + + Parameters + ---------- + pi_list : list of arrays (n, n_k) + List of transport plans. + + log : bool, optional + If True, return a log dictionary (computationally expensive). + + Returns + ------- + J : array (N, K) + The indices (J[i, 1], ..., J[i, K]) of the K-plan rho. + w : array (N,) + The weights w_i of the K-plan rho. + log_dict : dict, optional + If log=True, a dictionary containing the full K+1-marginal transport + plan under the key 'gamma'. + """ + nx = get_backend(pi_list[0]) + a = nx.sum(pi_list[0], axis=1) # common first marginal a in Delta_n + nk_list = [pi.shape[1] for pi in pi_list] # list of n_k + K = len(pi_list) + n = pi_list[0].shape[0] # number of points in the first marginal + gamma = None + + log_dict = {} + if log: # n x n_1 x ... x n_K tensor + gamma = nx.zeros([n] + nk_list, type_as=pi_list[0]) + + gamma_weights = {} # dict of (j_1, ..., j_K) : weight + P_list = [nx.copy(pi) for pi in pi_list] # copy of the transport plans + + # jjs is a list of K lists of size m_k + # checks if each jj_idx[k] is < m_k + # this is to avoid over-shooting the while loop due to numerical + # imprecision in the conditions "x > 0" + def jj_idx_in_range(jj_idx, jjs): + out = True + for k in range(K): + out = out and jj_idx[k] < len(jjs[k]) + return out + + for i in range(n): + # jjs[k] is the list of indices j in [0, n_k - 1] such that Pk[i, j] >0 + jjs = [nx.to_numpy(nx.where(P[i, :] > 0)[0]) for P in P_list] + # list [0, ..., 0] of size K for use with jjs: current indices in jjs + jj_idx = [0] * K + u = a[i] # mass at i, will decrease to 0 as we fill gamma[i, :] + + # while there is mass to add to gamma[i, :] + while u > 0 and jj_idx_in_range(jj_idx, jjs): + # current multi-index j_1 ... j_K + jj = tuple(jjs[k][jj_idx[k]] for k in range(K)) + # min transport plan value: min_k pi_k[i, j_k] + v = nx.min(nx.stack([P_list[k][i, jj[k]] for k in range(K)])) + if log: # assign mass v to gamma[i, j_1, ..., j_K] + gamma[(i,) + jj] = v + if jj in gamma_weights: + gamma_weights[jj] += v + else: + gamma_weights[jj] = v + u -= v # at i, we u-v mass left to assign + for k in range(K): # update plan copies Pk + P_list[k][i, jj[k]] -= v # Pk[i, j_k] has v less mass left + if P_list[k][i, jj[k]] == 0: + # move to next index in jjs[k] if Pk[i, j_k] is empty + jj_idx[k] += 1 + + log_dict["gamma"] = gamma + J = list(gamma_weights.keys()) # list of multi-indices (j_1, ..., j_K) + J = _to_int_array(nx.from_numpy(np.array(J), type_as=pi_list[0])) + w = nx.stack(list(gamma_weights.values())) + if log: + return J, w, log_dict + return J, w + + +def _clean_discrete_measure(X, a, tol=1e-10): + r""" + Simplifies a discrete measure by consolidating duplicate points and summing + their weights. Given a discrete measure with support X (n, d) and weights a + (n), returns a points Y (m, d) and weights b (m) such that Y is the set of + unique points in X and b is the sum of weights in a for each point in Y + + Parameters + ---------- + X : array-like + Array of shape (n, d) representing the support points of the discrete + measure. + a : array-like + Array of shape (n,) representing the weights associated with the support + points. + tol : float, optional + Tolerance for determining uniqueness of points in `X`. Points closer + than `tol` are considered identical. Default is 1e-10. + + Returns + ------- + Y : array-like + Array of shape (m, d) representing the unique support points of the + discrete measure. + b : array-like + Array of shape (m,) representing the summed weights for each unique + point in `Y`. + """ + nx = get_backend(X, a) + D = dist(X, X) + # each D[I[k], J[k]] < tol so X[I[k]] = X[J[k]] + idxI, idxJ = nx.where(D < tol) + idxI = nx.to_numpy(idxI) + idxJ = nx.to_numpy(idxJ) + # keep only the cases I[k] <= J[k] to avoid pairs (i, j) (j, i) with i != j + mask = idxI <= idxJ + idxI, idxJ = idxI[mask], idxJ[mask] + X_idx_to_Y_idx = {} # X[i] = Y[X_idx_to_Y_idx[i]] + # indices of unique points in X, at the end, Y := X[unique_X_idx] + unique_X_idx = [] + + b = [] + for i, j in zip(idxI, idxJ): + if i not in X_idx_to_Y_idx: # i is a new point + unique_X_idx.append(i) + X_idx_to_Y_idx[i] = len(unique_X_idx) - 1 + b.append(a[i]) + + else: # i is not new, check if j is known + if j not in X_idx_to_Y_idx: + b[X_idx_to_Y_idx[i]] += a[j] + X_idx_to_Y_idx[j] = X_idx_to_Y_idx[i] + + # create the unique points array Y + Y = X[tuple(unique_X_idx), :] + b = nx.from_numpy(np.array(b), type_as=X) + return Y, b diff --git a/ot/mapping.py b/ot/mapping.py index ea1917772..cc3e6cd57 100644 --- a/ot/mapping.py +++ b/ot/mapping.py @@ -7,7 +7,7 @@ use it you need to explicitly import :mod:`ot.mapping` """ -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # # License: MIT License diff --git a/test/test_gmm.py b/test/test_gmm.py index 5f1a92965..629a68d57 100644 --- a/test/test_gmm.py +++ b/test/test_gmm.py @@ -1,6 +1,6 @@ """Tests for module gaussian""" -# Author: Eloi Tanguy +# Author: Eloi Tanguy # Remi Flamary # Julie Delon # @@ -17,6 +17,7 @@ gmm_ot_plan, gmm_ot_apply_map, gmm_ot_plan_density, + gmm_barycenter_fixed_point, ) try: @@ -193,3 +194,54 @@ def test_gmm_ot_plan_density(nx): with pytest.raises(AssertionError): gmm_ot_plan_density(x[:, 1:], y, m_s, m_t, C_s, C_t, w_s, w_t) + + +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") +def test_gmm_barycenter_fixed_point(nx): + m_s, m_t, C_s, C_t, w_s, w_t = get_gmms(nx) + means_list = [m_s, m_t] + covs_list = [C_s, C_t] + w_list = [w_s, w_t] + n_iter = 3 + n = m_s.shape[0] # number of components of barycenter + means_init = m_s + covs_init = C_s + weights = nx.ones(2, type_as=m_s) / 2 # barycenter coefficients + + # with euclidean barycentric projections + means, covs = gmm_barycenter_fixed_point( + means_list, covs_list, w_list, means_init, covs_init, weights, iterations=n_iter + ) + + # with bures barycentric projections and assigned weights to uniform + means_bures_proj, covs_bures_proj, log = gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + iterations=n_iter, + w_bar=nx.ones(n, type_as=m_s) / n, + barycentric_proj_method="bures", + log=True, + ) + + assert "means_its" in log + assert "covs_its" in log + + assert np.allclose(means, means_bures_proj, atol=1e-6) + assert np.allclose(covs, covs_bures_proj, atol=1e-6) + + with pytest.raises(ValueError): + gmm_barycenter_fixed_point( + means_list, + covs_list, + w_list, + means_init, + covs_init, + weights, + iterations=n_iter, + barycentric_proj_method="unknown", + ) diff --git a/test/test_ot.py b/test/test_ot.py index f84f8773a..d523d3248 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -11,7 +11,7 @@ import ot from ot.datasets import make_1D_gauss as gauss -from ot.backend import torch, tf +from ot.backend import torch, tf, get_backend def test_emd_dimension_and_mass_mismatch(): @@ -395,6 +395,427 @@ def test_generalised_free_support_barycenter_backends(nx): np.testing.assert_allclose(Y, nx.to_numpy(Y2)) +def test_free_support_barycenter_generic_costs(): + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + + X_init = np.array([-12.0]).reshape((1, 1)) + + # obvious barycenter location between two Diracs + bar_locations = np.array([0.0]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def ground_bary(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, measures_weights, X_init, cost_list, ground_bary + ) + + np.testing.assert_allclose(X, bar_locations, rtol=1e-5, atol=1e-7) + + # test with log and specific weights + X2, log = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + a=ot.unif(1), + log=True, + ) + + assert "X_list" in log + assert "exit_status" in log + assert "diff_list" in log + + np.testing.assert_allclose(X, X2, rtol=1e-5, atol=1e-7) + + # test with one iteration for Max Iterations Reached + X3, log2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=1, + log=True, + ) + assert log2["exit_status"] == "Max iterations reached" + + # test with a single callable cost + X3, log3 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost, + ground_bary, + numItermax=1, + log=True, + ) + + # test with no ground_bary but in numpy: requires pytorch backend + with pytest.raises(AssertionError): + ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + numItermax=1, + ) + + # test with unknown method + with pytest.raises(AssertionError): + ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=1, + method="unknown_method", + ) + + # test true fixed-point method + X4, a4, log4 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=3, + method="true_fixed_point", + log=True, + ) + + assert "a_list" in log4 + assert X4.shape[0] == a4.shape[0] == 1 + np.testing.assert_allclose(a4, ot.unif(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(X, X4, rtol=1e-5, atol=1e-7) + + # test with measure cleaning and no log + X5, a5 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=3, + method="true_fixed_point", + clean_measure=True, + ) + np.testing.assert_allclose(a5, ot.unif(1), rtol=1e-5, atol=1e-7) + np.testing.assert_allclose(X, X5, rtol=1e-5, atol=1e-7) + + # test with (too) lax convergence criterion + # for Stationary Point exit status + X6, log6 = ot.lp.free_support_barycenter_generic_costs( + [np.array([-1.0]).reshape((1, 1))], + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=3, + stopThr=1e20, + log=True, + ) + assert log6["exit_status"] == "Stationary Point" + + +@pytest.mark.skipif(not torch, reason="No torch available") +def test_free_support_barycenter_generic_costs_auto_ground_bary(): + measures_locations = [ + torch.tensor([1.0]).reshape((1, 1)), + torch.tensor([2.0]).reshape((1, 1)), + ] + measures_weights = [torch.tensor([1.0]), torch.tensor([1.0])] + + X_init = torch.tensor([1.2]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def ground_bary(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + numItermax=1, + ) + + X2, log2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + ground_bary_lr=2e-2, + ground_bary_stopThr=1e-20, + ground_bary_numItermax=100, + numItermax=10, + log=True, + ) + + np.testing.assert_allclose(X2.numpy(), X.numpy(), rtol=1e-4, atol=1e-4) + + X3 = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + ground_bary_lr=1e-2, + ground_bary_stopThr=1e-20, + ground_bary_numItermax=100, + numItermax=10, + ground_bary_solver="Adam", + ) + + np.testing.assert_allclose(X2.numpy(), X3.numpy(), rtol=1e-3, atol=1e-3) + + # test with (too) lax convergence criterion for ground barycenter + ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary=None, + numItermax=1, + ground_bary_stopThr=100, + ) + + +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") +def test_free_support_barycenter_generic_costs_backends(nx): + measures_locations = [ + np.array([-1.0]).reshape((1, 1)), + np.array([1.0]).reshape((1, 1)), + ] + measures_weights = [np.array([1.0]), np.array([1.0])] + X_init = np.array([-12.0]).reshape((1, 1)) + + def cost(x, y): + return ot.dist(x, y) + + cost_list = [cost, cost] + + def ground_bary(y): + out = 0 + for yk in y: + out += yk / len(y) + return out + + X = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + method="L2_barycentric_proj", + ) + + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) + X_init2 = nx.from_numpy(X_init) + + X2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations2, + measures_weights2, + X_init2, + cost_list, + ground_bary, + method="L2_barycentric_proj", + ) + + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + X, a = ot.lp.free_support_barycenter_generic_costs( + measures_locations, + measures_weights, + X_init, + cost_list, + ground_bary, + method="true_fixed_point", + ) + + measures_locations2 = nx.from_numpy(*measures_locations) + measures_weights2 = nx.from_numpy(*measures_weights) + X_init2 = nx.from_numpy(X_init) + + X2, a2 = ot.lp.free_support_barycenter_generic_costs( + measures_locations2, + measures_weights2, + X_init2, + cost_list, + ground_bary, + method="true_fixed_point", + ) + + np.testing.assert_allclose(a, nx.to_numpy(a2)) + np.testing.assert_allclose(X, nx.to_numpy(X2)) + + +def verify_gluing_validity(gamma, J, w, pi_list): + """ + Test the validity of the North-West gluing. + """ + nx = get_backend(gamma) + K = len(pi_list) + n = pi_list[0].shape[0] + nk_list = [pi.shape[1] for pi in pi_list] + + # Check first marginal + a = nx.sum(gamma, axis=tuple(range(1, K + 1))) + assert nx.allclose(a, nx.sum(pi_list[0], axis=1)) + + # Check other marginals + for k in range(K): + b_k = nx.sum(gamma, axis=tuple(i for i in range(K + 1) if i != k + 1)) + assert nx.allclose(b_k, nx.sum(pi_list[k], axis=0)) + + # Check bi-marginals + for k in range(K): + gamma_0k = nx.sum(gamma, axis=tuple(i for i in range(1, K + 1) if i != k + 1)) + assert nx.allclose(gamma_0k, pi_list[k]) + + # Check that N <= n + sum_k n_k - K + N = J.shape[0] + n_k_sum = sum(nk_list) + assert N <= n + n_k_sum - K, f"N={N}, n={n}, sum(n_k)={n_k_sum}, K={K}" + + # Check that w is on the simplex + w_sum = nx.sum(w) + assert nx.allclose(w_sum, 1), f"Sum of weights w is not 1: {w_sum}" + + # Check that gamma_1...K and (J, w) are consistent + rho = nx.zeros(nk_list, type_as=gamma) + for i in range(N): + jj = J[i] + rho[tuple(jj)] += w[i] + + gamma_1toK = nx.sum(gamma, axis=0) + assert nx.allclose(rho, gamma_1toK), "rho and gamma_1...K are not consistent" + + +def test_north_west_mm_gluing(): + rng = np.random.RandomState(0) + n = 7 + nk_list = [5, 6, 4] + a = rng.rand(n) + a = a / np.sum(a) + b_list = [rng.rand(nk) for nk in nk_list] + b_list = [b / np.sum(b) for b in b_list] + M_list = [rng.rand(n, nk) for nk in nk_list] + pi_list = [ot.emd(a, b, M) for b, M in zip(b_list, M_list)] + J, w, log_dict = ot.lp.NorthWestMMGluing(pi_list, log=True) + # Test the validity of the gluing + gamma = log_dict["gamma"] + verify_gluing_validity(gamma, J, w, pi_list) + + # test without log + J2, w2 = ot.lp.NorthWestMMGluing(pi_list, log=False) + np.testing.assert_allclose(J, J2) + np.testing.assert_allclose(w, w2) + + # test setting with highly non-injective plans + n = 6 + a = ot.unif(n) + b_list = [a] * 3 + pi_list = [a[:, None] @ a[None, :]] * 3 + J, w, log_dict = ot.lp.NorthWestMMGluing(pi_list, log=True) + # Test the validity of the gluing + gamma = log_dict["gamma"] + verify_gluing_validity(gamma, J, w, pi_list) + + +@pytest.skip_backend("tf") # skips because of array assignment +@pytest.skip_backend("jax") +def test_north_west_mm_gluing_backends(nx): + rng = np.random.RandomState(0) + n = 7 + nk_list = [5, 6, 4] + a = rng.rand(n) + a = a / np.sum(a) + b_list = [rng.rand(nk) for nk in nk_list] + b_list = [b / np.sum(b) for b in b_list] + M_list = [rng.rand(n, nk) for nk in nk_list] + pi_list = [ot.emd(a, b, M) for b, M in zip(b_list, M_list)] + + pi_list2 = [nx.from_numpy(pi) for pi in pi_list] + J, w, log_dict = ot.lp.NorthWestMMGluing(pi_list2, log=True) + gamma = log_dict["gamma"] + + # Test equality with numpy solution + J_np, w_np, log_dict_np = ot.lp.NorthWestMMGluing(pi_list, log=True) + gamma_np = log_dict_np["gamma"] + np.testing.assert_allclose(J, J_np) + np.testing.assert_allclose(w, w_np) + np.testing.assert_allclose(gamma, gamma_np) + + +def test_clean_discrete_measure(nx): + a = nx.ones(3) / 3.0 + X = nx.from_numpy(np.array([[1.0, 1.0], [1.0, 1.0], [2.0, 2.0]])) + X_clean, a_clean = ot.lp._barycenter_solvers._clean_discrete_measure(X, a) + a_true = nx.from_numpy(np.array([2 / 3, 1 / 3])) + X_true = nx.from_numpy(np.array([[1.0, 1.0], [2.0, 2.0]])) + assert a_clean.shape == a_true.shape + assert X_clean.shape == X_true.shape + np.testing.assert_allclose(a_clean, a_true) + np.testing.assert_allclose(X_clean, X_true) + + a = nx.ones(3) / 3.0 + X = nx.from_numpy(np.array([[1.0, 1.0], [2.0, 2.0], [1.0, 1.0]])) + X_clean, a_clean = ot.lp._barycenter_solvers._clean_discrete_measure(X, a) + a_true = nx.from_numpy(np.array([2 / 3, 1 / 3])) + X_true = nx.from_numpy(np.array([[1.0, 1.0], [2.0, 2.0]])) + assert a_clean.shape == a_true.shape + assert X_clean.shape == X_true.shape + np.testing.assert_allclose(a_clean, a_true) + np.testing.assert_allclose(X_clean, X_true) + + n = 5 + a = nx.ones(n) / n + v = nx.from_numpy(np.array([1.0, 2.0, 3.0])) + X = nx.stack([v] * n, axis=0) + X_clean, a_clean = ot.lp._barycenter_solvers._clean_discrete_measure(X, a) + a_true = np.array([1.0]) + X_true = np.array([1.0, 2.0, 3.0]).reshape(1, 3) + assert a_clean.shape == a_true.shape + assert X_clean.shape == X_true.shape + np.testing.assert_allclose(a_clean, a_true) + np.testing.assert_allclose(X_clean, X_true) + + +def test_to_int_array(nx): + a_np = np.array([1.0, 2.0, 3.0]) + a = nx.from_numpy(a_np) + a_int = ot.lp._barycenter_solvers._to_int_array(a) + a_np_int = a_np.astype(int) + np.testing.assert_allclose(nx.to_numpy(a_int), a_np_int) + + @pytest.mark.skipif(not ot.lp._barycenter_solvers.cvxopt, reason="No cvxopt available") def test_lp_barycenter_cvxopt(): a1 = np.array([1.0, 0, 0])[:, None]