diff --git a/Project.toml b/Project.toml index a80ce47a..fbce68d7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,17 @@ name = "MeasureBase" uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14" authors = ["Chad Scherrer and contributors"] -version = "0.10.0" +version = "0.11.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" @@ -24,11 +27,14 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [compat] +ChainRulesCore = "1" +ChangesOfVariables = "0.1.3" Compat = "3.35, 4" ConstructionBase = "1.3" DensityInterface = "0.4" FillArrays = "0.12, 0.13" IfElse = "0.1" +InverseFunctions = "0.1.7" IrrationalConstants = "0.1" LogExpFunctions = "0.3" LogarithmicNumbers = "1" @@ -42,6 +48,7 @@ julia = "1.3" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" [targets] -test = ["Aqua"] +test = ["Aqua", "ChainRulesTestUtils"] diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index fb304c84..1905719a 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -1,5 +1,7 @@ module MeasureBase +using Base: @propagate_inbounds + using Random import Random: rand! import Random: gentype @@ -11,6 +13,9 @@ import DensityInterface: densityof import DensityInterface: DensityKind using DensityInterface +using InverseFunctions +using ChangesOfVariables + import Base.iterate import ConstructionBase using ConstructionBase: constructorof @@ -18,6 +23,7 @@ using ConstructionBase: constructorof using PrettyPrinting const Pretty = PrettyPrinting +using ChainRulesCore using FillArrays using Static @@ -32,20 +38,11 @@ export logdensity_def export basemeasure export basekernel export productmeasure - -""" - inssupport(m, x) - insupport(m) - -`insupport(m,x)` computes whether `x` is in the support of `m`. - -`insupport(m)` returns a function, and satisfies - - insupport(m)(x) == insupport(m, x) -""" -function insupport end - export insupport +export getdof +export transport_to + +include("insupport.jl") abstract type AbstractMeasure end @@ -63,7 +60,7 @@ gentype(μ::AbstractMeasure) = typeof(testvalue(μ)) # gentype(μ::AbstractMeasure) = gentype(basemeasure(μ)) using NaNMath -using LogExpFunctions: logsumexp +using LogExpFunctions: logsumexp, logistic, logit @deprecate instance_type(x) Core.Typeof(x) false @@ -94,6 +91,8 @@ using Compat using IrrationalConstants +include("getdof.jl") +include("transport.jl") include("schema.jl") include("splat.jl") include("proxies.jl") @@ -125,9 +124,9 @@ include("combinators/powerweighted.jl") include("combinators/conditional.jl") include("standard/stdmeasure.jl") -include("standard/stdnormal.jl") include("standard/stduniform.jl") include("standard/stdexponential.jl") +include("standard/stdlogistic.jl") include("rand.jl") diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 86dc2379..db5b72c9 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -85,6 +85,13 @@ end end end +@inline function logdensity_def( + d::PowerMeasure{M,NTuple{N, Base.OneTo{StaticInt{0}}}}, + x, +) where {M,N} + static(0.0) +end + @inline function insupport(μ::PowerMeasure, x) p = μ.parent all(x) do xj @@ -100,3 +107,24 @@ end dynamic(insupport(p, xj)) end end + + +@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes)) + +@inline getdof(::PowerMeasure{<:Any, NTuple{N,Base.OneTo{StaticInt{0}}}}) where N = static(0) + + +@propagate_inbounds function checked_var(μ::PowerMeasure, x::AbstractArray{<:Any}) + @boundscheck begin + sz_μ = map(length, μ.axes) + sz_x = size(x) + if sz_μ != sz_x + throw(ArgumentError("Size of variate doesn't match size of power measure")) + end + end + return x +end + +function checked_var(μ::PowerMeasure, x::Any) + throw(ArgumentError("Size of variate doesn't match size of power measure")) +end diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 45beb316..ee471861 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -13,3 +13,94 @@ function params(::AbstractTransformedMeasure) end function paramnames(::AbstractTransformedMeasure) end function parent(::AbstractTransformedMeasure) end + + +export PushforwardMeasure + +""" + struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward + f :: FF + inv_f :: IF + origin :: MU + volcorr :: VC + end +""" +struct PushforwardMeasure{FF,IF,M,VC<:TransformVolCorr} <: AbstractPushforward + f::FF + inv_f::IF + origin::M + volcorr::VC +end + +gettransform(ν::PushforwardMeasure) = ν.f +parent(ν::PushforwardMeasure) = ν.origin + + +function Pretty.tile(ν::PushforwardMeasure) + Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure) +end + + +@inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, y) where {FF,IF,M} + x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y) + logd_orig = logdensity_def(ν.origin, x_orig) + logd = float(logd_orig + inv_ladj) + neginf = oftype(logd, -Inf) + return ifelse( + # Zero density wins against infinite volume: + (isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) || + # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ? + # Return constant -Inf to prevent problems with ForwardDiff: + (isfinite(logd_orig) && (inv_ladj == -Inf)), + neginf, + logd + ) +end + +@inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M} + x_orig = to_origin(ν, y) + return logdensity_def(ν.origin, x_orig) +end + + +insupport(ν::PushforwardMeasure, y) = insupport(transport_origin(ν), to_origin(ν, y)) + +testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(transport_origin(ν))) + +@inline function basemeasure(ν::PushforwardMeasure) + PushforwardMeasure(ν.f, ν.inv_f, basemeasure(transport_origin(ν)), NoVolCorr()) +end + + +_pushfwd_dof(::Type{MU}, ::Type, dof) where MU = NoDOF{MU}() +_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof + +# Assume that DOF are preserved if with_logabsdet_jacobian is functional: +@inline function getdof(ν::MU) where {MU<:PushforwardMeasure} + T = Core.Compiler.return_type(testvalue, Tuple{typeof(ν.origin)}) + R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof(ν.f), T}) + _pushfwd_dof(MU, R, getdof(ν.origin)) +end + +# Bypass `checked_var`, would require potentially costly transformation: +@inline checked_var(::PushforwardMeasure, x) = x + + +@inline transport_origin(ν::PushforwardMeasure) = ν.origin +@inline from_origin(ν::PushforwardMeasure, x) = ν.f(x) +@inline to_origin(ν::PushforwardMeasure, y) = ν.inv_f(y) + +function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T + return from_origin(ν, rand(rng, T, transport_origin(ν))) +end + + +export pushfwd + +""" + pushfwd(f, μ, volcorr = WithVolCorr()) + +Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) +from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`. +""" +pushfwd(f, μ, volcorr = WithVolCorr()) = PushforwardMeasure(f, inverse(f), μ, volcorr) diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index b31d1939..aef9dbee 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -48,3 +48,7 @@ Base.:*(m::AbstractMeasure, k::Real) = k * m gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) + +transport_origin(ν::WeightedMeasure) = ν.base +to_origin(::WeightedMeasure, y) = y +from_origin(::WeightedMeasure, x) = x diff --git a/src/getdof.jl b/src/getdof.jl new file mode 100644 index 00000000..b0c8d864 --- /dev/null +++ b/src/getdof.jl @@ -0,0 +1,77 @@ +""" + MeasureBase.NoDOF{MU} + +Indicates that there is no way to compute degrees of freedom of a measure +of type `MU` with the given information, e.g. because the DOF are not +a global property of the measure. +""" +struct NoDOF{MU} end + + +""" + getdof(μ) + +Returns the effective number of degrees of freedom of variates of +measure `μ`. + +The effective NDOF my differ from the length of the variates. For example, +the effective NDOF for a Dirichlet distribution with variates of length `n` +is `n - 1`. + +Also see [`check_dof`](@ref). +""" +function getdof end + +# Prevent infinite recursion: +@inline _default_getdof(::Type{MU}, ::MU) where MU = NoDOF{MU} +@inline _default_getdof(::Type{MU}, mu_base) where MU = getdof(mu_base) + +@inline getdof(μ::MU) where MU = _default_getdof(MU, basemeasure(μ)) + + +""" + MeasureBase.check_dof(ν, μ)::Nothing + +Check if `ν` and `μ` have the same effective number of degrees of freedom +according to [`MeasureBase.getdof`](@ref). +""" +function check_dof end + +function check_dof(ν, μ) + n_ν = getdof(ν) + n_μ = getdof(μ) + if n_ν != n_μ + throw(ArgumentError("Measure ν of type $(nameof(typeof(ν))) has $(n_ν) DOF but μ of type $(nameof(typeof(μ))) has $(n_μ) DOF")) + end + return nothing +end + +_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent() +ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback + + +""" + MeasureBase.NoVarCheck{MU,T} + +Indicates that there is no way to check of a values of type `T` are +variate of measures of type `MU`. +""" +struct NoVarCheck{MU,T} end + + +""" + MeasureBase.checked_var(μ::MU, x::T)::T + +Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not, +return `NoVarCheck{MU,T}()` if not check can be performed. +""" +function checked_var end + +# Prevent infinite recursion: +@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T} +@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x) + +@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) + +_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ +ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback diff --git a/src/insupport.jl b/src/insupport.jl new file mode 100644 index 00000000..7d407d0d --- /dev/null +++ b/src/insupport.jl @@ -0,0 +1,32 @@ +""" + inssupport(m, x) + insupport(m) + +`insupport(m,x)` computes whether `x` is in the support of `m`. + +`insupport(m)` returns a function, and satisfies + +insupport(m)(x) == insupport(m, x) +""" +function insupport end + + +""" + MeasureBase.require_insupport(μ, x)::Nothing + +Checks if `x` is in the support of distribution/measure `μ`, throws an +`ArgumentError` if not. +""" +function require_insupport end + +_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent() +function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) + return require_insupport(μ, x), _require_insupport_pullback +end + +function require_insupport(μ, x) + if !insupport(μ, x) + throw(ArgumentError("x is not within the support of μ")) + end + return nothing +end diff --git a/src/interface.jl b/src/interface.jl index f6207fad..d27ee505 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -6,8 +6,14 @@ using Reexport using MeasureBase: basemeasure_depth, proxy using MeasureBase: insupport, basemeasure_sequence, commonbase +using MeasureBase: transport_to, NoVarTransform + +using DensityInterface: logdensityof +using InverseFunctions: inverse +using ChangesOfVariables: with_logabsdet_jacobian export test_interface +export test_vartransform export basemeasure_depth export proxy export insupport @@ -59,4 +65,26 @@ function test_interface(μ::M) where {M} end end + +function test_vartransform(ν, μ) + supertype(x::Real) = Real + supertype(x::AbstractArray{<:Real,N}) where N = AbstractArray{<:Real,N} + + @testset "transport_to $μ to $ν" begin + x = rand(μ) + @test !(@inferred(transport_to(ν, μ)(x)) isa NoVarTransform) + f = transport_to(ν, μ) + y = f(x) + @test @inferred(inverse(f)(y)) ≈ x + @test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{supertype(y),Real} + @test @inferred(with_logabsdet_jacobian(inverse(f), y)) isa Tuple{supertype(x),Real} + y2, ladj_fwd = with_logabsdet_jacobian(f, x) + x2, ladj_inv = with_logabsdet_jacobian(inverse(f), y) + @test x ≈ x2 + @test y ≈ y2 + @test ladj_fwd ≈ - ladj_inv + @test ladj_fwd ≈ logdensityof(μ, x) - logdensityof(ν, y) + end +end + end # module Interface diff --git a/src/primitives/dirac.jl b/src/primitives/dirac.jl index 57067ec7..2575382c 100644 --- a/src/primitives/dirac.jl +++ b/src/primitives/dirac.jl @@ -29,3 +29,10 @@ dirac(d::AbstractMeasure) = Dirac(rand(d)) testvalue(d::Dirac) = d.x insupport(d::Dirac, x) = x == d.x + +@inline getdof(::Dirac) = static(0) + +@propagate_inbounds function checked_var(μ::Dirac, x) + @boundscheck insupport(μ, x) || throw(ArgumentError("Invalid variate for measure")) + x +end diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index f7a00917..7049e88a 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -40,3 +40,11 @@ insupport(::Lebesgue{RealNumbers}, ::Real) = true logdensity_def(::LebesgueMeasure, ::CountingMeasure, x) = -Inf logdensity_def(::CountingMeasure, ::LebesgueMeasure, x) = Inf + +@inline getdof(::Lebesgue) = static(1) + +@inline checked_var(::Lebesgue, x::Real) = x + +@propagate_inbounds function checked_var(::Lebesgue, x::Any) + @boundscheck throw(ArgumentError("Invalid variate type for measure")) +end diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index a29794a8..76442639 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -7,6 +7,8 @@ insupport(d::StdExponential, x) = x ≥ zero(x) @inline logdensity_def(::StdExponential, x) = -x @inline basemeasure(::StdExponential) = Lebesgue() -function Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} - randexp(rng, T) -end +@inline transport_def(::StdUniform, μ::StdExponential, x) = - expm1(-x) +@inline transport_def(::StdExponential, μ::StdUniform, x) = - log1p(-x) + +Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} = randexp(rng, T) + diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl new file mode 100644 index 00000000..922cb2c6 --- /dev/null +++ b/src/standard/stdlogistic.jl @@ -0,0 +1,13 @@ +struct StdLogistic <: StdMeasure end + +export StdLogistic + +@inline insupport(d::StdLogistic, x) = true + +@inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2*log1pexp(u)) +@inline basemeasure(::StdLogistic) = Lebesgue() + +@inline transport_def(::StdUniform, μ::StdLogistic, x) = logistic(x) +@inline transport_def(::StdLogistic, μ::StdUniform, x) = logit(x) + +@inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T)) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index ef13e837..0dba932b 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -1,5 +1,45 @@ abstract type StdMeasure <: AbstractMeasure end StdMeasure(::typeof(rand)) = StdUniform() -StdMeasure(::typeof(randn)) = StdNormal() StdMeasure(::typeof(randexp)) = StdExponential() + + +@inline check_dof(::StdMeasure, ::StdMeasure) = nothing + + +@inline transport_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x + +function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) + return transport_def(ν, μ.parent, only(x)) +end + +function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) + return Fill(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)...) +end + +function transport_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, x) + return transport_to(ν.parent, μ.parent).(x) +end + +function transport_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, x) where {N,M} + return reshape(transport_to(ν.parent, μ.parent).(x), map(length, ν.axes)...) +end + + +# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}): + +_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M() +_std_measure(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof +_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ)) + +MeasureBase.transport_to(::Type{NU}, μ) where {NU<:StdMeasure} = transport_to(_std_measure_for(NU, μ), μ) +MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} = transport_to(ν, _std_measure_for(MU, ν)) + + +# Transform between standard measures and Dirac: + +@inline transport_def(ν::Dirac, ::PowerMeasure{<:MeasureBase.StdMeasure}, ::Any) = ν.x + +@inline function transport_def(ν::PowerMeasure{<:MeasureBase.StdMeasure}, ::Dirac, ::Any) + Zeros{Bool}(map(_ -> 0, ν.axes)) +end diff --git a/src/standard/stdnormal.jl b/src/standard/stdnormal.jl deleted file mode 100644 index a5beffaf..00000000 --- a/src/standard/stdnormal.jl +++ /dev/null @@ -1,11 +0,0 @@ -struct StdNormal <: StdMeasure end - -export StdNormal - -insupport(d::StdNormal, x) = true -insupport(d::StdNormal) = Returns(true) - -@inline logdensity_def(::StdNormal, x) = -x^2 / 2 -@inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), Lebesgue(ℝ)) - -Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdNormal) where {T} = randn(rng, T) diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 0bc0263f..d29dce80 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -7,4 +7,4 @@ insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline logdensity_def(::StdUniform, x) = zero(x) @inline basemeasure(::StdUniform) = Lebesgue() -Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = randn(rng, T) +Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = rand(rng, T) diff --git a/src/transport.jl b/src/transport.jl new file mode 100644 index 00000000..e3970f55 --- /dev/null +++ b/src/transport.jl @@ -0,0 +1,278 @@ +""" + struct MeasureBase.NoTransformOrigin{NU} + +Indicates that no (default) pullback measure is available for measures of +type `NU`. + +See [`MeasureBase.transport_origin`](@ref). +""" +struct NoTransformOrigin{NU} end + + +""" + MeasureBase.transport_origin(ν) + +Default measure to pullback to resp. pushforward from when transforming +between `ν` and another measure. +""" +function transport_origin end + +transport_origin(ν::NU) where NU = NoTransformOrigin{NU}() + + +""" + MeasureBase.from_origin(ν, x) + +Push `x` from `MeasureBase.transport_origin(μ)` forward to `ν`. +""" +function from_origin end + +from_origin(ν::NU, ::Any) where NU = NoTransformOrigin{NU}() + + +""" + MeasureBase.to_origin(ν, y) + +Pull `y` from `ν` back to `MeasureBase.transport_origin(ν)`. +""" +function to_origin end + +to_origin(ν::NU, ::Any) where NU = NoTransformOrigin{NU}(ν) + + +""" + struct MeasureBase.NoVarTransform{NU,MU} end + +Indicates that no transformation from a measure of type `MU` to a measure of +type `NU` could be found. +""" +struct NoVarTransform{NU,MU} end + + +""" + f = transport_to(ν, μ) + +Generates a [measurable function](https://en.wikipedia.org/wiki/Measurable_function) +`f` that transforms a value `x` distributed according to measure `μ` to +a value `y = f(x)` distributed according to a measure `ν`. + +The [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) +from `μ` under `f` is is equivalent to `ν`. + +If terms of random values this implies that `f(rand(μ))` is equivalent to +`rand(ν)` (if `rand(μ)` and `rand(ν)` are supported). + +The resulting function `f` should support +`ChangesOfVariables.with_logabsdet_jacobian(f, x)` if mathematically well-defined, +so that densities of `ν` can be derived from densities of `μ` via `f` (using +appropriate base measures). + +Returns NoTransformOrigin{typeof(ν),typeof(μ)} if no transformation from +`μ` to `ν` can be found. + +To add transformation rules for a measure type `MyMeasure`, specialize + +* `MeasureBase.transport_def(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...` +* `MeasureBase.transport_def(ν::MyMeasure, μ::SomeStdMeasure, x) = ...` + +and/or + +* `MeasureBase.transport_origin(ν::MyMeasure) = SomeMeasure(...)` +* `MeasureBase.from_origin(μ::MyMeasure, x) = y` +* `MeasureBase.to_origin(μ::MyMeasure, y) = x` + +and ensure `MeasureBase.getdof(μ::MyMeasure)` is defined correctly. + +A standard measure type like `StdUniform`, `StdExponential` or +`StdLogistic` may also be used as the source or target of the transform: + +```julia +f_to_uniform(StdUniform, μ) +f_to_uniform(ν, StdUniform) +``` + +Depending on [`getdof(μ)`](@ref) (resp. `ν`), an instance of the standard +distribution itself or a power of it (e.g. `StdUniform()` or +`StdUniform()^dof`) will be chosen as the transformation partner. +""" +function transport_to end + + +""" + transport_def(ν, μ, x) + +Transforms a value `x` distributed according to `μ` to a value `y` distributed +according to `ν`. + +If no specialized `transport_def(::MU, ::NU, ...)` is available then +the default implementation of`transport_def(ν, μ, x)` uses the following +strategy: + +* Evaluate [`transport_origin`](@ref) for μ and ν. Transform between + each and it's origin, if available, and use the origin(s) as intermediate + measures for another transformation. + +* If all else fails, try to transform from μ to a standard multivariate + uniform measure and then to ν. + +See [`transport_to`](@ref). +""" +function transport_def end + +transport_def(::Any, ::Any, x::NoTransformOrigin) = x +transport_def(::Any, ::Any, x::NoVarTransform) = x + +function transport_def(ν, μ, x) + _vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x) +end + + +@inline _origin_must_have_separate_type(::Type{MU}, μ_o) where MU = μ_o +function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where MU + throw(ArgumentError("Measure of type $MU and its origin must have separate types")) +end + +@inline function _checked_vartransform_origin(μ::MU) where MU + μ_o = transport_origin(μ) + _origin_must_have_separate_type(MU, μ_o) +end + + +function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x) + x_o = to_origin(μ, x) + # If μ is a pushforward then checked_var may have been bypassed, so check now: + y_o = transport_def(ν_o, μ_o, checked_var(μ_o, x_o)) + y = from_origin(ν, y_o) + return y +end + +function _vartransform_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x) + y_o = transport_def(ν_o, μ, x) + y = from_origin(ν, y_o) + return y +end + +function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) + x_o = to_origin(μ, x) + # If μ is a pushforward then checked_var may have been bypassed, so check now: + y = transport_def(ν, μ_o, checked_var(μ_o, x_o)) + return y +end + +function _vartransform_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x) + _vartransform_with_intermediate(ν, _vartransform_intermediate(ν, μ), μ, x) +end + + +@inline _vartransform_intermediate(ν, μ) = _vartransform_intermediate(getdof(ν), getdof(μ)) +@inline _vartransform_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ +@inline _vartransform_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform() + +function _vartransform_with_intermediate(ν, m, μ, x) + z = transport_def(m, μ, x) + y = transport_def(ν, m, z) + return y +end + +# Prevent infinite recursion in case vartransform_intermediate doesn't change type: +@inline _vartransform_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}() +@inline _vartransform_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}() + + +""" + struct VarTransformation <: Function + +Transforms a variate from one measure to a variate of another. + +In general `VarTransformation` should not be called directly, call +[`transport_to`](@ref) instead. +""" +struct VarTransformation{NU,MU} <: Function + ν::NU + μ::MU + + function VarTransformation{NU,MU}(ν::NU, μ::MU) where {NU,MU} + return new{NU,MU}(ν, μ) + end + + function VarTransformation(ν::NU, μ::MU) where {NU,MU} + check_dof(ν, μ) + return new{NU,MU}(ν, μ) + end +end + +@inline transport_to(ν, μ) = VarTransformation(ν, μ) + +function Base.:(==)(a::VarTransformation, b::VarTransformation) + return a.ν == b.ν && a.μ == b.μ +end + + +Base.@propagate_inbounds function (f::VarTransformation)(x) + return transport_def(f.ν, f.μ, checked_var(f.μ, x)) +end + +@inline function InverseFunctions.inverse(f::VarTransformation{NU,MU}) where {NU,MU} + return VarTransformation{MU,NU}(f.μ, f.ν) +end + + +function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x) + y = f(x) + logpdf_src = logdensityof(f.μ, x) + logpdf_trg = logdensityof(f.ν, y) + ladj = logpdf_src - logpdf_trg + # If logpdf_src and logpdf_trg are -Inf setting lafj to zero is safe: + fixed_ladj = logpdf_src == logpdf_trg == -Inf ? zero(ladj) : ladj + return y, fixed_ladj +end + + +Base.:(∘)(::typeof(identity), f::VarTransformation) = f +Base.:(∘)(f::VarTransformation, ::typeof(identity)) = f + +function Base.:∘(outer::VarTransformation, inner::VarTransformation) + if !(outer.μ == inner.ν || isequal(outer.μ, inner.ν) || outer.μ ≈ inner.ν) + throw(ArgumentError("Cannot compose VarTransformation if source of outer doesn't equal target of inner.")) + end + return VarTransformation(outer.ν, inner.μ) +end + + +function Base.show(io::IO, f::VarTransformation) + print(io, Base.typename(typeof(f)).name, "(") + show(io, f.ν) + print(io, ", ") + show(io, f.μ) + print(io, ")") +end + +Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f) + + +""" + abstract type TransformVolCorr + +Provides control over density correction by transform volume element. +Either [`NoVolCorr()`](@ref) or [`WithVolCorr()`](@ref) +""" +abstract type TransformVolCorr end + +""" + NoVolCorr() + +Indicate that density calculations should ignore the volume element of +var transformations. Should only be used in special cases in which +the volume element has already been taken into account in a different +way. +""" +struct NoVolCorr <: TransformVolCorr end + +""" + WithVolCorr() + +Indicate that density calculations should take the volume element of +var transformations into account (typically via the +log-abs-det-Jacobian of the transform). +""" +struct WithVolCorr <: TransformVolCorr end diff --git a/test/combinators/transformedmeasure.jl b/test/combinators/transformedmeasure.jl new file mode 100644 index 00000000..d4a9fcaa --- /dev/null +++ b/test/combinators/transformedmeasure.jl @@ -0,0 +1,21 @@ +using Test + +using MeasureBase: pushfwd, StdUniform, StdExponential, StdLogistic +using MeasureBase: pushfwd, PushforwardMeasure +using MeasureBase: transport_to +using Statistics: var +using DensityInterface: logdensityof + +@testset "transformedmeasure.jl" begin + μ = StdUniform() + @test @inferred(pushfwd((-) ∘ log1p ∘ (-), μ)) isa PushforwardMeasure + ν = pushfwd((-) ∘ log1p ∘ (-), μ) + ν_ref = StdExponential() + + y = rand(ν_ref) + @test @inferred(logdensityof(ν, y)) ≈ logdensityof(ν_ref, y) + + @test isapprox(var(rand(ν^(10^5))), 1, rtol = 0.05) + + @test transport_to(StdLogistic(), ν)(y) ≈ transport_to(StdLogistic(), ν)(y) +end diff --git a/test/getdof.jl b/test/getdof.jl new file mode 100644 index 00000000..c8d3953b --- /dev/null +++ b/test/getdof.jl @@ -0,0 +1,35 @@ +using Test + +using MeasureBase: getdof, check_dof, checked_var +using MeasureBase: StdUniform, StdExponential, StdLogistic +using ChainRulesTestUtils: test_rrule +using Static: static + + +@testset "getdof" begin + μ0 = StdExponential() + x0 = rand(μ0) + + μ2 = StdExponential()^(2,3) + x2 = rand(μ2) + + @test @inferred(getdof(μ0)) === static(1) + @test (check_dof(μ0, StdUniform()); true) + @test_throws ArgumentError check_dof(μ2, μ0) + test_rrule(check_dof, μ0, StdUniform()) + + @test @inferred(checked_var(μ0, x0)) === x0 + @test_throws ArgumentError checked_var(μ0, x2) + test_rrule(checked_var, μ0, x0) + + @test @inferred(getdof(μ2)) == 6 + @test (check_dof(μ2, StdUniform()^(1,6,1)); true) + @test_throws ArgumentError check_dof(μ2, μ0) + test_rrule(check_dof, μ2, StdUniform()^(1,6,1)) + + @test @inferred(checked_var(μ2, x2)) === x2 + @test_throws ArgumentError checked_var(μ2, x0) + test_rrule(checked_var, μ2, x2) + + @test @inferred(getdof((StdExponential()^3)^(static(0),static(0)))) === static(0) +end diff --git a/test/runtests.jl b/test/runtests.jl index c3128102..22dfb993 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,13 +37,13 @@ test_measures = [ Dirac(0) + Dirac(1) Dirac(0.0) + Lebesgue(ℝ) SpikeMixture(Lebesgue(ℝ), 0.2) - StdNormal() - StdNormal()^3 - StdNormal()^(2, 3) - 3 * StdNormal() - 0.2 * StdNormal() + 0.8 * Dirac(0.0) - Dirac(0.0) + StdNormal() - SpikeMixture(StdNormal(), 0.2) + StdLogistic() + StdLogistic()^3 + StdLogistic()^(2, 3) + 3 * StdLogistic() + 0.2 * StdLogistic() + 0.8 * Dirac(0.0) + Dirac(0.0) + StdLogistic() + SpikeMixture(StdLogistic(), 0.2) StdUniform() StdUniform()^3 StdUniform()^(2, 3) @@ -233,4 +233,8 @@ end # end end +include("getdof.jl") +include("transport.jl") + include("combinators/weighted.jl") +include("combinators/transformedmeasure.jl") diff --git a/test/transport.jl b/test/transport.jl new file mode 100644 index 00000000..0cb6e55a --- /dev/null +++ b/test/transport.jl @@ -0,0 +1,39 @@ +using Test + +using MeasureBase.Interface: transport_to, test_vartransform +using MeasureBase: StdUniform, StdExponential, StdLogistic +using MeasureBase: Dirac + + +@testset "transport_to" begin + for μ0 in [StdUniform(), StdExponential(), StdLogistic()], ν0 in [StdUniform(), StdExponential(), StdLogistic()] + @testset "transport_to (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin + test_vartransform(ν0, μ0) + test_vartransform(2.2 * ν0, 3 * μ0) + test_vartransform(ν0, μ0^1) + test_vartransform(ν0^1, μ0) + test_vartransform(ν0^3, μ0^3) + test_vartransform(ν0^(2,3,2), μ0^(3,4)) + test_vartransform(2.2 * ν0^(2,3,2), 3 * μ0^(3,4)) + @test_throws ArgumentError transport_to(ν0, μ0)(rand(μ0^12)) + @test_throws ArgumentError transport_to(ν0^3, μ0^3)(rand(μ0^(3,4))) + end + end + + @testset "transfrom from/to Dirac" begin + μ = Dirac(4.2) + test_vartransform(StdExponential()^0, μ) + test_vartransform(StdExponential()^(0,0,0), μ) + test_vartransform(μ, StdExponential()^static(0)) + test_vartransform(μ, StdExponential()^(static(0),static(0))) + @test_throws ArgumentError transport_to(StdExponential()^1, μ) + @test_throws ArgumentError transport_to(μ, StdExponential()^1) + end + + @testset "transport_to autosel" begin + @test @inferred(transport_to(StdExponential, StdUniform())) == transport_to(StdExponential(), StdUniform()) + @test @inferred(transport_to(StdExponential, StdUniform()^(2,3))) == transport_to(StdExponential()^6, StdUniform()^(2,3)) + @test @inferred(transport_to(StdUniform(), StdExponential)) == transport_to(StdUniform(), StdExponential()) + @test @inferred(transport_to(StdUniform()^(2,3), StdExponential)) == transport_to(StdUniform()^(2,3), StdExponential()^6) + end +end