diff --git a/src/higher_fwd_rules.jl b/src/higher_fwd_rules.jl index 8486b8bd..e8a7e32e 100644 --- a/src/higher_fwd_rules.jl +++ b/src/higher_fwd_rules.jl @@ -30,16 +30,16 @@ end # TODO: It's a bit embarassing that we need to write these out, but currently the # compiler is not strong enough to automatically lift the frule. Let's hope we # can delete these in the near future. -function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N, T}, b::TaylorBundle{N, T}) where {N, T} TaylorBundle{N}(primal(a) + primal(b), map(+, a.tangent.coeffs, b.tangent.coeffs)) end -function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N}, b::AbstractZeroBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(+)}, a::TaylorBundle{N, T}, b::AbstractZeroBundle{N, T}) where {N, T} TaylorBundle{N}(primal(a) + primal(b), a.tangent.coeffs) end -function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(-)}, a::TaylorBundle{N}, b::TaylorBundle{N}) where {N} +function (∂☆ₙ::∂☆{N})(fb::AbstractZeroBundle{N, typeof(-)}, a::TaylorBundle{N, T}, b::TaylorBundle{N, T}) where {N, T} TaylorBundle{N}(primal(a) - primal(b), map(-, a.tangent.coeffs, b.tangent.coeffs)) end diff --git a/test/forward.jl b/test/forward.jl index f8040639..0d9b6dc3 100644 --- a/test/forward.jl +++ b/test/forward.jl @@ -172,6 +172,27 @@ end end end +@testset "binops of mixed number types" begin + # We have had issues with mixed number types before + struct StoreHalfed <: Number + val::Float64 + StoreHalfed(x) = new(x/2) + end + Base.:-(x::StoreHalfed, y::Number) = 2*x.val - y + Base.:+(x::StoreHalfed, y::Number) = 2*x.val + y + + sub_sh(a) = StoreHalfed(a) - 10*a + add_sh(a) = StoreHalfed(a) + 10*a + let var"'" = Diffractor.PrimeDerivativeFwd + @test add_sh'(100.0) == 11.0 + @test add_sh''(100.0) == 0.0 + + @test sub_sh'(100.0) == -9.0 + @test sub_sh''(100.0) == 0.0 + end +end + + @testset "taylor_compatible" begin taylor_compatible = Diffractor.taylor_compatible