Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,17 @@
# DynamicPPL Changelog

## 0.37.4

An extension for MarginalLogDensities.jl has been added.

Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalise out variables from a model.
This is useful for averaging out random effects or nuisance parameters while improving inference on fixed effects/parameters of interest.
The `marginalize` function returns a `MarginalLogDensities.MarginalLogDensity`, a function-like callable struct that returns the approximate log-density of a subset of the parameters after integrating out the rest of them.
By default, this uses the Laplace approximation and sparse AD, making the marginalisation computationally very efficient.
Note that the Laplace approximation relies on the model being differentiable with respect to the marginalised variables, and that their posteriors are unimodal and approximately Gaussian.

Please see [the MarginalLogDensities documentation](https://eloceanografo.github.io/MarginalLogDensities.jl/stable) and the [new Marginalisation section of the DynamicPPL documentation](https://turinglang.org/DynamicPPL.jl/v0.37/api/#Marginalisation) for further information.

## 0.37.3

Prevents inlining of `DynamicPPL.istrans` with Enzyme, which allows Enzyme to differentiate models where `VarName`s have the same symbol but different types.
Expand Down
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.37.3"
version = "0.37.4"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -33,6 +33,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
Expand All @@ -41,6 +42,7 @@ DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
DynamicPPLMooncakeExt = ["Mooncake"]

[compat]
Expand All @@ -66,6 +68,7 @@ LinearAlgebra = "1.6"
LogDensityProblems = "2"
MCMCChains = "6, 7"
MacroTools = "0.5.6"
MarginalLogDensities = "0.4.3"
Mooncake = "0.4.147"
OrderedCollections = "1"
Printf = "1.10"
Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
Expand Down
10 changes: 9 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ using Distributions
using DocumenterMermaid
# load MCMCChains package extension to make `predict` available
using MCMCChains
using MarginalLogDensities: MarginalLogDensities

# Need this to document a method which uses a type inside the extension...
DPPLMLDExt = Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt)

# Doctest setup
DocMeta.setdocmeta!(
Expand All @@ -24,7 +28,11 @@ makedocs(;
format=Documenter.HTML(;
size_threshold=2^10 * 400, mathengine=Documenter.HTMLWriter.MathJax3()
),
modules=[DynamicPPL, Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt)],
modules=[
DynamicPPL,
Base.get_extension(DynamicPPL, :DynamicPPLMCMCChainsExt),
Base.get_extension(DynamicPPL, :DynamicPPLMarginalLogDensitiesExt),
],
pages=[
"Home" => "index.md", "API" => "api.md", "Internals" => ["internals/varinfo.md"]
],
Expand Down
16 changes: 16 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,22 @@ When using `predict` with `MCMCChains.Chains`, you can control which variables a
- `include_all=false` (default): Include only newly predicted variables
- `include_all=true`: Include both parameters from the original chain and predicted variables

## Marginalisation

DynamicPPL provides the `marginalize` function to marginalise out variables from a model.
This requires `MarginalLogDensities.jl` to be loaded in your environment.

```@docs
marginalize
```

A `MarginalLogDensity` object acts as a function which maps non-marginalised parameter values to a marginal log-probability.
To retrieve a VarInfo object from it, you can use:

```@docs
VarInfo(::MarginalLogDensities.MarginalLogDensity{<:DPPLMLDExt.LogDensityFunctionWrapper}, ::Union{AbstractVector,Nothing})
```

## Models within models

One can include models and call another model inside the model function with `left ~ to_submodel(model)`.
Expand Down
204 changes: 204 additions & 0 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
module DynamicPPLMarginalLogDensitiesExt

using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
using MarginalLogDensities: MarginalLogDensities

# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type
# below.
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction}
logdensity::L
end
function (lw::LogDensityFunctionWrapper)(x, _)
return LogDensityProblems.logdensity(lw.logdensity, x)
end

"""
marginalize(
model::DynamicPPL.Model,
marginalized_varnames::AbstractVector{<:VarName};
varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model),
getlogprob=DynamicPPL.getlogjoint,
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
kwargs...,
)

Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal
log-density of the given `model`, after marginalizing out the variables specified in
`varnames`.

The resulting object can be called with a vector of parameter values to compute the marginal
log-density.

## Keyword arguments

- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`,
meaning that the resulting log-density function accepts parameters that have been
transformed to unconstrained space.

- `getlogprob`: A function which specifies which kind of marginal log-density to compute.
Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint
probability.

- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/)
for other options.

- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity`
constructor.

## Example

```jldoctest
julia> using DynamicPPL, Distributions, MarginalLogDensities

julia> @model function demo()
x ~ Normal(1.0)
y ~ Normal(2.0)
end
demo (generic function with 2 methods)

julia> marginalized = marginalize(demo(), [:x]);

julia> # The resulting callable computes the marginal log-density of `y`.
marginalized([1.0])
-1.4189385332046727

julia> logpdf(Normal(2.0), 1.0)
-1.4189385332046727
```


!!! warning

The default usage of linked VarInfo means that, for example, optimization of the
marginal log-density can be performed in unconstrained space. However, care must be
taken if the model contains variables where the link transformation depends on a
marginalized variable. For example:

```julia
@model function f()
x ~ Normal()
y ~ truncated(Normal(); lower=x)
end
```

Here, the support of `y`, and hence the link transformation used, depends on the value
of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of
`y` to log-probabilities. However, it will not be possible to use DynamicPPL to
correctly retrieve _unlinked_ values of `y`.
"""
function DynamicPPL.marginalize(
model::DynamicPPL.Model,
marginalized_varnames::AbstractVector{<:VarName};
varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model),
getlogprob::Function=DynamicPPL.getlogjoint,
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(),
kwargs...,
)
# Determine the indices for the variables to marginalise out.
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames))
# Construct the marginal log-density model.
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo)
mld = MarginalLogDensities.MarginalLogDensity(
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs...
)
return mld
end

