Skip to content

fix: check for NaNs in emd loss matrix #623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
8 changes: 8 additions & 0 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
@@ -237,6 +237,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c

.. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value.

.. note:: An error will be raided if the loss matrix :math:`\mathbf{M}` contains NaNs.

Uses the algorithm proposed in :ref:`[1] <references-emd>`.

Parameters
@@ -324,6 +326,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c
# convert to numpy
M, a, b = nx.to_numpy(M, a, b)

if np.isnan(M).any():
raise ValueError('The loss matrix should not contain NaN values.')

# ensure float64
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
@@ -502,6 +507,9 @@ def emd2(a, b, M, processes=1,
# convert to numpy
M, a, b = nx.to_numpy(M, a, b)

if np.isnan(M).any():
raise ValueError('The loss matrix should not contain NaN values.')

a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64, order='C')
15 changes: 15 additions & 0 deletions test/gromov/test_gw.py
Original file line number Diff line number Diff line change
@@ -910,3 +910,18 @@ def test_fgw_barycenter(nx):

np.testing.assert_allclose(C, Cb, atol=1e-06)
np.testing.assert_allclose(X, Xb, atol=1e-06)


# Related to issue 469
def test_gromov2_nan_in_target_cost():
# GIVEN - a target cost matrix with a NaN value
source_cost = np.zeros((2, 2))
target_cost = np.ones((2, 2))
source_distribution = np.array([0.5, 0.5])
target_distribution = np.array([0.5, 0.5])

target_cost[0, 0] = np.nan

# WHEN - we call
with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'):
ot.gromov_wasserstein2(source_cost, target_cost, source_distribution, target_distribution)
5 changes: 5 additions & 0 deletions test/test_ot.py
Original file line number Diff line number Diff line change
@@ -245,6 +245,11 @@ def test_emd_empty():
np.testing.assert_allclose(w, 0)


def test_emd_nan_in_loss_matrix():
with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'):
ot.emd([], [], [np.nan])


def test_emd2_multi():
n = 500 # nb bins