Skip to content

Commit 03bdffb

Browse files
sethaxengithub-actions[bot]devmotionyebaitorfjelde
authored
Make SimplexBijector actually bijective (#263)
* Remove unused proj field * Update simplex bijector calls * Update simplex jacobian calls * Remove proj type entry * Compute logdetjac from square part of jacobian * Increment minor version number * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Update test/interface.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed link and invlink for SimplexBijector * Update src/Bijectors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * super-hacky fix to size issue of TransformedDistribution * added fixme comment * removed redundant constructor for Stacked * added implementation of output_size for SimplexBijector * Update src/bijectors/simplex.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fixed tests * removed more references to old SimplexBijector code * fixed more dirichlet tests * formatting * possilby fixed weird formatting complaints * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: Tor Erlend Fjelde <[email protected]>
1 parent 2147089 commit 03bdffb

File tree

6 files changed

+92
-141
lines changed

6 files changed

+92
-141
lines changed

ext/BijectorsDistributionsADExt.jl

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,19 @@ Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true
7878
Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true
7979
Bijectors.isdirichlet(::TuringDirichlet) = true
8080

81-
function Bijectors.link(
82-
d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
83-
) where {proj}
84-
return Bijectors.SimplexBijector{proj}()(x)
81+
function Bijectors.link(d::TuringDirichlet, x::AbstractVecOrMat{<:Real})
82+
return Bijectors.SimplexBijector()(x)
8583
end
8684

87-
function Bijectors.link_jacobian(
88-
d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true)
89-
) where {proj}
90-
return jacobian(Bijectors.SimplexBijector{proj}(), x)
85+
function Bijectors.link_jacobian(d::TuringDirichlet, x::AbstractVector{<:Real})
86+
return jacobian(Bijectors.SimplexBijector(), x)
9187
end
9288

93-
function Bijectors.invlink(
94-
d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
95-
) where {proj}
96-
return inverse(Bijectors.SimplexBijector{proj}())(y)
89+
function Bijectors.invlink(d::TuringDirichlet, y::AbstractVecOrMat{<:Real})
90+
return inverse(Bijectors.SimplexBijector())(y)
9791
end
98-
function Bijectors.invlink_jacobian(
99-
d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true)
100-
) where {proj}
101-
return jacobian(inverse(Bijectors.SimplexBijector{proj}()), y)
92+
function Bijectors.invlink_jacobian(d::TuringDirichlet, y::AbstractVector{<:Real})
93+
return jacobian(inverse(Bijectors.SimplexBijector()), y)
10294
end
10395

10496
Bijectors.ispd(::TuringWishart) = true

src/Bijectors.jl

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -206,25 +206,12 @@ isdirichlet(::Distribution) = false
206206
# ∑xᵢ = 1 #
207207
###########
208208

209-
function link(d::Dirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)) where {proj}
210-
return SimplexBijector{proj}()(x)
211-
end
212-
213-
function link_jacobian(
214-
d::Dirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true)
215-
) where {proj}
216-
return jacobian(SimplexBijector{proj}(), x)
217-
end
209+
link(d::Dirichlet, x::AbstractVecOrMat{<:Real}) = SimplexBijector()(x)
210+
link_jacobian(d::Dirichlet, x::AbstractVector{<:Real}) = jacobian(SimplexBijector(), x)
218211

219-
function invlink(
220-
d::Dirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)
221-
) where {proj}
222-
return inverse(SimplexBijector{proj}())(y)
223-
end
224-
function invlink_jacobian(
225-
d::Dirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true)
226-
) where {proj}
227-
return jacobian(inverse(SimplexBijector{proj}()), y)
212+
invlink(d::Dirichlet, y::AbstractVecOrMat{<:Real}) = inverse(SimplexBijector())(y)
213+
function invlink_jacobian(d::Dirichlet, y::AbstractVector{<:Real})
214+
return jacobian(inverse(SimplexBijector()), y)
228215
end
229216

230217
## Matrix

src/bijectors/simplex.jl

Lines changed: 42 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
####################
22
# Simplex bijector #
33
####################
4-
struct SimplexBijector{T} <: Bijector end
5-
SimplexBijector() = SimplexBijector{true}()
4+
struct SimplexBijector <: Bijector end
5+
6+
output_size(::SimplexBijector, sz::Tuple{Int}) = (first(sz) - 1,)
7+
output_size(::Inverse{SimplexBijector}, sz::Tuple{Int}) = (first(sz) + 1,)
8+
9+
output_size(::SimplexBijector, sz::Tuple{Int,Int}) = (first(sz) - 1, last(sz))
10+
function output_size(::Inverse{SimplexBijector}, sz::Tuple{Int,Int})
11+
return (first(sz) + 1, last(sz))
12+
end
613

714
with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x)
815

916
transform(b::SimplexBijector, x) = _simplex_bijector(x, b)
1017
transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b)
1118