"""
VarInfo(
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
unmarginalized_params::Union{AbstractVector,Nothing}=nothing
)

Retrieve the `VarInfo` object used in the marginalisation process.

If a Laplace approximation was used for the marginalisation, the values of the marginalized
parameters are also set to their mode (note that this only happens if the `mld` object has
been used to compute the marginal log-density at least once, so that the mode has been
computed).

If a vector of `unmarginalized_params` is specified, the values for the corresponding
parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by
performing an optimization of the marginal log-density.

All other aspects of the VarInfo, such as link status, are preserved from the original
VarInfo used in the marginalisation.

!!! note

The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be
updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the
model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model,
vi))`).

## Example

```jldoctest
julia> using DynamicPPL, Distributions, MarginalLogDensities

julia> @model function demo()
x ~ Normal()
y ~ Beta(2, 2)
end
demo (generic function with 2 methods)

julia> # Note that by default `marginalize` uses a linked VarInfo.
mld = marginalize(demo(), [@varname(x)]);

julia> using MarginalLogDensities: Optimization, OptimizationOptimJL

julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`.
y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0])
OptimizationProblem. In-place: true
u0: 1-element Vector{Float64}:
2.0

julia> # This tells us the optimal (linked) value of `y` is around 0.
opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead())
retcode: Success
u: 1-element Vector{Float64}:
4.88281250001733e-5

julia> # Get the VarInfo corresponding to the mode of `y`.
vi = VarInfo(mld, opt_solution.u);

julia> # `x` is set to its mode (which for `Normal()` is zero).
vi[@varname(x)]
0.0

julia> # `y` is set to the optimal value we found above.
DynamicPPL.getindex_internal(vi, @varname(y))
1-element Vector{Float64}:
4.88281250001733e-5

julia> # To obtain values in the original constrained space, we can either
# use `getindex`:
vi[@varname(y)]
0.5000122070312476

julia> # Or invlink the entire VarInfo object using the model:
vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:]
2-element Vector{Float64}:
0.0
0.5000122070312476
```
"""
function DynamicPPL.VarInfo(
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper},
unmarginalized_params::Union{AbstractVector,Nothing}=nothing,
)
# Extract the original VarInfo. Its contents will in general be junk.
original_vi = mld.logdensity.logdensity.varinfo
# Extract the stored parameters, which includes the modes for any marginalized
# parameters
full_params = MarginalLogDensities.cached_params(mld)
# We can then (if needed) set the values for any non-marginalized parameters
if unmarginalized_params !== nothing
full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params
end
return DynamicPPL.unflatten(original_vi, full_params)
end

end
23 changes: 22 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ export AbstractVarInfo,
fix,
unfix,
predict,
marginalize,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that we are exporting new function, I want to tag @yebai to take a look

my assessment is that the risk is low, and benefit is high

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is useful, but given that it is implemented as a pkg extension, we should print a message if relevant packages are not correctly loaded.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added an error hint so now it does this:

julia> using DynamicPPL, Distributions; @model f() = x ~ Normal(); marginalize(f(), [@varname(x)])
ERROR: MethodError: no method matching marginalize(::Model{typeof(f), (), (), (), Tuple{}, Tuple{}, DefaultContext}, ::Vector{VarName{:x, typeof(identity)}})
The function `marginalize` exists, but no method is defined for this combination of argument types.

    `marginalize` requires MarginalLogDensities.jl to be loaded.
    Please run `using MarginalLogDensities` before calling `marginalize`.

Stacktrace:
 [1] top-level scope
   @ REPL[1]:1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @penelopeysm.

One more thing: we tend to follow British spelling in Turing.jl. Can we alias marginalise to marginalize and export both? This would keep everyone happy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to choose one and stick with it. The standard practice for programming is to use AmE spellings and that's generally true even for packages that are written in the UK. IMO it's best to stick to AmE spellings to reduce the amount of mental overhead for users (most people will expect AmE spellings).

So far, we have actually stuck to -ize, e.g. DynamicPPL.contextualize, AbstractPPL.concretize, AdvancedHMC.initialize!, ... so it would be both internally and externally consistent to use marginalize. Turing has src/optimisation/Optimisation.jl but doesn't export anything called 'optimise', only MAP and MLE.

Comments are a different matter of course :) -- but IMO identifiers and argument names should always use -ize

prefix,
returned,
to_submodel,
Expand Down Expand Up @@ -199,9 +200,9 @@ include("test_utils.jl")
include("experimental.jl")
include("deprecated.jl")

# Better error message if users forget to load JET
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
# Better error message if users forget to load JET.jl
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
requires_jet =
exc.f === DynamicPPL.Experimental._determine_varinfo_jet &&
Expand All @@ -222,6 +223,23 @@ if isdefined(Base.Experimental, :register_error_hint)
end
end

# Same for MarginalLogDensities.jl
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
requires_mld =
exc.f === DynamicPPL.marginalize &&
length(argtypes) == 2 &&
argtypes[1] <: Model &&
argtypes[2] <: AbstractVector{<:Union{Symbol,<:VarName}}
if requires_mld
printstyled(
io,
"\n\n `$(exc.f)` requires MarginalLogDensities.jl to be loaded.\n Please run `using MarginalLogDensities` before calling `$(exc.f)`.\n";
color=:cyan,
bold=true,
)
end
end

Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _
is_evaluate_three_arg =
exc.f === AbstractPPL.evaluate!! &&
Expand All @@ -243,4 +261,7 @@ end
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct DynamicPPLTag end

# Extended in MarginalLogDensitiesExt
function marginalize end

end # module
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
Loading
Loading