Skip to content

Specialize 5-arg Hermitian-Adjoint multiplication #1394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,12 @@ end

# Special handling for adj/trans vec
matprod_dest(A::Diagonal, B::AdjOrTransAbsVec, TS) = similar(B, TS)
# Hermitian and Adjoint multiplication is handled by conjugating the terms
matprod_dest(H::Hermitian, A::AdjointAbsMat, TS) = adjoint(matprod_dest(adjoint(A), H, TS))
matprod_dest(A::AdjointAbsMat, H::Hermitian, TS) = adjoint(matprod_dest(H, adjoint(A), TS))
# Symmetric and Transpose multiplication is handled by transposing the terms
matprod_dest(S::Symmetric, T::TransposeAbsMat, TS) = transpose(matprod_dest(transpose(T), S, TS))
matprod_dest(T::TransposeAbsMat, S::Symmetric, TS) = transpose(matprod_dest(S, transpose(T), TS))

# General fallback definition for handling under- and overdetermined system as well as square problems
# While this definition is pretty general, it does e.g. promote to common element type of lhs and rhs
Expand Down
33 changes: 19 additions & 14 deletions src/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ const RealHermSymComplexHerm{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S},
const RealHermSymComplexSym{T<:Real,S} = Union{Hermitian{T,S}, Symmetric{T,S}, Symmetric{Complex{T},S}}
const RealHermSymSymTriComplexHerm{T<:Real} = Union{RealHermSymComplexSym{T}, SymTridiagonal{T}}
const SelfAdjoint = Union{SymTridiagonal{<:Real}, Symmetric{<:Real}, Hermitian}
const SelfAdjointRealOrComplex = Union{SymTridiagonal{<:Real}, Symmetric{<:Real}, Hermitian{<:Union{Real,Complex}}}

wrappertype(::Union{Symmetric, SymTridiagonal}) = Symmetric
wrappertype(::Hermitian) = Hermitian
Expand Down Expand Up @@ -721,20 +722,24 @@ for f in (:+, :-)
end

mul(A::HermOrSym, B::HermOrSym) = A * copyto!(similar(parent(B)), B)
# catch a few potential BLAS-cases
function mul(A::HermOrSym{<:BlasFloat,<:StridedMatrix}, B::AdjOrTrans{<:BlasFloat,<:StridedMatrix})
matmul_size_check(size(A), size(B))
T = promote_type(eltype(A), eltype(B))
mul!(similar(B, T, (size(A, 1), size(B, 2))),
convert(AbstractMatrix{T}, A),
copy_oftype(B, T)) # make sure the AdjOrTrans wrapper is resolved
end
function mul(A::AdjOrTrans{<:BlasFloat,<:StridedMatrix}, B::HermOrSym{<:BlasFloat,<:StridedMatrix})
matmul_size_check(size(A), size(B))
T = promote_type(eltype(A), eltype(B))
mul!(similar(B, T, (size(A, 1), size(B, 2))),
copy_oftype(A, T), # make sure the AdjOrTrans wrapper is resolved
convert(AbstractMatrix{T}, B))

# Multiplication of Hermitian and Adjoint with an Adjoint destination
# may conjugate the terms to delegate the multiplication to the parents of the adjoints
# Only defined for commutative numbers
for (AdjTransT, SymHermT) in (
(:(Adjoint{<:Union{Real,Complex}}), :SelfAdjointRealOrComplex),
(:(Transpose{<:Union{Real,Complex}}), :RealHermSymComplexSym))

@eval begin
function mul!(C::$AdjTransT, A::$SymHermT, B::$AdjTransT, α::Union{Real,Complex}, β::Union{Real,Complex})
mul!(wrapperop(C)(C), wrapperop(B)(B), A, wrapperop(C)(α), wrapperop(C)(β))
return C
end
function mul!(C::$AdjTransT, A::$AdjTransT, B::$SymHermT, α::Union{Real,Complex}, β::Union{Real,Complex})
mul!(wrapperop(C)(C), B, wrapperop(A)(A), wrapperop(C)(α), wrapperop(C)(β))
return C
end
end
end

function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector)
Expand Down
66 changes: 56 additions & 10 deletions test/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -876,21 +876,67 @@ end
end
end

