Skip to content

Commit 8596ad1

Browse files
committed
Refactor existing instances with MonadMeasureTrans
1 parent bfdf592 commit 8596ad1

File tree

2 files changed

+17
-43
lines changed

2 files changed

+17
-43
lines changed

src/Control/Monad/Bayes/Class.hs

Lines changed: 14 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -345,69 +345,45 @@ histogramToList = H.asList
345345
----------------------------------------------------------------------------
346346
-- Instances that lift probabilistic effects to standard tranformers.
347347

348-
instance MonadDistribution m => MonadDistribution (IdentityT m) where
349-
random = lift random
350-
bernoulli = lift . bernoulli
348+
deriving via (MonadMeasureTrans IdentityT m) instance MonadDistribution m => MonadDistribution (IdentityT m)
351349

352-
instance MonadFactor m => MonadFactor (IdentityT m) where
353-
score = lift . score
350+
deriving via (MonadMeasureTrans IdentityT m) instance MonadFactor m => MonadFactor (IdentityT m)
354351

355352
instance MonadMeasure m => MonadMeasure (IdentityT m)
356353

357-
instance MonadDistribution m => MonadDistribution (ExceptT e m) where
358-
random = lift random
359-
uniformD = lift . uniformD
354+
deriving via (MonadMeasureTrans (ExceptT e) m) instance MonadDistribution m => MonadDistribution (ExceptT e m)
360355

361-
instance MonadFactor m => MonadFactor (ExceptT e m) where
362-
score = lift . score
356+
deriving via (MonadMeasureTrans (ExceptT e) m) instance MonadFactor m => MonadFactor (ExceptT e m)
363357

364358
instance MonadMeasure m => MonadMeasure (ExceptT e m)
365359

366-
instance MonadDistribution m => MonadDistribution (ReaderT r m) where
367-
random = lift random
368-
bernoulli = lift . bernoulli
360+
deriving via (MonadMeasureTrans (ReaderT r) m) instance MonadDistribution m => MonadDistribution (ReaderT r m)
369361

370-
instance MonadFactor m => MonadFactor (ReaderT r m) where
371-
score = lift . score
362+
deriving via (MonadMeasureTrans (ReaderT r) m) instance MonadFactor m => MonadFactor (ReaderT r m)
372363

373364
instance MonadMeasure m => MonadMeasure (ReaderT r m)
374365

375-
instance (Monoid w, MonadDistribution m) => MonadDistribution (WriterT w m) where
376-
random = lift random
377-
bernoulli = lift . bernoulli
378-
categorical = lift . categorical
366+
deriving via (MonadMeasureTrans (WriterT w) m) instance (Monoid w, MonadDistribution m) => MonadDistribution (WriterT w m)
379367

380-
instance (Monoid w, MonadFactor m) => MonadFactor (WriterT w m) where
381-
score = lift . score
368+
deriving via (MonadMeasureTrans (WriterT w) m) instance (Monoid w, MonadFactor m) => MonadFactor (WriterT w m)
382369

383370
instance (Monoid w, MonadMeasure m) => MonadMeasure (WriterT w m)
384371

385-
instance MonadDistribution m => MonadDistribution (StateT s m) where
386-
random = lift random
387-
bernoulli = lift . bernoulli
388-
categorical = lift . categorical
389-
uniformD = lift . uniformD
372+
deriving via (MonadMeasureTrans (StateT s) m) instance MonadDistribution m => MonadDistribution (StateT s m)
390373

391-
instance MonadFactor m => MonadFactor (StateT s m) where
392-
score = lift . score
374+
deriving via (MonadMeasureTrans (StateT s) m) instance MonadFactor m => MonadFactor (StateT s m)
393375

394376
instance MonadMeasure m => MonadMeasure (StateT s m)
395377

396-
instance MonadDistribution m => MonadDistribution (ListT m) where
397-
random = lift random
398-
bernoulli = lift . bernoulli
399-
categorical = lift . categorical
378+
deriving via (MonadMeasureTrans ListT m) instance MonadDistribution m => MonadDistribution (ListT m)
400379

401-
instance MonadFactor m => MonadFactor (ListT m) where
402-
score = lift . score
380+
deriving via (MonadMeasureTrans ListT m) instance MonadFactor m => MonadFactor (ListT m)
403381

404382
instance MonadMeasure m => MonadMeasure (ListT m)
405383

406-
instance MonadDistribution m => MonadDistribution (ContT r m) where
407-
random = lift random
384+
deriving via (MonadMeasureTrans (ContT r) m) instance MonadDistribution m => MonadDistribution (ContT r m)
408385

409-
instance MonadFactor m => MonadFactor (ContT r m) where
410-
score = lift . score
386+
deriving via (MonadMeasureTrans (ContT r) m) instance MonadFactor m => MonadFactor (ContT r m)
411387

412388
instance MonadMeasure m => MonadMeasure (ContT r m)
413389

src/Control/Monad/Bayes/Sequential/Coroutine.hs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ module Control.Monad.Bayes.Sequential.Coroutine
2626
where
2727

2828
import Control.Monad.Bayes.Class
29-
( MonadDistribution (bernoulli, categorical, random),
29+
( MonadDistribution,
3030
MonadFactor (..),
3131
MonadMeasure,
32+
MonadMeasureTrans (..),
3233
)
3334
import Control.Monad.Coroutine
3435
( Coroutine (..),
@@ -54,10 +55,7 @@ newtype Sequential m a = Sequential {runSequential :: Coroutine (Await ()) m a}
5455
extract :: Await () a -> a
5556
extract (Await f) = f ()
5657

57-
instance MonadDistribution m => MonadDistribution (Sequential m) where
58-
random = lift random
59-
bernoulli = lift . bernoulli
60-
categorical = lift . categorical
58+
deriving via (MonadMeasureTrans Sequential m) instance MonadDistribution m => MonadDistribution (Sequential m)
6159

6260
-- | Execution is 'suspend'ed after each 'score'.
6361
instance MonadFactor m => MonadFactor (Sequential m) where

0 commit comments

Comments
 (0)