Skip to content

Address some transform-related naming inconsistencies #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jun 20, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MeasureBase"
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
authors = ["Chad Scherrer <[email protected]> and contributors"]
version = "0.11.0"
version = "0.12.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
4 changes: 2 additions & 2 deletions src/combinators/power.jl
Original file line number Diff line number Diff line change
@@ -114,7 +114,7 @@ end
@inline getdof(::PowerMeasure{<:Any, NTuple{N,Base.OneTo{StaticInt{0}}}}) where N = static(0)


@propagate_inbounds function checked_var(μ::PowerMeasure, x::AbstractArray{<:Any})
@propagate_inbounds function checked_arg(μ::PowerMeasure, x::AbstractArray{<:Any})
@boundscheck begin
sz_μ = map(length, μ.axes)
sz_x = size(x)
@@ -125,6 +125,6 @@ end
return x
end

function checked_var(μ::PowerMeasure, x::Any)
function checked_arg(μ::PowerMeasure, x::Any)
throw(ArgumentError("Size of variate doesn't match size of power measure"))
end
4 changes: 2 additions & 2 deletions src/combinators/transformedmeasure.jl
Original file line number Diff line number Diff line change
@@ -82,8 +82,8 @@ _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof
_pushfwd_dof(MU, R, getdof(ν.origin))
end

# Bypass `checked_var`, would require potentially costly transformation:
@inline checked_var(::PushforwardMeasure, x) = x
# Bypass `checked_arg`, would require potentially costly transformation:
@inline checked_arg(::PushforwardMeasure, x) = x


@inline transport_origin(ν::PushforwardMeasure) = ν.origin
20 changes: 10 additions & 10 deletions src/getdof.jl
Original file line number Diff line number Diff line change
@@ -51,27 +51,27 @@ ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_do


"""
MeasureBase.NoVarCheck{MU,T}
MeasureBase.NoArgCheck{MU,T}

Indicates that there is no way to check of a values of type `T` are
variate of measures of type `MU`.
"""
struct NoVarCheck{MU,T} end
struct NoArgCheck{MU,T} end


"""
MeasureBase.checked_var(μ::MU, x::T)::T
MeasureBase.checked_arg(μ::MU, x::T)::T

Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not,
return `NoVarCheck{MU,T}()` if not check can be performed.
return `NoArgCheck{MU,T}()` if not check can be performed.
"""
function checked_var end
function checked_arg end

# Prevent infinite recursion:
@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T}
@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x)
@propagate_inbounds _default_checked_arg(::Type{MU}, ::MU, ::T) where {MU,T} = NoArgCheck{MU,T}
@propagate_inbounds _default_checked_arg(::Type{MU}, mu_base, x) where MU = checked_arg(mu_base, x)

@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x)
@propagate_inbounds checked_arg(mu::MU, x) where MU = _default_checked_arg(MU, basemeasure(mu), x)

_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback
_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback
8 changes: 4 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -6,14 +6,14 @@ using Reexport

using MeasureBase: basemeasure_depth, proxy
using MeasureBase: insupport, basemeasure_sequence, commonbase
using MeasureBase: transport_to, NoVarTransform
using MeasureBase: transport_to, NoTransport

using DensityInterface: logdensityof
using InverseFunctions: inverse
using ChangesOfVariables: with_logabsdet_jacobian

export test_interface
export test_vartransform
export test_transport
export basemeasure_depth
export proxy
export insupport
@@ -66,13 +66,13 @@ function test_interface(μ::M) where {M}
end


function test_vartransform(ν, μ)
function test_transport(ν, μ)
supertype(x::Real) = Real
supertype(x::AbstractArray{<:Real,N}) where N = AbstractArray{<:Real,N}

@testset "transport_to $μ to $ν" begin
x = rand(μ)
@test !(@inferred(transport_to(ν, μ)(x)) isa NoVarTransform)
@test !(@inferred(transport_to(ν, μ)(x)) isa NoTransport)
f = transport_to(ν, μ)
y = f(x)
@test @inferred(inverse(f)(y)) ≈ x
2 changes: 1 addition & 1 deletion src/primitives/dirac.jl
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@ insupport(d::Dirac, x) = x == d.x

