Skip to content

inv, size, and \ for QR objects #1300

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/StaticArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import Base: getindex, setindex!, size, similar, vec, show, length, convert, pro
promote_rule, map, map!, reduce, mapreduce, foldl, mapfoldl, broadcast,
broadcast!, conj, hcat, vcat, ones, zeros, one, reshape, fill, fill!, inv,
iszero, sum, prod, count, any, all, minimum, maximum, extrema,
copy, read, read!, write, reverse
copy, read, read!, write, reverse, invperm

using Random
import Random: rand, randn, randexp, rand!, randn!, randexp!
Expand Down
46 changes: 46 additions & 0 deletions src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ Base.iterate(S::QR, ::Val{:R}) = (S.R, Val(:p))
Base.iterate(S::QR, ::Val{:p}) = (S.p, Val(:done))
Base.iterate(S::QR, ::Val{:done}) = nothing

size(F::QR) = (size(F.Q,1), size(F.R,2))

pivot_options = if isdefined(LinearAlgebra, :PivotingStrategy) # introduced in Julia v1.7
(:(Val{true}), :(Val{false}), :NoPivot, :ColumnNorm)
else
Expand Down Expand Up @@ -62,6 +64,9 @@ function identity_perm(R::StaticMatrix{N,M,T}) where {N,M,T}
return similar_type(R, Int, Size((M,)))(ntuple(x -> x, Val{M}()))
end

is_identity_perm(p::StaticVector{M}) where {M} = all(i->i==p[i], 1:M)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
is_identity_perm(p::StaticVector{M}) where {M} = all(i->i==p[i], 1:M)
is_identity_perm(p::StaticVector{M}) where {M} = all(i->i==p[i], SOneTo(M))



_qreltype(::Type{T}) where T = typeof(zero(T)/sqrt(abs2(one(T))))

@generated function _qr(::Size{sA}, A::StaticMatrix{<:Any, <:Any, TA},
Expand Down Expand Up @@ -245,3 +250,44 @@ end
# end
#end


LinearAlgebra.ldiv!(x::AbstractVecOrMat, F::QR, y::AbstractVecOrMat) = (x .= F \ y)


function \(F::QR, y::AbstractVecOrMat)
checksquare(F.R)
v = F.Q' * y

x = UpperTriangular(F.R) \ v

@inbounds invpivot(x, F.p)
end


@inline Base.@propagate_inbounds function invpivot(x::AbstractVecOrMat, p)
if is_identity_perm(p)
x
else
extra = ntuple(_ -> Colon(), ndims(x) - 1)
x[invperm(p), extra...]
end
end


function inv(F::QR)
checksquare(F.R)
R⁻¹ = inv(UpperTriangular(F.R))
A⁻¹ = R⁻¹ * F.Q'
A⁻¹ = @inbounds invpivot(A⁻¹, F.p)

n = size(F.R, 1)
m = size(F.Q, 1)
if n < m
# Add zeros to enable inv(F)*A ≈ I(m)[:,1:n] like LinearAlgebra.inv(qr(::Matrix))
#
# This is different from LinearAlgebra which instead completes the Householder reflections
return [A⁻¹; zeros(SMatrix{m-n,m,eltype(A⁻¹)})]
end

A⁻¹
end
36 changes: 36 additions & 0 deletions src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,39 @@ end
function _first_zero_on_diagonal(A::StaticULT)
_first_zero_on_diagonal(A.data)
end


inv(R::UpperTriangular{T, <:StaticMatrix}) where T =
(@inline; UpperTriangular(_inv_upper_triangular(R.data)))

_inv_upper_triangular(R::StaticMatrix{n, m}) where {n, m} =
checksquare(R)

@generated function _inv_upper_triangular(R::StaticMatrix{n, n, T}) where {n, T}
ex = quote
R_inv = MMatrix{n,n,T}(undef)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should either use similar (to properly handle non-isbits T) or the same trick as _A_ldiv_B. Also, make sure the method works to T being an integer.

end
for i in n:-1:1
append!(ex.args, (quote
r = 1 / R[$i, $i]
for j in 1:$((i)-1)
R_inv[$i, j] = 0
end
R_inv[$i, $i] = r
end).args)

for j in (i+1):n
s = :(0r)
for k in (i+1):j
s = :( $s + R[$i, $k] * R_inv[$k, $j] )
end
push!(ex.args, :(
R_inv[$i, $j] = -r * $s
))
end
end
push!(ex.args, :(
return SMatrix(R_inv)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use similar_type similarly to _A_ldiv_B above.

))
ex
end
14 changes: 8 additions & 6 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,12 @@ TrivialView(a::AbstractArray{T,N}) where {T,N} = TrivialView{typeof(a),T,N}(a)
@inline drop_sdims(a::StaticArrayLike) = TrivialView(a)
@inline drop_sdims(a) = a

Base.@propagate_inbounds function invperm(p::StaticVector)
# in difference to base, this does not check if p is a permutation (every value unique)
ip = similar(p)
ip[p] = 1:length(p)
similar_type(p)(ip)
@inline function invperm(p::StaticVector{N,T}) where {N,T<:Integer}
ip = zeros(MVector{N,T})
@inbounds for i in SOneTo(N)
j = p[i]
1 <= j <= N && iszero(ip[j]) || throw(ArgumentError("argument is not a permutation"))
ip[j] = i
end
SVector(ip)
Comment on lines -77 to +84
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite likely someone relied on this being faster thanks to not performing the checks so I'm not sure about this change.

end

146 changes: 145 additions & 1 deletion test/qr.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using StaticArrays, Test, LinearAlgebra, Random

macro test_noalloc(ex)
esc(quote
$ex
@test(@allocated($ex) == 0)
end)
end

broadenrandn(::Type{BigFloat}) = BigFloat(randn(Float64))
broadenrandn(::Type{Int}) = rand(-9:9)
broadenrandn(::Type{Complex{T}}) where T = Complex{T}(broadenrandn(T), broadenrandn(T))
Expand Down Expand Up @@ -69,10 +76,147 @@ Random.seed!(42)
end
end


@testset "QR method ambiguity" begin
# Issue #931; just test that methods do not throw an ambiguity error when called
A = @SMatrix [1.0 2.0 3.0; 4.0 5.0 6.0]
@test isa(qr(A), StaticArrays.QR)
@test isa(qr(A, Val(true)), StaticArrays.QR)
@test isa(qr(A, Val(false)), StaticArrays.QR)
end
end

@testset "invperm" begin
p = @SVector [9,8,7,6,5,4,2,1,3]
v = @SVector [15,14,13,12,11,10,15,3,7]
@test StaticArrays.is_identity_perm(p[invperm(p)])
@test v == v[p][invperm(p)]
@test_throws ArgumentError invperm(v)
expect0 = Base.JLOptions().check_bounds != 1
# expect0 && @test_noalloc @inbounds invperm(p)

@test StaticArrays.invpivot(v, p) == v[invperm(p)]
@test_noalloc StaticArrays.invpivot(v, p)
end

@testset "#1192 QR inv, size, and \\" begin
function test_pivot(pivot, MatrixType)
Random.seed!(42)
A = rand(MatrixType)
n, m = size(A)
y = @SVector rand(size(A, 1))
Y = @SMatrix rand(n, 2)
F = @inferred QR qr(A, pivot)
F_gold = @inferred LinearAlgebra.QRCompactWY qr(Matrix(A), pivot)

expect0 = pivot isa NoPivot || Base.JLOptions().check_bounds != 1

@test StaticArrays.is_identity_perm(F.p) == (pivot isa NoPivot)
@test size(F) == size(A)

@testset "inv UpperTriangular StaticMatrix" begin
if m <= n
invR = @inferred StaticMatrix inv(UpperTriangular(F.R))
@test invR*F.R ≈ I(m)

expect0 && @eval @test_noalloc inv(UpperTriangular($F.R))
else
@test_throws DimensionMismatch inv(UpperTriangular(F.R))
end
end

@testset "qr inversion" begin
if m <= n
inv_F_gold = inv(qr(Matrix(A)))
inv_F = @inferred StaticMatrix inv(F)
@test size(inv_F) == size(inv_F_gold)
@test inv_F[1:m,:] ≈ inv_F_gold[1:m,:] # equal except for the nullspace
@test inv_F * A ≈ I(n)[:,1:m]

expect0 && @eval @test_noalloc inv($F)
else
@test_throws DimensionMismatch inv(F)
@test_throws DimensionMismatch inv(qr(Matrix(A)))
end
end

@testset "QR \\ StaticVector" begin
if m <= n
x_gold = Matrix(A) \ Vector(y)
x = @inferred StaticVector F \ y
@test x_gold ≈ x

expect0 && @eval @test_noalloc $F \ $y
else
@test_throws DimensionMismatch F \ y

if pivot isa Val{false}
@test_throws DimensionMismatch F_gold \ Vector(y)
end
end
end

@testset "QR \\ StaticMatrix" begin
if m <= n
@test F \ Y ≈ A \ Y

expect0 && @eval @test_noalloc $F \ $Y
else
@test_throws DimensionMismatch F \ Y
end
end

@testset "ldiv!" begin
x = @MVector zeros(m)
X = @MMatrix zeros(m, size(Y, 2))

if m <= n
ldiv!(x, F, y)
@test x ≈ A \ y

ldiv!(X, F, Y)
@test X ≈ A \ Y

expect0 && @test_noalloc ldiv!(x, F, y)
expect0 && @test_noalloc ldiv!(X, F, Y)
else
@test_throws DimensionMismatch ldiv!(x, F, y)
@test_throws DimensionMismatch ldiv!(X, F, Y)

if pivot isa Val{false}
@test_throws DimensionMismatch ldiv!(zeros(size(x)), F_gold, Array(y))
@test_throws DimensionMismatch ldiv!(zeros(size(X)), F_gold, Array(Y))
end
end
end
end

@testset "pivot=$pivot" for pivot in [NoPivot(), ColumnNorm()]
@testset "$label ($n,$m)" for (label,n,m) in [
(:square,3,3),
(:overdetermined,6,3),
(:underdetermined,3,4)
]
test_pivot(pivot, SMatrix{n,m,Float64})
end

@testset "performance" begin
function speed_test(n, iter)
y2 = @SVector rand(n)
A2 = @SMatrix rand(n,5)
F2 = qr(A2, pivot)
iA = pinv(A2)

min_time_to_solve = minimum(@elapsed(A2 \ y2) for _ in 1:iter)
min_time_to_solve_qr = minimum(@elapsed(F2 \ y2) for _ in 1:iter)
min_time_to_solve_inv = minimum(@elapsed(iA * y2) for _ in 1:iter)

if 1 != Base.JLOptions().check_bounds
@test 10min_time_to_solve_qr < min_time_to_solve
@test 2min_time_to_solve_inv < min_time_to_solve_qr
end
end
speed_test(100, 100)
@test @elapsed(speed_test(100, 100)) < 1
end
end
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ include("testutil.jl")
#
# Pkg.test("StaticArrays", test_args=["MVector", "SVector"])
#
# To tests with normal bounds checking use:
#
# Pkg.test("StaticArrays", julia_args=["--check-bounds=auto"], test_args=["..."])
#
enabled_tests = lowercase.(ARGS)
function addtests(fname)
key = lowercase(splitext(fname)[1])
Expand Down
Loading