diff --git a/src/tridiag.jl b/src/tridiag.jl index 0e8a5119..02b4c352 100644 --- a/src/tridiag.jl +++ b/src/tridiag.jl @@ -174,16 +174,13 @@ end isreal(S::SymTridiagonal) = isreal(S.dv) && isreal(S.ev) transpose(S::SymTridiagonal) = S -adjoint(S::SymTridiagonal{<:Number}) = SymTridiagonal(vec(adjoint(S.dv)), vec(adjoint(S.ev))) -adjoint(S::SymTridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = - SymTridiagonal(adjoint(parent(S.dv)), adjoint(parent(S.ev))) +adjoint(S::SymTridiagonal) = SymTridiagonal(_vecadjoint(S.dv), _vecadjoint(S.ev)) permutedims(S::SymTridiagonal) = S function permutedims(S::SymTridiagonal, perm) Base.checkdims_perm(axes(S), axes(S), perm) NTuple{2}(perm) == (2, 1) ? permutedims(S) : S end -Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...) ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S)) issymmetric(S::SymTridiagonal) = true @@ -683,23 +680,17 @@ for func in (:conj, :copy, :real, :imag) end isreal(T::Tridiagonal) = isreal(T.dl) && isreal(T.d) && isreal(T.du) -adjoint(S::Tridiagonal{<:Number}) = Tridiagonal(vec(adjoint(S.du)), vec(adjoint(S.d)), vec(adjoint(S.dl))) -adjoint(S::Tridiagonal{<:Number, <:Base.ReshapedArray{<:Number,1,<:Adjoint}}) = - Tridiagonal(adjoint(parent(S.du)), adjoint(parent(S.d)), adjoint(parent(S.dl))) -transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl) +adjoint(S::Tridiagonal) = Tridiagonal(_vecadjoint(S.du), _vecadjoint(S.d), _vecadjoint(S.dl)) +transpose(S::Tridiagonal) = Tridiagonal(_vectranspose(S.du), _vectranspose(S.d), _vectranspose(S.dl)) permutedims(T::Tridiagonal) = Tridiagonal(T.du, T.d, T.dl) function permutedims(T::Tridiagonal, perm) Base.checkdims_perm(axes(T), axes(T), perm) NTuple{2}(perm) == (2, 1) ? permutedims(T) : T end -Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...)) -Base.copy(tS::Transpose{<:Any,<:Tridiagonal}) = (S = tS.parent; Tridiagonal(map(x -> copy.(transpose.(x)), (S.du, S.d, S.dl))...)) ishermitian(S::Tridiagonal) = all(ishermitian, S.d) && all(Iterators.map((x, y) -> x == y', S.du, S.dl)) issymmetric(S::Tridiagonal) = all(issymmetric, S.d) && all(Iterators.map((x, y) -> x == transpose(y), S.du, S.dl)) -\(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:AbstractVecOrMat}) = copy(A) \ B - function diag(M::Tridiagonal, n::Integer=0) # every branch call similar(..., ::Int) to make sure the # same vector type is returned independent of n diff --git a/test/tridiag.jl b/test/tridiag.jl index effea2b0..7eaa4128 100644 --- a/test/tridiag.jl +++ b/test/tridiag.jl @@ -1207,4 +1207,26 @@ end @test_throws BoundsError S[LinearAlgebra.BandIndex(0,size(S,1)+1)] end +@testset "lazy adjtrans" begin + dv = fill([1 2; 3 4], 3) + ev = fill([5 6; 7 8], 2) + T = Tridiagonal(ev, dv, ev) + S = SymTridiagonal(dv, ev) + m = [2 4; 4 2] + for B in (copy(T), copy(S)) + for op in (transpose, adjoint) + C = op(B) + el = op(m) + C[1,1] = el + @test B[1,1] == m + if B isa Tridiagonal + C[2,1] = el + @test B[1,2] == m + end + @test (@allocated op(B)) == 0 + @test (@allocated op(op(B))) == 0 + end + end +end + end # module TestTridiagonal