Skip to content

Commit 413c397

Browse files
committed
Add a dispatch for LinearAlgebra.norm2
`norm(@view x[..], 2)` was previously leading to a call of `LinearAlgebra.generic_norm2` which led to a scalar indexing. This catches such cuda subarray norm2 calls earlier. Inf-norm and p-norm with cuda subarrays still leads to the following dispatches: ```julia LinearAlgebra.generic_normInf(x) = float(mapreduce(norm, max, x)) LinearAlgebra.generic_norm1(x) = mapreduce(float ∘ norm, +, x) ``` I am not sure if there is a better way to dispatch them.
1 parent f5100a1 commit 413c397

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

lib/cublas/linalg.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasF
138138
end
139139
end
140140

141+
function LinearAlgebra.norm2(x::SubArray{T,N,P} where {T<:Union{Float16, ComplexF16, CublasFloat}, N, P<:DenseCuArray{<:T}})
142+
return nrm2(x)
143+
end
144+
141145
LinearAlgebra.BLAS.asum(x::StridedCuArray{<:CublasFloat}) = asum(length(x), x)
142146

143147
function LinearAlgebra.axpy!(alpha::Number, x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{Float16, ComplexF16, CublasFloat}

test/libraries/cublas.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1767,6 +1767,20 @@ end
17671767
@view(p[reshape(1:(out*inn),out,inn)]) * x
17681768
end
17691769
end
1770+
1771+
@testset "nrm2 with strided inputs" begin # JuliaGPU/CUDA.jl#2280
1772+
cudaTypes = (Float16, Complex{Float16}, BFloat16, Complex{BFloat16}, Float32, Complex{Float32},
1773+
Float64, Complex{Float64}, Int8, Complex{Int8}, UInt8, Complex{UInt8},
1774+
Int16, Complex{Int16}, UInt16, Complex{UInt16}, Int32, Complex{Int32},
1775+
UInt32, Complex{UInt32}, Int64, Complex{Int64}, UInt64, Complex{UInt64})
1776+
for CT in cudaTypes
1777+
x = rand(CT, 10, 10, 10)
1778+
dx = CuArray(x)
1779+
dx_ = @view dx[3:6, 1:5, :]
1780+
x_ = @view x[3:6, 1:5, :]
1781+
@test norm(dx_, 2) norm(x_, 2)
1782+
end
1783+
end
17701784
end
17711785

17721786
############################################################################################

0 commit comments

Comments
 (0)