diff --git a/src/symmetric.jl b/src/symmetric.jl index ed5fe3b5..061abb49 100644 --- a/src/symmetric.jl +++ b/src/symmetric.jl @@ -289,7 +289,7 @@ Base.dataids(A::HermOrSym) = Base.dataids(parent(A)) Base.unaliascopy(A::Hermitian) = Hermitian(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) Base.unaliascopy(A::Symmetric) = Symmetric(Base.unaliascopy(parent(A)), sym_uplo(A.uplo)) -_conjugation(::Union{Symmetric, Hermitian{<:Real}}) = transpose +_conjugation(::Symmetric) = transpose _conjugation(::Hermitian) = adjoint diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo)) @@ -470,25 +470,25 @@ Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo) # tril/triu function tril(A::Hermitian, k::Integer=0) if A.uplo == 'U' && k <= 0 - return tril!(copy(A.data'),k) + return tril_maybe_inplace(copy(A.data'),k) elseif A.uplo == 'U' && k > 0 - return tril!(copy(A.data'),-1) + tril!(triu(A.data),k) + return tril_maybe_inplace(copy(A.data'),-1) + tril_maybe_inplace(triu(A.data),k) elseif A.uplo == 'L' && k <= 0 return tril(A.data,k) else - return tril(A.data,-1) + tril!(triu!(copy(A.data')),k) + return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(A.data')),k) end end function tril(A::Symmetric, k::Integer=0) if A.uplo == 'U' && k <= 0 - return tril!(copy(transpose(A.data)),k) + return tril_maybe_inplace(copy(transpose(A.data)),k) elseif A.uplo == 'U' && k > 0 - return tril!(copy(transpose(A.data)),-1) + tril!(triu(A.data),k) + return tril_maybe_inplace(copy(transpose(A.data)),-1) + tril_maybe_inplace(triu(A.data),k) elseif A.uplo == 'L' && k <= 0 return tril(A.data,k) else - return tril(A.data,-1) + tril!(triu!(copy(transpose(A.data))),k) + return tril(A.data,-1) + tril_maybe_inplace(triu_maybe_inplace(copy(transpose(A.data))),k) end end @@ -496,11 +496,11 @@ function triu(A::Hermitian, k::Integer=0) if A.uplo == 'U' && k >= 0 return triu(A.data,k) elseif A.uplo == 'U' && k < 0 - return triu(A.data,1) + triu!(tril!(copy(A.data')),k) + return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(A.data')),k) elseif A.uplo == 'L' && k >= 0 - return triu!(copy(A.data'),k) + return triu_maybe_inplace(copy(A.data'),k) else - return triu!(copy(A.data'),1) + triu!(tril(A.data),k) + return triu_maybe_inplace(copy(A.data'),1) + triu_maybe_inplace(tril(A.data),k) end end @@ -508,11 +508,11 @@ function triu(A::Symmetric, k::Integer=0) if A.uplo == 'U' && k >= 0 return triu(A.data,k) elseif A.uplo == 'U' && k < 0 - return triu(A.data,1) + triu!(tril!(copy(transpose(A.data))),k) + return triu(A.data,1) + triu_maybe_inplace(tril_maybe_inplace(copy(transpose(A.data))),k) elseif A.uplo == 'L' && k >= 0 - return triu!(copy(transpose(A.data)),k) + return triu_maybe_inplace(copy(transpose(A.data)),k) else - return triu!(copy(transpose(A.data)),1) + triu!(tril(A.data),k) + return triu_maybe_inplace(copy(transpose(A.data)),1) + triu_maybe_inplace(tril(A.data),k) end end diff --git a/src/triangular.jl b/src/triangular.jl index 5b476d24..906a25aa 100644 --- a/src/triangular.jl +++ b/src/triangular.jl @@ -484,6 +484,11 @@ function tril!(A::UnitLowerTriangular, k::Integer=0) return tril!(LowerTriangular(A.data), k) end +tril_maybe_inplace(A, k::Integer=0) = tril(A, k) +triu_maybe_inplace(A, k::Integer=0) = triu(A, k) +tril_maybe_inplace(A::StridedMatrix, k::Integer=0) = tril!(A, k) +triu_maybe_inplace(A::StridedMatrix, k::Integer=0) = triu!(A, k) + adjoint(A::LowerTriangular) = UpperTriangular(adjoint(A.data)) adjoint(A::UpperTriangular) = LowerTriangular(adjoint(A.data)) adjoint(A::UnitLowerTriangular) = UnitUpperTriangular(adjoint(A.data)) diff --git a/test/symmetric.jl b/test/symmetric.jl index 9f727b8c..5de3a853 100644 --- a/test/symmetric.jl +++ b/test/symmetric.jl @@ -1199,4 +1199,27 @@ end end end +@testset "triu/tril with immutable arrays" begin + struct ImmutableMatrix{T,A<:AbstractMatrix{T}} <: AbstractMatrix{T} + a :: A + end + Base.size(A::ImmutableMatrix) = size(A.a) + Base.getindex(A::ImmutableMatrix, i::Int, j::Int) = getindex(A.a, i, j) + Base.copy(A::ImmutableMatrix) = A + LinearAlgebra.adjoint(A::ImmutableMatrix) = ImmutableMatrix(adjoint(A.a)) + LinearAlgebra.transpose(A::ImmutableMatrix) = ImmutableMatrix(transpose(A.a)) + + A = ImmutableMatrix([1 2; 3 4]) + for T in (Symmetric, Hermitian), uplo in (:U, :L) + H = T(A, uplo) + MH = Matrix(H) + @test triu(H,-1) == triu(MH,-1) + @test triu(H) == triu(MH) + @test triu(H,1) == triu(MH,1) + @test tril(H,1) == tril(MH,1) + @test tril(H) == tril(MH) + @test tril(H,-1) == tril(MH,-1) + end +end + end # module TestSymmetric