Skip to content

Commit e0f48f4

Browse files
committed
Implement hyperbolic ufuncs
Add hyperbolic ufunc tests Add more tests for -0.0 Guarantee that min/max are zero-sign-sensitive
1 parent 90f6197 commit e0f48f4

File tree

4 files changed

+235
-26
lines changed

4 files changed

+235
-26
lines changed

quaddtype/numpy_quaddtype/src/ops.hpp

Lines changed: 104 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,42 @@ quad_atan(const Sleef_quad *op)
151151
return Sleef_atanq1_u10(*op);
152152
}
153153

154+
static inline Sleef_quad
155+
quad_sinh(const Sleef_quad *op)
156+
{
157+
return Sleef_sinhq1_u10(*op);
158+
}
159+
160+
static inline Sleef_quad
161+
quad_cosh(const Sleef_quad *op)
162+
{
163+
return Sleef_coshq1_u10(*op);
164+
}
165+
166+
static inline Sleef_quad
167+
quad_tanh(const Sleef_quad *op)
168+
{
169+
return Sleef_tanhq1_u10(*op);
170+
}
171+
172+
static inline Sleef_quad
173+
quad_asinh(const Sleef_quad *op)
174+
{
175+
return Sleef_asinhq1_u10(*op);
176+
}
177+
178+
static inline Sleef_quad
179+
quad_acosh(const Sleef_quad *op)
180+
{
181+
return Sleef_acoshq1_u10(*op);
182+
}
183+
184+
static inline Sleef_quad
185+
quad_atanh(const Sleef_quad *op)
186+
{
187+
return Sleef_atanhq1_u10(*op);
188+
}
189+
154190
// Unary long double operations
155191
typedef long double (*unary_op_longdouble_def)(const long double *);
156192

@@ -299,6 +335,42 @@ ld_atan(const long double *op)
299335
return atanl(*op);
300336
}
301337

338+
static inline long double
339+
ld_sinh(const long double *op)
340+
{
341+
return sinhl(*op);
342+
}
343+
344+
static inline long double
345+
ld_cosh(const long double *op)
346+
{
347+
return coshl(*op);
348+
}
349+
350+
static inline long double
351+
ld_tanh(const long double *op)
352+
{
353+
return tanhl(*op);
354+
}
355+
356+
static inline long double
357+
ld_asinh(const long double *op)
358+
{
359+
return asinhl(*op);
360+
}
361+
362+
static inline long double
363+
ld_acosh(const long double *op)
364+
{
365+
return acoshl(*op);
366+
}
367+
368+
static inline long double
369+
ld_atanh(const long double *op)
370+
{
371+
return atanhl(*op);
372+
}
373+
302374
// Unary Quad properties
303375
typedef npy_bool (*unary_prop_quad_def)(const Sleef_quad *);
304376

@@ -442,33 +514,53 @@ quad_mod(const Sleef_quad *a, const Sleef_quad *b)
442514
static inline Sleef_quad
443515
quad_minimum(const Sleef_quad *in1, const Sleef_quad *in2)
444516
{
445-
return Sleef_iunordq1(*in1, *in2) ? (Sleef_iunordq1(*in1, *in1) ? *in1 : *in2)
446-
: Sleef_icmpleq1(*in1, *in2) ? *in1
447-
: *in2;
517+
if (Sleef_iunordq1(*in1, *in2)) {
518+
return Sleef_iunordq1(*in1, *in1) ? *in1 : *in2;
519+
}
520+
// minimum(-0.0, +0.0) = -0.0
521+
if (Sleef_icmpeqq1(*in1, QUAD_ZERO) && Sleef_icmpeqq1(*in2, QUAD_ZERO)) {
522+
return Sleef_icmpleq1(Sleef_copysignq1(QUAD_ONE, *in1), Sleef_copysignq1(QUAD_ONE, *in2)) ? *in1 : *in2;
523+
}
524+
return Sleef_fminq1(*in1, *in2);
448525
}
449526

450527
static inline Sleef_quad
451528
quad_maximum(const Sleef_quad *in1, const Sleef_quad *in2)
452529
{
453-
return Sleef_iunordq1(*in1, *in2) ? (Sleef_iunordq1(*in1, *in1) ? *in1 : *in2)
454-
: Sleef_icmpgeq1(*in1, *in2) ? *in1
455-
: *in2;
530+
if (Sleef_iunordq1(*in1, *in2)) {
531+
return Sleef_iunordq1(*in1, *in1) ? *in1 : *in2;
532+
}
533+
// maximum(-0.0, +0.0) = +0.0
534+
if (Sleef_icmpeqq1(*in1, QUAD_ZERO) && Sleef_icmpeqq1(*in2, QUAD_ZERO)) {
535+
return Sleef_icmpgeq1(Sleef_copysignq1(QUAD_ONE, *in1), Sleef_copysignq1(QUAD_ONE, *in2)) ? *in1 : *in2;
536+
}
537+
return Sleef_fmaxq1(*in1, *in2);
456538
}
457539

