Skip to content

Commit

Permalink
Don't require primary in HierarchicalMeasure to have known DOF
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Jun 29, 2023
1 parent 73b66a6 commit 042882d
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 68 deletions.
213 changes: 149 additions & 64 deletions src/combinators/hierarchical.jl
Original file line number Diff line number Diff line change
@@ -1,113 +1,198 @@
export HierarchicalMeasure


struct HierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure
# TODO: Document and use FlattenMode
abstract type FlattenMode end
struct NoFlatten <: FlattenMode end
struct AutoFlatten <: FlattenMode end


struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure
f::F
m::M
dof_m::Int
flatten_mode::FM
end

# TODO: Document
const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten}
export HierarchicalProductMeasure

function HierarchicalMeasure(f, m::AbstractMeasure, ::NoDOF)
throw(ArgumentError("Primary measure in HierarchicalMeasure must have fixed and known DOF"))
end
HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten())

# TODO: Document
const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten}
export FlatHierarchicalMeasure

FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten())

HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m)

HierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, dynamic(getdof(m)))


function _split_variate(h::HierarchicalMeasure, x)
# TODO: Splitting x will be more complicated in general:
x_primary, x_secondary = x
return (x_primary, x_secondary)
function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2})
@assert x isa Tuple{2}
return x[1], x[2]
end


function _combine_variates(x_primary, x_secondary)
# TODO: Must offer optional flattening
return (x_primary, x_secondary)
function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x)
a_test = testvalue(μ)
return _autosplit_variate_after_testvalue(a_test, x)
end

function _autosplit_variate_after_testvalue(::Any, x)
@assert x isa Tuple{2}
return x[1], x[2]
end

function localmeasure(h::HierarchicalMeasure, x)
x_primary, x_secondary = _split_variate(h, x)
m_primary = h.m
m_primary_local = localmeasure(m_primary, x_primary)
m_secondary = m.f(x_secondary)
m_secondary_local = localmeasure(m_secondary, x_secondary)
# TODO: Must optionally return a flattened product measure
return productmeasure(m_primary_local, m_secondary_local)
function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector)
n, m = length(eachindex(a_test)), length(eachindex(x))
# TODO: Use getindex or view?
return x[begin:n], x[begin+n:m]
end

function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M}
return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M))
end

@inline function insupport(h::HierarchicalMeasure, x)
# Only test primary for efficiency:
x_primary = _split_variate(h, x)[1]
insupport(h.m, x_primary)
@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names}
# TODO: implement
@assert false
end


#!!!!!!! WON'T WORK: Only use primary measure for efficiency:
logdensity_type(h::HierarchicalMeasure{F,M}, ::Type{T}) where {F,M,T} = unstatic(float(logdensity_type(M, T)))

# Can't implement logdensity_def(::HierarchicalMeasure, x) directly.
_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b)


_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b)

_autoflat_combine_variates(a::Any, b::Any) = (a, b)

_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b)

_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b)

# TODO: Check that names don't overlap:
_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b)


_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2)

# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2)
# Needs a FlatProductMeasure type.

function _localmeasure_with_rest::HierarchicalProductMeasure, x)
μ_primary = μ.m
local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x)
μ_secondary = μ.f(x_secondary)
local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary)
return _local_productmeasure.flatten_mode, local_primary, local_secondary), x_rest
end

function _localmeasure_with_rest::AbstractMeasure, x)
x_checked = checked_arg(μ, x)
return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0)
end

function localmeasure::HierarchicalProductMeasure, x)
h_local, x_rest = _localmeasure_with_rest(μ, x)
if !isempty(x_rest)
throw(ArgumentError("Variate too long while computing localmeasure of HierarchicalMeasure"))
end
return h_local
end

# Can't implement getdof(::HierarchicalMeasure) efficiently

# No way to return a functional base measure:
struct _BaseOfHierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure end
@inline basemeasure(::HierarchicalMeasure{F,M}) where {F,M} = _BaseOfHierarchicalMeasure{F,M}()
@inline insupport(::HierarchicalMeasure, x) = NoFastInsupport()

@inline getdof::HierarchicalMeasure) = NoDOF{typeof(μ)}()

# Bypass `checked_arg`, would require potentially costly evaluation of h.f:
@inline checked_arg(::HierarchicalMeasure, x) = x

function unsafe_logdensityof(h::HierarchicalMeasure, x)
x_primary, x_secondary = _split_variate(h, x)
h_primary, h_secondary = h.m, h.f(x_secondary)
unsafe_logdensityof(h_primary, x_primary) + logdensityof(h_secondary, x_secondary)
end
rootmeasure(::HierarchicalMeasure) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for HierarchicalMeasure"))

basemeasure(::HierarchicalMeasure) = throw(ArgumentError("basemeasure is not available for HierarchicalMeasure"))

logdensity_def(::HierarchicalMeasure, x) = throw(ArgumentError("logdensity_def is not available for HierarchicalMeasure"))


# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient
# # for AutoFlatten, since variate will be split in `localmeasure` and then
# # split again in log-density evaluation. Maybe add something like
# function unsafe_logdensityof(h::HierarchicalMeasure, x)
# local_primary, local_secondary, x_primary, x_secondary = ...
# # Need to call full logdensityof for h_secondary since x_secondary hasn't
# # been checked yet:
# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary)
# end


function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::HierarchicalMeasure) where {T<:Real}
x_primary = rand(rng, T, h.m)
x_secondary = rand(rng, T, h.f(x_primary))
return _combine_variates(x_primary, x_secondary)
return _combine_variates(h.flatten_mode, x_primary, x_secondary)
end


