Skip to content

[WIP] Partial optimal transport 1d solver #741

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
aef4b55
Implemented the partial optimal transport 1d solver from Chapel & Tav…
rtavenar Jul 1, 2025
c36a45d
bugfix in cythonize
rtavenar Jul 1, 2025
5a6ef86
yet another bugfix in setup.py
rtavenar Jul 1, 2025
978e0f4
relative imports
rtavenar Jul 1, 2025
189a98b
make partial_wasserstein_1d visible at the subpackage level
rtavenar Jul 1, 2025
e873da6
minor fix
rtavenar Jul 1, 2025
ea5f59b
add tests
rtavenar Jul 2, 2025
d0689aa
fix data gen in test
rtavenar Jul 2, 2025
ae504fa
bugfix
rtavenar Jul 2, 2025
5214d4b
make costs double
rtavenar Jul 2, 2025
8245cd9
renaming
rtavenar Jul 2, 2025
8a78d6d
minor
rtavenar Jul 2, 2025
796d09d
remove unused log arg
rtavenar Jul 2, 2025
a0c241d
check precommit
rtavenar Jul 2, 2025
5b37670
better docs and test
rtavenar Jul 2, 2025
6ef4ca3
linting
rtavenar Jul 2, 2025
7eee412
info
rtavenar Jul 2, 2025
b68aa7d
empty commit for co-authorship
rtavenar Jul 2, 2025
843324c
add gallery example
rtavenar Jul 2, 2025
ff42483
minor refactor
rtavenar Jul 2, 2025
ec7d9c1
bugfix: use heapq also at init step
rtavenar Jul 2, 2025
bbc930b
define a function for plotting
rtavenar Jul 2, 2025
2ba8d56
bugfix
rtavenar Jul 2, 2025
9958a31
example fig tweaking
rtavenar Jul 2, 2025
27ec673
minor docs
rtavenar Jul 2, 2025
f11575d
figsize
rtavenar Jul 3, 2025
98967fb
minor
rtavenar Jul 3, 2025
63b5c05
removed pure python types as much as possible (yet to be tested prope…
rtavenar Jul 3, 2025
6ccad78
bugfix in insert_new_chain
rtavenar Jul 4, 2025
11ab6b3
added tests
rtavenar Jul 4, 2025
eba542d
should not change anything
rtavenar Jul 4, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Backend implementation of `ot.dist` for (PR #701)
- Updated documentation Quickstart guide and User guide with new API (PR #726)
- Fix jax version for auto-grad (PR #732)
- Implement 1d solver for partial optimal transport (PR #741)

#### Closed issues
- Fixed `ot.mapping` solvers which depended on deprecated `cvxpy` `ECOS` solver (PR #692, Issue #668)
Expand Down
85 changes: 85 additions & 0 deletions examples/unbalanced-partial/plot_partial_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
=========================
Partial Wasserstein in 1D
=========================

This script demonstrates how to compute and visualize the Partial Wasserstein distance between two 1D discrete distributions using `ot.partial.partial_wasserstein_1d`.

We illustrate the intermediate transport plans for all `k = 1...n`, where `n = min(len(x_a), len(x_b))`.
"""

# sphinx_gallery_thumbnail_number = 5

import numpy as np
import matplotlib.pyplot as plt
from ot.partial import partial_wasserstein_1d


def plot_partial_transport(
ax, x_a, x_b, indices_a=None, indices_b=None, marginal_costs=None
):
y_a = np.ones_like(x_a)
y_b = -np.ones_like(x_b)
min_min = min(x_a.min(), x_b.min())
max_max = max(x_a.max(), x_b.max())

ax.plot([min_min - 1, max_max + 1], [1, 1], "k-", lw=0.5, alpha=0.5)
ax.plot([min_min - 1, max_max + 1], [-1, -1], "k-", lw=0.5, alpha=0.5)

# Plot transport lines
if indices_a is not None and indices_b is not None:
subset_a = np.sort(x_a[indices_a])
subset_b = np.sort(x_b[indices_b])

for x_a_i, x_b_j in zip(subset_a, subset_b):
ax.plot([x_a_i, x_b_j], [1, -1], "k--", alpha=0.7)

# Plot all points
ax.plot(x_a, y_a, "o", color="C0", label="x_a", markersize=8)
ax.plot(x_b, y_b, "o", color="C1", label="x_b", markersize=8)

if marginal_costs is not None:
k = len(marginal_costs)
ax.set_title(
f"Partial Transport - k = {k}, Cumulative Cost = {sum(marginal_costs):.2f}",
fontsize=16,
)
else:
ax.set_title("Original 1D Discrete Distributions", fontsize=16)
ax.legend(loc="upper right", fontsize=14)
ax.set_yticks([])
ax.set_xticks([])
ax.set_ylim(-2, 2)
ax.set_xlim(min(x_a.min(), x_b.min()) - 1, max(x_a.max(), x_b.max()) + 1)
ax.axis("off")


# Simulate two 1D discrete distributions
np.random.seed(0)
n = 6
x_a = np.sort(np.random.uniform(0, 10, size=n))
x_b = np.sort(np.random.uniform(0, 10, size=n))

# Plot original distributions
plt.figure(figsize=(6, 2))
plot_partial_transport(plt.gca(), x_a, x_b)
plt.show()

# %%
indices_a, indices_b, marginal_costs = partial_wasserstein_1d(x_a, x_b)

# Compute cumulative cost
cumulative_costs = np.cumsum(marginal_costs)

# Visualize all partial transport plans
for k in range(n):
plt.figure(figsize=(6, 2))
plot_partial_transport(
plt.gca(),
x_a,
x_b,
indices_a[: k + 1],
indices_b[: k + 1],
marginal_costs[: k + 1],
)
plt.show()
37 changes: 37 additions & 0 deletions ot/partial/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
"""
Efficient 1D solver for the partial optimal transport problem.
"""

# Author: Romain Tavenard <[email protected]>
#
# License: MIT License

# import compiled emd
from .partial_solvers import (
partial_wasserstein_lagrange,
partial_wasserstein,
partial_wasserstein2,
entropic_partial_wasserstein,
gwgrad_partial,
gwloss_partial,
partial_gromov_wasserstein,
partial_gromov_wasserstein2,
entropic_partial_gromov_wasserstein,
entropic_partial_gromov_wasserstein2,
partial_wasserstein_1d,
)

__all__ = [
"partial_wasserstein_1d",
"partial_wasserstein_lagrange",
"partial_wasserstein",
"partial_wasserstein2",
"entropic_partial_wasserstein",
"gwgrad_partial",
"gwloss_partial",
"partial_gromov_wasserstein",
"partial_gromov_wasserstein2",
"entropic_partial_gromov_wasserstein",
"entropic_partial_gromov_wasserstein2",
]
Loading
Loading