Skip to content

Commit 0af3894

Browse files
MarginalLogDensities extension (#1036)
* Transfer MarginalLogDensities extension from Turing See: TuringLang/Turing.jl#2664 Co-authored-by: Sam Urmy <[email protected]> * Add documentation * Bump patch, add changelog * Add compat entry for MLD * Fix docs * Allow user to specify VarInfo used for marginalisation * Use linked varinfo by default * Make the non-essential stuff all keyword arguments * Fix docs * Add `VarInfo(::MarginalLogDensity)` method * Use new `cached_params` function * Add more detailed changelog Co-authored-by: Sam Urmy <[email protected]> * Add error hint if marginalize is called before loading MLD * fix comma -> semicolon typo * remove test with symbol * Update changelog, use -ise in prose --------- Co-authored-by: Sam Urmy <[email protected]>
1 parent 73113c7 commit 0af3894

File tree

10 files changed

+374
-3
lines changed

10 files changed

+374
-3
lines changed

HISTORY.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
# DynamicPPL Changelog
22

3+
## 0.37.4
4+
5+
An extension for MarginalLogDensities.jl has been added.
6+
7+
Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalise out variables from a model.
8+
This is useful for averaging out random effects or nuisance parameters while improving inference on fixed effects/parameters of interest.
9+
The `marginalize` function returns a `MarginalLogDensities.MarginalLogDensity`, a function-like callable struct that returns the approximate log-density of a subset of the parameters after integrating out the rest of them.
10+
By default, this uses the Laplace approximation and sparse AD, making the marginalisation computationally very efficient.
11+
Note that the Laplace approximation relies on the model being differentiable with respect to the marginalised variables, and that their posteriors are unimodal and approximately Gaussian.
12+
13+
Please see [the MarginalLogDensities documentation](https://eloceanografo.github.io/MarginalLogDensities.jl/stable) and the [new Marginalisation section of the DynamicPPL documentation](https://turinglang.org/DynamicPPL.jl/v0.37/api/#Marginalisation) for further information.
14+
315
## 0.37.3
416

517
Prevents inlining of `DynamicPPL.istrans` with Enzyme, which allows Enzyme to differentiate models where `VarName`s have the same symbol but different types.

Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.37.3"
3+
version = "0.37.4"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -33,6 +33,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3333
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
3434
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3535
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
36+
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
3637
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3738

3839
[extensions]
@@ -41,6 +42,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
4142
DynamicPPLForwardDiffExt = ["ForwardDiff"]
4243
DynamicPPLJETExt = ["JET"]
4344
DynamicPPLMCMCChainsExt = ["MCMCChains"]
45+
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
4446
DynamicPPLMooncakeExt = ["Mooncake"]
4547

4648
[compat]
@@ -66,6 +68,7 @@ LinearAlgebra = "1.6"
6668
LogDensityProblems = "2"
6769
MCMCChains = "6, 7"
6870
MacroTools = "0.5.6"
71+
MarginalLogDensities = "0.4.3"
6972
Mooncake = "0.4.147"
7073
OrderedCollections = "1"
7174
Printf = "1.10"

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1010
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1111
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1212
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
13+
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
1314
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1415

1516
[compat]

docs/make.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@ using Distributions
1111
using DocumenterMermaid
1212
# load MCMCChains package extension to make `predict` available
1313
using MCMCChains
14+
using MarginalLogDensities: MarginalLogDensities
15+
16+
# Need this to document a method which uses a type inside the extension...
17+
DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt)
1418

1519
# Doctest setup
1620
DocMeta.setdocmeta!(
@@ -24,7 +28,11 @@ makedocs(;
2428
format=Documenter.HTML(;
2529
size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3()
2630
),
27-
modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)],
31+
modules=[
32+
DynamicPPL,
33+
Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt),
34+
Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt),
35+
],
2836
pages=[
2937
"Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"]
3038
],

docs/src/api.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,22 @@ When using `predict` with `MCMCChains.Chains`, you can control which variables a
136136
- `include_all=false` (default): Include only newly predicted variables
137137
- `include_all=true`: Include both parameters from the original chain and predicted variables
138138

