Skip to content

Commit 4f6b297

Browse files
committed
Unsuccessfull attempt at fixing SMC2
1 parent 084862f commit 4f6b297

File tree

6 files changed

+56
-17
lines changed

6 files changed

+56
-17
lines changed

src/Control/Monad/Bayes/Inference/RMSMC.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ rmsmc ::
5151
PopulationT m a
5252
rmsmc (MCMCConfig {..}) (SMCConfig {..}) =
5353
marginal
54-
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps
55-
. S.hoistFirst (TrStat.hoist (withParticles numParticles))
54+
. S.sequentially (composeCopies numMCMCSteps (TrStat.hoistModel (single . flatten) . TrStat.hoist (single . flatten) . mhStep) . TrStat.hoist resampler) numSteps
55+
. S.hoistFirst (TrStat.hoistModel (single . flatten) . TrStat.hoist (withParticles numParticles))
5656

5757
-- | Resample-move Sequential Monte Carlo with a more efficient
5858
-- tracing representation.

src/Control/Monad/Bayes/Inference/SMC.hs

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,38 @@ where
2222

2323
import Control.Monad.Bayes.Class (MonadDistribution, MonadMeasure)
2424
import Control.Monad.Bayes.Population
25-
( PopulationT,
25+
( PopulationT (..),
26+
flatten,
2627
pushEvidence,
28+
single,
2729
withParticles,
2830
)
31+
import Control.Monad.Bayes.Population.Applicative qualified as Applicative
2932
import Control.Monad.Bayes.Sequential.Coroutine as Coroutine
33+
import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT
34+
import Control.Monad.Bayes.Weighted (WeightedT (..), weightedT)
35+
import Control.Monad.Coroutine
36+
import Control.Monad.Trans.Free (FreeF (..), FreeT (..))
3037

3138
data SMCConfig m = SMCConfig
3239
{ resampler :: forall x. PopulationT m x -> PopulationT m x,
3340
numSteps :: Int,
3441
numParticles :: Int
3542
}
3643

44+
sequentialToPopulation :: (Monad m) => Coroutine.SequentialT (Applicative.PopulationT m) a -> PopulationT m a
45+
sequentialToPopulation =
46+
PopulationT
47+
. weightedT
48+
. coroutineToFree
49+
. Coroutine.runSequentialT
50+
where
51+
coroutineToFree =
52+
FreeT
53+
. fmap (Free . fmap (\(cont, p) -> either (coroutineToFree . extract) (pure . (,p)) cont))
54+
. Applicative.runPopulationT
55+
. resume
56+
3757
-- | Sequential importance resampling.
3858
-- Basically an SMC template that takes a custom resampler.
3959
smc ::
@@ -42,12 +62,15 @@ smc ::
4262
Coroutine.SequentialT (PopulationT m) a ->
4363
PopulationT m a
4464
smc SMCConfig {..} =
45-
Coroutine.sequentially resampler numSteps
65+
(single . flatten)
66+
. Coroutine.sequentially resampler numSteps
67+
. SequentialT.hoist (single . flatten)
4668
. Coroutine.hoistFirst (withParticles numParticles)
69+
. SequentialT.hoist (single . flatten)
4770

4871
-- | Sequential Monte Carlo with multinomial resampling at each timestep.
4972
-- Weights are normalized at each timestep and the total weight is pushed
5073
-- as a score into the transformed monad.
5174
smcPush ::
5275
(MonadMeasure m) => SMCConfig m -> Coroutine.SequentialT (PopulationT m) a -> PopulationT m a
53-
smcPush config = smc config {resampler = (pushEvidence . resampler config)}
76+
smcPush config = smc config {resampler = (single . flatten . pushEvidence . resampler config)}

src/Control/Monad/Bayes/Inference/SMC2.hs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ import Control.Monad.Bayes.Class
2727
import Control.Monad.Bayes.Inference.MCMC
2828
import Control.Monad.Bayes.Inference.RMSMC (rmsmc)
2929
import Control.Monad.Bayes.Inference.SMC (SMCConfig (SMCConfig, numParticles, numSteps, resampler), smcPush)
30-
import Control.Monad.Bayes.Population as Pop (PopulationT, resampleMultinomial, runPopulationT)
30+
import Control.Monad.Bayes.Population as Pop (PopulationT, flatten, resampleMultinomial, runPopulationT, single)
31+
import Control.Monad.Bayes.Population qualified as PopulationT
3132
import Control.Monad.Bayes.Sequential.Coroutine (SequentialT)
33+
import Control.Monad.Bayes.Sequential.Coroutine qualified as SequentialT
3234
import Control.Monad.Bayes.Traced
3335
import Control.Monad.Trans (MonadTrans (..))
3436
import Numeric.Log (Log)
@@ -71,4 +73,10 @@ smc2 k n p t param m =
7173
rmsmc
7274
MCMCConfig {numMCMCSteps = t, proposal = SingleSiteMH, numBurnIn = 0}
7375
SMCConfig {numParticles = p, numSteps = k, resampler = resampleMultinomial}
74-
(param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . m)
76+
(flattenSequentiallyTraced param >>= setup . runPopulationT . smcPush (SMCConfig {numSteps = k, numParticles = n, resampler = resampleMultinomial}) . flattenSMC2 . m)
77+
78+
flattenSequentiallyTraced :: (Monad m) => SequentialT (TracedT (PopulationT m)) a -> SequentialT (TracedT (PopulationT m)) a
79+
flattenSequentiallyTraced = SequentialT.hoist $ hoistModel (single . flatten) . hoist (single . flatten)
80+
81+
flattenSMC2 :: (Monad m) => SequentialT (PopulationT (SMC2 m)) a -> SequentialT (PopulationT (SMC2 m)) a
82+
flattenSMC2 = SequentialT.hoist $ single . flatten . PopulationT.hoist (SMC2 . flattenSequentiallyTraced . setup)

src/Control/Monad/Bayes/Population.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ pushEvidence ::
238238
(MonadFactor m) =>
239239
PopulationT m a ->
240240
PopulationT m a
241-
pushEvidence = hoist applyWeight . extractEvidence
241+
pushEvidence = single . flatten . hoist applyWeight . extractEvidence
242242

243243
-- | A properly weighted single sample, that is one picked at random according
244244
-- to the weights, with the sum of all weights.

src/Control/Monad/Bayes/Sequential/Coroutine.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ module Control.Monad.Bayes.Sequential.Coroutine
2222
hoist,
2323
sequentially,
2424
sis,
25+
runSequentialT,
26+
extract,
2527
)
2628
where
2729

