Skip to content

Commit

Permalink
Fixes and improvements to experimental Gibbs (#2231)
Browse files Browse the repository at this point in the history
* moved new Gibbs tests all into a single block

* initial work on making Gibbs work with `externalsampler`

* removed references to Setfield.jl

* fixed crucial bug in experimental Gibbs sampler

* added ground-truth comparison for Gibbs sampler on demo models

* added convenience method for performing two sample KS test

* use thinning to avoid OOM issues

* removed incredibly slow testset that didn't really add much

* removed now-redundant testset

* use Anderson-Darling test instead of Kolomogorov-Smirnov to better
capture tail differences + remove subsampling of chains since it
doesn't really matter that when we're using aggressive thinning and
test statistics based on comparing order stats

* more work on testing

* fixed tests

* make failures of `two_sample_ad_tests` a bit more informative

* make failrues of `two_sample_ad_test` produce more informative logs

* additional information upon `two_sample_ad_test` failure

* rename `two_sample_ad_test` to `two_sample_test` and use KS test instead

* added minor test for externalsampler usage

* also test AdvancedHMC samplers with Gibbs

* forgot to add updates to src/mcmc/abstractmcmc.jl in previous commits

* removed usage of `timeit_testset` macro

* added temporary fix for externalsampler that needs to be removed once
DPPL has been updated

* minor reorg of two testsets

* set random seeds more aggressively in an attempt to make tests more reproducible

* removed hack, awaiting PR to DynamicPPL

* renamed `_getmodel` to `getmodel`, `_setmodel` to `setmodel`, and
`_varinfo` to `varinfo_from_logdensityfn`

* missed some instances during rnenaming

* fixed missing merge in initial step for experimental `Gibbs`

* Always reconstruct `ADGradientWrapper` using the `adype` available in `ExternalSampler`

* Test Gibbs with different adtype in externalsampler to ensure that works

* Update Project.toml

* Update Project.toml

---------

Co-authored-by: Hong Ge <[email protected]>
  • Loading branch information
torfjelde and yebai committed Jul 16, 2024
1 parent 142dab3 commit 29a1342
Show file tree
Hide file tree
Showing 6 changed files with 342 additions and 118 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.33.1"
version = "0.33.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -63,7 +63,7 @@ Distributions = "0.23.3, 0.24, 0.25"
DistributionsAD = "0.6"
DocStringExtensions = "0.8, 0.9"
DynamicHMC = "3.4"
DynamicPPL = "0.28"
DynamicPPL = "0.28.1"
Compat = "4.15.0"
EllipticalSliceSampling = "0.5, 1, 2"
ForwardDiff = "0.10.3"
Expand Down
53 changes: 36 additions & 17 deletions src/experimental/gibbs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ function DynamicPPL.dot_tilde_assume(context::GibbsContext, right, left, vns, vi
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, values), vi
return value, broadcast_logpdf(right, value), vi
end

# Otherwise, falls back to the default behavior.
Expand All @@ -90,8 +90,8 @@ function DynamicPPL.dot_tilde_assume(
)
# Short-circuits the tilde assume if `vn` is present in `context`.
if has_conditioned_gibbs(context, vns)
values = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return values, broadcast_logpdf(right, values), vi
value = reconstruct_getvalue(right, get_conditioned_gibbs(context, vns))
return value, broadcast_logpdf(right, value), vi
end

# Otherwise, falls back to the default behavior.
Expand Down Expand Up @@ -144,14 +144,14 @@ end
Return a `GibbsContext` with the values extracted from the given `varinfos` treated as conditioned.
"""
function condition_gibbs(context::DynamicPPL.AbstractContext, varinfo::DynamicPPL.AbstractVarInfo)
return DynamicPPL.condition(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))
return condition_gibbs(context, DynamicPPL.values_as(varinfo, preferred_value_type(varinfo)))
end
function DynamicPPL.condition(
function condition_gibbs(
context::DynamicPPL.AbstractContext,
varinfo::DynamicPPL.AbstractVarInfo,
varinfos::DynamicPPL.AbstractVarInfo...
)
return DynamicPPL.condition(DynamicPPL.condition(context, varinfo), varinfos...)
return condition_gibbs(condition_gibbs(context, varinfo), varinfos...)
end
# Allow calling this on a `DynamicPPL.Model` directly.
function condition_gibbs(model::DynamicPPL.Model, values...)
Expand Down Expand Up @@ -238,6 +238,9 @@ function Gibbs(algs::Pair...)
return Gibbs(map(first, algs), map(wrap_algorithm_maybe, map(last, algs)))
end

# TODO: Remove when no longer needed.
DynamicPPL.getspace(::Gibbs) = ()

struct GibbsState{V<:DynamicPPL.AbstractVarInfo,S}
vi::V
states::S
Expand All @@ -252,6 +255,7 @@ function DynamicPPL.initialstep(
model::DynamicPPL.Model,
spl::DynamicPPL.Sampler{<:Gibbs},
vi_base::DynamicPPL.AbstractVarInfo;
initial_params=nothing,
kwargs...,
)
alg = spl.alg
Expand All @@ -260,15 +264,35 @@ function DynamicPPL.initialstep(

# 1. Run the model once to get the varnames present + initial values to condition on.
vi_base = DynamicPPL.VarInfo(model)

# Simple way of setting the initial parameters: set them in the `vi_base`
# if they are given so they propagate to the subset varinfos used by each sampler.
if initial_params !== nothing
vi_base = DynamicPPL.unflatten(vi_base, initial_params)
end

# Create the varinfos for each sampler.
varinfos = map(Base.Fix1(DynamicPPL.subset, vi_base) _maybevec, varnames)
initial_params_all = if initial_params === nothing
fill(nothing, length(varnames))
else
# Extract from the `vi_base`, which should have the values set correctly from above.
map(vi -> vi[:], varinfos)
end

# 2. Construct a varinfo for every vn + sampler combo.
states_and_varinfos = map(samplers, varinfos) do sampler_local, varinfo_local
states_and_varinfos = map(samplers, varinfos, initial_params_all) do sampler_local, varinfo_local, initial_params_local
# Construct the conditional model.
model_local = make_conditional(model, varinfo_local, varinfos)

# Take initial step.
new_state_local = last(AbstractMCMC.step(rng, model_local, sampler_local; kwargs...))
new_state_local = last(AbstractMCMC.step(
rng, model_local, sampler_local;
# FIXME: This will cause issues if the sampler expects initial params in unconstrained space.
# This is not the case for any samplers in Turing.jl, but will be for external samplers, etc.
initial_params=initial_params_local,
kwargs...
))

# Return the new state and the invlinked `varinfo`.
vi_local_state = Turing.Inference.varinfo(new_state_local)
Expand All @@ -284,7 +308,7 @@ function DynamicPPL.initialstep(
varinfos = map(last, states_and_varinfos)

# Update the base varinfo from the first varinfo and replace it.
varinfos_new = DynamicPPL.setindex!!(varinfos, vi_base, 1)
varinfos_new = DynamicPPL.setindex!!(varinfos, merge(vi_base, first(varinfos)), 1)
# Merge the updated initial varinfo with the rest of the varinfos + update the logp.
vi = DynamicPPL.setlogp!!(
reduce(merge, varinfos_new),
Expand Down Expand Up @@ -365,12 +389,7 @@ function gibbs_requires_recompute_logprob(model_dst, sampler_dst, sampler_src, s
end

# TODO: Remove `rng`?
"""
recompute_logprob!!(rng, model, sampler, state)
Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
"""
function recompute_logprob!!(
function Turing.Inference.recompute_logprob!!(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler,
Expand Down Expand Up @@ -436,7 +455,7 @@ function gibbs_step_inner(
state_local,
state_previous
)
current_state_local = recompute_logprob!!(
state_local = Turing.Inference.recompute_logprob!!(
rng,
model_local,
sampler_local,
Expand All @@ -450,7 +469,7 @@ function gibbs_step_inner(
rng,
model_local,
sampler_local,
current_state_local;
state_local;
kwargs...,
),
)
Expand Down
2 changes: 2 additions & 0 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct ExternalSampler{S<:AbstractSampler,AD<:ADTypes.AbstractADType,Unconstrain
end
end

DynamicPPL.getspace(::ExternalSampler) = ()

"""
requires_unconstrained_space(sampler::ExternalSampler)
Expand Down
108 changes: 105 additions & 3 deletions src/mcmc/abstractmcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,61 @@ function transition_to_turing(f::LogDensityProblemsAD.ADGradientWrapper, transit
return transition_to_turing(parent(f), transition)
end

"""
getmodel(f)
Return the `DynamicPPL.Model` wrapped in the given log-density function `f`.
"""
getmodel(f::LogDensityProblemsAD.ADGradientWrapper) = getmodel(parent(f))
getmodel(f::DynamicPPL.LogDensityFunction) = f.model

"""
setmodel(f, model[, adtype])
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
!!! warning
Note that if `f` is a `LogDensityProblemsAD.ADGradientWrapper` wrapping a
`DynamicPPL.LogDensityFunction`, performing an update of the `model` in `f`
might require recompilation of the gradient tape, depending on the AD backend.
"""
function setmodel(
f::LogDensityProblemsAD.ADGradientWrapper,
model::DynamicPPL.Model,
adtype::ADTypes.AbstractADType
)
# TODO: Should we handle `SciMLBase.NoAD`?
# For an `ADGradientWrapper` we do the following:
# 1. Update the `Model` in the underlying `LogDensityFunction`.
# 2. Re-construct the `ADGradientWrapper` using `ADgradient` using the provided `adtype`
# to ensure that the recompilation of gradient tapes, etc. also occur. For example,
# ReverseDiff.jl in compiled mode will cache the compiled tape, which means that just
# replacing the corresponding field with the new model won't be sufficient to obtain
# the correct gradients.
return LogDensityProblemsAD.ADgradient(adtype, setmodel(parent(f), model))
end
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return Accessors.@set f.model = model
end

function varinfo_from_logdensityfn(f::LogDensityProblemsAD.ADGradientWrapper)
return varinfo_from_logdensityfn(parent(f))
end
varinfo_from_logdensityfn(f::DynamicPPL.LogDensityFunction) = f.varinfo

function varinfo(state::TuringState)
θ = getparams(getmodel(state.logdensity), state.state)
# TODO: Do we need to link here first?
return DynamicPPL.unflatten(varinfo_from_logdensityfn(state.logdensity), θ)
end

# NOTE: Only thing that depends on the underlying sampler.
# Something similar should be part of AbstractMCMC at some point:
# https://github.com/TuringLang/AbstractMCMC.jl/pull/86
getparams(::DynamicPPL.Model, transition::AdvancedHMC.Transition) = transition.z.θ
function getparams(model::DynamicPPL.Model, state::AdvancedHMC.HMCState)
return getparams(model, state.transition)
end
getstats(transition::AdvancedHMC.Transition) = transition.stat

getparams(::DynamicPPL.Model, transition::AdvancedMH.Transition) = transition.params
Expand All @@ -33,13 +84,59 @@ function setvarinfo(f::LogDensityProblemsAD.ADGradientWrapper, varinfo)
return Accessors.@set f.= setvarinfo(f.ℓ, varinfo)
end

"""
recompute_logprob!!(rng, model, sampler, state)
Recompute the log-probability of the `model` based on the given `state` and return the resulting state.
"""
function recompute_logprob!!(
rng::Random.AbstractRNG, # TODO: Do we need the `rng` here?
model::DynamicPPL.Model,
sampler::DynamicPPL.Sampler{<:ExternalSampler},
state,
)
# Re-using the log-density function from the `state` and updating only the `model` field,
# since the `model` might now contain different conditioning values.
f = setmodel(state.logdensity, model, sampler.alg.adtype)
# Recompute the log-probability with the new `model`.
state_inner = recompute_logprob!!(
rng, AbstractMCMC.LogDensityModel(f), sampler.alg.sampler, state.state
)
return state_to_turing(f, state_inner)
end

function recompute_logprob!!(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedHMC.AbstractHMCSampler,
state::AdvancedHMC.HMCState,
)
# Construct hamiltionian.
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
# Re-compute the log-probability and gradient.
return Accessors.@set state.transition.z = AdvancedHMC.phasepoint(
hamiltonian, state.transition.z.θ, state.transition.z.r
)
end

function recompute_logprob!!(
rng::Random.AbstractRNG,
model::AbstractMCMC.LogDensityModel,
sampler::AdvancedMH.MetropolisHastings,
state::AdvancedMH.Transition,
)
logdensity = model.logdensity
return Accessors.@set state.lp = LogDensityProblems.logdensity(logdensity, state.params)
end

# TODO: Do we also support `resume`, etc?
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler};
initial_state=nothing,
initial_params=nothing,
kwargs...
kwargs...,
)
alg = sampler_wrapper.alg
sampler = alg.sampler
Expand Down Expand Up @@ -69,7 +166,12 @@ function AbstractMCMC.step(
)
else
transition_inner, state_inner = AbstractMCMC.step(
rng, AbstractMCMC.LogDensityModel(f), sampler, initial_state; initial_params, kwargs...
rng,
AbstractMCMC.LogDensityModel(f),
sampler,
initial_state;
initial_params,
kwargs...,
)
end
# Update the `state`
Expand All @@ -81,7 +183,7 @@ function AbstractMCMC.step(
model::DynamicPPL.Model,
sampler_wrapper::Sampler{<:ExternalSampler},
state::TuringState;
kwargs...
kwargs...,
)
sampler = sampler_wrapper.alg.sampler
f = state.logdensity
Expand Down
Loading

2 comments on commit 29a1342

@sunxd3
Copy link
Collaborator

@sunxd3 sunxd3 commented on 29a1342 Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register

Release notes:

The release brought bug fixes and improvement to the new Experimental.Gibbs (👏 for @torfjelde 's heroic efforts)

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/111170

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.33.2 -m "<description of version>" 29a134245b2499d59fa992420eba37ab2b9f5945
git push origin v0.33.2

Please sign in to comment.