Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,15 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
"""
function ParamsWithStats(
param_vector::AbstractVector,
ldf::DynamicPPL.LogDensityFunction,
ldf::DynamicPPL.LogDensityFunction{Tlink},
stats::NamedTuple=NamedTuple();
include_colon_eq::Bool=true,
include_log_probs::Bool=true,
)
) where {Tlink}
strategy = InitFromParams(
VectorWithRanges(ldf._iden_varname_ranges, ldf._varname_ranges, param_vector),
VectorWithRanges{Tlink}(
ldf._iden_varname_ranges, ldf._varname_ranges, param_vector
),
nothing,
)
accs = if include_log_probs
Expand Down
33 changes: 28 additions & 5 deletions src/contexts/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ struct RangeAndLinked
end

"""
VectorWithRanges(
VectorWithRanges{Tlink}(
iden_varname_ranges::NamedTuple,
varname_ranges::Dict{VarName,RangeAndLinked},
vect::AbstractVector{<:Real},
Expand All @@ -223,6 +223,12 @@ end
A struct that wraps a vector of parameter values, plus information about how random
variables map to ranges in that vector.

The type parameter `Tlink` can be either `true` or `false`, to mark that the variables in
this `VectorWithRanges` are linked/not linked, or `nothing` if either the linking status is
not known or is mixed, i.e. some are linked while others are not. Using `nothing` does not
affect functionality or correctness, but causes more work to be done at runtime, with
possible impacts on type stability and performance.

In the simplest case, this could be accomplished only with a single dictionary mapping
VarNames to ranges and link status. However, for performance reasons, we separate out
VarNames with identity optics into a NamedTuple (`iden_varname_ranges`). All
Expand All @@ -231,13 +237,26 @@ non-identity-optic VarNames are stored in the `varname_ranges` Dict.
It would be nice to improve the NamedTuple and Dict approach. See, e.g.
https://github.com/TuringLang/DynamicPPL.jl/issues/1116.
"""
struct VectorWithRanges{N<:NamedTuple,T<:AbstractVector{<:Real}}
struct VectorWithRanges{Tlink,N<:NamedTuple,T<:AbstractVector{<:Real}}
# This NamedTuple stores the ranges for identity VarNames
iden_varname_ranges::N
# This Dict stores the ranges for all other VarNames
varname_ranges::Dict{VarName,RangeAndLinked}
# The full parameter vector which we index into to get variable values
vect::T

function VectorWithRanges{Tlink}(
iden_varname_ranges::N, varname_ranges::Dict{VarName,RangeAndLinked}, vect::T
) where {Tlink,N,T}
if !(Tlink isa Union{Bool,Nothing})
throw(
ArgumentError(
"VectorWithRanges type parameter has to be one of `true`, `false`, or `nothing`.",
),
)
end
return new{Tlink,N,T}(iden_varname_ranges, varname_ranges, vect)
end
end

function _get_range_and_linked(
Expand All @@ -252,11 +271,15 @@ function init(
::Random.AbstractRNG,
vn::VarName,
dist::Distribution,
p::InitFromParams{<:VectorWithRanges},
)
p::InitFromParams{<:VectorWithRanges{T}},
) where {T}
vr = p.params
range_and_linked = _get_range_and_linked(vr, vn)
transform = if range_and_linked.is_linked
# T can either be `nothing` (i.e., link status is mixed, in which
# case we use the stored link status), or `true` / `false`, which
# indicates that all variables are linked / unlinked.
linked = isnothing(T) ? range_and_linked.is_linked : T
transform = if linked
from_linked_vec_transform(dist)
else
from_vec_transform(dist)
Expand Down
56 changes: 43 additions & 13 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ with such models.** This is a general limitation of vectorised parameters: the o
`unflatten` + `evaluate!!` approach also fails with such models.
"""
struct LogDensityFunction{
# true if all variables are linked; false if all variables are unlinked; nothing if
# mixed
Tlink,
M<:Model,
AD<:Union{ADTypes.AbstractADType,Nothing},
F<:Function,
Expand All @@ -163,6 +166,21 @@ struct LogDensityFunction{
# Figure out which variable corresponds to which index, and
# which variables are linked.
all_iden_ranges, all_ranges = get_ranges_and_linked(varinfo)
# Figure out if all variables are linked, unlinked, or mixed
link_statuses = Bool[]
for ral in all_iden_ranges
push!(link_statuses, ral.is_linked)
end
for (_, ral) in all_ranges
push!(link_statuses, ral.is_linked)
end
Tlink = if all(link_statuses)
true
elseif all(!s for s in link_statuses)
false
else
nothing
end
x = [val for val in varinfo[:]]
dim = length(x)
# Do AD prep if needed
Expand All @@ -172,12 +190,13 @@ struct LogDensityFunction{
# Make backend-specific tweaks to the adtype
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
DI.prepare_gradient(
LogDensityAt(model, getlogdensity, all_iden_ranges, all_ranges),
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
adtype,
x,
)
end
return new{
Tlink,
typeof(model),
typeof(adtype),
typeof(getlogdensity),
Expand Down Expand Up @@ -209,36 +228,45 @@ end
ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))

struct LogDensityAt{M<:Model,F<:Function,N<:NamedTuple}
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
model::M
getlogdensity::F
iden_varname_ranges::N
varname_ranges::Dict{VarName,RangeAndLinked}

function LogDensityAt{Tlink}(
model::M,
getlogdensity::F,
iden_varname_ranges::N,
varname_ranges::Dict{VarName,RangeAndLinked},
) where {Tlink,M,F,N}
return new{Tlink,M,F,N}(model, getlogdensity, iden_varname_ranges, varname_ranges)
end
end
function (f::LogDensityAt)(params::AbstractVector{<:Real})
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
strategy = InitFromParams(
VectorWithRanges(f.iden_varname_ranges, f.varname_ranges, params), nothing
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
)
accs = ldf_accs(f.getlogdensity)
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
return f.getlogdensity(vi)
end

function LogDensityProblems.logdensity(
ldf::LogDensityFunction, params::AbstractVector{<:Real}
)
return LogDensityAt(
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
) where {Tlink}
return LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
)(
params
)
end

function LogDensityProblems.logdensity_and_gradient(
ldf::LogDensityFunction, params::AbstractVector{<:Real}
)
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
) where {Tlink}
return DI.value_and_gradient(
LogDensityAt(
LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
),
ldf._adprep,
Expand All @@ -247,12 +275,14 @@ function LogDensityProblems.logdensity_and_gradient(
)
end

function LogDensityProblems.capabilities(::Type{<:LogDensityFunction{M,Nothing}}) where {M}
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{T,M,Nothing}}
) where {T,M}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,<:ADTypes.AbstractADType}}
) where {M}
::Type{<:LogDensityFunction{T,M,<:ADTypes.AbstractADType}}
) where {T,M}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.dimension(ldf::LogDensityFunction)
Expand Down
10 changes: 7 additions & 3 deletions test/integration/enzyme/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@ using Test: @test, @testset
import Enzyme: set_runtime_activity, Forward, Reverse, Const
using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test

