diff --git a/ot/__init__.py b/ot/__init__.py index 5e21d6a76..dbddbd03f 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -68,7 +68,7 @@ ) from .weak import weak_optimal_transport from .factored import factored_optimal_transport -from .solvers import solve, solve_gromov, solve_sample +from .solvers import solve, solve_gromov, solve_sample, bary_sample from .lowrank import lowrank_sinkhorn # utils functions @@ -116,6 +116,7 @@ "solve", "solve_gromov", "solve_sample", + "bary_sample", "smooth", "stochastic", "unbalanced", diff --git a/ot/solvers.py b/ot/solvers.py index a5bbf0e94..daad962a8 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -4,10 +4,11 @@ """ # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License -from .utils import OTResult, dist +from .utils import OTResult, BaryResult, dist from .lp import emd2, wasserstein_1d from .backend import get_backend from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced @@ -33,6 +34,7 @@ from .optim import cg import warnings +import numpy as np lst_method_lazy = [ @@ -1936,3 +1938,633 @@ def solve_sample( log=log, ) return res + + +def _bary_sample_bcd( + X_s, + X_init, + a_s, + b_init, + w_s, + metric, + inner_solver, + update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, + max_iter_bary, + tol_bary, + verbose, + log, + nx, +): + """Compute the barycenter using BCD. + + Parameters + ---------- + X_s : list of array-like, shape (n_samples_k, dim) + List of samples in each source distribution + X_init : array-like, shape (n_samples_b, dim), + Initialization of the barycenter samples. + a_s : list of array-like, shape (dim_k,) + List of samples weights in each source distribution + b_init : array-like, shape (n_samples_b,) + Initialization of the barycenter weights. + w_s : list of array-like, shape (N,) + Samples barycentric weights + metric : str + Metric to use for the cost matrix, by default "sqeuclidean" + inner_solver : callable + Function to solve the inner OT problem + update_masses : bool + Update the masses of the barycenter, depending on whether balanced or unbalanced OT is used. + warmstart_plan : bool + Use the previous plan as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + warmstart_potentials : bool + Use the previous potentials as initialization for the inner solver. Set based on inner solver type in ot.bary_sample + stopping_criterion : str + Stopping criterion for the BCD algorithm. Can be "loss" or "bary". + max_iter_bary : int + Maximum number of iterations for the barycenter + tol_bary : float + Tolerance for the barycenter convergence + verbose : bool + Print information in the solver + log : bool + Log the loss during the iterations + nx: backend + Backend to use for the computation. Must match<< + Returns + ------- + TBD + """ + + X = X_init + b = b_init + inv_b = 1.0 / b + + prev_criterion = np.inf + n_samples = len(X_s) + + if log: + log_ = {"stopping_criterion": []} + else: + log_ = None + + # Compute the barycenter using BCD + for it in range(max_iter_bary): + # Solve the inner OT problem for each source distribution + if it == 0: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) + ] + elif warmstart_plan: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, list_res[k].plan, None) + for k in range(n_samples) + ] + elif warmstart_potentials: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, list_res[k].potentials) + for k in range(n_samples) + ] + else: + list_res = [ + inner_solver(X_s[k], X, a_s[k], b, None, None) for k in range(n_samples) + ] + print("inv_b:", inv_b) + # Update the estimated barycenter weights in unbalanced cases + if update_masses: + b = sum([w_s[k] * list_res[k].plan.sum(axis=0) for k in range(n_samples)]) + inv_b = 1.0 / b + + # Update the barycenter samples + if metric in ["sqeuclidean", "euclidean"]: + X_new = ( + sum([w_s[k] * list_res[k].plan.T @ X_s[k] for k in range(n_samples)]) + * inv_b[:, None] + ) + else: + raise NotImplementedError('Not implemented metric="{}"'.format(metric)) + + # compute criterion + if stopping_criterion == "loss": + new_criterion = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + else: # stopping_criterion = "bary" + new_criterion = nx.norm(X_new - X, ord=2) + + if verbose: + if it % 1 == 0: + print( + f"BCD iteration {it}: criterion {stopping_criterion} = {new_criterion:.4f}" + ) + + if log: + log_["stopping_criterion"].append(new_criterion) + # Check convergence + if abs(new_criterion - prev_criterion) / abs(prev_criterion) < tol_bary: + print(f"BCD converged in {it} iterations") + break + + X = X_new + prev_criterion = new_criterion + + # compute loss values + + value_linear = sum([w_s[k] * list_res[k].value_linear for k in range(n_samples)]) + if stopping_criterion == "loss": + value = new_criterion + else: + value = sum([w_s[k] * list_res[k].value for k in range(n_samples)]) + # update BaryResult + bary_res = BaryResult( + X=X_new, + b=b, + value=value, + value_linear=value_linear, + log=log_, + list_res=list_res, + backend=nx, + ) + return bary_res + + +def bary_sample( + X_s, + n, + a_s=None, + w_s=None, + X_init=None, + b_init=None, + learn_X=True, + learn_b=False, + metric="sqeuclidean", + reg=None, + c=None, + reg_type="KL", + unbalanced=None, + unbalanced_type="KL", + lazy=False, + batch_size=None, + method=None, + n_threads=1, + warmstart=False, + stopping_criterion="loss", + max_iter_bary=1000, + max_iter=None, + rank=100, + scaling=0.95, + tol_bary=1e-5, + tol=None, + random_state=0, + verbose=False, +): + r"""Solve the discrete OT barycenter problem over source distributions using Block-Coordinate Descent. + + The function solves the following general OT barycenter problem + + .. math:: + \min_{\mathbf{X} \in \mathbb{R}^{n \times d}, \mathbf{b} \in \Sigma_n} \min_{\{ \mathbf{T}^{(k)} \}_k \in \R_+^{n_i \times n}} \quad \sum_k w_k \sum_{i,j} T^{(k)}_{i,j}M^{(k)}_{i,j} + \lambda_r R(\mathbf{T}^{(k)}) + + \lambda_u U(\mathbf{T^{(k)}}\mathbf{1},\mathbf{a}^{(k)}) + + \lambda_u U(\mathbf{T}^{(k)T}\mathbf{1},\mathbf{b}) + + where the cost matrices :math:`\mathbf{M}^{(k)}` for each input distribution :math:`(\mathbf{X}^{(k)}, \mathbf{b}^{(k)})` + is computed from the samples in the source and barycenter domains such that + :math:`M^{(k)}_{i,j} = d(x^{(k)}_i,x_j)` where + :math:`d` is a metric (by default the squared Euclidean distance). + + The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By + default ``reg=None`` and there is no regularization. The unbalanced marginal + penalization can be selected with `unbalanced` (:math:`\lambda_u`) and + `unbalanced_type`. By default ``unbalanced=None`` and the function + solves the exact optimal transport problem (respecting the marginals). + + Parameters + ---------- + X_s : list of array-like, shape (n_samples_k, dim) + List of samples in each source distribution + n : int + number of samples in the barycenter domain + a_s : list of array-like, shape (dim_k,), optional + List of samples weights in each source distribution (default is uniform) + w_s : list of array-like, shape (N,), optional + Samples barycentric weights (default is uniform) + X_init : array-like, shape (n_samples_b, dim), optional + Initialization of the barycenter samples (default is gaussian random sampling). + Shape must match with required n. + b_init : array-like, shape (n_samples_b,), optional + Initialization of the barycenter weights (default is uniform). + Shape must match with required n. + learn_X : bool, optional + Learn the barycenter samples (default is True) + learn_b : bool, optional + Learn the barycenter weights (default is False) + metric : str, optional + Metric to use for the cost matrix, by default "sqeuclidean" + reg : float, optional + Regularization weight :math:`\lambda_r`, by default None (no reg., exact + OT) + c : array-like, shape (dim_a, dim_b), optional (default=None) + Reference measure for the regularization. + If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. + If :math:`\texttt{reg_type}=`'entropy', then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`. + reg_type : str, optional + Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" + unbalanced : float or indexable object of length 1 or 2 + Marginal relaxation term. + If it is a scalar or an indexable object of length 1, + then the same relaxation is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`unbalanced=float("inf")`. + For semi-relaxed case, use either + :math:`unbalanced=(float("inf"), scalar)` or + :math:`unbalanced=(scalar, float("inf"))`. + If unbalanced is an array, + it must have the same backend as input arrays `(a, b, M)`. + unbalanced_type : str, optional + Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" + lazy : bool, optional + Return :any:`OTResultlazy` object to reduce memory cost when True, by + default False + batch_size : int, optional + Batch size for lazy solver, by default None (default values in each + solvers) + method : str, optional + Method for solving the problem, this can be used to select the solver + for unbalanced problems (see :any:`ot.solve`), or to select a specific + large scale solver. + n_threads : int, optional + Number of OMP threads for exact OT solver, by default 1 + warmstart : bool, optional + Use the previous OT or potentials as initialization for the next inner solver iteration, by default False. + stopping_criterion : str, optional + Stopping criterion for the outer loop of the BCD solver, by default 'loss'. + Either 'loss' to use the optimize objective or 'bary' for variations of the barycenter w.r.t the Frobenius norm. + max_iter_bary : int, optional + Maximum number of iteration for the BCD solver, by default 1000. + max_iter : int, optional + Maximum number of iteration, by default None (default values in each solvers) + rank : int, optional + Rank of the OT matrix for lazy solers (method='factored'), by default 100 + scaling : float, optional + Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 + tol_bary : float, optional + Tolerance for solution precision of barycenter problem, by default None (default value 1e-5) + tol : float, optional + Tolerance for solution precision of inner OT solver, by default None (default values in each solvers) + random_state : int, optional + Random seed for the initialization of the barycenter samples, by default 0. + Only used if `X_init` is None. + verbose : bool, optional + Print information in the solver, by default False + + Returns + ------- + + res : BaryResult() + Result of the optimization problem. The information can be obtained as follows: + + OTResult() + Result of the optimization problem. The information can be obtained as follows: + + - res.plan : OT plan :math:`\mathbf{T}` + - res.potentials : OT dual potentials + - res.value : Optimal value of the optimization problem + - res.value_linear : Linear OT loss with the optimal OT plan + - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) + + See :any:`OTResult` for more information. + + Notes + ----- + + The following methods are available for solving the OT problems: + + - **Classical exact OT problem [1]** (default parameters) : + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b) + + # for uniform weights + res = ot.solve_sample(xa, xb) + + - **Entropic regularized OT [2]** (when ``reg!=None``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` regularization (``reg_type="KL"``) + res = ot.solve_sample(xa, xb, a, b, reg=1.0) + # or for original Sinkhorn paper formulation [2] + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') + + # lazy solver of memory complexity O(n) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) + # lazy OT plan + lazy_plan = res.lazy_plan + + # Use envelope theorem differentiation for memory saving + res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') + res.value.backward() # only the value is differentiable + + Note that by default the Sinkhorn solver uses automatic differentiation to + compute the gradients of the values and plan. This can be changed with the + `grad` parameter. The `envelope` mode computes the gradients only + for the value and the other outputs are detached. This is useful for + memory saving when only the gradient of value is needed. + + We also have a very efficient solver with compiled CPU/CUDA code using + geomloss/PyKeOps that can be used with the following code: + + .. code-block:: python + + # automatic solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') + + # force O(n) memory efficient solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') + + # force pre-computed cost matrix + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') + + # use multiscale solver + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') + + # One can play with speed (small scaling factor) and precision (scaling close to 1) + res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) + + - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): + + .. math:: + \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) + + s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} + + \mathbf{T}^T \mathbf{1} = \mathbf{b} + + \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') + + - **Unbalanced OT [41]** (when ``unbalanced!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + \text{with} \ M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) + # quadratic unbalanced OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') + # TV = partial OT + res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') + + + - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): + + .. math:: + \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) + + \text{with} \ M_{i,j} = d(x_i,y_j) + + can be solved with the following code: + + .. code-block:: python + + # default is ``"KL"`` for both + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) + # quadratic unbalanced OT with KL regularization + res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') + # both quadratic + res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', + unbalanced=1.0, unbalanced_type='L2') + + + - **Factored OT [2]** (when ``method='factored'``): + + This method solve the following OT problem [40]_ + + .. math:: + \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) + + where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated + to the samples in the source and target domains, and :math:`W_2` is the + Wasserstein distance. This problem is solved using exact OT solvers for + `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides + two transport plans that can be used to recover a low rank OT plan between + the two distributions. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='factored', rank=10) + + # recover the lazy low rank plan + factored_solution_lazy = res.lazy_plan + + # recover the full low rank plan + factored_solution = factored_solution_lazy[:] + + - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): + + This method computes the Gaussian Bures-Wasserstein distance between two + Gaussian distributions estimated from the empirical distributions + + .. math:: + \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} + + where : + + .. math:: + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + + The covariances and means are estimated from the data. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='gaussian') + + # recover the squared Gaussian Bures-Wasserstein distance + BW_dist = res.value + + - **Wasserstein 1d [1]** (when ``method='1D'``): + + This method computes the Wasserstein distance between two 1d distributions + estimated from the empirical distributions. For multivariate data the + distances are computed independently for each dimension. + + .. code-block:: python + + res = ot.solve_sample(xa, xb, method='1D') + + # recover the squared Wasserstein distances + W_dists = res.value + + + .. _references-bary-sample: + References + ---------- + + """ + if learn_b: + raise NotImplementedError("Barycenter weights learning not implemented yet") + + if method is not None and method.lower() in lst_method_lazy: + raise NotImplementedError("Barycenter with Lazy tensors not implemented yet") + + if stopping_criterion not in ["loss", "bary"]: + raise ValueError( + "stopping_criterion must be either 'loss' or 'bary', got {}".format( + stopping_criterion + ) + ) + + n_samples = len(X_s) + + if ( + not lazy + ): # default non lazy solver calls ot.solve_sample within _bary_sample_bcd + # Detect backend + nx = get_backend(*X_s, X_init, b_init, w_s) + + # check sample weights + if a_s is None: + a_s = [ + nx.ones((X_s[k].shape[0],), type_as=X_s[k]) / X_s[k].shape[0] + for k in range(n_samples) + ] + + # check samples barycentric weights + if w_s is None: + w_s = nx.ones(n_samples, type_as=X_s[0]) / n_samples + + # check X_init + if X_init is None: + if (not learn_X) and learn_b: + raise ValueError( + "X_init must be provided if learn_X=False and learn_b=True" + ) + else: + rng = np.random.RandomState(random_state) + mean_ = nx.concatenate( + [nx.mean(X_s[k], axis=0) for k in range(n_samples)], + axis=0, + ) + mean_ = nx.mean(mean_, axis=0) + std_ = nx.concatenate( + [nx.std(X_s[k], axis=0) for k in range(n_samples)], + axis=0, + ) + std_ = nx.mean(std_, axis=0) + X_init = rng.normal( + loc=mean_, + scale=std_, + size=(n, X_s[0].shape[1]), + ) + X_init = nx.from_numpy(X_init, type_as=X_s[0]) + else: + if (X_init.shape[0] != n) or (X_init.shape[1] != X_s[0].shape[1]): + raise ValueError("X_init must have shape (n, dim)") + + # check b_init + if b_init is None: + b_init = nx.ones((n,), type_as=X_s[0]) / n + + if warmstart: + if reg is None: # exact OT + warmstart_plan = True + warmstart_potentials = False + else: # regularized OT + # unbalanced AND regularized OT + if ( + not isinstance(reg_type, tuple) + and reg_type.lower() in ["kl"] + and unbalanced_type.lower() == "kl" + ): + warmstart_plan = False + warmstart_potentials = True + + else: + warmstart_plan = True + warmstart_potentials = False + else: + warmstart_plan = False + warmstart_potentials = False + + def inner_solver(X_a, X, a, b, plan_init, potentials_init): + return solve_sample( + X_a=X_a, + X_b=X, + a=a, + b=b, + metric=metric, + reg=reg, + c=c, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + method=method, + n_threads=n_threads, + max_iter=max_iter, + tol=tol, + plan_init=plan_init, + potentials_init=potentials_init, + verbose=False, + ) + + # compute the barycenter using BCD + update_masses = unbalanced is not None + res = _bary_sample_bcd( + X_s, + X_init, + a_s, + b_init, + w_s, + metric, + inner_solver, + update_masses, + warmstart_plan, + warmstart_potentials, + stopping_criterion, + max_iter_bary, + tol_bary, + verbose, + True, # log set to True by default + nx, + ) + + return res + + else: + raise (NotImplementedError("Barycenter solver with lazy=True not implemented")) diff --git a/ot/unbalanced/_lbfgs.py b/ot/unbalanced/_lbfgs.py index c4de87474..eb995efb5 100644 --- a/ot/unbalanced/_lbfgs.py +++ b/ot/unbalanced/_lbfgs.py @@ -46,9 +46,9 @@ def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div="kl", regm_div Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) @@ -218,9 +218,9 @@ def lbfgsb_unbalanced( Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (half-squared) or a tuple - of two calable functions returning the reg term and its derivative. + of two callable functions returning the reg term and its derivative. Note that the callable functions should be able to handle Numpy arrays - and not tesors from the backend + and not tensors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take three values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) or 'tv' (Total Variation) diff --git a/ot/utils.py b/ot/utils.py index 1f24fa33f..8b045984b 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1310,6 +1310,183 @@ def citation(self): """ +class BaryResult: + """Base class for OT barycenter results. + + Parameters + ---------- + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + status : int or str + Status of the solver. + + Attributes + ---------- + + X : array-like, shape (`n`, `d`) + Barycenter features. + C: array-like, shape (`n`, `n`) + Barycenter structure for Gromov Wasserstein solutions. + b : array-like, shape (`n`,) + Barycenter weights. + value : float, array-like + Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions. + value_linear : float, array-like + The linear part of the transport cost, i.e. the product between the + transport plan and the cost. + value_quad : float, array-like + The quadratic part of the transport cost for Gromov-Wasserstein + solutions. + log : dict + Dictionary containing potential information about the solver. + list_res: list of OTResult + List of results for the individual OT matching. + status : int or str + Status of the solver. + backend : Backend + Backend used to compute the results. + """ + + def __init__( + self, + X=None, + C=None, + b=None, + value=None, + value_linear=None, + value_quad=None, + log=None, + list_res=None, + status=None, + backend=None, + ): + self._X = X + self._C = C + self._b = b + self._value = value + self._value_linear = value_linear + self._value_quad = value_quad + self._log = log + self._list_res = list_res + self._status = status + self._backend = backend if backend is not None else NumpyBackend() + + def __repr__(self): + s = "BaryResult(" + if self._value is not None: + s += "value={},".format(self._value) + if self._value_linear is not None: + s += "value_linear={},".format(self._value_linear) + if self._X is not None: + s += "X={}(shape={}),".format(self._X.__class__.__name__, self._X.shape) + if self._C is not None: + s += "C={}(shape={}),".format(self._C.__class__.__name__, self._C.shape) + if self._b is not None: + s += "b={}(shape={}),".format(self._b.__class__.__name__, self._b.shape) + if s[-1] != "(": + s = s[:-1] + ")" + else: + s = s + ")" + return s + + # Barycerters -------------------------------- + + @property + def X(self): + """Barycenter features.""" + return self._X + + @property + def C(self): + """Barycenter structure for Gromov Wasserstein solutions.""" + return self._C + + @property + def b(self): + """Barycenter weights.""" + return self._b + + # Loss values -------------------------------- + + @property + def value(self): + """Full transport cost, including possible regularization terms and + quadratic term for Gromov Wasserstein solutions.""" + return self._value + + @property + def value_linear(self): + """The "minimal" transport cost, i.e. the product between the transport plan and the cost.""" + return self._value_linear + + @property + def value_quad(self): + """The quadratic part of the transport cost for Gromov-Wasserstein solutions.""" + return self._value_quad + + # List of OTResult objects ------------------------- + + @property + def list_res(self): + """List of results for the individual OT matching.""" + return self._list_res + + @property + def status(self): + """Optimization status of the solver.""" + return self._status + + @property + def log(self): + """Dictionary containing potential information about the solver.""" + return self._log + + # Miscellaneous -------------------------------- + + @property + def citation(self): + """Appropriate citation(s) for this result, in plain text and BibTex formats.""" + + # The string below refers to the POT library: + # successor methods may concatenate the relevant references + # to the original definitions, solvers and underlying numerical backends. + return """POT library: + + POT Python Optimal Transport library, Journal of Machine Learning Research, 22(78):1−8, 2021. + Website: https://pythonot.github.io/ + Rémi Flamary, Nicolas Courty, Alexandre Gramfort, Mokhtar Z. Alaya, Aurélie Boisbunon, Stanislas Chambon, Laetitia Chapel, Adrien Corenflos, Kilian Fatras, Nemo Fournier, Léo Gautheron, Nathalie T.H. Gayraud, Hicham Janati, Alain Rakotomamonjy, Ievgen Redko, Antoine Rolet, Antony Schutz, Vivien Seguy, Danica J. Sutherland, Romain Tavenard, Alexander Tong, Titouan Vayer; + + @article{flamary2021pot, + author = {R{\'e}mi Flamary and Nicolas Courty and Alexandre Gramfort and Mokhtar Z. Alaya and Aur{\'e}lie Boisbunon and Stanislas Chambon and Laetitia Chapel and Adrien Corenflos and Kilian Fatras and Nemo Fournier and L{\'e}o Gautheron and Nathalie T.H. Gayraud and Hicham Janati and Alain Rakotomamonjy and Ievgen Redko and Antoine Rolet and Antony Schutz and Vivien Seguy and Danica J. Sutherland and Romain Tavenard and Alexander Tong and Titouan Vayer}, + title = {{POT}: {Python} {Optimal} {Transport}}, + journal = {Journal of Machine Learning Research}, + year = {2021}, + volume = {22}, + number = {78}, + pages = {1-8}, + url = {http://jmlr.org/papers/v22/20-451.html} + } + """ + + class LazyTensor(object): """A lazy tensor is a tensor that is not stored in memory. Instead, it is defined by a function that computes its values on the fly from slices. diff --git a/test/test_solvers.py b/test/test_solvers.py index a0c1d7c43..6ede9b3f6 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -1,6 +1,7 @@ """Tests for ot solvers""" # Author: Remi Flamary +# Cédric Vincent-Cuaz # # License: MIT License @@ -703,3 +704,166 @@ def test_solve_sample_NotImplemented(nx, method_params): with pytest.raises(NotImplementedError): ot.solve_sample(xb, yb, ab, bb, **method_params) + + +def assert_allclose_bary_sol(sol1, sol2): + lst_attr = ["X", "b", "value", "value_linear", "log"] + + nx1 = sol1._backend if sol1._backend is not None else ot.backend.NumpyBackend() + nx2 = sol2._backend if sol2._backend is not None else ot.backend.NumpyBackend() + + for attr in lst_attr: + if getattr(sol1, attr) is not None and getattr(sol2, attr) is not None: + try: + var1 = getattr(sol1, attr) + var2 = getattr(sol2, attr) + if isinstance(var1, dict): # only contains lists + for key in var1.keys(): + np.allclose( + np.array(var1[key]), + np.array(var2[key]), + equal_nan=True, + ) + else: + np.allclose( + nx1.to_numpy(getattr(sol1, attr)), + nx2.to_numpy(getattr(sol2, attr)), + equal_nan=True, + ) + except NotImplementedError: + pass + elif getattr(sol1, attr) is None and getattr(sol2, attr) is None: + return True + else: + return False + + +@pytest.skip_backend("jax", reason="test very slow with jax backend") +@pytest.skip_backend("tf", reason="test very slow with tf backend") +@pytest.mark.parametrize( + "reg,reg_type,unbalanced,unbalanced_type,warmstart", + itertools.product( + lst_reg, + ["tuple"], + lst_unbalanced, + lst_unbalanced_type, + [True, False], + # lst_reg, lst_reg_type, lst_unbalanced, lst_unbalanced_type, [True, False] + ), +) +def test_bary_sample_free_support( + nx, reg, reg_type, unbalanced, unbalanced_type, warmstart +): + # test bary_sample when is_Lazy = False + rng = np.random.RandomState() + + K = 3 # number of distributions + ns = rng.randint(10, 20, K) # number of samples within each distribution + n = 5 # number of samples in the barycenter + + X_s = [rng.randn(ns_i, 2) for ns_i in ns] + # X_init = np.reshape(1.0 * np.randn(n, 2), (n, 1)) + + a_s = [ot.utils.unif(X.shape[0]) for X in X_s] + b = ot.utils.unif(n) + + w_s = ot.utils.unif(K) + + try: + if reg_type == "tuple": + + def f(G): + return np.sum(G**2) + + def df(G): + return 2 * G + + reg_type = (f, df) + # print('test reg_type:', reg_type[0](None), reg_type[1](None)) + # solve default None weights + sol0 = ot.bary_sample( + X_s, + n, + w_s=None, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=3, + tol_bary=1e-3, + verbose=True, + ) + print("------ [done] sol0 - no backend") + + # solve provided uniform weights + + sol = ot.bary_sample( + X_s, + n, + a_s=a_s, + b_init=b, + w_s=w_s, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=3, + tol_bary=1e-3, + verbose=True, + ) + print("------ [done] sol - no backend") + + assert_allclose_bary_sol(sol0, sol) + + # solve in backend + X_sb = nx.from_numpy(*X_s) + a_sb = nx.from_numpy(*a_s) + w_sb, bb = nx.from_numpy(w_s, b) + + if reg_type == "tuple": + + def fb(G): + return nx.sum( + G**2 + ) # otherwise we keep previously defined (f, df) as required by inner solver + + def dfb(G): + return 2 * G + + """ + if ( + unbalanced_type.lower() in ["kl", "l2", "tv"]) and ( + unbalanced is not None) and ( + reg is not None + ): + reg_type = (f, df) + else: + """ + reg_type = (f, df) + + solb = ot.bary_sample( + X_sb, + n, + a_s=a_sb, + b_init=bb, + w_s=w_sb, + metric="sqeuclidean", + reg=reg, + reg_type=reg_type, + unbalanced=unbalanced, + unbalanced_type=unbalanced_type, + warmstart=warmstart, + max_iter_bary=3, + tol_bary=1e-3, + verbose=True, + ) + print("------ [done] sol - with backend") + + assert_allclose_bary_sol(sol, solb) + + except NotImplementedError: + pytest.skip("Not implemented") diff --git a/test/test_utils.py b/test/test_utils.py index 938fd6058..1ecd1b51f 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -456,7 +456,7 @@ def test_OTResult(): # test print print(res) - # tets get citation + # test get citation print(res.citation) lst_attributes = [ @@ -486,6 +486,31 @@ def test_OTResult(): getattr(res, at) +def test_BaryResult(): + res = ot.utils.BaryResult() + + # test print + print(res) + + # test get citation + print(res.citation) + + lst_attributes = [ + "X", + "C", + "b", + "value", + "value_linear", + "value_quad", + "list_res", + "status", + "log", + ] + for at in lst_attributes: + print(at) + assert getattr(res, at) is None + + def test_get_coordinate_circle(): rng = np.random.RandomState(42) u = rng.rand(1, 100)