Skip to content

Commit f7b4444

Browse files
authored
Address some transform-related naming inconsistencies (#63)
* Rename VarTransformation to TransportFunction * Rename test_vartransform to test_transport * Rename NoVarTransform to NoTransport * Rename checked_var to checked_arg * Rename NoVarCheck to NoArgCheck * Clarify NoVolCorr and WithVolCorr docstrings
1 parent 648376d commit f7b4444

File tree

10 files changed

+82
-82
lines changed

10 files changed

+82
-82
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "MeasureBase"
22
uuid = "fa1605e6-acd5-459c-a1e6-7e635759db14"
33
authors = ["Chad Scherrer <[email protected]> and contributors"]
4-
version = "0.11.0"
4+
version = "0.12.0"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/combinators/power.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ end
114114
@inline getdof(::PowerMeasure{<:Any, NTuple{N,Base.OneTo{StaticInt{0}}}}) where N = static(0)
115115

116116

117-
@propagate_inbounds function checked_var::PowerMeasure, x::AbstractArray{<:Any})
117+
@propagate_inbounds function checked_arg::PowerMeasure, x::AbstractArray{<:Any})
118118
@boundscheck begin
119119
sz_μ = map(length, μ.axes)
120120
sz_x = size(x)
@@ -125,6 +125,6 @@ end
125125
return x
126126
end
127127

128-
function checked_var::PowerMeasure, x::Any)
128+
function checked_arg::PowerMeasure, x::Any)
129129
throw(ArgumentError("Size of variate doesn't match size of power measure"))
130130
end

src/combinators/transformedmeasure.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,8 @@ _pushfwd_dof(::Type{MU}, ::Type{<:Tuple{Any,Real}}, dof) where MU = dof
8282
_pushfwd_dof(MU, R, getdof.origin))
8383
end
8484

85-
# Bypass `checked_var`, would require potentially costly transformation:
86-
@inline checked_var(::PushforwardMeasure, x) = x
85+
# Bypass `checked_arg`, would require potentially costly transformation:
86+
@inline checked_arg(::PushforwardMeasure, x) = x
8787

8888

8989
@inline transport_origin::PushforwardMeasure) = ν.origin

src/getdof.jl

+10-10
Original file line numberDiff line numberDiff line change
@@ -51,27 +51,27 @@ ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_do
5151

5252

5353
"""
54-
MeasureBase.NoVarCheck{MU,T}
54+
MeasureBase.NoArgCheck{MU,T}
5555
5656
Indicates that there is no way to check of a values of type `T` are
5757
variate of measures of type `MU`.
5858
"""
59-
struct NoVarCheck{MU,T} end
59+
struct NoArgCheck{MU,T} end
6060

6161

6262
"""
63-
MeasureBase.checked_var(μ::MU, x::T)::T
63+
MeasureBase.checked_arg(μ::MU, x::T)::T
6464
6565
Return `x` if `x` is a valid variate of `μ`, throw an `ArgumentError` if not,
66-
return `NoVarCheck{MU,T}()` if not check can be performed.
66+
return `NoArgCheck{MU,T}()` if not check can be performed.
6767
"""
68-
function checked_var end
68+
function checked_arg end
6969

7070
# Prevent infinite recursion:
71-
@propagate_inbounds _default_checked_var(::Type{MU}, ::MU, ::T) where {MU,T} = NoVarCheck{MU,T}
72-
@propagate_inbounds _default_checked_var(::Type{MU}, mu_base, x) where MU = checked_var(mu_base, x)
71+
@propagate_inbounds _default_checked_arg(::Type{MU}, ::MU, ::T) where {MU,T} = NoArgCheck{MU,T}
72+
@propagate_inbounds _default_checked_arg(::Type{MU}, mu_base, x) where MU = checked_arg(mu_base, x)
7373

74-
@propagate_inbounds checked_var(mu::MU, x) where MU = _default_checked_var(MU, basemeasure(mu), x)
74+
@propagate_inbounds checked_arg(mu::MU, x) where MU = _default_checked_arg(MU, basemeasure(mu), x)
7575

76-
_checked_var_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
77-
ChainRulesCore.rrule(::typeof(checked_var), ν, x) = checked_var(ν, x), _checked_var_pullback
76+
_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
77+
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback

src/interface.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ using Reexport
66

77
using MeasureBase: basemeasure_depth, proxy
88
using MeasureBase: insupport, basemeasure_sequence, commonbase
9-
using MeasureBase: transport_to, NoVarTransform
9+
using MeasureBase: transport_to, NoTransport
1010

