Skip to content

Commit 3582ba8

Browse files
Manuel Bärenzturion
authored andcommitted
Fix RMSMC algorithms
1 parent d050182 commit 3582ba8

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

src/Control/Monad/Bayes/Inference/RMSMC.hs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import Control.Monad.Bayes.Inference.MCMC (MCMCConfig (..))
2525
import Control.Monad.Bayes.Inference.SMC
2626
import Control.Monad.Bayes.Population
2727
( PopulationT,
28-
spawn,
28+
flatten,
2929
withParticles,
3030
)
3131
import Control.Monad.Bayes.Sequential.Coroutine as Seq
@@ -50,8 +50,8 @@ rmsmc ::
5050
PopulationT m a
5151
rmsmc (MCMCConfig {..}) (SMCConfig {..}) =
5252
marginal
53-
. S.sequentially (composeCopies numMCMCSteps mhStep . TrStat.hoist resampler) numSteps
54-
. S.hoistFirst (TrStat.hoist (spawn numParticles >>))
53+
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoist flatten . mhStep) . TrStat.hoist resampler) numSteps
54+
. S.hoistFirst (TrStat.hoist (withParticles numParticles))
5555

5656
-- | Resample-move Sequential Monte Carlo with a more efficient
5757
-- tracing representation.
@@ -64,7 +64,7 @@ rmsmcBasic ::
6464
PopulationT m a
6565
rmsmcBasic (MCMCConfig {..}) (SMCConfig {..}) =
6666
TrBas.marginal
67-
. S.sequentially (composeCopies numMCMCSteps TrBas.mhStep . TrBas.hoist resampler) numSteps
67+
. S.sequentially (TrBas.hoist flatten . composeCopies numMCMCSteps (TrBas.hoist flatten . TrBas.mhStep) . TrBas.hoist resampler) numSteps
6868
. S.hoistFirst (TrBas.hoist (withParticles numParticles))
6969

7070
-- | A variant of resample-move Sequential Monte Carlo
@@ -79,7 +79,7 @@ rmsmcDynamic ::
7979
PopulationT m a
8080
rmsmcDynamic (MCMCConfig {..}) (SMCConfig {..}) =
8181
TrDyn.marginal
82-
. S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps TrDyn.mhStep . TrDyn.hoist resampler) numSteps
82+
. S.sequentially (TrDyn.freeze . composeCopies numMCMCSteps (TrDyn.hoist flatten . TrDyn.mhStep) . TrDyn.hoist resampler) numSteps
8383
. S.hoistFirst (TrDyn.hoist (withParticles numParticles))
8484

8585
-- | Apply a function a given number of times.

src/Control/Monad/Bayes/Population.hs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ module Control.Monad.Bayes.Population
3434
collapse,
3535
popAvg,
3636
withParticles,
37+
flatten,
3738
)
3839
where
3940

@@ -274,3 +275,7 @@ hoist ::
274275
PopulationT m a ->
275276
PopulationT n a
276277
hoist f = PopulationT . Weighted.hoist (hoistFreeT f) . getPopulationT
278+
279+
-- | Flatten all layers of the free structure
280+
flatten :: (Monad m) => PopulationT m a -> PopulationT m a
281+
flatten = fromWeightedList . runPopulationT

0 commit comments

Comments
 (0)