Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 117 additions & 36 deletions cuda_core/cuda/core/experimental/_kernel_arg_handler.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -128,67 +128,123 @@ cdef inline int prepare_ctypes_arg(
vector.vector[void*]& data_addresses,
arg,
const size_t idx) except -1:
if isinstance(arg, ctypes_bool):
cdef object arg_type = type(arg)
if arg_type is ctypes_bool:
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int8):
elif arg_type is ctypes_int8:
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int16):
elif arg_type is ctypes_int16:
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int32):
elif arg_type is ctypes_int32:
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int64):
elif arg_type is ctypes_int64:
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint8):
elif arg_type is ctypes_uint8:
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint16):
elif arg_type is ctypes_uint16:
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint32):
elif arg_type is ctypes_uint32:
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint64):
elif arg_type is ctypes_uint64:
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_float):
elif arg_type is ctypes_float:
return prepare_arg[float](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_double):
elif arg_type is ctypes_double:
return prepare_arg[double](data, data_addresses, arg.value, idx)
else:
return 1
# If no exact types are found, fallback to slower `isinstance` check
if isinstance(arg, ctypes_bool):
return prepare_arg[cpp_bool](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int8):
return prepare_arg[int8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int16):
return prepare_arg[int16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int32):
return prepare_arg[int32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_int64):
return prepare_arg[int64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint8):
return prepare_arg[uint8_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint16):
return prepare_arg[uint16_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint32):
return prepare_arg[uint32_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_uint64):
return prepare_arg[uint64_t](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_float):
return prepare_arg[float](data, data_addresses, arg.value, idx)
elif isinstance(arg, ctypes_double):
return prepare_arg[double](data, data_addresses, arg.value, idx)
else:
return 1


cdef inline int prepare_numpy_arg(
vector.vector[void*]& data,
vector.vector[void*]& data_addresses,
arg,
const size_t idx) except -1:
if isinstance(arg, numpy_bool):
cdef object arg_type = type(arg)
if arg_type is numpy_bool:
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int8):
elif arg_type is numpy_int8:
return prepare_arg[int8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int16):
elif arg_type is numpy_int16:
return prepare_arg[int16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int32):
elif arg_type is numpy_int32:
return prepare_arg[int32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int64):
elif arg_type is numpy_int64:
return prepare_arg[int64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint8):
elif arg_type is numpy_uint8:
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint16):
elif arg_type is numpy_uint16:
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint32):
elif arg_type is numpy_uint32:
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint64):
elif arg_type is numpy_uint64:
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float16):
elif arg_type is numpy_float16:
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float32):
elif arg_type is numpy_float32:
return prepare_arg[float](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float64):
elif arg_type is numpy_float64:
return prepare_arg[double](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex64):
elif arg_type is numpy_complex64:
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex128):
elif arg_type is numpy_complex128:
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
else:
return 1
# If no exact types are found, fallback to slower `isinstance` check
if isinstance(arg, numpy_bool):
return prepare_arg[cpp_bool](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int8):
return prepare_arg[int8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int16):
return prepare_arg[int16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int32):
return prepare_arg[int32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_int64):
return prepare_arg[int64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint8):
return prepare_arg[uint8_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint16):
return prepare_arg[uint16_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint32):
return prepare_arg[uint32_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_uint64):
return prepare_arg[uint64_t](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float16):
return prepare_arg[__half_raw](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float32):
return prepare_arg[float](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_float64):
return prepare_arg[double](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex64):
return prepare_arg[cpp_single_complex](data, data_addresses, arg, idx)
elif isinstance(arg, numpy_complex128):
return prepare_arg[cpp_double_complex](data, data_addresses, arg, idx)
else:
return 1