1111
using DensityInterface: logdensityof
1212
using InverseFunctions: inverse
1313
using ChangesOfVariables: with_logabsdet_jacobian
1414

1515
export test_interface
16-
export test_vartransform
16+
export test_transport
1717
export basemeasure_depth
1818
export proxy
1919
export insupport
@@ -66,13 +66,13 @@ function test_interface(μ::M) where {M}
6666
end
6767

6868

69-
function test_vartransform(ν, μ)
69+
function test_transport(ν, μ)
7070
supertype(x::Real) = Real
7171
supertype(x::AbstractArray{<:Real,N}) where N = AbstractArray{<:Real,N}
7272

7373
@testset "transport_to to " begin
7474
x = rand(μ)
75-
@test !(@inferred(transport_to(ν, μ)(x)) isa NoVarTransform)
75+
@test !(@inferred(transport_to(ν, μ)(x)) isa NoTransport)
7676
f = transport_to(ν, μ)
7777
y = f(x)
7878
@test @inferred(inverse(f)(y)) x

src/primitives/dirac.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ insupport(d::Dirac, x) = x == d.x
3232

3333
@inline getdof(::Dirac) = static(0)
3434

35-
@propagate_inbounds function checked_var::Dirac, x)
35+
@propagate_inbounds function checked_arg::Dirac, x)
3636
@boundscheck insupport(μ, x) || throw(ArgumentError("Invalid variate for measure"))
3737
x
3838
end

src/primitives/lebesgue.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ logdensity_def(::CountingMeasure, ::LebesgueMeasure, x) = Inf
4343

4444
@inline getdof(::Lebesgue) = static(1)
4545

46-
@inline checked_var(::Lebesgue, x::Real) = x
46+
@inline checked_arg(::Lebesgue, x::Real) = x
4747

48-
@propagate_inbounds function checked_var(::Lebesgue, x::Any)
48+
@propagate_inbounds function checked_arg(::Lebesgue, x::Any)
4949
@boundscheck throw(ArgumentError("Invalid variate type for measure"))
5050
end

src/transport.jl

+41-41
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ to_origin(ν::NU, ::Any) where NU = NoTransformOrigin{NU}(ν)
4141

4242

4343
"""
44-
struct MeasureBase.NoVarTransform{NU,MU} end
44+
struct MeasureBase.NoTransport{NU,MU} end
4545
4646
Indicates that no transformation from a measure of type `MU` to a measure of
4747
type `NU` could be found.
4848
"""
49-
struct NoVarTransform{NU,MU} end
49+
struct NoTransport{NU,MU} end
5050

5151