1219
function _simplex_bijector(x::AbstractArray, b::SimplexBijector)
13-
return _simplex_bijector!(similar(x), x, b)
20+
sz = size(x)
21+
K = size(x, 1)
22+
y = similar(x, Base.setindex(sz, K - 1, 1))
23+
_simplex_bijector!(y, x, b)
24+
return y
1425
end
1526

1627
# Vector implementation.
17-
function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj}
28+
function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector)
1829
K = length(x)
1930
@assert K > 1 "x needs to be of length greater than 1"
2031
T = eltype(x)
@@ -29,18 +40,11 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where
2940
z = (x[k] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp)
3041
y[k] = LogExpFunctions.logit(z) + log(T(K - k))
3142
end
32-
@inbounds sum_tmp += x[K - 1]
33-
@inbounds if proj
34-
y[K] = zero(T)
35-
else
36-
y[K] = one(T) - sum_tmp - x[K]
37-
end
38-
3943
return y
4044
end
4145

4246
# Matrix implementation.
43-
function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj}
47+
function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector)
4448
K, N = size(X, 1), size(X, 2)
4549
@assert K > 1 "x needs to be of length greater than 1"
4650
T = eltype(X)
@@ -54,12 +58,6 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where
5458
z = (X[k, n] + ϵ) * (one(T) - 2ϵ) / ((one(T) + ϵ) - sum_tmp)
5559
Y[k, n] = LogExpFunctions.logit(z) + log(T(K - k))
5660
end
57-
sum_tmp += X[K - 1, n]
58-
if proj
59-
Y[K, n] = zero(T)
60-
else
61-
Y[K, n] = one(T) - sum_tmp - X[K, n]
62-
end
6361
end
6462

6563
return Y
@@ -75,10 +73,16 @@ function transform!(
7573
return _simplex_inv_bijector!(x, y, ib.orig)
7674
end
7775

78-
_simplex_inv_bijector(y, b) = _simplex_inv_bijector!(similar(y), y, b)
76+
function _simplex_inv_bijector(y, b)
77+
sz = size(y)
78+
K = sz[1] + 1
79+
x = similar(y, Base.setindex(sz, K, 1))
80+
_simplex_inv_bijector!(x, y, b)
81+
return x
82+
end
7983

80-
function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) where {proj}
81-
K = length(y)
84+
function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector)
85+
K = length(y) + 1
8286
@assert K > 1 "x needs to be of length greater than 1"
8387
T = eltype(y)
8488
ϵ = _eps(T)
@@ -91,17 +95,12 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj})
9195
x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1)
9296
end
9397
@inbounds sum_tmp += x[K - 1]
94-
@inbounds if proj
95-
x[K] = _clamp(one(T) - sum_tmp, 0, 1)
96-
else
97-
x[K] = _clamp(one(T) - sum_tmp - y[K], 0, 1)
98-
end
99-
98+
x[K] = _clamp(one(T) - sum_tmp, 0, 1)
10099
return x
101100
end
102101

103-
function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj}
104-
K, N = size(Y, 1), size(Y, 2)
102+
function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector)
103+
K, N = size(Y, 1) + 1, size(Y, 2)
105104
@assert K > 1 "x needs to be of length greater than 1"
106105
T = eltype(Y)
107106
ϵ = _eps(T)
@@ -114,11 +113,7 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj})
114113
X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1)
115114
end
116115
sum_tmp += X[K - 1, n]
117-
if proj
118-
X[K, n] = _clamp(one(T) - sum_tmp, 0, 1)
119-
else
120-
X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], 0, 1)
121-
end
116+
X[K, n] = _clamp(one(T) - sum_tmp, 0, 1)
122117
end
123118

124119
return X
@@ -213,13 +208,10 @@ function simplex_logabsdetjac_gradient(x::AbstractMatrix)
213208
return g
214209
end
215210

216-
function simplex_link_jacobian(
217-
x::AbstractVector{T}, ::Val{proj}=Val(true)
218-
) where {T<:Real,proj}
211+
function simplex_link_jacobian(x::AbstractVector{T}) where {T<:Real}
219212
K = length(x)
220213
@assert K > 1 "x needs to be of length greater than 1"
221-
dydxt = similar(x, length(x), length(x))
222-
@inbounds dydxt .= 0
214+
dydxt = fill!(similar(x, K, K - 1), 0)
223215
ϵ = _eps(T)
224216
sum_tmp = zero(T)
225217

@@ -237,16 +229,10 @@ function simplex_link_jacobian(
237229
((one(T) + ϵ) - sum_tmp)^2
238230
end
239231
end
240-
@inbounds sum_tmp += x[K - 1]
241-
@inbounds if !proj
242-
@simd for i in 1:K
243-
dydxt[i, K] = -1
244-
end
245-
end
246-
return UpperTriangular(dydxt)'
232+
return dydxt'
247233
end
248-
function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj,T}
249-
return simplex_link_jacobian(x, Val(proj))
234+
function jacobian(b::SimplexBijector, x::AbstractVector{T}) where {T}
235+
return simplex_link_jacobian(x)
250236
end
251237