@testset "Multiplications symmetric/hermitian for $T and $S" for T in
(Float16, Float32, Float64, BigFloat), S in (ComplexF16, ComplexF32, ComplexF64)
let A = transpose(Symmetric(rand(S, 3, 3))), Bv = Vector(rand(T, 3)), Bm = Matrix(rand(T, 3,3))
@testset "Multiplications symmetric/hermitian for T=$T and S=$S for size n=$n" for T in
(Float16, Float32, Float64, BigFloat, Quaternion{Float64}),
S in (T <: Quaternion ? (Quaternion{Float64},) : (ComplexF16, ComplexF32, ComplexF64, Quaternion{Float64})),
n in (2, 3, 4)
let A = transpose(Symmetric(rand(S, n, n))), Bv = Vector(rand(T, n)), Bm = Matrix(rand(T, n,n))
@test A * Bv ≈ Matrix(A) * Bv
@test A * Bm ≈ Matrix(A) * Bm
@test A * transpose(Bm) ≈ Matrix(A) * transpose(Bm)
@test A * adjoint(Bm) ≈ Matrix(A) * adjoint(Bm)
@test Bm * A ≈ Bm * Matrix(A)
@test transpose(Bm) * A ≈ transpose(Bm) * Matrix(A)
@test adjoint(Bm) * A ≈ adjoint(Bm) * Matrix(A)
C = similar(Bm, promote_type(T, S))
@test mul!(C, A, Bm) ≈ A * Bm
@test mul!(adjoint(C), A, adjoint(Bm)) ≈ A * adjoint(Bm)
@test mul!(transpose(C), A, transpose(Bm)) ≈ A * transpose(Bm)
rand!(C)
@test mul!(copy(C), A, Bm, 2, 3) ≈ A * Bm * 2 + C * 3
@test mul!(copy(C), Bm, A, 2, 3) ≈ Bm * A * 2 + C * 3
@test mul!(adjoint(copy(C)), A, adjoint(Bm), 2, 3) ≈ A * adjoint(Bm) * 2 + adjoint(C) * 3
@test mul!(adjoint(copy(C)), adjoint(Bm), A, 2, 3) ≈ adjoint(Bm) * A * 2 + adjoint(C) * 3
@test mul!(transpose(copy(C)), A, transpose(Bm), 2, 3) ≈ A * transpose(Bm) * 2 + transpose(C) * 3
@test mul!(transpose(copy(C)), transpose(Bm), A, 2, 3) ≈ transpose(Bm) * A * 2 + transpose(C) * 3
if eltype(C) <: Complex
alpha, beta = 4+2im, 3+im
@test mul!(adjoint(copy(C)), A, adjoint(Bm), alpha, beta) ≈ A * adjoint(Bm) * alpha + adjoint(C) * beta
@test mul!(adjoint(copy(C)), adjoint(Bm), A, alpha, beta) ≈ adjoint(Bm) * A * alpha + adjoint(C) * beta
@test mul!(transpose(copy(C)), A, transpose(Bm), alpha, beta) ≈ A * transpose(Bm) * alpha + transpose(C) * beta
@test mul!(transpose(copy(C)), transpose(Bm), A, alpha, beta) ≈ transpose(Bm) * A * alpha + transpose(C) * beta
end
end
let A = adjoint(Hermitian(rand(S, 3,3))), Bv = Vector(rand(T, 3)), Bm = Matrix(rand(T, 3,3))
let A = adjoint(Hermitian(rand(S, n,n))), Bv = Vector(rand(T, n)), Bm = Matrix(rand(T, n,n))
@test A * Bv ≈ Matrix(A) * Bv
@test A * Bm ≈ Matrix(A) * Bm
@test A * transpose(Bm) ≈ Matrix(A) * transpose(Bm)
@test A * adjoint(Bm) ≈ Matrix(A) * adjoint(Bm)
@test Bm * A ≈ Bm * Matrix(A)
@test transpose(Bm) * A ≈ transpose(Bm) * Matrix(A)
@test adjoint(Bm) * A ≈ adjoint(Bm) * Matrix(A)
C = similar(Bm, promote_type(T, S))
@test mul!(C, A, Bm) ≈ A * Bm
@test mul!(adjoint(C), A, adjoint(Bm)) ≈ A * adjoint(Bm)
@test mul!(transpose(C), A, transpose(Bm)) ≈ A * transpose(Bm)
rand!(C)
@test mul!(copy(C), A, Bm, 2, 3) ≈ A * Bm * 2 + C * 3
@test mul!(copy(C), Bm, A, 2, 3) ≈ Bm * A * 2 + C * 3
@test mul!(adjoint(copy(C)), A, adjoint(Bm), 2, 3) ≈ A * adjoint(Bm) * 2 + adjoint(C) * 3
@test mul!(adjoint(copy(C)), adjoint(Bm), A, 2, 3) ≈ adjoint(Bm) * A * 2 + adjoint(C) * 3
@test mul!(transpose(copy(C)), A, transpose(Bm), 2, 3) ≈ A * transpose(Bm) * 2 + transpose(C) * 3
@test mul!(transpose(copy(C)), transpose(Bm), A, 2, 3) ≈ transpose(Bm) * A * 2 + transpose(C) * 3
if eltype(C) <: Complex
alpha, beta = 4+2im, 3+im
@test mul!(adjoint(copy(C)), A, adjoint(Bm), alpha, beta) ≈ A * adjoint(Bm) * alpha + adjoint(C) * beta
@test mul!(adjoint(copy(C)), adjoint(Bm), A, alpha, beta) ≈ adjoint(Bm) * A * alpha + adjoint(C) * beta
@test mul!(transpose(copy(C)), A, transpose(Bm), alpha, beta) ≈ A * transpose(Bm) * alpha + transpose(C) * beta
@test mul!(transpose(copy(C)), transpose(Bm), A, alpha, beta) ≈ transpose(Bm) * A * alpha + transpose(C) * beta
end
end
let Ahrs = transpose(Hermitian(Symmetric(rand(T, 3, 3)))),
Acs = transpose(Symmetric(rand(S, 3, 3))),
Ahcs = transpose(Hermitian(Symmetric(rand(S, 3, 3))))
let Ahrs = transpose(Hermitian(Symmetric(rand(T, n, n)))),
Acs = transpose(Symmetric(rand(S, n, n))),
Ahcs = transpose(Hermitian(Symmetric(rand(S, n, n))))