ADTYPES = Dict(
"EnzymeForward" =>
ADTYPES = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Entirely ambivalent about which constructor to use, but curious if you had a reason for changing.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, CI crashes with a Julia GC error when using a Dict.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(don't ask 🙃)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

...

(
"EnzymeForward",
AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const),
"EnzymeReverse" =>
),
(
"EnzymeReverse",
AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const),
),
)

@testset "$ad_key" for (ad_key, ad_type) in ADTYPES
Expand Down
16 changes: 16 additions & 0 deletions test/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,22 @@ end
end
end

@testset "LogDensityFunction: Type stability" begin
@testset "$(m.f)" for m in DynamicPPL.TestUtils.DEMO_MODELS
unlinked_vi = DynamicPPL.VarInfo(m)
@testset "$islinked" for islinked in (false, true)
vi = if islinked
DynamicPPL.link!!(unlinked_vi, m)
else
unlinked_vi
end
ldf = DynamicPPL.LogDensityFunction(m, DynamicPPL.getlogjoint_internal, vi)
x = vi[:]
@inferred LogDensityProblems.logdensity(ldf, x)
end
end
end

@testset "LogDensityFunction: performance" begin
if Threads.nthreads() == 1
# Evaluating these three models should not lead to any allocations (but only when
Expand Down