@@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
140140`unflatten` + `evaluate!!` approach also fails with such models.
141141"""
142142struct LogDensityFunction{
143+ # true if all variables are linked; false if all variables are unlinked; nothing if
144+ # mixed
145+ Tlink,
143146 M<: Model ,
144147 AD<: Union{ADTypes.AbstractADType,Nothing} ,
145148 F<: Function ,
@@ -163,6 +166,21 @@ struct LogDensityFunction{
163166 # Figure out which variable corresponds to which index, and
164167 # which variables are linked.
165168 all_iden_ranges, all_ranges = get_ranges_and_linked (varinfo)
169+ # Figure out if all variables are linked, unlinked, or mixed
170+ link_statuses = Bool[]
171+ for ral in all_iden_ranges
172+ push! (link_statuses, ral. is_linked)
173+ end
174+ for (_, ral) in all_ranges
175+ push! (link_statuses, ral. is_linked)
176+ end
177+ Tlink = if all (link_statuses)
178+ true
179+ elseif all (! s for s in link_statuses)
180+ false
181+ else
182+ nothing
183+ end
166184 x = [val for val in varinfo[:]]
167185 dim = length (x)
168186 # Do AD prep if needed
@@ -172,12 +190,13 @@ struct LogDensityFunction{
172190 # Make backend-specific tweaks to the adtype
173191 adtype = DynamicPPL. tweak_adtype (adtype, model, varinfo)
174192 DI. prepare_gradient (
175- LogDensityAt (model, getlogdensity, all_iden_ranges, all_ranges),
193+ LogDensityAt {Tlink} (model, getlogdensity, all_iden_ranges, all_ranges),
176194 adtype,
177195 x,
178196 )
179197 end
180198 return new{
199+ Tlink,
181200 typeof (model),
182201 typeof (adtype),
183202 typeof (getlogdensity),
@@ -209,36 +228,45 @@ end
209228fast_ldf_accs (:: typeof (getlogprior)) = AccumulatorTuple ((LogPriorAccumulator (),))
210229fast_ldf_accs (:: typeof (getloglikelihood)) = AccumulatorTuple ((LogLikelihoodAccumulator (),))
211230
212- struct LogDensityAt{M<: Model ,F<: Function ,N<: NamedTuple }
231+ struct LogDensityAt{Tlink, M<: Model ,F<: Function ,N<: NamedTuple }
213232 model:: M
214233 getlogdensity:: F
215234 iden_varname_ranges:: N
216235 varname_ranges:: Dict{VarName,RangeAndLinked}
236+
237+ function LogDensityAt {Tlink} (
238+ model:: M ,
239+ getlogdensity:: F ,
240+ iden_varname_ranges:: N ,
241+ varname_ranges:: Dict{VarName,RangeAndLinked} ,
242+ ) where {Tlink,M,F,N}
243+ return new {Tlink,M,F,N} (model, getlogdensity, iden_varname_ranges, varname_ranges)
244+ end
217245end
218- function (f:: LogDensityAt )(params:: AbstractVector{<:Real} )
246+ function (f:: LogDensityAt{Tlink} )(params:: AbstractVector{<:Real} ) where {Tlink}
219247 strategy = InitFromParams (
220- VectorWithRanges (f. iden_varname_ranges, f. varname_ranges, params), nothing
248+ VectorWithRanges {Tlink} (f. iden_varname_ranges, f. varname_ranges, params), nothing
221249 )
222250 accs = fast_ldf_accs (f. getlogdensity)
223251 _, vi = DynamicPPL. init!! (f. model, OnlyAccsVarInfo (accs), strategy)
224252 return f. getlogdensity (vi)
225253end
226254
227255function LogDensityProblems. logdensity (
228- ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
229- )
230- return LogDensityAt (
256+ ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
257+ ) where {Tlink}
258+ return LogDensityAt {Tlink} (
231259 ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
232260 )(
233261 params
234262 )
235263end
236264
237265function LogDensityProblems. logdensity_and_gradient (
238- ldf:: LogDensityFunction , params:: AbstractVector{<:Real}
239- )
266+ ldf:: LogDensityFunction{Tlink} , params:: AbstractVector{<:Real}
267+ ) where {Tlink}
240268 return DI. value_and_gradient (
241- LogDensityAt (
269+ LogDensityAt {Tlink} (
242270 ldf. model, ldf. _getlogdensity, ldf. _iden_varname_ranges, ldf. _varname_ranges
243271 ),
244272 ldf. _adprep,
@@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient(
247275 )
248276end
249277
250- function LogDensityProblems. capabilities (:: Type{<:LogDensityFunction{M,Nothing}} ) where {M}
278+ function LogDensityProblems. capabilities (
279+ :: Type{<:LogDensityFunction{T,M,Nothing}}
280+ ) where {T,M}
251281 return LogDensityProblems. LogDensityOrder {0} ()
252282end
253283function LogDensityProblems. capabilities (
254- :: Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
255- ) where {M}
284+ :: Type{<:LogDensityFunction{T, M,<:ADTypes.AbstractADType}}
285+ ) where {T, M}
256286 return LogDensityProblems. LogDensityOrder {1} ()
257287end
258288function LogDensityProblems. dimension (ldf:: LogDensityFunction )
0 commit comments