139+
## Marginalisation
140+
141+
DynamicPPL provides the `marginalize` function to marginalise out variables from a model.
142+
This requires `MarginalLogDensities.jl` to be loaded in your environment.
143+
144+
```@docs
145+
marginalize
146+
```
147+
148+
A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability.
149+
To retrieve a VarInfo object from it, you can use:
150+
151+
```@docs
152+
VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing})
153+
```
154+
139155
## Models within models
140156

141157
One can include models and call another model inside the model function with `left ~ to_submodel(model)`.
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
module DynamicPPLMarginalLogDensitiesExt
2+
3+
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
4+
using MarginalLogDensities: MarginalLogDensities
5+
6+
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
7+
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
8+
# below.
9+
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction}
10+
logdensity::L
11+
end
12+
function (lw::LogDensityFunctionWrapper)(x, _)
13+
return LogDensityProblems.logdensity(lw.logdensity, x)
14+
end
15+
16+
"""
17+
marginalize(
18+
model::DynamicPPL.Model,
19+
marginalized_varnames::AbstractVector{<:VarName};
20+
varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model),
21+
getlogprob=DynamicPPL.getlogjoint,
22+
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
23+
kwargs...,
24+
)
25+
26+
Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal
27+
log-density of the given `model`, after marginalizing out the variables specified in
28+
`varnames`.
29+
30+
The resulting object can be called with a vector of parameter values to compute the marginal
31+
log-density.
32+
33+
## Keyword arguments
34+
35+
- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`,
36+
meaning that the resulting log-density function accepts parameters that have been
37+
transformed to unconstrained space.
38+
39+
- `getlogprob`: A function which specifies which kind of marginal log-density to compute.
40+
Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint
41+
probability.
42+
43+
- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the
44+
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/)
45+
for other options.
46+
47+
- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity`
48+
constructor.
49+
50+
## Example
51+
52+
```jldoctest
53+
julia> using DynamicPPL, Distributions, MarginalLogDensities
54+
55+
julia> @model function demo()
56+
x ~ Normal(1.0)
57+
y ~ Normal(2.0)
58+
end
59+
demo (generic function with 2 methods)
60+
61+
julia> marginalized = marginalize(demo(), [:x]);
62+
63+
julia> # The resulting callable computes the marginal log-density of `y`.
64+
marginalized([1.0])
65+
-1.4189385332046727
66+
67+
julia> logpdf(Normal(2.0), 1.0)
68+
-1.4189385332046727
69+
```
70+
71+
72+
!!! warning
73+
74+
The default usage of linked VarInfo means that, for example, optimization of the
75+
marginal log-density can be performed in unconstrained space. However, care must be
76+
taken if the model contains variables where the link transformation depends on a
77+
marginalized variable. For example:
78+
79+
```julia
80+
@model function f()
81+
x ~ Normal()
82+
y ~ truncated(Normal(); lower=x)
83+
end
84+
```
85+
86+
Here, the support of `y`, and hence the link transformation used, depends on the value
87+
of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of
88+
`y` to log-probabilities. However, it will not be possible to use DynamicPPL to
89+
correctly retrieve _unlinked_ values of `y`.
90+
"""
91+
function DynamicPPL.marginalize(
92+
model::DynamicPPL.Model,
93+
marginalized_varnames::AbstractVector{<:VarName};
94+
varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model),
95+
getlogprob::Function=DynamicPPL.getlogjoint,
96+
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
97+
kwargs...,
98+
)
99+
# Determine the indices for the variables to marginalise out.
100+
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames))
101+
# Construct the marginal log-density model.
102+
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
103+
mld = MarginalLogDensities.MarginalLogDensity(
104+
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs...
105+
)
106+
return mld
107+
end
108+
109+
"""
110+
VarInfo(
111+
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
112+
unmarginalized_params::Union{AbstractVector,Nothing}=nothing
113+
)
114+
115+
Retrieve the `VarInfo` object used in the marginalisation process.
116+
117+
If a Laplace approximation was used for the marginalisation, the values of the marginalized
118+
parameters are also set to their mode (note that this only happens if the `mld` object has
119+
been used to compute the marginal log-density at least once, so that the mode has been
120+
computed).
121+
122+
If a vector of `unmarginalized_params` is specified, the values for the corresponding
123+
parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
124+
performing an optimization of the marginal log-density.
125+
126+
All other aspects of the VarInfo, such as link status, are preserved from the original
127+
VarInfo used in the marginalisation.
128+
129+
!!! note
130+
131+
The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
132+
updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the
133+
model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model,
134+
vi))`).
135+
136+
## Example
137+
138+
```jldoctest
139+
julia> using DynamicPPL, Distributions, MarginalLogDensities
140+
141+
julia> @model function demo()
142+
x ~ Normal()
143+
y ~ Beta(2, 2)
144+
end
145+
demo (generic function with 2 methods)
146+
147+
julia> # Note that by default `marginalize` uses a linked VarInfo.
148+
mld = marginalize(demo(), [@varname(x)]);
149+
150+
julia> using MarginalLogDensities: Optimization, OptimizationOptimJL
151+
152+
julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
153+
y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
154+
OptimizationProblem. In-place: true
155+
u0: 1-element Vector{Float64}:
156+
2.0
157+
158+
julia> # This tells us the optimal (linked) value of `y` is around 0.
159+
opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
160+
retcode: Success
161+
u: 1-element Vector{Float64}:
162+
4.88281250001733e-5
163+
164+
julia> # Get the VarInfo corresponding to the mode of `y`.
165+
vi = VarInfo(mld, opt_solution.u);
166+
167+
julia> # `x` is set to its mode (which for `Normal()` is zero).
168+
vi[@varname(x)]
169+
0.0
170+
171+
julia> # `y` is set to the optimal value we found above.
172+
DynamicPPL.getindex_internal(vi, @varname(y))
173+
1-element Vector{Float64}:
174+
4.88281250001733e-5
175+
176+
julia> # To obtain values in the original constrained space, we can either
177+
# use `getindex`:
178+
vi[@varname(y)]
179+
0.5000122070312476
180+
181+
julia> # Or invlink the entire VarInfo object using the model:
182+
vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
183+
2-element Vector{Float64}:
184+
0.0
185+
0.5000122070312476
186+
```
187+
"""
188+
function DynamicPPL.VarInfo(
189+
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
190+
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
191+
)
192+
# Extract the original VarInfo. Its contents will in general be junk.
193+
original_vi = mld.logdensity.logdensity.varinfo
194+
# Extract the stored parameters, which includes the modes for any marginalized
195+
# parameters
196+
full_params = MarginalLogDensities.cached_params(mld)
197+
# We can then (if needed) set the values for any non-marginalized parameters
198+
if unmarginalized_params !== nothing
199+
full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params
200+
end
201+
return DynamicPPL.unflatten(original_vi, full_params)
202+
end
203+
204+
end

