Skip to content

Commit

Permalink
Rename local_measure, change insupprt handling
Browse files Browse the repository at this point in the history
  • Loading branch information
oschulz committed Jun 23, 2023
1 parent a77681b commit f084797
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 16 deletions.
6 changes: 3 additions & 3 deletions src/combinators/hierarchical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ function _combine_variates(x_primary, x_secondary)
end


function local_measure(h::HierarchicalMeasure, x)
function localmeasure(h::HierarchicalMeasure, x)
x_primary, x_secondary = _split_variate(h, x)
m_primary = h.m
m_primary_local = local_measure(m_primary, x_primary)
m_primary_local = localmeasure(m_primary, x_primary)
m_secondary = m.f(x_secondary)
m_secondary_local = local_measure(m_secondary, x_secondary)
m_secondary_local = localmeasure(m_secondary, x_secondary)
# TODO: Must optionally return a flattened product measure
return productmeasure(m_primary_local, m_secondary_local)
end
Expand Down
43 changes: 34 additions & 9 deletions src/density-core.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export local_measure
export localmeasure

export logdensityof
export logdensity_rel
Expand All @@ -13,7 +13,7 @@ export density_def


"""
local_measure(m::AbstractMeasure, x)::AbstractMeasure
localmeasure(m::AbstractMeasure, x)::AbstractMeasure
Return a local measure of `m` at `x` which will be `m` itself for many
measures.
Expand All @@ -25,9 +25,9 @@ Note that the resulting measure may not be well defined outside of such a
neighborhood of `x`.
See [`HierarchicalMeasure`](@ref) as an example of a measure where
`local_measure` returns different measures depending on `x`.
`localmeasure` returns different measures depending on `x`.
"""
local_measure(m::AbstractMeasure, x) = m
localmeasure(m::AbstractMeasure, x) = m


"""
Expand Down Expand Up @@ -113,23 +113,42 @@ known to be in the support of both, it can be more efficient to call
`unsafe_logdensity_rel`.
"""
@inline function logdensity_rel::M, ν::N, x::X) where {M,N,X}
inμ = insupport(μ, x)
inν = insupport(ν, x)
return unsafe_logdensity_rel(μ, ν, x, inμ, inν)
end


function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, inν::Bool) where {M,N,X}
T = unstatic(
promote_type(
logdensity_type(μ, X),
logdensity_type(ν, X),
),
)
inμ = insupport(μ, x)
inν = insupport(ν, x)

istrue(inμ) || return convert(T, ifelse(inν, -Inf, NaN))
istrue(inν) || return convert(T, Inf)

μ_local = localmeasure, x)
ν_local = localmeasure(ν, x)
return unsafe_logdensity_rel(μ, ν, x)
end

return unsafe_logdensity_rel(μ_local, ν_local, x)

function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), @nospecialize(::NoFastInsupport)) where {M,N,X}
unsafe_logdensity_rel(μ, ν, x)
end

function _logdensity_rel_impl::M, ν::N, x::X, inμ::Bool, @nospecialize(::NoFastInsupport)) where {M,N,X}
logd = unsafe_logdensity_rel(μ, ν, x)
return istrue(inμ) ? logd : logd * oftypeof(logd, -Inf)
end

function _logdensity_rel_impl::M, ν::N, x::X, @nospecialize(::NoFastInsupport), inν::Bool) where {M,N,X}
logd = unsafe_logdensity_rel(μ, ν, x)
return istrue(inν) ? logd : logd * oftypeof(logd, +Inf)
end


"""
unsafe_logdensity_rel(m1, m2, x)
Expand All @@ -139,6 +158,12 @@ known to be in the support of both `m1` and `m2`.
See also `logdensity_rel`.
"""
@inline function unsafe_logdensity_rel::M, ν::N, x::X) where {M,N,X}
μ_local = localmeasure(μ, x)
ν_local = localmeasure(ν, x)
return _unsafe_logdensity_rel_local(μ_local, ν_local, x)
end

@inline function _unsafe_logdensity_rel_local::M, ν::N, x::X) where {M,N,X}
if static_hasmethod(logdensity_def, Tuple{M,N,X})
return logdensity_def(μ, ν, x)
end
Expand Down
22 changes: 18 additions & 4 deletions src/insupport.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,33 @@
"""
MeasureBase.NoFastInsupport{MU}
Indicates that there is no fast way to compute if a point lies within the
support of measures of type `MU`
"""
struct NoFastInsupport{MU} end


"""
inssupport(m, x)
insupport(m)
`insupport(m,x)` computes whether `x` is in the support of `m`.
`insupport(m,x)` computes whether `x` is in the support of `m` and
returns either a `Bool` or an instance of [`NoFastInsupport`](@ref).
`insupport(m)` returns a function, and satisfies
insupport(m)(x) == insupport(m, x)
`insupport(m)(x) == insupport(m, x)``
"""
function insupport end


"""
MeasureBase.require_insupport(μ, x)::Nothing
Checks if `x` is in the support of distribution/measure `μ`, throws an
`ArgumentError` if not.
Will not throw an exception if `insupport` returns an instance of
[`NoFastInsupport`](@ref).
"""
function require_insupport end

Expand All @@ -24,7 +37,8 @@ function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
end

function require_insupport(μ, x)
if !insupport(μ, x)
r = insupport(μ, x)
if !(r isa NoFastInsupport) || r
throw(ArgumentError("x is not within the support of μ"))
end
return nothing
Expand Down

0 comments on commit f084797

Please sign in to comment.