Skip to content

Commit 11bed80

Browse files
authored
Support mul!(Diagonal, A, B) (#2977)
1 parent 1d1be49 commit 11bed80

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

lib/cublas/linalg.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,12 @@ function LinearAlgebra.mul!(C::CuMatrix{T}, A::Diagonal{T,<:CuVector}, B::Adjoin
537537
return C
538538
end
539539

540+
function LinearAlgebra.mul!(C::Diagonal{T, <:CuVector}, A::Union{<:CuMatrix{T}, Adjoint{T, <:CuMatrix}, Transpose{T, <:CuMatrix}}, B::Union{<:CuMatrix{T}, Adjoint{T, <:CuMatrix}, Transpose{T, <:CuMatrix}}) where {T<:CublasFloat}
541+
Cfull = A*B
542+
C.diag .= diag(Cfull)
543+
return C
544+
end
545+
540546
function LinearAlgebra.lmul!(A::Diagonal{T,<:CuVector{T}}, B::CuMatrix{T}) where {T<:CublasFloat}
541547
return dgmm!('L', B, A.diag, B)
542548
end

test/libraries/cublas/extensions.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,15 @@ k = 13
585585
@test Array(d_AX) A' * Diagonal(x)
586586

587587
@test Array(d_X) == Diagonal(Array(d_x))
588+
589+
d_X = Diagonal(copy(d_x))
590+
diagA = diagm(rand(elty, m))
591+
d_diagA = CuArray(diagA)
592+
diagB = diagm(rand(elty, m))
593+
d_diagB = CuArray(diagB)
594+
diagAdiagB = diagA * diagB'
595+
mul!(d_X, d_diagA, d_diagB')
596+
@test Diagonal(collect(d_X.diag)) Diagonal(diagAdiagB)
588597
end
589598
end # extensions
590599

0 commit comments

Comments
 (0)