From 042882d59b89987c0117b6ce19513679daeb48fb Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 29 Jun 2023 10:52:56 +0200 Subject: [PATCH] Don't require primary in HierarchicalMeasure to have known DOF --- src/combinators/hierarchical.jl | 213 ++++++++++++++++++++++---------- src/density-core.jl | 16 ++- src/standard/stdmeasure.jl | 6 + 3 files changed, 167 insertions(+), 68 deletions(-) diff --git a/src/combinators/hierarchical.jl b/src/combinators/hierarchical.jl index 32f08940..0a2b1909 100644 --- a/src/combinators/hierarchical.jl +++ b/src/combinators/hierarchical.jl @@ -1,113 +1,198 @@ export HierarchicalMeasure -struct HierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure +# TODO: Document and use FlattenMode +abstract type FlattenMode end +struct NoFlatten <: FlattenMode end +struct AutoFlatten <: FlattenMode end + + +struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure f::F m::M - dof_m::Int + flatten_mode::FM end +# TODO: Document +const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten} +export HierarchicalProductMeasure -function HierarchicalMeasure(f, m::AbstractMeasure, ::NoDOF) - throw(ArgumentError("Primary measure in HierarchicalMeasure must have fixed and known DOF")) -end +HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten()) + +# TODO: Document +const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten} +export FlatHierarchicalMeasure + +FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten()) + +HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m) -HierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, dynamic(getdof(m))) -function _split_variate(h::HierarchicalMeasure, x) - # TODO: Splitting x will be more complicated in general: - x_primary, x_secondary = x - return (x_primary, x_secondary) +function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2}) + @assert x isa Tuple{2} + return x[1], x[2] end -function _combine_variates(x_primary, x_secondary) - # TODO: Must offer optional flattening - return (x_primary, x_secondary) +function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x) + a_test = testvalue(μ) + return _autosplit_variate_after_testvalue(a_test, x) end +function _autosplit_variate_after_testvalue(::Any, x) + @assert x isa Tuple{2} + return x[1], x[2] +end -function localmeasure(h::HierarchicalMeasure, x) - x_primary, x_secondary = _split_variate(h, x) - m_primary = h.m - m_primary_local = localmeasure(m_primary, x_primary) - m_secondary = m.f(x_secondary) - m_secondary_local = localmeasure(m_secondary, x_secondary) - # TODO: Must optionally return a flattened product measure - return productmeasure(m_primary_local, m_secondary_local) +function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector) + n, m = length(eachindex(a_test)), length(eachindex(x)) + # TODO: Use getindex or view? + return x[begin:n], x[begin+n:m] end +function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M} + return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) +end -@inline function insupport(h::HierarchicalMeasure, x) - # Only test primary for efficiency: - x_primary = _split_variate(h, x)[1] - insupport(h.m, x_primary) +@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} + # TODO: implement + @assert false end -#!!!!!!! WON'T WORK: Only use primary measure for efficiency: -logdensity_type(h::HierarchicalMeasure{F,M}, ::Type{T}) where {F,M,T} = unstatic(float(logdensity_type(M, T))) -# Can't implement logdensity_def(::HierarchicalMeasure, x) directly. +_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b) + + +_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b) + +_autoflat_combine_variates(a::Any, b::Any) = (a, b) + +_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b) + +_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b) + +# TODO: Check that names don't overlap: +_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b) + + +_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2) + +# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2) +# Needs a FlatProductMeasure type. + +function _localmeasure_with_rest(μ::HierarchicalProductMeasure, x) + μ_primary = μ.m + local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x) + μ_secondary = μ.f(x_secondary) + local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary) + return _local_productmeasure(μ.flatten_mode, local_primary, local_secondary), x_rest +end + +function _localmeasure_with_rest(μ::AbstractMeasure, x) + x_checked = checked_arg(μ, x) + return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0) +end + +function localmeasure(μ::HierarchicalProductMeasure, x) + h_local, x_rest = _localmeasure_with_rest(μ, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long while computing localmeasure of HierarchicalMeasure")) + end + return h_local +end -# Can't implement getdof(::HierarchicalMeasure) efficiently -# No way to return a functional base measure: -struct _BaseOfHierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure end -@inline basemeasure(::HierarchicalMeasure{F,M}) where {F,M} = _BaseOfHierarchicalMeasure{F,M}() +@inline insupport(::HierarchicalMeasure, x) = NoFastInsupport() @inline getdof(μ::HierarchicalMeasure) = NoDOF{typeof(μ)}() # Bypass `checked_arg`, would require potentially costly evaluation of h.f: @inline checked_arg(::HierarchicalMeasure, x) = x -function unsafe_logdensityof(h::HierarchicalMeasure, x) - x_primary, x_secondary = _split_variate(h, x) - h_primary, h_secondary = h.m, h.f(x_secondary) - unsafe_logdensityof(h_primary, x_primary) + logdensityof(h_secondary, x_secondary) -end +rootmeasure(::HierarchicalMeasure) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for HierarchicalMeasure")) + +basemeasure(::HierarchicalMeasure) = throw(ArgumentError("basemeasure is not available for HierarchicalMeasure")) + +logdensity_def(::HierarchicalMeasure, x) = throw(ArgumentError("logdensity_def is not available for HierarchicalMeasure")) + + +# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient +# # for AutoFlatten, since variate will be split in `localmeasure` and then +# # split again in log-density evaluation. Maybe add something like +# function unsafe_logdensityof(h::HierarchicalMeasure, x) +# local_primary, local_secondary, x_primary, x_secondary = ... +# # Need to call full logdensityof for h_secondary since x_secondary hasn't +# # been checked yet: +# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) +# end function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::HierarchicalMeasure) where {T<:Real} x_primary = rand(rng, T, h.m) x_secondary = rand(rng, T, h.f(x_primary)) - return _combine_variates(x_primary, x_secondary) + return _combine_variates(h.flatten_mode, x_primary, x_secondary) end -function _split_measure_at(μ::PowerMeasure{M, Tuple{R}}, n::Integer) where {M<:StdMeasure,R} - dof_μ = getdof(μ) - return M()^n, M()^(dof_μ - n) -end - -function transport_def( - ν::PowerMeasure{M, Tuple{R}}, - μ::HierarchicalMeasure, - x, -) where {M<:StdMeasure,R} - ν_primary, ν_secondary = _split_measure_at(ν, μ.dof_m) - x_primary, x_secondary = _split_variate(μ, x) +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::HierarchicalMeasure, x) μ_primary = μ.m + y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) μ_secondary = μ.f(x_secondary) - y_primary = transport_to(ν_primary, μ_primary, x_primary) - y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) - return vcat(y_primary, y_secondary) + y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) + return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest +end + +function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) + dof_μ = getdof(μ) + x_μ, x_rest = _split_variate_after(flatten_mode, μ, x) + y = transport_to(ν_inner^dof_μ, μ, x_μ) + return y, x_rest +end + +function transport_def(ν::_PowerStdMeasure{1}, μ::HierarchicalMeasure, x) + ν_inner = _get_inner_stdmeasure(ν) + y, x_rest = _to_std_with_rest(ν_inner, μ, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) + end + return y end -function transport_def( - ν::HierarchicalMeasure, - μ::PowerMeasure{M, Tuple{R}}, - x, -) where {M<:StdMeasure,R} - dof_primary = ν.dof_m - μ_primary, μ_secondary = _split_measure_at(μ, dof_primary) - x_primary, x_secondary = x[begin:begin+dof_primary-1], x[begin+dof_primary:end] +function _from_std_with_rest(ν::HierarchicalMeasure, μ_inner::StdMeasure, x) ν_primary = ν.m - y_primary = transport_to(ν_primary, μ_primary, x_primary) + y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) ν_secondary = ν.f(y_primary) - y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) - return _combine_variates(y_primary, y_secondary) + y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) + return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest +end + +function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) + dof_ν = getdof(ν) + len_x = length(eachindex(x)) + + # Since we can't check DOF of original HierarchicalMeasure, we could "run out x" if + # the original x was too short. `transport_to` below will detect this, but better + # throw a more informative exception here: + if len_x < dof_ν + throw(ArgumentError("Variate too short during transport involving HierarchicalMeasure")) + end + + y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) + x_rest = Fill(zero(eltype(x)), dof_ν - len_x) + return y, x_rest +end + +function transport_def(ν::HierarchicalMeasure, μ::_PowerStdMeasure{1}, x) + # Sanity check, should be checked by transport machinery already: + @assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector + μ_inner = _get_inner_stdmeasure(μ) + y, x_rest = _from_std_with_rest(ν, μ_inner, x) + if !isempty(x_rest) + throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) + end + return y end diff --git a/src/density-core.jl b/src/density-core.jl index ef045b92..626c6ced 100644 --- a/src/density-core.jl +++ b/src/density-core.jl @@ -55,6 +55,7 @@ To compute a log-density relative to a specific base-measure, see end _checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf)) +@inline _checksupport(::NoFastInsupport, result) = result import ChainRulesCore @inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result) @@ -77,6 +78,12 @@ This is "unsafe" because it does not check `insupport(m, x)`. See also `logdensityof`. """ @inline function unsafe_logdensityof(μ::M, x) where {M} + μ_local = localmeasure(μ, x) + # Extra dispatch boundary to reduce number of required specializations of implementation: + return _unsafe_logdensityof_local(μ_local, x) +end + +@inline function _unsafe_logdensityof_local(μ::M, x) where {M} ℓ_0 = logdensity_def(μ, x) b_0 = μ Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number @@ -119,7 +126,7 @@ known to be in the support of both, it can be more efficient to call end -function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X} +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X} T = unstatic( promote_type( logdensity_type(μ, X), @@ -134,16 +141,16 @@ function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where end -function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X} +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X} unsafe_logdensity_rel(μ, ν, x) end -function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X} +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X} logd = unsafe_logdensity_rel(μ, ν, x) return istrue(inμ) ? logd : logd * oftypeof(logd, -Inf) end -function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X} +@inline function _logdensity_rel_impl(μ::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X} logd = unsafe_logdensity_rel(μ, ν, x) return istrue(inν) ? logd : logd * oftypeof(logd, +Inf) end @@ -160,6 +167,7 @@ See also `logdensity_rel`. @inline function unsafe_logdensity_rel(μ::M, ν::N, x::X) where {M,N,X} μ_local = localmeasure(μ, x) ν_local = localmeasure(ν, x) + # Extra dispatch boundary to reduce number of required specializations of implementation: return _unsafe_logdensity_rel_local(μ_local, ν_local, x) end diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 833f280e..0df8977f 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -1,5 +1,11 @@ abstract type StdMeasure <: AbstractMeasure end + +const _PowerStdMeasure{N,M<:StdMeasure} = PowerMeasure{M,<:NTuple{N,Base.OneTo}} + +_get_inner_stdmeasure(μ::_PowerStdMeasure{N,M}) where {N,M} = M() + + StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randexp)) = StdExponential() StdMeasure(::typeof(randn)) = StdNormal()