cdef class ParamHolder:
Expand All @@ -207,43 +263,68 @@ cdef class ParamHolder:
cdef size_t n_args = len(kernel_args)
cdef size_t i
cdef int not_prepared
cdef object arg_type
self.data = vector.vector[voidptr](n_args, nullptr)
self.data_addresses = vector.vector[voidptr](n_args)
for i, arg in enumerate(kernel_args):
if isinstance(arg, Buffer):
arg_type = type(arg)
if arg_type is Buffer:
# we need the address of where the actual buffer address is stored
if isinstance(arg.handle, int):
if type(arg.handle) is int:
# see note below on handling int arguments
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
# it's a CUdeviceptr:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, int):
elif arg_type is bool:
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue
elif arg_type is int:
# Here's the dilemma: We want to have a fast path to pass in Python
# integers as pointer addresses, but one could also (mistakenly) pass
# it with the intention of passing a scalar integer. It's a mistake
# bacause a Python int is ambiguous (arbitrary width). Our judgement
# call here is to treat it as a pointer address, without any warning!
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, float):
elif arg_type is float:
prepare_arg[double](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, complex):
elif arg_type is complex:
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, bool):
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue

not_prepared = prepare_numpy_arg(self.data, self.data_addresses, arg, i)
if not_prepared:
not_prepared = prepare_ctypes_arg(self.data, self.data_addresses, arg, i)
if not_prepared:
# TODO: revisit this treatment if we decide to cythonize cuda.core
if isinstance(arg, driver.CUgraphConditionalHandle):
if arg_type is driver.CUgraphConditionalHandle:
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
continue
# If no exact types are found, fallback to slower `isinstance` check
elif isinstance(arg, Buffer):
if isinstance(arg.handle, int):
prepare_arg[intptr_t](self.data, self.data_addresses, arg.handle, i)
continue
else:
self.data_addresses[i] = <void*><intptr_t>(arg.handle.getPtr())
continue
elif isinstance(arg, bool):
prepare_arg[cpp_bool](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, int):
prepare_arg[intptr_t](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, float):
prepare_arg[double](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, complex):
prepare_arg[cpp_double_complex](self.data, self.data_addresses, arg, i)
continue
elif isinstance(arg, driver.CUgraphConditionalHandle):
prepare_arg[intptr_t](self.data, self.data_addresses, <intptr_t>int(arg), i)
continue
# TODO: support ctypes/numpy struct
Expand Down
39 changes: 39 additions & 0 deletions cuda_core/docs/source/release/0.5.x-notes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
.. SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
.. SPDX-License-Identifier: Apache-2.0
.. currentmodule:: cuda.core.experimental

``cuda.core`` 0.5.x Release Notes
=================================


Highlights
----------

None.


Breaking Changes
----------------

- Python ``bool`` objects are now converted to C++ ``bool`` type when passed as kernel
arguments. Previously, they were converted to ``int``. This brings them inline
with ``ctypes.c_bool`` and ``numpy.bool_``.


New features
------------

None.


New examples
------------

None.


Fixes and enhancements
----------------------

None.
14 changes: 11 additions & 3 deletions cuda_core/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE

import ctypes

import numpy as np
import pytest

Expand Down Expand Up @@ -41,7 +43,9 @@ def _common_kernels_conditional():
unsigned int value);
__global__ void empty_kernel() {}
__global__ void add_one(int *a) { *a += 1; }
__global__ void set_handle(cudaGraphConditionalHandle handle, int value) { cudaGraphSetConditional(handle, value); }
__global__ void set_handle(cudaGraphConditionalHandle handle, bool value) {
cudaGraphSetConditional(handle, value);
}
__global__ void loop_kernel(cudaGraphConditionalHandle handle)
{
static int count = 10;
Expand Down Expand Up @@ -216,7 +220,9 @@ def test_graph_capture_errors(init_cuda):
gb.end_building().complete()


@pytest.mark.parametrize("condition_value", [True, False])
@pytest.mark.parametrize(
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False)]
)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
def test_graph_conditional_if(init_cuda, condition_value):
mod = _common_kernels_conditional()
Expand Down Expand Up @@ -278,7 +284,9 @@ def test_graph_conditional_if(init_cuda, condition_value):
b.close()


@pytest.mark.parametrize("condition_value", [True, False])
@pytest.mark.parametrize(
"condition_value", [True, False, ctypes.c_bool(True), ctypes.c_bool(False), np.bool_(True), np.bool_(False)]
)
@pytest.mark.skipif(tuple(int(i) for i in np.__version__.split(".")[:2]) < (2, 1), reason="need numpy 2.1.0+")
def test_graph_conditional_if_else(init_cuda, condition_value):
mod = _common_kernels_conditional()
Expand Down
Loading