Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# DynamicPPL Changelog

## 0.37.3

An extension for MarginalLogDensities.jl has been added.
Loading DynamicPPL and MarginalLogDensities now provides the `DynamicPPL.marginalize` function to marginalize out variables from a model; please see the documentation for further information.

## 0.37.2

Make the `resume_from` keyword work for multiple-chain (parallel) sampling as well.
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.37.2"
version = "0.37.3"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -34,13 +34,15 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[extensions]
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
DynamicPPLForwardDiffExt = ["ForwardDiff"]
DynamicPPLJETExt = ["JET"]
DynamicPPLMarginalLogDensitiesExt = ["MarginalLogDensities"]
DynamicPPLMCMCChainsExt = ["MCMCChains"]
DynamicPPLMooncakeExt = ["Mooncake"]

Expand Down
1 change: 1 addition & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"

[compat]
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ using Distributions
using DocumenterMermaid
# load MCMCChains package extension to make `predict` available
using MCMCChains
using MarginalLogDensities: MarginalLogDensities

# Doctest setup
DocMeta.setdocmeta!(
Expand Down
9 changes: 9 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,15 @@ The `predict` function has two main methods:
predict
```

## Marginalization

DynamicPPL provides the `marginalize` function to marginalize out variables from a model.
This requires `MarginalLogDensities.jl` to be loaded in your environment.

```@docs
marginalize
```

### Basic Usage

The typical workflow for posterior prediction involves:
Expand Down
81 changes: 81 additions & 0 deletions ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
module DynamicPPLMarginalLogDensitiesExt

using DynamicPPL: DynamicPPL, LogDensityProblems, VarName
using MarginalLogDensities: MarginalLogDensities

_to_varname(n::Symbol) = VarName{n}()
_to_varname(n::VarName) = n

"""
marginalize(
model::DynamicPPL.Model,
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
getlogprob=DynamicPPL.getlogjoint,
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
kwargs...,
)
Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal
log-density of the given `model`, after marginalizing out the variables specified in
`varnames`.
The resulting object can be called with a vector of parameter values to compute the marginal
log-density.
The `getlogprob` argument can be used to specify which kind of marginal log-density to
compute. Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint
probability.
By default the marginalization is performed with a Laplace approximation. Please see [the
MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/)
for other options.
## Example
```jldoctest
julia> using DynamicPPL, Distributions, MarginalLogDensities
julia> @model function demo()
x ~ Normal(1.0)
y ~ Normal(2.0)
end
demo (generic function with 2 methods)
julia> marginalized = marginalize(demo(), [:x]);
julia> # The resulting callable computes the marginal log-density of `y`.
marginalized([1.0])
-1.4189385332046727
julia> logpdf(Normal(2.0), 1.0)
-1.4189385332046727
```
"""
function DynamicPPL.marginalize(
model::DynamicPPL.Model,
varnames::AbstractVector{<:Union{Symbol,<:VarName}},
getlogprob=DynamicPPL.getlogjoint,
method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox();
kwargs...,
)
# Determine the indices for the variables to marginalise out.
varinfo = DynamicPPL.typed_varinfo(model)
vns = map(_to_varname, varnames)
varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, vns))
# Construct the marginal log-density model.
# Use linked `varinfo` to that we're working in unconstrained space
varinfo_linked = DynamicPPL.link(varinfo, model)
Copy link
Member Author

@penelopeysm penelopeysm Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does MLD require a linked VarInfo, or would an unlinked VarInfo be fine?

The reason I'm thinking about this is because if the VarInfo is linked, then all the parameters that are later supplied must be in linked space, which is potentially a bit confusing (though nothing that can't be fixed by documentation). Example:

julia> using DynamicPPL, Distributions, Bijectors, MarginalLogDensities

julia> @model function f()
           x ~ Normal()
           y ~ Beta(2, 2)
       end
f (generic function with 2 methods)

julia> m = marginalize(f(), [@varname(x)]);

julia> m([0.5]) # this 0.5 is in linked space
0.3436055008678415

julia> logpdf(Beta(2, 2), 0.5) # this 0.5 is unlinked, so logp is wrong
0.4054651081081644

julia> inverse(Bijectors.bijector(Beta(2, 2)))(0.5) # this is the unlinked value corresponding to 0.5
0.6224593312018546

julia> logpdf(Beta(2, 2), 0.6224593312018546) # now logp matches
0.3436055008678416

If an unlinked VarInfo is acceptable, then the choice of varinfo should probably also be added as an argument to marginalize.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had just copied this from @torfjelde's original example without understanding the linked/unlinked VarInfo distinction. I agree that unlinked would be make more sense in this context.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, will change

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While writing a test I ran into this:

julia> using MarginalLogDensities, Distributions

julia> f(x, d) = logpdf(d.dist, x[1])
f (generic function with 1 method)

julia> MarginalLogDensity(f, [0.5], [1], (;dist=Normal()))(Float64[])
-1.1102230246251565e-16

julia> MarginalLogDensity(f, [0.5], [1], (;dist=Beta(2, 2)))(Float64[])
0.2846828704729192

My impression is that it should be zero. I added Cubature() and it goes down to -3e-10 (yay!), but for the equivalent linked varinfo case, even using Cubature() doesn't help: it still returns a constant additive term of around 1.7. Am I correct in saying that this is generally expected?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is just because the Laplace approximation is fitting a beta distribution with a normal, based on the curvature at the peak of the beta, and it's not a good approximation. By default, Cubature also does an LA to select its upper and lower integration limits, based on 6x the standard deviation of the approximating normal, which may or may not be good enough:

# default 6 sigma
julia> MarginalLogDensity(f, [0.5], [1], (;dist=TDist(5)), Cubature())(Float64[])
-0.0018478445031467727

# manually setting wider integration limits
julia> MarginalLogDensity(f, [0.5], [1], (;dist=TDist(5)), Cubature(upper=[20], lower=[-20]))(Float64[])
-5.775533047958272e-6

Thinking a bit more about the linked/unlinked question, I'm not actually sure unlinked is right. The MarginalLogDensity is almost always going to be used to optimize or sample from the fixed-effect parameters, meaning they need to be linked.

Interpreting the optimum/chain is then potentially confusing, and this confusion is probably more likely if you used DynamicPPL/Turing to write your model, since you don't have to handle the parameter transforms yourself, as you would if you wrote the model function from scratch. But I think that's better than having an MLD object that produces optimizer/sampler errors.

For now, I think we can just document this behavior, along with a clear example of how to un-link the variables.

Copy link
Member Author

@penelopeysm penelopeysm Sep 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(For example, if we marginalise x out of the model above, do you know what it would mean to 'unlink' the value of y? I don't think I've fully wrapped my head around it yet 😅. My current suspicion is that we'd need to access the original value of x that's stored inside the MarginalLogDensity object, but I am pretty confused. I might have to read the MLD source code more to see what's going on.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, the problem is when the link function for one variable depends on another one.

The MLD object stores the full parameter vector, and updates the marginalized subset to their optimal values every time it does a Laplace approximation. So if you need the full set of parameters you can get them via mld.u, see here. So you could get a link function for y, though it wouldn't necessarily be the right one, since it would be for a point estimate of x rather than integrating over its distribution.

That should be doable taking advantage of the cached Laplace approx info...i.e., if the unlinked $y$ is related to the linked $y_L$ as $y = \mathrm{invlink}(y_L, x)$, then

$$ E[y] = \int \mathrm{invlink}(y_L, x) p(x) dx $$

where we can define $p(x)$ as an MVNormal based on the mode and Hessian stored in the MLD object.

Also worth noting that calling maximum_a_posteriori(f()) with your example produces a different mode than MCMC:

julia> m = f()

julia> mean(sample(m, NUTS(), 1000))
┌ Info: Found initial step size
└   ϵ = 0.4
Sampling 100%|██████████████████████████████████████████| Time: 0:00:00
Mean
  parameters      mean 
      Symbol   Float64 

           x   -0.0507
           y    0.9262


julia> maximum_a_posteriori(m)
ModeResult with maximized lp of -0.90
[0.6120031806556528, 0.6120031809603704]

so maybe for now we can file this under "problematic model structure for deterministic optimization," add a warning box to the docs, and call it good enough...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Yup, I'd be happy to leave that as a "to be figured out". Then, that just leaves the question of how to invlink the values for a non-pathological model, which I think should be easy enough with the u field. Let me see if I can put together an example and a test, then I'd be quite happy to merge this in.

Copy link
Member Author

@penelopeysm penelopeysm Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working through it analytically (not 100% sure if I got everything right, but some plots back me up qualitatively), it's probably a bit of a nonsensical example because we have that

$$\begin{align*} p(x) &= npdf(x) \\ p(y|x) &= \begin{cases} \frac{npdf(y)}{1 - ncdf(x)} & x \leq y \\ 0 & \text{otherwise} \end{cases}\\ \implies \quad p(y) &= \int_{-\infty}^{\infty} p(y|x) p(x) \mathrm{d}x \\ &= \int_{-\infty}^{y} npdf(y) \left(\frac{npdf(x)}{1 - ncdf(x)} \right) \mathrm{d}x \\ \end{align*}$$

where npdf and ncdf are the pdf and cdf for the unit normal distribution. But I'm pretty sure that $p(y)$ can be made arbitrarily large by increasing $y$, because as $x \to +\infty$, $ncdf(x) \to 1$ and so the denominator in the integrand vanishes.

I guess the obvious question is whether there is some example that doesn't suffer from this, but maybe that's for another time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a warning to the docstring, so hopefully that'll be fine for now!


f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo_linked)
mdl = MarginalLogDensities.MarginalLogDensity(
(x, _) -> LogDensityProblems.logdensity(f, x),
varinfo_linked[:],
varindices,
(),
method;
kwargs...,
)
return mdl
end

end
8 changes: 4 additions & 4 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ export AbstractVarInfo,
fix,
unfix,
predict,
marginalize,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that we are exporting new function, I want to tag @yebai to take a look

my assessment is that the risk is low, and benefit is high

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is useful, but given that it is implemented as a pkg extension, we should print a message if relevant packages are not correctly loaded.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just added an error hint so now it does this:

julia> using DynamicPPL, Distributions; @model f() = x ~ Normal(); marginalize(f(), [@varname(x)])
ERROR: MethodError: no method matching marginalize(::Model{typeof(f), (), (), (), Tuple{}, Tuple{}, DefaultContext}, ::Vector{VarName{:x, typeof(identity)}})
The function `marginalize` exists, but no method is defined for this combination of argument types.

    `marginalize` requires MarginalLogDensities.jl to be loaded.
    Please run `using MarginalLogDensities` before calling `marginalize`.

Stacktrace:
 [1] top-level scope
   @ REPL[1]:1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @penelopeysm.

One more thing: we tend to follow British spelling in Turing.jl. Can we alias marginalise to marginalize and export both? This would keep everyone happy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to choose one and stick with it. The standard practice for programming is to use AmE spellings and that's generally true even for packages that are written in the UK. IMO it's best to stick to AmE spellings to reduce the amount of mental overhead for users (most people will expect AmE spellings).

So far, we have actually stuck to -ize, e.g. DynamicPPL.contextualize, AbstractPPL.concretize, AdvancedHMC.initialize!, ... so it would be both internally and externally consistent to use marginalize. Turing has src/optimisation/Optimisation.jl but doesn't export anything called 'optimise', only MAP and MLE.

Comments are a different matter of course :) -- but IMO identifiers and argument names should always use -ize

prefix,
returned,
to_submodel,
Expand Down Expand Up @@ -199,10 +200,6 @@ include("test_utils.jl")
include("experimental.jl")
include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end

# Better error message if users forget to load JET
if isdefined(Base.Experimental, :register_error_hint)
function __init__()
Expand Down Expand Up @@ -247,4 +244,7 @@ end
# Ref: https://www.stochasticlifestyle.com/improved-forwarddiff-jl-stacktraces-with-package-tags/
struct DynamicPPLTag end

# Extended in MarginalLogDensitiesExt
function marginalize end

end # module
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MarginalLogDensities = "f0c3360a-fb8d-11e9-1194-5521fd7ee392"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
27 changes: 27 additions & 0 deletions test/ext/DynamicPPLMarginalLogDensitiesExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module MarginalLogDensitiesExtTests

using DynamicPPL, Distributions, Test
using MarginalLogDensities
using ADTypes: AutoForwardDiff

@testset "MarginalLogDensities" begin
# Simple test case.
@model function demo()
x ~ MvNormal(zeros(2), [1, 1])
return y ~ Normal(0, 1)
end
model = demo()
# Marginalize out `x`.

for vn in [@varname(x), :x]
for getlogprob in [DynamicPPL.getlogprior, DynamicPPL.getlogjoint]
marginalized = marginalize(
model, [vn], getlogprob; hess_adtype=AutoForwardDiff()
)
# Compute the marginal log-density of `y = 0.0`.
@test marginalized([0.0]) ≈ logpdf(Normal(0, 1), 0.0) atol = 1e-5
end
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ include("test_util.jl")
@testset "extensions" begin
include("ext/DynamicPPLMCMCChainsExt.jl")
include("ext/DynamicPPLJETExt.jl")
include("ext/DynamicPPLMarginalLogDensitiesExt.jl")
end
@testset "ad" begin
include("ext/DynamicPPLForwardDiffExt.jl")
Expand Down
Loading