function _split_measure_at::PowerMeasure{M, Tuple{R}}, n::Integer) where {M<:StdMeasure,R}
dof_μ = getdof(μ)
return M()^n, M()^(dof_μ - n)
end


function transport_def(
ν::PowerMeasure{M, Tuple{R}},
μ::HierarchicalMeasure,
x,
) where {M<:StdMeasure,R}
ν_primary, ν_secondary = _split_measure_at(ν, μ.dof_m)
x_primary, x_secondary = _split_variate(μ, x)
function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::HierarchicalMeasure, x)
μ_primary = μ.m
y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x)
μ_secondary = μ.f(x_secondary)
y_primary = transport_to(ν_primary, μ_primary, x_primary)
y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary)
return vcat(y_primary, y_secondary)
y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary)
return _combine_variates.flatten_mode, y_primary, y_secondary), x_rest
end

function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x)
dof_μ = getdof(μ)
x_μ, x_rest = _split_variate_after(flatten_mode, μ, x)
y = transport_to(ν_inner^dof_μ, μ, x_μ)
return y, x_rest
end

function transport_def::_PowerStdMeasure{1}, μ::HierarchicalMeasure, x)
ν_inner = _get_inner_stdmeasure(ν)
y, x_rest = _to_std_with_rest(ν_inner, μ, x)
if !isempty(x_rest)
throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure"))
end
return y
end


function transport_def(
ν::HierarchicalMeasure,
μ::PowerMeasure{M, Tuple{R}},
x,
) where {M<:StdMeasure,R}
dof_primary = ν.dof_m
μ_primary, μ_secondary = _split_measure_at(μ, dof_primary)
x_primary, x_secondary = x[begin:begin+dof_primary-1], x[begin+dof_primary:end]
function _from_std_with_rest::HierarchicalMeasure, μ_inner::StdMeasure, x)
ν_primary = ν.m
y_primary = transport_to(ν_primary, μ_primary, x_primary)
y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x)
ν_secondary = ν.f(y_primary)
y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary)
return _combine_variates(y_primary, y_secondary)
y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary)
return _combine_variates.flatten_mode, y_primary, y_secondary), x_rest
end

function _from_std_with_rest::AbstractMeasure, μ_inner::StdMeasure, x)
dof_ν = getdof(ν)
len_x = length(eachindex(x))

# Since we can't check DOF of original HierarchicalMeasure, we could "run out x" if
# the original x was too short. `transport_to` below will detect this, but better
# throw a more informative exception here:
if len_x < dof_ν
throw(ArgumentError("Variate too short during transport involving HierarchicalMeasure"))
end

y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1])
x_rest = Fill(zero(eltype(x)), dof_ν - len_x)
return y, x_rest
end

function transport_def::HierarchicalMeasure, μ::_PowerStdMeasure{1}, x)
# Sanity check, should be checked by transport machinery already:
@assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector
μ_inner = _get_inner_stdmeasure(μ)
y, x_rest = _from_std_with_rest(ν, μ_inner, x)
if !isempty(x_rest)
throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure"))
end
return y
end
16 changes: 12 additions & 4 deletions src/density-core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ To compute a log-density relative to a specific base-measure, see
end

_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))
@inline _checksupport(::NoFastInsupport, result) = result

import ChainRulesCore
@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
Expand All @@ -77,6 +78,12 @@ This is "unsafe" because it does not check `insupport(m, x)`.
See also `logdensityof`.
"""
@inline function unsafe_logdensityof::M, x) where {M}
μ_local = localmeasure(μ, x)
# Extra dispatch boundary to reduce number of required specializations of implementation:
return _unsafe_logdensityof_local(μ_local, x)
end

@inline function _unsafe_logdensityof_local::M, x) where {M}
ℓ_0 = logdensity_def(μ, x)
b_0 = μ
Base.Cartesian.@nexprs 10 i -> begin # 10 is just some "big enough" number
Expand Down Expand Up @@ -119,7 +126,7 @@ known to be in the support of both, it can be more efficient to call
end


function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X}
@inline function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X}
T = unstatic(
promote_type(
logdensity_type(μ, X),
Expand All @@ -134,16 +141,16 @@ function _logdensity_rel_impl(μ::M, ν::N, x::X, inμ::Bool, inν::Bool) where
end


function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X}
@inline function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X}
unsafe_logdensity_rel(μ, ν, x)
end

function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X}
@inline function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X}
logd = unsafe_logdensity_rel(μ, ν, x)
return istrue(inμ) ? logd : logd * oftypeof(logd, -Inf)
end

function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X}
@inline function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X}
logd = unsafe_logdensity_rel(μ, ν, x)
return istrue(inν) ? logd : logd * oftypeof(logd, +Inf)
end
Expand All @@ -160,6 +167,7 @@ See also `logdensity_rel`.
@inline function unsafe_logdensity_rel::M, ν::N, x::X) where {M,N,X}
μ_local = localmeasure(μ, x)
ν_local = localmeasure(ν, x)
# Extra dispatch boundary to reduce number of required specializations of implementation:
return _unsafe_logdensity_rel_local(μ_local, ν_local, x)
end

Expand Down
6 changes: 6 additions & 0 deletions src/standard/stdmeasure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
abstract type StdMeasure <: AbstractMeasure end


const _PowerStdMeasure{N,M<:StdMeasure} = PowerMeasure{M,<:NTuple{N,Base.OneTo}}

_get_inner_stdmeasure::_PowerStdMeasure{N,M}) where {N,M} = M()


StdMeasure(::typeof(rand)) = StdUniform()
StdMeasure(::typeof(randexp)) = StdExponential()
StdMeasure(::typeof(randn)) = StdNormal()
Expand Down

0 comments on commit 042882d

Please sign in to comment.