458540
static inline Sleef_quad
459541
quad_fmin(const Sleef_quad *in1, const Sleef_quad *in2)
460542
{
461-
return Sleef_iunordq1(*in1, *in2) ? (Sleef_iunordq1(*in2, *in2) ? *in1 : *in2)
462-
: Sleef_icmpleq1(*in1, *in2) ? *in1
463-
: *in2;
543+
if (Sleef_iunordq1(*in1, *in2)) {
544+
return Sleef_iunordq1(*in2, *in2) ? *in1 : *in2;
545+
}
546+
// fmin(-0.0, +0.0) = -0.0
547+
if (Sleef_icmpeqq1(*in1, QUAD_ZERO) && Sleef_icmpeqq1(*in2, QUAD_ZERO)) {
548+
return Sleef_icmpleq1(Sleef_copysignq1(QUAD_ONE, *in1), Sleef_copysignq1(QUAD_ONE, *in2)) ? *in1 : *in2;
549+
}
550+
return Sleef_fminq1(*in1, *in2);
464551
}
465552

466553
static inline Sleef_quad
467554
quad_fmax(const Sleef_quad *in1, const Sleef_quad *in2)
468555
{
469-
return Sleef_iunordq1(*in1, *in2) ? (Sleef_iunordq1(*in2, *in2) ? *in1 : *in2)
470-
: Sleef_icmpgeq1(*in1, *in2) ? *in1
471-
: *in2;
556+
if (Sleef_iunordq1(*in1, *in2)) {
557+
return Sleef_iunordq1(*in2, *in2) ? *in1 : *in2;
558+
}
559+
// maximum(-0.0, +0.0) = +0.0
560+
if (Sleef_icmpeqq1(*in1, QUAD_ZERO) && Sleef_icmpeqq1(*in2, QUAD_ZERO)) {
561+
return Sleef_icmpgeq1(Sleef_copysignq1(QUAD_ONE, *in1), Sleef_copysignq1(QUAD_ONE, *in2)) ? *in1 : *in2;
562+
}
563+
return Sleef_fmaxq1(*in1, *in2);
472564
}
473565

474566
static inline Sleef_quad

quaddtype/numpy_quaddtype/src/umath/unary_ops.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,5 +216,23 @@ init_quad_unary_ops(PyObject *numpy)
216216
if (create_quad_unary_ufunc<quad_atan, ld_atan>(numpy, "arctan") < 0) {
217217
return -1;
218218
}
219+
if (create_quad_unary_ufunc<quad_sinh, ld_sinh>(numpy, "sinh") < 0) {
220+
return -1;
221+
}
222+
if (create_quad_unary_ufunc<quad_cosh, ld_cosh>(numpy, "cosh") < 0) {
223+
return -1;
224+
}
225+
if (create_quad_unary_ufunc<quad_tanh, ld_tanh>(numpy, "tanh") < 0) {
226+
return -1;
227+
}
228+
if (create_quad_unary_ufunc<quad_asinh, ld_asinh>(numpy, "arcsinh") < 0) {
229+
return -1;
230+
}
231+
if (create_quad_unary_ufunc<quad_acosh, ld_acosh>(numpy, "arccosh") < 0) {
232+
return -1;
233+
}
234+
if (create_quad_unary_ufunc<quad_atanh, ld_atanh>(numpy, "arctanh") < 0) {
235+
return -1;
236+
}
219237
return 0;
220238
}

quaddtype/release_tracker.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,12 @@
5050
| arctan ||_Need: basic tests + edge cases (NaN/inf/0/asymptotes)_ |
5151
| arctan2 ||_Need: basic tests + edge cases (NaN/inf/0/quadrant coverage)_ |
5252
| hypot | | |
53-
| sinh | | |
54-
| cosh | | |
55-
| tanh | | |
56-
| arcsinh | | |
57-
| arccosh | | |
58-
| arctanh | | |
53+
| sinh | | |
54+
| cosh | | |
55+
| tanh | | |
56+
| arcsinh | | |
57+
| arccosh | | |
58+
| arctanh | | |
5959
| degrees | | |
6060
| radians | | |
6161
| deg2rad | | |