@inline getdof(::Dirac) = static(0)

@propagate_inbounds function checked_var(μ::Dirac, x)
@propagate_inbounds function checked_arg(μ::Dirac, x)
@boundscheck insupport(μ, x) || throw(ArgumentError("Invalid variate for measure"))
x
end
4 changes: 2 additions & 2 deletions src/primitives/lebesgue.jl
Original file line number Diff line number Diff line change
@@ -43,8 +43,8 @@ logdensity_def(::CountingMeasure, ::LebesgueMeasure, x) = Inf

@inline getdof(::Lebesgue) = static(1)

@inline checked_var(::Lebesgue, x::Real) = x
@inline checked_arg(::Lebesgue, x::Real) = x

@propagate_inbounds function checked_var(::Lebesgue, x::Any)
@propagate_inbounds function checked_arg(::Lebesgue, x::Any)
@boundscheck throw(ArgumentError("Invalid variate type for measure"))
end
82 changes: 41 additions & 41 deletions src/transport.jl
Original file line number Diff line number Diff line change
@@ -41,12 +41,12 @@ to_origin(ν::NU, ::Any) where NU = NoTransformOrigin{NU}(ν)


"""
struct MeasureBase.NoVarTransform{NU,MU} end
struct MeasureBase.NoTransport{NU,MU} end

Indicates that no transformation from a measure of type `MU` to a measure of
type `NU` could be found.
"""
struct NoVarTransform{NU,MU} end
struct NoTransport{NU,MU} end


"""
@@ -120,10 +120,10 @@ See [`transport_to`](@ref).
function transport_def end

transport_def(::Any, ::Any, x::NoTransformOrigin) = x
transport_def(::Any, ::Any, x::NoVarTransform) = x
transport_def(::Any, ::Any, x::NoTransport) = x

function transport_def(ν, μ, x)
_vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x)
_transport_with_intermediate(ν, _checked_transport_origin(ν), _checked_transport_origin(μ), μ, x)
end


@@ -132,92 +132,92 @@ function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where MU
throw(ArgumentError("Measure of type $MU and its origin must have separate types"))
end

@inline function _checked_vartransform_origin(μ::MU) where MU
@inline function _checked_transport_origin(μ::MU) where MU
μ_o = transport_origin(μ)
_origin_must_have_separate_type(MU, μ_o)
end


function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x)
function _transport_with_intermediate(ν, ν_o, μ_o, μ, x)
x_o = to_origin(μ, x)
# If μ is a pushforward then checked_var may have been bypassed, so check now:
y_o = transport_def(ν_o, μ_o, checked_var(μ_o, x_o))
# If μ is a pushforward then checked_arg may have been bypassed, so check now:
y_o = transport_def(ν_o, μ_o, checked_arg(μ_o, x_o))
y = from_origin(ν, y_o)
return y
end

function _vartransform_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x)
function _transport_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x)
y_o = transport_def(ν_o, μ, x)
y = from_origin(ν, y_o)
return y
end

function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x)
function _transport_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x)
x_o = to_origin(μ, x)
# If μ is a pushforward then checked_var may have been bypassed, so check now:
y = transport_def(ν, μ_o, checked_var(μ_o, x_o))
# If μ is a pushforward then checked_arg may have been bypassed, so check now:
y = transport_def(ν, μ_o, checked_arg(μ_o, x_o))
return y
end

function _vartransform_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x)
_vartransform_with_intermediate(ν, _vartransform_intermediate(ν, μ), μ, x)
function _transport_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x)
_transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x)
end


@inline _vartransform_intermediate(ν, μ) = _vartransform_intermediate(getdof(ν), getdof(μ))
@inline _vartransform_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ
@inline _vartransform_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform()
@inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ))
@inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ
@inline _transport_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform()

function _vartransform_with_intermediate(ν, m, μ, x)
function _transport_with_intermediate(ν, m, μ, x)
z = transport_def(m, μ, x)
y = transport_def(ν, m, z)
return y
end

# Prevent infinite recursion in case vartransform_intermediate doesn't change type:
@inline _vartransform_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}()
@inline _vartransform_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}()
@inline _transport_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}()
@inline _transport_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}()


"""
struct VarTransformation <: Function
struct TransportFunction <: Function

