diff --git a/src/CUDAKernels.jl b/src/CUDAKernels.jl index c3d014b351..ee8f904bd7 100644 --- a/src/CUDAKernels.jl +++ b/src/CUDAKernels.jl @@ -57,14 +57,17 @@ end ## device operations function KA.ndevices(::CUDABackend) - return ndevices() + return Int(ndevices()) end -function KA.device(::CUDABackend) +function KA.device(::CUDABackend)::Int deviceid(CUDA.active_state().device) + 1 end -function KA.device!(::CUDABackend, id::Int32) +function KA.device!(backend::CUDABackend, id::Int) + if !(0 < id <= KA.ndevices(backend)) + throw(ArgumentError("Device id $id out of bounds.")) + end device!(id - 1) end