Skip to content

Rework sample() call stack to use LogDensityFunction #2555

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
penelopeysm opened this issue May 20, 2025 · 2 comments · May be fixed by #2566
Open

Rework sample() call stack to use LogDensityFunction #2555

penelopeysm opened this issue May 20, 2025 · 2 comments · May be fixed by #2566

Comments

@penelopeysm
Copy link
Member

penelopeysm commented May 20, 2025

Current situation

Right now, if you call sample(::Model, ::InferenceAlgorithm, ::Int) this first goes to src/mcmc/Inference.jl where the InferenceAlgorithm gets wrapped in DynamicPPL.Sampler, e.g.

function AbstractMCMC.sample(
rng::AbstractRNG,
model::AbstractModel,
alg::InferenceAlgorithm,
N::Integer;
check_model::Bool=true,
kwargs...,
)
check_model && _check_model(model, alg)
return AbstractMCMC.sample(rng, model, Sampler(alg), N; kwargs...)
end

This then goes to src/mcmc/$sampler.jl which defines the methods sample(::Model, ::Sampler{<:InferenceAlgorithm}, ::Int), e.g.

Turing.jl/src/mcmc/hmc.jl

Lines 82 to 104 in 5acc97f

function AbstractMCMC.sample(
rng::AbstractRNG,
model::DynamicPPL.Model,
sampler::Sampler{<:AdaptiveHamiltonian},
N::Integer;
chain_type=DynamicPPL.default_chain_type(sampler),
resume_from=nothing,
initial_state=DynamicPPL.loadstate(resume_from),
progress=PROGRESS[],
nadapts=sampler.alg.n_adapts,
discard_adapt=true,
discard_initial=-1,
kwargs...,
)
if resume_from === nothing
# If `nadapts` is `-1`, then the user called a convenience
# constructor like `NUTS()` or `NUTS(0.65)`,
# and we should set a default for them.
if nadapts == -1
_nadapts = min(1000, N ÷ 2)
else
_nadapts = nadapts
end

This then goes to AbstractMCMC's sample:

https://github.com/TuringLang/AbstractMCMC.jl/blob/fdaa0ebce22ce227b068e847415cd9ee0e15c004/src/sample.jl#L255-L259

Which then calls step(::AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm}), which is defined in DynamicPPL:

https://github.com/TuringLang/DynamicPPL.jl/blob/072234d094d1d68064bf259d3c3e815a87c18c8e/src/sampler.jl#L108-L126

Which then calls initialstep, which goes back to being defined in src/mcmc/$sampler.jl:

Turing.jl/src/mcmc/hmc.jl

Lines 141 to 149 in 5acc97f

function DynamicPPL.initialstep(
rng::AbstractRNG,
model::AbstractModel,
spl::Sampler{<:Hamiltonian},
vi_original::AbstractVarInfo;
initial_params=nothing,
nadapts=0,
kwargs...,
)

(this signature claims to work on AbstractModel, it only really works for DynamicPPL.Model)

Inside here, we finally construct a LogDensityFunction from the model. So, there are very many steps between the time that sample() is called, and the time where a LogDensityFunction is actually constructed.

Proposal

Rework everything below the very first call to accept LogDensityFunction rather than Model. That is to say, the method sample(::Model, ::InferenceAlgorithm, ::Int) should look something like this:

function sample(rng::Random.AbstracRNG, model::Model, alg::InferenceAlgorithm, N::Int, kwargs...)
    adtype = get_adtype(alg) # returns nothing (e.g. for MH) or AbstractADType (e.g. NUTS)
    ldf = DynamicPPL.LogDensityFunction(model; adtype=adtype)
    spl = DynamicPPL.Sampler(alg)
    sample(rng, ldf, spl, N; kwargs...)
end

# similar methods for sample(::Random.AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm})
# as well as sample(:Random.AbstractRNG, ::LogDensityFunction, ::InferenceAlgorithm)

function sample(
    rng::Random.AbstractRNG,
    ldf::DynamicPPL.LogDensityFunction,
    spl::DynamicPPL.Sampler{<:InferenceAlgorithm}, N::Int;
    chain_type=MCMCChains.Chains, check=true, ...
)
    # All of Turing's magic behaviour, e.g. checking the model, setting the chain_type,
    # should happen in this method.
    check && check_model(ldf.model)
    ... (whatever else)
    # Then we can call mcmcsample, and that means that everything from mcmcsample
    # onwards _knows_ that it's dealing with a LogDensityFunction and a Sampler{...}.
    AbstractMCMC.mcmcsample(...)
end

# handle rng-less methods (ugly boilerplate, because somebody decided that the *optional*
# rng argument should be the first argument and not a keyword argument, ugh!)
function sample(model::Model, alg::InferenceAlgorithm, N::Int, kwargs...)
    sample(Random.default_rng(), model, alg, N; kwargs...)
end
# similar for sample(::Model, ::Sampler{<:InferenceAlgorithm})
#  as well as sample(::LogDensityFunction, ::InferenceAlgorithm)
#  as well as sample(::LogDensityFunction, ::Sampler{<:InferenceAlgorithm})

# oh don't forget to handle the methods with MCMCEnsemble... argh...
# so that's 16 methods to make sure that everything works correctly
# - 2x with/without rng
# - 2x for Model and LDF
# - 2x for InferenceAlgorithm and Sampler{...}
# - 2x for standard and parallel

This would require making several changes across DynamicPPL and Turing. It (thankfully) probably does not need to touch AbstractMCMC, as long as we make LogDensityFunction a subtype of AbstractMCMC.AbstractModel (so that mcmcsample can work). That should be fine, because AbstractModel has no interface.

Why?

For one, this is probably the best way to let people have greater control over their sampling process. For example:

More philosophically, it's IMO the first step that's necessary towards encapsulating Turing's "magic behaviour" at the very top level of the call stack. We know that a DynamicPPL.Model on its own does not actually give enough information about how to evaluate it — it's only LogDensityFunction that contains the necessary information. Thus, it shouldn't be the job of the low-level functions like step to make this decision — they should just 'receive' objects that are already complete.

I'm still not sure about

How this will work with Gibbs. I haven't looked at it deeply enough.

@mhauru
Copy link
Member

mhauru commented May 20, 2025

Sounds reasonable to me, but I haven't thought about this call stack deeply, so my opinion isn't strong.

For Gibbs, I don't immediately see a problem as long as we can call contextualize on an LDF.

@yebai
Copy link
Member

yebai commented May 21, 2025

Great idea. Let's rename LogDensityFunction to NamedLogDensity TuringLang/DynamicPPL.jl#880.

FYI, there is ongoing discussion @sunxd3 and I about refactoring all modelling APIs like condition, fix, returned, rand (sampling from the prior), logdensity, etc, around this new NamedLogDensity object instead of DynamicPPL.Model. However, that is a separate concern from this issue. (related: TuringLang/AbstractPPL.jl#108; and likely requires TuringLang/DynamicPPL.jl#900)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants