Skip to content

Commit bfdf592

Browse files
committed
Add MonadMeasureTrans newtype
1 parent 1f6d7bc commit bfdf592

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

monad-bayes.cabal

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,14 @@ library
9494

9595
default-extensions:
9696
BlockArguments
97+
DerivingVia
9798
FlexibleContexts
99+
GeneralizedNewtypeDeriving
98100
ImportQualifiedPost
101+
KindSignatures
99102
LambdaCase
100103
OverloadedStrings
104+
StandaloneDeriving
101105
TupleSections
102106

103107
if flag(dev)

src/Control/Monad/Bayes/Class.hs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ module Control.Monad.Bayes.Class
7171
Measure,
7272
Kernel,
7373
Log (ln, Exp),
74+
MonadMeasureTrans (..),
7475
)
7576
where
7677

@@ -82,9 +83,11 @@ import Control.Monad.Identity (IdentityT)
8283
import Control.Monad.List (ListT)
8384
import Control.Monad.Reader (ReaderT)
8485
import Control.Monad.State (StateT)
86+
import Control.Monad.Trans (MonadTrans)
8587
import Control.Monad.Writer (WriterT)
8688
import Data.Histogram qualified as H
8789
import Data.Histogram.Fill qualified as H
90+
import Data.Kind (Type)
8891
import Data.Matrix
8992
( Matrix,
9093
cholDecomp,
@@ -407,3 +410,41 @@ instance MonadFactor m => MonadFactor (ContT r m) where
407410
score = lift . score
408411

409412
instance MonadMeasure m => MonadMeasure (ContT r m)
413+
414+
-- * Utility for deriving MonadDistribution, MonadFactor and MonadMeasure
415+
416+
-- | Newtype to derive 'MonadDistribution', 'MonadFactor' and 'MonadMeasure' automatically for monad transformers.
417+
--
418+
-- The typical usage is with the `StandaloneDeriving` and `DerivingVia` extensions.
419+
-- For example, to derive all instances for the 'IdentityT' transformer, one writes:
420+
--
421+
-- @
422+
-- deriving via (MonadMeasureTrans IdentityT m) instance MonadDistribution m => MonadDistribution (IdentityT m)
423+
-- deriving via (MonadMeasureTrans IdentityT m) instance MonadFactor m => MonadFactor (IdentityT m)
424+
-- instance MonadMeasure m => MonadMeasure (IdentityT m)
425+
-- @
426+
-- (The final 'MonadMeasure' could also be derived `via`, but this isn't necessary because it doesn't contain any methods.)
427+
newtype MonadMeasureTrans (t :: (Type -> Type) -> Type -> Type) (m :: Type -> Type) a = MonadMeasureTrans {getMonadMeasureTrans :: t m a}
428+
deriving (Functor, Applicative, Monad)
429+
430+
instance MonadTrans t => MonadTrans (MonadMeasureTrans t) where
431+
lift = MonadMeasureTrans . lift
432+
433+
instance (MonadTrans t, MonadDistribution m, Monad (t m)) => MonadDistribution (MonadMeasureTrans t m) where
434+
random = lift random
435+
uniform = (lift .) . uniform
436+
normal = (lift .) . normal
437+
gamma = (lift .) . gamma
438+
beta = (lift .) . beta
439+
bernoulli = lift . bernoulli
440+
categorical = lift . categorical
441+
logCategorical = lift . logCategorical
442+
uniformD = lift . uniformD
443+
geometric = lift . geometric
444+
poisson = lift . poisson
445+
dirichlet = lift . dirichlet
446+
447+
instance (MonadFactor m, MonadTrans t, Monad (t m)) => MonadFactor (MonadMeasureTrans t m) where
448+
score = lift . score
449+
450+
instance (MonadDistribution m, MonadFactor m, MonadTrans t, Monad (t m)) => MonadMeasure (MonadMeasureTrans t m)

0 commit comments

Comments
 (0)