From 9f846554c935c29a44d80b5ce3d8370be86d1011 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 27 Feb 2025 12:50:33 -0800 Subject: [PATCH] Support `__cuda_array_interface__` types without `len()` --- .../distributed/comm/__rdd_patch_ucx.py | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py b/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py index 974e256..51da536 100644 --- a/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py +++ b/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py @@ -10,6 +10,7 @@ import functools import logging +import operator import os import struct import weakref @@ -30,7 +31,7 @@ has_cuda_context, ) from distributed.protocol.utils import host_array -from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes +from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes as _nbytes logger = logging.getLogger(__name__) @@ -51,6 +52,31 @@ cuda_context_created = False +def nbytes(f): + try: + return _nbytes(f) + except TypeError as e: + if hasattr(f, "__cuda_array_interface__"): + interface = f.__cuda_array_interface__ + shape = interface["shape"] + typestr = interface["typestr"] + strides = interface.get("strides") + + # Get element size from typestr, format is two character specifying + # the type and the latter part is the number of bytes. E.g., '