diff --git a/examples/LinearizedTransport/script.jl b/examples/LinearizedTransport/script.jl new file mode 100644 index 00000000..d723a764 --- /dev/null +++ b/examples/LinearizedTransport/script.jl @@ -0,0 +1,61 @@ +""" +Optimal transport costs are expensive to compute in general, so scaling can be quite bad if we need to, say, compute +the OT cost pairwise for a reasonably sized family of measures. When this is the situation, it may be beneficial to +linearize the OT distance using the manifold-like structure induced by the Wasserstein cost. Fix μ, and consider the transformation +ν → T_ν, where T_ν is the optimal transport map pushing μ forward to ν. Now fix two other measures ν, ρ, not equal to μ. +We may approximate OT(ν, ρ) via OT(ν, ρ) ≈ ||T_ν - T_ρ||_L^2(μ). If μ, ν, and ρ are "nice" (i.e. have smooth and accessible densities +w.r.t to the Lebesgue measure), then the right hand side is easy to approximate well via standard numerical methods. + +Now, it is a sad fact that recovering the maps T_ν is generally no easy task itself. But in the case of entropically regularized +transport, there exists a very nice entropic approximation to the transport map, which depends only on the measure ν and +a family of N i.i.d samples Y_i ∼ ν. + +The following example is rather contrived, since if we only wanted to compute one distance, we're actually doing much more work than we +need to by computing 2 Sinkhorn problems and an integral on top of that, but again the main application here would be when +we have O(n^2) distances to compute + +Note that the choice of reference measure can significantly affect the quality of the approximation, and as of writing there is +no non-heauristic method for choosing a "good" reference. + +Relevant sources: + +Moosmüller, Caroline, and Alexander Cloninger. "Linear optimal transport embedding: provable Wasserstein classification for certain rigid transformations and perturbations." Information and Inference: A Journal of the IMA 12.1 (2023): 363-389. +Pooladian, A.-A. and Niles-Weed, J. Entropic estimation of optimal transport maps. arXiv: 2109.12004, 2021 + +""" + +using Distances +using Distributions +using OptimalTransport + +N = 1000 # number of samples + +# sample some points according to our chosen reference and target distributions +μ = rand(Normal(1,1), N) +ν = rand(Normal(0,1), N) +ρ = rand(Normal(2,1), N) + +# set the weights on the samples to be uniform +a = fill(1/N, N) + +# compute the cost matrices for the two pairs of distributions +C = pairwise(SqEuclidean(), μ', ν') +D = pairwise(SqEuclidean(), μ', ρ') +E = pairwise(SqEuclidean(), ν', ρ') + +# get the entropic transport maps +T_ν = entropic_transport_map(a, a, ν, C, 0.1, SinkhornGibbs()) +T_ρ = entropic_transport_map(a, a, ρ, D, 0.1, SinkhornGibbs()) + +# integrand for the linearization +f(x) = (T_ν([x]) - T_ρ([x]))^2 + +# convert target distributions to dirac clouds +ν_dist = DiscreteNonParametric(ν, a) +ρ_dist = DiscreteNonParametric(ρ, a) + +# compute and compare +I = (sum(f.(μ)) / N)^0.5 # naive Monte Carlo approximation of the L2 distance between the entropic maps +J = ot_cost(sqeuclidean, ν_dist, ρ_dist) + +println("Linear approximation of the distance: $I; True OT distance: $J") diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index bbf0a29a..9e9d49b3 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -18,7 +18,7 @@ export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs export QuadraticOTNewton -export sinkhorn, sinkhorn2 +export sinkhorn, sinkhorn2, sinkhorn_potentials, entropic_transport_map export sinkhorn_stabilized, sinkhorn_stabilized_epsscaling, sinkhorn_barycenter export sinkhorn_unbalanced, sinkhorn_unbalanced2 export sinkhorn_divergence, sinkhorn_divergence_unbalanced diff --git a/src/entropic/sinkhorn.jl b/src/entropic/sinkhorn.jl index 22c13942..b959f4f3 100644 --- a/src/entropic/sinkhorn.jl +++ b/src/entropic/sinkhorn.jl @@ -183,6 +183,94 @@ function sinkhorn(μ, ν, C, ε, alg::Sinkhorn; kwargs...) return γ end +""" + sinkhorn_potentials( + μ, ν, C, ε, alg=SinkhornGibbs(); + atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, + ) + +Compute the dual potentials for the entropically regularized optimal transport +problem with source and target marginals `μ` and `ν`, cost matrix `C` of size +`(length(μ), length(ν))`, and entropic regularization parameter `ε`. + +Every `check_convergence` steps it is assessed if the algorithm is converged by checking if +the iterate of the transport plan `G` satisfies +```julia +isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1)) +``` +The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, +the computation is stopped. + +Batch computations for multiple histograms with a common cost matrix `C` can be performed by +passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required that +the number of source and target marginals is equal or that a single source or single target +marginal is provided (either as matrix or as vector). The optimal transport plans are +returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the +`i`th pair of source and target marginals. + +See also: [`sinkhorn2`](@ref) +""" + +function sinkhorn_potentials(μ, ν, C, ε, alg::Sinkhorn; kwargs...) + # build solver + solver = build_solver(μ, ν, C, ε, alg; kwargs...) + + # perform Sinkhorn algorithm + solve!(solver) + + # compute optimal transport plan + u = solver.cache.u + v = solver.cache.v + f = ε * log.(u) + g = ε * log.(v) + + return (f, g) +end + + +""" + sinkhorn_potentials( + μ, ν, C, ε, alg=SinkhornGibbs(); + atol=0, rtol=atol > 0 ? 0 : √eps, check_convergence=10, maxiter=1_000, + ) + +Compute the entropic transport plan estimator for the entropically regularized optimal transport +problem with source and target marginals `μ` and `ν`, cost matrix `C` of size +`(length(μ), length(ν))`, and entropic regularization parameter `ε`. + +Every `check_convergence` steps it is assessed if the algorithm is converged by checking if +the iterate of the transport plan `G` satisfies +```julia +isapprox(sum(G; dims=2), μ; atol=atol, rtol=rtol, norm=x -> norm(x, 1)) +``` +The default `rtol` depends on the types of `μ`, `ν`, and `C`. After `maxiter` iterations, +the computation is stopped. + +Batch computations for multiple histograms with a common cost matrix `C` can be performed by +passing `μ` or `ν` as matrices whose columns correspond to histograms. It is required that +the number of source and target marginals is equal or that a single source or single target +marginal is provided (either as matrix or as vector). The optimal transport plans are +returned as three-dimensional array where `γ[:, :, i]` is the optimal transport plan for the +`i`th pair of source and target marginals. + +See also: [`sinkhorn2`](@ref) +""" + +function entropic_transport_map(μ, ν, samples_ν, C, ε, alg::Sinkhorn; kwargs...) + _, g = sinkhorn_potentials(μ, ν, C, ε, alg; kwargs...) + N = size(ν, 1) + function T(x::AbstractVecOrMat) + b = zeros(N) + for i in 1:N + y = x .- samples_ν[i,:] + b[i] = exp(1/ε * (g[i] - 0.5 * sum(y .* y))) + end + return sum(b .* samples_ν, dims=1) / sum(b) + end + return T +end + + function sinkhorn_cost_from_plan(γ, C, ε; regularization=false) cost = if regularization dot_matwise(γ, C) .+