@test Ahrs * Ahrs ≈ Ahrs * Matrix(Ahrs)
@test Ahrs * Acs ≈ Ahrs * Matrix(Acs)
Expand All @@ -899,9 +945,9 @@ end
@test Ahrs * Ahcs ≈ Matrix(Ahrs) * Ahcs
@test Ahcs * Ahrs ≈ Ahcs * Matrix(Ahrs)
end
let Ahrs = adjoint(Hermitian(Symmetric(rand(T, 3, 3)))),
Acs = adjoint(Symmetric(rand(S, 3, 3))),
Ahcs = adjoint(Hermitian(Symmetric(rand(S, 3, 3))))
let Ahrs = adjoint(Hermitian(Symmetric(rand(T, n, n)))),
Acs = adjoint(Symmetric(rand(S, n, n))),
Ahcs = adjoint(Hermitian(Symmetric(rand(S, n, n))))

@test Ahrs * Ahrs ≈ Ahrs * Matrix(Ahrs)
@test Ahcs * Ahcs ≈ Matrix(Ahcs) * Matrix(Ahcs)
Expand Down
1 change: 1 addition & 0 deletions test/testhelpers/Quaternions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ struct Quaternion{T<:Real} <: Number
end
Quaternion{T}(s::Real) where {T<:Real} = Quaternion{T}(T(s), zero(T), zero(T), zero(T))
Quaternion(s::Real, v1::Real, v2::Real, v3::Real) = Quaternion(promote(s, v1, v2, v3)...)
Quaternion{T}(q::Quaternion) where {T<:Real} = Quaternion{T}(T(q.s), T(q.v1), T(q.v2), T(q.v3))
Base.convert(::Type{Quaternion{T}}, s::Real) where {T <: Real} =
Quaternion{T}(convert(T, s), zero(T), zero(T), zero(T))
Base.promote_rule(::Type{Quaternion{T}}, ::Type{S}) where {T <: Real, S <: Real} =
Expand Down
6 changes: 5 additions & 1 deletion test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,13 @@ end
end
fds = [abs.(d) for d in ds]
@test abs.(A)::mat_type == mat_type(fds...)
@testset "Multiplication with strided matrix/vector" begin
@testset "Multiplication with strided matrix/vector, and their adjoint/transpose" begin
@test (x = fill(1.,n); A*x ≈ Array(A)*x)
@test (X = fill(1.,n,2); A*X ≈ Array(A)*X)
@test (X = fill(1.,2,n); A * X' ≈ Array(A) * X')
@test (X = fill(1.,n,2); X' * A ≈ X' * Array(A))
@test (X = fill(1.,2,n); A * transpose(X) ≈ Array(A) * transpose(X))
@test (X = fill(1.,n,2); transpose(X) * A ≈ transpose(X) * Array(A))
end
@testset "Binary operations" begin
B = mat_type == Tridiagonal ? mat_type(a, b, c) : mat_type(b, a)
Expand Down