diff --git a/src/array.jl b/src/array.jl index d66396bfce..03078e0024 100644 --- a/src/array.jl +++ b/src/array.jl @@ -1,4 +1,4 @@ -export CuArray, CuVector, CuMatrix, CuVecOrMat, cu, is_device, is_unified, is_host +export CuArray, CuScalar, CuVector, CuMatrix, CuVecOrMat, cu, is_device, is_unified, is_host ## array type @@ -123,6 +123,7 @@ end ## convenience constructors +const CuScalar{T} = CuArray{T,0} const CuVector{T} = CuArray{T,1} const CuMatrix{T} = CuArray{T,2} const CuVecOrMat{T} = Union{CuVector{T},CuMatrix{T}} @@ -371,7 +372,7 @@ is_host(a::CuArray) = memory_type(a) == HostMemory export DenseCuArray, DenseCuVector, DenseCuMatrix, DenseCuVecOrMat, StridedCuArray, StridedCuVector, StridedCuMatrix, StridedCuVecOrMat, - AnyCuArray, AnyCuVector, AnyCuMatrix, AnyCuVecOrMat + AnyCuArray, AnyCuScalar, AnyCuVector, AnyCuMatrix, AnyCuVecOrMat # dense arrays: stored contiguously in memory # @@ -426,6 +427,7 @@ end # anything that's (secretly) backed by a CuArray const AnyCuArray{T,N} = Union{CuArray{T,N}, WrappedArray{T,N,CuArray,CuArray{T,N}}} +const AnyCuScalar{T} = AnyCuArray{T,0} const AnyCuVector{T} = AnyCuArray{T,1} const AnyCuMatrix{T} = AnyCuArray{T,2} const AnyCuVecOrMat{T} = Union{AnyCuVector{T}, AnyCuMatrix{T}} diff --git a/src/device/array.jl b/src/device/array.jl index 59322349c5..2d3f4e350b 100644 --- a/src/device/array.jl +++ b/src/device/array.jl @@ -33,9 +33,18 @@ struct CuDeviceArray{T,N,A} <: DenseArray{T,N} new(ptr, maxsize, dims, prod(dims)) end -const CuDeviceVector = CuDeviceArray{T,1,A} where {T,A} -const CuDeviceMatrix = CuDeviceArray{T,2,A} where {T,A} - +const CuDeviceScalar{T} = CuDeviceArray{T,0,A} where {A} +const CuDeviceVector{T} = CuDeviceArray{T,1,A} where {A} +const CuDeviceMatrix{T} = CuDeviceArray{T,2,A} where {A} + +# anything that's (secretly) backed by a CuDeviceArray +export AnyCuDeviceArray, AnyCuDeviceScalar, AnyCuDeviceVector, AnyCuDeviceMatrix, AnyCuDeviceVecOrMat + +const AnyCuDeviceArray{T,N} = Union{CuDeviceArray{T,N},WrappedArray{T,N,CuDeviceArray,CuDeviceArray{T,N,A}}} where {A} +const AnyCuDeviceScalar{T} = AnyCuDeviceArray{T,0} +const AnyCuDeviceVector{T} = AnyCuDeviceArray{T,1} +const AnyCuDeviceMatrix{T} = AnyCuDeviceArray{T,2} +const AnyCuDeviceVecOrMat{T} = Union{AnyCuDeviceVector{T},AnyCuDeviceMatrix{T}} ## array interface