Transforms a variate from one measure to a variate of another.

In general `VarTransformation` should not be called directly, call
In general `TransportFunction` should not be called directly, call
[`transport_to`](@ref) instead.
"""
struct VarTransformation{NU,MU} <: Function
struct TransportFunction{NU,MU} <: Function
ν::NU
μ::MU

function VarTransformation{NU,MU}(ν::NU, μ::MU) where {NU,MU}
function TransportFunction{NU,MU}(ν::NU, μ::MU) where {NU,MU}
return new{NU,MU}(ν, μ)
end

function VarTransformation(ν::NU, μ::MU) where {NU,MU}
function TransportFunction(ν::NU, μ::MU) where {NU,MU}
check_dof(ν, μ)
return new{NU,MU}(ν, μ)
end
end

@inline transport_to(ν, μ) = VarTransformation(ν, μ)
@inline transport_to(ν, μ) = TransportFunction(ν, μ)

function Base.:(==)(a::VarTransformation, b::VarTransformation)
function Base.:(==)(a::TransportFunction, b::TransportFunction)
return a.ν == b.ν && a.μ == b.μ
end


Base.@propagate_inbounds function (f::VarTransformation)(x)
return transport_def(f.ν, f.μ, checked_var(f.μ, x))
Base.@propagate_inbounds function (f::TransportFunction)(x)
return transport_def(f.ν, f.μ, checked_arg(f.μ, x))
end

@inline function InverseFunctions.inverse(f::VarTransformation{NU,MU}) where {NU,MU}
return VarTransformation{MU,NU}(f.μ, f.ν)
@inline function InverseFunctions.inverse(f::TransportFunction{NU,MU}) where {NU,MU}
return TransportFunction{MU,NU}(f.μ, f.ν)
end


function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x)
function ChangesOfVariables.with_logabsdet_jacobian(f::TransportFunction, x)
y = f(x)
logpdf_src = logdensityof(f.μ, x)
logpdf_trg = logdensityof(f.ν, y)
@@ -228,26 +228,26 @@ function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x)
end


Base.:(∘)(::typeof(identity), f::VarTransformation) = f
Base.:(∘)(f::VarTransformation, ::typeof(identity)) = f
Base.:(∘)(::typeof(identity), f::TransportFunction) = f
Base.:(∘)(f::TransportFunction, ::typeof(identity)) = f

function Base.:∘(outer::VarTransformation, inner::VarTransformation)
function Base.:∘(outer::TransportFunction, inner::TransportFunction)
if !(outer.μ == inner.ν || isequal(outer.μ, inner.ν) || outer.μ ≈ inner.ν)
throw(ArgumentError("Cannot compose VarTransformation if source of outer doesn't equal target of inner."))
throw(ArgumentError("Cannot compose TransportFunction if source of outer doesn't equal target of inner."))
end
return VarTransformation(outer.ν, inner.μ)
return TransportFunction(outer.ν, inner.μ)
end


function Base.show(io::IO, f::VarTransformation)
function Base.show(io::IO, f::TransportFunction)
print(io, Base.typename(typeof(f)).name, "(")
show(io, f.ν)
print(io, ", ")
show(io, f.μ)
print(io, ")")
end

Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f)
Base.show(io::IO, M::MIME"text/plain", f::TransportFunction) = show(io, f)


"""
@@ -262,7 +262,7 @@ abstract type TransformVolCorr end
NoVolCorr()

