Skip to content

Commit 6af6330

Browse files
committed
Remove loadstate and resume_from
1 parent 89a61af commit 6af6330

File tree

3 files changed

+25
-53
lines changed

3 files changed

+25
-53
lines changed

src/mcmc/emcee.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,8 @@ function DynamicPPL.init_strategy(spl::Sampler{<:Emcee})
4141
end
4242

4343
function AbstractMCMC.step(
44-
rng::Random.AbstractRNG,
45-
model::Model,
46-
spl::Sampler{<:Emcee};
47-
resume_from=nothing,
48-
initial_params,
49-
kwargs...,
44+
rng::Random.AbstractRNG, model::Model, spl::Sampler{<:Emcee}; initial_params, kwargs...
5045
)
51-
if resume_from !== nothing
52-
state = loadstate(resume_from)
53-
return AbstractMCMC.step(rng, model, spl, state; kwargs...)
54-
end
55-
5646
# Sample from the prior
5747
n = _get_n_walkers(spl)
5848
vis = [VarInfo(rng, model) for _ in 1:n]

src/mcmc/hmc.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,15 @@ function AbstractMCMC.sample(
8989
sampler::Sampler{<:AdaptiveHamiltonian},
9090
N::Integer;
9191
chain_type=TURING_CHAIN_TYPE,
92-
resume_from=nothing,
9392
initial_params=DynamicPPL.init_strategy(sampler),
94-
initial_state=DynamicPPL.loadstate(resume_from),
93+
initial_state=nothing,
9594
progress=PROGRESS[],
9695
nadapts=sampler.alg.n_adapts,
9796
discard_adapt=true,
9897
discard_initial=-1,
9998
kwargs...,
10099
)
101-
if resume_from === nothing
100+
if initial_state === nothing
102101
# If `nadapts` is `-1`, then the user called a convenience
103102
# constructor like `NUTS()` or `NUTS(0.65)`,
104103
# and we should set a default for them.

src/mcmc/particle_mcmc.jl

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ struct SMC{R} <: ParticleInference
7676
end
7777

7878
"""
79-
SMC([resampler = AdvancedPS.ResampleWithESSThreshold()])
80-
SMC([resampler = AdvancedPS.resample_systematic, ]threshold)
79+
SMC([resampler = AdvancedPS.ResampleWithESSThreshold()])
80+
SMC([resampler = AdvancedPS.resample_systematic, ]threshold)
8181
8282
Create a sequential Monte Carlo sampler of type [`SMC`](@ref).
8383
@@ -111,38 +111,22 @@ function AbstractMCMC.sample(
111111
sampler::Sampler{<:SMC},
112112
N::Integer;
113113
chain_type=TURING_CHAIN_TYPE,
114-
resume_from=nothing,
115114
initial_params=DynamicPPL.init_strategy(sampler),
116-
initial_state=DynamicPPL.loadstate(resume_from),
117115
progress=PROGRESS[],
118116
kwargs...,
119117
)
120-
if resume_from === nothing
121-
return AbstractMCMC.mcmcsample(
122-
rng,
123-
model,
124-
sampler,
125-
N;
126-
chain_type=chain_type,
127-
initial_params=initial_params,
128-
progress=progress,
129-
nparticles=N,
130-
kwargs...,
131-
)
132-
else
133-
return AbstractMCMC.mcmcsample(
134-
rng,
135-
model,
136-
sampler,
137-
N;
138-
chain_type,
139-
initial_params=initial_params,
140-
initial_state,
141-
progress=progress,
142-
nparticles=N,
143-
kwargs...,
144-
)
145-
end
118+
# need to add on the `nparticles` keyword argument for `initialstep` to make use of
119+
return AbstractMCMC.mcmcsample(
120+
rng,
121+
model,
122+
sampler,
123+
N;
124+
chain_type=chain_type,
125+
initial_params=initial_params,
126+
progress=progress,
127+
nparticles=N,
128+
kwargs...,
129+
)
146130
end
147131

148132
function DynamicPPL.initialstep(
@@ -155,7 +139,6 @@ function DynamicPPL.initialstep(
155139
)
156140
# Reset the VarInfo.
157141
vi = DynamicPPL.setacc!!(vi, ProduceLogLikelihoodAccumulator())
158-
set_all_del!(vi)
159142
vi = DynamicPPL.empty!!(vi)
160143

161144
# Create a new set of particles.
@@ -220,8 +203,8 @@ struct PG{R} <: ParticleInference
220203
end
221204

222205
"""
223-
PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold()])
224-
PG(n, [resampler = AdvancedPS.resample_systematic, ]threshold)
206+
PG(n, [resampler = AdvancedPS.ResampleWithESSThreshold()])
207+
PG(n, [resampler = AdvancedPS.resample_systematic, ]threshold)
225208
226209
Create a Particle Gibbs sampler of type [`PG`](@ref) with `n` particles.
227210
@@ -241,7 +224,7 @@ function PG(nparticles::Int, threshold::Real)
241224
end
242225

243226
"""
244-
CSMC(...)
227+
CSMC(...)
245228
246229
Equivalent to [`PG`](@ref).
247230
"""
@@ -345,7 +328,7 @@ end
345328
DynamicPPL.use_threadsafe_eval(::ParticleMCMCContext, ::AbstractVarInfo) = false
346329

347330
"""
348-
get_trace_local_varinfo_maybe(vi::AbstractVarInfo)
331+
get_trace_local_varinfo_maybe(vi::AbstractVarInfo)
349332
350333
Get the `Trace` local varinfo if one exists.
351334
@@ -362,7 +345,7 @@ function get_trace_local_varinfo_maybe(varinfo::AbstractVarInfo)
362345
end
363346

364347
"""
365-
get_trace_local_resampled_maybe(fallback_resampled::Bool)
348+
get_trace_local_resampled_maybe(fallback_resampled::Bool)
366349
367350
Get the `Trace` local `resampled` if one exists.
368351
@@ -379,7 +362,7 @@ function get_trace_local_resampled_maybe(fallback_resampled::Bool)
379362
end
380363

381364
"""
382-
get_trace_local_rng_maybe(rng::Random.AbstractRNG)
365+
get_trace_local_rng_maybe(rng::Random.AbstractRNG)
383366
384367
Get the `Trace` local rng if one exists.
385368
@@ -395,7 +378,7 @@ function get_trace_local_rng_maybe(rng::Random.AbstractRNG)
395378
end
396379

397380
"""
398-
set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
381+
set_trace_local_varinfo_maybe(vi::AbstractVarInfo)
399382
400383
Set the `Trace` local varinfo if executing within a `Trace`. Return `nothing`.
401384
@@ -477,7 +460,7 @@ function AdvancedPS.Trace(
477460
end
478461

479462
"""
480-
ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
463+
ProduceLogLikelihoodAccumulator{T<:Real} <: AbstractAccumulator
481464
482465
Exactly like `LogLikelihoodAccumulator`, but calls `Libtask.produce` on change of value.
483466

0 commit comments

Comments
 (0)