diff --git a/docs/source/user_guide/block.md b/docs/source/user_guide/block.md index 3f498b19fd..fcacd30d12 100644 --- a/docs/source/user_guide/block.md +++ b/docs/source/user_guide/block.md @@ -8,21 +8,28 @@ The closely-related device-scope memory fence is documented separately in [grid] ## What's available -| Op | CUDA | AMDGPU | Vulkan | Metal | -|-------------------------------------------------|------|--------|--------|-------| -| `block.sync()` | yes | yes | yes | yes | -| `block.sync_all_nonzero(predicate)` | yes | yes\* | yes\* | yes\* | -| `block.sync_any_nonzero(predicate)` | yes | yes\* | yes\* | yes\* | -| `block.sync_count_nonzero(predicate)` | yes | yes\* | yes\* | yes\* | -| `block.mem_fence()` | yes | yes | yes | yes | -| `block.SharedArray(shape, dtype)` | yes | yes | yes | yes | -| `block.global_thread_idx()` | yes | yes | yes | yes | -| `block.thread_idx()` | yes | yes | yes | yes | +| Op | CUDA | AMDGPU | Vulkan | Metal | +|-------------------------------------------------|------|---------|--------|-------| +| `block.sync()` | yes | yes | yes | yes | +| `block.sync_all_nonzero(predicate)` | yes | yes\* | yes\* | yes\* | +| `block.sync_any_nonzero(predicate)` | yes | yes\* | yes\* | yes\* | +| `block.sync_count_nonzero(predicate)` | yes | yes\* | yes\* | yes\* | +| `block.mem_fence()` | yes | yes | yes | yes | +| `block.SharedArray(shape, dtype)` | yes | yes | yes | yes | +| `block.global_thread_idx()` | yes | yes | yes | yes | +| `block.thread_idx()` | yes | yes | yes | yes | +| `block.reduce_{add,min,max}(v, block_dim, dtype)` | yes | yes | yes | yes | +| `block.reduce_all_{add,min,max}(v, block_dim, dtype)` | yes | yes | yes | yes | +| `block.inclusive_{add,min,max}(v, block_dim, dtype)` | yes | yes | yes | yes | +| `block.exclusive_{add,min,max}(v, block_dim, ...)` | yes | yes | yes | yes | +| `block.radix_rank_match_atomic_or(...)` | yes | yes | yes | yes | Vulkan and Metal share a SPIR-V codegen path (Metal goes through MoltenVK → MSL); they are listed as separate columns because a couple of ops have Metal-specific caveats called out below. Footnoted entries are still functional, just with the limitations the footnote describes. \* On AMDGPU, Vulkan, and Metal the `block.sync_{all,any,count}_nonzero(p)` ops are *emulated* via shared memory (one shared `i32` slot + 2 block barriers + a single `atomic_add` per contributing thread) rather than a single hardware-fused barrier-with-reduction. CUDA has the fused NVPTX `barrier.cta.red.{and,or,popc}.aligned.all.sync` family of intrinsics so it stays on the fast path; the other backends do not have a direct analog (in particular, SPIR-V `OpGroupNonUniform*` only operates at subgroup scope reliably across Vulkan + Metal). All three reductions are routed through `atomic_add` rather than `atomic_or` / `atomic_and`: the latter trip a Metal-specific bug where `OpAtomicOr` on threadgroup memory silently no-ops via MoltenVK / SPIRV-Cross. The emulation is correct and portable but costs two `block.sync()`s plus one shared-memory atomic per call instead of a single barrier instruction; if you have an inner loop calling these ops millions of times, consider whether you can batch the predicate before reducing it. +`block.radix_rank_match_atomic_or` is portable across wave32 (CUDA, Vulkan-on-NVIDIA, Metal) and wave64 (AMDGPU — Quadrants pins every AMDGPU target to `+wavefrontsize64`). The match-mask shared-memory region picks its dtype at compile time: `i32` on wave32 (32-lane ballot fits in a single `i32`, with `subgroup.lanemask_le` and `clz` / `popcnt` on `u32`) and `i64` on wave64 (64-lane ballot needs 64 bits, with an inline u64 `lanemask_le` and `clz` / `popcnt` on `u64`). The two paths share steps 1–4 (per-subgroup histograms, column-sum upsweep, block exclusive scan, downsweep) and step 6 (publish bins + exclusive prefix); only the per-key match phase (step 5) diverges. Atomic `or` on `i64` shared memory is native on AMDGPU LDS; wave32 backends never see the `i64` path, so portability does not depend on SPIR-V / Metal supporting 64-bit threadgroup atomics. + Naming note: `block.mem_sync()` was recently renamed to `block.mem_fence()` for consistency with the project's "fence vs barrier" terminology. The old name is still available as a deprecated alias that emits `DeprecationWarning` on first use; new code should use `block.mem_fence()`. ## Barrier vs fence: the distinction that matters @@ -118,6 +125,98 @@ This is the thread's index *within its own block / workgroup*. To get the across Today only the X dimension is exposed (1-D blocks). For 2-D / 3-D blocks the calling code should compute the linear index from `block.thread_idx()` and the block-Y / Z dimensions itself, or stick to 1-D blocks (the dominant Quadrants idiom — `qd.loop_config(block_dim=N)` always sets the X extent). +### `block.reduce_{add,min,max}(value, block_dim, dtype)` + +Block-scope reductions following the standard two-stage subgroup-reduction strategy: each subgroup reduces its lanes via a `shuffle_down` tree, lane 0 of each subgroup publishes the subgroup aggregate to shared memory, then thread 0 sequentially folds the subgroup aggregates with the same operator. The result is valid in **thread 0 only**; other threads retain partial values. For the broadcast-to-every-thread variants see `block.reduce_all_{add,min,max}` below. + +Arguments: + +- `value`: per-thread input. +- `block_dim`: threads per block (compile-time `template()`). Must be a positive multiple of `subgroup.group_size()`, which resolves to 32 on CUDA / Metal / Vulkan-on-NVIDIA and 64 on AMDGPU. Passing a `block_dim` that is not a multiple of the subgroup size raises a compile-time error. +- `dtype`: scalar dtype for the inter-subgroup shared-memory staging slot; must match `value`'s type. + +The calling thread's block-local index is read internally via `block.thread_idx()`; the subgroup size is read from `subgroup.group_size()` at compile time. Neither is plumbed through as an argument. + +Cost: `log2(subgroup_size)` shuffles + 1 shared-memory write/read per subgroup + 1 `block.sync()` + `(block_dim / subgroup_size) - 1` ops on thread 0. When the block is exactly one subgroup the shared-memory path is short-circuited at compile time. + +```python +@qd.kernel +def kern(src: qd.types.ndarray(ndim=1), out: qd.types.ndarray(ndim=1)): + qd.loop_config(block_dim=128) + for i in range(N): + agg = qd.simt.block.reduce_add(src[i], 128, qd.f32) + if qd.simt.block.thread_idx() == 0: + out[i // 128] = agg +``` + +A generic `block.reduce(value, block_dim, op, dtype)` is also available for custom associative operators (e.g. bitwise ops, custom monoids). It accepts an `op: template()` `@qd.func` taking `(a, b)` and returning the same type as `value`. + +### `block.reduce_all_{add,min,max}(value, block_dim, dtype)` + +The broadcast variants of the above. Identical semantics, but the result is published to a one-slot `SharedArray` and read back by every thread after a second `block.sync()`. Use this when downstream code on every thread needs the block-wide aggregate (e.g. normalising each thread's value by the block sum). Cost: one extra `block.sync()` plus one shared-memory hop vs. the lane-0-only variants. The corresponding generic form is `block.reduce_all(value, block_dim, op, dtype)`. + +### `block.inclusive_{add,min,max}(value, block_dim, dtype)` + +Block-scope inclusive prefix scans via the standard two-stage subgroup-scan strategy: each subgroup does a Hillis-Steele scan via `subgroup` shuffles, the last lane of each subgroup publishes the subgroup aggregate to shared memory, then every thread sequentially folds the cross-subgroup prefix and applies its own subgroup's prefix to its scan value. **All threads receive a valid result.** After the call, thread `i` holds `op(v[0], v[1], ..., v[i])`. + +Args match `block.reduce_add` (`value, block_dim, dtype`). Cost: per-subgroup Hillis-Steele tree (`log2(subgroup_size)` shuffles) + 1 shared-memory write/read per subgroup + 1 `block.sync()` + `(block_dim / subgroup_size) - 1` ops on every thread (the cross-subgroup prefix is computed redundantly to avoid a second barrier). When the block is exactly one subgroup the shared-memory path is short-circuited at compile time. + +```python +@qd.kernel +def kern(src: qd.types.ndarray(ndim=1), out: qd.types.ndarray(ndim=1)): + qd.loop_config(block_dim=128) + for i in range(N): + out[i] = qd.simt.block.inclusive_add(src[i], 128, qd.i32) +``` + +The corresponding generic form is `block.inclusive_scan(value, block_dim, op, dtype)` for custom monoids. + +### `block.exclusive_{add,min,max}(value, block_dim[, identity], dtype)` + +Block-scope exclusive prefix scans. Same strategy and cost profile as `inclusive_*`, but each thread receives the prefix `op(v[0], ..., v[i-1])` instead — and thread 0 receives the operator's identity. + +- `exclusive_add`: identity is the additive zero; derived from `value - value` so callers do not need to pass it. After the call, thread 0 holds 0. +- `exclusive_min(..., identity, dtype)`: pass `identity` greater than or equal to every legal element of the input — typically `+∞` for floats or the dtype's maximum for integers. Thread 0 holds `identity`. There is no portable type-extreme derivable from `value` alone, so this op takes an explicit `identity` argument (mirrors `subgroup.exclusive_min`). +- `exclusive_max(..., identity, dtype)`: pass `identity` less than or equal to every legal element of the input — typically `-∞` for floats or the dtype's minimum for integers. Thread 0 holds `identity`. + +The corresponding generic form is `block.exclusive_scan(value, block_dim, op, identity, dtype)`. + +### `block.radix_rank_match_atomic_or(key, block_dim, radix_bits, bit_start, num_bits, bins, excl_prefix)` + +Block-level radix ranking via the atomic-OR match-and-count strategy (the workhorse of an SM90-style onesweep radix sort). Each thread holds one `u32` key; the function returns the key's stable rank within the block under the digit `(key >> bit_start) & ((1 << num_bits) - 1)`, and writes the per-digit count and exclusive-prefix arrays to two caller-supplied `block.SharedArray` outparams. + +Constraints (currently): + +- `block_dim` must equal `1 << radix_bits` (each digit gets exactly one thread for the per-thread bin / exclusive-prefix output). Typical configuration is `radix_bits=8, block_dim=256`. +- `subgroup.group_size()` must be 32 (CUDA / Metal / Vulkan-on-NVIDIA) or 64 (AMDGPU). The match path picks its ballot dtype at compile time — `i32` on wave32, `i64` on wave64 — and the function `static_assert`s this at compile time. +- One key per thread (`items_per_thread = 1`). Multi-item per thread is a future extension. +- `num_bits <= radix_bits`; `bit_start` is the offset of the digit's low bit. + +Args: + +- `key`: per-thread `u32` input. +- `block_dim`, `radix_bits`, `bit_start`, `num_bits`: all compile-time `template()`. +- `bins`: `block.SharedArray((1 << radix_bits,), qd.i32)`. After the call, `bins[d]` holds the count of keys whose digit equals `d`. +- `excl_prefix`: `block.SharedArray((1 << radix_bits,), qd.i32)`. After the call, `excl_prefix[d]` holds the exclusive prefix sum of `bins` up to digit `d`. + +The calling thread's block-local index is read internally via `block.thread_idx()`. + +Cost: 2 `block.sync()` + a handful of `subgroup.sync()` calls + 1 block exclusive scan + per-key `atomic_or` + leader-only `atomic_add` on shared memory. Shared-memory footprint at the default `radix_bits=8` configuration: 4 KiB `i32` for the per-subgroup offsets + a match-mask region whose dtype is wave-size-specific — 4 KiB `i32` on wave32 (8 subgroups × 256 digits × 4 B) or 8 KiB `i64` on wave64 (4 subgroups × 256 digits × 8 B). So 8 KiB total on wave32, 12 KiB total on wave64. + +```python +@qd.kernel +def kern(keys_in: qd.types.ndarray(ndim=1), ranks_out: qd.types.ndarray(ndim=1)): + qd.loop_config(block_dim=256) + for i in range(256): + bins = qd.simt.block.SharedArray((256,), qd.i32) + excl = qd.simt.block.SharedArray((256,), qd.i32) + ranks_out[i] = qd.simt.block.radix_rank_match_atomic_or( + keys_in[i], 256, 8, 0, 8, bins, excl + ) +``` + +The function inserts the necessary `block.sync()` retires before returning, so callers can read `bins` / `excl_prefix` immediately after the call without an extra barrier. + ## Related - [grid](grid.md) — the device-scope counterpart of `block.mem_fence()`. For coordination within a single block, prefer `block.mem_fence()` — it is cheaper. diff --git a/python/quadrants/lang/simt/block.py b/python/quadrants/lang/simt/block.py index 1a0f29ca74..e9b37bbc67 100644 --- a/python/quadrants/lang/simt/block.py +++ b/python/quadrants/lang/simt/block.py @@ -1,4 +1,5 @@ # type: ignore +# pyright: reportInvalidTypeForm=false, reportOperatorIssue=false, reportArgumentType=false import warnings @@ -7,8 +8,14 @@ from quadrants.lang import ops as _ops from quadrants.lang.expr import make_expr_group from quadrants.lang.kernel_impl import func as _func +from quadrants.lang.simt import subgroup as _subgroup +from quadrants.lang.simt.subgroup import _bin_add, _bin_max, _bin_min from quadrants.lang.util import quadrants_scope +from quadrants.types.annotations import template from quadrants.types.primitive_types import i32 as _i32 +from quadrants.types.primitive_types import i64 as _i64 +from quadrants.types.primitive_types import u32 as _u32 +from quadrants.types.primitive_types import u64 as _u64 def arch_uses_spv(arch): @@ -119,6 +126,598 @@ def subscript(self, *indices): ) +# --- Block reductions ------------------------------------------------------------------ +# +# Two-stage block reduce: each subgroup reduces its lanes via `shuffle_down`, lane 0 of every subgroup publishes the +# subgroup aggregate to shared memory, a `block.sync()` retires the publish, and thread 0 sequentially folds the +# subgroup aggregates with `op`. Cost: `log2(subgroup_size)` shuffles + 1 shared-mem write/read per subgroup + 1 +# `block.sync` + (NUM_SUBGROUPS - 1) ops on thread 0. The subgroup size is read from `subgroup.group_size()` (a +# compile-time Python int) at the top of every block op, so callers never plumb it in. +# +# The per-subgroup step delegates to `subgroup._reduce`, the generic-op private helper that mirrors +# `subgroup.reduce_add` / `_min` / `_max` but takes a caller-supplied template operator -- so the same block skeleton +# covers add / min / max / mul / bitwise / custom monoids. + + +@_func +def reduce(value, block_dim: template(), op: template(), dtype: template()): + """Block-scope reduction under a generic associative ``op``. Result is valid in **thread 0 only**; other threads + retain partial values. Use `reduce_all` if you need the result on every thread. + + Args: + value: per-thread input. + block_dim: threads per block (template). Must be a positive multiple of ``subgroup.group_size()`` (32 on CUDA + / Metal / Vulkan-on-NVIDIA, 64 on AMDGPU). + op: ``@qd.func`` taking two values and returning the same type as ``value``; callers can plug in custom + associative monoids (bitwise ops, multiplicative, matrix-multiply, etc.) without re-implementing the + per-subgroup + shared-mem skeleton. See `reduce_add` for the standard sum specialization. + dtype: scalar dtype for the inter-subgroup shared-memory staging slot (must match ``value``'s type). + + The calling thread's block-local index is read internally via `block.thread_idx()`; the subgroup size is read from + `subgroup.group_size()` at compile time. When the block is exactly one subgroup the shared-memory path is + short-circuited at compile time and the call costs only the per-subgroup tree. + """ + SUBGROUP_SIZE = impl.static(_subgroup.group_size()) + log2_subgroup = impl.static(_subgroup.log2_group_size()) + impl.static_assert( + impl.static(block_dim % SUBGROUP_SIZE == 0 and block_dim >= SUBGROUP_SIZE), + "block.reduce: block_dim must be a positive multiple of subgroup size", + ) + NUM_SUBGROUPS = impl.static(block_dim // SUBGROUP_SIZE) + + subgroup_agg = _subgroup._reduce(value, op, log2_subgroup) + + if impl.static(NUM_SUBGROUPS == 1): + return subgroup_agg + + tid = thread_idx() + subgroup_id = tid // SUBGROUP_SIZE + lane_id = tid & impl.static(SUBGROUP_SIZE - 1) + + shared = SharedArray(impl.static((NUM_SUBGROUPS,)), dtype) + if lane_id == 0: + shared[subgroup_id] = subgroup_agg + sync() + + result = subgroup_agg + if tid == 0: + result = shared[0] + for w in impl.static(range(1, NUM_SUBGROUPS)): + result = op(result, shared[impl.static(w)]) + return result + + +@_func +def reduce_all(value, block_dim: template(), op: template(), dtype: template()): + """Block-scope reduction under a generic associative ``op``, broadcast to every thread. Costs one extra + ``block.sync()`` plus a one-slot shared-memory broadcast vs. `reduce`. See `reduce` for the operator contract. + """ + result = reduce(value, block_dim, op, dtype) + bcast = SharedArray((1,), dtype) + if thread_idx() == 0: + bcast[0] = result + sync() + return bcast[0] + + +@_func +def reduce_add(value, block_dim: template(), dtype: template()): + """Block-scope sum reduction. Result valid in **thread 0 only**. See `reduce` for the argument contract.""" + return reduce(value, block_dim, _bin_add, dtype) + + +@_func +def reduce_min(value, block_dim: template(), dtype: template()): + """Block-scope min reduction. Result valid in **thread 0 only**. See `reduce` for the argument contract.""" + return reduce(value, block_dim, _bin_min, dtype) + + +@_func +def reduce_max(value, block_dim: template(), dtype: template()): + """Block-scope max reduction. Result valid in **thread 0 only**. See `reduce` for the argument contract.""" + return reduce(value, block_dim, _bin_max, dtype) + + +@_func +def reduce_all_add(value, block_dim: template(), dtype: template()): + """Block-scope sum reduction with the result broadcast to every thread. See `reduce_add` for the cheaper + thread-0-only variant and `reduce` for the argument contract. + """ + return reduce_all(value, block_dim, _bin_add, dtype) + + +@_func +def reduce_all_min(value, block_dim: template(), dtype: template()): + """Block-scope min reduction broadcast to every thread. See `reduce_all_add`.""" + return reduce_all(value, block_dim, _bin_min, dtype) + + +@_func +def reduce_all_max(value, block_dim: template(), dtype: template()): + """Block-scope max reduction broadcast to every thread. See `reduce_all_add`.""" + return reduce_all(value, block_dim, _bin_max, dtype) + + +# --- Block scans ----------------------------------------------------------------------- +# +# Two-stage block scan. Each subgroup does a Hillis-Steele scan via +# `subgroup.{_inclusive_scan, _exclusive_scan}`, the last lane of every subgroup publishes the +# subgroup aggregate to shared memory, then every thread sequentially folds the subgroup prefixes +# and applies its own subgroup's prefix to its scan value. All threads receive a valid result; +# cost: one subgroup scan + 1 shared-mem write/read per subgroup + 1 `block.sync()` + (NUM_SUBGROUPS - 1) +# ops on every thread (the cross-subgroup prefix is computed redundantly to avoid a second +# barrier). +# +# Inclusive: subgroup aggregate at the last lane is just the inclusive value, written directly. +# Exclusive: subgroup aggregate = `op(exclusive[last_lane], value[last_lane])`, since the +# exclusive scan does not include the last lane's input — we recover the inclusive total +# with one extra `op` on the publish path. + + +@_func +def inclusive_scan(value, block_dim: template(), op: template(), dtype: template()): + """Block-scope inclusive scan under a generic associative ``op``. Every thread receives a valid result. + + Args: + value: per-thread input. + block_dim: threads per block (template). Must be a positive multiple of ``subgroup.group_size()`` (32 on CUDA + / Metal / Vulkan-on-NVIDIA, 64 on AMDGPU). + op: ``@qd.func`` taking two values and returning the same type as ``value``; callers can plug in custom + associative monoids without re-implementing the per-subgroup + shared-mem skeleton. See `inclusive_add` + for the standard sum specialization. + dtype: scalar dtype for the inter-subgroup shared-memory staging slot; must match ``value``'s type. + + The calling thread's block-local index is read internally via `block.thread_idx()`; the subgroup size is read from + `subgroup.group_size()` at compile time. When the block is exactly one subgroup the cross-subgroup shared-memory path is + short-circuited at compile time and the call costs only the per-subgroup Hillis-Steele tree. + """ + SUBGROUP_SIZE = impl.static(_subgroup.group_size()) + log2_subgroup = impl.static(_subgroup.log2_group_size()) + impl.static_assert( + impl.static(block_dim % SUBGROUP_SIZE == 0 and block_dim >= SUBGROUP_SIZE), + "block.inclusive_scan: block_dim must be a positive multiple of subgroup size", + ) + NUM_SUBGROUPS = impl.static(block_dim // SUBGROUP_SIZE) + + inclusive = _subgroup._inclusive_scan(value, op, log2_subgroup) + + if impl.static(NUM_SUBGROUPS == 1): + return inclusive + + tid = thread_idx() + subgroup_id = tid // SUBGROUP_SIZE + lane_id = tid & impl.static(SUBGROUP_SIZE - 1) + + shared = SharedArray(impl.static((NUM_SUBGROUPS,)), dtype) + if lane_id == impl.static(SUBGROUP_SIZE - 1): + shared[subgroup_id] = inclusive + sync() + + # Sequential exclusive prefix scan over subgroup aggregates; each thread captures its own subgroup's prefix. Subgroup 0's + # prefix is unused (its inclusive value is already the prefix sum from the start of the block), so we never read + # `subgroup_prefix` on subgroup 0; the placeholder there exists only to give the variable a definite type. + block_aggregate = shared[0] + subgroup_prefix = block_aggregate + for w in impl.static(range(1, NUM_SUBGROUPS)): + if subgroup_id == impl.static(w): + subgroup_prefix = block_aggregate + addend = shared[impl.static(w)] + block_aggregate = op(block_aggregate, addend) + + if subgroup_id != 0: + inclusive = op(subgroup_prefix, inclusive) + return inclusive + + +@_func +def exclusive_scan(value, block_dim: template(), op: template(), identity, dtype: template()): + """Block-scope exclusive scan under a generic associative ``op`` with explicit ``identity``. Every thread receives + a valid result; thread 0 holds ``identity`` and thread ``i > 0`` holds ``op(v[0], ..., v[i-1])``. + + See `inclusive_scan` for the per-arg contract; in addition this op takes an explicit ``identity`` because exclusive + scan needs a definite value for thread 0 (and for the sentinel paths in `exclusive_min` / `exclusive_max`). See + `exclusive_add` for the additive specialization which derives a zero identity automatically. + """ + SUBGROUP_SIZE = impl.static(_subgroup.group_size()) + log2_subgroup = impl.static(_subgroup.log2_group_size()) + impl.static_assert( + impl.static(block_dim % SUBGROUP_SIZE == 0 and block_dim >= SUBGROUP_SIZE), + "block.exclusive_scan: block_dim must be a positive multiple of subgroup size", + ) + NUM_SUBGROUPS = impl.static(block_dim // SUBGROUP_SIZE) + + exclusive = _subgroup._exclusive_scan(value, op, identity, log2_subgroup) + + if impl.static(NUM_SUBGROUPS == 1): + return exclusive + + tid = thread_idx() + subgroup_id = tid // SUBGROUP_SIZE + lane_id = tid & impl.static(SUBGROUP_SIZE - 1) + + shared = SharedArray(impl.static((NUM_SUBGROUPS,)), dtype) + if lane_id == impl.static(SUBGROUP_SIZE - 1): + # Subgroup aggregate = inclusive at last lane = exclusive[last] + value[last] under `op`. + shared[subgroup_id] = op(exclusive, value) + sync() + + block_aggregate = shared[0] + subgroup_prefix = ( + identity # subgroup 0's prefix is the identity; subsequent subgroups overwrite this in their own iteration + ) + for w in impl.static(range(1, NUM_SUBGROUPS)): + if subgroup_id == impl.static(w): + subgroup_prefix = block_aggregate + addend = shared[impl.static(w)] + block_aggregate = op(block_aggregate, addend) + + if subgroup_id != 0: + exclusive = op(subgroup_prefix, exclusive) + return exclusive + + +@_func +def inclusive_add(value, block_dim: template(), dtype: template()): + """Block-scope inclusive prefix sum. After the call, thread ``i`` holds ``v[0] + v[1] + ... + v[i]``. See + `inclusive_scan` for the argument contract. + """ + return inclusive_scan(value, block_dim, _bin_add, dtype) + + +@_func +def inclusive_min(value, block_dim: template(), dtype: template()): + """Block-scope inclusive prefix min. See `inclusive_scan` for the argument contract.""" + return inclusive_scan(value, block_dim, _bin_min, dtype) + + +@_func +def inclusive_max(value, block_dim: template(), dtype: template()): + """Block-scope inclusive prefix max. See `inclusive_scan` for the argument contract.""" + return inclusive_scan(value, block_dim, _bin_max, dtype) + + +@_func +def exclusive_add(value, block_dim: template(), dtype: template()): + """Block-scope exclusive prefix sum. After the call, thread ``i > 0`` holds ``v[0] + v[1] + ... + v[i-1]`` and + thread 0 holds the additive identity (zero, in ``value``'s dtype, derived as ``value - value``). See + `exclusive_scan` for the argument contract. + """ + return exclusive_scan(value, block_dim, _bin_add, value - value, dtype) + + +@_func +def exclusive_min(value, block_dim: template(), identity, dtype: template()): + """Block-scope exclusive prefix min. Thread 0 holds ``identity``: the caller must supply a value that is ``>=`` + every legal element of the input (typically ``+∞`` for floats, the dtype's maximum for integers). See + `subgroup.exclusive_min` for why this op alone takes an explicit identity. + """ + return exclusive_scan(value, block_dim, _bin_min, identity, dtype) + + +@_func +def exclusive_max(value, block_dim: template(), identity, dtype: template()): + """Block-scope exclusive prefix max. Thread 0 holds ``identity``: the caller must supply a value that is ``<=`` + every legal element of the input (typically ``-∞`` for floats, the dtype's minimum for integers). See + `exclusive_min`. + """ + return exclusive_scan(value, block_dim, _bin_max, identity, dtype) + + +# --- Block radix rank ------------------------------------------------------------------ +# +# Block-level radix ranking via the atomic-OR match-and-count strategy. Each thread holds a single ``u32`` key; the +# function returns the key's stable rank within the block under the digit `(key >> bit_start) & ((1 << num_bits) - 1)`, +# and writes the per-digit count and exclusive-prefix arrays to caller-supplied shared-memory outparams. +# +# The algorithm runs in six steps: +# +# 1. ComputeHistogramsSubgroup: each subgroup builds a private digit histogram in shared memory via ``atomic_add``. +# 2. ComputeOffsetsSubgroupUpsweep: every thread sums per-subgroup histograms column-wise to produce a block-wide bin count +# for digit ``= tid``, while rewriting the subgroup histogram entries into per-subgroup running exclusive prefixes. +# 3. ExclusiveSum on the per-thread bin counts — uses the block exclusive scan defined above. +# 4. ComputeOffsetsSubgroupDownsweep: add the block-wide exclusive prefix into every subgroup's offset entry. +# 5. ComputeRanksItem (atomic-OR match): per-subgroup match via ``atomic_or`` on a per-digit lane-mask, then leader +# (highest set lane) does a single ``atomic_add`` on the subgroup offset and broadcasts via ``subgroup.shuffle``; each +# thread's rank is ``subgroup_offset + popc(bin_mask & lanemask_le) - 1``. +# 6. Write bin count + exclusive prefix to the outparam shared arrays. +# +# Shared-memory layout (all i32, total ``2 * BLOCK_SUBGROUPS * RADIX_DIGITS`` ints, 4096 ints = 16 KiB at the default +# 8-subgroup / 256-digit configuration): +# +# subgroup_offsets / subgroup_histograms : [0, BLOCK_SUBGROUPS * RADIX_DIGITS) (union backing) +# match_masks : [BLOCK_SUBGROUPS * RADIX_DIGITS, 2 * ...) +# +# Subgroup-scope barriers use ``subgroup.sync()`` (lowers to ``__syncwarp`` on CUDA, +# ``OpControlBarrier(ScopeSubgroup, ...)`` on SPIR-V, ``s_barrier`` on AMDGPU). ``LaneMaskLe()`` (the PTX intrinsic +# that gives a lane its less-than-or-equal lane mask) is replaced by ``subgroup.lanemask_le(lane)`` from the portable +# subgroup primitives. + + +@_func +def _subgroup_sync_fence(): + """Subgroup-scope barrier + memory fence — CUDA ``__syncwarp`` semantics across every backend. + + Why both ops: on CUDA, `subgroup.sync()` already lowers to `__syncwarp` which folds in a memory fence, so the + extra `subgroup.mem_fence()` is redundant (a `__threadfence_block`). On SPIR-V, however, the codegen emits + `subgroupBarrier` as `OpControlBarrier(ScopeSubgroup, ScopeSubgroup, 0)` — i.e. with **no** memory semantics — so + a bare `subgroup.sync()` does *not* publish prior shared-memory writes to other lanes. The radix rank algorithm + relies on the `__syncwarp` invariant that, after the barrier, every lane sees every other lane's prior + `atomic_or` / `atomic_add` to shared memory; pairing the barrier with `subgroup.mem_fence()` (which emits a real + `OpMemoryBarrier(ScopeSubgroup, AcquireRelease | UniformMemory | WorkgroupMemory)`) restores that invariant. + """ + _subgroup.sync() + _subgroup.mem_fence() + + +@_func +def _radix_rank_match_atomic_or_wave32( + key, + block_dim: template(), + radix_bits: template(), + bit_start: template(), + num_bits: template(), + bins, + excl_prefix, +): + """Wave32 implementation of `radix_rank_match_atomic_or`. See the public wrapper for the contract. + + Match-mask region is ``i32``; atomic_or, ballot, clz, popcnt all operate on 32 bits. This path is taken on CUDA, + Vulkan-on-NVIDIA, and Metal — none of which require ``i64`` threadgroup atomics. + """ + SUBGROUP_THREADS = impl.static(_subgroup.group_size()) + RADIX_DIGITS = impl.static(1 << radix_bits) + BLOCK_SUBGROUPS = impl.static(block_dim // SUBGROUP_THREADS) + NUM_BITS_MASK = impl.static((1 << num_bits) - 1) + BINS_PER_LANE = impl.static(RADIX_DIGITS // SUBGROUP_THREADS) + + # ``smem_offsets`` (i32) backs the per-subgroup histograms (step 1), in-place column-sum upsweep (step 2), folded + # prefixes (step 4), and the leader's atomic_add slot (step 5). ``smem_match`` (i32) backs the per-digit ballot + # mask in step 5. These were previously unioned into a single ``i32`` SharedArray; splitting them keeps the + # offsets path independent of the match-mask width so the wave64 sibling can pick ``i64`` for its match region. + smem_offsets = SharedArray(impl.static((BLOCK_SUBGROUPS * RADIX_DIGITS,)), _i32) + smem_match = SharedArray(impl.static((BLOCK_SUBGROUPS * RADIX_DIGITS,)), _i32) + + tid = thread_idx() + subgroup_idx = tid // SUBGROUP_THREADS + lane = _ops.cast(_subgroup.invocation_id(), _i32) + + # Step 1: zero per-subgroup histograms and match_masks. + for b in impl.static(range(BINS_PER_LANE)): + bin_idx = lane + impl.static(b * SUBGROUP_THREADS) + smem_offsets[subgroup_idx * RADIX_DIGITS + bin_idx] = _i32(0) + smem_match[subgroup_idx * RADIX_DIGITS + bin_idx] = _i32(0) + _subgroup_sync_fence() + + # Each thread atomic-adds 1 to its subgroup's bin for ``digit``. + digit = _ops.cast(_ops.bit_and(_ops.bit_shr(key, _u32(bit_start)), _u32(NUM_BITS_MASK)), _i32) + _ops.atomic_add(smem_offsets[subgroup_idx * RADIX_DIGITS + digit], _i32(1)) + + sync() # Publish per-subgroup histograms before column-sum. + + # Step 2: per-thread column sum across subgroups for digit == tid. Each thread collects the running exclusive prefix + # into ``bin_count`` while overwriting the subgroup histogram entries with their per-subgroup exclusive prefix. After the + # loop, ``bin_count`` is the block-wide total for digit == tid. + bin_count = _i32(0) + for j_subgroup in impl.static(range(BLOCK_SUBGROUPS)): + subgroup_count = smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] + smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] = bin_count + bin_count = bin_count + subgroup_count + + # Step 3: block-wide exclusive sum on the per-thread bin counts. + exclusive_digit_prefix = exclusive_add(bin_count, block_dim, _i32) + + # Step 4: ComputeOffsetsSubgroupDownsweep — fold the block-wide exclusive prefix into every subgroup's offset. + for j_subgroup in impl.static(range(BLOCK_SUBGROUPS)): + smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] = ( + smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] + exclusive_digit_prefix + ) + + sync() # Publish subgroup offsets before the per-key match phase. + + # Step 5: per-key atomic-OR match. ``items_per_thread == 1``, so this runs once per thread. + lane_mask = _i32(1) << lane + lane_mask_le_v = _subgroup.lanemask_le(_subgroup.invocation_id()) + + match_idx = subgroup_idx * RADIX_DIGITS + digit + + # Every thread ORs its lane_mask into the per-digit match mask of its subgroup. Threads with the same digit collide + # on the same shared-memory cell and produce a bitmask of "lanes in this subgroup that share this digit". + _ops.atomic_or(smem_match[match_idx], lane_mask) + _subgroup_sync_fence() + + # Read the bin_mask back and find the leader (highest matching lane) + intra-subgroup rank. ``clz`` here MUST run on + # the u32 (FindUMsb on SPIR-V): casting to i32 first triggers SPIR-V's FindSMsb, which for negative i32 (top bit + # set) returns the most-significant 0-bit instead of MSB-of-1, giving a leader that's one less than the actual + # highest matching lane. Concretely, with lane 31 holding the only key for its digit, bin_mask = 0x80000000; + # FindSMsb on -2147483648 returns 30 (highest 0-bit), so 31 - 30 = 1 elects lane 1 instead of lane 31, and lane + # 31's shuffle reads from lane 1 (= 0) — observed as last-lane ranks off by one on Vulkan / Metal. Now that the + # subgroup layer dispatches FindUMsb for unsigned ``clz``, passing the u32 directly emits the right intrinsic on + # every backend. + bin_mask = _ops.cast(smem_match[match_idx], _u32) + leader = _i32(31) - _ops.cast(_ops.clz(bin_mask), _i32) + popc = _ops.popcnt(_ops.bit_and(bin_mask, lane_mask_le_v)) + + # Leader claims `popc` slots from this subgroup's slice of the subgroup_offsets entry. + subgroup_offset = _i32(0) + if lane == leader: + subgroup_offset = _ops.atomic_add(smem_offsets[subgroup_idx * RADIX_DIGITS + digit], _ops.cast(popc, _i32)) + + # Leader broadcasts its claimed offset to every lane in the subgroup. + subgroup_offset = _subgroup.shuffle(subgroup_offset, _ops.cast(leader, _u32)) + + # Leader resets the match mask so subsequent passes (or items_per_thread > 1) start clean. + if lane == leader: + smem_match[match_idx] = _i32(0) + _subgroup_sync_fence() + + rank = subgroup_offset + _ops.cast(popc, _i32) - _i32(1) + + # Step 6: publish bins + exclusive_digit_prefix to the caller-supplied outparams. ``block_dim == RADIX_DIGITS`` so + # every thread writes exactly one digit. Followed by a ``block.sync()`` so the caller can read these arrays + # without having to add their own retiring barrier. + bins[tid] = bin_count + excl_prefix[tid] = exclusive_digit_prefix + sync() + + return rank + + +@_func +def _radix_rank_match_atomic_or_wave64( + key, + block_dim: template(), + radix_bits: template(), + bit_start: template(), + num_bits: template(), + bins, + excl_prefix, +): + """Wave64 implementation of `radix_rank_match_atomic_or`. See the public wrapper for the contract. + + Match-mask region is ``i64``; atomic_or on shared ``i64`` is native on AMDGPU LDS. Subgroup ``lanemask_le`` is + u32-only by contract (see ``subgroup.py``: "lane_id in [0, 31]"), so the 64-lane form is synthesized inline as + ``one_at_lane | (one_at_lane - 1)`` — avoids the UB of shifting by 64 when lane == 63. + + Structural twin of the wave32 path; duplicated rather than parameterised because Quadrants' AST transformer + doesn't carry locals across ``if impl.static`` branches and the smem_match dtype + match-phase widths are the only + things that differ. + """ + SUBGROUP_THREADS = impl.static(_subgroup.group_size()) + RADIX_DIGITS = impl.static(1 << radix_bits) + BLOCK_SUBGROUPS = impl.static(block_dim // SUBGROUP_THREADS) + NUM_BITS_MASK = impl.static((1 << num_bits) - 1) + BINS_PER_LANE = impl.static(RADIX_DIGITS // SUBGROUP_THREADS) + + smem_offsets = SharedArray(impl.static((BLOCK_SUBGROUPS * RADIX_DIGITS,)), _i32) + smem_match = SharedArray(impl.static((BLOCK_SUBGROUPS * RADIX_DIGITS,)), _i64) + + tid = thread_idx() + subgroup_idx = tid // SUBGROUP_THREADS + lane = _ops.cast(_subgroup.invocation_id(), _i32) + + # Step 1: zero per-subgroup histograms and match_masks. + for b in impl.static(range(BINS_PER_LANE)): + bin_idx = lane + impl.static(b * SUBGROUP_THREADS) + smem_offsets[subgroup_idx * RADIX_DIGITS + bin_idx] = _i32(0) + smem_match[subgroup_idx * RADIX_DIGITS + bin_idx] = _i64(0) + _subgroup_sync_fence() + + digit = _ops.cast(_ops.bit_and(_ops.bit_shr(key, _u32(bit_start)), _u32(NUM_BITS_MASK)), _i32) + _ops.atomic_add(smem_offsets[subgroup_idx * RADIX_DIGITS + digit], _i32(1)) + + sync() + + bin_count = _i32(0) + for j_subgroup in impl.static(range(BLOCK_SUBGROUPS)): + subgroup_count = smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] + smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] = bin_count + bin_count = bin_count + subgroup_count + + exclusive_digit_prefix = exclusive_add(bin_count, block_dim, _i32) + + for j_subgroup in impl.static(range(BLOCK_SUBGROUPS)): + smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] = ( + smem_offsets[impl.static(j_subgroup * RADIX_DIGITS) + tid] + exclusive_digit_prefix + ) + + sync() + + # Step 5 — wave64 specifics: u64 ballot mask via inline ``one_at_lane | (one_at_lane - 1)`` (avoids UB on lane=63), + # atomic_or on the i64 match cell, clz / popcnt on u64. Leader formula is ``63 - clz(u64)``. + lane_u64 = _ops.cast(lane, _u64) + lane_mask = _u64(1) << lane_u64 + lane_mask_le_v = lane_mask | (lane_mask - _u64(1)) + + match_idx = subgroup_idx * RADIX_DIGITS + digit + + _ops.atomic_or(smem_match[match_idx], _ops.cast(lane_mask, _i64)) + _subgroup_sync_fence() + + # u64 clz via FindUMsb-equivalent on every backend; the wave32 path's caveat about FindSMsb vs FindUMsb on i64 + # would apply on SPIR-V wave64 devices if those existed (today wave64 = AMDGPU only). + bin_mask = _ops.cast(smem_match[match_idx], _u64) + leader = _i32(63) - _ops.cast(_ops.clz(bin_mask), _i32) + popc = _ops.popcnt(_ops.bit_and(bin_mask, lane_mask_le_v)) + + subgroup_offset = _i32(0) + if lane == leader: + subgroup_offset = _ops.atomic_add(smem_offsets[subgroup_idx * RADIX_DIGITS + digit], _ops.cast(popc, _i32)) + + subgroup_offset = _subgroup.shuffle(subgroup_offset, _ops.cast(leader, _u32)) + + if lane == leader: + smem_match[match_idx] = _i64(0) + _subgroup_sync_fence() + + rank = subgroup_offset + _ops.cast(popc, _i32) - _i32(1) + + bins[tid] = bin_count + excl_prefix[tid] = exclusive_digit_prefix + sync() + + return rank + + +@_func +def radix_rank_match_atomic_or( + key, + block_dim: template(), + radix_bits: template(), + bit_start: template(), + num_bits: template(), + bins, + excl_prefix, +): + """Block-level radix rank via the atomic-OR match-and-count strategy. + + Returns the calling thread's stable rank within the block under digit ``(key >> bit_start) & ((1 << num_bits) - 1)``. + + Args: + key: ``u32`` key, one per thread. + block_dim: threads per block (template). Must equal ``RADIX_DIGITS = 1 << radix_bits``: each digit gets exactly + one thread for the per-thread bin/excl_prefix output. + radix_bits: number of bits in the digit (template). Typical onesweep value is 8, giving 256 digits. + bit_start: starting bit of the digit (template). Used as ``key >> bit_start``. + num_bits: actual digit width in bits (template), with ``num_bits <= radix_bits``. Bits ``[bit_start, bit_start + + num_bits)`` of ``key`` are extracted. + bins: ``block.SharedArray((1 << radix_bits,), qd.i32)`` outparam. After the call, ``bins[d]`` holds the count + of keys whose digit equals ``d``. Caller is responsible for allocating this array exactly once per kernel. + excl_prefix: ``block.SharedArray((1 << radix_bits,), qd.i32)`` outparam. After the call, ``excl_prefix[d]`` holds + the exclusive prefix sum of ``bins`` up to digit ``d``. Caller allocates as for ``bins``. + + The calling thread's block-local index is read internally via `block.thread_idx()`; the subgroup size is read from + `subgroup.group_size()` at compile time. Supports both wave32 (CUDA, Vulkan-on-NVIDIA, Metal) and wave64 + (AMDGPU — Quadrants pins every AMDGPU target to ``+wavefrontsize64``). Dispatches to one of two private + implementations at compile time based on subgroup size; the match-mask shared-memory region's dtype is the only + semantic difference (``i32`` on wave32, ``i64`` on wave64), but Quadrants' AST transformer doesn't carry locals + across ``if impl.static`` branches so the two paths are written as separate ``@func`` bodies. Atomic ``or`` on + ``i64`` shared memory is native on AMDGPU's LDS; wave32 backends never see the ``i64`` path so portability does + not depend on SPIR-V / Metal supporting 64-bit threadgroup atomics. + + Pre/post: caller must guarantee uniform control flow on entry; the function inserts the necessary ``block.sync()`` + and ``subgroup.sync()`` retires. After the call, ``bins`` and ``excl_prefix`` are visible to every thread without a + further ``block.sync()`` (we sync internally before exit). + + Cost: ``~items_per_thread`` atomic_or + atomic_add per pass on shared memory + 2 ``block.sync()`` + 1 block exclusive + scan + ``BLOCK_SUBGROUPS`` ops per thread for the column-sum upsweep. Shared-memory footprint at the default + ``radix_bits=8``: 4 KiB ``i32`` for subgroup offsets + 4 KiB ``i32`` (wave32) or 8 KiB ``i64`` (wave64) for the + match-mask region — so 8 KiB total on wave32, 12 KiB on wave64. + """ + SUBGROUP_THREADS = impl.static(_subgroup.group_size()) + impl.static_assert( + impl.static(SUBGROUP_THREADS == 32 or SUBGROUP_THREADS == 64), + "block.radix_rank_match_atomic_or: subgroup size must be 32 or 64", + ) + RADIX_DIGITS = impl.static(1 << radix_bits) + impl.static_assert( + impl.static(block_dim == RADIX_DIGITS), + "block.radix_rank_match_atomic_or: block_dim must equal RADIX_DIGITS (1 << radix_bits)", + ) + if impl.static(SUBGROUP_THREADS == 32): + return _radix_rank_match_atomic_or_wave32(key, block_dim, radix_bits, bit_start, num_bits, bins, excl_prefix) + return _radix_rank_match_atomic_or_wave64(key, block_dim, radix_bits, bit_start, num_bits, bins, excl_prefix) + + # Shared-memory emulation of CUDA's hardware-fused barrier-with-reduction ops, used on backends that lack a direct # equivalent (AMDGPU has no NVPTX `barrier.cta.red.*` analog; SPIR-V's `OpGroupNonUniform*` only operate at subgroup # scope reliably across Vulkan + Metal). diff --git a/python/quadrants/lang/simt/subgroup.py b/python/quadrants/lang/simt/subgroup.py index 3eb57d3613..8e42f46f32 100644 --- a/python/quadrants/lang/simt/subgroup.py +++ b/python/quadrants/lang/simt/subgroup.py @@ -314,6 +314,27 @@ def invocation_id(): return impl.call_internal("subgroupInvocationId", with_runtime_context=False) +@func +def _reduce(value, op: template(), log2_size: template()): + """Tree-reduce ``value`` across ``2**log2_size`` consecutive lanes via ``shuffle_down`` under a caller-supplied + binary ``op``. Mirrors the operator-specialized public ``reduce_add`` / ``reduce_min`` / ``reduce_max`` but takes a + template operator so cross-module callers (currently ``block.reduce`` and the typed ``block.reduce_{add,min,max}``) + can compose the per-subgroup step with custom monoids without reimplementing the shuffle tree. + + Result is valid in lane 0 of each ``2**log2_size`` group; other lanes hold partial values. ``log2_size`` is a + compile-time template, so the body unrolls into ``log2_size`` shuffle+op pairs. Caller must ensure + ``2**log2_size`` does not exceed the active subgroup size on the target. + + Underscore-prefixed because the generic-op contract is fragile (``op`` must be associative and side-effect-free) + and we don't want to invite ad-hoc subgroup-scope reductions from arbitrary kernels; the typed ``reduce_{add,min, + max}`` cover the common cases. + """ + for i in impl.static(range(log2_size)): + offset = impl.static(1 << (log2_size - 1 - i)) + value = op(value, shuffle_down(value, u32(offset))) + return value + + @func def reduce_add(value, log2_size: template()): """Sum ``value`` across ``2**log2_size`` consecutive lanes via a ``shuffle_down`` tree. diff --git a/tests/python/test_atomic.py b/tests/python/test_atomic.py index 5ed3e03d52..fdff6c5de8 100644 --- a/tests/python/test_atomic.py +++ b/tests/python/test_atomic.py @@ -643,8 +643,8 @@ def kern(): assert int(f[None]) == expected -# Pins the doc claim that bitwise atomics on float dtypes raise a type error at trace time (atomics page: "Integer -# dtypes only -- passing f32 / f64 raises a type error at trace time"). Enforced by the is_integral check in +# Pins the doc claim that bitwise atomics on float dtypes raise a type error at compile time (atomics page: "Integer +# dtypes only -- passing f32 / f64 raises a type error at compile time"). Enforced by the is_integral check in # AtomicOpExpression::type_check (quadrants/ir/frontend_ir.cpp) for bit_and / bit_or / bit_xor. @pytest.mark.parametrize("op", ["and", "or", "xor"]) @pytest.mark.parametrize("dtype", [qd.f32, qd.f64]) diff --git a/tests/python/test_simt.py b/tests/python/test_simt.py index 02a5c50e67..0c0ca43f12 100644 --- a/tests/python/test_simt.py +++ b/tests/python/test_simt.py @@ -5,7 +5,7 @@ from pytest import approx import quadrants as qd -from quadrants.lang.simt import subgroup +from quadrants.lang.simt import block, subgroup from tests import test_utils @@ -811,6 +811,663 @@ def _init_field(field, n, dtype): field[i] = (i + 1) if dtype in int_dtypes else 1.0000000000001 * (i + 1) +# --- Block reduce tests ---------------------------------------------------------------- +# +# `qd.simt.block.reduce_{add,min,max}` is a two-stage block reduce: per-subgroup +# `shuffle_down` tree, lane 0 of each subgroup publishes the subgroup aggregate to shared +# memory, then thread 0 sequentially folds the subgroup aggregates. Result is valid +# in thread 0 only; the `reduce_all_*` variants broadcast it to every thread via +# one extra `block.sync()` plus a one-slot shared-memory hop. +# +# We exercise three regimes per arch by parameterizing on subgroups-per-block rather +# than absolute block_dim: 1 subgroup (single-subgroup short-circuit path — no +# shared memory, no cross-subgroup fold), 4 subgroups (multi-subgroup), 8 subgroups +# (multi-subgroup, larger). The host-side ``_arch_subgroup_size()`` maps to +# ``block_dim`` at test-body entry, so wave32 archs (CUDA / Metal / NVIDIA Vulkan) +# get ``[32, 128, 256]`` and wave64 (AMDGPU) gets ``[64, 256, 512]`` — both cover +# the single-subgroup short-circuit + multi-subgroup paths without skipping +# anything at collection time. Inside the kernel, the subgroup size is still read +# from ``subgroup.group_size()`` at compile time, so the same source compiles +# correctly on every backend without an API knob. + +_BLOCK_REDUCE_DTYPES = [qd.i32, qd.f32] +_BLOCK_REDUCE_SG_PER_BLOCK = [1, 4, 8] + + +def _arch_subgroup_size(): + """Return the subgroup size for the active arch (host side). + + AMDGPU is pinned to wave64 in Quadrants; every other supported arch is wave32. This is the host-side mirror of + the kernel-side ``subgroup.group_size()`` and is used by block-* tests to derive ``block_dim`` from a + subgroups-per-block parameter so each arch tests its own canonical sizes. + """ + return 64 if qd.lang.impl.current_cfg().arch == qd.amdgpu else 32 + + +def _ref_reduce_add(values): + return sum(values) + + +def _ref_reduce_min(values): + return min(values) + + +def _ref_reduce_max(values): + return max(values) + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_reduce_add(dtype, sg_per_block): + """Block sum-reduce: thread 0 of each block holds `sum(src[block_base:block_base+block_dim])`.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=NUM_BLOCKS) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + tid = i % block_dim + agg = block.reduce_add(src[i], block_dim, dtype) + if tid == 0: + dst[i // block_dim] = agg + + _init_field(src, N, dtype) + foo() + + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_reduce_add(block_vals) + if dtype == qd.i32: + assert dst[b] == expected, f"block {b}: got {dst[b]}, expected {expected}" + else: + assert abs(dst[b] - expected) < 1e-4 * abs(expected), f"block {b}: got {dst[b]}, expected {expected}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_reduce_min(dtype, sg_per_block): + """Block min-reduce: thread 0 of each block holds `min(src[block_base:block_base+block_dim])`.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=NUM_BLOCKS) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + tid = i % block_dim + agg = block.reduce_min(src[i], block_dim, dtype) + if tid == 0: + dst[i // block_dim] = agg + + # Permuted (non-monotone) initialisation so the min depends on lanes other than the first / last. + int_dtypes = (qd.i32, qd.i64, qd.u64) + for i in range(N): + v = ((i * 1009) % 997) + 1 # in [1, 997]; stable hash, no collisions w/ block_dim values up to 256 + src[i] = v if dtype in int_dtypes else 1.0 * v + foo() + + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_reduce_min(block_vals) + if dtype == qd.i32: + assert dst[b] == expected, f"block {b}: got {dst[b]}, expected {expected}" + else: + assert abs(dst[b] - expected) < 1e-5, f"block {b}: got {dst[b]}, expected {expected}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_reduce_max(dtype, sg_per_block): + """Block max-reduce: thread 0 of each block holds `max(src[block_base:block_base+block_dim])`.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=NUM_BLOCKS) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + tid = i % block_dim + agg = block.reduce_max(src[i], block_dim, dtype) + if tid == 0: + dst[i // block_dim] = agg + + int_dtypes = (qd.i32, qd.i64, qd.u64) + for i in range(N): + v = ((i * 1009) % 997) + 1 + src[i] = v if dtype in int_dtypes else 1.0 * v + foo() + + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_reduce_max(block_vals) + if dtype == qd.i32: + assert dst[b] == expected, f"block {b}: got {dst[b]}, expected {expected}" + else: + assert abs(dst[b] - expected) < 1e-5, f"block {b}: got {dst[b]}, expected {expected}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_reduce_all_add(dtype, sg_per_block): + """Block sum-reduce broadcast: every thread of each block holds the block-wide sum. + + Verifies the broadcast variant by writing the per-thread output to a flat field, then asserting every thread of a + given block reads the same aggregate. + """ + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + dst[i] = block.reduce_all_add(src[i], block_dim, dtype) + + _init_field(src, N, dtype) + foo() + + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_reduce_add(block_vals) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected, f"block {b} thread {j}: got {actual}, expected {expected}" + else: + assert abs(actual - expected) < 1e-4 * abs( + expected + ), f"block {b} thread {j}: got {actual}, expected {expected}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_reduce_all_min(dtype, sg_per_block): + """Block min-reduce broadcast: every thread reads the block-wide min.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + dst[i] = block.reduce_all_min(src[i], block_dim, dtype) + + int_dtypes = (qd.i32, qd.i64, qd.u64) + for i in range(N): + v = ((i * 1009) % 997) + 1 + src[i] = v if dtype in int_dtypes else 1.0 * v + foo() + + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_reduce_min(block_vals) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected, f"block {b} thread {j}: got {actual}, expected {expected}" + else: + assert abs(actual - expected) < 1e-5, f"block {b} thread {j}: got {actual}, expected {expected}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_reduce_all_max(dtype, sg_per_block): + """Block max-reduce broadcast: every thread reads the block-wide max.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + dst[i] = block.reduce_all_max(src[i], block_dim, dtype) + + int_dtypes = (qd.i32, qd.i64, qd.u64) + for i in range(N): + v = ((i * 1009) % 997) + 1 + src[i] = v if dtype in int_dtypes else 1.0 * v + foo() + + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_reduce_max(block_vals) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected, f"block {b} thread {j}: got {actual}, expected {expected}" + else: + assert abs(actual - expected) < 1e-5, f"block {b} thread {j}: got {actual}, expected {expected}" + + +# --- Block scan tests ------------------------------------------------------------------ +# +# `qd.simt.block.{inclusive,exclusive}_{add,min,max}` is a two-stage block scan: per-subgroup +# Hillis-Steele scan via shuffle, last lane of each subgroup publishes the subgroup aggregate to +# shared memory, then every thread sequentially folds the cross-subgroup prefix and applies its +# own subgroup's prefix. Every thread receives a valid result. +# +# We exercise the same three block sizes as block reduce (32 single-subgroup short-circuit, 128 +# / 256 multi-subgroup shared-mem) and assert per-thread against a sequential CPU oracle. The +# min / max tests use a permuted (non-monotone) input so the scan result genuinely depends +# on every prefix step, not just the trailing or leading element. + + +def _ref_inclusive_scan_add(values): + out = [] + acc = 0 + for v in values: + acc = acc + v + out.append(acc) + return out + + +def _ref_exclusive_scan_add(values): + out = [] + acc = 0 + for v in values: + out.append(acc) + acc = acc + v + return out + + +def _ref_inclusive_scan_op(values, op, identity): + out = [] + acc = identity + first = True + for v in values: + acc = v if first else op(acc, v) + first = False + out.append(acc) + return out + + +def _ref_exclusive_scan_op(values, op, identity): + out = [] + acc = identity + for v in values: + out.append(acc) + acc = op(acc, v) + return out + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_inclusive_add(dtype, sg_per_block): + """Block inclusive prefix sum: thread `i` holds `sum(src[block_base..i])`.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + dst[i] = block.inclusive_add(src[i], block_dim, dtype) + + _init_field(src, N, dtype) + foo() + + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_inclusive_scan_add(block_vals) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected[j], f"block {b} thread {j}: got {actual}, expected {expected[j]}" + else: + assert abs(actual - expected[j]) < 1e-4 * abs( + expected[j] + 1.0 + ), f"block {b} thread {j}: got {actual}, expected {expected[j]}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_exclusive_add(dtype, sg_per_block): + """Block exclusive prefix sum: thread `i` holds `sum(src[block_base..i-1])`; thread 0 holds 0.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + dst[i] = block.exclusive_add(src[i], block_dim, dtype) + + _init_field(src, N, dtype) + foo() + + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_exclusive_scan_add(block_vals) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected[j], f"block {b} thread {j}: got {actual}, expected {expected[j]}" + else: + # First thread's expected is 0; gate the relative tolerance so it doesn't blow up. + tol_base = abs(expected[j]) if abs(expected[j]) > 1.0 else 1.0 + assert ( + abs(actual - expected[j]) < 1e-4 * tol_base + ), f"block {b} thread {j}: got {actual}, expected {expected[j]}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_inclusive_min(dtype, sg_per_block): + """Block inclusive prefix min.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + dst[i] = block.inclusive_min(src[i], block_dim, dtype) + + int_dtypes = (qd.i32, qd.i64, qd.u64) + for i in range(N): + v = ((i * 1009) % 997) + 1 + src[i] = v if dtype in int_dtypes else 1.0 * v + foo() + + py_min = lambda a, b: a if a < b else b # noqa: E731 (intentional 1-line lambda for ref oracle) + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_inclusive_scan_op(block_vals, py_min, 0) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected[j], f"block {b} thread {j}: got {actual}, expected {expected[j]}" + else: + assert abs(actual - expected[j]) < 1e-5, f"block {b} thread {j}: got {actual}, expected {expected[j]}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_inclusive_max(dtype, sg_per_block): + """Block inclusive prefix max.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + dst[i] = block.inclusive_max(src[i], block_dim, dtype) + + int_dtypes = (qd.i32, qd.i64, qd.u64) + for i in range(N): + v = ((i * 1009) % 997) + 1 + src[i] = v if dtype in int_dtypes else 1.0 * v + foo() + + py_max = lambda a, b: a if a > b else b # noqa: E731 + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_inclusive_scan_op(block_vals, py_max, 0) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected[j], f"block {b} thread {j}: got {actual}, expected {expected[j]}" + else: + assert abs(actual - expected[j]) < 1e-5, f"block {b} thread {j}: got {actual}, expected {expected[j]}" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_exclusive_min(dtype, sg_per_block): + """Block exclusive prefix min; thread 0 holds the supplied identity.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + SENTINEL_INT = 1_000_000 # > every value we initialise (max is ~997 from the permuted hash) + SENTINEL_FLOAT = 1e9 + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + if dtype == qd.i32: + dst[i] = block.exclusive_min(src[i], block_dim, SENTINEL_INT, dtype) + else: + dst[i] = block.exclusive_min(src[i], block_dim, SENTINEL_FLOAT, dtype) + + int_dtypes = (qd.i32, qd.i64, qd.u64) + for i in range(N): + v = ((i * 1009) % 997) + 1 + src[i] = v if dtype in int_dtypes else 1.0 * v + foo() + + sentinel = SENTINEL_INT if dtype == qd.i32 else SENTINEL_FLOAT + py_min = lambda a, b: a if a < b else b # noqa: E731 + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_exclusive_scan_op(block_vals, py_min, sentinel) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected[j], f"block {b} thread {j}: got {actual}, expected {expected[j]}" + else: + assert abs(actual - expected[j]) < 1e-5, f"block {b} thread {j}: got {actual}, expected {expected[j]}" + + +# --- Block radix rank tests ------------------------------------------------------------ +# +# `qd.simt.block.radix_rank_match_atomic_or` implements the atomic-OR match-and-count +# radix-rank strategy on top of the portable subgroup primitives (lanemask_le, sync, +# shuffle) and the block exclusive scan defined above. Block size and digit count are +# both 256 (one digit per thread); each thread contributes one u32 key. +# +# We test the algorithm end-to-end against a CPU oracle: +# +# - rank[i] = excl_prefix[digit[i]] + (#j < i with digit[j] == digit[i]) +# - bins[d] = count of keys whose digit equals d +# - excl_prefix[d] = sum(bins[0..d-1]) +# +# Inputs are mixed: a low-entropy distribution that hits every digit multiple times (so +# the leader-election + atomic_or match path actually has work to do) and a uniform +# random distribution (covers the case where most digits have ~1 key each). Both +# distributions also probe the subgroup-level dedup logic with multiple keys-per-subgroup landing +# in the same digit bin. + + +_RADIX_BITS = 8 +_RADIX_DIGITS = 1 << _RADIX_BITS # 256 +_BLOCK_DIM_RR = _RADIX_DIGITS # algorithm requires block_dim == RADIX_DIGITS + + +def _ref_radix_rank(keys, bit_start, num_bits): + """CPU oracle for `block.radix_rank_match_atomic_or`. + + Returns ``(ranks, bins, excl_prefix)`` over a single tile of ``len(keys)`` u32 keys. ``ranks[i]`` is the stable + rank of ``keys[i]`` when keys are sorted by their ``[bit_start, bit_start + num_bits)`` digit; threads with the + same digit are ordered by their original index. + """ + n = len(keys) + digits_count = 1 << num_bits + mask = (1 << num_bits) - 1 + digits = [(int(k) >> bit_start) & mask for k in keys] + bins = [0] * digits_count + for d in digits: + bins[d] += 1 + excl_prefix = [0] * digits_count + for d in range(1, digits_count): + excl_prefix[d] = excl_prefix[d - 1] + bins[d - 1] + ranks = [0] * n + seen = [0] * digits_count + for i in range(n): + d = digits[i] + ranks[i] = excl_prefix[d] + seen[d] + seen[d] += 1 + return ranks, bins, excl_prefix + + +@pytest.mark.parametrize( + "key_pattern,bit_start,num_bits", + [ + ("low_entropy", 0, 8), # 16 distinct digits each appearing 16 times — heavy match path traffic + ("uniform", 0, 8), # full 8-bit uniform — most digits get 1 key, some get 0 + ("uniform_high_bits", 8, 8), # digit drawn from bits [8, 16) — exercises bit_start > 0 + ], +) +@test_utils.test(arch=qd.gpu) +def test_block_radix_rank_match_atomic_or(key_pattern, bit_start, num_bits): + """End-to-end test of `block.radix_rank_match_atomic_or` against a CPU oracle. + + Single block of ``RADIX_DIGITS == 256`` threads with one key each; we verify per-thread ``rank`` plus the per-digit + ``bins`` and ``excl_prefix`` outparams. + """ + keys_in = qd.field(dtype=qd.u32, shape=_BLOCK_DIM_RR) + ranks_out = qd.field(dtype=qd.i32, shape=_BLOCK_DIM_RR) + bins_out = qd.field(dtype=qd.i32, shape=_RADIX_DIGITS) + excl_prefix_out = qd.field(dtype=qd.i32, shape=_RADIX_DIGITS) + + rng = np.random.default_rng(seed=1234) + if key_pattern == "low_entropy": + # Pick 16 distinct digit values and put 16 copies of each in random positions. Picks land at every digit + # boundary that the [bit_start, bit_start+num_bits) extraction would isolate. + base_digits = rng.choice(_RADIX_DIGITS, size=16, replace=False) + keys_py = np.repeat(base_digits.astype(np.uint32), 16) + rng.shuffle(keys_py) + # Stuff the digit into the relevant bits, leave the rest random so bit_start > 0 still has work. + upper = rng.integers(0, 1 << 16, size=_BLOCK_DIM_RR, dtype=np.uint32) + keys_py = ((upper << np.uint32(8)) | keys_py.astype(np.uint32)).astype(np.uint32) + elif key_pattern == "uniform": + keys_py = rng.integers(0, 1 << 16, size=_BLOCK_DIM_RR, dtype=np.uint32) + elif key_pattern == "uniform_high_bits": + keys_py = rng.integers(0, 1 << 24, size=_BLOCK_DIM_RR, dtype=np.uint32) + else: + raise ValueError(key_pattern) + + for i in range(_BLOCK_DIM_RR): + keys_in[i] = int(keys_py[i]) + + @qd.kernel + def kern(): + qd.loop_config(block_dim=_BLOCK_DIM_RR) + for i in range(_BLOCK_DIM_RR): + tid = i % _BLOCK_DIM_RR + bins_smem = block.SharedArray((_RADIX_DIGITS,), qd.i32) + excl_smem = block.SharedArray((_RADIX_DIGITS,), qd.i32) + rank = block.radix_rank_match_atomic_or( + keys_in[i], + _BLOCK_DIM_RR, + _RADIX_BITS, + bit_start, + num_bits, + bins_smem, + excl_smem, + ) + ranks_out[i] = rank + if tid < _RADIX_DIGITS: + bins_out[tid] = bins_smem[tid] + excl_prefix_out[tid] = excl_smem[tid] + + kern() + + ref_ranks, ref_bins, ref_excl = _ref_radix_rank(keys_py.tolist(), bit_start, num_bits) + + actual_bins = [bins_out[d] for d in range(_RADIX_DIGITS)] + assert actual_bins == ref_bins, f"bins mismatch (pattern={key_pattern})" + + actual_excl = [excl_prefix_out[d] for d in range(_RADIX_DIGITS)] + assert actual_excl == ref_excl, f"excl_prefix mismatch (pattern={key_pattern})" + + actual_ranks = [ranks_out[i] for i in range(_BLOCK_DIM_RR)] + # Ranks must be a permutation of [0, n) — uniqueness check first so any duplicate is caught even if the sorted + # invariant below silently masks it. + assert sorted(actual_ranks) == list( + range(_BLOCK_DIM_RR) + ), f"ranks not a permutation of [0, {_BLOCK_DIM_RR}) for pattern={key_pattern}" + assert actual_ranks == ref_ranks, f"ranks mismatch (pattern={key_pattern})" + + +@pytest.mark.parametrize("dtype", _BLOCK_REDUCE_DTYPES) +@pytest.mark.parametrize("sg_per_block", _BLOCK_REDUCE_SG_PER_BLOCK) +@test_utils.test(arch=qd.gpu) +def test_block_exclusive_max(dtype, sg_per_block): + """Block exclusive prefix max; thread 0 holds the supplied identity.""" + block_dim = sg_per_block * _arch_subgroup_size() + NUM_BLOCKS = 4 + N = NUM_BLOCKS * block_dim + src = qd.field(dtype=dtype, shape=N) + dst = qd.field(dtype=dtype, shape=N) + + SENTINEL_INT = -1_000_000 + SENTINEL_FLOAT = -1e9 + + @qd.kernel + def foo(): + qd.loop_config(block_dim=block_dim) + for i in range(N): + if dtype == qd.i32: + dst[i] = block.exclusive_max(src[i], block_dim, SENTINEL_INT, dtype) + else: + dst[i] = block.exclusive_max(src[i], block_dim, SENTINEL_FLOAT, dtype) + + int_dtypes = (qd.i32, qd.i64, qd.u64) + for i in range(N): + v = ((i * 1009) % 997) + 1 + src[i] = v if dtype in int_dtypes else 1.0 * v + foo() + + sentinel = SENTINEL_INT if dtype == qd.i32 else SENTINEL_FLOAT + py_max = lambda a, b: a if a > b else b # noqa: E731 + for b in range(NUM_BLOCKS): + block_vals = [src[b * block_dim + j] for j in range(block_dim)] + expected = _ref_exclusive_scan_op(block_vals, py_max, sentinel) + for j in range(block_dim): + actual = dst[b * block_dim + j] + if dtype == qd.i32: + assert actual == expected[j], f"block {b} thread {j}: got {actual}, expected {expected[j]}" + else: + assert abs(actual - expected[j]) < 1e-5, f"block {b} thread {j}: got {actual}, expected {expected[j]}" + + @pytest.mark.parametrize("dtype", [qd.i32, qd.f32, qd.f64]) @test_utils.test(arch=qd.gpu) def test_subgroup_shuffle_broadcast(dtype):