Skip to content

Commit 5914811

Browse files
authored
Merge pull request #115 from juntyr/nan-comparisons
Test and fix comparisons with NaNs
2 parents 53ccf00 + 2db17f4 commit 5914811

File tree

3 files changed

+78
-12
lines changed

3 files changed

+78
-12
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,17 @@ quad_mod(Sleef_quad *a, Sleef_quad *b)
302302
static inline Sleef_quad
303303
quad_minimum(Sleef_quad *in1, Sleef_quad *in2)
304304
{
305-
return Sleef_icmpleq1(*in1, *in2) ? *in1 : *in2;
305+
return Sleef_iunordq1(*in1, *in2) ? (
306+
Sleef_iunordq1(*in1, *in1) ? *in1 : *in2
307+
) : Sleef_icmpleq1(*in1, *in2) ? *in1 : *in2;
306308
}
307309

308310
static inline Sleef_quad
309311
quad_maximum(Sleef_quad *in1, Sleef_quad *in2)
310312
{
311-
return Sleef_icmpgeq1(*in1, *in2) ? *in1 : *in2;
313+
return Sleef_iunordq1(*in1, *in2) ? (
314+
Sleef_iunordq1(*in1, *in1) ? *in1 : *in2
315+
) : Sleef_icmpgeq1(*in1, *in2) ? *in1 : *in2;
312316
}
313317

314318
static inline Sleef_quad
@@ -386,7 +390,7 @@ quad_equal(const Sleef_quad *a, const Sleef_quad *b)
386390
static inline npy_bool
387391
quad_notequal(const Sleef_quad *a, const Sleef_quad *b)
388392
{
389-
return Sleef_icmpneq1(*a, *b);
393+
return Sleef_icmpneq1(*a, *b) || Sleef_iunordq1(*a, *b);
390394
}
391395

392396
static inline npy_bool

quaddtype/numpy_quaddtype/src/scalar_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ quad_richcompare(QuadPrecisionObject *self, PyObject *other, int cmp_op)
163163
cmp = Sleef_icmpeqq1(self->value.sleef_value, other_quad->value.sleef_value);
164164
break;
165165
case Py_NE:
166-
cmp = Sleef_icmpneq1(self->value.sleef_value, other_quad->value.sleef_value);
166+
cmp = Sleef_icmpneq1(self->value.sleef_value, other_quad->value.sleef_value) || Sleef_iunordq1(self->value.sleef_value, other_quad->value.sleef_value);
167167
break;
168168
case Py_GT:
169169
cmp = Sleef_icmpgtq1(self->value.sleef_value, other_quad->value.sleef_value);

quaddtype/tests/test_quaddtype.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ def test_basic_equality():
1818

1919

2020
@pytest.mark.parametrize("op", ["add", "sub", "mul", "truediv", "pow"])
21-
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0"])
21+
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
2222
def test_binary_ops(op, other):
23+
if op == "truediv" and float(other) == 0:
24+
pytest.xfail("float division by zero")
25+
2326
op_func = getattr(operator, op)
2427
quad_a = QuadPrecision("12.5")
2528
quad_b = QuadPrecision(other)
@@ -29,21 +32,80 @@ def test_binary_ops(op, other):
2932
quad_result = op_func(quad_a, quad_b)
3033
float_result = op_func(float_a, float_b)
3134

32-
assert np.abs(np.float64(quad_result) - float_result) < 1e-10
35+
with np.errstate(invalid="ignore"):
36+
assert (
37+
(np.float64(quad_result) == float_result) or
38+
(np.abs(np.float64(quad_result) - float_result) < 1e-10) or
39+
((float_result != float_result) and (quad_result != quad_result))
40+
)
3341

3442

3543
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
36-
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0"])
37-
def test_comparisons(op, other):
44+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
45+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
46+
def test_comparisons(op, a, b):
3847
op_func = getattr(operator, op)
39-
quad_a = QuadPrecision("12.5")
40-
quad_b = QuadPrecision(other)
41-
float_a = 12.5
42-
float_b = float(other)
48+
quad_a = QuadPrecision(a)
49+
quad_b = QuadPrecision(b)
50+
float_a = float(a)
51+
float_b = float(b)
4352

4453
assert op_func(quad_a, quad_b) == op_func(float_a, float_b)
4554

4655

56+
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
57+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
58+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
59+
def test_array_comparisons(op, a, b):
60+
op_func = getattr(operator, op)
61+
quad_a = np.array(QuadPrecision(a))
62+
quad_b = np.array(QuadPrecision(b))
63+
float_a = np.array(float(a))
64+
float_b = np.array(float(b))
65+
66+
assert np.array_equal(op_func(quad_a, quad_b), op_func(float_a, float_b))
67+
68+
69+
@pytest.mark.parametrize("op", ["minimum", "maximum", "fmin", "fmax"])
70+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
71+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
72+
def test_array_minmax(op, a, b):
73+
if op in ["fmin", "fmax"]:
74+
pytest.skip("fmin and fmax ufuncs are not yet supported")
75+
76+
op_func = getattr(np, op)
77+
quad_a = np.array([QuadPrecision(a)])
78+
quad_b = np.array([QuadPrecision(b)])
79+
float_a = np.array([float(a)])
80+
float_b = np.array([float(b)])
81+
82+
quad_res = op_func(quad_a, quad_b)
83+
float_res = op_func(float_a, float_b)
84+
85+
# FIXME: @juntyr: replace with array_equal once isnan is supported
86+
with np.errstate(invalid="ignore"):
87+
assert np.all((quad_res == float_res) | ((quad_res != quad_res) & (float_res != float_res)))
88+
89+
90+
@pytest.mark.parametrize("op", ["amin", "amax", "nanmin", "nanmax"])
91+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
92+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
93+
def test_array_aminmax(op, a, b):
94+
if op in ["nanmin", "nanmax"]:
95+
pytest.skip("fmin and fmax ufuncs are not yet supported")
96+
97+
op_func = getattr(np, op)
98+
quad_ab = np.array([QuadPrecision(a), QuadPrecision(b)])
99+
float_ab = np.array([float(a), float(b)])
100+
101+
quad_res = op_func(quad_ab)
102+
float_res = op_func(float_ab)
103+
104+
# FIXME: @juntyr: replace with array_equal once isnan is supported
105+
with np.errstate(invalid="ignore"):
106+
assert np.all((quad_res == float_res) | ((quad_res != quad_res) & (float_res != float_res)))
107+
108+
47109
@pytest.mark.parametrize("op, val, expected", [
48110
("neg", "3.0", "-3.0"),
49111
("neg", "-3.0", "3.0"),

0 commit comments

Comments
 (0)