quaddtype/tests/test_quaddtype.py

Lines changed: 107 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,28 @@ def test_basic_equality():
3535

3636

3737
@pytest.mark.parametrize("op", ["add", "sub", "mul", "truediv", "pow", "copysign"])
38-
@pytest.mark.parametrize("other", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
39-
def test_binary_ops(op, other):
40-
if op == "truediv" and float(other) == 0:
38+
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
39+
@pytest.mark.parametrize("b", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
40+
def test_binary_ops(op, a, b):
41+
if op == "truediv" and float(b) == 0:
4142
pytest.xfail("float division by zero")
4243

4344
op_func = getattr(operator, op, None) or getattr(np, op)
44-
quad_a = QuadPrecision("12.5")
45-
quad_b = QuadPrecision(other)
46-
float_a = 12.5
47-
float_b = float(other)
45+
quad_a = QuadPrecision(a)
46+
quad_b = QuadPrecision(b)
47+
float_a = float(a)
48+
float_b = float(b)
4849

4950
quad_result = op_func(quad_a, quad_b)
5051
float_result = op_func(float_a, float_b)
5152

5253
np.testing.assert_allclose(np.float64(quad_result), float_result, atol=1e-10, rtol=0, equal_nan=True)
5354

55+
# Check sign for zero results
56+
if float_result == 0.0:
57+
assert np.signbit(float_result) == np.signbit(
58+
quad_result), f"Zero sign mismatch for {op}({a}, {b})"
59+
5460

5561
@pytest.mark.parametrize("op", ["eq", "ne", "le", "lt", "ge", "gt"])
5662
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
@@ -91,8 +97,20 @@ def test_array_minmax(op, a, b):
9197
quad_res = op_func(quad_a, quad_b)
9298
float_res = op_func(float_a, float_b)
9399

100+
# native implementation may not be sensitive to zero signs
101+
# but we want to enforce it for the quad dtype
102+
# e.g. min(+0.0, -0.0) = -0.0
103+
if float_a == 0.0 and float_b == 0.0:
104+
assert float_res == 0.0
105+
float_res = np.copysign(0.0, op_func(np.copysign(1.0, float_a), np.copysign(1.0, float_b)))
106+
94107
np.testing.assert_array_equal(quad_res.astype(float), float_res)
95108

109+
# Check sign for zero results
110+
if float_res == 0.0:
111+
assert np.signbit(float_res) == np.signbit(
112+
quad_res), f"Zero sign mismatch for {op}({a}, {b})"
113+
96114

97115
@pytest.mark.parametrize("op", ["amin", "amax", "nanmin", "nanmax"])
98116
@pytest.mark.parametrize("a", ["3.0", "12.5", "100.0", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
@@ -105,8 +123,20 @@ def test_array_aminmax(op, a, b):
105123
quad_res = op_func(quad_ab)
106124
float_res = op_func(float_ab)
107125

126+
# native implementation may not be sensitive to zero signs
127+
# but we want to enforce it for the quad dtype
128+
# e.g. min(+0.0, -0.0) = -0.0
129+
if float(a) == 0.0 and float(b) == 0.0:
130+
assert float_res == 0.0
131+
float_res = np.copysign(0.0, op_func(np.array([np.copysign(1.0, float(a)), np.copysign(1.0, float(b))])))
132+
108133
np.testing.assert_array_equal(np.array(quad_res).astype(float), float_res)
109134

135+
# Check sign for zero results
136+
if float_res == 0.0:
137+
assert np.signbit(float_res) == np.signbit(
138+
quad_res), f"Zero sign mismatch for {op}({a}, {b})"
139+
110140

111141
@pytest.mark.parametrize("op", ["negative", "positive", "absolute", "sign", "signbit", "isfinite", "isinf", "isnan", "sqrt", "square", "reciprocal"])
112142
@pytest.mark.parametrize("val", ["3.0", "-3.0", "12.5", "100.0", "1e100", "0.0", "-0.0", "inf", "-inf", "nan", "-nan"])
@@ -126,7 +156,7 @@ def test_unary_ops(op, val):
126156

127157
np.testing.assert_array_equal(np.array(quad_result).astype(float), float_result)
128158

129-
if op in ["negative", "positive", "absolute", "sign"]:
159+
if (float_result == 0.0) and (op not in ["signbit", "isfinite", "isinf", "isnan"]):
130160
assert np.signbit(float_result) == np.signbit(quad_result)
131161

132162

@@ -290,6 +320,11 @@ def test_logarithmic_functions(op, val):
290320
np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=atol,
291321
err_msg=f"Value mismatch for {op}({val})")
292322

323+
# Check sign for zero results
324+
if float_result == 0.0:
325+
assert np.signbit(float_result) == np.signbit(
326+
quad_result), f"Zero sign mismatch for {op}({a}, {b})"
327+
293328

294329
@pytest.mark.parametrize("val", [
295330
# Basic cases around -1 (critical point for log1p)
@@ -304,6 +339,8 @@ def test_logarithmic_functions(op, val):
304339
"-1.1", "-2.0", "-10.0",
305340
# Large positive values
306341
"1e10", "1e15", "1e100",
342+
# Edge cases
343+
"0.0", "-0.0",
307344
# Special values
308345
"inf", "-inf", "nan", "-nan"
309346
])
@@ -341,9 +378,16 @@ def test_log1p(val):
341378
np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=atol,
342379
err_msg=f"Value mismatch for log1p({val})")
343380

381+
# Check sign for zero results
382+
if float_result == 0.0:
383+
assert np.signbit(float_result) == np.signbit(
384+
quad_result), f"Zero sign mismatch for {op}({val})"
385+
344386
def test_inf():
345387
assert QuadPrecision("inf") > QuadPrecision("1e1000")
388+
assert np.signbit(QuadPrecision("inf")) == 0
346389
assert QuadPrecision("-inf") < QuadPrecision("-1e1000")
390+
assert np.signbit(QuadPrecision("-inf")) == 1
347391

348392

349393
def test_dtype_creation():
@@ -448,3 +492,58 @@ def test_mod(a, b, backend, op):
448492
numpy_negative = numpy_result < 0
449493

450494
assert result_negative == numpy_negative, f"Sign mismatch for {a} % {b}: quad={result_negative}, numpy={numpy_negative}"
495+
496+
497+
@pytest.mark.parametrize("op", ["sinh", "cosh", "tanh", "arcsinh", "arccosh", "arctanh"])
498+
@pytest.mark.parametrize("val", [
499+
# Basic cases
500+
"0.0", "-0.0", "1.0", "-1.0", "2.0", "-2.0",
501+
# Small values
502+
"1e-10", "-1e-10", "1e-15", "-1e-15",
503+
# Values near one
504+
"0.9", "-0.9", "0.9999", "-0.9999",
505+
"1.1", "-1.1", "1.0001", "-1.0001",
506+
# Medium values
507+
"10.0", "-10.0", "20.0", "-20.0",
508+
# Large values
509+
"100.0", "200.0", "700.0", "1000.0", "1e100", "1e308",
510+
"-100.0", "-200.0", "-700.0", "-1000.0", "-1e100", "-1e308",
511+
# Fractional values
512+
"0.5", "-0.5", "1.5", "-1.5", "2.5", "-2.5",
513+
# Special values
514+
"inf", "-inf", "nan", "-nan"
515+
])
516+
def test_hyperbolic_functions(op, val):
517+
"""Comprehensive test for hyperbolic functions: sinh, cosh, tanh, arcsinh, arccosh, arctanh"""
518+
op_func = getattr(np, op)
519+
520+
quad_val = QuadPrecision(val)
521+
float_val = float(val)
522+
523+
quad_result = op_func(quad_val)
524+
float_result = op_func(float_val)
525+
526+
# Handle NaN cases
527+
if np.isnan(float_result):
528+
assert np.isnan(
529+
float(quad_result)), f"Expected NaN for {op}({val}), got {float(quad_result)}"
530+
return
531+
532+
# Handle infinity cases
533+
if np.isinf(float_result):
534+
assert np.isinf(
535+
float(quad_result)), f"Expected inf for {op}({val}), got {float(quad_result)}"
536+
assert np.sign(float_result) == np.sign(
537+
float(quad_result)), f"Infinity sign mismatch for {op}({val})"
538+
return
539+
540+
# For finite non-zero results
541+
# Use relative tolerance for exponential functions due to their rapid growth
542+
rtol = 1e-13 if abs(float_result) < 1e100 else 1e-10
543+
np.testing.assert_allclose(float(quad_result), float_result, rtol=rtol, atol=1e-15,
544+
err_msg=f"Value mismatch for {op}({val})")
545+
546+
# Check sign for zero results
547+
if float_result == 0.0:
548+
assert np.signbit(float_result) == np.signbit(
549+
quad_result), f"Zero sign mismatch for {op}({val})"

0 commit comments

Comments
 (0)