From 25fa2e108afef7bcbe1a70ff9881c0db2d255bb4 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 15:14:07 +0200 Subject: [PATCH 01/70] Add ChainRulesCore to dependencies Already an indirect dependency via FillArrays. --- Project.toml | 2 ++ src/MeasureBase.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/Project.toml b/Project.toml index a80ce47a..445d3d94 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Chad Scherrer and contributors"] version = "0.10.0" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" @@ -24,6 +25,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [compat] +ChainRulesCore = "1" Compat = "3.35, 4" ConstructionBase = "1.3" DensityInterface = "0.4" diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index fb304c84..e57de3ee 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -18,6 +18,7 @@ using ConstructionBase: constructorof using PrettyPrinting const Pretty = PrettyPrinting +using ChainRulesCore using FillArrays using Static From a3e1bbf73a86d93fb04ad4427fbb7c20d9f6bcd4 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 15:33:55 +0200 Subject: [PATCH 02/70] Add InverseFunctions and ChangesOfVariables to deps --- Project.toml | 4 ++++ src/MeasureBase.jl | 3 +++ 2 files changed, 7 insertions(+) diff --git a/Project.toml b/Project.toml index 445d3d94..c6e134bb 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,13 @@ version = "0.10.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DensityInterface = "b429d917-457f-4dbc-8f4c-0cc954292b1d" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" +InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112" IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" @@ -26,11 +28,13 @@ Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775" [compat] ChainRulesCore = "1" +ChangesOfVariables = "0.1.3" Compat = "3.35, 4" ConstructionBase = "1.3" DensityInterface = "0.4" FillArrays = "0.12, 0.13" IfElse = "0.1" +InverseFunctions = "0.1.7" IrrationalConstants = "0.1" LogExpFunctions = "0.3" LogarithmicNumbers = "1" diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index e57de3ee..8be134cb 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -11,6 +11,9 @@ import DensityInterface: densityof import DensityInterface: DensityKind using DensityInterface +using InverseFunctions +using ChangesOfVariables + import Base.iterate import ConstructionBase using ConstructionBase: constructorof From 2015ccdc736b16794e503b9a66ad0afe3301b79b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 15:27:21 +0200 Subject: [PATCH 03/70] Add require_insupport --- src/MeasureBase.jl | 15 ++------------- src/insupport.jl | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 13 deletions(-) create mode 100644 src/insupport.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 8be134cb..a8d6a013 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -36,21 +36,10 @@ export logdensity_def export basemeasure export basekernel export productmeasure - -""" - inssupport(m, x) - insupport(m) - -`insupport(m,x)` computes whether `x` is in the support of `m`. - -`insupport(m)` returns a function, and satisfies - - insupport(m)(x) == insupport(m, x) -""" -function insupport end - export insupport +include("insupport.jl") + abstract type AbstractMeasure end using Static: @constprop diff --git a/src/insupport.jl b/src/insupport.jl new file mode 100644 index 00000000..ccfc841b --- /dev/null +++ b/src/insupport.jl @@ -0,0 +1,32 @@ +""" + inssupport(m, x) + insupport(m) + +`insupport(m,x)` computes whether `x` is in the support of `m`. + +`insupport(m)` returns a function, and satisfies + +insupport(m)(x) == insupport(m, x) +""" +function insupport end + + +""" + MeasureBase.require_insupport(μ, x)::Nothing + +Checks if `x` is in the support of distribution/measure `μ`, throws an +`ArgumentError` if not. +""" +function require_insupport end + +_check_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent() +function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) + return require_insupport(μ, x), _check_insupport_pullback +end + +function require_insupport(μ, x::AbstractArray{T,N}) where {T,N} + if !insupport(μ, x) + throw(ArgumentError("x is not within the support of μ")) + end + return nothing +end From e86cf2c3aa0cbc4ab1a29c2d783e4fbb367ef941 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 15:29:04 +0200 Subject: [PATCH 04/70] Add effndof and require_same_effndof --- src/MeasureBase.jl | 1 + src/effndof.jl | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+) create mode 100644 src/effndof.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index a8d6a013..102259a1 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -87,6 +87,7 @@ using Compat using IrrationalConstants +include("effndof.jl") include("schema.jl") include("splat.jl") include("proxies.jl") diff --git a/src/effndof.jl b/src/effndof.jl new file mode 100644 index 00000000..af0fddb9 --- /dev/null +++ b/src/effndof.jl @@ -0,0 +1,33 @@ +""" + effndof(μ) + +Returns the effective number of degrees of freedom of variates of +measure-like object `μ`. + +The effective NDOF my differ from the length of the variates. For example, +the effective NDOF for a Dirichlet distribution with variates of length `n` +is `n - 1`. + +Also see [`require_same_effndof`](@ref). +""" +function effndof end + + +""" + MeasureBase.require_same_effndof(a, b)::Nothing + +Check if `a` and `b` have the same effective number of degrees of freedom +according to [`MeasureBase.effndof`](@ref). +""" +function require_same_effndof end + +ChainRulesCore.rrule(::typeof(require_same_effndof), a, b) = nothing, _nogradient_pullback2 + +function require_same_effndof(a, b) + trg_d_n = effndof(ν) + src_d_n = effndof(μ) + if trg_d_n != src_d_n + throw(ArgumentError("Can't convert to $(typeof(ν).name) with $(trg_d_n) eff. DOF from $(typeof(μ).name) with $(src_d_n) eff. DOF")) + end + return nothing +end From 2e8e7d6fed52a3a9c47dc997545d33858e69ffbf Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 15:42:26 +0200 Subject: [PATCH 05/70] Add check_varshape --- src/insupport.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/insupport.jl b/src/insupport.jl index ccfc841b..c063e03d 100644 --- a/src/insupport.jl +++ b/src/insupport.jl @@ -30,3 +30,15 @@ function require_insupport(μ, x::AbstractArray{T,N}) where {T,N} end return nothing end + + +""" + MeasureBase.check_varshape(μ, x)::Nothing + +Checks if `x` has the correct shape/size for a variate of measure-like object +`μ`, throws an `ArgumentError` if not. +""" +function check_varshape end + +_check_varshape_pullback(ΔΩ) = NoTangent(), ZeroTangent() +ChainRulesCore.rrule(::typeof(check_varshape), μ, x) = check_varshape(μ, x), _check_varshape_pullback From 5019b1415c947d67d5657382b88d8d58e4b6346a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 14:56:11 +0200 Subject: [PATCH 06/70] Add vartransform --- src/MeasureBase.jl | 1 + src/vartransform.jl | 216 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 src/vartransform.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 102259a1..f8574b9d 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -88,6 +88,7 @@ using Compat using IrrationalConstants include("effndof.jl") +include("vartransform.jl") include("schema.jl") include("splat.jl") include("proxies.jl") diff --git a/src/vartransform.jl b/src/vartransform.jl new file mode 100644 index 00000000..271f60cc --- /dev/null +++ b/src/vartransform.jl @@ -0,0 +1,216 @@ +""" + struct MeasureBase.NoTransformOrigin{MU} + +Indicates that no (default) pullback measure is available for measures of +type `MU`. + +See [`MeasureBase.vartransform_origin`](@ref). +""" +struct NoTransformOrigin{MU} end + + +""" + MeasureBase.vartransform_origin(μ) + +Default measure to pullback to resp. pushforward from when transforming +between `μ` and another measure. +""" +function vartransform_origin end + +vartransform_origin(m::M) where M = NoTransformOrigin{M}() + + +""" + MeasureBase.from_origin(μ, y) + +Push `y` from `MeasureBase.vartransform_origin(μ)` forward to `μ`. +""" +function from_origin end + +from_origin(m::M) where M = NoTransformOrigin{M}() + + +""" + MeasureBase.to_origin(μ, x) + +Pull `x` from `μ` back to `MeasureBase.vartransform_origin(μ)`. +""" +function to_origin end + +to_origin(m::M) where M = NoTransformOrigin{M}() + + +""" + struct MeasureBase.NoVarTransform{NU,MU} end + +Indicates that no transformation from a measure of type `MU` to a measure of +type `NU` could be found. +""" +struct NoVarTransform{NU,MU} end + + +""" + f = vartransform(ν, μ)::Function + +Generates a [measurable function](https://en.wikipedia.org/wiki/Measurable_function) +`f` that transforms values distributed according to measure-like object `μ` to +values distributed according to a measure-like object `ν`. + + y = vartransform(ν, μ, x) + +Transforms a value `x` distributed according to `μ` to a value `y` distributed +according to `ν`. + +The [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) +from `μ` under `f` is is equivalent to `ν`. + +If terms of random values this implies that `f(rand(μ))` is equivalent to +`rand(ν)` (if `rand(μ)` and `rand(ν)` are supported). + +The resulting function `f` should support +`ChangesOfVariables.with_logabsdet_jacobian(f, x)` if mathematically well-defined, +so that densities of `ν` can be derived from densities of `μ` via `f` (using +appropriate base measures). + +Returns NoTransformOrigin{typeof(ν),typeof(μ)} if no transformation from +`μ` to `ν` can be found. + +To add transformation rules for a measure-like type `MyMeasure`, specialize + +* `MeasureBase.vartransform(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...` +* `MeasureBase.vartransform(ν::MyMeasure, μ::SomeStdMeasure, x) = ...` + +and/or + +* `MeasureBase.vartransform_origin(ν::MyMeasure) = SomeMeasure(...)` +* `MeasureBase.from_origin(μ::MyMeasure, y) = x` +* `MeasureBase.to_origin(μ::MyMeasure, x) = y` + +and ensure `MeasureBase.effndof(μ::MyMeasure)` is defined correctly. + +If no direct transformation rule is available, `vartransform(ν, μ, x)` uses +the following strategy: + +* Evaluate [`vartransform_origin`](@ref) for μ and ν. If both have an origin, + select one as an intermediate measure using + [`select_vartransform_intermediate`](@ref). Try to transform from `μ` to + that intermediate measure and then to `ν` origin(s) of `μ` and/or `ν` if + available. + +* If all else fails, try to transform from μ to a standard multivariate + uniform measure and then to ν. +""" +function vartransform end + + +function _vartransform_with_intermediate(ν, m, μ, x) + x_m = vartransform(m, μ, x) + _vartransform_with_intermediate_step2(ν, m, x_m) +end + +@inline _vartransform_with_intermediate_step2(ν, m, x_m) = vartransform(ν, m, x_m) +@inline _vartransform_with_intermediate_step2(ν, m, x_m::NoTransformOrigin) = x_m + +function _vartransform_with_intermediate(ν, m::NoTransformOrigin, μ, x) + _vartransform_with_intermediate(ν, StdUniform()^effndof(μ), μ, x) +end + + +# Prevent endless recursion: +_vartransform_with_intermediate(::NU, ::NU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() +_vartransform_with_intermediate(::NU, ::MU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() + +function vartransform(ν, μ, x) + require_same_effndof(ν, μ) + m = vartransform_intermediate(vartransform_origin(ν), vartransform_origin(μ)) + _vartransform_with_intermediate(ν, m, μ, x) +end + +vartransform(::Any, ::Any, x::NoTransformOrigin) = x + + +""" + struct VarTransformation <: Function + +Transforms a variate from one measure-like object to a variate of another. + +In general users should not call `VarTransformation` directly, call +[`vartransform`](@ref) instead. +""" +struct VarTransformation{NU,MU} <: Function + ν::NU + μ::MU + + function VarTransformation{NU,MU}(ν::NU, μ::MU) where {NU,MU} + require_same_effndof(ν, μ) + return new{NU,MU}(ν, μ) + end + + function VarTransformation(ν::NU, μ::MU) where {NU,MU} + require_same_effndof(ν, μ) + return new{NU,MU}(ν, μ) + end +end + +vartransform(ν, μ) = VarTransformation(ν, μ) + + +(f::VarTransformation)(x) = vartransform(f.ν, f.μ, x) + +InverseFunctions.inverse(f::VarTransformation) = VarTransformation(f.μ, f.ν) + + +function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x) + y = f(x) + logpdf_src = logdensityof(f.μ, x) + logpdf_trg = logdensityof(f.ν, y) + ladj = logpdf_src - logpdf_trg + # If logpdf_src and logpdf_trg are -Inf setting lafj to zero is safe: + fixed_ladj = logpdf_src == logpdf_trg == -Inf ? zero(ladj) : ladj + return y, fixed_ladj +end + + +Base.:(∘)(::typeof(identity), f::VarTransformation) = f +Base.:(∘)(f::VarTransformation, ::typeof(identity)) = f + +function Base.:∘(outer::VarTransformation, inner::VarTransformation) + if !(outer.μ == inner.ν || isequal(outer.μ, inner.ν) || outer.μ ≈ inner.ν) + throw(ArgumentError("Cannot compose VarTransformation if source of outer doesn't equal target of inner.")) + end + return VarTransformation(outer.ν, inner.μ) +end + + +function Base.show(io::IO, f::VarTransformation) + print(io, Base.typename(typeof(f)).name, "(") + show(io, f.ν) + print(io, ", ") + show(io, f.μ) + print(io, ")") +end + +Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f) + + + + + +""" + MeasureBase.select_vartransform_intermediate(a, b) + +Selects one of two candidate pullback measures `a, b` to use as an +intermediate in variate transformations. + +See [`MeasureBase.vartransform_intermediate`](@ref). +""" +function select_vartransform_intermediate end + +select_vartransform_intermediate(nu, ::NoTransformOrigin) = nu +select_vartransform_intermediate(::NoTransformOrigin, mu) = mu +select_vartransform_intermediate(::NoTransformOrigin, mu::NoTransformOrigin) = mu + +# Ensure forward and inverse transformation use the same intermediate: +@generated function select_vartransform_intermediate(a, b) + return nameof(a) < nameof(b) ? :a : :b +end From 6ee4e3da286fb95d323e39fad9fb983da8168b16 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 21:10:26 +0200 Subject: [PATCH 07/70] Remove "measure-like" terminology --- src/effndof.jl | 2 +- src/vartransform.jl | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/effndof.jl b/src/effndof.jl index af0fddb9..bd941826 100644 --- a/src/effndof.jl +++ b/src/effndof.jl @@ -2,7 +2,7 @@ effndof(μ) Returns the effective number of degrees of freedom of variates of -measure-like object `μ`. +measure `μ`. The effective NDOF my differ from the length of the variates. For example, the effective NDOF for a Dirichlet distribution with variates of length `n` diff --git a/src/vartransform.jl b/src/vartransform.jl index 271f60cc..0c77f6d5 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -53,8 +53,8 @@ struct NoVarTransform{NU,MU} end f = vartransform(ν, μ)::Function Generates a [measurable function](https://en.wikipedia.org/wiki/Measurable_function) -`f` that transforms values distributed according to measure-like object `μ` to -values distributed according to a measure-like object `ν`. +`f` that transforms values distributed according to measure `μ` to +values distributed according to a measure `ν`. y = vartransform(ν, μ, x) @@ -75,7 +75,7 @@ appropriate base measures). Returns NoTransformOrigin{typeof(ν),typeof(μ)} if no transformation from `μ` to `ν` can be found. -To add transformation rules for a measure-like type `MyMeasure`, specialize +To add transformation rules for a measure type `MyMeasure`, specialize * `MeasureBase.vartransform(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...` * `MeasureBase.vartransform(ν::MyMeasure, μ::SomeStdMeasure, x) = ...` @@ -132,7 +132,7 @@ vartransform(::Any, ::Any, x::NoTransformOrigin) = x """ struct VarTransformation <: Function -Transforms a variate from one measure-like object to a variate of another. +Transforms a variate from one measure to a variate of another. In general users should not call `VarTransformation` directly, call [`vartransform`](@ref) instead. From 2217a4275bdc81e529dd5b6b1f287daaa0370024 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 21:11:47 +0200 Subject: [PATCH 08/70] Remove requirement for vartransform to return a Function --- src/vartransform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vartransform.jl b/src/vartransform.jl index 0c77f6d5..59314cda 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -50,7 +50,7 @@ struct NoVarTransform{NU,MU} end """ - f = vartransform(ν, μ)::Function + f = vartransform(ν, μ) Generates a [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f` that transforms values distributed according to measure `μ` to From bb0257d15741cc25879f8aaa2ea695787bf1eaca Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 21:17:57 +0200 Subject: [PATCH 09/70] Remove check_varshape --- src/insupport.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/insupport.jl b/src/insupport.jl index c063e03d..ccfc841b 100644 --- a/src/insupport.jl +++ b/src/insupport.jl @@ -30,15 +30,3 @@ function require_insupport(μ, x::AbstractArray{T,N}) where {T,N} end return nothing end - - -""" - MeasureBase.check_varshape(μ, x)::Nothing - -Checks if `x` has the correct shape/size for a variate of measure-like object -`μ`, throws an `ArgumentError` if not. -""" -function check_varshape end - -_check_varshape_pullback(ΔΩ) = NoTangent(), ZeroTangent() -ChainRulesCore.rrule(::typeof(check_varshape), μ, x) = check_varshape(μ, x), _check_varshape_pullback From 15952bbf27c2ef6c2045d0315a10f1e481b5edfa Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Wed, 15 Jun 2022 21:53:45 +0200 Subject: [PATCH 10/70] Rename effndof to getdof --- src/MeasureBase.jl | 2 +- src/{effndof.jl => getdof.jl} | 20 ++++++++++---------- src/vartransform.jl | 10 +++++----- 3 files changed, 16 insertions(+), 16 deletions(-) rename src/{effndof.jl => getdof.jl} (58%) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index f8574b9d..86162f4d 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -87,7 +87,7 @@ using Compat using IrrationalConstants -include("effndof.jl") +include("getdof.jl") include("vartransform.jl") include("schema.jl") include("splat.jl") diff --git a/src/effndof.jl b/src/getdof.jl similarity index 58% rename from src/effndof.jl rename to src/getdof.jl index bd941826..47c0c49d 100644 --- a/src/effndof.jl +++ b/src/getdof.jl @@ -1,5 +1,5 @@ """ - effndof(μ) + getdof(μ) Returns the effective number of degrees of freedom of variates of measure `μ`. @@ -8,24 +8,24 @@ The effective NDOF my differ from the length of the variates. For example, the effective NDOF for a Dirichlet distribution with variates of length `n` is `n - 1`. -Also see [`require_same_effndof`](@ref). +Also see [`check_dof`](@ref). """ -function effndof end +function getdof end """ - MeasureBase.require_same_effndof(a, b)::Nothing + MeasureBase.check_dof(a, b)::Nothing Check if `a` and `b` have the same effective number of degrees of freedom -according to [`MeasureBase.effndof`](@ref). +according to [`MeasureBase.getdof`](@ref). """ -function require_same_effndof end +function check_dof end -ChainRulesCore.rrule(::typeof(require_same_effndof), a, b) = nothing, _nogradient_pullback2 +ChainRulesCore.rrule(::typeof(check_dof), a, b) = nothing, _nogradient_pullback2 -function require_same_effndof(a, b) - trg_d_n = effndof(ν) - src_d_n = effndof(μ) +function check_dof(a, b) + trg_d_n = getdof(ν) + src_d_n = getdof(μ) if trg_d_n != src_d_n throw(ArgumentError("Can't convert to $(typeof(ν).name) with $(trg_d_n) eff. DOF from $(typeof(μ).name) with $(src_d_n) eff. DOF")) end diff --git a/src/vartransform.jl b/src/vartransform.jl index 59314cda..ba9627f0 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -86,7 +86,7 @@ and/or * `MeasureBase.from_origin(μ::MyMeasure, y) = x` * `MeasureBase.to_origin(μ::MyMeasure, x) = y` -and ensure `MeasureBase.effndof(μ::MyMeasure)` is defined correctly. +and ensure `MeasureBase.getdof(μ::MyMeasure)` is defined correctly. If no direct transformation rule is available, `vartransform(ν, μ, x)` uses the following strategy: @@ -112,7 +112,7 @@ end @inline _vartransform_with_intermediate_step2(ν, m, x_m::NoTransformOrigin) = x_m function _vartransform_with_intermediate(ν, m::NoTransformOrigin, μ, x) - _vartransform_with_intermediate(ν, StdUniform()^effndof(μ), μ, x) + _vartransform_with_intermediate(ν, StdUniform()^getdof(μ), μ, x) end @@ -121,7 +121,7 @@ _vartransform_with_intermediate(::NU, ::NU, ::MU, x) where {NU,MU} = NoVarTransf _vartransform_with_intermediate(::NU, ::MU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() function vartransform(ν, μ, x) - require_same_effndof(ν, μ) + check_dof(ν, μ) m = vartransform_intermediate(vartransform_origin(ν), vartransform_origin(μ)) _vartransform_with_intermediate(ν, m, μ, x) end @@ -142,12 +142,12 @@ struct VarTransformation{NU,MU} <: Function μ::MU function VarTransformation{NU,MU}(ν::NU, μ::MU) where {NU,MU} - require_same_effndof(ν, μ) + check_dof(ν, μ) return new{NU,MU}(ν, μ) end function VarTransformation(ν::NU, μ::MU) where {NU,MU} - require_same_effndof(ν, μ) + check_dof(ν, μ) return new{NU,MU}(ν, μ) end end From d5635fe027bed2662df49db75bda8f840b92f5d0 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 00:20:29 +0200 Subject: [PATCH 11/70] Separate vartransform and vartransform_def --- src/vartransform.jl | 43 +++++++++++++++++++++++++------------------ 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/src/vartransform.jl b/src/vartransform.jl index ba9627f0..22e2af47 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -53,13 +53,8 @@ struct NoVarTransform{NU,MU} end f = vartransform(ν, μ) Generates a [measurable function](https://en.wikipedia.org/wiki/Measurable_function) -`f` that transforms values distributed according to measure `μ` to -values distributed according to a measure `ν`. - - y = vartransform(ν, μ, x) - -Transforms a value `x` distributed according to `μ` to a value `y` distributed -according to `ν`. +`f` that transforms a value `x` distributed according to measure `μ` to +a value `y = f(x)` distributed according to a measure `ν`. The [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` under `f` is is equivalent to `ν`. @@ -77,8 +72,8 @@ Returns NoTransformOrigin{typeof(ν),typeof(μ)} if no transformation from To add transformation rules for a measure type `MyMeasure`, specialize -* `MeasureBase.vartransform(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...` -* `MeasureBase.vartransform(ν::MyMeasure, μ::SomeStdMeasure, x) = ...` +* `MeasureBase.vartransform_def(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...` +* `MeasureBase.vartransform_def(ν::MyMeasure, μ::SomeStdMeasure, x) = ...` and/or @@ -87,9 +82,19 @@ and/or * `MeasureBase.to_origin(μ::MyMeasure, x) = y` and ensure `MeasureBase.getdof(μ::MyMeasure)` is defined correctly. +""" +function vartransform end + -If no direct transformation rule is available, `vartransform(ν, μ, x)` uses -the following strategy: +""" + vartransform_def(ν, μ, x) + +Transforms a value `x` distributed according to `μ` to a value `y` distributed +according to `ν`. + +If no specialized `vartransform_def(::MU, ::NU, ...)` is available then +the default implementation of`vartransform_def(ν, μ, x)` uses the following +strategy: * Evaluate [`vartransform_origin`](@ref) for μ and ν. If both have an origin, select one as an intermediate measure using @@ -99,16 +104,18 @@ the following strategy: * If all else fails, try to transform from μ to a standard multivariate uniform measure and then to ν. + +See [`vartransform`](@ref). """ -function vartransform end +function vartransform_def end function _vartransform_with_intermediate(ν, m, μ, x) - x_m = vartransform(m, μ, x) + x_m = vartransform_def(m, μ, x) _vartransform_with_intermediate_step2(ν, m, x_m) end -@inline _vartransform_with_intermediate_step2(ν, m, x_m) = vartransform(ν, m, x_m) +@inline _vartransform_with_intermediate_step2(ν, m, x_m) = vartransform_def(ν, m, x_m) @inline _vartransform_with_intermediate_step2(ν, m, x_m::NoTransformOrigin) = x_m function _vartransform_with_intermediate(ν, m::NoTransformOrigin, μ, x) @@ -120,13 +127,13 @@ end _vartransform_with_intermediate(::NU, ::NU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() _vartransform_with_intermediate(::NU, ::MU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() -function vartransform(ν, μ, x) +function vartransform_def(ν, μ, x) check_dof(ν, μ) m = vartransform_intermediate(vartransform_origin(ν), vartransform_origin(μ)) _vartransform_with_intermediate(ν, m, μ, x) end -vartransform(::Any, ::Any, x::NoTransformOrigin) = x +vartransform_def(::Any, ::Any, x::NoTransformOrigin) = x """ @@ -134,7 +141,7 @@ vartransform(::Any, ::Any, x::NoTransformOrigin) = x Transforms a variate from one measure to a variate of another. -In general users should not call `VarTransformation` directly, call +In general `VarTransformation` should not be called directly, call [`vartransform`](@ref) instead. """ struct VarTransformation{NU,MU} <: Function @@ -155,7 +162,7 @@ end vartransform(ν, μ) = VarTransformation(ν, μ) -(f::VarTransformation)(x) = vartransform(f.ν, f.μ, x) +(f::VarTransformation)(x) = vartransform_def(f.ν, f.μ, x) InverseFunctions.inverse(f::VarTransformation) = VarTransformation(f.μ, f.ν) From 57c83b3e752624c7c0280f16263f0e913aad4e81 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 01:21:44 +0200 Subject: [PATCH 12/70] Fix default vartransform_def --- src/vartransform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vartransform.jl b/src/vartransform.jl index 22e2af47..a265cd3e 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -129,7 +129,7 @@ _vartransform_with_intermediate(::NU, ::MU, ::MU, x) where {NU,MU} = NoVarTransf function vartransform_def(ν, μ, x) check_dof(ν, μ) - m = vartransform_intermediate(vartransform_origin(ν), vartransform_origin(μ)) + m = select_vartransform_intermediate(vartransform_origin(ν), vartransform_origin(μ)) _vartransform_with_intermediate(ν, m, μ, x) end From b2817f045657fefae95c911bc52e02d0123d2445 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 01:51:36 +0200 Subject: [PATCH 13/70] Remove select_vartransform_intermediate --- src/vartransform.jl | 69 ++++++++++++++++++++------------------------- 1 file changed, 30 insertions(+), 39 deletions(-) diff --git a/src/vartransform.jl b/src/vartransform.jl index a265cd3e..eddc41b4 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -96,11 +96,9 @@ If no specialized `vartransform_def(::MU, ::NU, ...)` is available then the default implementation of`vartransform_def(ν, μ, x)` uses the following strategy: -* Evaluate [`vartransform_origin`](@ref) for μ and ν. If both have an origin, - select one as an intermediate measure using - [`select_vartransform_intermediate`](@ref). Try to transform from `μ` to - that intermediate measure and then to `ν` origin(s) of `μ` and/or `ν` if - available. +* Evaluate [`vartransform_origin`](@ref) for μ and ν. Transform between + each and it's origin, if available, and use the origin(s) as intermediate + measures for another transformation. * If all else fails, try to transform from μ to a standard multivariate uniform measure and then to ν. @@ -110,30 +108,46 @@ See [`vartransform`](@ref). function vartransform_def end -function _vartransform_with_intermediate(ν, m, μ, x) - x_m = vartransform_def(m, μ, x) - _vartransform_with_intermediate_step2(ν, m, x_m) +function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x) + x_o = to_origin(μ, x) + y_o = vartransform_def(ν_o, μ_o, x_o) + y = from_origin(ν, y_o) + return y end -@inline _vartransform_with_intermediate_step2(ν, m, x_m) = vartransform_def(ν, m, x_m) -@inline _vartransform_with_intermediate_step2(ν, m, x_m::NoTransformOrigin) = x_m +function _vartransform_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x) + y_o = vartransform_def(ν_o, μ, x) + y = from_origin(ν, y_o) + return y +end + +function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) + x_o = to_origin(μ, x) + y = vartransform_def(ν, μ_o, x_o) + return y +end -function _vartransform_with_intermediate(ν, m::NoTransformOrigin, μ, x) +function _vartransform_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x) _vartransform_with_intermediate(ν, StdUniform()^getdof(μ), μ, x) end +@inline _origin_must_have_separate_type(::Type{MU}, μ_o) where MU = μ_o +function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where MU + throw(ArgumentError("Measure of type $MU and its origin must have separate types")) +end -# Prevent endless recursion: -_vartransform_with_intermediate(::NU, ::NU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() -_vartransform_with_intermediate(::NU, ::MU, ::MU, x) where {NU,MU} = NoVarTransform{NU,MU}() +@inline function _checked_vartransform_origin(μ::MU) where MU + μ_o = vartransform_origin(μ) + _origin_must_have_separate_type(MU, μ_o) +end function vartransform_def(ν, μ, x) check_dof(ν, μ) - m = select_vartransform_intermediate(vartransform_origin(ν), vartransform_origin(μ)) - _vartransform_with_intermediate(ν, m, μ, x) + _vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x) end vartransform_def(::Any, ::Any, x::NoTransformOrigin) = x +vartransform_def(::Any, ::Any, x::NoVarTransform) = x """ @@ -198,26 +212,3 @@ function Base.show(io::IO, f::VarTransformation) end Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f) - - - - - -""" - MeasureBase.select_vartransform_intermediate(a, b) - -Selects one of two candidate pullback measures `a, b` to use as an -intermediate in variate transformations. - -See [`MeasureBase.vartransform_intermediate`](@ref). -""" -function select_vartransform_intermediate end - -select_vartransform_intermediate(nu, ::NoTransformOrigin) = nu -select_vartransform_intermediate(::NoTransformOrigin, mu) = mu -select_vartransform_intermediate(::NoTransformOrigin, mu::NoTransformOrigin) = mu - -# Ensure forward and inverse transformation use the same intermediate: -@generated function select_vartransform_intermediate(a, b) - return nameof(a) < nameof(b) ? :a : :b -end From 66f59907d66596be158fff1a0803e4c4568216c7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 02:38:52 +0200 Subject: [PATCH 14/70] Fix check_dof --- src/getdof.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/getdof.jl b/src/getdof.jl index 47c0c49d..8311edf0 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -23,7 +23,7 @@ function check_dof end ChainRulesCore.rrule(::typeof(check_dof), a, b) = nothing, _nogradient_pullback2 -function check_dof(a, b) +function check_dof(ν, μ) trg_d_n = getdof(ν) src_d_n = getdof(μ) if trg_d_n != src_d_n From de9684290d8b73f8cc6460ed002cb82678cd10e4 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 11:27:13 +0200 Subject: [PATCH 15/70] Export getdof --- src/MeasureBase.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 86162f4d..7c451fd8 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -37,6 +37,7 @@ export basemeasure export basekernel export productmeasure export insupport +export getdof include("insupport.jl") From 91010150b2824276395c8918fae62c32c9a198f9 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 11:27:35 +0200 Subject: [PATCH 16/70] Export vartransform --- src/MeasureBase.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 7c451fd8..9c805d50 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -38,6 +38,7 @@ export basekernel export productmeasure export insupport export getdof +export vartransform include("insupport.jl") From 45fe22050a22af5741ba20f34c04accbd25161a6 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 11:29:13 +0200 Subject: [PATCH 17/70] Implement getdof for measures --- src/standard/stdmeasure.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index ef13e837..28a9c0e3 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -3,3 +3,7 @@ abstract type StdMeasure <: AbstractMeasure end StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randn)) = StdNormal() StdMeasure(::typeof(randexp)) = StdExponential() + +getdof(::StdMeasure) = static(1) + +getdof(μ::PowerMeasure{<:StdMeasure}) = prod(map(length, d.axes)) From e7fff14cf2cf78951004f9da8832bbe90a1deac7 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 14:20:23 +0200 Subject: [PATCH 18/70] Remove StdNormal Can't be fully implemented without depending on SpecialFunctions.jl. --- src/MeasureBase.jl | 1 - src/standard/stdmeasure.jl | 1 - src/standard/stdnormal.jl | 11 ----------- 3 files changed, 13 deletions(-) delete mode 100644 src/standard/stdnormal.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 9c805d50..112fba6e 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -122,7 +122,6 @@ include("combinators/powerweighted.jl") include("combinators/conditional.jl") include("standard/stdmeasure.jl") -include("standard/stdnormal.jl") include("standard/stduniform.jl") include("standard/stdexponential.jl") diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 28a9c0e3..0753a28c 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -1,7 +1,6 @@ abstract type StdMeasure <: AbstractMeasure end StdMeasure(::typeof(rand)) = StdUniform() -StdMeasure(::typeof(randn)) = StdNormal() StdMeasure(::typeof(randexp)) = StdExponential() getdof(::StdMeasure) = static(1) diff --git a/src/standard/stdnormal.jl b/src/standard/stdnormal.jl deleted file mode 100644 index a5beffaf..00000000 --- a/src/standard/stdnormal.jl +++ /dev/null @@ -1,11 +0,0 @@ -struct StdNormal <: StdMeasure end - -export StdNormal - -insupport(d::StdNormal, x) = true -insupport(d::StdNormal) = Returns(true) - -@inline logdensity_def(::StdNormal, x) = -x^2 / 2 -@inline basemeasure(::StdNormal) = WeightedMeasure(static(-0.5 * log2π), Lebesgue(ℝ)) - -Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdNormal) where {T} = randn(rng, T) From 680e8f8b71968ef6de94d9040e4954e77b47aac5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 14:20:51 +0200 Subject: [PATCH 19/70] FIXUP implement getdof --- src/combinators/power.jl | 3 +++ src/standard/stdmeasure.jl | 2 -- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 86dc2379..d3c2e955 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -100,3 +100,6 @@ end dynamic(insupport(p, xj)) end end + + +@inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes)) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 0753a28c..742ee5ae 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -4,5 +4,3 @@ StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randexp)) = StdExponential() getdof(::StdMeasure) = static(1) - -getdof(μ::PowerMeasure{<:StdMeasure}) = prod(map(length, d.axes)) From 7d9e11e99ce6e15c15415c54788c5ac399e27de9 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 15:02:54 +0200 Subject: [PATCH 20/70] FIXUP implement getdof --- src/MeasureBase.jl | 2 +- src/standard/stdexponential.jl | 2 ++ src/standard/stdmeasure.jl | 2 -- src/standard/stduniform.jl | 2 ++ 4 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 112fba6e..be155e18 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -58,7 +58,7 @@ gentype(μ::AbstractMeasure) = typeof(testvalue(μ)) # gentype(μ::AbstractMeasure) = gentype(basemeasure(μ)) using NaNMath -using LogExpFunctions: logsumexp +using LogExpFunctions: logsumexp, logistic, logit @deprecate instance_type(x) Core.Typeof(x) false diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index a29794a8..263c4631 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -7,6 +7,8 @@ insupport(d::StdExponential, x) = x ≥ zero(x) @inline logdensity_def(::StdExponential, x) = -x @inline basemeasure(::StdExponential) = Lebesgue() +@inline getdof(::StdExponential) = static(1) + function Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} randexp(rng, T) end diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 742ee5ae..2ef880ba 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -2,5 +2,3 @@ abstract type StdMeasure <: AbstractMeasure end StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randexp)) = StdExponential() - -getdof(::StdMeasure) = static(1) diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 0bc0263f..674d15d1 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -7,4 +7,6 @@ insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline logdensity_def(::StdUniform, x) = zero(x) @inline basemeasure(::StdUniform) = Lebesgue() +@inline getdof(::StdUniform) = static(1) + Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = randn(rng, T) From c6368bba76ca6e0e0692e046b5e7b10bdaeba61a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 15:20:00 +0200 Subject: [PATCH 21/70] Add StdLogistic --- src/MeasureBase.jl | 1 + src/standard/stdlogistic.jl | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) create mode 100644 src/standard/stdlogistic.jl diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index be155e18..9bfd6a9a 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -124,6 +124,7 @@ include("combinators/conditional.jl") include("standard/stdmeasure.jl") include("standard/stduniform.jl") include("standard/stdexponential.jl") +include("standard/stdlogistic.jl") include("rand.jl") diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl new file mode 100644 index 00000000..14e20759 --- /dev/null +++ b/src/standard/stdlogistic.jl @@ -0,0 +1,18 @@ +struct StdLogistic <: StdMeasure end + +export StdLogistic + +@inline _isreal(x) = false +@inline _isreal(x::Real) = true + +@inline insupport(d::StdLogistic, x) = _isreal(x) +@inline insupport(d::StdLogistic) = _isreal + +@inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2*log1pexp(u)) +@inline basemeasure(::StdLogistic) = Lebesgue() + +@inline getdof(::StdLogistic) = static(1) + +@inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T)) + +@inline StdMeasure(::typeof(randn)) = StdLogistic() From 9372871e00df8964ee89bfaa953daf9246efbe2e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 16:12:38 +0200 Subject: [PATCH 22/70] Implement vartransform_def for StdMeasure --- src/standard/stdlogistic.jl | 3 +++ src/standard/stdmeasure.jl | 34 ++++++++++++++++++++++++++++++++++ src/standard/stduniform.jl | 3 +++ 3 files changed, 40 insertions(+) diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 14e20759..efc94425 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -11,6 +11,9 @@ export StdLogistic @inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2*log1pexp(u)) @inline basemeasure(::StdLogistic) = Lebesgue() +@inline vartransform_def(::StdUniform, ::StdLogistic, x::Real) = logistic(x) +@inline vartransform_def(::StdLogistic, ::StdUniform, x::Real) = logit(x) + @inline getdof(::StdLogistic) = static(1) @inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T)) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 2ef880ba..ff11de63 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -2,3 +2,37 @@ abstract type StdMeasure <: AbstractMeasure end StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randexp)) = StdExponential() + + +@inline vartransform_def(::MU, ::MU, x::Real) where {MU<:StdMeasure} = x + +function vartransform_def(@nospecialize(::StdMeasure), @nospecialize(d::StdMeasure), @nospecialize(x)) + throw(ArgumentError("$(typeof(x)) is not a valid variate type for measures of type $(typeof(d))")) +end + +@inline vartransform_def(::StdUniform, ::StdLogistic, x::Real) = logistic(x) +@inline vartransform_def(::StdLogistic, ::StdUniform, x::Real) = logit(x) + +@inline vartransform_def(::StdUniform, ::StdExponential, x::Real) = - expm1(-x) +@inline vartransform_def(::StdExponential, ::StdUniform, x::Real) = - log1p(-x) + + +function vartransform_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) + check_dof(ν, μ) + vartransform_def(ν, μ.parent, only(x)) +end + +function vartransform_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) + check_dof(ν, μ) + Fill(vartransform_def(ν.parent, μ, only(x)), map(length, ν.axes)...) +end + +function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, x) + check_dof(ν, μ) + vartransform(ν.parent, μ.parent).(x) +end + +function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, x) where {N,M} + check_dof(ν, μ) + reshape(vartransform(ν.parent, μ.parent).(x), map(length, ν.axes)...) +end diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 674d15d1..9763a1d0 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -7,6 +7,9 @@ insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline logdensity_def(::StdUniform, x) = zero(x) @inline basemeasure(::StdUniform) = Lebesgue() +@inline vartransform_def(::StdUniform, ::StdLogistic, x::Real) = logistic(x) +@inline vartransform_def(::StdLogistic, ::StdUniform, x::Real) = logit(x) + @inline getdof(::StdUniform) = static(1) Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = randn(rng, T) From 70f1e3067beb901cff6719763d14afd5b69c5ffb Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 16:14:08 +0200 Subject: [PATCH 23/70] FIXUP vartransform_def for StdMeasure --- src/standard/stdmeasure.jl | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index ff11de63..51578b0c 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -10,12 +10,6 @@ function vartransform_def(@nospecialize(::StdMeasure), @nospecialize(d::StdMeasu throw(ArgumentError("$(typeof(x)) is not a valid variate type for measures of type $(typeof(d))")) end -@inline vartransform_def(::StdUniform, ::StdLogistic, x::Real) = logistic(x) -@inline vartransform_def(::StdLogistic, ::StdUniform, x::Real) = logit(x) - -@inline vartransform_def(::StdUniform, ::StdExponential, x::Real) = - expm1(-x) -@inline vartransform_def(::StdExponential, ::StdUniform, x::Real) = - log1p(-x) - function vartransform_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) check_dof(ν, μ) From 147ba2abb7f0be1f96fb5941754de1491f7c6472 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 16:43:46 +0200 Subject: [PATCH 24/70] Add _vartransform_intermediate --- src/vartransform.jl | 44 ++++++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/src/vartransform.jl b/src/vartransform.jl index eddc41b4..64bcf6ef 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -107,6 +107,25 @@ See [`vartransform`](@ref). """ function vartransform_def end +vartransform_def(::Any, ::Any, x::NoTransformOrigin) = x +vartransform_def(::Any, ::Any, x::NoVarTransform) = x + +function vartransform_def(ν, μ, x) + check_dof(ν, μ) + _vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x) +end + + +@inline _origin_must_have_separate_type(::Type{MU}, μ_o) where MU = μ_o +function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where MU + throw(ArgumentError("Measure of type $MU and its origin must have separate types")) +end + +@inline function _checked_vartransform_origin(μ::MU) where MU + μ_o = vartransform_origin(μ) + _origin_must_have_separate_type(MU, μ_o) +end + function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x) x_o = to_origin(μ, x) @@ -128,26 +147,23 @@ function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) end function _vartransform_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x) - _vartransform_with_intermediate(ν, StdUniform()^getdof(μ), μ, x) + _vartransform_with_intermediate(ν, _vartransform_intermediate(ν, μ), μ, x) end -@inline _origin_must_have_separate_type(::Type{MU}, μ_o) where MU = μ_o -function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where MU - throw(ArgumentError("Measure of type $MU and its origin must have separate types")) -end -@inline function _checked_vartransform_origin(μ::MU) where MU - μ_o = vartransform_origin(μ) - _origin_must_have_separate_type(MU, μ_o) -end +@inline _vartransform_intermediate(ν, μ) = _vartransform_intermediate(getdof(ν), getdof(μ)) +@inline _vartransform_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ +@inline _vartransform_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform() -function vartransform_def(ν, μ, x) - check_dof(ν, μ) - _vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x) +function _vartransform_with_intermediate(ν, m, μ, x) + z = vartransform_def(m, μ, x) + y = vartransform_def(ν, m, z) + return y end -vartransform_def(::Any, ::Any, x::NoTransformOrigin) = x -vartransform_def(::Any, ::Any, x::NoVarTransform) = x +# Prevent infinite recursion in case vartransform_intermediate doesn't change type: +@inline _vartransform_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}() +@inline _vartransform_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}() """ From 0f263239921e535b29ee3e67f3e24f2ef936f749 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 16:45:47 +0200 Subject: [PATCH 25/70] FIXUP Implement vartransform_def for StdMeasure --- src/standard/stdexponential.jl | 3 +++ src/standard/stdlogistic.jl | 4 ++-- src/standard/stduniform.jl | 3 --- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index 263c4631..fb039893 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -9,6 +9,9 @@ insupport(d::StdExponential, x) = x ≥ zero(x) @inline getdof(::StdExponential) = static(1) +@inline vartransform_def(::StdUniform, ::StdExponential, x::Real) = - expm1(-x) +@inline vartransform_def(::StdExponential, ::StdUniform, x::Real) = - log1p(-x) + function Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} randexp(rng, T) end diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index efc94425..cac7d32a 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -11,11 +11,11 @@ export StdLogistic @inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2*log1pexp(u)) @inline basemeasure(::StdLogistic) = Lebesgue() +@inline getdof(::StdLogistic) = static(1) + @inline vartransform_def(::StdUniform, ::StdLogistic, x::Real) = logistic(x) @inline vartransform_def(::StdLogistic, ::StdUniform, x::Real) = logit(x) -@inline getdof(::StdLogistic) = static(1) - @inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T)) @inline StdMeasure(::typeof(randn)) = StdLogistic() diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 9763a1d0..674d15d1 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -7,9 +7,6 @@ insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline logdensity_def(::StdUniform, x) = zero(x) @inline basemeasure(::StdUniform) = Lebesgue() -@inline vartransform_def(::StdUniform, ::StdLogistic, x::Real) = logistic(x) -@inline vartransform_def(::StdLogistic, ::StdUniform, x::Real) = logit(x) - @inline getdof(::StdUniform) = static(1) Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = randn(rng, T) From 6ffa92ce24754eb50ea2626db6de79f63af060ba Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 16:51:10 +0200 Subject: [PATCH 26/70] Fix insupport for StdLogistic --- src/standard/stdlogistic.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index cac7d32a..01ff6494 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -2,11 +2,7 @@ struct StdLogistic <: StdMeasure end export StdLogistic -@inline _isreal(x) = false -@inline _isreal(x::Real) = true - -@inline insupport(d::StdLogistic, x) = _isreal(x) -@inline insupport(d::StdLogistic) = _isreal +@inline insupport(d::StdLogistic, x) = true @inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2*log1pexp(u)) @inline basemeasure(::StdLogistic) = Lebesgue() From 101f947f5a51c4c74eed3788753ed370f467699e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 16:56:54 +0200 Subject: [PATCH 27/70] Fix StdLogistic --- src/standard/stdlogistic.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 01ff6494..5089d082 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -13,5 +13,3 @@ export StdLogistic @inline vartransform_def(::StdLogistic, ::StdUniform, x::Real) = logit(x) @inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T)) - -@inline StdMeasure(::typeof(randn)) = StdLogistic() From 1193e05e2ae74173034d9a0e471ee67906d24d24 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 17:06:05 +0200 Subject: [PATCH 28/70] FIXUP StdMeasure vartransform --- src/standard/stdmeasure.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 51578b0c..ea4a3014 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -6,11 +6,6 @@ StdMeasure(::typeof(randexp)) = StdExponential() @inline vartransform_def(::MU, ::MU, x::Real) where {MU<:StdMeasure} = x -function vartransform_def(@nospecialize(::StdMeasure), @nospecialize(d::StdMeasure), @nospecialize(x)) - throw(ArgumentError("$(typeof(x)) is not a valid variate type for measures of type $(typeof(d))")) -end - - function vartransform_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) check_dof(ν, μ) vartransform_def(ν, μ.parent, only(x)) From d0509e96138213759c3070e9637559565189b23b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 17:21:29 +0200 Subject: [PATCH 29/70] Fix check_dof --- src/getdof.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/getdof.jl b/src/getdof.jl index 8311edf0..1b275d94 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -14,20 +14,20 @@ function getdof end """ - MeasureBase.check_dof(a, b)::Nothing + MeasureBase.check_dof(ν, μ)::Nothing -Check if `a` and `b` have the same effective number of degrees of freedom +Check if `ν` and `μ` have the same effective number of degrees of freedom according to [`MeasureBase.getdof`](@ref). """ function check_dof end -ChainRulesCore.rrule(::typeof(check_dof), a, b) = nothing, _nogradient_pullback2 - function check_dof(ν, μ) - trg_d_n = getdof(ν) - src_d_n = getdof(μ) - if trg_d_n != src_d_n - throw(ArgumentError("Can't convert to $(typeof(ν).name) with $(trg_d_n) eff. DOF from $(typeof(μ).name) with $(src_d_n) eff. DOF")) + n_ν = getdof(ν) + n_μ = getdof(μ) + if n_ν != n_μ + throw(ArgumentError("Measure ν of type $(nameof(typeof(ν))) has $(n_ν) DOF but μ of type $(nameof(typeof(μ))) has $(n_μ) DOF")) end return nothing end + +ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = NoTangent(), NoTangent(), NoTangent() From 9556e313711443d1dd5de90f1061673e99ec6c2c Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 18:10:34 +0200 Subject: [PATCH 30/70] Add checked_var --- src/combinators/power.jl | 11 +++++++++++ src/getdof.jl | 22 ++++++++++++++++++++++ src/standard/stdexponential.jl | 4 ++-- src/standard/stdlogistic.jl | 4 ++-- src/standard/stdmeasure.jl | 20 +++++++++++++++----- 5 files changed, 52 insertions(+), 9 deletions(-) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index d3c2e955..a3a63041 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -103,3 +103,14 @@ end @inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes)) + +@propagate_inbounds function checked_var(μ::PowerMeasure, x::AbstractArray{<:Any}) + @boundscheck begin + sz_μ = map(length, μ.axes) + sz_x = size(x) + if sz_μ != sz_x + throw(ArgumentError("Size of variate doesn't match size of power measure")) + end + end + return x +end diff --git a/src/getdof.jl b/src/getdof.jl index 1b275d94..999cec25 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -31,3 +31,25 @@ function check_dof(ν, μ) end ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = NoTangent(), NoTangent(), NoTangent() + + +""" + MeasureBase.NoVarCheck{MU,T} + +Indicates that there is no way to check of a values of type `T` are +variate of measures of type `MU`. +""" +struct NoVarCheck{MU,T} end + + +""" + MeasureBase.checked_var(μ::MU, x::T)::T + +Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not, +return `NoVarCheck{MU,T}()` if not check can be performed. +""" +function checked_var end + +@inline checked_var(::MU, ::T) where {MU,T} = NoVarCheck{MU,T} + +ChainRulesCore.rrule(::typeof(checked_var), ν, x) = NoTangent(), NoTangent(), ZeroTangent() diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index fb039893..09376af6 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -9,8 +9,8 @@ insupport(d::StdExponential, x) = x ≥ zero(x) @inline getdof(::StdExponential) = static(1) -@inline vartransform_def(::StdUniform, ::StdExponential, x::Real) = - expm1(-x) -@inline vartransform_def(::StdExponential, ::StdUniform, x::Real) = - log1p(-x) +@inline vartransform_def(::StdUniform, μ::StdExponential, x) = - expm1(- checked_var(μ, x)) +@inline vartransform_def(::StdExponential, μ::StdUniform, x) = - log1p(- checked_var(μ, x)) function Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} randexp(rng, T) diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 5089d082..cc5a3690 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -9,7 +9,7 @@ export StdLogistic @inline getdof(::StdLogistic) = static(1) -@inline vartransform_def(::StdUniform, ::StdLogistic, x::Real) = logistic(x) -@inline vartransform_def(::StdLogistic, ::StdUniform, x::Real) = logit(x) +@inline vartransform_def(::StdUniform, μ::StdLogistic, x) = logistic(checked_var(μ, x)) +@inline vartransform_def(::StdLogistic, μ::StdUniform, x) = logit(checked_var(μ, x)) @inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T)) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index ea4a3014..3730dbec 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -4,24 +4,34 @@ StdMeasure(::typeof(rand)) = StdUniform() StdMeasure(::typeof(randexp)) = StdExponential() -@inline vartransform_def(::MU, ::MU, x::Real) where {MU<:StdMeasure} = x +@inline check_dof(::StdMeasure, ::StdMeasure) = nothing + +@inline checked_var(::StdMeasure, x::Real) = x + +@propagate_inbounds function checked_var(::StdMeasure, x::Any) + @boundscheck throw(ArgumentError("Invalid variate type for measure")) +end + + +@inline vartransform_def(::MU, μ::MU, x) where {MU<:StdMeasure} = checked_var(μ, x) + function vartransform_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) check_dof(ν, μ) - vartransform_def(ν, μ.parent, only(x)) + return vartransform_def(ν, μ.parent, only(x)) end function vartransform_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) check_dof(ν, μ) - Fill(vartransform_def(ν.parent, μ, only(x)), map(length, ν.axes)...) + return Fill(vartransform_def(ν.parent, μ, only(x)), map(length, ν.axes)...) end function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, x) check_dof(ν, μ) - vartransform(ν.parent, μ.parent).(x) + return vartransform(ν.parent, μ.parent).(checked_var(μ, x)) end function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, x) where {N,M} check_dof(ν, μ) - reshape(vartransform(ν.parent, μ.parent).(x), map(length, ν.axes)...) + return reshape(vartransform(ν.parent, μ.parent).(checked_var(μ, x)), map(length, ν.axes)...) end From acde08b4c4cf96b796fe1bc4c4a00c1ea98b1019 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 20:02:21 +0200 Subject: [PATCH 31/70] WIP Add vartransform tests --- test/vartransform.jl | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 test/vartransform.jl diff --git a/test/vartransform.jl b/test/vartransform.jl new file mode 100644 index 00000000..ee637fb9 --- /dev/null +++ b/test/vartransform.jl @@ -0,0 +1,38 @@ +using Test + +using MeasureBase: vartransform, NoVarTransform +using DensityInterface: logdensityof +using InverseFunctions: inverse +using ChangesOfVariables: with_logabsdet_jacobian + +@testset "vartransform" begin + function test_transform_and_back(ν, μ) + @testset "vartransform powers from $(nameof(typeof(μ))) to $(ν)" begin + x = rand(μ) + @test !(@inferred(vartransform(ν, μ)(x)) isa NoVarTransform) + f = vartransform(ν, μ) + y = f(x) + @test @inferred(inverse(f)(y)) ≈ x + @test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{typeof(y),Real} + @test @inferred(with_logabsdet_jacobian(inverse(f), y)) isa Tuple{typeof(x),Real} + y2, ladj_fwd = with_logabsdet_jacobian(f, x) + x2, ladj_inv = with_logabsdet_jacobian(inverse(f), y) + @test x ≈ x2 + @test y ≈ y2 + @test ladj_fwd ≈ - ladj_inv + @test ladj_fwd ≈ logdensityof(μ, x) - logdensityof(ν, y) + end + end + + for μ in [StdUniform(), StdExponential(), StdLogistic()], ν in [StdUniform(), StdExponential(), StdLogistic()] + @testset "vartransform powers of $(nameof(typeof(μ))) to $(ν)" begin + test_transform_and_back(ν, μ) + test_transform_and_back(ν, μ^1) + test_transform_and_back(ν^1, μ) + test_transform_and_back(ν^3, μ^3) + test_transform_and_back(ν^(2,3,2), μ^(3,4)) + @test_throws ArgumentError vartransform(ν, μ)(rand(μ^12)) + @test_throws ArgumentError vartransform(ν^3, μ^3)(rand(μ^(3,4))) + end + end +end From 5a16befbba4f25329b8238c3d364eb3ef1d3db0d Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 20:14:43 +0200 Subject: [PATCH 32/70] FIXUP vartransform tests --- test/vartransform.jl | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/test/vartransform.jl b/test/vartransform.jl index ee637fb9..c3fee6af 100644 --- a/test/vartransform.jl +++ b/test/vartransform.jl @@ -6,15 +6,18 @@ using InverseFunctions: inverse using ChangesOfVariables: with_logabsdet_jacobian @testset "vartransform" begin + supertype(x::Real) = Real + supertype(x::AbstractArray{T,N}) where {T,N} = AbstractArray{T,N} + function test_transform_and_back(ν, μ) - @testset "vartransform powers from $(nameof(typeof(μ))) to $(ν)" begin + @testset "vartransform $μ to $ν0" begin x = rand(μ) @test !(@inferred(vartransform(ν, μ)(x)) isa NoVarTransform) f = vartransform(ν, μ) y = f(x) @test @inferred(inverse(f)(y)) ≈ x - @test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{typeof(y),Real} - @test @inferred(with_logabsdet_jacobian(inverse(f), y)) isa Tuple{typeof(x),Real} + @test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{supertype(y),Real} + @test @inferred(with_logabsdet_jacobian(inverse(f), y)) isa Tuple{supertype(x),Real} y2, ladj_fwd = with_logabsdet_jacobian(f, x) x2, ladj_inv = with_logabsdet_jacobian(inverse(f), y) @test x ≈ x2 @@ -24,15 +27,15 @@ using ChangesOfVariables: with_logabsdet_jacobian end end - for μ in [StdUniform(), StdExponential(), StdLogistic()], ν in [StdUniform(), StdExponential(), StdLogistic()] - @testset "vartransform powers of $(nameof(typeof(μ))) to $(ν)" begin - test_transform_and_back(ν, μ) - test_transform_and_back(ν, μ^1) - test_transform_and_back(ν^1, μ) - test_transform_and_back(ν^3, μ^3) - test_transform_and_back(ν^(2,3,2), μ^(3,4)) - @test_throws ArgumentError vartransform(ν, μ)(rand(μ^12)) - @test_throws ArgumentError vartransform(ν^3, μ^3)(rand(μ^(3,4))) + for μ0 in [StdUniform(), StdExponential(), StdLogistic()], ν0 in [StdUniform(), StdExponential(), StdLogistic()] + @testset "vartransform (powers of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin + test_transform_and_back(ν0, μ0) + test_transform_and_back(ν0, μ0^1) + test_transform_and_back(ν0^1, μ0) + test_transform_and_back(ν0^3, μ0^3) + test_transform_and_back(ν0^(2,3,2), μ0^(3,4)) + @test_throws ArgumentError vartransform(ν0, μ0)(rand(μ0^12)) + @test_throws ArgumentError vartransform(ν0^3, μ0^3)(rand(μ0^(3,4))) end end end From ebb7ddb14369bb0c0ee08a67f8ebaf00f8322830 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 21:13:15 +0200 Subject: [PATCH 33/70] Fix rand for StdUniform --- src/standard/stdexponential.jl | 5 ++--- src/standard/stduniform.jl | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index 09376af6..263e968a 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -12,6 +12,5 @@ insupport(d::StdExponential, x) = x ≥ zero(x) @inline vartransform_def(::StdUniform, μ::StdExponential, x) = - expm1(- checked_var(μ, x)) @inline vartransform_def(::StdExponential, μ::StdUniform, x) = - log1p(- checked_var(μ, x)) -function Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} - randexp(rng, T) -end +Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} = randexp(rng, T) + diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 674d15d1..81bbdc4c 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -9,4 +9,4 @@ insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline getdof(::StdUniform) = static(1) -Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = randn(rng, T) +Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = rand(rng, T) From 7b56f080991380db47c2b292054b63416f965f7a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 21:21:01 +0200 Subject: [PATCH 34/70] FIXUP vartransform tests --- test/vartransform.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/vartransform.jl b/test/vartransform.jl index c3fee6af..90088272 100644 --- a/test/vartransform.jl +++ b/test/vartransform.jl @@ -1,6 +1,7 @@ using Test using MeasureBase: vartransform, NoVarTransform +using MeasureBase: StdUniform, StdExponential, StdLogistic using DensityInterface: logdensityof using InverseFunctions: inverse using ChangesOfVariables: with_logabsdet_jacobian From 23e41c7cf11c864a1e2313ede423cd9072ce69d3 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 21:47:54 +0200 Subject: [PATCH 35/70] Use checked_var at VarTransformation input stage --- src/standard/stdexponential.jl | 4 ++-- src/standard/stdlogistic.jl | 4 ++-- src/standard/stdmeasure.jl | 6 +++--- src/vartransform.jl | 8 +++++--- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index 263e968a..b8daa1bb 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -9,8 +9,8 @@ insupport(d::StdExponential, x) = x ≥ zero(x) @inline getdof(::StdExponential) = static(1) -@inline vartransform_def(::StdUniform, μ::StdExponential, x) = - expm1(- checked_var(μ, x)) -@inline vartransform_def(::StdExponential, μ::StdUniform, x) = - log1p(- checked_var(μ, x)) +@inline vartransform_def(::StdUniform, μ::StdExponential, x) = - expm1(-x) +@inline vartransform_def(::StdExponential, μ::StdUniform, x) = - log1p(-x) Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} = randexp(rng, T) diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index cc5a3690..3c90690d 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -9,7 +9,7 @@ export StdLogistic @inline getdof(::StdLogistic) = static(1) -@inline vartransform_def(::StdUniform, μ::StdLogistic, x) = logistic(checked_var(μ, x)) -@inline vartransform_def(::StdLogistic, μ::StdUniform, x) = logit(checked_var(μ, x)) +@inline vartransform_def(::StdUniform, μ::StdLogistic, x) = logistic(x) +@inline vartransform_def(::StdLogistic, μ::StdUniform, x) = logit(x) @inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T)) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 3730dbec..53f353e6 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -13,7 +13,7 @@ StdMeasure(::typeof(randexp)) = StdExponential() end -@inline vartransform_def(::MU, μ::MU, x) where {MU<:StdMeasure} = checked_var(μ, x) +@inline vartransform_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x function vartransform_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) @@ -28,10 +28,10 @@ end function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, x) check_dof(ν, μ) - return vartransform(ν.parent, μ.parent).(checked_var(μ, x)) + return vartransform(ν.parent, μ.parent).(x) end function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, x) where {N,M} check_dof(ν, μ) - return reshape(vartransform(ν.parent, μ.parent).(checked_var(μ, x)), map(length, ν.axes)...) + return reshape(vartransform(ν.parent, μ.parent).(x), map(length, ν.axes)...) end diff --git a/src/vartransform.jl b/src/vartransform.jl index 64bcf6ef..3fb99ed1 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -189,12 +189,14 @@ struct VarTransformation{NU,MU} <: Function end end -vartransform(ν, μ) = VarTransformation(ν, μ) +@inline vartransform(ν, μ) = VarTransformation(ν, μ) -(f::VarTransformation)(x) = vartransform_def(f.ν, f.μ, x) +Base.@propagate_inbounds function (f::VarTransformation)(x) + return vartransform_def(f.ν, f.μ, checked_var(f.μ, x)) +end -InverseFunctions.inverse(f::VarTransformation) = VarTransformation(f.μ, f.ν) +@inline InverseFunctions.inverse(f::VarTransformation) = VarTransformation(f.μ, f.ν) function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x) From 6f5b2465558ad23a111c51d54bb0b36a4c86e073 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 21:48:35 +0200 Subject: [PATCH 36/70] FIX vartransform tests --- test/vartransform.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/vartransform.jl b/test/vartransform.jl index 90088272..15835af4 100644 --- a/test/vartransform.jl +++ b/test/vartransform.jl @@ -11,7 +11,7 @@ using ChangesOfVariables: with_logabsdet_jacobian supertype(x::AbstractArray{T,N}) where {T,N} = AbstractArray{T,N} function test_transform_and_back(ν, μ) - @testset "vartransform $μ to $ν0" begin + @testset "vartransform $μ to $ν" begin x = rand(μ) @test !(@inferred(vartransform(ν, μ)(x)) isa NoVarTransform) f = vartransform(ν, μ) From f5ebe6da7739a777194904828faa204e37b11604 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 23:07:06 +0200 Subject: [PATCH 37/70] Add defaults for check_dof and checked_var --- src/getdof.jl | 22 +++++++++++++++++++++- src/primitives/lebesgue.jl | 8 ++++++++ src/standard/stdexponential.jl | 2 -- src/standard/stdlogistic.jl | 2 -- src/standard/stdmeasure.jl | 7 ------- src/standard/stduniform.jl | 2 -- 6 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/getdof.jl b/src/getdof.jl index 999cec25..0f3dd4f3 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -1,3 +1,13 @@ +""" +MeasureBase.NoDOF{MU} + +Indicates that there is no way to compute degrees of freedom of a measure +of type `MU` with the given information, e.g. because the DOF are not +a global property of the measure. +""" +struct NoDOF{MU} end + + """ getdof(μ) @@ -12,6 +22,12 @@ Also see [`check_dof`](@ref). """ function getdof end +# Prevent infinite recursion: +@inline _default_getdof(::Type{MU}, ::MU) where MU = NoDOF{MU} +@inline _default_getdof(::Type{MU}, mu_base) where MU = getdof(mu_base) + +@inline getdof(μ::MU) where MU = _default_getdof(MU, basemeasure(μ)) + """ MeasureBase.check_dof(ν, μ)::Nothing @@ -50,6 +66,10 @@ return `NoVarCheck{MU,T}()` if not check can be performed. """ function checked_var end -@inline checked_var(::MU, ::T) where {MU,T} = NoVarCheck{MU,T} +# Prevent infinite recursion: +@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::Any) where MU = NoVarCheck{MU,T} +@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x) + +@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) ChainRulesCore.rrule(::typeof(checked_var), ν, x) = NoTangent(), NoTangent(), ZeroTangent() diff --git a/src/primitives/lebesgue.jl b/src/primitives/lebesgue.jl index f7a00917..7049e88a 100644 --- a/src/primitives/lebesgue.jl +++ b/src/primitives/lebesgue.jl @@ -40,3 +40,11 @@ insupport(::Lebesgue{RealNumbers}, ::Real) = true logdensity_def(::LebesgueMeasure, ::CountingMeasure, x) = -Inf logdensity_def(::CountingMeasure, ::LebesgueMeasure, x) = Inf + +@inline getdof(::Lebesgue) = static(1) + +@inline checked_var(::Lebesgue, x::Real) = x + +@propagate_inbounds function checked_var(::Lebesgue, x::Any) + @boundscheck throw(ArgumentError("Invalid variate type for measure")) +end diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index b8daa1bb..2b7b81a0 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -7,8 +7,6 @@ insupport(d::StdExponential, x) = x ≥ zero(x) @inline logdensity_def(::StdExponential, x) = -x @inline basemeasure(::StdExponential) = Lebesgue() -@inline getdof(::StdExponential) = static(1) - @inline vartransform_def(::StdUniform, μ::StdExponential, x) = - expm1(-x) @inline vartransform_def(::StdExponential, μ::StdUniform, x) = - log1p(-x) diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 3c90690d..8ddb0137 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -7,8 +7,6 @@ export StdLogistic @inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2*log1pexp(u)) @inline basemeasure(::StdLogistic) = Lebesgue() -@inline getdof(::StdLogistic) = static(1) - @inline vartransform_def(::StdUniform, μ::StdLogistic, x) = logistic(x) @inline vartransform_def(::StdLogistic, μ::StdUniform, x) = logit(x) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 53f353e6..71d88d67 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -6,16 +6,9 @@ StdMeasure(::typeof(randexp)) = StdExponential() @inline check_dof(::StdMeasure, ::StdMeasure) = nothing -@inline checked_var(::StdMeasure, x::Real) = x - -@propagate_inbounds function checked_var(::StdMeasure, x::Any) - @boundscheck throw(ArgumentError("Invalid variate type for measure")) -end - @inline vartransform_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x - function vartransform_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) check_dof(ν, μ) return vartransform_def(ν, μ.parent, only(x)) diff --git a/src/standard/stduniform.jl b/src/standard/stduniform.jl index 81bbdc4c..d29dce80 100644 --- a/src/standard/stduniform.jl +++ b/src/standard/stduniform.jl @@ -7,6 +7,4 @@ insupport(d::StdUniform, x) = zero(x) ≤ x ≤ one(x) @inline logdensity_def(::StdUniform, x) = zero(x) @inline basemeasure(::StdUniform) = Lebesgue() -@inline getdof(::StdUniform) = static(1) - Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdUniform) where {T} = rand(rng, T) From cbb6873eeb34dae5f77a31e5c5a99e56547e1a2e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 23:08:10 +0200 Subject: [PATCH 38/70] Add vartransform_origin for WeightedMeasure --- src/combinators/weighted.jl | 4 ++++ test/vartransform.jl | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index b31d1939..2cacb0a8 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -48,3 +48,7 @@ Base.:*(m::AbstractMeasure, k::Real) = k * m gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) + +vartransform_origin(μ::WeightedMeasure) = μ.base +to_origin(μ::WeightedMeasure, y) = y +from_origin(μ::WeightedMeasure, x) = x diff --git a/test/vartransform.jl b/test/vartransform.jl index 15835af4..d5abe66a 100644 --- a/test/vartransform.jl +++ b/test/vartransform.jl @@ -29,12 +29,14 @@ using ChangesOfVariables: with_logabsdet_jacobian end for μ0 in [StdUniform(), StdExponential(), StdLogistic()], ν0 in [StdUniform(), StdExponential(), StdLogistic()] - @testset "vartransform (powers of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin + @testset "vartransform (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin test_transform_and_back(ν0, μ0) + test_transform_and_back(2.2 * ν0, 3 * μ0) test_transform_and_back(ν0, μ0^1) test_transform_and_back(ν0^1, μ0) test_transform_and_back(ν0^3, μ0^3) test_transform_and_back(ν0^(2,3,2), μ0^(3,4)) + test_transform_and_back(2.2 * ν0^(2,3,2), 3 * μ0^(3,4)) @test_throws ArgumentError vartransform(ν0, μ0)(rand(μ0^12)) @test_throws ArgumentError vartransform(ν0^3, μ0^3)(rand(μ0^(3,4))) end From 1c52ed07cdc7e00e98b2bab0b8999dc0b2a1d7d5 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 23:10:08 +0200 Subject: [PATCH 39/70] Fix deps --- src/MeasureBase.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index 9bfd6a9a..e6c436d6 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -1,5 +1,7 @@ module MeasureBase +using Base: @propagate_inbounds + using Random import Random: rand! import Random: gentype From 7b954a53ac4b0d41751b4cc17961dae00e1e51ab Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Thu, 16 Jun 2022 23:21:46 +0200 Subject: [PATCH 40/70] Fix tests --- test/runtests.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index c3128102..a3358f04 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,13 +37,13 @@ test_measures = [ Dirac(0) + Dirac(1) Dirac(0.0) + Lebesgue(ℝ) SpikeMixture(Lebesgue(ℝ), 0.2) - StdNormal() - StdNormal()^3 - StdNormal()^(2, 3) - 3 * StdNormal() - 0.2 * StdNormal() + 0.8 * Dirac(0.0) - Dirac(0.0) + StdNormal() - SpikeMixture(StdNormal(), 0.2) + StdLogistic() + StdLogistic()^3 + StdLogistic()^(2, 3) + 3 * StdLogistic() + 0.2 * StdLogistic() + 0.8 * Dirac(0.0) + Dirac(0.0) + StdLogistic() + SpikeMixture(StdLogistic(), 0.2) StdUniform() StdUniform()^3 StdUniform()^(2, 3) From 5a1252313d057b6100b4bea4b7b6994ef879f327 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 00:25:47 +0200 Subject: [PATCH 41/70] WIP Add PushforwardMeasure --- src/combinators/transformedmeasure.jl | 74 +++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 45beb316..0c7abb58 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -13,3 +13,77 @@ function params(::AbstractTransformedMeasure) end function paramnames(::AbstractTransformedMeasure) end function parent(::AbstractTransformedMeasure) end + + +export PushforwardMeasure + +""" + struct PushforwardMeasure{FF,IF,M} <: AbstractPushforward + f :: FF + inv_f :: IF + origin :: M + end +""" +struct PushforwardMeasure{FF,IF,M} <: AbstractPushforward + f::FF + inv_f::IF + origin::M +end + +gettransform(μ::PushforwardMeasure) = μ.f +parent(μ::PushforwardMeasure) = μ.origin + + +function Pretty.tile(μ::PushforwardMeasure) + Pretty.list_layout(Pretty.tile.([μ.f, μ.inv_f, μ.origin]); prefix = :PushforwardMeasure) +end + +@inline function logdensity_def(μ::PushforwardMeasure, x) + x_orig, inv_ladj = with_logabsdet_jacobian(μ.inv_f, x) + logd_orig = logdensityof(μ.origin, x_orig) + + logd = logd_orig + inv_ladj + R = typeof(logd) + if isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf + # Zero density wins against infinite volume: + R(-Inf) + elseif isfinite(logd_orig) && (inv_ladj == -Inf) + # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ? + # Return constant -Inf to prevent problems with ForwardDiff: + R(-Inf) + else + logd + end +end + + +insupport(μ::PushforwardMeasure, x) = insupport(to_origin(μ, x)) + +testvalue(μ::PushforwardMeasure) = from_origin(μ, testvalue(vartransform_origin(μ))) + +@inline function basemeasure(μ::PushforwardMeasure) + PushforwardMeasure(μ.f, μ.inv_f, basemeasure(vartransform_origin(μ))) +end + +@inline getdof(::MU) where {MU<:PushforwardMeasure} = NoDOF{MU}() + +@inline checked_var(::MU, ::Any) where {MU<:PushforwardMeasure} = NoVarCheck{MU}() + +@inline vartransform_origin(μ::PushforwardMeasure) = μ.origin +@inline to_origin(μ::PushforwardMeasure, y) = μ.inv_f(y) +@inline from_origin(μ::PushforwardMeasure, x) = μ.f(x) + +function Base.rand(rng::AbstractRNG, ::Type{T}, μ::PushforwardMeasure) where T + return from_origin(μ, rand(rng, T, vartransform_origin(μ))) +end + + +export pushfwd + +""" + pushfwd(f, μ) + +Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) +from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`. +""" +pushfwd(f, μ) = PushforwardMeasure(f, inverse(f), μ) From 520562c2f108af1d22de2bab55c31fdfaaf0559b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 00:37:15 +0200 Subject: [PATCH 42/70] WIP improve PushforwardMeasure --- src/combinators/transformedmeasure.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 0c7abb58..6b3df6c2 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -42,15 +42,17 @@ end x_orig, inv_ladj = with_logabsdet_jacobian(μ.inv_f, x) logd_orig = logdensityof(μ.origin, x_orig) - logd = logd_orig + inv_ladj - R = typeof(logd) - if isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf + logd = float(logd_orig + inv_ladj) + neginf = oftype(logd, -Inf) + if ( # Zero density wins against infinite volume: - R(-Inf) - elseif isfinite(logd_orig) && (inv_ladj == -Inf) + (isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) || + # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ? # Return constant -Inf to prevent problems with ForwardDiff: - R(-Inf) + (isfinite(logd_orig) && (inv_ladj == -Inf)) + ) + neginf else logd end From 10b12dc1f99fd70dcfe61d5b19f26524baffaa48 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 00:39:09 +0200 Subject: [PATCH 43/70] WIP improve PushforwardMeasure --- src/combinators/transformedmeasure.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 6b3df6c2..7da8ce78 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -44,18 +44,15 @@ end logd = float(logd_orig + inv_ladj) neginf = oftype(logd, -Inf) - if ( + ifelse( # Zero density wins against infinite volume: (isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) || - # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ? # Return constant -Inf to prevent problems with ForwardDiff: - (isfinite(logd_orig) && (inv_ladj == -Inf)) - ) - neginf - else + (isfinite(logd_orig) && (inv_ladj == -Inf)), + neginf, logd - end + ) end From bfda82bc92949b92413589925f9ad9640c32f0d2 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 00:50:31 +0200 Subject: [PATCH 44/70] WIP improve PushforwardMeasure --- src/combinators/transformedmeasure.jl | 52 ++++++++++++++++----------- 1 file changed, 32 insertions(+), 20 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 7da8ce78..1e1f2bc4 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -15,36 +15,43 @@ function paramnames(::AbstractTransformedMeasure) end function parent(::AbstractTransformedMeasure) end +abstract type TransformVolCorr end +struct WithVolCorr <: TDVolCorr end +struct NoVolCorr <: TDVolCorr end + + export PushforwardMeasure """ - struct PushforwardMeasure{FF,IF,M} <: AbstractPushforward + struct PushforwardMeasure{FF,IF,MU,VC<:TransformVolCorr} <: AbstractPushforward f :: FF inv_f :: IF - origin :: M + origin :: MU + volcorr :: VC end """ -struct PushforwardMeasure{FF,IF,M} <: AbstractPushforward +struct PushforwardMeasure{FF,IF,M,VC<:TransformVolCorr} <: AbstractPushforward f::FF inv_f::IF origin::M + volcorr::VC end -gettransform(μ::PushforwardMeasure) = μ.f -parent(μ::PushforwardMeasure) = μ.origin +gettransform(ν::PushforwardMeasure) = ν.f +parent(ν::PushforwardMeasure) = ν.origin -function Pretty.tile(μ::PushforwardMeasure) - Pretty.list_layout(Pretty.tile.([μ.f, μ.inv_f, μ.origin]); prefix = :PushforwardMeasure) +function Pretty.tile(ν::PushforwardMeasure) + Pretty.list_layout(Pretty.tile.([ν.f, ν.inv_f, ν.origin]); prefix = :PushforwardMeasure) end -@inline function logdensity_def(μ::PushforwardMeasure, x) - x_orig, inv_ladj = with_logabsdet_jacobian(μ.inv_f, x) - logd_orig = logdensityof(μ.origin, x_orig) +@inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, y) where {FF,IF,M} + x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y) + logd_orig = logdensityof(ν.origin, x_orig) logd = float(logd_orig + inv_ladj) neginf = oftype(logd, -Inf) - ifelse( + return ifelse( # Zero density wins against infinite volume: (isnan(logd) && logd_orig == -Inf && inv_ladj == +Inf) || # Maybe also for (logd_orig == -Inf) && isfinite(inv_ladj) ? @@ -55,25 +62,30 @@ end ) end +@inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M} + x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y) + return logdensityof(ν.origin, x_orig) +end + -insupport(μ::PushforwardMeasure, x) = insupport(to_origin(μ, x)) +insupport(ν::PushforwardMeasure, y) = insupport(to_origin(ν, y)) -testvalue(μ::PushforwardMeasure) = from_origin(μ, testvalue(vartransform_origin(μ))) +testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(vartransform_origin(ν))) -@inline function basemeasure(μ::PushforwardMeasure) - PushforwardMeasure(μ.f, μ.inv_f, basemeasure(vartransform_origin(μ))) +@inline function basemeasure(ν::PushforwardMeasure) + PushforwardMeasure(ν.f, ν.inv_f, basemeasure(vartransform_origin(ν))) end @inline getdof(::MU) where {MU<:PushforwardMeasure} = NoDOF{MU}() @inline checked_var(::MU, ::Any) where {MU<:PushforwardMeasure} = NoVarCheck{MU}() -@inline vartransform_origin(μ::PushforwardMeasure) = μ.origin -@inline to_origin(μ::PushforwardMeasure, y) = μ.inv_f(y) -@inline from_origin(μ::PushforwardMeasure, x) = μ.f(x) +@inline vartransform_origin(ν::PushforwardMeasure) = ν.origin +@inline to_origin(ν::PushforwardMeasure, x) = ν.inv_f(x) +@inline from_origin(ν::PushforwardMeasure, y) = ν.f(y) -function Base.rand(rng::AbstractRNG, ::Type{T}, μ::PushforwardMeasure) where T - return from_origin(μ, rand(rng, T, vartransform_origin(μ))) +function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T + return from_origin(ν, rand(rng, T, vartransform_origin(ν))) end From 62eabb0c6c01a673e0b7e8187a381c7bb3277b2a Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 00:53:47 +0200 Subject: [PATCH 45/70] WIP improve PushforwardMeasure --- src/combinators/transformedmeasure.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 1e1f2bc4..0f26d7e6 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -63,7 +63,7 @@ end end @inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M} - x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y) + x_orig = to_origin(ν, y) return logdensityof(ν.origin, x_orig) end From a3a7b00b79fc38802d6dd20db10c04018c2b8e5b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 01:12:23 +0200 Subject: [PATCH 46/70] FIX PushforwardMeasure --- src/combinators/transformedmeasure.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 0f26d7e6..e83a4832 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -16,8 +16,8 @@ function parent(::AbstractTransformedMeasure) end abstract type TransformVolCorr end -struct WithVolCorr <: TDVolCorr end -struct NoVolCorr <: TDVolCorr end +struct WithVolCorr <: TransformVolCorr end +struct NoVolCorr <: TransformVolCorr end export PushforwardMeasure @@ -48,7 +48,7 @@ end @inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:WithVolCorr}, y) where {FF,IF,M} x_orig, inv_ladj = with_logabsdet_jacobian(ν.inv_f, y) - logd_orig = logdensityof(ν.origin, x_orig) + logd_orig = logdensity_def(ν.origin, x_orig) logd = float(logd_orig + inv_ladj) neginf = oftype(logd, -Inf) return ifelse( @@ -64,16 +64,16 @@ end @inline function logdensity_def(ν::PushforwardMeasure{FF,IF,M,<:NoVolCorr}, y) where {FF,IF,M} x_orig = to_origin(ν, y) - return logdensityof(ν.origin, x_orig) + return logdensity_def(ν.origin, x_orig) end -insupport(ν::PushforwardMeasure, y) = insupport(to_origin(ν, y)) +insupport(ν::PushforwardMeasure, y) = insupport(vartransform_origin(ν), to_origin(ν, y)) testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(vartransform_origin(ν))) @inline function basemeasure(ν::PushforwardMeasure) - PushforwardMeasure(ν.f, ν.inv_f, basemeasure(vartransform_origin(ν))) + PushforwardMeasure(ν.f, ν.inv_f, basemeasure(vartransform_origin(ν)), NoVolCorr()) end @inline getdof(::MU) where {MU<:PushforwardMeasure} = NoDOF{MU}() @@ -92,9 +92,9 @@ end export pushfwd """ - pushfwd(f, μ) + pushfwd(f, μ, volcorr = WithVolCorr()) Return the [pushforward measure](https://en.wikipedia.org/wiki/Pushforward_measure) from `μ` the [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f`. """ -pushfwd(f, μ) = PushforwardMeasure(f, inverse(f), μ) +pushfwd(f, μ, volcorr = WithVolCorr()) = PushforwardMeasure(f, inverse(f), μ, volcorr) From 39bf7b0ccaf336708f50c47e893134f8aab66410 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 02:27:39 +0200 Subject: [PATCH 47/70] Allow PushforwardMeasure to bypass checked_var --- src/combinators/transformedmeasure.jl | 15 +++++++++++++-- src/vartransform.jl | 6 ++++-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index e83a4832..e2fc130d 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -76,9 +76,20 @@ testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(vartransform_origi PushforwardMeasure(ν.f, ν.inv_f, basemeasure(vartransform_origin(ν)), NoVolCorr()) end -@inline getdof(::MU) where {MU<:PushforwardMeasure} = NoDOF{MU}() -@inline checked_var(::MU, ::Any) where {MU<:PushforwardMeasure} = NoVarCheck{MU}() +_pushfwd_dof(::Type{MU}, ::Type, dof) where MU = NoDOF{MU}() +_pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof + +# Assume that DOF are preserved if with_logabsdet_jacobian is functional: +@inline function getdof(ν::MU) where {MU<:PushforwardMeasure} + T = Core.Compiler.return_type(testvalue, Tuple{typeof(ν.origin)}) + R = Core.Compiler.return_type(with_logabsdet_jacobian, Tuple{typeof(ν.f), T}) + _pushfwd_dof(MU, R, getdof(ν.origin)) +end + +# Bypass `checked_var`, would require potentially costly transformation: +@inline checked_var(::PushforwardMeasure, x) = x + @inline vartransform_origin(ν::PushforwardMeasure) = ν.origin @inline to_origin(ν::PushforwardMeasure, x) = ν.inv_f(x) diff --git a/src/vartransform.jl b/src/vartransform.jl index 3fb99ed1..46933a33 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -129,7 +129,8 @@ end function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x) x_o = to_origin(μ, x) - y_o = vartransform_def(ν_o, μ_o, x_o) + # If μ is a pushforward then checked_var may have been bypassed, so check now: + y_o = vartransform_def(ν_o, μ_o, checked_var(μ_o, x_o)) y = from_origin(ν, y_o) return y end @@ -142,7 +143,8 @@ end function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) x_o = to_origin(μ, x) - y = vartransform_def(ν, μ_o, x_o) + # If μ is a pushforward then checked_var may have been bypassed, so check now: + y = vartransform_def(ν, μ_o, checked_var(μ_o, x_o)) return y end From da7ecc697b30c5ace6c122388dddd7b73007bcb2 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 02:30:46 +0200 Subject: [PATCH 48/70] Test PushforwardMeasure --- test/combinators/transformedmeasure.jl | 21 +++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 22 insertions(+) create mode 100644 test/combinators/transformedmeasure.jl diff --git a/test/combinators/transformedmeasure.jl b/test/combinators/transformedmeasure.jl new file mode 100644 index 00000000..7097d038 --- /dev/null +++ b/test/combinators/transformedmeasure.jl @@ -0,0 +1,21 @@ +using Test + +using MeasureBase: pushfwd, StdUniform, StdExponential, StdLogistic +using MeasureBase: pushfwd, PushforwardMeasure +using MeasureBase: vartransform +using Statistics: var +using DensityInterface: logdensityof + +@testset "transformedmeasure.jl" begin + μ = StdUniform() + @test @inferred(pushfwd((-) ∘ log1p ∘ (-), μ)) isa PushforwardMeasure + ν = pushfwd((-) ∘ log1p ∘ (-), μ) + ν_ref = StdExponential() + + y = rand(ν_ref) + @test @inferred(logdensityof(ν, y)) ≈ logdensityof(ν_ref, y) + + @test isapprox(var(rand(ν^(10^5))), 1, rtol = 0.05) + + @test vartransform(StdLogistic(), ν)(y) ≈ vartransform(StdLogistic(), ν)(y) +end diff --git a/test/runtests.jl b/test/runtests.jl index a3358f04..174b2111 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -234,3 +234,4 @@ end end include("combinators/weighted.jl") +include("combinators/transformedmeasure.jl") From 75e1fb3e4972415fbd7fed860cd08a5a35035704 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 16:49:37 +0200 Subject: [PATCH 49/70] Fix docstring of NoDOF Co-authored-by: Moritz Schauer --- src/getdof.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/getdof.jl b/src/getdof.jl index 0f3dd4f3..87dce2c3 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -1,5 +1,5 @@ """ -MeasureBase.NoDOF{MU} + MeasureBase.NoDOF{MU} Indicates that there is no way to compute degrees of freedom of a measure of type `MU` with the given information, e.g. because the DOF are not From 6250b20bcbb6eadbc490db8de4c51f4e507896c1 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 17:21:09 +0200 Subject: [PATCH 50/70] Add test_vartransform to Interface --- src/interface.jl | 28 ++++++++++++++++++++++++++++ test/vartransform.jl | 41 +++++++++-------------------------------- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index f6207fad..b1b6722f 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -6,8 +6,14 @@ using Reexport using MeasureBase: basemeasure_depth, proxy using MeasureBase: insupport, basemeasure_sequence, commonbase +using MeasureBase: vartransform, NoVarTransform + +using DensityInterface: logdensityof +using InverseFunctions: inverse +using ChangesOfVariables: with_logabsdet_jacobian export test_interface +export test_vartransform export basemeasure_depth export proxy export insupport @@ -59,4 +65,26 @@ function test_interface(μ::M) where {M} end end + +function test_vartransform(ν, μ) + supertype(x::Real) = Real + supertype(x::AbstractArray{T,N}) where {T,N} = AbstractArray{T,N} + + @testset "vartransform $μ to $ν" begin + x = rand(μ) + @test !(@inferred(vartransform(ν, μ)(x)) isa NoVarTransform) + f = vartransform(ν, μ) + y = f(x) + @test @inferred(inverse(f)(y)) ≈ x + @test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{supertype(y),Real} + @test @inferred(with_logabsdet_jacobian(inverse(f), y)) isa Tuple{supertype(x),Real} + y2, ladj_fwd = with_logabsdet_jacobian(f, x) + x2, ladj_inv = with_logabsdet_jacobian(inverse(f), y) + @test x ≈ x2 + @test y ≈ y2 + @test ladj_fwd ≈ - ladj_inv + @test ladj_fwd ≈ logdensityof(μ, x) - logdensityof(ν, y) + end +end + end # module Interface diff --git a/test/vartransform.jl b/test/vartransform.jl index d5abe66a..1070bae0 100644 --- a/test/vartransform.jl +++ b/test/vartransform.jl @@ -1,42 +1,19 @@ using Test -using MeasureBase: vartransform, NoVarTransform +using MeasureBase.Interface: vartransform, test_vartransform using MeasureBase: StdUniform, StdExponential, StdLogistic -using DensityInterface: logdensityof -using InverseFunctions: inverse -using ChangesOfVariables: with_logabsdet_jacobian -@testset "vartransform" begin - supertype(x::Real) = Real - supertype(x::AbstractArray{T,N}) where {T,N} = AbstractArray{T,N} - - function test_transform_and_back(ν, μ) - @testset "vartransform $μ to $ν" begin - x = rand(μ) - @test !(@inferred(vartransform(ν, μ)(x)) isa NoVarTransform) - f = vartransform(ν, μ) - y = f(x) - @test @inferred(inverse(f)(y)) ≈ x - @test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{supertype(y),Real} - @test @inferred(with_logabsdet_jacobian(inverse(f), y)) isa Tuple{supertype(x),Real} - y2, ladj_fwd = with_logabsdet_jacobian(f, x) - x2, ladj_inv = with_logabsdet_jacobian(inverse(f), y) - @test x ≈ x2 - @test y ≈ y2 - @test ladj_fwd ≈ - ladj_inv - @test ladj_fwd ≈ logdensityof(μ, x) - logdensityof(ν, y) - end - end +@testset "vartransform" begin for μ0 in [StdUniform(), StdExponential(), StdLogistic()], ν0 in [StdUniform(), StdExponential(), StdLogistic()] @testset "vartransform (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin - test_transform_and_back(ν0, μ0) - test_transform_and_back(2.2 * ν0, 3 * μ0) - test_transform_and_back(ν0, μ0^1) - test_transform_and_back(ν0^1, μ0) - test_transform_and_back(ν0^3, μ0^3) - test_transform_and_back(ν0^(2,3,2), μ0^(3,4)) - test_transform_and_back(2.2 * ν0^(2,3,2), 3 * μ0^(3,4)) + test_vartransform(ν0, μ0) + test_vartransform(2.2 * ν0, 3 * μ0) + test_vartransform(ν0, μ0^1) + test_vartransform(ν0^1, μ0) + test_vartransform(ν0^3, μ0^3) + test_vartransform(ν0^(2,3,2), μ0^(3,4)) + test_vartransform(2.2 * ν0^(2,3,2), 3 * μ0^(3,4)) @test_throws ArgumentError vartransform(ν0, μ0)(rand(μ0^12)) @test_throws ArgumentError vartransform(ν0^3, μ0^3)(rand(μ0^(3,4))) end From 9bfa9f94e415057058c454bf6d2fd920540380fc Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Fri, 17 Jun 2022 21:09:08 +0200 Subject: [PATCH 51/70] FIXUP _default_checked_var --- src/getdof.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/getdof.jl b/src/getdof.jl index 87dce2c3..aeee1738 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -67,7 +67,7 @@ return `NoVarCheck{MU,T}()` if not check can be performed. function checked_var end # Prevent infinite recursion: -@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::Any) where MU = NoVarCheck{MU,T} +@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T} @propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x) @propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) From d4f0246ae577e678ac4fffd00aa06aa56634c614 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 00:09:35 +0200 Subject: [PATCH 52/70] FIXUP vartransform_origin docs and defaults --- src/vartransform.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/vartransform.jl b/src/vartransform.jl index 46933a33..f585a08b 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -1,43 +1,43 @@ """ - struct MeasureBase.NoTransformOrigin{MU} + struct MeasureBase.NoTransformOrigin{NU} Indicates that no (default) pullback measure is available for measures of -type `MU`. +type `NU`. See [`MeasureBase.vartransform_origin`](@ref). """ -struct NoTransformOrigin{MU} end +struct NoTransformOrigin{NU} end """ - MeasureBase.vartransform_origin(μ) + MeasureBase.vartransform_origin(ν) Default measure to pullback to resp. pushforward from when transforming -between `μ` and another measure. +between `ν` and another measure. """ function vartransform_origin end -vartransform_origin(m::M) where M = NoTransformOrigin{M}() +vartransform_origin(ν::NU) where NU = NoTransformOrigin{NU}() """ - MeasureBase.from_origin(μ, y) + MeasureBase.from_origin(ν, x) -Push `y` from `MeasureBase.vartransform_origin(μ)` forward to `μ`. +Push `x` from `MeasureBase.vartransform_origin(μ)` forward to `ν`. """ function from_origin end -from_origin(m::M) where M = NoTransformOrigin{M}() +from_origin(ν::NU, ::Any) where NU = NoTransformOrigin{NU}() """ - MeasureBase.to_origin(μ, x) + MeasureBase.to_origin(ν, y) -Pull `x` from `μ` back to `MeasureBase.vartransform_origin(μ)`. +Pull `y` from `ν` back to `MeasureBase.vartransform_origin(ν)`. """ function to_origin end -to_origin(m::M) where M = NoTransformOrigin{M}() +to_origin(ν::NU, ::Any) where NU = NoTransformOrigin{NU}(ν) """ From 051293009d25a67355adaedb748951ea926fb4bd Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 00:09:54 +0200 Subject: [PATCH 53/70] Run vartransform tests --- test/runtests.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/runtests.jl b/test/runtests.jl index 174b2111..51fc560b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -233,5 +233,7 @@ end # end end +include("vartransform.jl") + include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") From 14890bb9346350d5fe5f5d2ad787e59c23cdb495 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 00:37:25 +0200 Subject: [PATCH 54/70] Improve vartransform_origin def for WeightedMeasure --- src/combinators/weighted.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index 2cacb0a8..22b9e440 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -49,6 +49,6 @@ gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) -vartransform_origin(μ::WeightedMeasure) = μ.base -to_origin(μ::WeightedMeasure, y) = y -from_origin(μ::WeightedMeasure, x) = x +vartransform_origin(ν::WeightedMeasure) = ν.base +to_origin(::WeightedMeasure, y) = y +from_origin(::WeightedMeasure, x) = x From b32f34d17cc54aff5b67b0ce6e47f1556280680e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 01:45:51 +0200 Subject: [PATCH 55/70] Add vartransform stdmeasure autodim --- src/standard/stdmeasure.jl | 10 ++++++++++ src/vartransform.jl | 12 ++++++++++++ test/vartransform.jl | 7 +++++++ 3 files changed, 29 insertions(+) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 71d88d67..7a9684fc 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -28,3 +28,13 @@ function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, check_dof(ν, μ) return reshape(vartransform(ν.parent, μ.parent).(x), map(length, ν.axes)...) end + + +# Implement vartransform(NU::Type{<:StdMeasure}, μ) and vartransform(ν, MU::Type{<:StdMeasure}): + +_std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M() +_std_measure(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof +_std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ)) + +MeasureBase.vartransform(::Type{NU}, μ) where {NU<:StdMeasure} = vartransform(_std_measure_for(NU, μ), μ) +MeasureBase.vartransform(ν, ::Type{MU}) where {MU<:StdMeasure} = vartransform(ν, _std_measure_for(MU, ν)) diff --git a/src/vartransform.jl b/src/vartransform.jl index f585a08b..7234e382 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -82,6 +82,18 @@ and/or * `MeasureBase.to_origin(μ::MyMeasure, x) = y` and ensure `MeasureBase.getdof(μ::MyMeasure)` is defined correctly. + +A standard measure type like `StdUniform`, `StdExponential` or +`StdLogistic` may also be used as the source or target of the transform: + +```julia +f_to_uniform(StdUniform, μ) +f_to_uniform(ν, StdUniform) +``` + +Depending on [`getdof(μ)`](@ref) (resp. `ν`), an instance of the standard +distribution itself or a power of it (e.g. `StdUniform()` or +`StdUniform()^dof`) will be chosen as the transformation partner. """ function vartransform end diff --git a/test/vartransform.jl b/test/vartransform.jl index 1070bae0..d6457539 100644 --- a/test/vartransform.jl +++ b/test/vartransform.jl @@ -18,4 +18,11 @@ using MeasureBase: StdUniform, StdExponential, StdLogistic @test_throws ArgumentError vartransform(ν0^3, μ0^3)(rand(μ0^(3,4))) end end + + @testset "vartransform autosel" begin + @test @inferred(vartransform(StdExponential, StdUniform())) == vartransform(StdExponential(), StdUniform()) + @test @inferred(vartransform(StdExponential, StdUniform()^(2,3))) == vartransform(StdExponential()^6, StdUniform()^(2,3)) + @test @inferred(vartransform(StdUniform(), StdExponential)) == vartransform(StdUniform(), StdExponential()) + @test @inferred(vartransform(StdUniform()^(2,3), StdExponential)) == vartransform(StdUniform()^(2,3), StdExponential()^6) + end end From f12f1b593ed8a28e34025fdf3f110779f1356e4c Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 02:31:29 +0200 Subject: [PATCH 56/70] Specialize equality for VarTransformation --- src/vartransform.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/vartransform.jl b/src/vartransform.jl index 7234e382..f229b7ea 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -193,7 +193,6 @@ struct VarTransformation{NU,MU} <: Function μ::MU function VarTransformation{NU,MU}(ν::NU, μ::MU) where {NU,MU} - check_dof(ν, μ) return new{NU,MU}(ν, μ) end @@ -205,12 +204,18 @@ end @inline vartransform(ν, μ) = VarTransformation(ν, μ) +function Base.:(==)(a::VarTransformation, b::VarTransformation) + return a.ν == b.ν && a.μ == b.μ +end + Base.@propagate_inbounds function (f::VarTransformation)(x) return vartransform_def(f.ν, f.μ, checked_var(f.μ, x)) end -@inline InverseFunctions.inverse(f::VarTransformation) = VarTransformation(f.μ, f.ν) +@inline function InverseFunctions.inverse(f::VarTransformation{NU,MU}) where {NU,MU} + return VarTransformation{MU,NU}(f.μ, f.ν) +end function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x) From b3dfe13b2e84dc1b5a5fc9bfddbf6dbf3118e996 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 03:09:33 +0200 Subject: [PATCH 57/70] Don't call check_dof so often --- src/standard/stdmeasure.jl | 4 ---- src/vartransform.jl | 1 - 2 files changed, 5 deletions(-) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 7a9684fc..fa567908 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -10,22 +10,18 @@ StdMeasure(::typeof(randexp)) = StdExponential() @inline vartransform_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x function vartransform_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) - check_dof(ν, μ) return vartransform_def(ν, μ.parent, only(x)) end function vartransform_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) - check_dof(ν, μ) return Fill(vartransform_def(ν.parent, μ, only(x)), map(length, ν.axes)...) end function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, x) - check_dof(ν, μ) return vartransform(ν.parent, μ.parent).(x) end function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, x) where {N,M} - check_dof(ν, μ) return reshape(vartransform(ν.parent, μ.parent).(x), map(length, ν.axes)...) end diff --git a/src/vartransform.jl b/src/vartransform.jl index f229b7ea..1c8f548b 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -123,7 +123,6 @@ vartransform_def(::Any, ::Any, x::NoTransformOrigin) = x vartransform_def(::Any, ::Any, x::NoVarTransform) = x function vartransform_def(ν, μ, x) - check_dof(ν, μ) _vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x) end From 1088fc18e0d1462cab2d78d5b9c7020b48bce980 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 12:54:58 +0200 Subject: [PATCH 58/70] Improve checked_var for PowerMeasure --- src/combinators/power.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index a3a63041..c672d55b 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -114,3 +114,7 @@ end end return x end + +function checked_var(μ::PowerMeasure, x::Any) + throw(ArgumentError("Size of variate doesn't match size of power measure")) +end From 1c2311c52e3f571b5066d13aacd6a3a32055a15b Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 18:02:53 +0200 Subject: [PATCH 59/70] Fix check_dof and require_insupport rrules --- src/getdof.jl | 6 ++++-- src/insupport.jl | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/getdof.jl b/src/getdof.jl index aeee1738..b0c8d864 100644 --- a/src/getdof.jl +++ b/src/getdof.jl @@ -46,7 +46,8 @@ function check_dof(ν, μ) return nothing end -ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = NoTangent(), NoTangent(), NoTangent() +_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent() +ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback """ @@ -72,4 +73,5 @@ function checked_var end @propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x) -ChainRulesCore.rrule(::typeof(checked_var), ν, x) = NoTangent(), NoTangent(), ZeroTangent() +_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ +ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback diff --git a/src/insupport.jl b/src/insupport.jl index ccfc841b..7d407d0d 100644 --- a/src/insupport.jl +++ b/src/insupport.jl @@ -19,12 +19,12 @@ Checks if `x` is in the support of distribution/measure `μ`, throws an """ function require_insupport end -_check_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent() +_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent() function ChainRulesCore.rrule(::typeof(require_insupport), μ, x) - return require_insupport(μ, x), _check_insupport_pullback + return require_insupport(μ, x), _require_insupport_pullback end -function require_insupport(μ, x::AbstractArray{T,N}) where {T,N} +function require_insupport(μ, x) if !insupport(μ, x) throw(ArgumentError("x is not within the support of μ")) end From 1fb805c6df8d1528552c135746ebe42ae7cec95e Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 18:24:13 +0200 Subject: [PATCH 60/70] Test getdof --- Project.toml | 3 ++- test/getdof.jl | 33 +++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 test/getdof.jl diff --git a/Project.toml b/Project.toml index c6e134bb..9ef0d01f 100644 --- a/Project.toml +++ b/Project.toml @@ -48,6 +48,7 @@ julia = "1.3" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" [targets] -test = ["Aqua"] +test = ["Aqua", "ChainRulesTestUtils"] diff --git a/test/getdof.jl b/test/getdof.jl new file mode 100644 index 00000000..087c7b01 --- /dev/null +++ b/test/getdof.jl @@ -0,0 +1,33 @@ +using Test + +using MeasureBase: getdof, check_dof, checked_var +using MeasureBase: StdUniform, StdExponential, StdLogistic +using ChainRulesTestUtils: test_rrule +using Static: static + + +@testset "getdof" begin + μ0 = StdExponential() + x0 = rand(μ0) + + μ2 = StdExponential()^(2,3) + x2 = rand(μ2) + + @test @inferred(getdof(μ0)) === static(1) + @test (check_dof(μ0, StdUniform()); true) + @test_throws ArgumentError check_dof(μ2, μ0) + test_rrule(check_dof, μ0, StdUniform()) + + @test @inferred(checked_var(μ0, x0)) === x0 + @test_throws ArgumentError checked_var(μ0, x2) + test_rrule(checked_var, μ0, x0) + + @test @inferred(getdof(μ2)) == 6 + @test (check_dof(μ2, StdUniform()^(1,6,1)); true) + @test_throws ArgumentError check_dof(μ2, μ0) + test_rrule(check_dof, μ2, StdUniform()^(1,6,1)) + + @test @inferred(checked_var(μ2, x2)) === x2 + @test_throws ArgumentError checked_var(μ2, x0) + test_rrule(checked_var, μ2, x2) +end diff --git a/test/runtests.jl b/test/runtests.jl index 51fc560b..c8f06d36 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -233,6 +233,7 @@ end # end end +include("getdof.jl") include("vartransform.jl") include("combinators/weighted.jl") From 24affe6f6400db973c1fe4c0b2f4aa24e07cedeb Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 19:07:44 +0200 Subject: [PATCH 61/70] Document TransformVolCorr --- src/combinators/transformedmeasure.jl | 5 ----- src/vartransform.jl | 28 +++++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index e2fc130d..8512e566 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -15,11 +15,6 @@ function paramnames(::AbstractTransformedMeasure) end function parent(::AbstractTransformedMeasure) end -abstract type TransformVolCorr end -struct WithVolCorr <: TransformVolCorr end -struct NoVolCorr <: TransformVolCorr end - - export PushforwardMeasure """ diff --git a/src/vartransform.jl b/src/vartransform.jl index 1c8f548b..688d739a 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -248,3 +248,31 @@ function Base.show(io::IO, f::VarTransformation) end Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f) + + +""" + abstract type TransformVolCorr + +Provides control over density correction by transform volume element. +Either [`NoVolCorr()`](@ref) or [`WithVolCorr()`](@ref) +""" +abstract type TransformVolCorr end + +""" + NoVolCorr() + +Indicate that density calculations should ignore the volume element of +var transformations. Should only be used in special cases in which +the volume element has already been taken into account in a different +way. +""" +struct NoVolCorr <: TransformVolCorr end + +""" + WithVolCorr() + +Indicate that density calculations should take the volume element of +var transformations into account (typically via the +log-abs-det-Jacobian of the transform). +""" +struct WithVolCorr <: TransformVolCorr end From 35fae34c172a99bea2bc84893c73175550500994 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 21:30:37 +0200 Subject: [PATCH 62/70] Fix transform variable naming inconsistencies --- src/combinators/transformedmeasure.jl | 4 ++-- src/vartransform.jl | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 8512e566..56db1f1d 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -87,8 +87,8 @@ end @inline vartransform_origin(ν::PushforwardMeasure) = ν.origin -@inline to_origin(ν::PushforwardMeasure, x) = ν.inv_f(x) -@inline from_origin(ν::PushforwardMeasure, y) = ν.f(y) +@inline from_origin(ν::PushforwardMeasure, x) = ν.f(x) +@inline to_origin(ν::PushforwardMeasure, y) = ν.inv_f(y) function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T return from_origin(ν, rand(rng, T, vartransform_origin(ν))) diff --git a/src/vartransform.jl b/src/vartransform.jl index 688d739a..5bd859a7 100644 --- a/src/vartransform.jl +++ b/src/vartransform.jl @@ -78,8 +78,8 @@ To add transformation rules for a measure type `MyMeasure`, specialize and/or * `MeasureBase.vartransform_origin(ν::MyMeasure) = SomeMeasure(...)` -* `MeasureBase.from_origin(μ::MyMeasure, y) = x` -* `MeasureBase.to_origin(μ::MyMeasure, x) = y` +* `MeasureBase.from_origin(μ::MyMeasure, x) = y` +* `MeasureBase.to_origin(μ::MyMeasure, y) = x` and ensure `MeasureBase.getdof(μ::MyMeasure)` is defined correctly. From d67369244ddb4055ecd1abd88d1db45290ff6206 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 17:35:00 -0400 Subject: [PATCH 63/70] Specialize gotdof for inferrably empty power measures --- src/combinators/power.jl | 3 +++ test/getdof.jl | 2 ++ 2 files changed, 5 insertions(+) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index c672d55b..62468120 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -104,6 +104,9 @@ end @inline getdof(μ::PowerMeasure) = getdof(μ.parent) * prod(map(length, μ.axes)) +@inline getdof(::PowerMeasure{<:Any, NTuple{N,Base.OneTo{StaticInt{0}}}}) where N = static(0) + + @propagate_inbounds function checked_var(μ::PowerMeasure, x::AbstractArray{<:Any}) @boundscheck begin sz_μ = map(length, μ.axes) diff --git a/test/getdof.jl b/test/getdof.jl index 087c7b01..c8d3953b 100644 --- a/test/getdof.jl +++ b/test/getdof.jl @@ -30,4 +30,6 @@ using Static: static @test @inferred(checked_var(μ2, x2)) === x2 @test_throws ArgumentError checked_var(μ2, x0) test_rrule(checked_var, μ2, x2) + + @test @inferred(getdof((StdExponential()^3)^(static(0),static(0)))) === static(0) end From 3965732922735bf06250813403c7b6b50a4a5fd4 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 17:52:24 -0400 Subject: [PATCH 64/70] Add trafos for Dirac --- src/primitives/dirac.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/primitives/dirac.jl b/src/primitives/dirac.jl index 57067ec7..8ed0e5ac 100644 --- a/src/primitives/dirac.jl +++ b/src/primitives/dirac.jl @@ -29,3 +29,16 @@ dirac(d::AbstractMeasure) = Dirac(rand(d)) testvalue(d::Dirac) = d.x insupport(d::Dirac, x) = x == d.x + +@inline getdof(::Dirac) = static(0) + +@propagate_inbounds function checked_var(μ::Dirac, x) + @boundscheck insupport(μ, x) || throw(ArgumentError("Invalid variate for measure")) + x +end + +@inline vartransform_def(ν::Dirac, ::PowerMeasure{<:MeasureBase.StdMeasure}, ::Any) = ν.x + +@inline function vartransform_def(ν::PowerMeasure{<:MeasureBase.StdMeasure}, ::Dirac, ::Any) + Zeros{Bool}(map(_ -> 0, ν.axes)) +end From 5b5eb7f7abcb944038b1abad9cdd3cc9334e9f62 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 21:50:46 -0400 Subject: [PATCH 65/70] Support logdensity calculation on empty power measures --- src/combinators/power.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/combinators/power.jl b/src/combinators/power.jl index 62468120..db5b72c9 100644 --- a/src/combinators/power.jl +++ b/src/combinators/power.jl @@ -85,6 +85,13 @@ end end end +@inline function logdensity_def( + d::PowerMeasure{M,NTuple{N, Base.OneTo{StaticInt{0}}}}, + x, +) where {M,N} + static(0.0) +end + @inline function insupport(μ::PowerMeasure, x) p = μ.parent all(x) do xj From e70f294aabc516009a9fcd92a832c50d13c7b324 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 21:51:12 -0400 Subject: [PATCH 66/70] Improve test_vartransform --- src/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/interface.jl b/src/interface.jl index b1b6722f..938d8a0b 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -68,7 +68,7 @@ end function test_vartransform(ν, μ) supertype(x::Real) = Real - supertype(x::AbstractArray{T,N}) where {T,N} = AbstractArray{T,N} + supertype(x::AbstractArray{<:Real,N}) where N = AbstractArray{<:Real,N} @testset "vartransform $μ to $ν" begin x = rand(μ) From e85e54fcb955098729e16d8f0ce1a59b056eb406 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sat, 18 Jun 2022 21:51:55 -0400 Subject: [PATCH 67/70] Fix and test vartransform for Dirac --- src/primitives/dirac.jl | 6 ------ src/standard/stdmeasure.jl | 9 +++++++++ test/vartransform.jl | 11 +++++++++++ 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/primitives/dirac.jl b/src/primitives/dirac.jl index 8ed0e5ac..2575382c 100644 --- a/src/primitives/dirac.jl +++ b/src/primitives/dirac.jl @@ -36,9 +36,3 @@ insupport(d::Dirac, x) = x == d.x @boundscheck insupport(μ, x) || throw(ArgumentError("Invalid variate for measure")) x end - -@inline vartransform_def(ν::Dirac, ::PowerMeasure{<:MeasureBase.StdMeasure}, ::Any) = ν.x - -@inline function vartransform_def(ν::PowerMeasure{<:MeasureBase.StdMeasure}, ::Dirac, ::Any) - Zeros{Bool}(map(_ -> 0, ν.axes)) -end diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index fa567908..8ab68ecd 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -34,3 +34,12 @@ _std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, get MeasureBase.vartransform(::Type{NU}, μ) where {NU<:StdMeasure} = vartransform(_std_measure_for(NU, μ), μ) MeasureBase.vartransform(ν, ::Type{MU}) where {MU<:StdMeasure} = vartransform(ν, _std_measure_for(MU, ν)) + + +# Transform between standard measures and Dirac: + +@inline vartransform_def(ν::Dirac, ::PowerMeasure{<:MeasureBase.StdMeasure}, ::Any) = ν.x + +@inline function vartransform_def(ν::PowerMeasure{<:MeasureBase.StdMeasure}, ::Dirac, ::Any) + Zeros{Bool}(map(_ -> 0, ν.axes)) +end diff --git a/test/vartransform.jl b/test/vartransform.jl index d6457539..b808b4eb 100644 --- a/test/vartransform.jl +++ b/test/vartransform.jl @@ -2,6 +2,7 @@ using Test using MeasureBase.Interface: vartransform, test_vartransform using MeasureBase: StdUniform, StdExponential, StdLogistic +using MeasureBase: Dirac @testset "vartransform" begin @@ -19,6 +20,16 @@ using MeasureBase: StdUniform, StdExponential, StdLogistic end end + @testset "transfrom from/to Dirac" begin + μ = Dirac(4.2) + test_vartransform(StdExponential()^0, μ) + test_vartransform(StdExponential()^(0,0,0), μ) + test_vartransform(μ, StdExponential()^static(0)) + test_vartransform(μ, StdExponential()^(static(0),static(0))) + @test_throws ArgumentError vartransform(StdExponential()^1, μ) + @test_throws ArgumentError vartransform(μ, StdExponential()^1) + end + @testset "vartransform autosel" begin @test @inferred(vartransform(StdExponential, StdUniform())) == vartransform(StdExponential(), StdUniform()) @test @inferred(vartransform(StdExponential, StdUniform()^(2,3))) == vartransform(StdExponential()^6, StdUniform()^(2,3)) From 8f2da10e29e16d9d640f314a658ed1ef24952863 Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 15:27:10 -0400 Subject: [PATCH 68/70] Rename vartransform to transport_to --- src/MeasureBase.jl | 4 +-- src/interface.jl | 8 +++--- src/standard/stdexponential.jl | 4 +-- src/standard/stdlogistic.jl | 4 +-- src/standard/stdmeasure.jl | 28 +++++++++--------- src/{vartransform.jl => transport.jl} | 40 +++++++++++++------------- test/combinators/transformedmeasure.jl | 4 +-- test/runtests.jl | 2 +- test/{vartransform.jl => transport.jl} | 24 ++++++++-------- 9 files changed, 59 insertions(+), 59 deletions(-) rename src/{vartransform.jl => transport.jl} (88%) rename test/{vartransform.jl => transport.jl} (51%) diff --git a/src/MeasureBase.jl b/src/MeasureBase.jl index e6c436d6..1905719a 100644 --- a/src/MeasureBase.jl +++ b/src/MeasureBase.jl @@ -40,7 +40,7 @@ export basekernel export productmeasure export insupport export getdof -export vartransform +export transport_to include("insupport.jl") @@ -92,7 +92,7 @@ using Compat using IrrationalConstants include("getdof.jl") -include("vartransform.jl") +include("transport.jl") include("schema.jl") include("splat.jl") include("proxies.jl") diff --git a/src/interface.jl b/src/interface.jl index 938d8a0b..d27ee505 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -6,7 +6,7 @@ using Reexport using MeasureBase: basemeasure_depth, proxy using MeasureBase: insupport, basemeasure_sequence, commonbase -using MeasureBase: vartransform, NoVarTransform +using MeasureBase: transport_to, NoVarTransform using DensityInterface: logdensityof using InverseFunctions: inverse @@ -70,10 +70,10 @@ function test_vartransform(ν, μ) supertype(x::Real) = Real supertype(x::AbstractArray{<:Real,N}) where N = AbstractArray{<:Real,N} - @testset "vartransform $μ to $ν" begin + @testset "transport_to $μ to $ν" begin x = rand(μ) - @test !(@inferred(vartransform(ν, μ)(x)) isa NoVarTransform) - f = vartransform(ν, μ) + @test !(@inferred(transport_to(ν, μ)(x)) isa NoVarTransform) + f = transport_to(ν, μ) y = f(x) @test @inferred(inverse(f)(y)) ≈ x @test @inferred(with_logabsdet_jacobian(f, x)) isa Tuple{supertype(y),Real} diff --git a/src/standard/stdexponential.jl b/src/standard/stdexponential.jl index 2b7b81a0..76442639 100644 --- a/src/standard/stdexponential.jl +++ b/src/standard/stdexponential.jl @@ -7,8 +7,8 @@ insupport(d::StdExponential, x) = x ≥ zero(x) @inline logdensity_def(::StdExponential, x) = -x @inline basemeasure(::StdExponential) = Lebesgue() -@inline vartransform_def(::StdUniform, μ::StdExponential, x) = - expm1(-x) -@inline vartransform_def(::StdExponential, μ::StdUniform, x) = - log1p(-x) +@inline transport_def(::StdUniform, μ::StdExponential, x) = - expm1(-x) +@inline transport_def(::StdExponential, μ::StdUniform, x) = - log1p(-x) Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdExponential) where {T} = randexp(rng, T) diff --git a/src/standard/stdlogistic.jl b/src/standard/stdlogistic.jl index 8ddb0137..922cb2c6 100644 --- a/src/standard/stdlogistic.jl +++ b/src/standard/stdlogistic.jl @@ -7,7 +7,7 @@ export StdLogistic @inline logdensity_def(::StdLogistic, x) = (u = -abs(x); u - 2*log1pexp(u)) @inline basemeasure(::StdLogistic) = Lebesgue() -@inline vartransform_def(::StdUniform, μ::StdLogistic, x) = logistic(x) -@inline vartransform_def(::StdLogistic, μ::StdUniform, x) = logit(x) +@inline transport_def(::StdUniform, μ::StdLogistic, x) = logistic(x) +@inline transport_def(::StdLogistic, μ::StdUniform, x) = logit(x) @inline Base.rand(rng::Random.AbstractRNG, ::Type{T}, ::StdLogistic) where {T} = logit(rand(rng, T)) diff --git a/src/standard/stdmeasure.jl b/src/standard/stdmeasure.jl index 8ab68ecd..0dba932b 100644 --- a/src/standard/stdmeasure.jl +++ b/src/standard/stdmeasure.jl @@ -7,39 +7,39 @@ StdMeasure(::typeof(randexp)) = StdExponential() @inline check_dof(::StdMeasure, ::StdMeasure) = nothing -@inline vartransform_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x +@inline transport_def(::MU, μ::MU, x) where {MU<:StdMeasure} = x -function vartransform_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) - return vartransform_def(ν, μ.parent, only(x)) +function transport_def(ν::StdMeasure, μ::PowerMeasure{<:StdMeasure}, x) + return transport_def(ν, μ.parent, only(x)) end -function vartransform_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) - return Fill(vartransform_def(ν.parent, μ, only(x)), map(length, ν.axes)...) +function transport_def(ν::PowerMeasure{<:StdMeasure}, μ::StdMeasure, x) + return Fill(transport_def(ν.parent, μ, only(x)), map(length, ν.axes)...) end -function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, x) - return vartransform(ν.parent, μ.parent).(x) +function transport_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{1,Base.OneTo}}, x) + return transport_to(ν.parent, μ.parent).(x) end -function vartransform_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, x) where {N,M} - return reshape(vartransform(ν.parent, μ.parent).(x), map(length, ν.axes)...) +function transport_def(ν::PowerMeasure{<:StdMeasure,<:NTuple{N,Base.OneTo}}, μ::PowerMeasure{<:StdMeasure,<:NTuple{M,Base.OneTo}}, x) where {N,M} + return reshape(transport_to(ν.parent, μ.parent).(x), map(length, ν.axes)...) end -# Implement vartransform(NU::Type{<:StdMeasure}, μ) and vartransform(ν, MU::Type{<:StdMeasure}): +# Implement transport_to(NU::Type{<:StdMeasure}, μ) and transport_to(ν, MU::Type{<:StdMeasure}): _std_measure(::Type{M}, ::StaticInt{1}) where {M<:StdMeasure} = M() _std_measure(::Type{M}, dof::Integer) where {M<:StdMeasure} = M()^dof _std_measure_for(::Type{M}, μ::Any) where {M<:StdMeasure} = _std_measure(M, getdof(μ)) -MeasureBase.vartransform(::Type{NU}, μ) where {NU<:StdMeasure} = vartransform(_std_measure_for(NU, μ), μ) -MeasureBase.vartransform(ν, ::Type{MU}) where {MU<:StdMeasure} = vartransform(ν, _std_measure_for(MU, ν)) +MeasureBase.transport_to(::Type{NU}, μ) where {NU<:StdMeasure} = transport_to(_std_measure_for(NU, μ), μ) +MeasureBase.transport_to(ν, ::Type{MU}) where {MU<:StdMeasure} = transport_to(ν, _std_measure_for(MU, ν)) # Transform between standard measures and Dirac: -@inline vartransform_def(ν::Dirac, ::PowerMeasure{<:MeasureBase.StdMeasure}, ::Any) = ν.x +@inline transport_def(ν::Dirac, ::PowerMeasure{<:MeasureBase.StdMeasure}, ::Any) = ν.x -@inline function vartransform_def(ν::PowerMeasure{<:MeasureBase.StdMeasure}, ::Dirac, ::Any) +@inline function transport_def(ν::PowerMeasure{<:MeasureBase.StdMeasure}, ::Dirac, ::Any) Zeros{Bool}(map(_ -> 0, ν.axes)) end diff --git a/src/vartransform.jl b/src/transport.jl similarity index 88% rename from src/vartransform.jl rename to src/transport.jl index 5bd859a7..5d83647a 100644 --- a/src/vartransform.jl +++ b/src/transport.jl @@ -50,7 +50,7 @@ struct NoVarTransform{NU,MU} end """ - f = vartransform(ν, μ) + f = transport_to(ν, μ) Generates a [measurable function](https://en.wikipedia.org/wiki/Measurable_function) `f` that transforms a value `x` distributed according to measure `μ` to @@ -72,8 +72,8 @@ Returns NoTransformOrigin{typeof(ν),typeof(μ)} if no transformation from To add transformation rules for a measure type `MyMeasure`, specialize -* `MeasureBase.vartransform_def(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...` -* `MeasureBase.vartransform_def(ν::MyMeasure, μ::SomeStdMeasure, x) = ...` +* `MeasureBase.transport_def(ν::SomeStdMeasure, μ::CustomMeasure, x) = ...` +* `MeasureBase.transport_def(ν::MyMeasure, μ::SomeStdMeasure, x) = ...` and/or @@ -95,17 +95,17 @@ Depending on [`getdof(μ)`](@ref) (resp. `ν`), an instance of the standard distribution itself or a power of it (e.g. `StdUniform()` or `StdUniform()^dof`) will be chosen as the transformation partner. """ -function vartransform end +function transport_to end """ - vartransform_def(ν, μ, x) + transport_def(ν, μ, x) Transforms a value `x` distributed according to `μ` to a value `y` distributed according to `ν`. -If no specialized `vartransform_def(::MU, ::NU, ...)` is available then -the default implementation of`vartransform_def(ν, μ, x)` uses the following +If no specialized `transport_def(::MU, ::NU, ...)` is available then +the default implementation of`transport_def(ν, μ, x)` uses the following strategy: * Evaluate [`vartransform_origin`](@ref) for μ and ν. Transform between @@ -115,14 +115,14 @@ strategy: * If all else fails, try to transform from μ to a standard multivariate uniform measure and then to ν. -See [`vartransform`](@ref). +See [`transport_to`](@ref). """ -function vartransform_def end +function transport_def end -vartransform_def(::Any, ::Any, x::NoTransformOrigin) = x -vartransform_def(::Any, ::Any, x::NoVarTransform) = x +transport_def(::Any, ::Any, x::NoTransformOrigin) = x +transport_def(::Any, ::Any, x::NoVarTransform) = x -function vartransform_def(ν, μ, x) +function transport_def(ν, μ, x) _vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x) end @@ -141,13 +141,13 @@ end function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x) x_o = to_origin(μ, x) # If μ is a pushforward then checked_var may have been bypassed, so check now: - y_o = vartransform_def(ν_o, μ_o, checked_var(μ_o, x_o)) + y_o = transport_def(ν_o, μ_o, checked_var(μ_o, x_o)) y = from_origin(ν, y_o) return y end function _vartransform_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x) - y_o = vartransform_def(ν_o, μ, x) + y_o = transport_def(ν_o, μ, x) y = from_origin(ν, y_o) return y end @@ -155,7 +155,7 @@ end function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x) x_o = to_origin(μ, x) # If μ is a pushforward then checked_var may have been bypassed, so check now: - y = vartransform_def(ν, μ_o, checked_var(μ_o, x_o)) + y = transport_def(ν, μ_o, checked_var(μ_o, x_o)) return y end @@ -169,8 +169,8 @@ end @inline _vartransform_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform() function _vartransform_with_intermediate(ν, m, μ, x) - z = vartransform_def(m, μ, x) - y = vartransform_def(ν, m, z) + z = transport_def(m, μ, x) + y = transport_def(ν, m, z) return y end @@ -185,7 +185,7 @@ end Transforms a variate from one measure to a variate of another. In general `VarTransformation` should not be called directly, call -[`vartransform`](@ref) instead. +[`transport_to`](@ref) instead. """ struct VarTransformation{NU,MU} <: Function ν::NU @@ -201,7 +201,7 @@ struct VarTransformation{NU,MU} <: Function end end -@inline vartransform(ν, μ) = VarTransformation(ν, μ) +@inline transport_to(ν, μ) = VarTransformation(ν, μ) function Base.:(==)(a::VarTransformation, b::VarTransformation) return a.ν == b.ν && a.μ == b.μ @@ -209,7 +209,7 @@ end Base.@propagate_inbounds function (f::VarTransformation)(x) - return vartransform_def(f.ν, f.μ, checked_var(f.μ, x)) + return transport_def(f.ν, f.μ, checked_var(f.μ, x)) end @inline function InverseFunctions.inverse(f::VarTransformation{NU,MU}) where {NU,MU} diff --git a/test/combinators/transformedmeasure.jl b/test/combinators/transformedmeasure.jl index 7097d038..d4a9fcaa 100644 --- a/test/combinators/transformedmeasure.jl +++ b/test/combinators/transformedmeasure.jl @@ -2,7 +2,7 @@ using Test using MeasureBase: pushfwd, StdUniform, StdExponential, StdLogistic using MeasureBase: pushfwd, PushforwardMeasure -using MeasureBase: vartransform +using MeasureBase: transport_to using Statistics: var using DensityInterface: logdensityof @@ -17,5 +17,5 @@ using DensityInterface: logdensityof @test isapprox(var(rand(ν^(10^5))), 1, rtol = 0.05) - @test vartransform(StdLogistic(), ν)(y) ≈ vartransform(StdLogistic(), ν)(y) + @test transport_to(StdLogistic(), ν)(y) ≈ transport_to(StdLogistic(), ν)(y) end diff --git a/test/runtests.jl b/test/runtests.jl index c8f06d36..22dfb993 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -234,7 +234,7 @@ end end include("getdof.jl") -include("vartransform.jl") +include("transport.jl") include("combinators/weighted.jl") include("combinators/transformedmeasure.jl") diff --git a/test/vartransform.jl b/test/transport.jl similarity index 51% rename from test/vartransform.jl rename to test/transport.jl index b808b4eb..0cb6e55a 100644 --- a/test/vartransform.jl +++ b/test/transport.jl @@ -1,13 +1,13 @@ using Test -using MeasureBase.Interface: vartransform, test_vartransform +using MeasureBase.Interface: transport_to, test_vartransform using MeasureBase: StdUniform, StdExponential, StdLogistic using MeasureBase: Dirac -@testset "vartransform" begin +@testset "transport_to" begin for μ0 in [StdUniform(), StdExponential(), StdLogistic()], ν0 in [StdUniform(), StdExponential(), StdLogistic()] - @testset "vartransform (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin + @testset "transport_to (variations of) $(nameof(typeof(μ0))) to $(nameof(typeof(ν0)))" begin test_vartransform(ν0, μ0) test_vartransform(2.2 * ν0, 3 * μ0) test_vartransform(ν0, μ0^1) @@ -15,8 +15,8 @@ using MeasureBase: Dirac test_vartransform(ν0^3, μ0^3) test_vartransform(ν0^(2,3,2), μ0^(3,4)) test_vartransform(2.2 * ν0^(2,3,2), 3 * μ0^(3,4)) - @test_throws ArgumentError vartransform(ν0, μ0)(rand(μ0^12)) - @test_throws ArgumentError vartransform(ν0^3, μ0^3)(rand(μ0^(3,4))) + @test_throws ArgumentError transport_to(ν0, μ0)(rand(μ0^12)) + @test_throws ArgumentError transport_to(ν0^3, μ0^3)(rand(μ0^(3,4))) end end @@ -26,14 +26,14 @@ using MeasureBase: Dirac test_vartransform(StdExponential()^(0,0,0), μ) test_vartransform(μ, StdExponential()^static(0)) test_vartransform(μ, StdExponential()^(static(0),static(0))) - @test_throws ArgumentError vartransform(StdExponential()^1, μ) - @test_throws ArgumentError vartransform(μ, StdExponential()^1) + @test_throws ArgumentError transport_to(StdExponential()^1, μ) + @test_throws ArgumentError transport_to(μ, StdExponential()^1) end - @testset "vartransform autosel" begin - @test @inferred(vartransform(StdExponential, StdUniform())) == vartransform(StdExponential(), StdUniform()) - @test @inferred(vartransform(StdExponential, StdUniform()^(2,3))) == vartransform(StdExponential()^6, StdUniform()^(2,3)) - @test @inferred(vartransform(StdUniform(), StdExponential)) == vartransform(StdUniform(), StdExponential()) - @test @inferred(vartransform(StdUniform()^(2,3), StdExponential)) == vartransform(StdUniform()^(2,3), StdExponential()^6) + @testset "transport_to autosel" begin + @test @inferred(transport_to(StdExponential, StdUniform())) == transport_to(StdExponential(), StdUniform()) + @test @inferred(transport_to(StdExponential, StdUniform()^(2,3))) == transport_to(StdExponential()^6, StdUniform()^(2,3)) + @test @inferred(transport_to(StdUniform(), StdExponential)) == transport_to(StdUniform(), StdExponential()) + @test @inferred(transport_to(StdUniform()^(2,3), StdExponential)) == transport_to(StdUniform()^(2,3), StdExponential()^6) end end From 343d20b6aca828b48e4c948ed602925a67b9c9ef Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 15:39:02 -0400 Subject: [PATCH 69/70] Rename vartransform_origin --- src/combinators/transformedmeasure.jl | 10 +++++----- src/combinators/weighted.jl | 2 +- src/transport.jl | 18 +++++++++--------- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/combinators/transformedmeasure.jl b/src/combinators/transformedmeasure.jl index 56db1f1d..ee471861 100644 --- a/src/combinators/transformedmeasure.jl +++ b/src/combinators/transformedmeasure.jl @@ -63,12 +63,12 @@ end end -insupport(ν::PushforwardMeasure, y) = insupport(vartransform_origin(ν), to_origin(ν, y)) +insupport(ν::PushforwardMeasure, y) = insupport(transport_origin(ν), to_origin(ν, y)) -testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(vartransform_origin(ν))) +testvalue(ν::PushforwardMeasure) = from_origin(ν, testvalue(transport_origin(ν))) @inline function basemeasure(ν::PushforwardMeasure) - PushforwardMeasure(ν.f, ν.inv_f, basemeasure(vartransform_origin(ν)), NoVolCorr()) + PushforwardMeasure(ν.f, ν.inv_f, basemeasure(transport_origin(ν)), NoVolCorr()) end @@ -86,12 +86,12 @@ end @inline checked_var(::PushforwardMeasure, x) = x -@inline vartransform_origin(ν::PushforwardMeasure) = ν.origin +@inline transport_origin(ν::PushforwardMeasure) = ν.origin @inline from_origin(ν::PushforwardMeasure, x) = ν.f(x) @inline to_origin(ν::PushforwardMeasure, y) = ν.inv_f(y) function Base.rand(rng::AbstractRNG, ::Type{T}, ν::PushforwardMeasure) where T - return from_origin(ν, rand(rng, T, vartransform_origin(ν))) + return from_origin(ν, rand(rng, T, transport_origin(ν))) end diff --git a/src/combinators/weighted.jl b/src/combinators/weighted.jl index 22b9e440..aef9dbee 100644 --- a/src/combinators/weighted.jl +++ b/src/combinators/weighted.jl @@ -49,6 +49,6 @@ gentype(μ::WeightedMeasure) = gentype(μ.base) insupport(μ::WeightedMeasure, x) = insupport(μ.base, x) -vartransform_origin(ν::WeightedMeasure) = ν.base +transport_origin(ν::WeightedMeasure) = ν.base to_origin(::WeightedMeasure, y) = y from_origin(::WeightedMeasure, x) = x diff --git a/src/transport.jl b/src/transport.jl index 5d83647a..e3970f55 100644 --- a/src/transport.jl +++ b/src/transport.jl @@ -4,26 +4,26 @@ Indicates that no (default) pullback measure is available for measures of type `NU`. -See [`MeasureBase.vartransform_origin`](@ref). +See [`MeasureBase.transport_origin`](@ref). """ struct NoTransformOrigin{NU} end """ - MeasureBase.vartransform_origin(ν) + MeasureBase.transport_origin(ν) Default measure to pullback to resp. pushforward from when transforming between `ν` and another measure. """ -function vartransform_origin end +function transport_origin end -vartransform_origin(ν::NU) where NU = NoTransformOrigin{NU}() +transport_origin(ν::NU) where NU = NoTransformOrigin{NU}() """ MeasureBase.from_origin(ν, x) -Push `x` from `MeasureBase.vartransform_origin(μ)` forward to `ν`. +Push `x` from `MeasureBase.transport_origin(μ)` forward to `ν`. """ function from_origin end @@ -33,7 +33,7 @@ from_origin(ν::NU, ::Any) where NU = NoTransformOrigin{NU}() """ MeasureBase.to_origin(ν, y) -Pull `y` from `ν` back to `MeasureBase.vartransform_origin(ν)`. +Pull `y` from `ν` back to `MeasureBase.transport_origin(ν)`. """ function to_origin end @@ -77,7 +77,7 @@ To add transformation rules for a measure type `MyMeasure`, specialize and/or -* `MeasureBase.vartransform_origin(ν::MyMeasure) = SomeMeasure(...)` +* `MeasureBase.transport_origin(ν::MyMeasure) = SomeMeasure(...)` * `MeasureBase.from_origin(μ::MyMeasure, x) = y` * `MeasureBase.to_origin(μ::MyMeasure, y) = x` @@ -108,7 +108,7 @@ If no specialized `transport_def(::MU, ::NU, ...)` is available then the default implementation of`transport_def(ν, μ, x)` uses the following strategy: -* Evaluate [`vartransform_origin`](@ref) for μ and ν. Transform between +* Evaluate [`transport_origin`](@ref) for μ and ν. Transform between each and it's origin, if available, and use the origin(s) as intermediate measures for another transformation. @@ -133,7 +133,7 @@ function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where MU end @inline function _checked_vartransform_origin(μ::MU) where MU - μ_o = vartransform_origin(μ) + μ_o = transport_origin(μ) _origin_must_have_separate_type(MU, μ_o) end From a8faa6625c716063bb372cf58ddd96d9389e5bbf Mon Sep 17 00:00:00 2001 From: Oliver Schulz Date: Sun, 19 Jun 2022 16:14:17 -0400 Subject: [PATCH 70/70] Increase package version to v0.11.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 9ef0d01f..fbce68d7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MeasureBase" uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14" authors = ["Chad Scherrer and contributors"] -version = "0.10.0" +version = "0.11.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"