diff --git a/ot/batch/_quadratic.py b/ot/batch/_quadratic.py index 0da4b8962..e549c7e72 100644 --- a/ot/batch/_quadratic.py +++ b/ot/batch/_quadratic.py @@ -152,6 +152,90 @@ def h2(C2): return compute_tensor_batch(f1, f2, h1, h2, a, b, C1, C2, symmetric=symmetric) +def div_to_product_batch( + T, a, b, T1=None, T2=None, divergence="kl", mass=True, nx=None +): + r"""Fast computation of the Bregman divergence between a batch of arbitrary measures and a product measures. + Only support for Kullback-Leibler and half-squared L2 divergences. + + - For half-squared L2 divergence: + + .. math:: + \frac{1}{2} || \pi - a \otimes b ||^2 + = \frac{1}{2} \Big[ \sum_{i, j} \pi_{ij}^2 + (\sum_i a_i^2) ( \sum_j b_j^2) - 2 \sum_{i, j} a_i \pi_{ij} b_j \Big] + + - For Kullback-Leibler divergence: + + .. math:: + KL(\pi | a \otimes b) + = \langle \pi, \log \pi \rangle - \langle \pi_1, \log a \rangle + - \langle \pi_2, \log b \rangle - m(\pi) + m(a) m(b) + + where : + + - :math:`\pi` is the (`dim_a`, `dim_b`) transport plan + - :math:`\pi_1` and :math:`\pi_2` are the marginal distributions + - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions + - :math:`m` denotes the mass of the measure + + Parameters + ---------- + pi : array-like (B, n, m) + Transport plan for each problem in the batch + a : array-like (B,n) + Unnormalized histogram of dimension `n` for each problem in the batch + b : array-like (B,m) + Unnormalized histogram of dimension `m` for each problem in the batch + T1 : array-like (B, n), optional (default = None) + Marginal distribution with respect to the first dimension of the transport plan for each problem in the batch + Only used in case of Kullback-Leibler divergence. + T2 : array-like (B, m), optional (default = None) + Marginal distribution with respect to the second dimension of the transport plan for each problem in the batch + Only used in case of Kullback-Leibler divergence. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + mass : bool, optional. Default is False. + Only used in case of Kullback-Leibler divergence. + If False, calculate the relative entropy. + If True, calculate the Kullback-Leibler divergence. + nx : backend, optional + If let to its default value None, a backend test will be conducted. + + Returns + ------- + Bregman divergence between an arbitrary measure and a product measure for each problem in the batch. + """ + + arr = [T, a, b, T1, T2] + + if nx is None: + nx = get_backend(*arr, T1, T2) + + if divergence == "kl": + if T1 is None: + T1 = nx.sum(T, 2) + if T2 is None: + T2 = nx.sum(T, 1) + + if divergence == "kl": + res = ( + nx.sum((T * nx.log(T + 1.0 * (T == 0))), (1, 2)) + - nx.sum(T1 * nx.log(a), 1) + - nx.sum(T2 * nx.log(b), 1) + ) + if mass: + res = res - nx.sum(T1, 1) + nx.sum(a, 1) * nx.sum(b, 1) + + elif divergence == "l2": + res = ( + nx.sum(T**2, (1, 2)) + + nx.sum(a**2, 1) * nx.sum(b**2, 1) + - 2 * nx.sum((a * (T @ b[:, :, None]).squeeze(-1)), 1) + ) / 2 + + return res + + def loss_quadratic_batch(L, T, recompute_const=False, symmetric=True, nx=None): r""" Computes the gromov-wasserstein cost given a cost tensor and transport plan. Batched version. @@ -266,6 +350,74 @@ def loss_quadratic_samples_batch( ) +def loss_fugw_batch( + L, M, T, alpha=0.5, reg_marginals=1, symmetric=True, divergence="kl", nx=None +): + r""" + Computes the fused unbalanced gromov-wasserstein cost given a cost tensor (Gromov term), a cost matrix between features across domains (linear term) and a transport plan. Batched version. + + Parameters + ---------- + L : dict + Cost tensor as returned by `tensor_batch`. + M : array-like, shape (B, n, m) + Cost matrix between features across domains. + T : array-like, shape (B, n, m) + Transport plan. + alpha : float or array-like( B,) optional + Weight the quadratic term (alpha*Gromov) and the linear term + ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. If alpha + a scalar it is used for all problems in the batch. + reg_marginals : float or array-like( B,) optional + Marginal relaxation terms. If rho is + a scalar it is used for all problems in the batch. + symmetric : bool, optional + Whether to use symmetric version. Default is True. + divergence : string, default = "kl" + Bregman divergence, either "kl" (Kullback-Leibler divergence) or "l2" (half-squared L2 divergence) + nx : module, optional + Backend to use. Default is None. + + Examples + -------- + >>> import numpy as np + >>> from ot.batch import tensor_batch, loss_quadratic_batch + >>> # Create batch of cost matrices + >>> C1 = np.random.rand(3, 5, 5) # 3 problems, 5x5 source matrices + >>> C2 = np.random.rand(3, 4, 4) # 3 problems, 4x4 target matrices + >>> a = np.ones((3, 5)) / 5 # Uniform source distributions + >>> b = np.ones((3, 4)) / 4 # Uniform target distributions + >>> L = tensor_batch(a, b, C1, C2, loss='sqeuclidean') + >>> # Use the uniform transport plan for testing + >>> T = np.ones((3, 5, 4)) / (5 * 4) + >>> loss = loss_quadratic_batch(L, T, recompute_const=True) + >>> loss.shape + (3,) + + See Also + -------- + ot.batch.tensor_batch : From computing the cost tensor L. + ot.batch.solve_gromov_batch : For finding the optimal transport plan T. + """ + if nx is None: + nx = get_backend(T) + + Q = loss_quadratic_batch(L, T, recompute_const=True, symmetric=symmetric, nx=nx) + + L = loss_linear_batch(M, T, nx=nx) + + unbalanced = div_to_product_batch( + T, + a=nx.sum(T, axis=2), + b=nx.sum(T, axis=1), + divergence=divergence, + mass=True, + nx=nx, + ) + + return (1 - alpha) * L + alpha * Q + reg_marginals * unbalanced + + def solve_gromov_batch( C1, C2, diff --git a/test/batch/test_solve_unbalanced_batch.py b/test/batch/test_solve_unbalanced_batch.py new file mode 100644 index 000000000..e69de29bb