diff --git a/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py b/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py index 974e256..51da536 100644 --- a/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py +++ b/rapids_dask_dependency/patches/distributed/comm/__rdd_patch_ucx.py @@ -10,6 +10,7 @@ import functools import logging +import operator import os import struct import weakref @@ -30,7 +31,7 @@ has_cuda_context, ) from distributed.protocol.utils import host_array -from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes +from distributed.utils import ensure_ip, get_ip, get_ipv6, log_errors, nbytes as _nbytes logger = logging.getLogger(__name__) @@ -51,6 +52,31 @@ cuda_context_created = False +def nbytes(f): + try: + return _nbytes(f) + except TypeError as e: + if hasattr(f, "__cuda_array_interface__"): + interface = f.__cuda_array_interface__ + shape = interface["shape"] + typestr = interface["typestr"] + strides = interface.get("strides") + + # Get element size from typestr, format is two character specifying + # the type and the latter part is the number of bytes. E.g., '