Skip to content

Commit 766f663

Browse files
authored
Make FastLDF the default (#1139)
* Make FastLDF the default * Add miscellaneous LogDensityProblems tests * Use `init!!` instead of `fast_evaluate!!` * Rename files, rebalance tests
1 parent 4a11560 commit 766f663

File tree

14 files changed

+584
-1140
lines changed

14 files changed

+584
-1140
lines changed

HISTORY.md

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,17 @@
44

55
### Breaking changes
66

7+
#### Fast Log Density Functions
8+
9+
This version provides a reimplementation of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
10+
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
11+
12+
For more information about how this is accomplished, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
13+
14+
As a result of this change, `LogDensityFunction` no longer stores a VarInfo inside it.
15+
In general, if `ldf` is a `LogDensityFunction`, it is now only valid to access `ldf.model` and `ldf.adtype`.
16+
If you were previously relying on this behaviour, you will need to store a VarInfo separately.
17+
718
#### Parent and leaf contexts
819

920
The `DynamicPPL.NodeTrait` function has been removed.
@@ -24,18 +35,6 @@ Removed the method `returned(::Model, values, keys)`; please use `returned(::Mod
2435
The method `DynamicPPL.init` (for implementing `AbstractInitStrategy`) now has a different signature: it must return a tuple of the generated value, plus a transform function that maps it back to unlinked space.
2536
This is a generalisation of the previous behaviour, where `init` would always return an unlinked value (in effect forcing the transform to be the identity function).
2637

27-
### Other changes
28-
29-
#### FastLDF
30-
31-
Added `DynamicPPL.Experimental.FastLDF`, a version of `LogDensityFunction` that provides performance improvements on the order of 2–10× for both model evaluation as well as automatic differentiation.
32-
Exact speedups depend on the model size: larger models have less significant speedups because the bulk of the work is done in calls to `logpdf`.
33-
34-
Please note that `FastLDF` is currently considered internal and its API may change without warning.
35-
We intend to replace `LogDensityFunction` with `FastLDF` in a release in the near future, but until then we recommend not using it.
36-
37-
For more information about `FastLDF`, please see https://github.com/TuringLang/DynamicPPL.jl/pull/1113 as well as the `src/fasteval.jl` file, which contains extensive comments.
38-
3938
## 0.38.9
4039

4140
Remove warning when using Enzyme as the AD backend.

docs/src/api.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ The [LogDensityProblems.jl](https://github.com/tpapp/LogDensityProblems.jl) inte
6666
LogDensityFunction
6767
```
6868

69+
Internally, this is accomplished using [`init!!`](@ref) on:
70+
71+
```@docs
72+
OnlyAccsVarInfo
73+
```
74+
6975
## Condition and decondition
7076

7177
A [`Model`](@ref) can be conditioned on a set of observations with [`AbstractPPL.condition`](@ref) or its alias [`|`](@ref).
@@ -510,7 +516,7 @@ The function `init!!` is used to initialise, or overwrite, values in a VarInfo.
510516
It is really a thin wrapper around using `evaluate!!` with an `InitContext`.
511517

512518
```@docs
513-
DynamicPPL.init!!
519+
init!!
514520
```
515521

516522
To accomplish this, an initialisation _strategy_ is required, which defines how new values are to be obtained.

ext/DynamicPPLMarginalLogDensitiesExt.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,13 @@ using MarginalLogDensities: MarginalLogDensities
66
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
77
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
88
# below.
9-
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction}
9+
struct LogDensityFunctionWrapper{
10+
L<:DynamicPPL.LogDensityFunction,V<:DynamicPPL.AbstractVarInfo
11+
}
1012
logdensity::L
13+
# This field is used only to reconstruct the VarInfo later on; it's not needed for the
14+
# actual log-density evaluation.
15+
varinfo::V
1116
end
1217
function (lw::LogDensityFunctionWrapper)(x, _)
1318
return LogDensityProblems.logdensity(lw.logdensity, x)
@@ -101,7 +106,7 @@ function DynamicPPL.marginalize(
101106
# Construct the marginal log-density model.
102107
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
103108
mld = MarginalLogDensities.MarginalLogDensity(
104-
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs...
109+
LogDensityFunctionWrapper(f, varinfo), varinfo[:], varindices, (), method; kwargs...
105110
)
106111
return mld
107112
end
@@ -190,7 +195,7 @@ function DynamicPPL.VarInfo(
190195
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
191196
)
192197
# Extract the original VarInfo. Its contents will in general be junk.
193-
original_vi = mld.logdensity.logdensity.varinfo
198+
original_vi = mld.logdensity.varinfo
194199
# Extract the stored parameters, which includes the modes for any marginalized
195200
# parameters
196201
full_params = MarginalLogDensities.cached_params(mld)

src/DynamicPPL.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ export AbstractVarInfo,
9292
getargnames,
9393
extract_priors,
9494
values_as_in_model,
95+
# evaluation
96+
evaluate!!,
97+
init!!,
9598
# LogDensityFunction
9699
LogDensityFunction,
100+
OnlyAccsVarInfo,
97101
# Leaf contexts
98102
AbstractContext,
99103
contextualize,

src/chains.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ end
137137
"""
138138
ParamsWithStats(
139139
param_vector::AbstractVector,
140-
ldf::DynamicPPL.Experimental.FastLDF,
140+
ldf::DynamicPPL.LogDensityFunction,
141141
stats::NamedTuple=NamedTuple();
142142
include_colon_eq::Bool=true,
143143
include_log_probs::Bool=true,
@@ -156,7 +156,7 @@ via `unflatten` plus re-evaluation. It is faster for two reasons:
156156
"""
157157
function ParamsWithStats(
158158
param_vector::AbstractVector,
159-
ldf::DynamicPPL.Experimental.FastLDF,
159+
ldf::DynamicPPL.LogDensityFunction,
160160
stats::NamedTuple=NamedTuple();
161161
include_colon_eq::Bool=true,
162162
include_log_probs::Bool=true,
@@ -174,9 +174,7 @@ function ParamsWithStats(
174174
else
175175
(DynamicPPL.ValuesAsInModelAccumulator(include_colon_eq),)
176176
end
177-
_, vi = DynamicPPL.Experimental.fast_evaluate!!(
178-
ldf.model, strategy, AccumulatorTuple(accs)
179-
)
177+
_, vi = DynamicPPL.init!!(ldf.model, OnlyAccsVarInfo(AccumulatorTuple(accs)), strategy)
180178
params = DynamicPPL.getacc(vi, Val(:ValuesAsInModel)).values
181179
if include_log_probs
182180
stats = merge(

src/experimental.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,6 @@ module Experimental
22

33
using DynamicPPL: DynamicPPL
44

5-
include("fasteval.jl")
6-
75
# This file only defines the names of the functions, and their docstrings. The actual implementations are in `ext/DynamicPPLJETExt.jl`, since we don't want to depend on JET.jl other than as a weak dependency.
86
"""
97
is_suitable_varinfo(model::Model, varinfo::AbstractVarInfo; kwargs...)

0 commit comments

Comments
 (0)