-
Notifications
You must be signed in to change notification settings - Fork 36
MarginalLogDensities extension #1036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
4b820f5
8fab001
105013a
1b3e76b
aaca138
14f9f46
00c08b2
e9eabb4
844ec4c
f4049c1
144bee7
29840be
e590e02
06599cf
015534b
206d661
d8fd73c
d5f2063
2b1d5e3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
module DynamicPPLMarginalLogDensitiesExt | ||
|
||
using DynamicPPL: DynamicPPL, LogDensityProblems, VarName | ||
using MarginalLogDensities: MarginalLogDensities | ||
|
||
_to_varname(n::Symbol) = VarName{n}() | ||
_to_varname(n::VarName) = n | ||
|
||
""" | ||
marginalize( | ||
model::DynamicPPL.Model, | ||
varnames::AbstractVector{<:Union{Symbol,<:VarName}}, | ||
getlogprob=DynamicPPL.getlogjoint, | ||
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); | ||
kwargs..., | ||
) | ||
Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal | ||
log-density of the given `model`, after marginalizing out the variables specified in | ||
`varnames`. | ||
The resulting object can be called with a vector of parameter values to compute the marginal | ||
log-density. | ||
The `getlogprob` argument can be used to specify which kind of marginal log-density to | ||
compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint | ||
probability. | ||
By default the marginalization is performed with a Laplace approximation. Please see [the | ||
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) | ||
for other options. | ||
## Example | ||
```jldoctest | ||
julia> using DynamicPPL, Distributions, MarginalLogDensities | ||
julia> @model function demo() | ||
x ~ Normal(1.0) | ||
y ~ Normal(2.0) | ||
end | ||
demo (generic function with 2 methods) | ||
julia> marginalized = marginalize(demo(), [:x]); | ||
julia> # The resulting callable computes the marginal log-density of `y`. | ||
marginalized([1.0]) | ||
-1.4189385332046727 | ||
julia> logpdf(Normal(2.0), 1.0) | ||
-1.4189385332046727 | ||
``` | ||
""" | ||
function DynamicPPL.marginalize( | ||
model::DynamicPPL.Model, | ||
varnames::AbstractVector{<:Union{Symbol,<:VarName}}, | ||
getlogprob=DynamicPPL.getlogjoint, | ||
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); | ||
kwargs..., | ||
) | ||
# Determine the indices for the variables to marginalise out. | ||
varinfo = DynamicPPL.typed_varinfo(model) | ||
vns = map(_to_varname, varnames) | ||
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns)) | ||
penelopeysm marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
# Construct the marginal log-density model. | ||
# Use linked `varinfo` to that we're working in unconstrained space | ||
varinfo_linked = DynamicPPL.link(varinfo, model) | ||
|
||
|
||
f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked) | ||
mdl = MarginalLogDensities.MarginalLogDensity( | ||
(x, _) -> LogDensityProblems.logdensity(f, x), | ||
varinfo_linked[:], | ||
varindices, | ||
(), | ||
method; | ||
kwargs..., | ||
) | ||
return mdl | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,6 +122,7 @@ export AbstractVarInfo, | |
fix, | ||
unfix, | ||
predict, | ||
marginalize, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. given that we are exporting new function, I want to tag @yebai to take a look my assessment is that the risk is low, and benefit is high There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is useful, but given that it is implemented as a pkg extension, we should print a message if relevant packages are not correctly loaded. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just added an error hint so now it does this: julia> using DynamicPPL, Distributions; @model f() = x ~ Normal(); marginalize(f(), [@varname(x)])
ERROR: MethodError: no method matching marginalize(::Model{typeof(f), (), (), (), Tuple{}, Tuple{}, DefaultContext}, ::Vector{VarName{:x, typeof(identity)}})
The function `marginalize` exists, but no method is defined for this combination of argument types.
`marginalize` requires MarginalLogDensities.jl to be loaded.
Please run `using MarginalLogDensities` before calling `marginalize`.
Stacktrace:
[1] top-level scope
@ REPL[1]:1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @penelopeysm. One more thing: we tend to follow British spelling in Turing.jl. Can we alias There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to choose one and stick with it. The standard practice for programming is to use AmE spellings and that's generally true even for packages that are written in the UK. IMO it's best to stick to AmE spellings to reduce the amount of mental overhead for users (most people will expect AmE spellings). So far, we have actually stuck to -ize, e.g. Comments are a different matter of course :) -- but IMO identifiers and argument names should always use -ize |
||
prefix, | ||
returned, | ||
to_submodel, | ||
|
@@ -199,10 +200,6 @@ include("test_utils.jl") | |
include("experimental.jl") | ||
include("deprecated.jl") | ||
|
||
if !isdefined(Base, :get_extension) | ||
using Requires | ||
end | ||
|
||
# Better error message if users forget to load JET | ||
if isdefined(Base.Experimental, :register_error_hint) | ||
function __init__() | ||
|
@@ -247,4 +244,7 @@ end | |
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/ | ||
struct DynamicPPLTag end | ||
|
||
# Extended in MarginalLogDensitiesExt | ||
function marginalize end | ||
|
||
end # module |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
module MarginalLogDensitiesExtTests | ||
|
||
using DynamicPPL, Distributions, Test | ||
using MarginalLogDensities | ||
using ADTypes: AutoForwardDiff | ||
|
||
@testset "MarginalLogDensities" begin | ||
# Simple test case. | ||
@model function demo() | ||
x ~ MvNormal(zeros(2), [1, 1]) | ||
return y ~ Normal(0, 1) | ||
end | ||
model = demo() | ||
# Marginalize out `x`. | ||
|
||
for vn in [@varname(x), :x] | ||
for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint] | ||
marginalized = marginalize( | ||
model, [vn], getlogprob; hess_adtype=AutoForwardDiff() | ||
) | ||
# Compute the marginal log-density of `y = 0.0`. | ||
@test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol = 1e-5 | ||
end | ||
end | ||
end | ||
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.