From fb5bb0c569f036c271a6c77ec3642c558b8a2b33 Mon Sep 17 00:00:00 2001 From: bluppes Date: Fri, 17 May 2024 11:31:52 +0200 Subject: [PATCH 1/6] add tests --- test/gromov/test_gw.py | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index 0008cebce..a05943d7d 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -832,3 +832,33 @@ def test_fgw_barycenter(nx): # test correspondance with utils function recovered_C = ot.gromov.update_kl_loss(p, lambdas, log['T'], [C1, C2]) np.testing.assert_allclose(C, recovered_C) + + +# Related to issue 469 +def test_gromov2_nan_in_source_cost(): + # GIVEN a source cost matrix with a NaN value + source_cost = np.zeros((2, 2)) + target_cost = np.ones((2, 2)) + source_distribution = np.array([0.5, 0.5]) + target_distribution = np.array([0.5, 0.5]) + + source_cost[0, 0] = np.nan + + # WHEN we call gromov_wasserstein2 - THEN we expect a ValueError + with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'): + ot.gromov_wasserstein2(source_cost, target_cost, source_distribution, target_distribution) + + +# Related to issue 469 +def test_gromov2_nan_in_target_cost(): + # GIVEN - a target cost matrix with a NaN value + source_cost = np.zeros((2, 2)) + target_cost = np.ones((2, 2)) + source_distribution = np.array([0.5, 0.5]) + target_distribution = np.array([0.5, 0.5]) + + target_cost[0, 0] = np.nan + + # WHEN - we call + with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'): + ot.gromov_wasserstein2(source_cost, target_cost, source_distribution, target_distribution) From fc53a26d5dd5a0dbd6106d2bfc06cbaaf1578dfc Mon Sep 17 00:00:00 2001 From: bluppes Date: Fri, 17 May 2024 11:47:46 +0200 Subject: [PATCH 2/6] test emd directly --- test/test_ot.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_ot.py b/test/test_ot.py index a90321d5f..6096061cb 100644 --- a/test/test_ot.py +++ b/test/test_ot.py @@ -245,6 +245,11 @@ def test_emd_empty(): np.testing.assert_allclose(w, 0) +def test_emd_nan_in_loss_matrix(): + with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'): + ot.emd([], [], [np.nan]) + + def test_emd2_multi(): n = 500 # nb bins From 9942d1e2ded85c2caf30a78c934609ad9036c9c1 Mon Sep 17 00:00:00 2001 From: bluppes Date: Fri, 17 May 2024 11:33:35 +0200 Subject: [PATCH 3/6] perform check in emd entrypoint --- ot/lp/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 752c5d2d7..8c4f4852d 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -237,6 +237,8 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c .. note:: An error will be raised if the vectors :math:`\mathbf{a}` and :math:`\mathbf{b}` do not sum to the same value. + .. note:: An error will be raided if the loss matrix :math:`\mathbf{M}` contains NaNs. + Uses the algorithm proposed in :ref:`[1] `. Parameters @@ -302,6 +304,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c ot.optim.cg : General regularized OT """ + if np.isnan(M).any(): + raise ValueError('The loss matrix should not contain NaN values.') + a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) From be8a5ea37c3b2e764d0168adf45de47f407a3892 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Jun 2024 10:41:52 +0200 Subject: [PATCH 4/6] test if nans --- ot/lp/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 8c4f4852d..cec98de27 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -304,9 +304,6 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c ot.optim.cg : General regularized OT """ - if np.isnan(M).any(): - raise ValueError('The loss matrix should not contain NaN values.') - a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) @@ -329,6 +326,9 @@ def emd(a, b, M, numItermax=100000, log=False, center_dual=True, numThreads=1, c # convert to numpy M, a, b = nx.to_numpy(M, a, b) + if np.isnan(M).any(): + raise ValueError('The loss matrix should not contain NaN values.') + # ensure float64 a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) @@ -507,6 +507,9 @@ def emd2(a, b, M, processes=1, # convert to numpy M, a, b = nx.to_numpy(M, a, b) + if np.isnan(M).any(): + raise ValueError('The loss matrix should not contain NaN values.') + a = np.asarray(a, dtype=np.float64) b = np.asarray(b, dtype=np.float64) M = np.asarray(M, dtype=np.float64, order='C') From 79d00b9a8f821eaacaeb78f3aaeba3f143db16be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Jun 2024 10:46:01 +0200 Subject: [PATCH 5/6] Update test_gw.py --- test/gromov/test_gw.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index 8a247d9a1..5fa0acd9e 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -906,22 +906,6 @@ def test_fgw_barycenter(nx): np.testing.assert_allclose(C, Cb, atol=1e-06) np.testing.assert_allclose(X, Xb, atol=1e-06) - -# Related to issue 469 -def test_gromov2_nan_in_source_cost(): - # GIVEN a source cost matrix with a NaN value - source_cost = np.zeros((2, 2)) - target_cost = np.ones((2, 2)) - source_distribution = np.array([0.5, 0.5]) - target_distribution = np.array([0.5, 0.5]) - - source_cost[0, 0] = np.nan - - # WHEN we call gromov_wasserstein2 - THEN we expect a ValueError - with pytest.raises(ValueError, match='The loss matrix should not contain NaN values.'): - ot.gromov_wasserstein2(source_cost, target_cost, source_distribution, target_distribution) - - # Related to issue 469 def test_gromov2_nan_in_target_cost(): # GIVEN - a target cost matrix with a NaN value From b75d07cbe87a27ae40b83657e3647c6a940e07e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Tue, 25 Jun 2024 10:48:03 +0200 Subject: [PATCH 6/6] pep8 --- test/gromov/test_gw.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/gromov/test_gw.py b/test/gromov/test_gw.py index 5fa0acd9e..e61d0eeac 100644 --- a/test/gromov/test_gw.py +++ b/test/gromov/test_gw.py @@ -906,6 +906,7 @@ def test_fgw_barycenter(nx): np.testing.assert_allclose(C, Cb, atol=1e-06) np.testing.assert_allclose(X, Xb, atol=1e-06) + # Related to issue 469 def test_gromov2_nan_in_target_cost(): # GIVEN - a target cost matrix with a NaN value