Skip to content
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
85d4a3a
Add StridedLayout
stiepan Nov 21, 2025
23b35fe
Support wrapping ptr in Buffer, create SMV from buffer and layout, dl…
stiepan Nov 17, 2025
2470617
Documentation, linting, minor fixes
stiepan Nov 18, 2025
cf7eff5
Add NotImplemented copy_from/copy_to
stiepan Nov 21, 2025
79010b8
Adjust flattening scalars to numpy/cupy behavior, fix shape validatio…
stiepan Nov 26, 2025
4ca3567
Add StridedLayout tests
stiepan Nov 26, 2025
acdd6f8
Use explicit int32_t instead of int in integer fused type
stiepan Nov 26, 2025
60a0d66
Disable (for now) exporting the SMV via dlpack
stiepan Nov 26, 2025
1fa43d4
Revert dlpack changes
stiepan Nov 26, 2025
67c6c5e
Support layouts up to 64 dims
stiepan Nov 27, 2025
a96bec5
Use cydriver to query memory attributes, fix managed memory handling,…
stiepan Nov 27, 2025
91387b0
Test owner and mr cannot be specified together
stiepan Nov 27, 2025
91c0af9
Test Buffer.close with owner
stiepan Nov 27, 2025
b74ef2c
Add envelope checks (rquires_size_in_bytes, offset_bounds)
stiepan Nov 27, 2025
2c0343f
Docs, annotation fixes, remove dlpack export mentions
stiepan Nov 27, 2025
598a2f1
Add SMV.from_buffer/view tests
stiepan Nov 27, 2025
bbb227b
Layout tests for SMV created from CAI
stiepan Nov 28, 2025
26dfe3b
Fix missing host unregister call in buffer test
stiepan Dec 1, 2025
3adae5c
Fix num attrib on re-try
stiepan Dec 1, 2025
7554164
Call int on the buffer.handle
stiepan Dec 1, 2025
68b7a79
Merge branch 'main' into introduce_strided_layout_memview
stiepan Dec 3, 2025
edace66
Don't enforce Buffer having an owner when creating SMV
stiepan Dec 3, 2025
9f86322
Use np._s instead of a custom helper in the tests
stiepan Dec 3, 2025
4335b2e
Take lanes into account when computing the itemsize
stiepan Dec 3, 2025
6568e27
Merge branch 'main' into introduce_strided_layout_memview
stiepan Dec 4, 2025
cbf1d17
Move layout validation out of get_data_ptr helper
stiepan Dec 4, 2025
4767fbb
Disambiguate all_axes mask for layout flattening, add range flattenin…
stiepan Dec 4, 2025
5765a22
Bring back the intptr_t in SMV
stiepan Dec 8, 2025
7ec8961
Merge branch 'main' into introduce_strided_layout_memview
stiepan Dec 8, 2025
db75aa0
Reorder methods, adjust SMV tests to from_dlpack/form_cai methods
stiepan Dec 8, 2025
639ee5f
Move the Device import to top-level imports
stiepan Dec 8, 2025
9fb5dfb
Merge branch 'main' into introduce_strided_layout_memview
stiepan Dec 8, 2025
3375b4d
Merge branch 'main' into introduce_strided_layout_memview
stiepan Dec 9, 2025
66fc6e8
Merge branch 'main' into introduce_strided_layout_memview
leofang Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
693 changes: 693 additions & 0 deletions cuda_core/cuda/core/experimental/_layout.pxd

Large diffs are not rendered by default.

1,323 changes: 1,323 additions & 0 deletions cuda_core/cuda/core/experimental/_layout.pyx

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions cuda_core/cuda/core/experimental/_memory/_buffer.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,23 @@ from libc.stdint cimport uintptr_t
from cuda.core.experimental._stream cimport Stream


cdef struct _MemAttrs:
int device_id
bint is_device_accessible
bint is_host_accessible


cdef class Buffer:
cdef:
uintptr_t _ptr
size_t _size
MemoryResource _memory_resource
object _ipc_data
object _owner
object _ptr_obj
Stream _alloc_stream
_MemAttrs _mem_attrs
bint _mem_attrs_inited


