-
Notifications
You must be signed in to change notification settings - Fork 36
MarginalLogDensities extension #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4b820f5
8fab001
105013a
1b3e76b
aaca138
14f9f46
00c08b2
e9eabb4
844ec4c
f4049c1
144bee7
29840be
e590e02
06599cf
015534b
206d661
d8fd73c
d5f2063
2b1d5e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
module DynamicPPLMarginalLogDensitiesExt | ||
|
||
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName | ||
using MarginalLogDensities: MarginalLogDensities | ||
|
||
# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by | ||
# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type | ||
# below. | ||
struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} | ||
logdensity::L | ||
end | ||
function (lw::LogDensityFunctionWrapper)(x, _) | ||
return LogDensityProblems.logdensity(lw.logdensity, x) | ||
end | ||
|
||
""" | ||
marginalize( | ||
model::DynamicPPL.Model, | ||
marginalized_varnames::AbstractVector{<:VarName}; | ||
varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model), | ||
getlogprob=DynamicPPL.getlogjoint, | ||
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); | ||
kwargs..., | ||
) | ||
|
||
Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal | ||
log-density of the given `model`, after marginalizing out the variables specified in | ||
`varnames`. | ||
|
||
The resulting object can be called with a vector of parameter values to compute the marginal | ||
log-density. | ||
|
||
## Keyword arguments | ||
|
||
- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`, | ||
meaning that the resulting log-density function accepts parameters that have been | ||
transformed to unconstrained space. | ||
|
||
- `getlogprob`: A function which specifies which kind of marginal log-density to compute. | ||
Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint | ||
probability. | ||
|
||
- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the | ||
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) | ||
for other options. | ||
|
||
- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity` | ||
constructor. | ||
|
||
## Example | ||
|
||
```jldoctest | ||
julia> using DynamicPPL, Distributions, MarginalLogDensities | ||
|
||
julia> @model function demo() | ||
x ~ Normal(1.0) | ||
y ~ Normal(2.0) | ||
end | ||
demo (generic function with 2 methods) | ||
|
||
julia> marginalized = marginalize(demo(), [:x]); | ||
|
||
julia> # The resulting callable computes the marginal log-density of `y`. | ||
marginalized([1.0]) | ||
-1.4189385332046727 | ||
|
||
julia> logpdf(Normal(2.0), 1.0) | ||
-1.4189385332046727 | ||
``` | ||
|
||
|
||
!!! warning | ||
|
||
The default usage of linked VarInfo means that, for example, optimization of the | ||
marginal log-density can be performed in unconstrained space. However, care must be | ||
taken if the model contains variables where the link transformation depends on a | ||
marginalized variable. For example: | ||
|
||
```julia | ||
@model function f() | ||
x ~ Normal() | ||
y ~ truncated(Normal(); lower=x) | ||
end | ||
``` | ||
|
||
Here, the support of `y`, and hence the link transformation used, depends on the value | ||
of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of | ||
`y` to log-probabilities. However, it will not be possible to use DynamicPPL to | ||
correctly retrieve _unlinked_ values of `y`. | ||
""" | ||
function DynamicPPL.marginalize( | ||
model::DynamicPPL.Model, | ||
marginalized_varnames::AbstractVector{<:VarName}; | ||
varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model), | ||
getlogprob::Function=DynamicPPL.getlogjoint, | ||
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), | ||
kwargs..., | ||
) | ||
# Determine the indices for the variables to marginalise out. | ||
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames)) | ||
# Construct the marginal log-density model. | ||
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) | ||
mld = MarginalLogDensities.MarginalLogDensity( | ||
LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... | ||
) | ||
return mld | ||
end | ||
|
||
""" | ||
VarInfo( | ||
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, | ||
unmarginalized_params::Union{AbstractVector,Nothing}=nothing | ||
) | ||
|
||
Retrieve the `VarInfo` object used in the marginalisation process. | ||
|
||
If a Laplace approximation was used for the marginalisation, the values of the marginalized | ||
parameters are also set to their mode (note that this only happens if the `mld` object has | ||
been used to compute the marginal log-density at least once, so that the mode has been | ||
computed). | ||
|
||
If a vector of `unmarginalized_params` is specified, the values for the corresponding | ||
parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by | ||
performing an optimization of the marginal log-density. | ||
|
||
All other aspects of the VarInfo, such as link status, are preserved from the original | ||
VarInfo used in the marginalisation. | ||
|
||
!!! note | ||
|
||
The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be | ||
updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the | ||
model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model, | ||
vi))`). | ||
|
||
## Example | ||
|
||
```jldoctest | ||
julia> using DynamicPPL, Distributions, MarginalLogDensities | ||
|
||
julia> @model function demo() | ||
x ~ Normal() | ||
y ~ Beta(2, 2) | ||
end | ||
demo (generic function with 2 methods) | ||
|
||
julia> # Note that by default `marginalize` uses a linked VarInfo. | ||
mld = marginalize(demo(), [@varname(x)]); | ||
|
||
julia> using MarginalLogDensities: Optimization, OptimizationOptimJL | ||
|
||
julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`. | ||
y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0]) | ||
OptimizationProblem. In-place: true | ||
u0: 1-element Vector{Float64}: | ||
2.0 | ||
|
||
julia> # This tells us the optimal (linked) value of `y` is around 0. | ||
opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead()) | ||
retcode: Success | ||
u: 1-element Vector{Float64}: | ||
4.88281250001733e-5 | ||
|
||
julia> # Get the VarInfo corresponding to the mode of `y`. | ||
vi = VarInfo(mld, opt_solution.u); | ||
|
||
julia> # `x` is set to its mode (which for `Normal()` is zero). | ||
vi[@varname(x)] | ||
0.0 | ||
|
||
julia> # `y` is set to the optimal value we found above. | ||
DynamicPPL.getindex_internal(vi, @varname(y)) | ||
1-element Vector{Float64}: | ||
4.88281250001733e-5 | ||
|
||
julia> # To obtain values in the original constrained space, we can either | ||
# use `getindex`: | ||
vi[@varname(y)] | ||
0.5000122070312476 | ||
|
||
julia> # Or invlink the entire VarInfo object using the model: | ||
vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:] | ||
2-element Vector{Float64}: | ||
0.0 | ||
0.5000122070312476 | ||
``` | ||
""" | ||
function DynamicPPL.VarInfo( | ||
mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, | ||
unmarginalized_params::Union{AbstractVector,Nothing}=nothing, | ||
) | ||
# Extract the original VarInfo. Its contents will in general be junk. | ||
original_vi = mld.logdensity.logdensity.varinfo | ||
# Extract the stored parameters, which includes the modes for any marginalized | ||
# parameters | ||
full_params = MarginalLogDensities.cached_params(mld) | ||
# We can then (if needed) set the values for any non-marginalized parameters | ||
if unmarginalized_params !== nothing | ||
full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params | ||
end | ||
return DynamicPPL.unflatten(original_vi, full_params) | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,6 +122,7 @@ export AbstractVarInfo, | |
fix, | ||
unfix, | ||
predict, | ||
marginalize, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given that we are exporting new function, I want to tag @yebai to take a look my assessment is that the risk is low, and benefit is high There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is useful, but given that it is implemented as a pkg extension, we should print a message if relevant packages are not correctly loaded. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just added an error hint so now it does this: julia> using DynamicPPL, Distributions; @model f() = x ~ Normal(); marginalize(f(), [@varname(x)])
ERROR: MethodError: no method matching marginalize(::Model{typeof(f), (), (), (), Tuple{}, Tuple{}, DefaultContext}, ::Vector{VarName{:x, typeof(identity)}})
The function `marginalize` exists, but no method is defined for this combination of argument types.
`marginalize` requires MarginalLogDensities.jl to be loaded.
Please run `using MarginalLogDensities` before calling `marginalize`.
Stacktrace:
[1] top-level scope
@ REPL[1]:1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @penelopeysm. One more thing: we tend to follow British spelling in Turing.jl. Can we alias There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to choose one and stick with it. The standard practice for programming is to use AmE spellings and that's generally true even for packages that are written in the UK. IMO it's best to stick to AmE spellings to reduce the amount of mental overhead for users (most people will expect AmE spellings). So far, we have actually stuck to -ize, e.g. Comments are a different matter of course :) -- but IMO identifiers and argument names should always use -ize |
||
prefix, | ||
returned, | ||
to_submodel, | ||
|
@@ -199,9 +200,9 @@ include("test_utils.jl") | |
include("experimental.jl") | ||
include("deprecated.jl") | ||
|
||
# Better error message if users forget to load JET | ||
if isdefined(Base.Experimental, :register_error_hint) | ||
function __init__() | ||
# Better error message if users forget to load JET.jl | ||
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ | ||
requires_jet = | ||
exc.f === DynamicPPL.Experimental._determine_varinfo_jet && | ||
|
@@ -222,6 +223,23 @@ if isdefined(Base.Experimental, :register_error_hint) | |
end | ||
end | ||
|
||
# Same for MarginalLogDensities.jl | ||
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ | ||
requires_mld = | ||
exc.f === DynamicPPL.marginalize && | ||
length(argtypes) == 2 && | ||
argtypes[1] <: Model && | ||
argtypes[2] <: AbstractVector{<:Union{Symbol,<:VarName}} | ||
if requires_mld | ||
printstyled( | ||
io, | ||
"\n\n `$(exc.f)` requires MarginalLogDensities.jl to be loaded.\n Please run `using MarginalLogDensities` before calling `$(exc.f)`.\n"; | ||
color=:cyan, | ||
bold=true, | ||
) | ||
end | ||
end | ||
|
||
Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, _ | ||
is_evaluate_three_arg = | ||
exc.f === AbstractPPL.evaluate!! && | ||
|
@@ -243,4 +261,7 @@ end | |
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ | ||
struct DynamicPPLTag end | ||
|
||
# Extended in MarginalLogDensitiesExt | ||
function marginalize end | ||
|
||
end # module |
Uh oh!
There was an error while loading. Please reload this page.