@@ -71,6 +71,7 @@ module Control.Monad.Bayes.Class
7171 Measure ,
7272 Kernel ,
7373 Log (ln , Exp ),
74+ MonadMeasureTrans (.. ),
7475 )
7576where
7677
@@ -82,9 +83,11 @@ import Control.Monad.Identity (IdentityT)
8283import Control.Monad.List (ListT )
8384import Control.Monad.Reader (ReaderT )
8485import Control.Monad.State (StateT )
86+ import Control.Monad.Trans (MonadTrans )
8587import Control.Monad.Writer (WriterT )
8688import Data.Histogram qualified as H
8789import Data.Histogram.Fill qualified as H
90+ import Data.Kind (Type )
8891import Data.Matrix
8992 ( Matrix ,
9093 cholDecomp ,
@@ -407,3 +410,41 @@ instance MonadFactor m => MonadFactor (ContT r m) where
407410 score = lift . score
408411
409412instance MonadMeasure m => MonadMeasure (ContT r m )
413+
414+ -- * Utility for deriving MonadDistribution, MonadFactor and MonadMeasure
415+
416+ -- | Newtype to derive 'MonadDistribution', 'MonadFactor' and 'MonadMeasure' automatically for monad transformers.
417+ --
418+ -- The typical usage is with the `StandaloneDeriving` and `DerivingVia` extensions.
419+ -- For example, to derive all instances for the 'IdentityT' transformer, one writes:
420+ --
421+ -- @
422+ -- deriving via (MonadMeasureTrans IdentityT m) instance MonadDistribution m => MonadDistribution (IdentityT m)
423+ -- deriving via (MonadMeasureTrans IdentityT m) instance MonadFactor m => MonadFactor (IdentityT m)
424+ -- instance MonadMeasure m => MonadMeasure (IdentityT m)
425+ -- @
426+ -- (The final 'MonadMeasure' could also be derived `via`, but this isn't necessary because it doesn't contain any methods.)
427+ newtype MonadMeasureTrans (t :: (Type -> Type ) -> Type -> Type ) (m :: Type -> Type ) a = MonadMeasureTrans { getMonadMeasureTrans :: t m a }
428+ deriving (Functor , Applicative , Monad )
429+
430+ instance MonadTrans t => MonadTrans (MonadMeasureTrans t ) where
431+ lift = MonadMeasureTrans . lift
432+
433+ instance (MonadTrans t , MonadDistribution m , Monad (t m )) => MonadDistribution (MonadMeasureTrans t m ) where
434+ random = lift random
435+ uniform = (lift . ) . uniform
436+ normal = (lift . ) . normal
437+ gamma = (lift . ) . gamma
438+ beta = (lift . ) . beta
439+ bernoulli = lift . bernoulli
440+ categorical = lift . categorical
441+ logCategorical = lift . logCategorical
442+ uniformD = lift . uniformD
443+ geometric = lift . geometric
444+ poisson = lift . poisson
445+ dirichlet = lift . dirichlet
446+
447+ instance (MonadFactor m , MonadTrans t , Monad (t m )) => MonadFactor (MonadMeasureTrans t m ) where
448+ score = lift . score
449+
450+ instance (MonadDistribution m , MonadFactor m , MonadTrans t , Monad (t m )) => MonadMeasure (MonadMeasureTrans t m )
0 commit comments