Skip to content

Commit 2dde909

Browse files
committed
Remove diagm in favour of GPUArrays
1 parent e0ec269 commit 2dde909

File tree

2 files changed

+3
-34
lines changed

2 files changed

+3
-34
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,6 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
9696
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
9797
SparseMatricesCSR = "a0a7dd2c-ebf4-11e9-1f05-cf50bc540ca1"
9898
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
99+
100+
[sources]
101+
GPUArrays = {url="https://github.com/JuliaGPU/GPUArrays.jl", rev="master"}

lib/cublas/linalg.jl

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -581,40 +581,6 @@ function LinearAlgebra.rmul!(A::Adjoint{Tt, <:CuMatrix{Tt}}, B::Diagonal{Td,<:Cu
581581
return adjoint(At)
582582
end
583583

584-
# diagm
585-
586-
LinearAlgebra.diagm(kv::Pair{<:Integer,<:CuVector}...) = _cuda_diagm(nothing, kv...)
587-
LinearAlgebra.diagm(m::Integer, n::Integer, kv::Pair{<:Integer,<:CuVector}...) = _cuda_diagm((Int(m),Int(n)), kv...)
588-
LinearAlgebra.diagm(v::CuVector) = LinearAlgebra.diagm(0 => v)
589-
LinearAlgebra.diagm(m::Integer, n::Integer, v::CuVector) = LinearAlgebra.diagm(m, n, 0 => v)
590-
591-
function _cuda_diagm(size, kv::Pair{<:Integer,<:CuVector}...)
592-
A = LinearAlgebra.diagm_container(size, kv...)
593-
for p in kv
594-
inds = LinearAlgebra.diagind(A, p.first)
595-
copyto!(view(A, inds), p.second)
596-
end
597-
return A
598-
end
599-
600-
function LinearAlgebra.diagm_container(size, kv::Pair{<:Integer,<:CuVector}...)
601-
T = promote_type(map(x -> eltype(x.second), kv)...)
602-
U = promote_type(T, typeof(zero(T)))
603-
return CUDA.zeros(U, LinearAlgebra.diagm_size(size, kv...)...)
604-
end
605-
606-
function LinearAlgebra.diagm_size(size::Nothing, kv::Pair{<:Integer,<:CuVector}...)
607-
mnmax = mapreduce(x -> length(x.second) + abs(Int(x.first)), max, kv; init=0)
608-
return mnmax, mnmax
609-
end
610-
function LinearAlgebra.diagm_size(size::Tuple{Int,Int}, kv::Pair{<:Integer,<:CuVector}...)
611-
mmax = mapreduce(x -> length(x.second) - min(0,Int(x.first)), max, kv; init=0)
612-
nmax = mapreduce(x -> length(x.second) + max(0,Int(x.first)), max, kv; init=0)
613-
m, n = size
614-
(m mmax && n nmax) || throw(DimensionMismatch(lazy"invalid size=$size"))
615-
return m, n
616-
end
617-
618584
# symmetric mul!
619585

620586
op_wrappers = ((identity, T -> 'N', identity),

0 commit comments

Comments
 (0)