5252
"""
@@ -120,10 +120,10 @@ See [`transport_to`](@ref).
120120
function transport_def end
121121

122122
transport_def(::Any, ::Any, x::NoTransformOrigin) = x
123-
transport_def(::Any, ::Any, x::NoVarTransform) = x
123+
transport_def(::Any, ::Any, x::NoTransport) = x
124124

125125
function transport_def(ν, μ, x)
126-
_vartransform_with_intermediate(ν, _checked_vartransform_origin(ν), _checked_vartransform_origin(μ), μ, x)
126+
_transport_with_intermediate(ν, _checked_transport_origin(ν), _checked_transport_origin(μ), μ, x)
127127
end
128128

129129

@@ -132,92 +132,92 @@ function _origin_must_have_separate_type(::Type{MU}, μ_o::MU) where MU
132132
throw(ArgumentError("Measure of type $MU and its origin must have separate types"))
133133
end
134134

135-
@inline function _checked_vartransform_origin::MU) where MU
135+
@inline function _checked_transport_origin::MU) where MU
136136
μ_o = transport_origin(μ)
137137
_origin_must_have_separate_type(MU, μ_o)
138138
end
139139

140140

141-
function _vartransform_with_intermediate(ν, ν_o, μ_o, μ, x)
141+
function _transport_with_intermediate(ν, ν_o, μ_o, μ, x)
142142
x_o = to_origin(μ, x)
143-
# If μ is a pushforward then checked_var may have been bypassed, so check now:
144-
y_o = transport_def(ν_o, μ_o, checked_var(μ_o, x_o))
143+
# If μ is a pushforward then checked_arg may have been bypassed, so check now:
144+
y_o = transport_def(ν_o, μ_o, checked_arg(μ_o, x_o))
145145
y = from_origin(ν, y_o)
146146
return y
147147
end
148148

149-
function _vartransform_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x)
149+
function _transport_with_intermediate(ν, ν_o, ::NoTransformOrigin, μ, x)
150150
y_o = transport_def(ν_o, μ, x)
151151
y = from_origin(ν, y_o)
152152
return y
153153
end
154154

155-
function _vartransform_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x)
155+
function _transport_with_intermediate(ν, ::NoTransformOrigin, μ_o, μ, x)
156156
x_o = to_origin(μ, x)
157-
# If μ is a pushforward then checked_var may have been bypassed, so check now:
158-
y = transport_def(ν, μ_o, checked_var(μ_o, x_o))
157+
# If μ is a pushforward then checked_arg may have been bypassed, so check now:
158+
y = transport_def(ν, μ_o, checked_arg(μ_o, x_o))
159159
return y
160160
end
161161

162-
function _vartransform_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x)
163-
_vartransform_with_intermediate(ν, _vartransform_intermediate(ν, μ), μ, x)
162+
function _transport_with_intermediate(ν, ::NoTransformOrigin, ::NoTransformOrigin, μ, x)
163+
_transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x)
164164
end
165165

166166

167-
@inline _vartransform_intermediate(ν, μ) = _vartransform_intermediate(getdof(ν), getdof(μ))
168-
@inline _vartransform_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ
169-
@inline _vartransform_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform()
167+
@inline _transport_intermediate(ν, μ) = _transport_intermediate(getdof(ν), getdof(μ))
168+
@inline _transport_intermediate(::Integer, n_μ::Integer) = StdUniform()^n_μ
169+
@inline _transport_intermediate(::StaticInt{1}, ::StaticInt{1}) = StdUniform()
170170

171-
function _vartransform_with_intermediate(ν, m, μ, x)
171+
function _transport_with_intermediate(ν, m, μ, x)
172172
z = transport_def(m, μ, x)
173173
y = transport_def(ν, m, z)
174174
return y
175175
end
176176

177177
# Prevent infinite recursion in case vartransform_intermediate doesn't change type:
178-
@inline _vartransform_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}()
179-
@inline _vartransform_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoVarTransform{NU,MU}()
178+
@inline _transport_with_intermediate(::NU, ::NU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}()
179+
@inline _transport_with_intermediate(::NU, ::MU, ::MU, ::Any) where {NU,MU} = NoTransport{NU,MU}()
180180

181181

182182
"""
183-
struct VarTransformation <: Function
183+
struct TransportFunction <: Function
184184
185185
Transforms a variate from one measure to a variate of another.
186186
187-
In general `VarTransformation` should not be called directly, call
187+
In general `TransportFunction` should not be called directly, call
188188
[`transport_to`](@ref) instead.
189189
"""
190-
struct VarTransformation{NU,MU} <: Function
190+
struct TransportFunction{NU,MU} <: Function
191191
ν::NU
192192
μ::MU
193193

194-
function VarTransformation{NU,MU}::NU, μ::MU) where {NU,MU}
194+
function TransportFunction{NU,MU}::NU, μ::MU) where {NU,MU}
195195
return new{NU,MU}(ν, μ)
196196
end
197197

198-
function VarTransformation::NU, μ::MU) where {NU,MU}
198+
function TransportFunction::NU, μ::MU) where {NU,MU}
199199
check_dof(ν, μ)
200200
return new{NU,MU}(ν, μ)
201201
end
202202
end
203203

204-
@inline transport_to(ν, μ) = VarTransformation(ν, μ)
204+
@inline transport_to(ν, μ) = TransportFunction(ν, μ)
205205

206-
function Base.:(==)(a::VarTransformation, b::VarTransformation)
206+
function Base.:(==)(a::TransportFunction, b::TransportFunction)
207207
return a.ν == b.ν && a.μ == b.μ
208208
end
209209

210210

211-
Base.@propagate_inbounds function (f::VarTransformation)(x)
212-
return transport_def(f.ν, f.μ, checked_var(f.μ, x))
211+
Base.@propagate_inbounds function (f::TransportFunction)(x)
212+
return transport_def(f.ν, f.μ, checked_arg(f.μ, x))
213213
end
214214

215-
@inline function InverseFunctions.inverse(f::VarTransformation{NU,MU}) where {NU,MU}
216-
return VarTransformation{MU,NU}(f.μ, f.ν)
215+
@inline function InverseFunctions.inverse(f::TransportFunction{NU,MU}) where {NU,MU}
216+
return TransportFunction{MU,NU}(f.μ, f.ν)
217217
end
218218

219219

220-
function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x)
220+
function ChangesOfVariables.with_logabsdet_jacobian(f::TransportFunction, x)
221221
y = f(x)
222222
logpdf_src = logdensityof(f.μ, x)
223223
logpdf_trg = logdensityof(f.ν, y)
@@ -228,26 +228,26 @@ function ChangesOfVariables.with_logabsdet_jacobian(f::VarTransformation, x)
228228
end
229229

230230

231-
Base.:()(::typeof(identity), f::VarTransformation) = f
232-
Base.:()(f::VarTransformation, ::typeof(identity)) = f
231+
Base.:()(::typeof(identity), f::TransportFunction) = f
232+
Base.:()(f::TransportFunction, ::typeof(identity)) = f
233233

234-
function Base.:(outer::VarTransformation, inner::VarTransformation)
234+
function Base.:(outer::TransportFunction, inner::TransportFunction)
235235
if !(outer.μ == inner.ν || isequal(outer.μ, inner.ν) || outer.μ inner.ν)
236-
throw(ArgumentError("Cannot compose VarTransformation if source of outer doesn't equal target of inner."))
236+
throw(ArgumentError("Cannot compose TransportFunction if source of outer doesn't equal target of inner."))
237237
end
238-
return VarTransformation(outer.ν, inner.μ)
238+
return TransportFunction(outer.ν, inner.μ)
239239
end
240240

241241

242-
function Base.show(io::IO, f::VarTransformation)
242+
function Base.show(io::IO, f::TransportFunction)
243243
print(io, Base.typename(typeof(f)).name, "(")
244244
show(io, f.ν)
245245
print(io, ", ")
246246
show(io, f.μ)
247247
print(io, ")")
248248
end
249249

250-
Base.show(io::IO, M::MIME"text/plain", f::VarTransformation) = show(io, f)
250+
Base.show(io::IO, M::MIME"text/plain", f::TransportFunction) = show(io, f)
251251

252252

253253
"""
@@ -262,7 +262,7 @@ abstract type TransformVolCorr end
262262
NoVolCorr()
263263
264264
Indicate that density calculations should ignore the volume element of
265-
var transformations. Should only be used in special cases in which
265+
variate transformations. Should only be used in special cases in which
266266
the volume element has already been taken into account in a different
267267
way.
268268
"""
@@ -272,7 +272,7 @@ struct NoVolCorr <: TransformVolCorr end
272272
WithVolCorr()
273273
274274
Indicate that density calculations should take the volume element of
275-
var transformations into account (typically via the
275+
variate transformations into account (typically via the
276276
log-abs-det-Jacobian of the transform).
277277
"""
278278
struct WithVolCorr <: TransformVolCorr end