src/DynamicPPL.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ export AbstractVarInfo,
122122
fix,
123123
unfix,
124124
predict,
125+
marginalize,
125126
prefix,
126127
returned,
127128
to_submodel,
@@ -199,9 +200,9 @@ include("test_utils.jl")
199200
include("experimental.jl")
200201
include("deprecated.jl")
201202

202-
# Better error message if users forget to load JET
203203
if isdefined(Base.Experimental, :register_error_hint)
204204
function __init__()
205+
# Better error message if users forget to load JET.jl
205206
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
206207
requires_jet =
207208
exc.f === DynamicPPL.Experimental._determine_varinfo_jet &&
@@ -222,6 +223,23 @@ if isdefined(Base.Experimental, :register_error_hint)
222223
end
223224
end
224225

226+
# Same for MarginalLogDensities.jl
227+
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
228+
requires_mld =
229+
exc.f === DynamicPPL.marginalize &&
230+
length(argtypes) == 2 &&
231+
argtypes[1] <: Model &&
232+
argtypes[2] <: AbstractVector{<:Union{Symbol,<:VarName}}
233+
if requires_mld
234+
printstyled(
235+
io,
236+
"\n\n `$(exc.f)` requires MarginalLogDensities.jl to be loaded.\n Please run `using MarginalLogDensities` before calling `$(exc.f)`.\n";
237+
color=:cyan,
238+
bold=true,
239+
)
240+
end
241+
end
242+
225243
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
226244
is_evaluate_three_arg =
227245
exc.f === AbstractPPL.evaluate!! &&
@@ -243,4 +261,7 @@ end
243261
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
244262
struct DynamicPPLTag end
245263

264+
# Extended in MarginalLogDensitiesExt
265+
function marginalize end
266+
246267
end # module

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1818
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1919
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2020
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
21+
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
2122
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
2223
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2324
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

0 commit comments

Comments
 (0)