-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Don't require primary in HierarchicalMeasure to have known DOF
- Loading branch information
Showing
3 changed files
with
167 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,113 +1,198 @@ | ||
export HierarchicalMeasure | ||
|
||
|
||
struct HierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure | ||
# TODO: Document and use FlattenMode | ||
abstract type FlattenMode end | ||
struct NoFlatten <: FlattenMode end | ||
struct AutoFlatten <: FlattenMode end | ||
|
||
|
||
struct HierarchicalMeasure{F,M<:AbstractMeasure,FM<:FlattenMode} <: AbstractMeasure | ||
f::F | ||
m::M | ||
dof_m::Int | ||
flatten_mode::FM | ||
end | ||
|
||
# TODO: Document | ||
const HierarchicalProductMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,NoFlatten} | ||
export HierarchicalProductMeasure | ||
|
||
function HierarchicalMeasure(f, m::AbstractMeasure, ::NoDOF) | ||
throw(ArgumentError("Primary measure in HierarchicalMeasure must have fixed and known DOF")) | ||
end | ||
HierarchicalProductMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, NoFlatten()) | ||
|
||
# TODO: Document | ||
const FlatHierarchicalMeasure{F,M<:AbstractMeasure} = HierarchicalMeasure{F,M,AutoFlatten} | ||
export FlatHierarchicalMeasure | ||
|
||
FlatHierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, AutoFlatten()) | ||
|
||
HierarchicalMeasure(f, m::AbstractMeasure) = FlatHierarchicalMeasure(f, m) | ||
|
||
HierarchicalMeasure(f, m::AbstractMeasure) = HierarchicalMeasure(f, m, dynamic(getdof(m))) | ||
|
||
|
||
function _split_variate(h::HierarchicalMeasure, x) | ||
# TODO: Splitting x will be more complicated in general: | ||
x_primary, x_secondary = x | ||
return (x_primary, x_secondary) | ||
function _split_variate_after(::NoFlatten, μ::AbstractMeasure, x::Tuple{2}) | ||
@assert x isa Tuple{2} | ||
return x[1], x[2] | ||
end | ||
|
||
|
||
function _combine_variates(x_primary, x_secondary) | ||
# TODO: Must offer optional flattening | ||
return (x_primary, x_secondary) | ||
function _split_variate_after(::AutoFlatten, μ::AbstractMeasure, x) | ||
a_test = testvalue(μ) | ||
return _autosplit_variate_after_testvalue(a_test, x) | ||
end | ||
|
||
function _autosplit_variate_after_testvalue(::Any, x) | ||
@assert x isa Tuple{2} | ||
return x[1], x[2] | ||
end | ||
|
||
function localmeasure(h::HierarchicalMeasure, x) | ||
x_primary, x_secondary = _split_variate(h, x) | ||
m_primary = h.m | ||
m_primary_local = localmeasure(m_primary, x_primary) | ||
m_secondary = m.f(x_secondary) | ||
m_secondary_local = localmeasure(m_secondary, x_secondary) | ||
# TODO: Must optionally return a flattened product measure | ||
return productmeasure(m_primary_local, m_secondary_local) | ||
function _autosplit_variate_after_testvalue(a_test::AbstractVector, x::AbstractVector) | ||
n, m = length(eachindex(a_test)), length(eachindex(x)) | ||
# TODO: Use getindex or view? | ||
return x[begin:n], x[begin+n:m] | ||
end | ||
|
||
function _autosplit_variate_after_testvalue(::Tuple{N}, x::Tuple{M}) where {N,M} | ||
return ntuple(i -> x[i], Val(1:N)), ntuple(i -> x[i], Val(N+1:M)) | ||
end | ||
|
||
@inline function insupport(h::HierarchicalMeasure, x) | ||
# Only test primary for efficiency: | ||
x_primary = _split_variate(h, x)[1] | ||
insupport(h.m, x_primary) | ||
@generated function _autosplit_variate_after_testvalue(::NamedTuple{names_a}, x::NamedTuple{names}) where {names_a,names} | ||
# TODO: implement | ||
@assert false | ||
end | ||
|
||
|
||
#!!!!!!! WON'T WORK: Only use primary measure for efficiency: | ||
logdensity_type(h::HierarchicalMeasure{F,M}, ::Type{T}) where {F,M,T} = unstatic(float(logdensity_type(M, T))) | ||
|
||
# Can't implement logdensity_def(::HierarchicalMeasure, x) directly. | ||
_combine_variates(::NoFlatten, a::Any, b::Any) = (a, b) | ||
|
||
|
||
_combine_variates(::AutoFlatten, a::Any, b::Any) = _autoflat_combine_variates(a, b) | ||
|
||
_autoflat_combine_variates(a::Any, b::Any) = (a, b) | ||
|
||
_autoflat_combine_variates(a::AbstractVector, b::AbstractVector) = vcat(a, b) | ||
|
||
_autoflat_combine_variates(a::Tuple, b::Tuple) = (a, b) | ||
|
||
# TODO: Check that names don't overlap: | ||
_autoflat_combine_variates(a::NamedTuple, b::NamedTuple) = merge(a, b) | ||
|
||
|
||
_local_productmeasure(::NoFlatten, μ1, μ2) = productmeasure(μ1, μ2) | ||
|
||
# TODO: _local_productmeasure(::AutoFlatten, μ1, μ2) = productmeasure(μ1, μ2) | ||
# Needs a FlatProductMeasure type. | ||
|
||
function _localmeasure_with_rest(μ::HierarchicalProductMeasure, x) | ||
μ_primary = μ.m | ||
local_primary, x_secondary = _localmeasure_with_rest(μ_primary, x) | ||
μ_secondary = μ.f(x_secondary) | ||
local_secondary, x_rest = _localmeasure_with_rest(μ_secondary, x_secondary) | ||
return _local_productmeasure(μ.flatten_mode, local_primary, local_secondary), x_rest | ||
end | ||
|
||
function _localmeasure_with_rest(μ::AbstractMeasure, x) | ||
x_checked = checked_arg(μ, x) | ||
return localmeasure(μ, x_checked), Fill(zero(eltype(x)), 0) | ||
end | ||
|
||
function localmeasure(μ::HierarchicalProductMeasure, x) | ||
h_local, x_rest = _localmeasure_with_rest(μ, x) | ||
if !isempty(x_rest) | ||
throw(ArgumentError("Variate too long while computing localmeasure of HierarchicalMeasure")) | ||
end | ||
return h_local | ||
end | ||
|
||
# Can't implement getdof(::HierarchicalMeasure) efficiently | ||
|
||
# No way to return a functional base measure: | ||
struct _BaseOfHierarchicalMeasure{F,M<:AbstractMeasure} <: AbstractMeasure end | ||
@inline basemeasure(::HierarchicalMeasure{F,M}) where {F,M} = _BaseOfHierarchicalMeasure{F,M}() | ||
@inline insupport(::HierarchicalMeasure, x) = NoFastInsupport() | ||
|
||
@inline getdof(μ::HierarchicalMeasure) = NoDOF{typeof(μ)}() | ||
|
||
# Bypass `checked_arg`, would require potentially costly evaluation of h.f: | ||
@inline checked_arg(::HierarchicalMeasure, x) = x | ||
|
||
function unsafe_logdensityof(h::HierarchicalMeasure, x) | ||
x_primary, x_secondary = _split_variate(h, x) | ||
h_primary, h_secondary = h.m, h.f(x_secondary) | ||
unsafe_logdensityof(h_primary, x_primary) + logdensityof(h_secondary, x_secondary) | ||
end | ||
rootmeasure(::HierarchicalMeasure) = throw(ArgumentError("root measure is implicit, but can't be instantiated, for HierarchicalMeasure")) | ||
|
||
basemeasure(::HierarchicalMeasure) = throw(ArgumentError("basemeasure is not available for HierarchicalMeasure")) | ||
|
||
logdensity_def(::HierarchicalMeasure, x) = throw(ArgumentError("logdensity_def is not available for HierarchicalMeasure")) | ||
|
||
|
||
# # TODO: Default implementation of unsafe_logdensityof is a bit inefficient | ||
# # for AutoFlatten, since variate will be split in `localmeasure` and then | ||
# # split again in log-density evaluation. Maybe add something like | ||
# function unsafe_logdensityof(h::HierarchicalMeasure, x) | ||
# local_primary, local_secondary, x_primary, x_secondary = ... | ||
# # Need to call full logdensityof for h_secondary since x_secondary hasn't | ||
# # been checked yet: | ||
# unsafe_logdensityof(local_primary, x_primary) + logdensityof(local_secondary, x_secondary) | ||
# end | ||
|
||
|
||
function Base.rand(rng::Random.AbstractRNG, ::Type{T}, h::HierarchicalMeasure) where {T<:Real} | ||
x_primary = rand(rng, T, h.m) | ||
x_secondary = rand(rng, T, h.f(x_primary)) | ||
return _combine_variates(x_primary, x_secondary) | ||
return _combine_variates(h.flatten_mode, x_primary, x_secondary) | ||
end | ||
|
||
|
||
function _split_measure_at(μ::PowerMeasure{M, Tuple{R}}, n::Integer) where {M<:StdMeasure,R} | ||
dof_μ = getdof(μ) | ||
return M()^n, M()^(dof_μ - n) | ||
end | ||
|
||
|
||
function transport_def( | ||
ν::PowerMeasure{M, Tuple{R}}, | ||
μ::HierarchicalMeasure, | ||
x, | ||
) where {M<:StdMeasure,R} | ||
ν_primary, ν_secondary = _split_measure_at(ν, μ.dof_m) | ||
x_primary, x_secondary = _split_variate(μ, x) | ||
function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::HierarchicalMeasure, x) | ||
μ_primary = μ.m | ||
y_primary, x_secondary = _to_std_with_rest(flatten_mode, ν_inner, μ_primary, x) | ||
μ_secondary = μ.f(x_secondary) | ||
y_primary = transport_to(ν_primary, μ_primary, x_primary) | ||
y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) | ||
return vcat(y_primary, y_secondary) | ||
y_secondary, x_rest = _to_std_with_rest(flatten_mode, ν_inner, μ_secondary, x_secondary) | ||
return _combine_variates(μ.flatten_mode, y_primary, y_secondary), x_rest | ||
end | ||
|
||
function _to_std_with_rest(flatten_mode::FlattenMode, ν_inner::StdMeasure, μ::AbstractMeasure, x) | ||
dof_μ = getdof(μ) | ||
x_μ, x_rest = _split_variate_after(flatten_mode, μ, x) | ||
y = transport_to(ν_inner^dof_μ, μ, x_μ) | ||
return y, x_rest | ||
end | ||
|
||
function transport_def(ν::_PowerStdMeasure{1}, μ::HierarchicalMeasure, x) | ||
ν_inner = _get_inner_stdmeasure(ν) | ||
y, x_rest = _to_std_with_rest(ν_inner, μ, x) | ||
if !isempty(x_rest) | ||
throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) | ||
end | ||
return y | ||
end | ||
|
||
|
||
function transport_def( | ||
ν::HierarchicalMeasure, | ||
μ::PowerMeasure{M, Tuple{R}}, | ||
x, | ||
) where {M<:StdMeasure,R} | ||
dof_primary = ν.dof_m | ||
μ_primary, μ_secondary = _split_measure_at(μ, dof_primary) | ||
x_primary, x_secondary = x[begin:begin+dof_primary-1], x[begin+dof_primary:end] | ||
function _from_std_with_rest(ν::HierarchicalMeasure, μ_inner::StdMeasure, x) | ||
ν_primary = ν.m | ||
y_primary = transport_to(ν_primary, μ_primary, x_primary) | ||
y_primary, x_secondary = _from_std_with_rest(ν_primary, μ_inner, x) | ||
ν_secondary = ν.f(y_primary) | ||
y_secondary = transport_to(ν_secondary, μ_secondary, x_secondary) | ||
return _combine_variates(y_primary, y_secondary) | ||
y_secondary, x_rest = _from_std_with_rest(ν_secondary, μ_inner, x_secondary) | ||
return _combine_variates(ν.flatten_mode, y_primary, y_secondary), x_rest | ||
end | ||
|
||
function _from_std_with_rest(ν::AbstractMeasure, μ_inner::StdMeasure, x) | ||
dof_ν = getdof(ν) | ||
len_x = length(eachindex(x)) | ||
|
||
# Since we can't check DOF of original HierarchicalMeasure, we could "run out x" if | ||
# the original x was too short. `transport_to` below will detect this, but better | ||
# throw a more informative exception here: | ||
if len_x < dof_ν | ||
throw(ArgumentError("Variate too short during transport involving HierarchicalMeasure")) | ||
end | ||
|
||
y = transport_to(ν, μ_inner^dof_ν, x[begin:begin+dof_ν-1]) | ||
x_rest = Fill(zero(eltype(x)), dof_ν - len_x) | ||
return y, x_rest | ||
end | ||
|
||
function transport_def(ν::HierarchicalMeasure, μ::_PowerStdMeasure{1}, x) | ||
# Sanity check, should be checked by transport machinery already: | ||
@assert getdof(μ) == length(eachindex(x)) && x isa AbstractVector | ||
μ_inner = _get_inner_stdmeasure(μ) | ||
y, x_rest = _from_std_with_rest(ν, μ_inner, x) | ||
if !isempty(x_rest) | ||
throw(ArgumentError("Variate too long during transport involving HierarchicalMeasure")) | ||
end | ||
return y | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters