From fd62c7cbe110e1da4e874120a3ac86d43f88cdec Mon Sep 17 00:00:00 2001 From: bzinberg Date: Wed, 14 Oct 2020 12:56:43 -0400 Subject: [PATCH 1/4] Specialize `Base.:+` and `Base.:-` to work on `StaticArray`s See https://github.com/JuliaDiff/ReverseDiff.jl/issues/153 --- src/derivatives/linalg/arithmetic.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/derivatives/linalg/arithmetic.jl b/src/derivatives/linalg/arithmetic.jl index 271af22..5398ef9 100644 --- a/src/derivatives/linalg/arithmetic.jl +++ b/src/derivatives/linalg/arithmetic.jl @@ -33,6 +33,9 @@ for A in ARRAY_TYPES @eval @inline Base.:+(x::$(A), y::TrackedArray{V,D}) where {V,D} = record_plus(x, y, D) end +@inline Base.:+(x::TrackedArray{V,D}, y::StaticArray) where {V,D} = record_plus(x, Array(y), D) +@inline Base.:+(x::StaticArray, y::TrackedArray{V,D}) where {V,D} = record_plus(Array(x), y, D) + function record_plus(x, y, ::Type{D}) where D tp = tape(x, y) out = track(value(x) + value(y), D, tp) @@ -108,6 +111,9 @@ for A in ARRAY_TYPES @eval Base.:-(x::$(A), y::TrackedArray{V,D}) where {V,D} = record_minus(x, y, D) end +@inline Base.:+(x::TrackedArray{V,D}, y::StaticArray) where {V,D} = record_plus(x, Array(y), D) +@inline Base.:+(x::StaticArray, y::TrackedArray{V,D}) where {V,D} = record_plus(Array(x), y, D) + function Base.:-(x::TrackedArray{V,D}) where {V,D} tp = tape(x) out = track(-(value(x)), D, tp) From 622e50f85d58adf0b076e70662b0f682cbac1a13 Mon Sep 17 00:00:00 2001 From: bzinberg Date: Wed, 14 Oct 2020 13:02:37 -0400 Subject: [PATCH 2/4] Specialize `capture` to work with `StaticArray`s --- src/tracked.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tracked.jl b/src/tracked.jl index 049c1a9..fa13eb0 100644 --- a/src/tracked.jl +++ b/src/tracked.jl @@ -223,6 +223,8 @@ unseed!(x::AbstractArray, i) = unseed!(x[i]) capture(t::TrackedReal) = ifelse(hastape(t), t, value(t)) capture(t::TrackedArray) = t capture(t::AbstractArray) = istracked(t) ? map!(capture, similar(t), t) : copy(t) +# `StaticArray`s don't support mutation unless the eltype is a bits type (`isbitstype`). +capture(t::SA) where SA <: StaticArray = istracked(t) ? SA(map(capture, t)) : copy(t) ######################## # Conversion/Promotion # From 2055889f301862c3f21e5beeac6886e04b2d6d4e Mon Sep 17 00:00:00 2001 From: bzinberg Date: Wed, 14 Oct 2020 13:25:55 -0400 Subject: [PATCH 3/4] copy/paste error --- src/derivatives/linalg/arithmetic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/derivatives/linalg/arithmetic.jl b/src/derivatives/linalg/arithmetic.jl index 5398ef9..f3532a7 100644 --- a/src/derivatives/linalg/arithmetic.jl +++ b/src/derivatives/linalg/arithmetic.jl @@ -111,8 +111,8 @@ for A in ARRAY_TYPES @eval Base.:-(x::$(A), y::TrackedArray{V,D}) where {V,D} = record_minus(x, y, D) end -@inline Base.:+(x::TrackedArray{V,D}, y::StaticArray) where {V,D} = record_plus(x, Array(y), D) -@inline Base.:+(x::StaticArray, y::TrackedArray{V,D}) where {V,D} = record_plus(Array(x), y, D) +@inline Base.:-(x::TrackedArray{V,D}, y::StaticArray) where {V,D} = record_minus(x, Array(y), D) +@inline Base.:-(x::StaticArray, y::TrackedArray{V,D}) where {V,D} = record_minus(Array(x), y, D) function Base.:-(x::TrackedArray{V,D}) where {V,D} tp = tape(x) From 9e6e9a838573f85c5ad24b2f75c46296f45fa6bb Mon Sep 17 00:00:00 2001 From: bzinberg Date: Wed, 14 Oct 2020 16:25:57 -0400 Subject: [PATCH 4/4] Specialize `deriv!` to not mutate `StaticArray`s --- src/tracked.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tracked.jl b/src/tracked.jl index fa13eb0..08cf9bd 100644 --- a/src/tracked.jl +++ b/src/tracked.jl @@ -172,6 +172,8 @@ function deriv!(t::NTuple{N,Any}, v::NTuple{N,Any}) where N return nothing end +deriv!(t::StaticArray, v::AbstractArray) = deriv!(Tuple(t), Tuple(v)) + # pulling values from origin # #----------------------------#