From 0e2ef842e423e7c490037b0e0023fbc43cfe66b8 Mon Sep 17 00:00:00 2001 From: Sharan Yalburgi Date: Fri, 22 Mar 2024 22:29:30 +0000 Subject: [PATCH] Add a dispatch for LinearAlgebra.norm2 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `norm(@view x[..], 2)` was previously leading to a call of `LinearAlgebra.generic_norm2` which led to a scalar indexing. This catches such cuda subarray norm2 calls earlier. Inf-norm and p-norm with cuda subarrays still lead to the following dispatches: ```julia LinearAlgebra.generic_normInf(x) = float(mapreduce(norm, max, x)) LinearAlgebra.generic_norm1(x) = mapreduce(float ∘ norm, +, x) ``` I am not sure if there is a better way to dispatch the above. should resolve https://github.com/JuliaGPU/CUDA.jl/issues/2280 --- lib/cublas/linalg.jl | 4 ++++ test/libraries/cublas.jl | 11 +++++++++++ 2 files changed, 15 insertions(+) diff --git a/lib/cublas/linalg.jl b/lib/cublas/linalg.jl index 763b32aafa..ff9263532b 100644 --- a/lib/cublas/linalg.jl +++ b/lib/cublas/linalg.jl @@ -138,6 +138,10 @@ function LinearAlgebra.norm(x::DenseCuArray{<:Union{Float16, ComplexF16, CublasF end end +function LinearAlgebra.norm2(x::SubArray{T,N,P} where {T<:Union{Float16, ComplexF16, CublasFloat}, N, P<:DenseCuArray{<:T}}) + return nrm2(x) +end + LinearAlgebra.BLAS.asum(x::StridedCuArray{<:CublasFloat}) = asum(length(x), x) function LinearAlgebra.axpy!(alpha::Number, x::StridedCuArray{T}, y::StridedCuArray{T}) where T<:Union{Float16, ComplexF16, CublasFloat} diff --git a/test/libraries/cublas.jl b/test/libraries/cublas.jl index 804abb219c..d3d858730c 100644 --- a/test/libraries/cublas.jl +++ b/test/libraries/cublas.jl @@ -1767,6 +1767,17 @@ end @view(p[reshape(1:(out*inn),out,inn)]) * x end end + + @testset "nrm2 with strided inputs" begin # JuliaGPU/CUDA.jl#2280 + cudaTypes = (Float16, ComplexF16, CublasFloat) + for CT in cudaTypes + x = rand(CT, 10, 10, 10) + dx = CuArray(x) + dx_ = @view dx[3:6, 1:5, :] + x_ = @view x[3:6, 1:5, :] + @test norm(dx_, 2) ≈ norm(x_, 2) + end + end end ############################################################################################