Skip to content

Commit 60736ba

Browse files
committed
Improve type stability when all parameters are linked or unlinked
1 parent 766f663 commit 60736ba

File tree

3 files changed

+74
-18
lines changed

3 files changed

+74
-18
lines changed

src/contexts/init.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ struct RangeAndLinked
214214
end
215215

216216
"""
217-
VectorWithRanges(
217+
VectorWithRanges{Tlink}(
218218
iden_varname_ranges::NamedTuple,
219219
varname_ranges::Dict{VarName,RangeAndLinked},
220220
vect::AbstractVector{<:Real},
@@ -231,13 +231,19 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
231231
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
232232
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
233233
"""
234-
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
234+
struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
235235
# This NamedTuple stores the ranges for identity VarNames
236236
iden_varname_ranges::N
237237
# This Dict stores the ranges for all other VarNames
238238
varname_ranges::Dict{VarName,RangeAndLinked}
239239
# The full parameter vector which we index into to get variable values
240240
vect::T
241+
242+
function VectorWithRanges{Tlink}(
243+
iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T
244+
) where {Tlink,N,T}
245+
return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect)
246+
end
241247
end
242248

243249
function _get_range_and_linked(
@@ -252,11 +258,15 @@ function init(
252258
::Random.AbstractRNG,
253259
vn::VarName,
254260
dist::Distribution,
255-
p::InitFromParams{<:VectorWithRanges},
256-
)
261+
p::InitFromParams{<:VectorWithRanges{T}},
262+
) where {T}
257263
vr = p.params
258264
range_and_linked = _get_range_and_linked(vr, vn)
259-
transform = if range_and_linked.is_linked
265+
# T can either be `nothing` (i.e., link status is mixed, in which
266+
# case we use the stored link status), or `true` / `false`, which
267+
# indicates that all variables are linked / unlinked.
268+
linked = isnothing(T) ? range_and_linked.is_linked : T
269+
transform = if linked
260270
from_linked_vec_transform(dist)
261271
else
262272
from_vec_transform(dist)

src/logdensityfunction.jl

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
140140
`unflatten` + `evaluate!!` approach also fails with such models.
141141
"""
142142
struct LogDensityFunction{
143+
# true if all variables are linked; false if all variables are unlinked; nothing if
144+
# mixed
145+
Tlink,
143146
M<:Model,
144147
AD<:Union{ADTypes.AbstractADType,Nothing},
145148
F<:Function,
@@ -163,6 +166,21 @@ struct LogDensityFunction{
163166
# Figure out which variable corresponds to which index, and
164167
# which variables are linked.
165168
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
169+
# Figure out if all variables are linked, unlinked, or mixed
170+
link_statuses = Bool[]
171+
for ral in all_iden_ranges
172+
push!(link_statuses, ral.is_linked)
173+
end
174+
for (_, ral) in all_ranges
175+
push!(link_statuses, ral.is_linked)
176+
end
177+
Tlink = if all(link_statuses)
178+
true
179+
elseif all(!s for s in link_statuses)
180+
false
181+
else
182+
nothing
183+
end
166184
x = [val for val in varinfo[:]]
167185
dim = length(x)
168186
# Do AD prep if needed
@@ -172,12 +190,13 @@ struct LogDensityFunction{
172190
# Make backend-specific tweaks to the adtype
173191
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
174192
DI.prepare_gradient(
175-
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
193+
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
176194
adtype,
177195
x,
178196
)
179197
end
180198
return new{
199+
Tlink,
181200
typeof(model),
182201
typeof(adtype),
183202
typeof(getlogdensity),
@@ -209,36 +228,45 @@ end
209228
fast_ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
210229
fast_ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))
211230

212-
struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
231+
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
213232
model::M
214233
getlogdensity::F
215234
iden_varname_ranges::N
216235
varname_ranges::Dict{VarName,RangeAndLinked}
236+
237+
function LogDensityAt{Tlink}(
238+
model::M,
239+
getlogdensity::F,
240+
iden_varname_ranges::N,
241+
varname_ranges::Dict{VarName,RangeAndLinked},
242+
) where {Tlink,M,F,N}
243+
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
244+
end
217245
end
218-
function (f::LogDensityAt)(params::AbstractVector{<:Real})
246+
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
219247
strategy = InitFromParams(
220-
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
248+
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
221249
)
222250
accs = fast_ldf_accs(f.getlogdensity)
223251
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
224252
return f.getlogdensity(vi)
225253
end
226254

227255
function LogDensityProblems.logdensity(
228-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
229-
)
230-
return LogDensityAt(
256+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
257+
) where {Tlink}
258+
return LogDensityAt{Tlink}(
231259
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
232260
)(
233261
params
234262
)
235263
end
236264

237265
function LogDensityProblems.logdensity_and_gradient(
238-
ldf::LogDensityFunction, params::AbstractVector{<:Real}
239-
)
266+
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
267+
) where {Tlink}
240268
return DI.value_and_gradient(
241-
LogDensityAt(
269+
LogDensityAt{Tlink}(
242270
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
243271
),
244272
ldf._adprep,
@@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient(
247275
)
248276
end
249277

250-
function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
278+
function LogDensityProblems.capabilities(
279+
::Type{<:LogDensityFunction{T,M,Nothing}}
280+
) where {T,M}
251281
return LogDensityProblems.LogDensityOrder{0}()
252282
end
253283
function LogDensityProblems.capabilities(
254-
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
255-
) where {M}
284+
::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}}
285+
) where {T,M}
256286
return LogDensityProblems.LogDensityOrder{1}()
257287
end
258288
function LogDensityProblems.dimension(ldf::LogDensityFunction)

test/logdensityfunction.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,22 @@ end
108108
end
109109
end
110110

111+
@testset "LogDensityFunction: Type stability" begin
112+
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
113+
unlinked_vi = DynamicPPL.VarInfo(m)
114+
@testset "$islinked" for islinked in (false, true)
115+
vi = if islinked
116+
DynamicPPL.link!!(unlinked_vi, m)
117+
else
118+
unlinked_vi
119+
end
120+
ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi)
121+
x = vi[:]
122+
@inferred LogDensityProblems.logdensity(ldf, x)
123+
end
124+
end
125+
end
126+
111127
@testset "LogDensityFunction: performance" begin
112128
if Threads.nthreads() == 1
113129
# Evaluating these three models should not lead to any allocations (but only when

0 commit comments

Comments
 (0)