Skip to content

Commit ca505eb

Browse files
Merge pull request #30 from frankschae/ReverseDiff_buffer
Change type annotation and `similar` to `zero`
2 parents ff072ea + fab0ffd commit ca505eb

File tree

4 files changed

+63
-2
lines changed

4 files changed

+63
-2
lines changed

Project.toml

+6-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ version = "0.4.1"
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
10+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
1011

1112
[compat]
1213
Adapt = "3"
@@ -15,15 +16,19 @@ ForwardDiff = "0.10.3"
1516
julia = "1.6"
1617

1718
[extras]
19+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1820
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
1921
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2022
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
2123
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
2224
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2325
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
26+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2427
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2528
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
29+
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
2630
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
31+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2732

2833
[targets]
29-
test = ["LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL"]
34+
test = ["FiniteDiff", "LabelledArrays", "LinearAlgebra", "OrdinaryDiffEq", "Test", "Random", "RecursiveArrayTools", "Pkg", "SafeTestsets", "Optimization", "OptimizationOptimJL", "SciMLSensitivity", "Zygote"]

src/PreallocationTools.jl

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module PreallocationTools
22

33
using ForwardDiff, ArrayInterfaceCore, Adapt
4+
import ReverseDiff
45

56
struct DiffCache{T <: AbstractArray, S <: AbstractArray}
67
du::T
@@ -95,7 +96,17 @@ function Base.getindex(b::LazyBufferCache, u::T) where {T <: AbstractArray}
9596
s = b.sizemap(size(u)) # required buffer size
9697
buf = get!(b.bufs, (T, s)) do
9798
similar(u, s) # buffer to allocate if it was not found in b.bufs
98-
end::T # declare type since b.bufs dictionary is untyped
99+
end::T # declare type since b.bufs dictionary is untyped
100+
return buf
101+
end
102+
103+
function Base.getindex(b::LazyBufferCache, u::ReverseDiff.TrackedArray)
104+
s = b.sizemap(size(u)) # required buffer size
105+
T = ReverseDiff.TrackedArray
106+
buf = get!(b.bufs, (T, s)) do
107+
# declare type since b.bufs dictionary is untyped
108+
similar(u, s)
109+
end
99110
return buf
100111
end
101112

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ if GROUP == "All" || GROUP == "Core"
1515
@safetestset "ODE tests" begin include("core_odes.jl") end
1616
@safetestset "Resizing" begin include("core_resizing.jl") end
1717
@safetestset "Nested Duals" begin include("core_nesteddual.jl") end
18+
@safetestset "ODE Sensitivity analysis" begin include("upstream/sensitivity_analysis.jl") end
1819
end
1920

2021
if !is_APPVEYOR && GROUP == "GPU"

test/upstream/sensitivity_analysis.jl

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
using LinearAlgebra, OrdinaryDiffEq, Test, PreallocationTools
2+
using Random, FiniteDiff, ForwardDiff, ReverseDiff, SciMLSensitivity, Zygote
3+
4+
# see https://github.com/SciML/PreallocationTools.jl/issues/29
5+
@testset "VJP computation with LazyBuffer" begin
6+
u0 = rand(2, 2)
7+
p = rand(2, 2)
8+
struct foo{T}
9+
lbc::T
10+
end
11+
12+
f = foo(LazyBufferCache())
13+
14+
function (f::foo)(du, u, p, t)
15+
tmp = f.lbc[u]
16+
mul!(tmp, p, u) # avoid tmp = p*u
17+
@. du = u + tmp
18+
nothing
19+
end
20+
21+
prob = ODEProblem(f, u0, (0.0, 1.0), p)
22+
23+
function loss(u0, p; sensealg = nothing)
24+
_prob = remake(prob, u0 = u0, p = p)
25+
_sol = solve(_prob, Tsit5(), sensealg = sensealg, saveat = 0.1, abstol = 1e-14,
26+
reltol = 1e-14)
27+
sum(abs2, _sol)
28+
end
29+
30+
loss(u0, p)
31+
32+
du0 = FiniteDiff.finite_difference_gradient(u0 -> loss(u0, p), u0)
33+
dp = FiniteDiff.finite_difference_gradient(p -> loss(u0, p), p)
34+
Fdu0 = ForwardDiff.gradient(u0 -> loss(u0, p), u0)
35+
Fdp = ForwardDiff.gradient(p -> loss(u0, p), p)
36+
@test du0Fdu0 rtol=1e-8
37+
@test dpFdp rtol=1e-8
38+
39+
Zdu0, Zdp = Zygote.gradient((u0, p) -> loss(u0, p;
40+
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP())),
41+
u0, p)
42+
@test du0Zdu0 rtol=1e-8
43+
@test dpZdp rtol=1e-8
44+
end

0 commit comments

Comments
 (0)