Skip to content

Commit a295e20

Browse files
committed
Fix maximum and minimum for NaNs
1 parent dc7abe5 commit a295e20

File tree

2 files changed

+60
-8
lines changed

2 files changed

+60
-8
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,17 @@ quad_mod(Sleef_quad *a, Sleef_quad *b)
302302
static inline Sleef_quad
303303
quad_minimum(Sleef_quad *in1, Sleef_quad *in2)
304304
{
305-
return Sleef_icmpleq1(*in1, *in2) ? *in1 : *in2;
305+
return Sleef_iunordq1(*in1, *in2) ? (
306+
Sleef_iunordq1(*in1, *in1) ? *in1 : *in2
307+
) : Sleef_icmpleq1(*in1, *in2) ? *in1 : *in2;
306308
}
307309

308310
static inline Sleef_quad
309311
quad_maximum(Sleef_quad *in1, Sleef_quad *in2)
310312
{
311-
return Sleef_icmpgeq1(*in1, *in2) ? *in1 : *in2;
313+
return Sleef_iunordq1(*in1, *in2) ? (
314+
Sleef_iunordq1(*in1, *in1) ? *in1 : *in2
315+
) : Sleef_icmpgeq1(*in1, *in2) ? *in1 : *in2;
312316
}
313317

314318
static inline Sleef_quad

quaddtype/tests/test_quaddtype.py

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,11 @@ def test_basic_equality():
1818

1919

2020
@pytest.mark.parametrize("op", ["add", "sub", "mul", "truediv", "pow"])
21-
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0"])
21+
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
2222
def test_binary_ops(op, other):
23+
if op == "truediv" and float(other) == 0:
24+
pytest.xfail("float division by zero")
25+
2326
op_func = getattr(operator, op)
2427
quad_a = QuadPrecision("12.5")
2528
quad_b = QuadPrecision(other)
@@ -29,12 +32,17 @@ def test_binary_ops(op, other):
2932
quad_result = op_func(quad_a, quad_b)
3033
float_result = op_func(float_a, float_b)
3134

32-
assert np.abs(np.float64(quad_result) - float_result) < 1e-10
35+
with np.errstate(invalid="ignore"):
36+
assert (
37+
(np.float64(quad_result) == float_result) or
38+
(np.abs(np.float64(quad_result) - float_result) < 1e-10) or
39+
((float_result != float_result) and (quad_result != quad_result))
40+
)
3341

3442

3543
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
36-
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "inf", "-inf", "nan", "-nan"])
37-
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "inf", "-inf", "nan", "-nan"])
44+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
45+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
3846
def test_comparisons(op, a, b):
3947
op_func = getattr(operator, op)
4048
quad_a = QuadPrecision(a)
@@ -46,8 +54,8 @@ def test_comparisons(op, a, b):
4654

4755

4856
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
49-
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "inf", "-inf", "nan", "-nan"])
50-
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "inf", "-inf", "nan", "-nan"])
57+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
58+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
5159
def test_array_comparisons(op, a, b):
5260
op_func = getattr(operator, op)
5361
quad_a = np.array(QuadPrecision(a))
@@ -58,6 +66,46 @@ def test_array_comparisons(op, a, b):
5866
assert np.array_equal(op_func(quad_a, quad_b), op_func(float_a, float_b))
5967

6068

69+
@pytest.mark.parametrize("op", ["minimum", "maximum", "fmin", "fmax"])
70+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
71+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
72+
def test_array_minmax(op, a, b):
73+
if op in ["fmin", "fmax"]:
74+
pytest.skip("fmin and fmax ufuncs are not yet supported")
75+
76+
op_func = getattr(np, op)
77+
quad_a = np.array([QuadPrecision(a)])
78+
quad_b = np.array([QuadPrecision(b)])
79+
float_a = np.array([float(a)])
80+
float_b = np.array([float(b)])
81+
82+
quad_res = op_func(quad_a, quad_b)
83+
float_res = op_func(float_a, float_b)
84+
85+
# FIXME: @juntyr: replace with array_equal once isnan is supported
86+
with np.errstate(invalid="ignore"):
87+
assert np.all((quad_res == float_res) | ((quad_res != quad_res) & (float_res != float_res)))
88+
89+
90+
@pytest.mark.parametrize("op", ["amin", "amax", "nanmin", "nanmax"])
91+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
92+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
93+
def test_array_aminmax(op, a, b):
94+
if op in ["nanmin", "nanmax"]:
95+
pytest.skip("fmin and fmax ufuncs are not yet supported")
96+
97+
op_func = getattr(np, op)
98+
quad_ab = np.array([QuadPrecision(a), QuadPrecision(b)])
99+
float_ab = np.array([float(a), float(b)])
100+
101+
quad_res = op_func(quad_ab)
102+
float_res = op_func(float_ab)
103+
104+
# FIXME: @juntyr: replace with array_equal once isnan is supported
105+
with np.errstate(invalid="ignore"):
106+
assert np.all((quad_res == float_res) | ((quad_res != quad_res) & (float_res != float_res)))
107+
108+
61109
@pytest.mark.parametrize("op, val, expected", [
62110
("neg", "3.0", "-3.0"),
63111
("neg", "-3.0", "3.0"),

0 commit comments

Comments
 (0)