diff --git a/src/special/double.jl b/src/special/double.jl index 019e758c..c81baedc 100644 --- a/src/special/double.jl +++ b/src/special/double.jl @@ -315,9 +315,19 @@ end end # two-prod-fma +# +# These `::True` branches implement the TwoProduct error-free transformation +# (e.g. `Double(z, fma(x, y, -z))`), which extracts the rounding error of +# `x*y` as the low part of a double-double. The identity requires a +# **single-rounded** FMA; using the may-fuse `vfmsub`/`vfnmadd` lets LLVM +# constant-fold `(x*y) + (-(x*y))` to `0` even on hardware that has FMA, +# which destroys every double-double error term and propagates as wildly +# wrong results in downstream `asinh`/`sin`/`cos`/etc. (e.g. ~10^5 ULP on +# Float32). Call `vfma` (which lowers to `llvm.fma`) directly with an +# explicit negation so the FMA is preserved. @inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::True) z = (x * y) - Double(z, vfmsub(x, y, z)) + Double(z, vfma(x, y, -z)) end @inline function dmul(x::vIEEEFloat, y::vIEEEFloat, ::False) hx, lx = splitprec(x) @@ -329,7 +339,7 @@ end end @inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::True) z = (x.hi * y) - Double(z, vfmsub(x.hi, y, z) + x.lo * y) + Double(z, vfma(x.hi, y, -z) + x.lo * y) end @inline function dmul(x::Double{<:vIEEEFloat}, y::vIEEEFloat, ::False) hx, lx = splitprec(x.hi) @@ -341,7 +351,7 @@ end end @inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True) z = x.hi * y.hi - Double(z, vfmsub(x.hi, y.hi, z) + x.hi * y.lo + x.lo * y.hi) + Double(z, vfma(x.hi, y.hi, -z) + x.hi * y.lo + x.lo * y.hi) end @inline function dmul(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False) hx, lx = splitprec(x.hi) @@ -361,7 +371,7 @@ end # x^2 @inline function dsqu(x::T, ::True) where {T<:vIEEEFloat} z = x * x - Double(z, vfmsub(x, x, z)) + Double(z, vfma(x, x, -z)) end @inline function dsqu(x::T, ::False) where {T<:vIEEEFloat} hx, lx = splitprec(x) @@ -372,7 +382,7 @@ end end @inline function dsqu(x::Double{T}, ::True) where {T<:vIEEEFloat} z = x.hi * x.hi - Double(z, vfmsub(x.hi, x.hi, z) + (x.hi * (x.lo + x.lo))) + Double(z, vfma(x.hi, x.hi, -z) + (x.hi * (x.lo + x.lo))) end @inline function dsqu(x::Double{T}, ::False) where {T<:vIEEEFloat} hx, lx = splitprec(x.hi) @@ -386,7 +396,7 @@ end # sqrt(x) @inline function dsqrt(x::Double{T}, ::True) where {T<:vIEEEFloat} zhi = @fastmath sqrt(x.hi) - Double(zhi, (x.lo + vfnmadd(zhi, zhi, x.hi)) / (zhi + zhi)) + Double(zhi, (x.lo + vfma(-zhi, zhi, x.hi)) / (zhi + zhi)) end @inline function dsqrt(x::Double{T}, ::False) where {T<:vIEEEFloat} c = @fastmath sqrt(x.hi) @@ -399,7 +409,7 @@ end @inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::True) invy = inv(y.hi) zhi = (x.hi * invy) - Double(zhi, ((vfnmadd(zhi, y.hi, x.hi) + vfnmadd(zhi, y.lo, x.lo)) * invy)) + Double(zhi, ((vfma(-zhi, y.hi, x.hi) + vfma(-zhi, y.lo, x.lo)) * invy)) end @inline function ddiv(x::Double{<:vIEEEFloat}, y::Double{<:vIEEEFloat}, ::False) @ieee begin @@ -412,7 +422,7 @@ end @inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::True) ry = inv(y) r = (x * ry) - Double(r, (vfnmadd(r, y, x) * ry)) + Double(r, (vfma(-r, y, x) * ry)) end @inline function ddiv(x::vIEEEFloat, y::vIEEEFloat, ::False) @ieee begin @@ -427,7 +437,7 @@ end # 1/x @inline function drec(x::vIEEEFloat, ::True) zhi = inv(x) - Double(zhi, (vfnmadd(zhi, x, one(eltype(x))) * zhi)) + Double(zhi, (vfma(-zhi, x, one(eltype(x))) * zhi)) end @inline function drec(x::vIEEEFloat, ::False) @ieee begin @@ -439,7 +449,7 @@ end @inline function drec(x::Double{<:vIEEEFloat}, ::True) zhi = inv(x.hi) - Double(zhi, ((vfnmadd(zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi)) + Double(zhi, ((vfma(-zhi, x.hi, one(eltype(x))) - (zhi * x.lo)) * zhi)) end @inline function drec(x::Double{<:vIEEEFloat}, ::False) @ieee begin