cdef class MemoryResource:
Expand Down
111 changes: 102 additions & 9 deletions cuda_core/cuda/core/experimental/_memory/_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

cimport cython
from libc.stdint cimport uintptr_t

from cuda.core.experimental._memory._device_memory_resource cimport DeviceMemoryResource
Expand All @@ -12,7 +13,9 @@ from cuda.core.experimental._memory cimport _ipc
from cuda.core.experimental._stream cimport Stream_accept, Stream
from cuda.core.experimental._utils.cuda_utils cimport (
_check_driver_error as raise_if_driver_error,
HANDLE_RETURN,
)
from cuda.bindings cimport cydriver

import abc
from typing import TypeVar, Union
Expand Down Expand Up @@ -48,6 +51,8 @@ cdef class Buffer:
self._ipc_data = None
self._ptr_obj = None
self._alloc_stream = None
self._owner = None
self._mem_attrs_inited = False

def __init__(self, *args, **kwargs):
raise RuntimeError("Buffer objects cannot be instantiated directly. "
Expand All @@ -56,15 +61,19 @@ cdef class Buffer:
@classmethod
def _init(
cls, ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None,
stream: Stream | None = None, ipc_descriptor: IPCBufferDescriptor | None = None
stream: Stream | None = None, ipc_descriptor: IPCBufferDescriptor | None = None,
owner : object | None = None
):
cdef Buffer self = Buffer.__new__(cls)
self._ptr = <uintptr_t>(int(ptr))
self._ptr_obj = ptr
self._size = size
if mr is not None and owner is not None:
raise ValueError("owner and memory resource cannot be both specified together")
self._memory_resource = mr
self._ipc_data = IPCDataForBuffer(ipc_descriptor, True) if ipc_descriptor is not None else None
self._alloc_stream = <Stream>(stream) if stream is not None else None
self._owner = owner
return self

def __dealloc__(self):
Expand All @@ -76,7 +85,8 @@ cdef class Buffer:

@staticmethod
def from_handle(
ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None
ptr: DevicePointerT, size_t size, mr: MemoryResource | None = None,
owner: object | None = None,
) -> Buffer:
"""Create a new :class:`Buffer` object from a pointer.

Expand All @@ -88,9 +98,13 @@ cdef class Buffer:
Memory size of the buffer
mr : :obj:`~_memory.MemoryResource`, optional
Memory resource associated with the buffer
owner : object, optional
An object holding external allocation that the ``ptr`` points to.
The reference is kept as long as the buffer is alive.
The ``owner`` and ``mr`` cannot be specified together.
"""
# TODO: It is better to take a stream for latter deallocation
return Buffer._init(ptr, size, mr=mr)
return Buffer._init(ptr, size, mr=mr, owner=owner)

@classmethod
def from_ipc_descriptor(
Expand Down Expand Up @@ -228,7 +242,9 @@ cdef class Buffer:
"""Return the device ordinal of this buffer."""
if self._memory_resource is not None:
return self._memory_resource.device_id
raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource")
else:
Buffer_init_mem_attrs(self)
return self._mem_attrs.device_id

@property
def handle(self) -> DevicePointerT:
Expand All @@ -252,14 +268,18 @@ cdef class Buffer:
"""Return True if this buffer can be accessed by the GPU, otherwise False."""
if self._memory_resource is not None:
return self._memory_resource.is_device_accessible
raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource")
else:
Buffer_init_mem_attrs(self)
return self._mem_attrs.is_device_accessible

@property
def is_host_accessible(self) -> bool:
"""Return True if this buffer can be accessed by the CPU, otherwise False."""
if self._memory_resource is not None:
return self._memory_resource.is_host_accessible
raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource")
else:
Buffer_init_mem_attrs(self)
return self._mem_attrs.is_host_accessible

@property
def is_mapped(self) -> bool:
Expand All @@ -277,20 +297,93 @@ cdef class Buffer:
"""Return the memory size of this buffer."""
return self._size

@property
def owner(self) -> object:
"""Return the object holding external allocation."""
return self._owner


# Buffer Implementation
# ---------------------
cdef inline void Buffer_close(Buffer self, stream):
cdef Stream s
if self._ptr and self._memory_resource is not None:
s = Stream_accept(stream) if stream is not None else self._alloc_stream
self._memory_resource.deallocate(self._ptr, self._size, s)
if self._ptr:
if self._memory_resource is not None:
s = Stream_accept(stream) if stream is not None else self._alloc_stream
self._memory_resource.deallocate(self._ptr, self._size, s)
self._ptr = 0
self._memory_resource = None
self._owner = None
self._ptr_obj = None
self._alloc_stream = None


cdef Buffer_init_mem_attrs(Buffer self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: cdef void Buffer_init_mem_attrs(Buffer self):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I recall some weird issues with void ret type when it comes to exception propagation with cython. Won't this require except* clause?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it'd have to be

cef int Buffer_init_mem_attrs(Buffer self) except?-1:
    ...
    return 0

if we want to do this and gain a bit of perf. I am fine with the status quo.

if not self._mem_attrs_inited:
query_memory_attrs(self._mem_attrs, self._ptr)
self._mem_attrs_inited = True


cdef int query_memory_attrs(_MemAttrs &out, uintptr_t ptr) except -1 nogil:
cdef unsigned int memory_type = 0
cdef int is_managed = 0
cdef int device_id = 0
_query_memory_attrs(memory_type, is_managed, device_id, <cydriver.CUdeviceptr>ptr)

if memory_type == 0:
# unregistered host pointer
out.is_host_accessible = True
out.is_device_accessible = False
out.device_id = -1
# for managed memory, the memory type can be CU_MEMORYTYPE_DEVICE,
# so we need to check it first not to falsely claim it is not
# host accessible.
elif (
is_managed
or memory_type == cydriver.CUmemorytype.CU_MEMORYTYPE_HOST
):
# For pinned memory allocated with cudaMallocHost or paged-locked
# with cudaHostRegister, the memory_type is
# cydriver.CUmemorytype.CU_MEMORYTYPE_HOST.
# TODO(ktokarski): In some cases, the registered memory requires
# using different ptr for device and host, we could check
# cuMemHostGetDevicePointer and
# CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM
# to double check the device accessibility.
Comment on lines +416 to +420
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you happen to know what cases these are? This used to be the case with non-unified addressing but I don't think any platforms that CUDA supports are non-unified addressing these days.

Copy link
Member Author

@stiepan stiepan Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not find a comprehensive list, but digging a bit I learnt one notable exception for modern gpus: running on WSL. Indeed, trying to access cudahostregistered ptr on WSL fails (if the memory is allocated with cuda from the start, using the same pointer is fine).

import cuda.core.experimental as ccx
from cuda.bindings import runtime
from cuda.bindings import driver
import cupy as cp
import numpy as np

d = ccx.Device()
d.set_current()

def query_memory_attrs(ptr):
    attrs = (
        driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
        driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
        driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_IS_MANAGED,
    )
    ret, attr = driver.cuPointerGetAttributes(len(attrs), attrs, ptr)
    assert ret == 0
    return attr

a_np = np.empty(5, dtype=np.int32)
cpu_ptr = a_np.ctypes.data
ret, = runtime.cudaHostRegister(cpu_ptr, 20, 0)
assert ret == 0
assert query_memory_attrs(cpu_ptr)[0] == driver.CUmemorytype.CU_MEMORYTYPE_HOST
ret, attr = runtime.cudaPointerGetAttributes(cpu_ptr)
assert ret == 0
print(attr.devicePointer == cpu_ptr)
# On WSL, accessing cpu_ptr instead of attr.devicePointer fails
um = cp.cuda.UnownedMemory(cpu_ptr, 20, a_np, 0)
mem = cp.cuda.MemoryPointer(um, 0)
a_cp = cp.ndarray(shape=(5,), dtype=cp.int32, memptr=mem)
a_cp[:] = 1
print(a_np)
print(a_cp)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the same time, driver's cuPointerGetAttributes still reports that pointer as CU_MEMORYTYPE_HOST.

Copy link
Member Author

@stiepan stiepan Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, what should be the meaning of the is_device_accessible, is_host_accessible in this case?

  1. Should we check the device attribute and, if the attribute is 0, follow-up by retreiving host_ptr, device_ptr and set is_host_accessible=host_ptr==ptr, is_device_accessible=device_ptr==ptr?
  2. Or expect user to pass the correct pointer in a correct context, i.e. if the buffer is to be consumed on the gpu, user is expected to pass the device ptr?
  3. Or (not a fan) have buffer.device_ptr, buffer.host_ptr attributes?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's table this discussion for now. I'll create an issue to track this. I think the strided layout itself is already big enough that we want to keep the scope limited.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we checked memory_type == cydriver.CUmemorytype.CU_MEMORYTYPE_HOST and CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM, I would assume the first check would return True, and the second check would return True if an allocation made from cudaMallocHost can use the same ptr for device and host, so it would still return True for is_device_accessible?

Copy link
Member Author

@stiepan stiepan Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On WSL, the CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM is False.

If a ptr comes from cudaMallocHost or was passed to cudaHostRegister, the memory_type == cydriver.CUmemorytype.CU_MEMORYTYPE_HOST is True.

For cudaMallocHost, the ptr is truely device and host accessible, only the cudaHostRegister-ed one is troublesome - even though the memory type is CU_MEMORYTYPE_HOST, it cannot be used to access the mem from device. So my point was that if we were to say is_device_accessible is False whenever CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM is False, we would break otherwise valid cudaMallocHost usages.

import cuda.core.experimental as ccx
from cuda.bindings import runtime
from cuda.bindings import driver
import cupy as cp
import numpy as np
import ctypes


def query_memory_attrs(ptr):
    attrs = (
        driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_MEMORY_TYPE,
        driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
        driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_IS_MANAGED,
    )
    ret, attr = driver.cuPointerGetAttributes(len(attrs), attrs, ptr)
    assert ret == 0
    return attr

def as_numpy(ptr, shape, dtype):
    size = np.prod(shape) * dtype.itemsize
    return np.ndarray(
        shape=shape,
        dtype=dtype,
        buffer=memoryview((ctypes.c_char * size).from_address(ptr))
    )

def as_cupy(ptr, shape, dtype):
    size = np.prod(shape) * dtype.itemsize
    um = cp.cuda.UnownedMemory(ptr, size, owner=None, device_id=0)
    mem = cp.cuda.MemoryPointer(um, 0)
    return cp.ndarray(shape=shape, dtype=dtype, memptr=mem)

d = ccx.Device()
d.set_current()

# On WSL this is 0
print(driver.cuDeviceGetAttribute(driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM, 0))

shape = (5,)
dtype = np.dtype(np.int32)
size = np.prod(shape) * dtype.itemsize

# But this works
l = ccx.LegacyPinnedMemoryResource()
alloc_mem = l.allocate(np.prod(shape) * dtype.itemsize)
alloc_ptr = int(alloc_mem.handle)
# the pinned ptr is CU_POINTER_ATTRIBUTE_MEMORY_TYPE, as expected, 1 (aka CU_MEMORYTYPE_HOST) 
assert query_memory_attrs(alloc_ptr)[0] == driver.CUmemorytype.CU_MEMORYTYPE_HOST

a_np = as_numpy(alloc_ptr, shape, dtype)
a_cp = as_cupy(alloc_ptr, shape, dtype)
a_np[:] = 1
print(a_np)
print(a_cp)

# The problem is when we register the memory
a_np = np.empty(shape, dtype=dtype)
cpu_ptr = a_np.ctypes.data
ret, = runtime.cudaHostRegister(cpu_ptr, size, 0)
assert ret == 0
assert query_memory_attrs(cpu_ptr)[0] == driver.CUmemorytype.CU_MEMORYTYPE_HOST
reg_np = as_numpy(cpu_ptr, shape, dtype)
reg_cp = as_cupy(cpu_ptr, shape, dtype)
reg_np[:] = 2
print(reg_np)
# Here we end up with invalid access
print(reg_cp)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm my understanding is correct, on WSL CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM is False for both cudaMallocHost memory as well as cudaHostRegister memory, but the ptr returned from cudaMallocHost is in fact usable in device code while the ptr used for cudaHostRegister is not usable in device code?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, we could query the CU_POINTER_ATTRIBUTE_DEVICE_POINTER and CU_POINTER_ATTRIBUTE_HOST_POINTER attributes. On my local WSL setup it yields:

  • Same ptr for pinned host memory
  • Same ptr for managed memory
  • Different ptrs for device memory (0 for the CU_POINTER_ATTRIBUTE_HOST_POINTER attribute, as expected)
  • Different ptrs for registered host memory (neither are 0)

Our logic could be that we return is_device_accessible == True only when the ptr is equal to the ptr returned from CU_POINTER_ATTRIBUTE_DEVICE_POINTER and is_host_accessible == True when the ptr is equal to the ptr returned from CU_POINTER_ATTRIBUTE_HOST_POINTER.

That being said, querying these attributes are expensive and not sure if we want to pay this penalty...

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm my understanding is correct, on WSL CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM is False

That's right.

for both cudaMallocHost memory as well as cudaHostRegister memory

CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM is a device attribute, not a memory ptr attribute

the ptr returned from cudaMallocHost is in fact usable in device code while the ptr used for cudaHostRegister is not usable in device code

That's right. And using memory type is not enough to distinguish the two.

Alternatively, we could query the CU_POINTER_ATTRIBUTE_DEVICE_POINTER and CU_POINTER_ATTRIBUTE_HOST_POINTER attributes.

Yeah, I've been thinking about similar approach. According to cuMemHostGetDevicePointer, there is still a catch, though. In some cases, the device_ptr != host_ptr, even though the memory can be accessed through the host pointer from the device. 🥲 If I read the docs right (and assuming that's the only edge-case), we'd need to boundle it with the CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM check, so that the CU_DEVICE_ATTRIBUTE_CAN_USE_HOST_POINTER_FOR_REGISTERED_MEM or ptr == device_ptr would be accurate enough.

out.is_host_accessible = True
out.is_device_accessible = True
out.device_id = device_id
elif memory_type == cydriver.CUmemorytype.CU_MEMORYTYPE_DEVICE:
out.is_host_accessible = False
out.is_device_accessible = True
out.device_id = device_id
else:
raise ValueError(f"Unsupported memory type: {memory_type}")
return 0


cdef inline int _query_memory_attrs(unsigned int& memory_type, int & is_managed, int& device_id, cydriver.CUdeviceptr ptr) except -1 nogil:
cdef cydriver.CUpointer_attribute attrs[3]
cdef uintptr_t vals[3]
attrs[0] = cydriver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_MEMORY_TYPE
attrs[1] = cydriver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_IS_MANAGED
attrs[2] = cydriver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL
vals[0] = <uintptr_t><void*>&memory_type
vals[1] = <uintptr_t><void*>&is_managed
vals[2] = <uintptr_t><void*>&device_id

cdef cydriver.CUresult ret
ret = cydriver.cuPointerGetAttributes(3, attrs, <void**>vals, ptr)
if ret == cydriver.CUresult.CUDA_ERROR_NOT_INITIALIZED:
with cython.gil:
# Device class handles the cuInit call internally
from cuda.core.experimental import Device
Device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we call cuInit if that's what we're after, instead of calling Device?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Device has a lock around the cuInit call, so I wanted to stick to that single section handling the call. I guess calling cuInit more than once wouldn't hurt us either. Or I could move that part from Device.__new__ into separate utility. I just though this is one-time thingy, trying to recover from an error, so being a bit lavish about the import and Device call is not that bad. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my 2c is let's keep this for now, I don't have clever thoughts here 😅

ret = cydriver.cuPointerGetAttributes(3, attrs, <void**>vals, ptr)
HANDLE_RETURN(ret)
return 0


cdef class MemoryResource:
"""Abstract base class for memory resources that manage allocation and
deallocation of buffers.
Expand Down
Loading