Skip to content

Commit

Permalink
Support __cuda_array_interface__ types without len()
Browse files Browse the repository at this point in the history
  • Loading branch information
pentschev committed Feb 27, 2025
1 parent 2993770 commit 9f84655
Showing 1 changed file with 27 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import functools
import logging
import operator
import os
import struct
import weakref
Expand All @@ -30,7 +31,7 @@
has_cuda_context,
)
from distributed.protocol.utils import host_array
from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes
from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes as _nbytes

logger = logging.getLogger(__name__)

Expand All @@ -51,6 +52,31 @@
cuda_context_created = False


def nbytes(f):
try:
return _nbytes(f)
except TypeError as e:
if hasattr(f, "__cuda_array_interface__"):
interface = f.__cuda_array_interface__
shape = interface["shape"]
typestr = interface["typestr"]
strides = interface.get("strides")

# Get element size from typestr, format is two character specifying
# the type and the latter part is the number of bytes. E.g., '<f4' for
# 32-bit (4-byte) float.
element_size = int(typestr[2:])

if strides is not None:
raise NotImplementedError("Support for strides is not implemented")
else:
num_elements = functools.reduce(operator.mul, shape)

return num_elements * element_size
else:
raise e


_warning_suffix = (
"This is often the result of a CUDA-enabled library calling a CUDA runtime function before "
"Dask-CUDA can spawn worker processes. Please make sure any such function calls don't happen "
Expand Down

0 comments on commit 9f84655

Please sign in to comment.