You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Right now, if you call sample(::Model, ::InferenceAlgorithm, ::Int) this first goes to src/mcmc/Inference.jl where the InferenceAlgorithm gets wrapped in DynamicPPL.Sampler, e.g.
(this signature claims to work on AbstractModel, it only really works for DynamicPPL.Model)
Inside here, we finally construct a LogDensityFunction from the model. So, there are very many steps between the time that sample() is called, and the time where a LogDensityFunction is actually constructed.
Proposal
Rework everything below the very first call to accept LogDensityFunction rather than Model. That is to say, the method sample(::Model, ::InferenceAlgorithm, ::Int) should look something like this:
functionsample(rng::Random.AbstracRNG, model::Model, alg::InferenceAlgorithm, N::Int, kwargs...)
adtype =get_adtype(alg) # returns nothing (e.g. for MH) or AbstractADType (e.g. NUTS)
ldf = DynamicPPL.LogDensityFunction(model; adtype=adtype)
spl = DynamicPPL.Sampler(alg)
sample(rng, ldf, spl, N; kwargs...)
end# similar methods for sample(::Random.AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm})# as well as sample(:Random.AbstractRNG, ::LogDensityFunction, ::InferenceAlgorithm)functionsample(
rng::Random.AbstractRNG,
ldf::DynamicPPL.LogDensityFunction,
spl::DynamicPPL.Sampler{<:InferenceAlgorithm}, N::Int;
chain_type=MCMCChains.Chains, check=true, ...
)
# All of Turing's magic behaviour, e.g. checking the model, setting the chain_type,# should happen in this method.
check &&check_model(ldf.model)
... (whatever else)
# Then we can call mcmcsample, and that means that everything from mcmcsample# onwards _knows_ that it's dealing with a LogDensityFunction and a Sampler{...}.
AbstractMCMC.mcmcsample(...)
end# handle rng-less methods (ugly boilerplate, because somebody decided that the *optional*# rng argument should be the first argument and not a keyword argument, ugh!)functionsample(model::Model, alg::InferenceAlgorithm, N::Int, kwargs...)
sample(Random.default_rng(), model, alg, N; kwargs...)
end# similar for sample(::Model, ::Sampler{<:InferenceAlgorithm})# as well as sample(::LogDensityFunction, ::InferenceAlgorithm)# as well as sample(::LogDensityFunction, ::Sampler{<:InferenceAlgorithm})# oh don't forget to handle the methods with MCMCEnsemble... argh...# so that's 16 methods to make sure that everything works correctly# - 2x with/without rng# - 2x for Model and LDF# - 2x for InferenceAlgorithm and Sampler{...}# - 2x for standard and parallel
This would require making several changes across DynamicPPL and Turing. It (thankfully) probably does not need to touch AbstractMCMC, as long as we make LogDensityFunction a subtype of AbstractMCMC.AbstractModel (so that mcmcsample can work). That should be fine, because AbstractModel has no interface.
Why?
For one, this is probably the best way to let people have greater control over their sampling process. For example:
If you want to use a particular type of VarInfo, this interface would allow you to construct it, make the LogDensityFunction, then pass it to sample(). Right now, this is actually very difficult to do. (Just try it!!) Note that this also provides a natural interface for opting into ThreadSafeVarInfo (cf. ThreadSafeVarInfo and threadid DynamicPPL.jl#924)
More philosophically, it's IMO the first step that's necessary towards encapsulating Turing's "magic behaviour" at the very top level of the call stack. We know that a DynamicPPL.Model on its own does not actually give enough information about how to evaluate it — it's only LogDensityFunction that contains the necessary information. Thus, it shouldn't be the job of the low-level functions like step to make this decision — they should just 'receive' objects that are already complete.
I'm still not sure about
How this will work with Gibbs. I haven't looked at it deeply enough.
The text was updated successfully, but these errors were encountered:
FYI, there is ongoing discussion @sunxd3 and I about refactoring all modelling APIs like condition, fix, returned, rand (sampling from the prior), logdensity, etc, around this new NamedLogDensity object instead of DynamicPPL.Model. However, that is a separate concern from this issue. (related: TuringLang/AbstractPPL.jl#108; and likely requires TuringLang/DynamicPPL.jl#900)
Uh oh!
There was an error while loading. Please reload this page.
Current situation
Right now, if you call
sample(::Model, ::InferenceAlgorithm, ::Int)
this first goes tosrc/mcmc/Inference.jl
where theInferenceAlgorithm
gets wrapped inDynamicPPL.Sampler
, e.g.Turing.jl/src/mcmc/Inference.jl
Lines 268 to 278 in 5acc97f
This then goes to
src/mcmc/$sampler.jl
which defines the methodssample(::Model, ::Sampler{<:InferenceAlgorithm}, ::Int)
, e.g.Turing.jl/src/mcmc/hmc.jl
Lines 82 to 104 in 5acc97f
This then goes to AbstractMCMC's
sample
:https://github.com/TuringLang/AbstractMCMC.jl/blob/fdaa0ebce22ce227b068e847415cd9ee0e15c004/src/sample.jl#L255-L259
Which then calls
step(::AbstractRNG, ::Model, ::Sampler{<:InferenceAlgorithm})
, which is defined in DynamicPPL:https://github.com/TuringLang/DynamicPPL.jl/blob/072234d094d1d68064bf259d3c3e815a87c18c8e/src/sampler.jl#L108-L126
Which then calls
initialstep
, which goes back to being defined insrc/mcmc/$sampler.jl
:Turing.jl/src/mcmc/hmc.jl
Lines 141 to 149 in 5acc97f
(this signature claims to work on
AbstractModel
, it only really works forDynamicPPL.Model
)Inside here, we finally construct a
LogDensityFunction
from the model. So, there are very many steps between the time thatsample()
is called, and the time where aLogDensityFunction
is actually constructed.Proposal
Rework everything below the very first call to accept
LogDensityFunction
rather thanModel
. That is to say, the methodsample(::Model, ::InferenceAlgorithm, ::Int)
should look something like this:This would require making several changes across DynamicPPL and Turing. It (thankfully) probably does not need to touch AbstractMCMC, as long as we make
LogDensityFunction
a subtype ofAbstractMCMC.AbstractModel
(so thatmcmcsample
can work). That should be fine, becauseAbstractModel
has no interface.Why?
For one, this is probably the best way to let people have greater control over their sampling process. For example:
sample()
. Right now, this is actually very difficult to do. (Just try it!!) Note that this also provides a natural interface for opting into ThreadSafeVarInfo (cf.ThreadSafeVarInfo
andthreadid
DynamicPPL.jl#924)More philosophically, it's IMO the first step that's necessary towards encapsulating Turing's "magic behaviour" at the very top level of the call stack. We know that a
DynamicPPL.Model
on its own does not actually give enough information about how to evaluate it — it's onlyLogDensityFunction
that contains the necessary information. Thus, it shouldn't be the job of the low-level functions likestep
to make this decision — they should just 'receive' objects that are already complete.I'm still not sure about
How this will work with Gibbs. I haven't looked at it deeply enough.
The text was updated successfully, but these errors were encountered: