Skip to content

Commit 00547db

Browse files
committed
1 parent ad52409 commit 00547db

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

lib/cusolver/linalg.jl

+21-8
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ function Base.:\(_A::CuMatOrAdj, _B::CuOrAdj)
3434
if n < m
3535
# LQ decomposition
3636
At = CuMatrix(A')
37-
F, tau = CUSOLVER.geqrf!(At) # A = RᴴQᴴ
37+
F, tau = geqrf!(At) # A = RᴴQᴴ
3838
if B isa CuVector{T}
3939
CUBLAS.trsv!('U', 'C', 'N', view(F,1:n,1:n), B)
4040
X = CUDA.zeros(T, m)
@@ -45,15 +45,15 @@ function Base.:\(_A::CuMatOrAdj, _B::CuOrAdj)
4545
X = CUDA.zeros(T, m, p)
4646
view(X, 1:n, :) .= B
4747
end
48-
CUSOLVER.ormqr!('L', 'N', F, tau, X)
48+
ormqr!('L', 'N', F, tau, X)
4949
elseif n == m
5050
# LU decomposition with partial pivoting
51-
F, p, info = CUSOLVER.getrf!(A) # PA = LU
52-
X = CUSOLVER.getrs!('N', F, p, B)
51+
F, p, info = getrf!(A) # PA = LU
52+
X = getrs!('N', F, p, B)
5353
else
5454
# QR decomposition
55-
F, tau = CUSOLVER.geqrf!(A) # A = QR
56-
CUSOLVER.ormqr!('L', 'C', F, tau, B)
55+
F, tau = geqrf!(A) # A = QR
56+
ormqr!('L', 'C', F, tau, B)
5757
if B isa CuVector{T}
5858
X = B[1:m]
5959
CUBLAS.trsv!('U', 'N', 'N', view(F,1:m,1:m), X)
@@ -307,9 +307,22 @@ end
307307

308308
## LU
309309

310-
function LinearAlgebra.lu!(A::StridedCuMatrix{T}, ::RowMaximum; check::Bool = true) where {T}
310+
function _check_lu_success(info, allowsingular)
311+
if VERSION >= v"1.11.0-DEV.1535"
312+
if info < 0 # zero pivot error from unpivoted LU
313+
LinearAlgebra.checknozeropivot(-info)
314+
else
315+
allowsingular || LinearAlgebra.checknonsingular(info)
316+
end
317+
else
318+
LinearAlgebra.checknonsingular(info)
319+
end
320+
end
321+
322+
function LinearAlgebra.lu!(A::StridedCuMatrix{T}, ::RowMaximum;
323+
check::Bool=true, allowsingular::Bool=false) where {T}
311324
lpt = getrf!(A)
312-
check && LinearAlgebra.checknonsingular(lpt[3])
325+
check && _check_lu_success(lpt[3], allowsingular)
313326
return LU(lpt[1], lpt[2], Int(lpt[3]))
314327
end
315328

0 commit comments

Comments
 (0)