diff --git a/src/chains.jl b/src/chains.jl index 2fcd4e713..d01606c3d 100644 --- a/src/chains.jl +++ b/src/chains.jl @@ -130,13 +130,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons: """ function ParamsWithStats( param_vector::AbstractVector, - ldf::DynamicPPL.LogDensityFunction, + ldf::DynamicPPL.LogDensityFunction{Tlink}, stats::NamedTuple=NamedTuple(); include_colon_eq::Bool=true, include_log_probs::Bool=true, -) +) where {Tlink} strategy = InitFromParams( - VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector), + VectorWithRanges{Tlink}( + ldf._iden_varname_ranges, ldf._varname_ranges, param_vector + ), nothing, ) accs = if include_log_probs diff --git a/src/contexts/init.jl b/src/contexts/init.jl index a79969a13..80a494c23 100644 --- a/src/contexts/init.jl +++ b/src/contexts/init.jl @@ -214,7 +214,7 @@ struct RangeAndLinked end """ - VectorWithRanges( + VectorWithRanges{Tlink}( iden_varname_ranges::NamedTuple, varname_ranges::Dict{VarName,RangeAndLinked}, vect::AbstractVector{<:Real}, @@ -223,6 +223,12 @@ end A struct that wraps a vector of parameter values, plus information about how random variables map to ranges in that vector. +The type parameter `Tlink` can be either `true` or `false`, to mark that the variables in +this `VectorWithRanges` are linked/not linked, or `nothing` if either the linking status is +not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not +affect functionality or correctness, but causes more work to be done at runtime, with +possible impacts on type stability and performance. + In the simplest case, this could be accomplished only with a single dictionary mapping VarNames to ranges and link status. However, for performance reasons, we separate out VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All @@ -231,13 +237,26 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict. It would be nice to improve the NamedTuple and Dict approach. See, e.g. https://github.com/TuringLang/DynamicPPL.jl/issues/1116. """ -struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}} +struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}} # This NamedTuple stores the ranges for identity VarNames iden_varname_ranges::N # This Dict stores the ranges for all other VarNames varname_ranges::Dict{VarName,RangeAndLinked} # The full parameter vector which we index into to get variable values vect::T + + function VectorWithRanges{Tlink}( + iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T + ) where {Tlink,N,T} + if !(Tlink isa Union{Bool,Nothing}) + throw( + ArgumentError( + "VectorWithRanges type parameter has to be one of `true`, `false`, or `nothing`.", + ), + ) + end + return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect) + end end function _get_range_and_linked( @@ -252,11 +271,15 @@ function init( ::Random.AbstractRNG, vn::VarName, dist::Distribution, - p::InitFromParams{<:VectorWithRanges}, -) + p::InitFromParams{<:VectorWithRanges{T}}, +) where {T} vr = p.params range_and_linked = _get_range_and_linked(vr, vn) - transform = if range_and_linked.is_linked + # T can either be `nothing` (i.e., link status is mixed, in which + # case we use the stored link status), or `true` / `false`, which + # indicates that all variables are linked / unlinked. + linked = isnothing(T) ? range_and_linked.is_linked : T + transform = if linked from_linked_vec_transform(dist) else from_vec_transform(dist) diff --git a/src/logdensityfunction.jl b/src/logdensityfunction.jl index bcdd0bb25..7d1094fa3 100644 --- a/src/logdensityfunction.jl +++ b/src/logdensityfunction.jl @@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o `unflatten` + `evaluate!!` approach also fails with such models. """ struct LogDensityFunction{ + # true if all variables are linked; false if all variables are unlinked; nothing if + # mixed + Tlink, M<:Model, AD<:Union{ADTypes.AbstractADType,Nothing}, F<:Function, @@ -163,6 +166,21 @@ struct LogDensityFunction{ # Figure out which variable corresponds to which index, and # which variables are linked. all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo) + # Figure out if all variables are linked, unlinked, or mixed + link_statuses = Bool[] + for ral in all_iden_ranges + push!(link_statuses, ral.is_linked) + end + for (_, ral) in all_ranges + push!(link_statuses, ral.is_linked) + end + Tlink = if all(link_statuses) + true + elseif all(!s for s in link_statuses) + false + else + nothing + end x = [val for val in varinfo[:]] dim = length(x) # Do AD prep if needed @@ -172,12 +190,13 @@ struct LogDensityFunction{ # Make backend-specific tweaks to the adtype adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo) DI.prepare_gradient( - LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges), + LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges), adtype, x, ) end return new{ + Tlink, typeof(model), typeof(adtype), typeof(getlogdensity), @@ -209,15 +228,24 @@ end ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),)) ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),)) -struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple} +struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple} model::M getlogdensity::F iden_varname_ranges::N varname_ranges::Dict{VarName,RangeAndLinked} + + function LogDensityAt{Tlink}( + model::M, + getlogdensity::F, + iden_varname_ranges::N, + varname_ranges::Dict{VarName,RangeAndLinked}, + ) where {Tlink,M,F,N} + return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges) + end end -function (f::LogDensityAt)(params::AbstractVector{<:Real}) +function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink} strategy = InitFromParams( - VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing + VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing ) accs = ldf_accs(f.getlogdensity) _, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy) @@ -225,9 +253,9 @@ function (f::LogDensityAt)(params::AbstractVector{<:Real}) end function LogDensityProblems.logdensity( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) - return LogDensityAt( + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} + return LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges )( params @@ -235,10 +263,10 @@ function LogDensityProblems.logdensity( end function LogDensityProblems.logdensity_and_gradient( - ldf::LogDensityFunction, params::AbstractVector{<:Real} -) + ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real} +) where {Tlink} return DI.value_and_gradient( - LogDensityAt( + LogDensityAt{Tlink}( ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges ), ldf._adprep, @@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient( ) end -function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M} +function LogDensityProblems.capabilities( + ::Type{<:LogDensityFunction{T,M,Nothing}} +) where {T,M} return LogDensityProblems.LogDensityOrder{0}() end function LogDensityProblems.capabilities( - ::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}} -) where {M} + ::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}} +) where {T,M} return LogDensityProblems.LogDensityOrder{1}() end function LogDensityProblems.dimension(ldf::LogDensityFunction) diff --git a/test/integration/enzyme/main.jl b/test/integration/enzyme/main.jl index ea4ec497d..edfd67d18 100644 --- a/test/integration/enzyme/main.jl +++ b/test/integration/enzyme/main.jl @@ -5,11 +5,15 @@ using Test: @test, @testset import Enzyme: set_runtime_activity, Forward, Reverse, Const using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test -ADTYPES = Dict( - "EnzymeForward" => +ADTYPES = ( + ( + "EnzymeForward", AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const), - "EnzymeReverse" => + ), + ( + "EnzymeReverse", AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const), + ), ) @testset "$ad_key" for (ad_key, ad_type) in ADTYPES diff --git a/test/logdensityfunction.jl b/test/logdensityfunction.jl index 06492d6e1..f43ed45a4 100644 --- a/test/logdensityfunction.jl +++ b/test/logdensityfunction.jl @@ -108,6 +108,22 @@ end end end +@testset "LogDensityFunction: Type stability" begin + @testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS + unlinked_vi = DynamicPPL.VarInfo(m) + @testset "$islinked" for islinked in (false, true) + vi = if islinked + DynamicPPL.link!!(unlinked_vi, m) + else + unlinked_vi + end + ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi) + x = vi[:] + @inferred LogDensityProblems.logdensity(ldf, x) + end + end +end + @testset "LogDensityFunction: performance" begin if Threads.nthreads() == 1 # Evaluating these three models should not lead to any allocations (but only when