Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# DynamicPPL Changelog

## 0.39.2

The internals of `LogDensityFunction` have been changed slightly so that you do not need to specify `function_annotation` when performing AD with Enzyme.jl.

There should also be some minor performance improvements (maybe 10%) on AD with ForwardDiff / Mooncake.

## 0.39.1

`LogDensityFunction` now allows you to call `logdensity_and_gradient(ldf, x)` with `AbstractVector`s `x` that are not plain Vectors (they will be converted internally before calculating the gradient).
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.39.1"
version = "0.39.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
150 changes: 128 additions & 22 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,18 @@ struct LogDensityFunction{
else
# Make backend-specific tweaks to the adtype
adtype = DynamicPPL.tweak_adtype(adtype, model, varinfo)
DI.prepare_gradient(
LogDensityAt{Tlink}(model, getlogdensity, all_iden_ranges, all_ranges),
adtype,
x,
)
args = (model, getlogdensity, all_iden_ranges, all_ranges)
if _use_closure(adtype)
DI.prepare_gradient(LogDensityAt{Tlink}(args...), adtype, x)
else
DI.prepare_gradient(
logdensity_at,
adtype,
x,
DI.Constant(Val{Tlink}()),
map(DI.Constant, args)...,
)
end
end
return new{
Tlink,
Expand Down Expand Up @@ -235,6 +242,47 @@ end
ldf_accs(::typeof(getlogprior)) = AccumulatorTuple((LogPriorAccumulator(),))
ldf_accs(::typeof(getloglikelihood)) = AccumulatorTuple((LogLikelihoodAccumulator(),))

"""
logdensity_at(
params::AbstractVector{<:Real},
::Val{Tlink},
model::Model,
getlogdensity::Function,
iden_varname_ranges::NamedTuple,
varname_ranges::Dict{VarName,RangeAndLinked},
) where {Tlink}

Calculate the log density at the given `params`, using the provided
information extracted from a `LogDensityFunction`.
"""
function logdensity_at(
params::AbstractVector{<:Real},
::Val{Tlink},
model::Model,
getlogdensity::Function,
iden_varname_ranges::NamedTuple,
varname_ranges::Dict{VarName,RangeAndLinked},
) where {Tlink}
strategy = InitFromParams(
VectorWithRanges{Tlink}(iden_varname_ranges, varname_ranges, params), nothing
)
accs = ldf_accs(getlogdensity)
_, vi = DynamicPPL.init!!(model, OnlyAccsVarInfo(accs), strategy)
return getlogdensity(vi)
end

"""
LogDensityAt{Tlink}(
model::Model,
getlogdensity::Function,
iden_varname_ranges::NamedTuple,
varname_ranges::Dict{VarName,RangeAndLinked},
) where {Tlink}

A callable struct that behaves in the same way as `logdensity_at`, but stores the model and
other information internally. Having two separate functions/structs allows for better
performance with AD backends.
"""
struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
model::M
getlogdensity::F
Expand All @@ -251,36 +299,57 @@ struct LogDensityAt{Tlink,M<:Model,F<:Function,N<:NamedTuple}
end
end
function (f::LogDensityAt{Tlink})(params::AbstractVector{<:Real}) where {Tlink}
strategy = InitFromParams(
VectorWithRanges{Tlink}(f.iden_varname_ranges, f.varname_ranges, params), nothing
return logdensity_at(
params,
Val{Tlink}(),
f.model,
f.getlogdensity,
f.iden_varname_ranges,
f.varname_ranges,
)
accs = ldf_accs(f.getlogdensity)
_, vi = DynamicPPL.init!!(f.model, OnlyAccsVarInfo(accs), strategy)
return f.getlogdensity(vi)
end

function LogDensityProblems.logdensity(
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
) where {Tlink}
return LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
)(
params
return logdensity_at(
params,
Val{Tlink}(),
ldf.model,
ldf._getlogdensity,
ldf._iden_varname_ranges,
ldf._varname_ranges,
)
end

function LogDensityProblems.logdensity_and_gradient(
ldf::LogDensityFunction{Tlink}, params::AbstractVector{<:Real}
) where {Tlink}
# `params` has to be converted to the same vector type that was used for AD preparation,
# otherwise the preparation will not be valid.
params = convert(_get_input_vector_type(ldf), params)
return DI.value_and_gradient(
LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
),
ldf._adprep,
ldf.adtype,
params,
)
return if _use_closure(ldf.adtype)
DI.value_and_gradient(
LogDensityAt{Tlink}(
ldf.model, ldf._getlogdensity, ldf._iden_varname_ranges, ldf._varname_ranges
),
ldf._adprep,
ldf.adtype,
params,
)
else
DI.value_and_gradient(
logdensity_at,
ldf._adprep,
ldf.adtype,
params,
DI.Constant(Val{Tlink}()),
DI.Constant(ldf.model),
DI.Constant(ldf._getlogdensity),
DI.Constant(ldf._iden_varname_ranges),
DI.Constant(ldf._varname_ranges),
)
end
end

function LogDensityProblems.capabilities(
Expand Down Expand Up @@ -314,6 +383,43 @@ By default, this just returns the input unchanged.
"""
tweak_adtype(adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo) = adtype

"""
_use_closure(adtype::ADTypes.AbstractADType)

In LogDensityProblems, we want to calculate the derivative of `logdensity(f, x)` with
respect to x, where f is the model (in our case LogDensityFunction or its arguments ) and is
a constant. However, DifferentiationInterface generally expects a single-argument function
g(x) to differentiate.

There are two ways of dealing with this:

1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)

2. Use a constant DI.Context. This lets us pass a two-argument function to DI, as long as we
also give it the 'inactive argument' (i.e. the model) wrapped in `DI.Constant`.

The relative performance of the two approaches, however, depends on the AD backend used.
Some benchmarks are provided here: https://github.com/TuringLang/DynamicPPL.jl/pull/1172

This function is used to determine whether a given AD backend should use a closure or a
constant. If `use_closure(adtype)` returns `true`, then the closure approach will be used.
By default, this function returns `false`, i.e. the constant approach will be used.
"""
# For these AD backends both closure and no closure work, but it is just faster to not use a
# closure (see link in the docstring).
_use_closure(::ADTypes.AutoForwardDiff) = false
_use_closure(::ADTypes.AutoMooncake) = false
_use_closure(::ADTypes.AutoMooncakeForward) = false
# For ReverseDiff, with the compiled tape, you _must_ use a closure because otherwise with
# DI.Constant arguments the tape will always be recompiled upon each call to
# value_and_gradient. For non-compiled ReverseDiff, it is faster to not use a closure.
_use_closure(::ADTypes.AutoReverseDiff{compile}) where {compile} = !compile
# For AutoEnzyme it allows us to avoid setting function_annotation
_use_closure(::ADTypes.AutoEnzyme) = false
# Since for most backends it's faster to not use a closure, we set that as the default
# for unknown AD backends
_use_closure(::ADTypes.AbstractADType) = false

######################################################
# Helper functions to extract ranges and link status #
######################################################
Expand Down
10 changes: 2 additions & 8 deletions test/integration/enzyme/main.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,8 @@ import Enzyme: set_runtime_activity, Forward, Reverse, Const
using ForwardDiff: ForwardDiff # run_ad uses FD for correctness test

ADTYPES = (
(
"EnzymeForward",
AutoEnzyme(; mode=set_runtime_activity(Forward), function_annotation=Const),
),
(
"EnzymeReverse",
AutoEnzyme(; mode=set_runtime_activity(Reverse), function_annotation=Const),
),
("EnzymeForward", AutoEnzyme(; mode=set_runtime_activity(Forward))),
("EnzymeReverse", AutoEnzyme(; mode=set_runtime_activity(Reverse))),
)

@testset "$ad_key" for (ad_key, ad_type) in ADTYPES
Expand Down