Skip to content

Commit dc7abe5

Browse files
committed
Fix rich comparison
1 parent 708be1a commit dc7abe5

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ quad_equal(const Sleef_quad *a, const Sleef_quad *b)
386386
static inline npy_bool
387387
quad_notequal(const Sleef_quad *a, const Sleef_quad *b)
388388
{
389-
return Sleef_icmpneq1(*a, *b) | Sleef_iunordq1(*a, *b);
389+
return Sleef_icmpneq1(*a, *b) || Sleef_iunordq1(*a, *b);
390390
}
391391

392392
static inline npy_bool

quaddtype/numpy_quaddtype/src/scalar_ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ quad_richcompare(QuadPrecisionObject *self, PyObject *other, int cmp_op)
163163
cmp = Sleef_icmpeqq1(self->value.sleef_value, other_quad->value.sleef_value);
164164
break;
165165
case Py_NE:
166-
cmp = Sleef_icmpneq1(self->value.sleef_value, other_quad->value.sleef_value);
166+
cmp = Sleef_icmpneq1(self->value.sleef_value, other_quad->value.sleef_value) || Sleef_iunordq1(self->value.sleef_value, other_quad->value.sleef_value);
167167
break;
168168
case Py_GT:
169169
cmp = Sleef_icmpgtq1(self->value.sleef_value, other_quad->value.sleef_value);

quaddtype/tests/test_quaddtype.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,19 @@ def test_comparisons(op, a, b):
4545
assert op_func(quad_a, quad_b) == op_func(float_a, float_b)
4646

4747

48+
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
49+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "inf", "-inf", "nan", "-nan"])
50+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "inf", "-inf", "nan", "-nan"])
51+
def test_array_comparisons(op, a, b):
52+
op_func = getattr(operator, op)
53+
quad_a = np.array(QuadPrecision(a))
54+
quad_b = np.array(QuadPrecision(b))
55+
float_a = np.array(float(a))
56+
float_b = np.array(float(b))
57+
58+
assert np.array_equal(op_func(quad_a, quad_b), op_func(float_a, float_b))
59+
60+
4861
@pytest.mark.parametrize("op, val, expected", [
4962
("neg", "3.0", "-3.0"),
5063
("neg", "-3.0", "3.0"),

0 commit comments

Comments
 (0)