From 2d017c96fad3b7ce30b6ed0193784385c3da4803 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Tue, 7 Nov 2023 11:29:37 +0100 Subject: [PATCH] STASH bind --- src/MeasureBase.jl | 2 +- src/combinators/bind.jl | 39 ++++++++++++++++++++----------------- src/combinators/combined.jl | 5 ++--- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 093793c8..4fc3bf6a 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -28,7 +28,7 @@ import Base.iterate import ConstructionBase using ConstructionBase: constructorof using IntervalSets -using OneTwoMany: secondarg +using OneTwoMany: firstarg, secondarg using PrettyPrinting const Pretty = PrettyPrinting diff --git a/src/combinators/bind.jl b/src/combinators/bind.jl index a70f27eb..fe6af2bf 100644 --- a/src/combinators/bind.jl +++ b/src/combinators/bind.jl @@ -30,12 +30,6 @@ See also [`mbind`](@ref). function mkernel end export mkernel -@inline mkernel(f_β::MKernel) = f_β -@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c) - -@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c) -@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β - """ struct MeasureBase.MKernel <: Function @@ -45,12 +39,20 @@ Represents a generalized monatic transition kernel. User code should not create instances of `MKernel` directly, but should call [`mkernel`](@ref) instead. """ -struct MKernel - f_β::FK +struct MKernel{FT,FC} <: Function + f_β::FT f_c::FC end +@inline mkernel(f_β::MKernel) = f_β +@inline mkernel(f_β, f_c = secondarg) = _generic_mkernel_impl(f_β, f_c) + +@inline _generic_mkernel_impl(f_β, f_c) = MKernel(f_β, f_c) +@inline _generic_mkernel_impl(f_β::MKernel, ::typeof(secondarg)) = f_β + + + @doc raw""" mbind(f_β, α::AbstractMeasure, f_c = OneTwoMany.secondarg) mbind(f_β::MeasureBase.MKernel, α::AbstractMeasure) @@ -102,7 +104,7 @@ The measure `α` that went into the bind can be retrieved via Densities on hierarchical measures can only be evaluated if `ab = f_c(a, b)` can be unambiguously split into `a` and `b` again, knowing `α`. This is -currently implemented for `f_c` that is either tuple or `=>`/`Pair` (these +currently implemented for `f_c` that is either `tuple` or `=>`/`Pair` (these work for any combination of variate types), `vcat` (for tuple- or vector-like variates) and `merge` (`NamedTuple` variates). [`MeasureBase.split_point(::typeof(f_c), α)`](@ref) can be specialized to @@ -152,19 +154,20 @@ export mbind @inline mbind(f_β) = Base.Fix1(mbind, f_β) -@inline mbind(f_k::MKernel, α::AbstractMeasure) = mbind(f_k.f_β, α, f_k.f_c) - -#@inline mbind(f_β, α::AbstractMeasure, f_c = getsecond) = _generic_mbind_impl(f_β, α, f_c) --- temporary --- -@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, α, f_c) +@inline mbind(f_β, α::AbstractMeasure, f_c = secondarg) = _generic_mbind_impl(f_β, asmeasure(α), f_c) @inline function _generic_mbind_impl(f_β, α::AbstractMeasure, f_c) F, M, G = Core.Typeof(f_β), Core.Typeof(α), Core.Typeof(f_c) Bind{F,M,G}(f_β, α, f_c) end -function _generic_mbind_impl(f_β, α::Dirac, f_c) - mcombine(f_c, α, f_β(α.x)) -end +@inline _generic_mbind_impl(f_β, α::Dirac, f_c) = mcombine(f_c, α, f_β(α.x)) + +@inline _generic_mbind_impl(@nospecialize(f_β), α::AbstractMeasure, ::typeof(firstarg)) = α +@inline _generic_mbind_impl(@nospecialize(f_β), α::Dirac, ::typeof(firstarg)) = α + +@inline _generic_mbind_impl(f_k::MKernel, α::AbstractMeasure, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c) +@inline _generic_mbind_impl(f_k::MKernel, α::Dirac, ::typeof(secondarg)) = mbind(f_k.f_β, α, f_k.f_c) """ @@ -175,8 +178,8 @@ Represents a monatic bind resp. a mbind in general. User code should not create instances of `Bind` directly, but should call [`mbind`](@ref) instead. """ -struct Bind{FK,M<:AbstractMeasure,FC} <: AbstractMeasure - f_β::FK +struct Bind{FT,M<:AbstractMeasure,FC} <: AbstractMeasure + f_β::FT α::M f_c::FC end diff --git a/src/combinators/combined.jl b/src/combinators/combined.jl index 11ebc558..4127e952 100644 --- a/src/combinators/combined.jl +++ b/src/combinators/combined.jl @@ -62,9 +62,8 @@ export mcombine @inline mcombine(f_c, α::AbstractMeasure, β::AbstractMeasure) = _generic_mcombine_impl_stage1(f_c, α, β) -@inline _generic_mcombine_impl_stage1(::typeof(first), α::AbstractMeasure, β::AbstractMeasure) = α -@inline _generic_mcombine_impl_stage1(::typeof(getsecond), α::AbstractMeasure, β::AbstractMeasure) = β -@inline _generic_mcombine_impl_stage1(::typeof(last), α::AbstractMeasure, β::AbstractMeasure) = β +@inline _generic_mcombine_impl_stage1(::typeof(firstarg), α::AbstractMeasure, β::AbstractMeasure) = α +@inline _generic_mcombine_impl_stage1(::typeof(secondarg), α::AbstractMeasure, β::AbstractMeasure) = β @inline function _generic_mcombine_impl_stage1(::typeof(tuple), α::AbstractMeasure, β::AbstractMeasure) productmeasure((α, β))