|
| 1 | +module DynamicPPLMarginalLogDensitiesExt |
| 2 | + |
| 3 | +using DynamicPPL: DynamicPPL, LogDensityProblems, VarName |
| 4 | +using MarginalLogDensities: MarginalLogDensities |
| 5 | + |
| 6 | +# A thin wrapper to adapt a DynamicPPL.LogDensityFunction to the interface expected by |
| 7 | +# MarginalLogDensities. It's helpful to have a struct so that we can dispatch on its type |
| 8 | +# below. |
| 9 | +struct LogDensityFunctionWrapper{L<:DynamicPPL.LogDensityFunction} |
| 10 | + logdensity::L |
| 11 | +end |
| 12 | +function (lw::LogDensityFunctionWrapper)(x, _) |
| 13 | + return LogDensityProblems.logdensity(lw.logdensity, x) |
| 14 | +end |
| 15 | + |
| 16 | +""" |
| 17 | + marginalize( |
| 18 | + model::DynamicPPL.Model, |
| 19 | + marginalized_varnames::AbstractVector{<:VarName}; |
| 20 | + varinfo::DynamicPPL.AbstractVarInfo=link(VarInfo(model), model), |
| 21 | + getlogprob=DynamicPPL.getlogjoint, |
| 22 | + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(); |
| 23 | + kwargs..., |
| 24 | + ) |
| 25 | +
|
| 26 | +Construct a `MarginalLogDensities.MarginalLogDensity` object that represents the marginal |
| 27 | +log-density of the given `model`, after marginalizing out the variables specified in |
| 28 | +`varnames`. |
| 29 | +
|
| 30 | +The resulting object can be called with a vector of parameter values to compute the marginal |
| 31 | +log-density. |
| 32 | +
|
| 33 | +## Keyword arguments |
| 34 | +
|
| 35 | +- `varinfo`: The `varinfo` to use for the model. By default we use a linked `VarInfo`, |
| 36 | + meaning that the resulting log-density function accepts parameters that have been |
| 37 | + transformed to unconstrained space. |
| 38 | +
|
| 39 | +- `getlogprob`: A function which specifies which kind of marginal log-density to compute. |
| 40 | + Its default value is `DynamicPPL.getlogjoint` which returns the marginal log-joint |
| 41 | + probability. |
| 42 | +
|
| 43 | +- `method`: The marginalization method; defaults to a Laplace approximation. Please see [the |
| 44 | + MarginalLogDensities.jl package](https://github.com/ElOceanografo/MarginalLogDensities.jl/) |
| 45 | + for other options. |
| 46 | +
|
| 47 | +- Other keyword arguments are passed to the `MarginalLogDensities.MarginalLogDensity` |
| 48 | + constructor. |
| 49 | +
|
| 50 | +## Example |
| 51 | +
|
| 52 | +```jldoctest |
| 53 | +julia> using DynamicPPL, Distributions, MarginalLogDensities |
| 54 | +
|
| 55 | +julia> @model function demo() |
| 56 | + x ~ Normal(1.0) |
| 57 | + y ~ Normal(2.0) |
| 58 | + end |
| 59 | +demo (generic function with 2 methods) |
| 60 | +
|
| 61 | +julia> marginalized = marginalize(demo(), [:x]); |
| 62 | +
|
| 63 | +julia> # The resulting callable computes the marginal log-density of `y`. |
| 64 | + marginalized([1.0]) |
| 65 | +-1.4189385332046727 |
| 66 | +
|
| 67 | +julia> logpdf(Normal(2.0), 1.0) |
| 68 | +-1.4189385332046727 |
| 69 | +``` |
| 70 | +
|
| 71 | +
|
| 72 | +!!! warning |
| 73 | +
|
| 74 | + The default usage of linked VarInfo means that, for example, optimization of the |
| 75 | + marginal log-density can be performed in unconstrained space. However, care must be |
| 76 | + taken if the model contains variables where the link transformation depends on a |
| 77 | + marginalized variable. For example: |
| 78 | +
|
| 79 | + ```julia |
| 80 | + @model function f() |
| 81 | + x ~ Normal() |
| 82 | + y ~ truncated(Normal(); lower=x) |
| 83 | + end |
| 84 | + ``` |
| 85 | +
|
| 86 | + Here, the support of `y`, and hence the link transformation used, depends on the value |
| 87 | + of `x`. If we now marginalize over `x`, we obtain a function mapping linked values of |
| 88 | + `y` to log-probabilities. However, it will not be possible to use DynamicPPL to |
| 89 | + correctly retrieve _unlinked_ values of `y`. |
| 90 | +""" |
| 91 | +function DynamicPPL.marginalize( |
| 92 | + model::DynamicPPL.Model, |
| 93 | + marginalized_varnames::AbstractVector{<:VarName}; |
| 94 | + varinfo::DynamicPPL.AbstractVarInfo=DynamicPPL.link(DynamicPPL.VarInfo(model), model), |
| 95 | + getlogprob::Function=DynamicPPL.getlogjoint, |
| 96 | + method::MarginalLogDensities.AbstractMarginalizer=MarginalLogDensities.LaplaceApprox(), |
| 97 | + kwargs..., |
| 98 | +) |
| 99 | + # Determine the indices for the variables to marginalise out. |
| 100 | + varindices = reduce(vcat, DynamicPPL.vector_getranges(varinfo, marginalized_varnames)) |
| 101 | + # Construct the marginal log-density model. |
| 102 | + f = DynamicPPL.LogDensityFunction(model, getlogprob, varinfo) |
| 103 | + mld = MarginalLogDensities.MarginalLogDensity( |
| 104 | + LogDensityFunctionWrapper(f), varinfo[:], varindices, (), method; kwargs... |
| 105 | + ) |
| 106 | + return mld |
| 107 | +end |
| 108 | + |
| 109 | +""" |
| 110 | + VarInfo( |
| 111 | + mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, |
| 112 | + unmarginalized_params::Union{AbstractVector,Nothing}=nothing |
| 113 | + ) |
| 114 | +
|
| 115 | +Retrieve the `VarInfo` object used in the marginalisation process. |
| 116 | +
|
| 117 | +If a Laplace approximation was used for the marginalisation, the values of the marginalized |
| 118 | +parameters are also set to their mode (note that this only happens if the `mld` object has |
| 119 | +been used to compute the marginal log-density at least once, so that the mode has been |
| 120 | +computed). |
| 121 | +
|
| 122 | +If a vector of `unmarginalized_params` is specified, the values for the corresponding |
| 123 | +parameters will also be updated in the returned VarInfo. This vector may be obtained e.g. by |
| 124 | +performing an optimization of the marginal log-density. |
| 125 | +
|
| 126 | +All other aspects of the VarInfo, such as link status, are preserved from the original |
| 127 | +VarInfo used in the marginalisation. |
| 128 | +
|
| 129 | +!!! note |
| 130 | +
|
| 131 | + The other fields of the VarInfo, e.g. accumulated log-probabilities, will not be |
| 132 | + updated. If you wish to have a fully consistent VarInfo, you should re-evaluate the |
| 133 | + model with the returned VarInfo (e.g. using `vi = last(DynamicPPL.evaluate!!(model, |
| 134 | + vi))`). |
| 135 | +
|
| 136 | +## Example |
| 137 | +
|
| 138 | +```jldoctest |
| 139 | +julia> using DynamicPPL, Distributions, MarginalLogDensities |
| 140 | +
|
| 141 | +julia> @model function demo() |
| 142 | + x ~ Normal() |
| 143 | + y ~ Beta(2, 2) |
| 144 | + end |
| 145 | +demo (generic function with 2 methods) |
| 146 | +
|
| 147 | +julia> # Note that by default `marginalize` uses a linked VarInfo. |
| 148 | + mld = marginalize(demo(), [@varname(x)]); |
| 149 | +
|
| 150 | +julia> using MarginalLogDensities: Optimization, OptimizationOptimJL |
| 151 | +
|
| 152 | +julia> # Find the mode of the marginal log-density of `y`, with an initial point of `y0`. |
| 153 | + y0 = 2.0; opt_problem = Optimization.OptimizationProblem(mld, [y0]) |
| 154 | +OptimizationProblem. In-place: true |
| 155 | +u0: 1-element Vector{Float64}: |
| 156 | + 2.0 |
| 157 | +
|
| 158 | +julia> # This tells us the optimal (linked) value of `y` is around 0. |
| 159 | + opt_solution = Optimization.solve(opt_problem, OptimizationOptimJL.NelderMead()) |
| 160 | +retcode: Success |
| 161 | +u: 1-element Vector{Float64}: |
| 162 | + 4.88281250001733e-5 |
| 163 | +
|
| 164 | +julia> # Get the VarInfo corresponding to the mode of `y`. |
| 165 | + vi = VarInfo(mld, opt_solution.u); |
| 166 | +
|
| 167 | +julia> # `x` is set to its mode (which for `Normal()` is zero). |
| 168 | + vi[@varname(x)] |
| 169 | +0.0 |
| 170 | +
|
| 171 | +julia> # `y` is set to the optimal value we found above. |
| 172 | + DynamicPPL.getindex_internal(vi, @varname(y)) |
| 173 | +1-element Vector{Float64}: |
| 174 | + 4.88281250001733e-5 |
| 175 | +
|
| 176 | +julia> # To obtain values in the original constrained space, we can either |
| 177 | + # use `getindex`: |
| 178 | + vi[@varname(y)] |
| 179 | +0.5000122070312476 |
| 180 | +
|
| 181 | +julia> # Or invlink the entire VarInfo object using the model: |
| 182 | + vi_unlinked = DynamicPPL.invlink(vi, demo()); vi_unlinked[:] |
| 183 | +2-element Vector{Float64}: |
| 184 | + 0.0 |
| 185 | + 0.5000122070312476 |
| 186 | +``` |
| 187 | +""" |
| 188 | +function DynamicPPL.VarInfo( |
| 189 | + mld::MarginalLogDensities.MarginalLogDensity{<:LogDensityFunctionWrapper}, |
| 190 | + unmarginalized_params::Union{AbstractVector,Nothing}=nothing, |
| 191 | +) |
| 192 | + # Extract the original VarInfo. Its contents will in general be junk. |
| 193 | + original_vi = mld.logdensity.logdensity.varinfo |
| 194 | + # Extract the stored parameters, which includes the modes for any marginalized |
| 195 | + # parameters |
| 196 | + full_params = MarginalLogDensities.cached_params(mld) |
| 197 | + # We can then (if needed) set the values for any non-marginalized parameters |
| 198 | + if unmarginalized_params !== nothing |
| 199 | + full_params[MarginalLogDensities.ijoint(mld)] = unmarginalized_params |
| 200 | + end |
| 201 | + return DynamicPPL.unflatten(original_vi, full_params) |
| 202 | +end |
| 203 | + |
| 204 | +end |
0 commit comments