test/getdof.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using Test
22

3-
using MeasureBase: getdof, check_dof, checked_var
3+
using MeasureBase: getdof, check_dof, checked_arg
44
using MeasureBase: StdUniform, StdExponential, StdLogistic
55
using ChainRulesTestUtils: test_rrule
66
using Static: static
@@ -18,18 +18,18 @@ using Static: static
1818
@test_throws ArgumentError check_dof(μ2, μ0)
1919
test_rrule(check_dof, μ0, StdUniform())
2020

21-
@test @inferred(checked_var(μ0, x0)) === x0
22-
@test_throws ArgumentError checked_var(μ0, x2)
23-
test_rrule(checked_var, μ0, x0)
21+
@test @inferred(checked_arg(μ0, x0)) === x0
22+
@test_throws ArgumentError checked_arg(μ0, x2)
23+
test_rrule(checked_arg, μ0, x0)
2424

2525
@test @inferred(getdof(μ2)) == 6
2626
@test (check_dof(μ2, StdUniform()^(1,6,1)); true)
2727
@test_throws ArgumentError check_dof(μ2, μ0)
2828
test_rrule(check_dof, μ2, StdUniform()^(1,6,1))
2929

30-
@test @inferred(checked_var(μ2, x2)) === x2
31-
@test_throws ArgumentError checked_var(μ2, x0)
32-
test_rrule(checked_var, μ2, x2)
30+
@test @inferred(checked_arg(μ2, x2)) === x2
31+
@test_throws ArgumentError checked_arg(μ2, x0)
32+
test_rrule(checked_arg, μ2, x2)
3333

3434
@test @inferred(getdof((StdExponential()^3)^(static(0),static(0)))) === static(0)
3535
end

0 commit comments

Comments
 (0)