Indicate that density calculations should ignore the volume element of
var transformations. Should only be used in special cases in which
variate transformations. Should only be used in special cases in which
the volume element has already been taken into account in a different
way.
"""
@@ -272,7 +272,7 @@ struct NoVolCorr <: TransformVolCorr end
WithVolCorr()

Indicate that density calculations should take the volume element of
var transformations into account (typically via the
variate transformations into account (typically via the
log-abs-det-Jacobian of the transform).
"""
struct WithVolCorr <: TransformVolCorr end
14 changes: 7 additions & 7 deletions test/getdof.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Test

using MeasureBase: getdof, check_dof, checked_var
using MeasureBase: getdof, check_dof, checked_arg
using MeasureBase: StdUniform, StdExponential, StdLogistic
using ChainRulesTestUtils: test_rrule
using Static: static
@@ -18,18 +18,18 @@ using Static: static
@test_throws ArgumentError check_dof(μ2, μ0)
test_rrule(check_dof, μ0, StdUniform())

@test @inferred(checked_var(μ0, x0)) === x0
@test_throws ArgumentError checked_var(μ0, x2)
test_rrule(checked_var, μ0, x0)
@test @inferred(checked_arg(μ0, x0)) === x0
@test_throws ArgumentError checked_arg(μ0, x2)
test_rrule(checked_arg, μ0, x0)

@test @inferred(getdof(μ2)) == 6
@test (check_dof(μ2, StdUniform()^(1,6,1)); true)
@test_throws ArgumentError check_dof(μ2, μ0)
test_rrule(check_dof, μ2, StdUniform()^(1,6,1))

@test @inferred(checked_var(μ2, x2)) === x2
@test_throws ArgumentError checked_var(μ2, x0)
test_rrule(checked_var, μ2, x2)
@test @inferred(checked_arg(μ2, x2)) === x2
@test_throws ArgumentError checked_arg(μ2, x0)
test_rrule(checked_arg, μ2, x2)

@test @inferred(getdof((StdExponential()^3)^(static(0),static(0)))) === static(0)
end
24 changes: 12 additions & 12 deletions test/transport.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
using Test

using MeasureBase.Interface: transport_to, test_vartransform
using MeasureBase.Interface: transport_to, test_transport
using MeasureBase: StdUniform, StdExponential, StdLogistic
using MeasureBase: Dirac


@testset "transport_to" begin
for μ0 in [StdUniform(), StdExponential(), StdLogistic()], ν0 in [StdUniform(), StdExponential(), StdLogistic()]
@testset "transport_to (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin
test_vartransform(ν0, μ0)
test_vartransform(2.2 * ν0, 3 * μ0)
test_vartransform(ν0, μ0^1)
test_vartransform(ν0^1, μ0)
test_vartransform(ν0^3, μ0^3)
test_vartransform(ν0^(2,3,2), μ0^(3,4))
test_vartransform(2.2 * ν0^(2,3,2), 3 * μ0^(3,4))
test_transport(ν0, μ0)
test_transport(2.2 * ν0, 3 * μ0)
test_transport(ν0, μ0^1)
test_transport(ν0^1, μ0)
test_transport(ν0^3, μ0^3)
test_transport(ν0^(2,3,2), μ0^(3,4))
test_transport(2.2 * ν0^(2,3,2), 3 * μ0^(3,4))
@test_throws ArgumentError transport_to(ν0, μ0)(rand(μ0^12))
@test_throws ArgumentError transport_to(ν0^3, μ0^3)(rand(μ0^(3,4)))
end
end

@testset "transfrom from/to Dirac" begin
μ = Dirac(4.2)
test_vartransform(StdExponential()^0, μ)
test_vartransform(StdExponential()^(0,0,0), μ)
test_vartransform(μ, StdExponential()^static(0))
test_vartransform(μ, StdExponential()^(static(0),static(0)))
test_transport(StdExponential()^0, μ)
test_transport(StdExponential()^(0,0,0), μ)
test_transport(μ, StdExponential()^static(0))
test_transport(μ, StdExponential()^(static(0),static(0)))
@test_throws ArgumentError transport_to(StdExponential()^1, μ)
@test_throws ArgumentError transport_to(μ, StdExponential()^1)
end