diff --git a/src/StaticArrays.jl b/src/StaticArrays.jl index 023c45f5..15884212 100644 --- a/src/StaticArrays.jl +++ b/src/StaticArrays.jl @@ -6,7 +6,7 @@ import Base: getindex, setindex!, size, similar, vec, show, length, convert, pro promote_rule, map, map!, reduce, mapreduce, foldl, mapfoldl, broadcast, broadcast!, conj, hcat, vcat, ones, zeros, one, reshape, fill, fill!, inv, iszero, sum, prod, count, any, all, minimum, maximum, extrema, - copy, read, read!, write, reverse + copy, read, read!, write, reverse, invperm using Random import Random: rand, randn, randexp, rand!, randn!, randexp! diff --git a/src/qr.jl b/src/qr.jl index cab74d86..d0aa5b9e 100644 --- a/src/qr.jl +++ b/src/qr.jl @@ -11,6 +11,8 @@ Base.iterate(S::QR, ::Val{:R}) = (S.R, Val(:p)) Base.iterate(S::QR, ::Val{:p}) = (S.p, Val(:done)) Base.iterate(S::QR, ::Val{:done}) = nothing +size(F::QR) = (size(F.Q,1), size(F.R,2)) + pivot_options = if isdefined(LinearAlgebra, :PivotingStrategy) # introduced in Julia v1.7 (:(Val{true}), :(Val{false}), :NoPivot, :ColumnNorm) else @@ -62,6 +64,9 @@ function identity_perm(R::StaticMatrix{N,M,T}) where {N,M,T} return similar_type(R, Int, Size((M,)))(ntuple(x -> x, Val{M}())) end +is_identity_perm(p::StaticVector{M}) where {M} = all(i->i==p[i], 1:M) + + _qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T)))) @generated function _qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA}, @@ -245,3 +250,44 @@ end # end #end + +LinearAlgebra.ldiv!(x::AbstractVecOrMat, F::QR, y::AbstractVecOrMat) = (x .= F \ y) + + +function \(F::QR, y::AbstractVecOrMat) + checksquare(F.R) + v = F.Q' * y + + x = UpperTriangular(F.R) \ v + + @inbounds invpivot(x, F.p) +end + + +@inline Base.@propagate_inbounds function invpivot(x::AbstractVecOrMat, p) + if is_identity_perm(p) + x + else + extra = ntuple(_ -> Colon(), ndims(x) - 1) + x[invperm(p), extra...] + end +end + + +function inv(F::QR) + checksquare(F.R) + R⁻¹ = inv(UpperTriangular(F.R)) + A⁻¹ = R⁻¹ * F.Q' + A⁻¹ = @inbounds invpivot(A⁻¹, F.p) + + n = size(F.R, 1) + m = size(F.Q, 1) + if n < m + # Add zeros to enable inv(F)*A ≈ I(m)[:,1:n] like LinearAlgebra.inv(qr(::Matrix)) + # + # This is different from LinearAlgebra which instead completes the Householder reflections + return [A⁻¹; zeros(SMatrix{m-n,m,eltype(A⁻¹)})] + end + + A⁻¹ +end diff --git a/src/triangular.jl b/src/triangular.jl index 0b08ca6e..3cff4603 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -61,3 +61,39 @@ end function _first_zero_on_diagonal(A::StaticULT) _first_zero_on_diagonal(A.data) end + + +inv(R::UpperTriangular{T, <:StaticMatrix}) where T = + (@inline; UpperTriangular(_inv_upper_triangular(R.data))) + +_inv_upper_triangular(R::StaticMatrix{n, m}) where {n, m} = + checksquare(R) + +@generated function _inv_upper_triangular(R::StaticMatrix{n, n, T}) where {n, T} + ex = quote + R_inv = MMatrix{n,n,T}(undef) + end + for i in n:-1:1 + append!(ex.args, (quote + r = 1 / R[$i, $i] + for j in 1:$((i)-1) + R_inv[$i, j] = 0 + end + R_inv[$i, $i] = r + end).args) + + for j in (i+1):n + s = :(0r) + for k in (i+1):j + s = :( $s + R[$i, $k] * R_inv[$k, $j] ) + end + push!(ex.args, :( + R_inv[$i, $j] = -r * $s + )) + end + end + push!(ex.args, :( + return SMatrix(R_inv) + )) + ex +end diff --git a/src/util.jl b/src/util.jl index 6079c36d..2cbbe6ce 100644 --- a/src/util.jl +++ b/src/util.jl @@ -74,10 +74,12 @@ TrivialView(a::AbstractArray{T,N}) where {T,N} = TrivialView{typeof(a),T,N}(a) @inline drop_sdims(a::StaticArrayLike) = TrivialView(a) @inline drop_sdims(a) = a -Base.@propagate_inbounds function invperm(p::StaticVector) - # in difference to base, this does not check if p is a permutation (every value unique) - ip = similar(p) - ip[p] = 1:length(p) - similar_type(p)(ip) +@inline function invperm(p::StaticVector{N,T}) where {N,T<:Integer} + ip = zeros(MVector{N,T}) + @inbounds for i in SOneTo(N) + j = p[i] + 1 <= j <= N && iszero(ip[j]) || throw(ArgumentError("argument is not a permutation")) + ip[j] = i + end + SVector(ip) end - diff --git a/test/qr.jl b/test/qr.jl index a6297abf..8a20adc2 100644 --- a/test/qr.jl +++ b/test/qr.jl @@ -1,5 +1,12 @@ using StaticArrays, Test, LinearAlgebra, Random +macro test_noalloc(ex) + esc(quote + $ex + @test(@allocated($ex) == 0) + end) +end + broadenrandn(::Type{BigFloat}) = BigFloat(randn(Float64)) broadenrandn(::Type{Int}) = rand(-9:9) broadenrandn(::Type{Complex{T}}) where T = Complex{T}(broadenrandn(T), broadenrandn(T)) @@ -69,10 +76,147 @@ Random.seed!(42) end end + @testset "QR method ambiguity" begin # Issue #931; just test that methods do not throw an ambiguity error when called A = @SMatrix [1.0 2.0 3.0; 4.0 5.0 6.0] @test isa(qr(A), StaticArrays.QR) @test isa(qr(A, Val(true)), StaticArrays.QR) @test isa(qr(A, Val(false)), StaticArrays.QR) -end \ No newline at end of file +end + +@testset "invperm" begin + p = @SVector [9,8,7,6,5,4,2,1,3] + v = @SVector [15,14,13,12,11,10,15,3,7] + @test StaticArrays.is_identity_perm(p[invperm(p)]) + @test v == v[p][invperm(p)] + @test_throws ArgumentError invperm(v) + expect0 = Base.JLOptions().check_bounds != 1 + # expect0 && @test_noalloc @inbounds invperm(p) + + @test StaticArrays.invpivot(v, p) == v[invperm(p)] + @test_noalloc StaticArrays.invpivot(v, p) +end + +@testset "#1192 QR inv, size, and \\" begin + function test_pivot(pivot, MatrixType) + Random.seed!(42) + A = rand(MatrixType) + n, m = size(A) + y = @SVector rand(size(A, 1)) + Y = @SMatrix rand(n, 2) + F = @inferred QR qr(A, pivot) + F_gold = @inferred LinearAlgebra.QRCompactWY qr(Matrix(A), pivot) + + expect0 = pivot isa NoPivot || Base.JLOptions().check_bounds != 1 + + @test StaticArrays.is_identity_perm(F.p) == (pivot isa NoPivot) + @test size(F) == size(A) + + @testset "inv UpperTriangular StaticMatrix" begin + if m <= n + invR = @inferred StaticMatrix inv(UpperTriangular(F.R)) + @test invR*F.R ≈ I(m) + + expect0 && @eval @test_noalloc inv(UpperTriangular($F.R)) + else + @test_throws DimensionMismatch inv(UpperTriangular(F.R)) + end + end + + @testset "qr inversion" begin + if m <= n + inv_F_gold = inv(qr(Matrix(A))) + inv_F = @inferred StaticMatrix inv(F) + @test size(inv_F) == size(inv_F_gold) + @test inv_F[1:m,:] ≈ inv_F_gold[1:m,:] # equal except for the nullspace + @test inv_F * A ≈ I(n)[:,1:m] + + expect0 && @eval @test_noalloc inv($F) + else + @test_throws DimensionMismatch inv(F) + @test_throws DimensionMismatch inv(qr(Matrix(A))) + end + end + + @testset "QR \\ StaticVector" begin + if m <= n + x_gold = Matrix(A) \ Vector(y) + x = @inferred StaticVector F \ y + @test x_gold ≈ x + + expect0 && @eval @test_noalloc $F \ $y + else + @test_throws DimensionMismatch F \ y + + if pivot isa Val{false} + @test_throws DimensionMismatch F_gold \ Vector(y) + end + end + end + + @testset "QR \\ StaticMatrix" begin + if m <= n + @test F \ Y ≈ A \ Y + + expect0 && @eval @test_noalloc $F \ $Y + else + @test_throws DimensionMismatch F \ Y + end + end + + @testset "ldiv!" begin + x = @MVector zeros(m) + X = @MMatrix zeros(m, size(Y, 2)) + + if m <= n + ldiv!(x, F, y) + @test x ≈ A \ y + + ldiv!(X, F, Y) + @test X ≈ A \ Y + + expect0 && @test_noalloc ldiv!(x, F, y) + expect0 && @test_noalloc ldiv!(X, F, Y) + else + @test_throws DimensionMismatch ldiv!(x, F, y) + @test_throws DimensionMismatch ldiv!(X, F, Y) + + if pivot isa Val{false} + @test_throws DimensionMismatch ldiv!(zeros(size(x)), F_gold, Array(y)) + @test_throws DimensionMismatch ldiv!(zeros(size(X)), F_gold, Array(Y)) + end + end + end + end + + @testset "pivot=$pivot" for pivot in [NoPivot(), ColumnNorm()] + @testset "$label ($n,$m)" for (label,n,m) in [ + (:square,3,3), + (:overdetermined,6,3), + (:underdetermined,3,4) + ] + test_pivot(pivot, SMatrix{n,m,Float64}) + end + + @testset "performance" begin + function speed_test(n, iter) + y2 = @SVector rand(n) + A2 = @SMatrix rand(n,5) + F2 = qr(A2, pivot) + iA = pinv(A2) + + min_time_to_solve = minimum(@elapsed(A2 \ y2) for _ in 1:iter) + min_time_to_solve_qr = minimum(@elapsed(F2 \ y2) for _ in 1:iter) + min_time_to_solve_inv = minimum(@elapsed(iA * y2) for _ in 1:iter) + + if 1 != Base.JLOptions().check_bounds + @test 10min_time_to_solve_qr < min_time_to_solve + @test 2min_time_to_solve_inv < min_time_to_solve_qr + end + end + speed_test(100, 100) + @test @elapsed(speed_test(100, 100)) < 1 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b86f343c..08b3b77d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,10 @@ include("testutil.jl") # # Pkg.test("StaticArrays", test_args=["MVector", "SVector"]) # +# To tests with normal bounds checking use: +# +# Pkg.test("StaticArrays", julia_args=["--check-bounds=auto"], test_args=["..."]) +# enabled_tests = lowercase.(ARGS) function addtests(fname) key = lowercase(splitext(fname)[1])