[EXPERIMENT] Replace numba-cuda with numba-cuda-mlir#9421
Conversation
cuda.compute is migrating its JIT/struct machinery from numba-cuda to numba-cuda-mlir (the MLIR-based successor). This first step adds the dependency and a single import surface; no behavior changes yet. - pyproject: add numba-cuda-mlir[cu12]/[cu13] to the cu12/cu13/sysctk runtime extras, alongside numba-cuda (which is still needed by _compile_op_to_llvm_bitcode on the v2/HostJIT path). Ignore numba_cuda_mlir.* in mypy like numba.*. - _mlir.py: central re-export of the numba-cuda-mlir symbols the migration uses (cuda.compile, type system, typing/lowering extension API, data models, MLIR builder + llvm/arith dialects), plus small from_numpy_dtype/as_numpy_dtype/struct_field_position helpers. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Move the user-operator compilation path off numba-cuda and onto numba-cuda-mlir (the MLIR-based successor). The gpu_struct typing/ lowering machinery still uses numba-cuda and is migrated separately. _odr_helpers.py: the void* operator wrappers are now ordinary Python device functions compiled with abi="c", instead of hand-written @intrinsic LLVM-IR codegen. A void* argument is a typed CPointer parameter (ABI-identical to void*); loads/stores are ptr[0] indexing; numba-cuda-mlir inlines the user op into the wrapper. The unused iterator advance/dereference wrappers are dropped (iterators compile their device code via C++, not numba). Stateful state is unpacked from the packed void* via a CPointer(CPointer(dtype)) view; heterogeneous state dtypes are rejected (no pure-Python int->typed-pointer cast). _compile_op_to_llvm_bitcode: numba-cuda-mlir's cuda.compile only emits ptx/ltoir, so the v2 (HostJIT) LLVM bitcode is produced by extracting LLVM IR from its internal MLIR -> LLVM translation (one step before libnvvm) and lowering that to bitcode with llvmlite. _jit.py: op compilation, return-type inference, stateful-op compilation, and the POD/pointer TypeDescriptor<->numba conversions now use numba-cuda-mlir (via the _mlir surface). Both v1 (ltoir) and v2 (bitcode) compile the same numba-cuda-mlir-jitted wrapper. _mlir.py: add compile_to_llvm_ir() encapsulating the MLIR->LLVM IR extraction. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Reimplement the gpu_struct type machinery against numba-cuda-mlir's MLIR extension API, completing the migration: _jit.py no longer imports classic numba at all (numba-cuda stays a dependency only for cuda.coop). Typing: - The struct type subclasses numba_cuda_mlir.types.Type; tuple/struct conversions still use can_convert_from + Conversion.safe. - Data model: register_model + PrimitiveModel building the backend type as an MLIR llvm.StructType.new_identified over the fields' MLIR value types (replaces numba-cuda's models.StructModel). - Field access typing uses an AttributeTemplate via typing_registry (replaces make_attribute_wrapper, which has no MLIR equivalent). - Constructor typing uses a ConcreteTemplate registered with typing_registry.register_global (replaces the numba.cuda cudadecl registry). Lowering (MLIR instead of llvmlite/cgutils): - Field getattr: lower_getattr_generic + llvm.extractvalue. - Constructor: lower() + llvm.UndefOp/insertvalue. - tuple->struct and struct->struct casts: lower_cast + llvm extract/insertvalue (a tuple value is a Python sequence of MLIR values pre-concretization). _mlir.py: export Conversion. Validated headlessly: struct construct + field access compile to LTO-IR via the same MLIR pattern. The aggregate casts are the area most in need of validation against the struct test suite. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
- test_void_ptr_wrapper_validation.py: rewrite against the new _make_wrapper_name API (the @intrinsic-era _ArgMode / _ArgSpec / _create_void_ptr_wrapper internals were removed); keep the sanitize_identifier coverage. - test_merge_sort.py: xfail the unsigned compare_op cases. The test comparator np.uint8(lhs < rhs) hits a numba-cuda-mlir bug that miscompiles unsigned integer comparison as signed; signed/float comparators are unaffected. Remove once numba-cuda-mlir fixes it. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
OverviewThis experimental PR begins migrating NVIDIA CCCL's JIT compilation and struct machinery from the classic Key ChangesNew Central Export Surface
JIT Compilation Migration
Wrapper Generation Refactoring
Dependency and Configuration
Test Infrastructure Updates
Technical DetailsCompilation PipelineThe migration from classic Numba to
Struct Type HandlingStruct types now use MLIR's native Known IssuesA merge_sort test has been xfailed for specific unsigned-comparison cases due to a known Breaking Changes
NotesThe PR introduction marks this as experimental. While the public API surface (e.g., WalkthroughThis PR migrates CCCL's JIT compilation infrastructure from direct Numba APIs to numba-cuda-mlir. The changes include a new ChangesCCCL JIT migration from Numba to numba-cuda-mlir
Possibly related issues
Suggested reviewers
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (3)
python/cuda_cccl/tests/conftest.py (1)
35-45: ⚡ Quick winsuggestion: Name extraction logic is duplicated in
tests/compute/conftest.pyline 238. Consider extractinggetattr(item, "originalname", None) or item.name.split("[")[0]into a shared helper function to ensure consistency if the logic changes.python/cuda_cccl/tests/compute/conftest.py (2)
136-229: ⚖️ Poor tradeoffsuggestion: The
_upstream_xfail_reasonfunction spans 94 lines with nested conditionals checking five distinct issues. Consider splitting into issue-specific helper functions (e.g.,_check_issue_123,_check_issue_121) to improve maintainability and testability. Each helper could return the reason string or None.
238-238: ⚡ Quick winsuggestion: Name extraction logic
getattr(item, "originalname", None) or item.name.split("[")[0]is duplicated fromtests/conftest.pyline 37. Consider consolidating into a shared helper to avoid drift.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: c3c453b3-c3dd-4683-b958-f22ef9e34b75
📒 Files selected for processing (7)
python/cuda_cccl/cuda/compute/_jit.pypython/cuda_cccl/cuda/compute/_mlir.pypython/cuda_cccl/cuda/compute/_odr_helpers.pypython/cuda_cccl/pyproject.tomlpython/cuda_cccl/tests/compute/conftest.pypython/cuda_cccl/tests/compute/test_void_ptr_wrapper_validation.pypython/cuda_cccl/tests/conftest.py
| def can_convert_from(self, typingctx, other): | ||
| if isinstance(other, types.UniTuple): | ||
| if isinstance(other, _mlir.types.UniTuple): | ||
| tuple_size = other.count | ||
| if tuple_size == len(field_types): | ||
| return Conversion.safe | ||
| return _mlir.Conversion.safe | ||
|
|
There was a problem hiding this comment.
important: Validate UniTuple element convertibility before returning Conversion.safe.
This branch accepts any same-length UniTuple, even when its element type cannot convert to every target field type. That makes heterogeneous struct casts type-check here and fail later in the cast lowering, unlike the _mlir.types.Tuple branch which already does the per-field check.
| # State arrays are passed to the (transformed) op as typed pointers; the op | ||
| # body indexes them (``state[i]``), which works on a CPointer. See | ||
| # _odr_helpers.create_stateful_op_void_ptr_wrapper for how the packed state | ||
| # void* is unpacked into one CPointer per state array. | ||
| state_dtypes = [_mlir.from_numpy_dtype(get_dtype(s)) for s in state_arrays] | ||
| state_ptr_types = [_mlir.types.CPointer(dt) for dt in state_dtypes] | ||
|
|
There was a problem hiding this comment.
important: Typing captured state arrays as CPointer silently drops array semantics.
The transformed operator body is unchanged, but each captured device array is now compiled as a bare pointer. Stateful operators that use .shape, multidimensional indexing, slicing, or any array attribute/method will stop typing/lowering because those operations are not available on CPointer. As per coding guidelines, python/cuda_cccl/**/*: Focus on Python API stability, CUDA array interoperability, memory ownership, JIT/NVRTC/nvJitLink behavior, package boundaries, user-defined operator correctness, tests, and examples.
Source: Coding guidelines
| _, return_type = _mlir.cuda.compile( | ||
| op, | ||
| all_numba_input_types, | ||
| device=True, | ||
| abi_info={"abi_name": abi_name}, | ||
| output="ltoir", | ||
| ) | ||
| # Convert return type to TypeDescriptor | ||
| output_type = cccl_types.from_numpy_dtype( | ||
| numba.np.numpy_support.as_dtype(return_type) | ||
| ) | ||
| output_type = cccl_types.from_numpy_dtype(_mlir.as_numpy_dtype(return_type)) |
There was a problem hiding this comment.
important: Stateful return-type inference no longer supports gpu_struct results.
This path converts return_type through _mlir.as_numpy_dtype, which only works for NumPy-backed scalar/POD types. The stateless path already uses _numba_type_to_type_descriptor; stateful operators returning a registered struct will now fail here or be reported as the wrong type.
| unique_state_dtypes = set(state_dtypes) | ||
| if len(unique_state_dtypes) > 1: | ||
| raise NotImplementedError( | ||
| "stateful operators that capture device arrays of differing dtypes " | ||
| f"are not supported (got {sorted(map(str, unique_state_dtypes))}); " | ||
| "all captured arrays must share a dtype" | ||
| ) | ||
| state_dtype = state_dtypes[0] | ||
|
|
||
| op_device = cuda.jit(device=True)(op) | ||
|
|
||
| def create_input_dereference_void_ptr_wrapper(deref_fn, state_ptr_type, value_type): | ||
| """Creates a wrapper function for input iterator dereference method. | ||
| # sig.args == (state_0, ..., state_{num_states-1}, input_0, ..., input_{K-1}) | ||
| input_types = list(sig.args)[num_states:] | ||
| return_type = sig.return_type | ||
|
|
||
| The wrapper takes 2 void* arguments: | ||
| - state pointer | ||
| - result pointer (function writes result here) | ||
| """ | ||
| arg_specs = [ | ||
| _ArgSpec(state_ptr_type, _ArgMode.PTR), | ||
| _ArgSpec(types.CPointer(value_type), _ArgMode.PTR), | ||
| ] | ||
| inner_sig = types.void(state_ptr_type, types.CPointer(value_type)) | ||
| return _create_void_ptr_wrapper(deref_fn, deref_fn.__name__, arg_specs, inner_sig) | ||
| wrapper_name = _make_wrapper_name(op.__name__) | ||
| input_names = [f"arg_{i}" for i in range(len(input_types))] | ||
|
|
||
| # states[j] reinterprets the j-th packed pointer as CPointer(state_dtype). | ||
| state_args = ", ".join(f"states[{j}]" for j in range(num_states)) | ||
| input_args = ", ".join(f"{name}[0]" for name in input_names) | ||
| call_args = ", ".join(a for a in (state_args, input_args) if a) | ||
| reconstruct = _is_gpu_struct_type(return_type) and _op_returns_tuple( | ||
| op_device, sig.args | ||
| ) | ||
| body, extra_namespace = _result_store_body(call_args, return_type, reconstruct) | ||
|
|
||
| wrapper_func = _build_wrapper( | ||
| wrapper_name, | ||
| ["states", *input_names, "result"], | ||
| body, | ||
| op_device, | ||
| extra_namespace, | ||
| ) | ||
|
|
||
| def create_output_dereference_void_ptr_wrapper(deref_fn, state_ptr_type, value_type): | ||
| """Creates a wrapper function for output iterator dereference method. | ||
|
|
||
| The wrapper takes 2 void* arguments: | ||
| - state pointer | ||
| - value pointer (value to write) | ||
| """ | ||
| arg_specs = [ | ||
| _ArgSpec(state_ptr_type, _ArgMode.PTR), | ||
| _ArgSpec(value_type, _ArgMode.LOAD), | ||
| ] | ||
| inner_sig = types.void(state_ptr_type, value_type) | ||
| return _create_void_ptr_wrapper(deref_fn, deref_fn.__name__, arg_specs, inner_sig) | ||
| wrapper_sig = types.void( | ||
| types.CPointer(types.CPointer(state_dtype)), | ||
| *(types.CPointer(t) for t in input_types), | ||
| types.CPointer(return_type), | ||
| ) |
There was a problem hiding this comment.
important: Heterogeneous captured state arrays become unsupported in this wrapper path. _jit._compile_stateful_op still builds one pointer type per captured array, but this code collapses the packed states blob to CPointer(CPointer(state_dtype)) and raises on mixed state_dtypes. That makes previously valid stateful ops fail as soon as they capture arrays with different element types. Preserve per-slot pointer reconstruction here instead of requiring a single shared dtype, and add a mixed-dtype captured-state regression test once this is fixed. As per coding guidelines, "Focus on Python API stability, CUDA array interoperability, memory ownership, JIT/NVRTC/nvJitLink behavior, package boundaries, user-defined operator correctness, tests, and examples."
Source: Coding guidelines
| cu12 = [ | ||
| "cuda-cccl[minimal-cu12]", | ||
| # numba / numba-cuda: used by cuda.coop (Numba-CUDA cooperative primitives). | ||
| "numba>=0.60.0", | ||
| "numba-cuda[cu12]>=0.23.0,!=0.27.*,!=0.28.*,!=0.29.*,!=0.30.0", | ||
| # numba-cuda-mlir: backend that JIT-compiles cuda.compute user operators and | ||
| # gpu_struct types (the MLIR-based successor to numba-cuda). | ||
| "numba-cuda-mlir[cu12]>=0.3.0", | ||
| ] | ||
| cu13 = [ | ||
| "cuda-cccl[minimal-cu13]", | ||
| # numba / numba-cuda: used by cuda.coop (Numba-CUDA cooperative primitives). | ||
| "numba>=0.60.0", | ||
| "numba-cuda[cu13]>=0.23.0,!=0.27.*,!=0.28.*,!=0.29.*,!=0.30.0" | ||
| "numba-cuda[cu13]>=0.23.0,!=0.27.*,!=0.28.*,!=0.29.*,!=0.30.0", | ||
| # numba-cuda-mlir: backend that JIT-compiles cuda.compute user operators and | ||
| # gpu_struct types (the MLIR-based successor to numba-cuda). | ||
| "numba-cuda-mlir[cu13]>=0.3.0", | ||
| ] | ||
| sysctk12 = [ | ||
| "cuda-cccl[minimal-sysctk12]", | ||
| "numba>=0.60.0", | ||
| "numba-cuda[cu12]>=0.23.0,!=0.27.*,!=0.28.*,!=0.29.*,!=0.30.0" | ||
| "numba-cuda[cu12]>=0.23.0,!=0.27.*,!=0.28.*,!=0.29.*,!=0.30.0", | ||
| "numba-cuda-mlir[cu12]>=0.3.0", | ||
| ] | ||
| sysctk13 = [ | ||
| "cuda-cccl[minimal-sysctk13]", | ||
| "numba>=0.60.0", | ||
| "numba-cuda[cu13]>=0.23.0,!=0.27.*,!=0.28.*,!=0.29.*,!=0.30.0", | ||
| "numba-cuda-mlir[cu13]>=0.3.0", | ||
| ] |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify numba-cuda-mlir package, version, and extras exist on PyPI
echo "=== Checking numba-cuda-mlir on PyPI ==="
curl -s https://pypi.org/pypi/numba-cuda-mlir/json | jq -r '
"Latest version: " + .info.version,
"Available versions: " + ([.releases | keys[] | select(. >= "0.3.0")] | join(", ")),
""
'
echo "=== Checking for extras in version 0.3.0 ==="
curl -s https://pypi.org/pypi/numba-cuda-mlir/0.3.0/json | jq -r '
.releases["0.3.0"][] |
select(.packagetype == "bdist_wheel") |
.filename
' | head -5
echo ""
echo "=== Checking for cu12/cu13 in package metadata ==="
curl -s https://pypi.org/pypi/numba-cuda-mlir/0.3.0/json | jq -r '
.info.requires_dist // [] |
map(select(contains("cu12") or contains("cu13"))) |
.[]
' | head -10Repository: NVIDIA/cccl
Length of output: 597
Add an upper bound for numba-cuda-mlir in cu12/cu13 extras
python/cuda_cccl/pyproject.toml currently specifies numba-cuda-mlir[cu12/cu13] >= 0.3.0 without an upper cap; add an upper bound (e.g., <0.4.0) or pin to ==0.3.0 to avoid future breaking dependency changes.
😬 CI Workflow Results🟥 Finished in 46m 07s: Pass: 29%/51 | Total: 3h 48m | Max: 36m 34sSee results here. |
Description
closes
Checklist