252238
#=
@@ -315,13 +301,10 @@ function add_simplex_link_adjoint!(
315301
end
316302
=#
317303

318-
function simplex_invlink_jacobian(
319-
y::AbstractVector{T}, ::Val{proj}=Val(true)
320-
) where {T<:Real,proj}
321-
K = length(y)
304+
function simplex_invlink_jacobian(y::AbstractVector{T}) where {T<:Real}
305+
K = length(y) + 1
322306
@assert K > 1 "x needs to be of length greater than 1"
323-
dxdy = similar(y, length(y), length(y))
324-
@inbounds dxdy .= 0
307+
dxdy = fill!(similar(y, K, K - 1), 0)
325308

326309
ϵ = _eps(T)
327310
@inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1)))
@@ -346,28 +329,20 @@ function simplex_invlink_jacobian(
346329
end
347330
end
348331
@inbounds sum_tmp += clamped_x
349-
@inbounds if proj
350-
unclamped_x = one(T) - sum_tmp
351-
clamped_x = _clamp(unclamped_x, 0, 1)
352-
else
353-
unclamped_x = one(T) - sum_tmp - y[K]
354-
clamped_x = _clamp(unclamped_x, 0, 1)
355-
if unclamped_x == clamped_x
356-
dxdy[K, K] = -1
357-
end
358-
end
332+
unclamped_x = one(T) - sum_tmp
333+
clamped_x = _clamp(unclamped_x, 0, 1)
359334
@inbounds if unclamped_x == clamped_x
360335
for i in 1:(K - 1)
361336
@simd for j in i:(K - 1)
362337
dxdy[K, i] += -dxdy[j, i]
363338
end
364339
end
365340
end
366-
return LowerTriangular(dxdy)
341+
return dxdy
367342
end
368343
# jacobian
369-
function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj,T}
370-
return simplex_invlink_jacobian(y, Val(proj))
344+
function jacobian(ib::Inverse{<:SimplexBijector}, y::AbstractVector{T}) where {T}
345+
return simplex_invlink_jacobian(y)
371346
end
372347

373348
#=

src/bijectors/stacked.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,6 @@ end
8585
end
8686
end
8787

88-
# Avoid mixing tuples and arrays.
89-
Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges)
90-
9188
Functors.@functor Stacked (bs,)
9289

9390
function Base.show(io::IO, b::Stacked)

test/interface.jl

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ end
156156
# verify against AD
157157
# similar to what we do in test/transform.jl for Dirichlet
158158
if dist isa Dirichlet
159-
b = Bijectors.SimplexBijector{false}()
159+
b = Bijectors.SimplexBijector()
160160
# HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]`
161161
# which in turn will lead to differences between `ForwardDiff.jacobian`
162162
# and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`.
@@ -168,8 +168,9 @@ end
168168
end
169169
y = b(x)
170170
@test b(param(x)) isa TrackedArray
171-
@test log(abs(det(ForwardDiff.jacobian(b, x)))) logabsdetjac(b, x)
172-
@test log(abs(det(ForwardDiff.jacobian(inverse(b), y))))
171+
@test logabsdet(ForwardDiff.jacobian(b, x)[:, 1:(end - 1)])[1]
172+
logabsdetjac(b, x)
173+
@test logabsdet(ForwardDiff.jacobian(inverse(b), y)[1:(end - 1), :])[1]
173174
logabsdetjac(inverse(b), y)
174175
else
175176
b = bijector(dist)
@@ -420,35 +421,37 @@ end
420421
b = SimplexBijector()
421422
ib = inverse(b)
422423

423-
x = ib(randn(10))
424+
d_x = 10
425+
x = ib(randn(d_x - 1))
424426
y = b(x)
425427

426428
@test Bijectors.jacobian(b, x) ForwardDiff.jacobian(b, x)
427429
@test Bijectors.jacobian(ib, y) ForwardDiff.jacobian(ib, y)
428430

429431
# Just some additional computation so we also ensure the pullbacks are the same
430-
weights = randn(10)
432+
weights_x = randn(d_x)
433+
weights_y = randn(d_x - 1)
431434

432435
# Tracker.jl
433436
x_tracked = Tracker.param(x)
434-
z = sum(weights .* b(x_tracked))
437+
z = sum(weights_y .* b(x_tracked))
435438
Tracker.back!(z)
436439
Δ_tracker = Tracker.grad(x_tracked)
437440

438441
# ForwardDiff.jl
439-
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights .* b(z)), x)
442+
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights_y .* b(z)), x)
440443

441444
# Compare
442445
@test Δ_forwarddiff Δ_tracker
443446

444447
# Tracker.jl
445448
y_tracked = Tracker.param(y)
446-
z = sum(weights .* ib(y_tracked))
449+
z = sum(weights_x .* ib(y_tracked))
447450
Tracker.back!(z)
448451
Δ_tracker = Tracker.grad(y_tracked)
449452

450453
# ForwardDiff.jl
451-
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights .* ib(z)), y)
454+
Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights_x .* ib(z)), y)
452455

453456
@test Δ_forwarddiff Δ_tracker
454457
end

0 commit comments

Comments
 (0)