src/Control/Monad/Bayes/Traced/Static.hs

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
module Control.Monad.Bayes.Traced.Static
1313
( TracedT (..),
1414
hoist,
15+
hoistModel,
1516
marginal,
1617
mhStep,
1718
mh,
@@ -25,6 +26,7 @@ import Control.Monad.Bayes.Class
2526
MonadMeasure,
2627
)
2728
import Control.Monad.Bayes.Density.Free (DensityT)
29+
import Control.Monad.Bayes.Density.Free qualified as DensityT
2830
import Control.Monad.Bayes.Traced.Common
2931
( Trace (..),
3032
bind,
@@ -33,6 +35,7 @@ import Control.Monad.Bayes.Traced.Common
3335
singleton,
3436
)
3537
import Control.Monad.Bayes.Weighted (WeightedT)
38+
import Control.Monad.Bayes.Weighted qualified as WeightedT
3639
import Control.Monad.Trans (MonadTrans (..))
3740
import Data.List.NonEmpty as NE (NonEmpty ((:|)), toList)
3841

@@ -72,6 +75,9 @@ instance (MonadMeasure m) => MonadMeasure (TracedT m)
7275
hoist :: (forall x. m x -> m x) -> TracedT m a -> TracedT m a
7376
hoist f (TracedT m d) = TracedT m (f d)
7477

78+
hoistModel :: (Monad m) => (forall x. m x -> m x) -> TracedT m a -> TracedT m a
79+
hoistModel f (TracedT m d) = TracedT (WeightedT.hoist (DensityT.hoist f) m) d
80+
7581
-- | Discard the trace and supporting infrastructure.
7682
marginal :: (Monad m) => TracedT m a -> m a
7783
marginal (TracedT _ d) = fmap output d
@@ -98,15 +104,15 @@ mhStep (TracedT m d) = TracedT m d'
98104
-- * What is the probability that it is the weekend?
99105
--
100106
-- >>> :{
101-
-- let
102-
-- bus = do x <- bernoulli (2/7)
103-
-- let rate = if x then 3 else 10
104-
-- factor $ poissonPdf rate 4
105-
-- return x
106-
-- mhRunBusSingleObs = do
107-
-- let nSamples = 2
108-
-- sampleIOfixed $ unweighted $ mh nSamples bus
109-
-- in mhRunBusSingleObs
107+
-- let
108+
-- bus = do x <- bernoulli (2/7)
109+
-- let rate = if x then 3 else 10
110+
-- factor $ poissonPdf rate 4
111+
-- return x
112+
-- mhRunBusSingleObs = do
113+
-- let nSamples = 2
114+
-- sampleIOfixed $ unweighted $ mh nSamples bus
115+
-- in mhRunBusSingleObs
110116
-- :}
111117
-- [True,True,True]
112118
--

0 commit comments

Comments
 (0)