Skip to content

Commit 354bbd5

Browse files
committed
invperm: 0-allocation version of Base.invperm
1 parent d67c199 commit 354bbd5

File tree

2 files changed

+39
-28
lines changed

2 files changed

+39
-28
lines changed

src/util.jl

+11-6
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,15 @@ TrivialView(a::AbstractArray{T,N}) where {T,N} = TrivialView{typeof(a),T,N}(a)
7474
@inline drop_sdims(a::StaticArrayLike) = TrivialView(a)
7575
@inline drop_sdims(a) = a
7676

77-
Base.@propagate_inbounds function invperm(p::StaticVector)
78-
# in difference to base, this does not check if p is a permutation (every value unique)
79-
ip = similar(p)
80-
ip[p] = 1:length(p)
81-
similar_type(p)(ip)
82-
end
77+
import Base: invperm
8378

79+
# 0 allocations invperm
80+
@inline function invperm(p::StaticVector{N,T}) where {N,T<:Integer}
81+
ip = zeros(MVector{N,T})
82+
@inbounds for i in 1:N
83+
j = p[i]
84+
0 < j <= N && iszero(ip[j]) || throw(ArgumentError("argument is not a permuation"))
85+
ip[j] = i
86+
end
87+
SVector(ip)
88+
end

test/qr.jl

+28-22
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ end
8080

8181

8282
@testset "#1192 The following functions are available for the QR objects: inv, size, and \\." begin
83-
@testset "pivot=$pivot" for pivot in [Val(true), Val(false)] #, ColumnNorm()]
83+
function test_pivot(pivot)
8484
y = @SVector rand(5)
8585
Y = @SMatrix rand(5,5)
8686
A = @SMatrix rand(5,5)
@@ -106,17 +106,15 @@ end
106106
end
107107

108108
@testset "solve linear system" begin
109-
x = Matrix(A) \ Vector(y)
110-
@test x A \ y F \ y F \ Vector(y)
109+
gold_x = Matrix(A) \ Vector(y)
110+
@test gold_x A \ y F \ y F \ Vector(y)
111+
@test 0 == @allocated F \ y
111112

112-
x_under = Matrix(A_under) \ Vector(y)
113-
@test x_under == A_under \ y
114-
@test x_under F_under \ y
113+
gold_x_under = Matrix(A_under) \ Vector(y)
114+
@test gold_x_under == A_under \ y
115+
@test gold_x_under F_under \ y
115116
@test F_under \ y == F_under \ Vector(y)
116-
117-
x_over = Matrix(A_over) \ Vector(y)
118-
@test x_over A_over \ y
119-
@test A_over * x_over y
117+
@test 0 == @allocated F_under \ y
120118

121119
@test_throws DimensionMismatch F_over \ y
122120
@test_throws DimensionMismatch qr(Matrix(A_over)) \ y
@@ -125,38 +123,46 @@ end
125123
@testset "solve several linear systems" begin
126124
@test F \ Y A \ Y
127125
@test F_under \ Y A_under \ Y
126+
@test 0 == @allocated F \ Y
128127
end
129128

130129
@testset "ldiv!" begin
131130
x = @MVector zeros(5)
132-
ldiv!(x, F, y)
131+
@test 0 == @allocated ldiv!(x, F, y)
133132
@test x A \ y
134133

135134
X = @MMatrix zeros(5,5)
136135
Y = @SMatrix rand(5,5)
137-
ldiv!(X, F, Y)
136+
@test 0 == @allocated ldiv!(X, F, Y)
137+
@test 0 == @allocated A \ Y
138138
@test X A \ Y
139139
end
140140

141141
@testset "invperm" begin
142-
x = @SVector [10,15,3,7]
142+
v = @SVector [10,15,3,7]
143143
p = @SVector [4,2,1,3]
144-
@test x == x[p][invperm(p)]
144+
@test 0 == @allocated invperm(p)
145+
@test v == v[p][invperm(p)]
145146
@test StaticArrays.is_identity_perm(p[invperm(p)])
146-
@test_throws Union{BoundsError,ArgumentError} invperm(x)
147+
@test_throws Union{BoundsError,ArgumentError} invperm(v)
147148
end
148149

149150
@testset "10x faster" begin
150-
time_to_test = @elapsed (function()
151-
y2 = @SVector rand(50)
152-
A2 = @SMatrix rand(50,5)
151+
function speed_test(n, iter)
152+
y2 = @SVector rand(n)
153+
A2 = @SMatrix rand(n,5)
153154
F2 = qr(A2, pivot)
154155

155-
min_time_to_solve = minimum(@elapsed(A2 \ y2) for _ in 1:1_000)
156-
min_time_to_solve_qr = minimum(@elapsed(F2 \ y2) for _ in 1:1_000)
156+
min_time_to_solve = minimum(@elapsed(A2 \ y2) for _ in 1:iter)
157+
min_time_to_solve_qr = minimum(@elapsed(F2 \ y2) for _ in 1:iter)
157158
@test 10min_time_to_solve_qr < min_time_to_solve
158-
end)()
159-
@test time_to_test < 10
159+
end
160+
speed_test(50, 1_000)
161+
@test @elapsed(speed_test(50, 1_000)) < 1
160162
end
161163
end
164+
165+
@testset "pivot=$pivot" for pivot in [Val(true), Val(false)] #, ColumnNorm()]
166+
test_pivot(pivot)
167+
end
162168
end

0 commit comments

Comments
 (0)