diff --git a/docs/source/user_guide/algorithms.md b/docs/source/user_guide/algorithms.md index 859573171f..665de83643 100644 --- a/docs/source/user_guide/algorithms.md +++ b/docs/source/user_guide/algorithms.md @@ -1,21 +1,258 @@ # Algorithms -Device-wide algorithms — primitives that consume and produce whole arrays, executed as one or more kernel launches under the hood. They sit one tier above grid-scope synchronization: they *use* block, subgroup, and grid primitives internally and expose a high-level entry point that the user calls from host (Python) code, not from inside a kernel. +Device-wide algorithms are primitives that consume and produce whole arrays, executed as one or more kernel launches under the hood. They sit one tier above block and subgroup primitives: they *use* `block.reduce`, `block.exclusive_scan`, `block.radix_rank_match_atomic_or`, and `subgroup` reductions internally, and rely on the kernel-launch boundary (plus `atomic_add` in a few places) for cross-block synchronization rather than any in-kernel grid-scope barrier. The user calls them from host (Python) code, not from inside a kernel. ## What's available -| Op | What it does | CUDA | AMDGPU | Vulkan | Metal | -|---------------------------------|---------------------------------------------|------|--------|--------|-------| -| `qd.algorithms.parallel_sort` | Odd-even merge sort (in-place, key or key-value) | yes | yes\* | yes | yes\* | -| `qd.algorithms.PrefixSumExecutor` | Inclusive in-place prefix sum (i32 only) | yes | no | yes | no | +| Op | What it does | CUDA | AMDGPU | Vulkan | Metal | +|-------------------------------------------------------------|--------------------------------------------------------------------|------|--------|--------|-------| +| `qd.algorithms.device_reduce_{add,min,max}(arr, out)` | `out[0] = sum/min/max(arr)` (two-or-more-pass tree reduction; identity derived from `arr.dtype` for min / max) | yes | yes\* | yes | yes\* | +| `qd.algorithms.device_exclusive_scan_{add,min,max}(arr, out)` | `out[i] = sum/min/max(arr[0:i])` (three-pass Blelloch-style scan; 32-bit + 64-bit scalars; identity derived from `arr.dtype` for min / max) | yes | yes\* | yes | yes\* | +| `qd.algorithms.device_select(arr, flags, out, num_out)` | Stream compaction: copy `arr[i]` to a dense prefix of `out` for every `flags[i] == 1` (`flags` must be exactly 0/1). | yes | yes\* | yes | yes\* | +| `qd.algorithms.device_radix_sort(keys, tmp_keys, values=None, tmp_values=None, end_bit=None)` | LSB radix sort for 32-bit or 64-bit scalar keys (optional key-value). | yes | yes\* | yes | yes\* | +| `qd.algorithms.device_reduce_by_key_add(keys_in, values_in, keys_out, values_out, num_runs)` | Collapse each consecutive run of equal keys into `(key, sum_of_values)`. | yes | yes\* | yes | yes\* | +| `qd.algorithms.parallel_sort` | Odd-even merge sort (in-place, key or key-value). **Deprecated**: prefer `device_radix_sort`. | yes | yes\* | yes | yes\* | +| `qd.algorithms.PrefixSumExecutor` | Inclusive in-place prefix sum (i32 only). **Deprecated**: prefer `device_exclusive_scan_add`. | yes | no | yes | no | -\* `parallel_sort` runs anywhere a Quadrants kernel runs; portability is inherited from the underlying kernel infrastructure. AMDGPU and Metal coverage is exercised less heavily than CUDA / Vulkan; report any failures. +\* `device_reduce_*`, `device_exclusive_scan_*`, `device_select`, `device_radix_sort`, `device_reduce_by_key_add`, and `parallel_sort` run anywhere a Quadrants kernel runs; portability is inherited from the underlying block / subgroup primitives. + +## Scratch space + +Every device-wide algorithm in this module decomposes into "per-block partial → cross-block combine → finalize" passes (tree reduction, three-pass Blelloch scan, four-pass radix sort, scan-then-scatter compaction). The per-block partials need somewhere to live between kernel launches - that buffer is called **scratch**. Rather than ask each algorithm to allocate its own (forcing a `qd.field(...)` per call and undermining the no-implicit-allocation contract of the rest of the API), `qd.algorithms` shares a single set of module-level scratch fields across every call. + +There are **two scratch fields**, one per element width that algorithm partials need to live in: + +- `Field(u32)` - used by every 4-byte algorithm: `i32` / `u32` / `f32` reduce + scan, `device_select` indices, `device_reduce_by_key_add` flags + values, `device_radix_sort` tile histograms (regardless of key width). 4-byte values are `bit_cast` to / from `u32` on the way in and out. +- `Field(u64)` - used by every 8-byte algorithm: `i64` / `u64` / `f64` reduce + scan, `u64` radix-sort keys. Same `bit_cast` story, just at 8-byte width. + +Sizing: each field defaults to **5 MB** (`DEFAULT_SCRATCH_BYTES = 5 << 20`). That covers `N` up to ~1.3M elements for `device_select` / `device_radix_sort` / `device_reduce_by_key_add` (`~N` u32 slots, qipc's hot path), and well past `N = 64M` for `device_reduce_*` / `device_exclusive_scan_*` (`~N / BLOCK_DIM` u32 slots, `BLOCK_DIM = 256`). The u64 field is sized to the same byte budget, so it covers half as many elements. + +**Allocation is lazy.** A scratch field is only allocated on its first `get_scratch_*()` call from inside an algorithm. Programs that never touch `qd.algorithms.*` pay nothing; programs that only touch 4-byte algorithms never allocate the u64 buffer. (The default budget is therefore a per-field worst case, not a fixed cost: a 4-byte-only caller pays 5 MB, not 10 MB.) + +**`qd.reset()` invalidates every scratch field** via an `impl.on_reset` hook, and resets the byte budget back to `DEFAULT_SCRATCH_BYTES`. The next algorithm call after a `qd.init()` reallocates against the fresh runtime at the default capacity. This keeps `qd.init` / `qd.reset` a "clean slate" - all runtime-scoped state (resource handles *and* config) goes away on reset, by design. Apps that need a persistent bump should call `set_scratch_bytes` immediately after each `qd.init`. + +**Tuning the budget.** Call `quadrants._scratch.set_scratch_bytes(N)` before any algorithm runs (or before any algorithm runs after a `qd.reset()`). Pass a larger value to cover bigger `N`, or a smaller value to reduce the resident footprint on memory-constrained devices: + +```python +from quadrants import _scratch +_scratch.set_scratch_bytes(20 << 20) # 20 MB; covers N up to ~5M for device_select / radix sort +``` + +`set_scratch_bytes` raises `RuntimeError` if any scratch field has already been allocated in the current runtime cycle (re-`qd.init`-ing wipes that constraint). `scratch_bytes` must be a positive multiple of 8. + +The per-algorithm sections below mention scratch only to call out per-algo footprint (so you can size the budget for a known `N`); the mechanics live here. ## Semantics +### `qd.algorithms.device_reduce_{add,min,max}(arr, out)` + +Device-wide tree reduction over a 1-D tensor: `out[0]` holds `sum(arr)` / `min(arr)` / `max(arr)`. The monoid identity is derived from `arr.dtype` automatically (`0` for `add`; largest representable value for `min` - `+inf` for floats, `INT{32,64}_MAX` for signed ints, `UINT{32,64}_MAX` for unsigned; smallest representable value for `max` - `-inf` for floats, `INT{32,64}_MIN` for signed ints, `0` for unsigned), mirroring the `block.reduce_min` / `subgroup.reduce_min` typed wrappers which don't take an identity for the same reason. + +```python +import quadrants as qd + +inp = qd.field(qd.f32, shape=N) +out = qd.field(qd.f32, shape=1) +# ... fill inp ... + +qd.algorithms.device_reduce_add(inp, out=out) +total = float(out.to_numpy()[0]) # explicit device->host hop +``` + +Arguments: + +- `arr`: 1-D input tensor. Pass a `qd.field`, `qd.ndarray`, or `qd.Tensor` wrapper around either - the kernels are polymorphic via the `qd.Tensor` annotation. +- `out`: 1-element tensor with the same dtype as `arr`. Caller-supplied so the call is fully asynchronous - there is no implicit device→host sync. To get a Python scalar, do `out.to_numpy()[0]` explicitly after the call. This makes the host hop visible at the call site rather than hidden inside the algorithm. + +Constraints: + +- **Dtypes:** scalar `qd.i32`, `qd.u32`, `qd.f32`, `qd.i64`, `qd.u64`, `qd.f64`. Narrower / wider scalar dtypes (e.g. `qd.i16`, `qd.f16`) and struct dtypes raise `NotImplementedError`. 4-byte dtypes stage through the shared u32 scratch and 8-byte dtypes through the shared u64 scratch; see [Scratch space](#scratch-space) for the mechanics. +- **Shape:** `arr` must be 1-D; `out.shape` must be `(1,)`. Both must share the same dtype. +- **f32 / f64 non-associativity:** `device_reduce_add` on a floating-point dtype is not bitwise-reproducible across `N` changes, nor bitwise-equal to host `numpy.sum`. Tests tolerate a small relative error rather than asserting bitwise. + +Implementation: + +- Two-or-more-pass tree reduction. Each pass uses `BLOCK_DIM = 256` threads per block and reduces 256 elements per block via `block.reduce_{add,min,max}`. For `N <= 256` one pass suffices; for `N` up to `256^2 = 65536`, two passes; for larger `N`, additional intermediate passes are added until the reduction terminates in a single block. +- Per-block partials are written to the shared scratch field (u32 for 4-byte dtypes, u64 for 8-byte dtypes; see [Scratch space](#scratch-space)). +- The last pass writes the final value to `out[0]` directly. The kernel launches are pipelined back-to-back; correctness relies on the kernel-boundary serialization that Quadrants provides between host-launched kernels. + +Scratch footprint: `ceil(N / BLOCK_DIM)` slots, where `BLOCK_DIM = 256`. Well under the 5 MB default for any reasonable `N` (`N = 1G` is ~4M slots); see [Scratch space](#scratch-space) if you need a different budget. + +### `qd.algorithms.device_exclusive_scan_{add,min,max}(arr, out)` + +Device-wide exclusive prefix scan over a 1-D tensor: `out[i]` holds the reduction (`sum` / `min` / `max`) of `arr[0:i]`. `out[0]` is always the monoid identity, which is derived from `arr.dtype` automatically (`0` for `add`; largest representable value for `min` - `+inf` for floats, `INT{32,64}_MAX` for signed ints, `UINT{32,64}_MAX` for unsigned; smallest representable value for `max` - `-inf` for floats, `INT{32,64}_MIN` for signed ints, `0` for unsigned), mirroring the `block.exclusive_min` / `subgroup.exclusive_min_tiled` typed wrappers. + +```python +import quadrants as qd + +N = 1_000_000 +inp = qd.field(qd.f32, shape=N) +out = qd.field(qd.f32, shape=N) +# ... fill inp ... + +qd.algorithms.device_exclusive_scan_add(inp, out=out) +# out[0] == 0.0; out[i] == sum(inp[0:i]) for i > 0. +``` + +Constraints: + +- **Dtypes:** scalar `qd.i32`, `qd.u32`, `qd.f32`, `qd.i64`, `qd.u64`, `qd.f64`. Narrower / wider scalar dtypes (e.g. `qd.i16`, `qd.f16`) and struct dtypes raise `NotImplementedError`. 4-byte dtypes stage through the shared u32 scratch and 8-byte dtypes through the shared u64 scratch; see [Scratch space](#scratch-space) for the mechanics. +- **Shape:** `arr` and `out` must both be 1-D with the same shape and dtype. +- **No in-place scan:** `out` must be a distinct buffer from `arr`. Calling with `out is arr` raises `ValueError`. (The kernels do not protect against same-buffer aliasing; allocating one extra buffer once is cheap relative to the scan itself.) +- **Float non-associativity:** the order of additions inside a scan tree is not the same as a left-to-right host scan, so `f32` / `f64` results are *not* bitwise-equal to `numpy.cumsum`. Tests tolerate a small relative error (scaled by dtype precision). + +Implementation: + +- Blelloch 1990 three-pass exclusive scan: + 1. **Pass 1** - per-block tile reduce into the shared scratch (one slot per block). + 2. **Pass 2** - exclusive-scan the partials buffer in place. For `N ≤ BLOCK_DIM²` (= 65536) a single block does this. For larger `N`, the driver recurses: another tile-reduce on the partials, a recursive scan, then a downsweep that applies the higher-level prefixes. + 3. **Pass 3** - per-block tile scan + add the block prefix from scratch. Each block re-reads its tile from `arr`, runs `block.exclusive_scan` to get per-thread tile prefixes, and adds its `block_prefix` from the scanned partials. +- `BLOCK_DIM = 256`. Total scratch usage at `N = 1M` is `4096 + 16 = 4112` slots (~16 KB for 4-byte dtypes, ~32 KB for 8-byte), trivial relative to the 5 MB default. See [Scratch space](#scratch-space) for budget mechanics. + +### `qd.algorithms.device_select(arr, flags, out, num_out)` + +Stream compaction. Copy every `arr[i]` whose corresponding `flags[i]` is `1` into a dense prefix of `out`, in stable input order, and write the count of selected elements to `num_out[0]`. Flags must be exactly `0` or `1` - see the constraints below. + +```python +import quadrants as qd + +N = 100_000 +inp = qd.field(qd.f32, shape=N) +flags = qd.field(qd.i32, shape=N) # caller fills with 0 / 1 +out = qd.field(qd.f32, shape=N) # large enough for worst case +num_out = qd.field(qd.i32, shape=1) + +# ... fill inp + flags via a separate kernel ... + +qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + +# Only out[0 : count] is meaningful; copy out the count host-side explicitly: +count = int(num_out.to_numpy()[0]) +selected = out.to_numpy()[:count] +``` + +Constraints: + +- **Dtypes:** `arr.dtype` is any scalar dtype in `{qd.i32, qd.u32, qd.f32, qd.i64, qd.u64, qd.f64}` *or* any `qd.types.struct(...)` / `qd.Struct.field({...})` composite (e.g. libuipc `Vector2i` / `Vector3i` / `Vector4i` / `LinearBVHAABB`-style structs). The scatter is `dst[idx] = src[i]`, which lowers per-field, so the algorithm is dtype-agnostic - no scratch reinterpretation needed for wider or composite element types. +- **`flags`:** 1-D `qd.i32` tensor with the same shape as `arr`. **Every entry must be exactly `0` or `1`** (`1` selects). The algorithm prefix-sums `flags` directly as counts, so non-0/1 values produce wrong indices and a wrong `num_out` count - the caller is responsible for normalization, no implicit normalization pass is performed. `flags` is caller-built - populate it with a kernel applying whatever predicate you want, writing exactly `1` for selected and `0` otherwise. +- **`out`:** 1-D tensor, same dtype as `arr`, with `len(out) >= len(arr)` so the worst-case all-selected run is safe. Only `out[0 : num_out[0]]` carries meaningful data on return; the tail is left untouched (whatever was in `out` before the call remains). +- **`num_out`:** 1-element `qd.i32` tensor. Same explicit-host-hop rule: do `int(num_out.to_numpy()[0])` after the call to get the count as a Python scalar. + +Algorithm: the textbook scan-based compaction. + +1. **Exclusive scan of `flags`** into the shared u32 scratch, producing per-element write indices. Same three-pass internals as `device_exclusive_scan_add`. +2. **Scatter:** one parallel kernel reads each `(arr[i], flags[i], indices[i])` and, if the flag is set, writes `out[indices[i]] = arr[i]`. No races by construction of the exclusive scan over 0 / 1 flags. +3. **Count tail:** one-thread kernel computes `indices[N-1] + flags[N-1]` and stores it in `num_out[0]`. + +Scratch footprint: ~`N` u32 slots (one write index per input element). The default 5 MB scratch covers `N` up to ~1.3M (qipc's hot path lands here out of the box); bump the budget per [Scratch space](#scratch-space) for larger inputs. + +### `qd.algorithms.device_radix_sort(keys, tmp_keys, values=None, tmp_values=None, end_bit=None)` + +Ascending in-place radix sort over a 1-D tensor of 32-bit or 64-bit scalar keys (`u32` / `i32` / `f32` / `u64` / `i64` / `f64`), with optional lock-step permutation of an `values` tensor (key-value sort). + +```python +import quadrants as qd + +N = 100_000 +keys = qd.field(qd.f32, shape=N) +tmp_keys = qd.field(qd.f32, shape=N) # workspace; contents on return are garbage +# ... fill keys ... + +qd.algorithms.device_radix_sort(keys, tmp_keys=tmp_keys) +# keys is now ascending; tmp_keys holds intermediate state. + +# Key-value sort: +values = qd.field(qd.i32, shape=N) +tmp_values = qd.field(qd.i32, shape=N) +# ... fill values (e.g. with original indices) ... + +qd.algorithms.device_radix_sort( + keys, tmp_keys=tmp_keys, values=values, tmp_values=tmp_values, +) +# keys ascending; values permuted so values[k] corresponds to keys[k]. +``` + +Arguments: + +- `keys`: 1-D tensor. Sorted **in place**. Pass `qd.field`, `qd.ndarray`, or `qd.Tensor`. +- `tmp_keys`: ping-pong workspace, same shape & dtype as `keys`, distinct buffer. Contents on return are intermediate and should be considered garbage. +- `values`: optional 1-D tensor of any supported scalar dtype (the value dtype is independent of the key dtype), same shape as `keys`. If provided, permuted in lock-step with the keys. +- `tmp_values`: required iff `values` is provided. Same shape & dtype as `values`, distinct buffer. Same workspace semantics as `tmp_keys`. +- `end_bit`: number of low bits of the key to consider. Defaults to the full key width (32 for 4-byte keys, 64 for 8-byte keys). Must be a positive multiple of `8` (the radix-digit width). An even number of digit passes is required so the result lands back in `keys`; with the default `end_bit` this is automatic. Pass a smaller value when the high bits are known to be zero (e.g. `end_bit=16` for keys with values `< 2**16`) to save passes. + +Constraints: + +- **Dtypes:** `keys.dtype` and `values.dtype` are each independently one of `{qd.u32, qd.i32, qd.f32, qd.u64, qd.i64, qd.f64}`. Narrower scalar dtypes (`qd.i16`, `qd.f16`, ...) and struct dtypes raise `NotImplementedError`. 8-byte keys run 8 digit passes per sort; 4-byte keys run 4. Scratch footprint is the same for both widths (the per-tile histograms are `u32` regardless). +- **Aliasing:** `keys` and `tmp_keys` must be distinct buffers; same for `values` / `tmp_values`. Calling with the same buffer raises `ValueError`. +- **Stability:** stable sort - equal keys keep their original input order in the output. +- **NaN handling (f32):** matches `numpy.sort` (NaNs land at the end). NaNs are not tested separately and should not be relied on for ordering invariants beyond `numpy.sort`. + +Implementation: + +- Classical LSB radix sort with 8-bit digits, four passes for `u32` / `i32` / `f32`. Each digit pass is three internal kernels: + 1. **Histogram** - every block computes its per-digit count into shared memory, then publishes the 256-bin tile histogram to the shared u32 scratch (digit-major layout: `tile_histograms[d * num_blocks + b]`). + 2. **Scan** - in-place exclusive scan over the flat tile_histograms buffer. The digit-major layout makes a single 1-D scan enough to produce per-(digit, block) global offsets. + 3. **Scatter** - each block ranks its keys via `block.radix_rank_match_atomic_or` (wave32 + wave64 clean), looks up the per-(digit, block) global offset from the scan output, and scatters keys (and values, if provided) to the destination buffer. +- After each pass we swap `keys` ↔ `tmp_keys`. Four passes is even, so the sorted keys end up back in `keys`. +- Signed-integer (`i32` / `i64`) and floating-point (`f32` / `f64`) keys are mapped to a sortable unsigned representation (`u32` / `u64`) before the first pass and mapped back after the last pass via in-place "twiddle" kernels (signed: XOR sign bit; float: flip sign bit on positives, flip all bits on negatives - the standard sortable-key transform). `u32` / `u64` keys are sorted directly with no twiddle. + +Scratch footprint: `num_blocks * 256 + ...` u32 slots per pass (re-used across passes), where `num_blocks = ceil(N / 256)`. The default 5 MB scratch covers `N` up to ~1.3M (qipc's hot path lands here out of the box); bump the budget per [Scratch space](#scratch-space) for larger inputs. + +### `qd.algorithms.device_reduce_by_key_add(keys_in, values_in, keys_out, values_out, num_runs)` + +Collapse every **consecutive run of equal keys** into a single output entry `(unique_key, sum_of_values_in_run)`. Keys that compare equal but are separated by other keys form separate runs. For a global per-key sum, sort by key first (e.g. with `qd.algorithms.device_radix_sort`) and then reduce-by-key. + +```python +import quadrants as qd + +N = 100_000 +keys_in = qd.field(qd.i32, shape=N) # sorted by key beforehand +values_in = qd.field(qd.f32, shape=N) +keys_out = qd.field(qd.i32, shape=N) # capacity = N (worst case: all unique) +values_out = qd.field(qd.f32, shape=N) +num_runs = qd.field(qd.i32, shape=1) + +# ... fill keys_in + values_in ... + +qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs, +) + +count = int(num_runs.to_numpy()[0]) +uniq_k = keys_out.to_numpy()[:count] +sums = values_out.to_numpy()[:count] +``` + +Arguments: + +- `keys_in`: 1-D tensor of `u32` / `i32` / `f32`. Pass a `qd.field`, `qd.ndarray`, or `qd.Tensor`. +- `values_in`: 1-D tensor of `u32` / `i32` / `f32`, same shape as `keys_in`. +- `keys_out`: 1-D tensor of the same dtype as `keys_in`, with `len(keys_out) >= len(keys_in)` so the worst-case-all-unique run is safe. Only `keys_out[0 : num_runs[0]]` carries meaningful data on return; the tail is untouched. +- `values_out`: 1-D tensor of the same dtype as `values_in`, same length requirement. The first `num_runs[0]` slots are overwritten; the tail past that prefix is left untouched. +- `num_runs`: 1-element `qd.i32` tensor receiving the number of runs. Same explicit-host-hop rule: do `int(num_runs.to_numpy()[0])` after the call to get the count as a Python scalar. + +Constraints: + +- **Dtypes (first land):** `keys_in.dtype` and `values_in.dtype` ∈ {`qd.i32`, `qd.u32`, `qd.f32`}. Other dtypes raise `NotImplementedError`. +- **Reduction:** only `add` is exposed for first land. `min` / `max` variants need `atomic_min` / `atomic_max` for `f32`, which has spottier cross-backend support; defer to a follow-up gated on real qipc usage. +- **f32 non-associativity:** the order of additions inside a run is set by hardware atomic ordering, not host order, so `f32` results are *not* bitwise-equal to a serial scan. Tests tolerate a small relative error. +- **NaN handling (f32 keys):** `NaN != NaN` is true, so each NaN-keyed element becomes its own run. Consistent with treating NaN as "different from everything", which matches the run-length-encoding spirit. + +Algorithm: scan + scatter + atomic_add - no segmented-scan primitive needed. + +1. **Head-flag pass.** `head_flags[i] = 1` if `i == 0` or `keys[i] != keys[i-1]`, else `0`. Written to the shared u32 scratch (bit-cast from `i32`). +2. **In-place exclusive scan** of `head_flags` (using the same three-pass internals as `device_exclusive_scan_add`). After this, `scratch[i] = sum(head_flags[0:i])`. +3. **Zero-init `values_out[0:N]`.** The scatter uses `atomic_add`; slots must start at the additive identity `0`. +4. **Scatter.** For each `i`, recompute `head_flag(i)` from `keys[i]` / `keys[i-1]`, derive the run index `pos = scratch[i] + head_flag(i) - 1` (inclusive scan minus 1), and write `keys_out[pos] = keys[i]` + `atomic_add(values_out[pos], values[i])`. +5. **Count.** `num_runs[0] = scratch[N-1] + head_flag(N-1)`. + +Scratch footprint: ~`1.004 * N` u32 slots. The default 5 MB scratch covers `N` up to ~1.3M; bump the budget per [Scratch space](#scratch-space) for larger inputs. + ### `qd.algorithms.parallel_sort(keys, values=None)` -In-place sort. Reorders `keys` ascending; if `values` is provided, applies the same permutation to `values` (key-value sort). Both arguments must be 1-D `qd.field` — `parallel_sort` reaches into `snode.ptr.offset` internally, so `ndarray` is **not** supported and will fail at compile time with an `AttributeError`. +> **Deprecated.** New code should call `qd.algorithms.device_radix_sort(keys, tmp_keys, values=..., tmp_values=...)` instead. `device_radix_sort` is asymptotically `O(N log_radix N)` rather than `O(N log^2 N)`, is **stable** (odd-even merge sort is not), supports 32-bit and 64-bit scalar keys across CUDA / AMDGPU / Vulkan / Metal, and accepts `qd.field`, `qd.ndarray`, and `qd.Tensor` (`parallel_sort` is field-only). The only thing `parallel_sort` is competitive on is very small N (~4K and below); even there the radix path is comparable on modern hardware. To migrate, allocate a `tmp_keys` field of the same shape and dtype as `keys`, then call `device_radix_sort`. `parallel_sort` is kept for one release cycle for backward compat and will be removed thereafter. + +In-place sort. Reorders `keys` ascending; if `values` is provided, applies the same permutation to `values` (key-value sort). Both arguments must be 1-D `qd.field` - `parallel_sort` reaches into `snode.ptr.offset` internally, so `ndarray` is **not** supported and will fail at compile time with an `AttributeError`. ```python import quadrants as qd @@ -32,12 +269,14 @@ qd.algorithms.parallel_sort(keys, vals) - **Algorithm.** Batcher's odd-even merge sort. Time complexity `O(N log² N)`, work-efficient for small / mid-sized arrays. - **Key dtype.** Whatever the key field's dtype is, as long as `<` is meaningful for it (integer and float types). -- **Stability.** Odd-even merge sort is *not* a stable sort — equal keys may be reordered relative to one another. If stability matters, encode tiebreakers into the keys (e.g. pack the original index into the low bits). -- **Memory.** Strictly in-place — no auxiliary buffers from the caller's perspective. +- **Stability.** Odd-even merge sort is *not* a stable sort - equal keys may be reordered relative to one another. If stability matters, encode tiebreakers into the keys (e.g. pack the original index into the low bits). +- **Memory.** Strictly in-place - no auxiliary buffers from the caller's perspective. - **Performance characteristic.** Beats radix-style sorts for small N (roughly N ≲ 4K). ### `qd.algorithms.PrefixSumExecutor` +> **Deprecated.** New code should call `qd.algorithms.device_exclusive_scan_add(arr, out)` instead. `PrefixSumExecutor` is **inclusive**-only, **`i32`**-only, and **CUDA / Vulkan**-only; the new functional API covers `{i32, u32, f32, i64, u64, f64}` on every supported backend and runs the exclusive variant directly. To migrate from inclusive in-place to exclusive out-of-place, drop the `Executor` wrapper, allocate a distinct `out` field, and post-process if you actually need the inclusive form (`inclusive[i] = exclusive[i] + arr[i]`). `PrefixSumExecutor` is kept for one release cycle for backward compat and will be removed in a future release. + Inclusive in-place prefix sum (scan) over a 1-D `i32` field. Construct once with the array length, then call `.run(field)` to scan. ```python @@ -50,11 +289,11 @@ psum.run(arr) Constructor: -- `length: int` — the **fixed** number of elements the executor will scan on every `.run()` call. Internally allocates an auxiliary `qd.field(i32, shape=padded_length)` sized to the Kogge-Stone hierarchy (block size = 64). +- `length: int` - the **fixed** number of elements the executor will scan on every `.run()` call. Internally allocates an auxiliary `qd.field(i32, shape=padded_length)` sized to the Kogge-Stone hierarchy (block size = 64). `run(input_arr)`: -- `input_arr` must be a 1-D `qd.field(qd.i32, shape=(length,))` — its length must match the constructor's `length` exactly. `run()` always blits `length` elements between `input_arr` and the internal buffer; passing a shorter field results in out-of-bounds reads / writes (no runtime check today). +- `input_arr` must be a 1-D `qd.field(qd.i32, shape=(length,))` - its length must match the constructor's `length` exactly. `run()` always blits `length` elements between `input_arr` and the internal buffer; passing a shorter field results in out-of-bounds reads / writes (no runtime check today). - Returns nothing; `input_arr` is overwritten with the scan result. Constraints: @@ -73,8 +312,10 @@ No explicit fence is required between a kernel that writes the input and the sub ```python N = 1000 -keys = qd.field(qd.f32, shape=(N,)) -indices = qd.field(qd.i32, shape=(N,)) +keys = qd.field(qd.f32, shape=(N,)) +tmp_keys = qd.field(qd.f32, shape=(N,)) +indices = qd.field(qd.i32, shape=(N,)) +tmp_idx = qd.field(qd.i32, shape=(N,)) @qd.kernel def init() -> None: @@ -83,8 +324,10 @@ def init() -> None: indices[i] = i init() -qd.algorithms.parallel_sort(keys, indices) -# keys is now ascending; indices[k] is the original index of the k-th smallest key. +qd.algorithms.device_radix_sort( + keys, tmp_keys=tmp_keys, values=indices, tmp_values=tmp_idx, +) +# keys is now ascending; indices[k] is the original index of the k-th smallest key. (Stable: ties between equal keys preserve their input-order indices.) ``` ### Compact-array offsets via prefix sum @@ -116,7 +359,7 @@ The compact-output kernel reads `offsets[i]` (or `offsets[i] - flags[i]` for 0-b ## Related -- `qd.simt.block.*` — the block-scope reductions and shared-memory primitives that algorithm kernels build on. -- `qd.simt.subgroup.*` — `inclusive_add` and friends, what the per-block scan stage of `PrefixSumExecutor` actually calls. -- `qd.simt.grid.mem_fence()` — the grid-scope memory fence that decoupled-look-back scans (a more efficient alternative to Kogge-Stone) require. -- [parallelization](parallelization.md) — broader synchronization story, including how `qd.algorithms` operations compose with hand-written kernels. +- `qd.simt.block.*` - the block-scope reductions and shared-memory primitives that algorithm kernels build on. +- `qd.simt.subgroup.*` - `inclusive_add` and friends, what the per-block scan stage of `PrefixSumExecutor` actually calls. +- `qd.simt.grid.mem_fence()` - the grid-scope memory fence that decoupled-look-back scans (a more efficient alternative to Kogge-Stone) require. +- [parallelization](parallelization.md) - broader synchronization story, including how `qd.algorithms` operations compose with hand-written kernels. diff --git a/python/quadrants/_scratch.py b/python/quadrants/_scratch.py new file mode 100644 index 0000000000..e9fe67af1d --- /dev/null +++ b/python/quadrants/_scratch.py @@ -0,0 +1,105 @@ +"""Quadrants-level scratch buffer for device-wide algorithms. + +Two scratch fields - one ``Field(u32)`` and one ``Field(u64)`` - shared by every ``qd.algorithms.*`` device kernel. +Algorithms ``qd.bit_cast`` to / from these buffers to support every supported scalar dtype: 4-byte ``i32`` / ``u32`` +/ ``f32`` go through the u32 scratch; 8-byte ``i64`` / ``u64`` / ``f64`` go through the u64 scratch. Sized to +comfortably cover device-wide reduce, exclusive scan, select / compact, radix sort, and reduce-by-key on inputs up +to ``N = 1M`` out of the box (qipc's hot path), per the design doc at +``perso_hugh/doc/qipc/qipc_device_algos_design.md``. + +Sizing rationale: ``device_select`` / ``device_radix_sort`` need ~``N`` u32 slots per call (one write index / +tile-histogram entry per input element). At ``N = 1M`` that is 4 MB of u32 slots; we round up to 5 MB to leave +headroom for the recursion overhead (``ceil(N / BLOCK_DIM)`` extra slots) and the second-level scan partials. +``device_reduce_*`` / ``device_exclusive_scan_*`` need only ~``N / BLOCK_DIM`` u32 slots, so the same 5 MB +covers them well past ``N = 64M``. The u64 scratch sees half as many slots at the same byte budget. + +Allocation strategy: lazy on first use, invalidated on ``qd.reset()`` via the ``impl.on_reset`` hook. This avoids +paying the 5 MB-per-width allocation cost in programs that never touch ``qd.algorithms``, and avoids coupling +``qd.init()``'s argument surface to the device-algos work for the first land. Programs that only touch 4-byte +algorithms never pay for the u64 buffer. A future change can add ``qd.init(scratch_bytes=...)`` if a caller needs +to override the default before any allocation has happened. +""" + +from quadrants.lang.impl import field, on_reset +from quadrants.types.primitive_types import u32, u64 + +DEFAULT_SCRATCH_BYTES: int = 5 * (1 << 20) + +_scratch_field = None +_scratch_field_u64 = None +_scratch_bytes: int = DEFAULT_SCRATCH_BYTES + + +def set_scratch_bytes(scratch_bytes: int) -> None: + """Set the scratch capacity in bytes for the next allocation. + + Must be called before the first ``get_scratch_u32()`` / ``get_scratch_u64()`` call in the current runtime cycle. + Has no effect on an already-allocated scratch field; users wishing to enlarge an existing scratch must + ``qd.reset()`` and ``qd.init()`` again, then re-call ``set_scratch_bytes`` (capacity resets to + ``DEFAULT_SCRATCH_BYTES`` on every ``qd.reset()``). + """ + global _scratch_bytes + if _scratch_field is not None or _scratch_field_u64 is not None: + raise RuntimeError( + "set_scratch_bytes called after scratch was already allocated; " + "call before any qd.algorithms.* op runs, or qd.reset() first" + ) + if scratch_bytes <= 0 or scratch_bytes % 8 != 0: + raise ValueError(f"scratch_bytes must be a positive multiple of 8; got {scratch_bytes}") + _scratch_bytes = scratch_bytes + + +def get_scratch_u32(): + """Return the shared scratch ``Field(u32)``, allocating on first use. + + The field is invalidated automatically by the ``impl.on_reset`` hook registered below, so a subsequent call + after ``qd.reset()`` will reallocate against the fresh runtime. + """ + global _scratch_field + if _scratch_field is None: + _scratch_field = field(u32, shape=_scratch_bytes // 4) + return _scratch_field + + +def get_scratch_u64(): + """Return the shared scratch ``Field(u64)``, allocating on first use. + + Used by 8-byte algorithms (``i64`` / ``u64`` / ``f64`` reduce + scan, ``u64`` radix-sort keys). Lives alongside + the u32 scratch rather than overlaying it: a u64 backing aliasing into u32-sized half-cells would require + dtype-punning fields, which Quadrants doesn't expose. Same byte budget, half as many slots. + """ + global _scratch_field_u64 + if _scratch_field_u64 is None: + _scratch_field_u64 = field(u64, shape=_scratch_bytes // 8) + return _scratch_field_u64 + + +def scratch_capacity_u32() -> int: + """Return the scratch capacity in u32 slots for the *next* allocation.""" + return _scratch_bytes // 4 + + +def scratch_capacity_u64() -> int: + """Return the scratch capacity in u64 slots for the *next* allocation.""" + return _scratch_bytes // 8 + + +def _invalidate() -> None: + """Drop the cached scratch handles *and* reset the capacity setting back to ``DEFAULT_SCRATCH_BYTES``. Registered + as an ``impl.on_reset`` hook so every ``qd.reset()`` -> ``qd.init()`` transaction is a clean slate: the next + ``get_scratch_*()`` call reallocates against the fresh runtime at the default capacity, and any prior + ``set_scratch_bytes(...)`` bump has to be re-applied before the new runtime's first algorithm call. + + The persistence-vs-clean-slate trade-off was explicitly resolved in favour of clean slate: ``qd.init`` / + ``qd.reset`` is meant to be "free to use whenever, no constraints", which only holds if all module state tied to + a runtime cycle (resource handles *and* runtime-scoped config) goes away on reset. Apps that want a persistent + bump (or persistent shrink, for apps that know their N is small and don't want to pay 10 MB across the two + scratch fields) should call ``set_scratch_bytes`` immediately after each ``qd.init``. + """ + global _scratch_field, _scratch_field_u64, _scratch_bytes + _scratch_field = None + _scratch_field_u64 = None + _scratch_bytes = DEFAULT_SCRATCH_BYTES + + +on_reset(_invalidate) diff --git a/python/quadrants/algorithms/__init__.py b/python/quadrants/algorithms/__init__.py index 4f521c3ecf..8c2701a59e 100644 --- a/python/quadrants/algorithms/__init__.py +++ b/python/quadrants/algorithms/__init__.py @@ -1,3 +1,26 @@ # type: ignore from ._algorithms import * +from ._radix_sort import device_radix_sort +from ._reduce import device_reduce_add, device_reduce_max, device_reduce_min +from ._reduce_by_key import device_reduce_by_key_add +from ._scan import ( + device_exclusive_scan_add, + device_exclusive_scan_max, + device_exclusive_scan_min, +) +from ._select import device_select + +__all__ = [ + "PrefixSumExecutor", + "device_exclusive_scan_add", + "device_exclusive_scan_max", + "device_exclusive_scan_min", + "device_radix_sort", + "device_reduce_add", + "device_reduce_by_key_add", + "device_reduce_max", + "device_reduce_min", + "device_select", + "parallel_sort", +] diff --git a/python/quadrants/algorithms/_algorithms.py b/python/quadrants/algorithms/_algorithms.py index ea77fcf9cc..c4ae8dcf70 100644 --- a/python/quadrants/algorithms/_algorithms.py +++ b/python/quadrants/algorithms/_algorithms.py @@ -17,12 +17,28 @@ def parallel_sort(keys, values=None): - """Odd-even merge sort + """Odd-even merge sort (deprecated). + + .. deprecated:: + Prefer ``qd.algorithms.device_radix_sort(keys, *, tmp_keys, values=..., tmp_values=...)``. The new + functional API is asymptotically ``O(N log_radix N)`` rather than ``O(N log^2 N)``, supports + ``{u32, i32, f32}`` keys across CUDA / AMDGPU / Vulkan / Metal, and takes a caller-supplied tmp buffer so + the call stays fully async. ``parallel_sort`` is kept for one release cycle for backward compat and will be + removed thereafter. See ``docs/source/user_guide/algorithms.md`` for the migration recipe. References: https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort """ + import warnings # pylint: disable=import-outside-toplevel + + warnings.warn( + "qd.algorithms.parallel_sort is deprecated. Use " + "qd.algorithms.device_radix_sort(keys, tmp_keys=..., values=..., tmp_values=...) " + "instead. See docs/source/user_guide/algorithms.md for migration.", + DeprecationWarning, + stacklevel=2, + ) N = keys.shape[0] num_stages = 0 @@ -43,7 +59,14 @@ def parallel_sort(keys, values=None): @data_oriented class PrefixSumExecutor: - """Parallel Prefix Sum (Scan) Helper + """Parallel Prefix Sum (Scan) Helper. + + .. deprecated:: + Prefer ``qd.algorithms.device_exclusive_scan_add(arr, out)``. The new functional API supports + ``{i32, u32, f32}`` on every backend (CUDA, AMDGPU, Vulkan, Metal) and runs the exclusive variant directly. + ``PrefixSumExecutor`` is inclusive-only, ``i32``-only, and limited to CUDA / Vulkan; it is kept for one + release cycle for backward compat and will be removed thereafter. See ``docs/source/user_guide/algorithms.md`` + for the migration recipe. Use this helper to perform an inclusive in-place's parallel prefix sum. @@ -53,6 +76,18 @@ class PrefixSumExecutor: """ def __init__(self, length): + import warnings # pylint: disable=import-outside-toplevel + + warnings.warn( + "qd.algorithms.PrefixSumExecutor is deprecated. Use " + "qd.algorithms.device_exclusive_scan_add(arr, out) instead. " + "See docs/source/user_guide/algorithms.md for migration.", + DeprecationWarning, + stacklevel=2, + ) + self._init(length) + + def _init(self, length): self.sorting_length = length BLOCK_SZ = 64 diff --git a/python/quadrants/algorithms/_radix_sort.py b/python/quadrants/algorithms/_radix_sort.py new file mode 100644 index 0000000000..8540ddb0ad --- /dev/null +++ b/python/quadrants/algorithms/_radix_sort.py @@ -0,0 +1,506 @@ +# type: ignore +"""Device-wide LSB radix sort. + +Implements ``qd.algorithms.device_radix_sort`` on top of the block-tier ``block.radix_rank_match_atomic_or`` +primitive (which is wave32 + wave64 clean since ``cd9e546851``). See the design doc at +``perso_hugh/doc/qipc/qipc_device_algos_design.md`` for the broader context and the choice of *not* using single-pass +Onesweep for first land. + +Algorithm (classical histogram-scan-scatter LSB radix sort, Knuth Volume 3, Blelloch 1990 sort chapter): + +Sort proceeds digit-by-digit from the least significant byte upward. Each digit pass is three internal kernel launches: + +1. **Histogram pass** (``_radix_histogram_pass``). Every block computes its per-digit count for the current digit (8 + bits per pass at ``radix_bits=8``) into a 256-bin shared-memory histogram, then publishes it to global scratch laid + out **digit-major**: ``tile_histograms[d * num_blocks + b]``. +2. **Scan pass** (reuses ``_exclusive_scan_inplace_u32`` from ``_scan.py``). In-place exclusive scan of the flat + ``tile_histograms`` buffer. After this, ``tile_histograms[d * num_blocks + b]`` holds the global output position + of the first key in block ``b`` whose digit equals ``d``. The digit-major layout means a single 1-D scan over the + array suffices: the ordering "all digit-0 keys first, then digit-1, ..., and within each digit, in tile order" + is naturally encoded. +3. **Scatter pass** (``_radix_scatter_pass``). Each block re-reads its tile, computes per-thread ranks via + ``block.radix_rank_match_atomic_or``, looks up the per-(digit, block) global offset from the scanned + tile_histograms, subtracts the block-local ``excl_prefix[digit]`` to obtain the intra-digit offset, and scatters + ``keys_in[i] -> keys_out[offset + rank]``. Values, if provided, are scattered with the same indices. + +After each pass we swap (``keys_in`` ↔ ``keys_out``). Four passes for ``u32`` covering bits 0-31 - even, so the final +result lands back in the caller's ``keys`` buffer. + +**Twiddle for i32 / f32.** Radix sort sorts u32 bit patterns lexicographically. To get ascending ``i32`` and +``f32`` order, we apply the standard "sortable key" bit transforms before the first pass and inverse-transform after +the last pass: + +- ``u32``: identity. +- ``i32``: XOR sign bit (``0x80000000``) - maps two's-complement to monotone u32. +- ``f32``: if the sign bit is clear (positive), XOR ``0x80000000``; if set (negative), XOR ``0xFFFFFFFF``. Inverse uses + the *output* sign bit to pick the same masks back. + +Both twiddle and untwiddle are in-place over ``keys``; the user's data is restored to the same dtype on return. (NaN +handling is consistent with ``numpy.sort`` for the same input, but is not separately tested as part of first land.) + +**Out-of-range threads in the tail block.** When ``N % BLOCK_DIM != 0``, the final block has fewer valid keys than +threads. Out-of-range threads participate in the rank computation with a sentinel ``u32(0xFFFFFFFF)`` key (digit +``0xFF`` for any byte position), ensuring uniform control flow into ``block.radix_rank_match_atomic_or`` (which +requires every thread to participate). The histogram pass gates its atomic_add behind ``i < N``, so the sentinels do +not pollute the global histogram. The scatter pass gates its store behind ``i < N``, so the sentinels do not write +past ``keys_out``. + +The ranks of valid digit-``0xFF`` keys in the tail block are unaffected by sentinels because sentinels occupy the +highest thread indices and the rank is computed stably by thread index. + +**Scratch budget.** Each digit pass uses ``num_blocks * RADIX_DIGITS = N`` (rounded up to ``BLOCK_DIM`` granularity) +u32 slots in scratch for the tile_histograms, plus the partials buffers that the in-place exclusive scan introduces. +Total scratch footprint: ``≈ N * (1 + 1/256) u32 slots``. The default 5 MB scratch budget covers ``N ≤ ~1.3M`` +(qipc's hot path); for ``N = 1M`` (qipc's hot path) the caller must call ``quadrants._scratch.set_scratch_bytes(8 << +20)`` (or larger) before any algorithm runs. We raise a clear error when scratch is short rather than silently +scaling. +""" + +from quadrants._scratch import get_scratch_u32, scratch_capacity_u32 +from quadrants.lang.impl import static +from quadrants.lang.kernel_impl import kernel +from quadrants.lang.misc import loop_config +from quadrants.lang.ops import atomic_add, bit_cast +from quadrants.lang.simt import block as _block +from quadrants.lang.simt.reductions import _bin_add +from quadrants.types.annotations import template +from quadrants.types.primitive_types import f32, f64, i32, i64, u32, u64 + +from ._reduce import BLOCK_DIM, _identity_bits +from ._scan import _exclusive_scan_inplace_u32, _scan_total_scratch_slots + +_SUPPORTED_KEY_DTYPES_32 = (u32, i32, f32) +_SUPPORTED_KEY_DTYPES_64 = (u64, i64, f64) +_SUPPORTED_KEY_DTYPES = _SUPPORTED_KEY_DTYPES_32 + _SUPPORTED_KEY_DTYPES_64 +_SUPPORTED_VALUE_DTYPES = (u32, i32, f32, u64, i64, f64) + + +def _key_width_bits(dtype) -> int: + if dtype in _SUPPORTED_KEY_DTYPES_32: + return 32 + if dtype in _SUPPORTED_KEY_DTYPES_64: + return 64 + raise NotImplementedError(f"device_radix_sort key dtype {dtype} not supported") + + +RADIX_BITS = 8 +"""Bits per digit. Matches the ``block.radix_rank_match_atomic_or`` constraint that ``block_dim == 1 << radix_bits``; +with ``BLOCK_DIM = 256`` this is the only legal value.""" + +RADIX_DIGITS = 1 << RADIX_BITS # 256 + + +@kernel +def _twiddle_pass(keys: template(), N: i32, dtype: template(), do_twiddle: template()): + """In-place transform between caller-dtype keys and "sortable u32" keys. + + Set ``do_twiddle=True`` to map dtype -> u32 sort order at start of sort; ``False`` for the inverse at the end of + sort. Both directions write through ``bit_cast`` so the storage dtype is preserved. + + The two directions are encoded by the same kernel because their bodies differ only in which sign-bit (input's or + output's) selects the XOR mask - see the docstring on ``_radix_sort.py`` for the bit-twiddle table. + """ + loop_config(block_dim=BLOCK_DIM) + for i in range(N): + if static(dtype == u32): + pass + elif static(dtype == i32): + v = bit_cast(keys[i], u32) + keys[i] = bit_cast(v ^ u32(0x80000000), dtype) + else: + v = bit_cast(keys[i], u32) + if static(do_twiddle): + # f32 -> sort-u32: pick mask from *input* sign bit. + if (v & u32(0x80000000)) != u32(0): + keys[i] = bit_cast(v ^ u32(0xFFFFFFFF), dtype) + else: + keys[i] = bit_cast(v ^ u32(0x80000000), dtype) + else: + # sort-u32 -> f32: pick mask from *output* sign bit, which is the *opposite* of the sort-u32 sign + # bit (twiddle swaps them). + if (v & u32(0x80000000)) != u32(0): + keys[i] = bit_cast(v ^ u32(0x80000000), dtype) + else: + keys[i] = bit_cast(v ^ u32(0xFFFFFFFF), dtype) + + +@kernel +def _twiddle_pass_u64(keys: template(), N: i32, dtype: template(), do_twiddle: template()): + """64-bit sibling of :func:`_twiddle_pass`. + + Same monotonic-bit-pattern rules as the 32-bit case, just with the 64-bit sign bit / all-ones masks: + + - ``u64``: identity (no-op). + - ``i64``: XOR sign bit ``0x8000000000000000`` - maps two's-complement to monotone u64. + - ``f64``: positives XOR sign bit; negatives XOR all-ones. Inverse uses the *output* sign bit (same as f32). + """ + loop_config(block_dim=BLOCK_DIM) + for i in range(N): + if static(dtype == u64): + pass + elif static(dtype == i64): + v = bit_cast(keys[i], u64) + keys[i] = bit_cast(v ^ u64(0x8000000000000000), dtype) + else: + v = bit_cast(keys[i], u64) + if static(do_twiddle): + if (v & u64(0x8000000000000000)) != u64(0): + keys[i] = bit_cast(v ^ u64(0xFFFFFFFFFFFFFFFF), dtype) + else: + keys[i] = bit_cast(v ^ u64(0x8000000000000000), dtype) + else: + if (v & u64(0x8000000000000000)) != u64(0): + keys[i] = bit_cast(v ^ u64(0x8000000000000000), dtype) + else: + keys[i] = bit_cast(v ^ u64(0xFFFFFFFFFFFFFFFF), dtype) + + +@kernel +def _radix_histogram_pass( + keys: template(), + tile_histograms: template(), + histograms_off: i32, + N: i32, + num_blocks: i32, + bit_start: i32, + dtype: template(), +): + """Per-block histogram of digit ``(key >> bit_start) & 0xFF``. + + Writes to ``tile_histograms[histograms_off + d * num_blocks + b]`` (digit-major layout - see module docstring on + why). + + Out-of-range threads (in the tail block when ``N % BLOCK_DIM != 0``) do not contribute to the histogram. The + shared-mem zeroing and final write-out still cover all 256 digits. + """ + loop_config(block_dim=BLOCK_DIM) + total_threads = num_blocks * BLOCK_DIM + for i in range(total_threads): + tid = i % BLOCK_DIM + block_id = i // BLOCK_DIM + hist = _block.SharedArray((RADIX_DIGITS,), i32) + if tid < RADIX_DIGITS: + hist[tid] = i32(0) + _block.sync() + if i < N: + key = bit_cast(keys[i], u32) + digit = i32((key >> u32(bit_start)) & u32(RADIX_DIGITS - 1)) + atomic_add(hist[digit], i32(1)) + _block.sync() + if tid < RADIX_DIGITS: + tile_histograms[histograms_off + tid * num_blocks + block_id] = bit_cast(hist[tid], u32) + + +@kernel +def _radix_histogram_pass_u64( + keys: template(), + tile_histograms: template(), + histograms_off: i32, + N: i32, + num_blocks: i32, + bit_start: i32, + dtype: template(), +): + """64-bit sibling of :func:`_radix_histogram_pass`. + + Same algorithm, but the digit is extracted from a 64-bit key (``(key_u64 >> bit_start) & 0xFF``). The tile + histograms are still ``u32`` slots (each digit count fits in u32 since it's bounded by ``BLOCK_DIM = 256``); only + the key dtype changes. + """ + loop_config(block_dim=BLOCK_DIM) + total_threads = num_blocks * BLOCK_DIM + for i in range(total_threads): + tid = i % BLOCK_DIM + block_id = i // BLOCK_DIM + hist = _block.SharedArray((RADIX_DIGITS,), i32) + if tid < RADIX_DIGITS: + hist[tid] = i32(0) + _block.sync() + if i < N: + key = bit_cast(keys[i], u64) + digit = i32((key >> u64(bit_start)) & u64(RADIX_DIGITS - 1)) + atomic_add(hist[digit], i32(1)) + _block.sync() + if tid < RADIX_DIGITS: + tile_histograms[histograms_off + tid * num_blocks + block_id] = bit_cast(hist[tid], u32) + + +@kernel +def _radix_scatter_pass( + keys_in: template(), + keys_out: template(), + values_in: template(), + values_out: template(), + tile_histograms: template(), + histograms_off: i32, + N: i32, + num_blocks: i32, + bit_start: i32, + dtype: template(), + value_dtype: template(), + has_values: template(), +): + """Per-block radix rank + scatter to the global output position. + + For each thread: + - Read its key (or sentinel ``0xFFFFFFFF`` if past the tail). + - Compute its block-local rank via ``block.radix_rank_match_atomic_or``, which also fills shared ``bins`` and + ``excl_prefix`` arrays. + - Compute the global destination as + ``tile_histograms[digit * num_blocks + block_id] + (rank - excl_prefix[digit])``. (The subtraction normalizes + ``rank`` from "position among all keys of any digit in this block" to "position among only the digit-d keys + in this block".) + - Scatter ``keys_in[i] -> keys_out[dst]`` and, if values were passed, ``values_in[i] -> values_out[dst]``. + """ + loop_config(block_dim=BLOCK_DIM) + total_threads = num_blocks * BLOCK_DIM + for i in range(total_threads): + tid = i % BLOCK_DIM + block_id = i // BLOCK_DIM + bins = _block.SharedArray((RADIX_DIGITS,), i32) + excl_prefix = _block.SharedArray((RADIX_DIGITS,), i32) + block_offsets = _block.SharedArray((RADIX_DIGITS,), i32) + key = u32(0xFFFFFFFF) + if i < N: + key = bit_cast(keys_in[i], u32) + rank = _block.radix_rank_match_atomic_or(key, BLOCK_DIM, RADIX_BITS, bit_start, RADIX_BITS, bins, excl_prefix) + digit = i32((key >> u32(bit_start)) & u32(RADIX_DIGITS - 1)) + if tid < RADIX_DIGITS: + global_off = bit_cast(tile_histograms[histograms_off + tid * num_blocks + block_id], i32) + block_offsets[tid] = global_off - excl_prefix[tid] + _block.sync() + if i < N: + dst = block_offsets[digit] + rank + keys_out[dst] = bit_cast(key, dtype) + if static(has_values): + values_out[dst] = values_in[i] + + +@kernel +def _radix_scatter_pass_u64( + keys_in: template(), + keys_out: template(), + values_in: template(), + values_out: template(), + tile_histograms: template(), + histograms_off: i32, + N: i32, + num_blocks: i32, + bit_start: i32, + dtype: template(), + value_dtype: template(), + has_values: template(), +): + """64-bit sibling of :func:`_radix_scatter_pass`. + + The block primitive ``block.radix_rank_match_atomic_or`` only looks at the 8-bit digit, so we extract the digit + from the ``u64`` key into a ``u32`` and feed it to the existing block primitive unchanged (with ``bit_start = 0`` + on the primitive side - we've already shifted). The full ``u64`` key is scattered to ``keys_out``. + """ + loop_config(block_dim=BLOCK_DIM) + total_threads = num_blocks * BLOCK_DIM + for i in range(total_threads): + tid = i % BLOCK_DIM + block_id = i // BLOCK_DIM + bins = _block.SharedArray((RADIX_DIGITS,), i32) + excl_prefix = _block.SharedArray((RADIX_DIGITS,), i32) + block_offsets = _block.SharedArray((RADIX_DIGITS,), i32) + # Sentinel ``0xFFFFFFFFFFFFFFFF`` for out-of-range threads (digit ``0xFF`` for every byte position). + key = u64(0xFFFFFFFFFFFFFFFF) + if i < N: + key = bit_cast(keys_in[i], u64) + digit_only_u32 = u32((key >> u64(bit_start)) & u64(RADIX_DIGITS - 1)) + # Feed the pre-extracted digit at ``bit_start=0`` so the u32-key block primitive sees a digit-in-low-byte + # u32 and ranks correctly without needing a u64-aware variant of itself. + rank = _block.radix_rank_match_atomic_or( + digit_only_u32, BLOCK_DIM, RADIX_BITS, 0, RADIX_BITS, bins, excl_prefix + ) + digit = i32(digit_only_u32) + if tid < RADIX_DIGITS: + global_off = bit_cast(tile_histograms[histograms_off + tid * num_blocks + block_id], i32) + block_offsets[tid] = global_off - excl_prefix[tid] + _block.sync() + if i < N: + dst = block_offsets[digit] + rank + keys_out[dst] = bit_cast(key, dtype) + if static(has_values): + values_out[dst] = values_in[i] + + +def _validate_inputs(keys, tmp_keys, values, tmp_values, end_bit): + if not hasattr(keys, "shape") or len(keys.shape) != 1: + raise TypeError(f"device_radix_sort expects 1-D keys; got shape {getattr(keys, 'shape', None)}") + if not hasattr(tmp_keys, "shape") or tmp_keys.shape != keys.shape: + raise TypeError( + f"device_radix_sort expects tmp_keys.shape == keys.shape; got " + f"keys={keys.shape}, tmp_keys={tmp_keys.shape}" + ) + if tmp_keys.dtype != keys.dtype: + raise TypeError(f"device_radix_sort dtype mismatch: keys={keys.dtype}, tmp_keys={tmp_keys.dtype}") + if keys is tmp_keys: + raise ValueError("device_radix_sort requires keys and tmp_keys to be distinct buffers") + if keys.dtype not in _SUPPORTED_KEY_DTYPES: + raise NotImplementedError( + f"device_radix_sort key dtype {keys.dtype} not in first-land set " + f"{[d for d in _SUPPORTED_KEY_DTYPES]}; see design doc dtype matrix" + ) + + if (values is None) != (tmp_values is None): + raise ValueError( + "device_radix_sort: values and tmp_values must be passed together (both or neither). " + f"Got values={'provided' if values is not None else 'None'}, " + f"tmp_values={'provided' if tmp_values is not None else 'None'}" + ) + if values is not None: + if not hasattr(values, "shape") or values.shape != keys.shape: + raise TypeError( + f"device_radix_sort expects values.shape == keys.shape; got " + f"keys={keys.shape}, values={values.shape}" + ) + if tmp_values.shape != values.shape: + raise TypeError( + f"device_radix_sort expects tmp_values.shape == values.shape; got " + f"values={values.shape}, tmp_values={tmp_values.shape}" + ) + if tmp_values.dtype != values.dtype: + raise TypeError(f"device_radix_sort dtype mismatch: values={values.dtype}, tmp_values={tmp_values.dtype}") + if values is tmp_values: + raise ValueError("device_radix_sort requires values and tmp_values to be distinct buffers") + if values.dtype not in _SUPPORTED_VALUE_DTYPES: + raise NotImplementedError( + f"device_radix_sort value dtype {values.dtype} not in first-land set " + f"{[d for d in _SUPPORTED_VALUE_DTYPES]}; see design doc dtype matrix" + ) + + key_width = _key_width_bits(keys.dtype) + if end_bit <= 0 or end_bit > key_width: + raise ValueError( + f"device_radix_sort end_bit must satisfy 0 < end_bit <= {key_width} (key dtype width); got {end_bit}" + ) + if end_bit % RADIX_BITS != 0: + raise ValueError( + f"device_radix_sort end_bit must be a multiple of {RADIX_BITS} so that an even number of digit passes " + f"leaves the result back in `keys`; got end_bit={end_bit}" + ) + num_passes = end_bit // RADIX_BITS + if num_passes % 2 != 0: + raise ValueError( + f"device_radix_sort needs an even number of digit passes (so the ping-pong lands back in `keys`); " + f"got num_passes={num_passes} for end_bit={end_bit}, RADIX_BITS={RADIX_BITS}" + ) + + +def device_radix_sort( + keys, tmp_keys, values=None, tmp_values=None, end_bit=None +): # pylint: disable=too-many-locals,too-many-branches,too-many-statements + """Sort ``keys`` ascending on the device using LSB radix sort. + + Args: + keys: 1-D tensor of ``u32`` / ``i32`` / ``f32`` (4-byte key path) or ``u64`` / ``i64`` / ``f64`` (8-byte key + path). Sorted in place. Pass a ``qd.field``, ``qd.ndarray``, or ``qd.Tensor`` wrapper. + tmp_keys: 1-D tensor with the same shape and dtype as ``keys``, distinct buffer. Used as a ping-pong workspace; + its contents at return are intermediate and should be considered garbage. + values: optional 1-D tensor of any supported scalar dtype, same shape as ``keys`` (the value dtype is + independent of the key dtype). If provided, values are permuted in lock-step with keys (key-value sort), + in place. + tmp_values: required iff ``values`` is provided. Same shape and dtype as ``values``, distinct buffer; same + workspace semantics as ``tmp_keys``. + end_bit: number of low bits of the key to consider. Defaults to the full key width (32 for 4-byte keys, 64 + for 8-byte keys). Must be a non-zero multiple of ``RADIX_BITS = 8`` so that an even number of digit + passes leaves the result in ``keys``. Pass a smaller value if the high bits are known to be zero (saves + passes). + + Sort order matches ``numpy.sort`` for ascending sort (signed-int two's-complement, IEEE-754 floats with negatives + ordered before positives, NaN handling matches numpy). + + Built on ``block.radix_rank_match_atomic_or`` (which is wave64-clean as of ``cd9e546851``) + the shared + ``Field(u32)`` scratch. The first land is classical histogram-scan-scatter LSB; a single-pass decoupled-lookback + variant (Onesweep) is a perf follow-up if profiling shows sort in the top of qipc's frame budget. + + **Scratch budget**: requires ``ceil(N / BLOCK_DIM) * RADIX_DIGITS + ...`` u32 slots in the shared scratch (see + module docstring on ``_radix_sort.py`` for the exact formula). The histograms are u32 regardless of key width, + so 8-byte-key sorts have the same scratch footprint as 4-byte ones (the key dtype only affects digit extraction + and scatter). The default 5 MB scratch caps ``N`` at ~1.3M; raise the budget via ``set_scratch_bytes`` for larger. + """ + if end_bit is None: + end_bit = _key_width_bits(keys.dtype) if keys.dtype in _SUPPORTED_KEY_DTYPES else 32 + _validate_inputs(keys, tmp_keys, values, tmp_values, end_bit) + N = keys.shape[0] + if N <= 1: + return + + key_dtype = keys.dtype + key_width = _key_width_bits(key_dtype) + has_values = values is not None + # Provide a non-None placeholder for values_* even when has_values=False so the kernel's template-key includes a + # real tensor type; the kernel body itself guards on `has_values` so the tensors are never actually dereferenced. + values_in_arg = values if has_values else keys + tmp_values_arg = tmp_values if has_values else tmp_keys + value_dtype = values.dtype if has_values else key_dtype + + num_blocks = (N + BLOCK_DIM - 1) // BLOCK_DIM + hist_len = num_blocks * RADIX_DIGITS # u32 slots for the per-pass tile_histograms + + scratch = get_scratch_u32() + cap = scratch_capacity_u32() + # Scratch layout: scratch[0 : hist_len] = current pass's tile_histograms. The in-place scan over + # scratch[0 : hist_len] sub-allocates partials from scratch[hist_len : ...] for *all* of its recursive levels. + # We must account for the full recursive footprint up front (via ``_scan_total_scratch_slots``): otherwise we + # accept budgets that pass a single-level estimate but blow up mid-recursion, and ``_twiddle_pass`` below will + # have already mutated the user's keys in place by the time the recursive ``RuntimeError`` fires - leaving the + # caller with corrupted ``keys`` and no recovery path. + needed = _scan_total_scratch_slots(hist_len, partials_cursor=hist_len) + if needed > cap: + raise RuntimeError( + f"device_radix_sort on N={N} needs >= {needed} u32 scratch slots " + f"({needed * 4} bytes, including all levels of the in-place scan recursion), but only {cap} are " + f"configured ({cap * 4} bytes). Call quadrants._scratch.set_scratch_bytes(...) before any algorithm " + f"runs to raise the cap. For N=1M expect to need ~5 MB; for N=10M ~50 MB." + ) + + # Pre-twiddle keys (in-place) for signed-int / float. Unsigned-int path is a no-op. + if key_dtype in (i32, f32): + _twiddle_pass(keys, N, key_dtype, True) + elif key_dtype in (i64, f64): + _twiddle_pass_u64(keys, N, key_dtype, True) + + identity_bits = _identity_bits(0, u32) + src = keys + dst = tmp_keys + src_values = values_in_arg + dst_values = tmp_values_arg + num_passes = end_bit // RADIX_BITS + histogram_kernel = _radix_histogram_pass if key_width == 32 else _radix_histogram_pass_u64 + scatter_kernel = _radix_scatter_pass if key_width == 32 else _radix_scatter_pass_u64 + for p in range(num_passes): + bit_start = p * RADIX_BITS + # Pass A: per-block histograms into scratch[0 : hist_len]. + histogram_kernel(src, scratch, 0, N, num_blocks, bit_start, key_dtype) + # Pass B: in-place exclusive scan of scratch[0 : hist_len]. + _exclusive_scan_inplace_u32(scratch, 0, hist_len, identity_bits, _bin_add, u32, hist_len) + # Pass C: scatter from src -> dst using the scanned histograms. + scatter_kernel( + src, + dst, + src_values, + dst_values, + scratch, + 0, + N, + num_blocks, + bit_start, + key_dtype, + value_dtype, + has_values, + ) + src, dst = dst, src + src_values, dst_values = dst_values, src_values + + # After an even number of swaps, the sorted result is back in `keys`. + if key_dtype in (i32, f32): + _twiddle_pass(keys, N, key_dtype, False) + elif key_dtype in (i64, f64): + _twiddle_pass_u64(keys, N, key_dtype, False) + + +__all__ = ["device_radix_sort"] diff --git a/python/quadrants/algorithms/_reduce.py b/python/quadrants/algorithms/_reduce.py new file mode 100644 index 0000000000..3660d5d700 --- /dev/null +++ b/python/quadrants/algorithms/_reduce.py @@ -0,0 +1,402 @@ +# type: ignore +"""Device-wide reduce primitives. + +Implements ``qd.algorithms.device_reduce_{add,min,max}`` on top of the block-tier ``block.reduce_{add,min,max}`` +primitives. See the design doc at ``perso_hugh/doc/qipc/qipc_device_algos_design.md`` for the algorithmic rationale. + +Layout (host driver builds a recursion plan, kernels are the per-pass workers): + +- **First pass** reads the caller's input tensor (of the algorithm's ``dtype``) and writes per-block partials to the + shared scratch field as ``u32`` via ``qd.bit_cast``. +- **Intermediate passes** (only needed when ``N`` is large enough to require more than two passes total - i.e. + ``B0 > BLOCK_DIM``) read from one slice of scratch (``u32`` → ``dtype`` via ``qd.bit_cast``) and write to another + slice (``dtype`` → ``u32`` via ``qd.bit_cast``). +- **Last pass** reduces to a single value and writes it directly to the caller's ``out`` tensor as ``dtype`` (no + bit_cast on the write side). + +A single generic kernel handles every pass; ``src_is_u32`` and ``dst_is_u32`` are compile-time template flags +selecting between the bit_cast and direct-read / direct-write paths. + +The shared scratch field is owned by ``quadrants._scratch`` (see that module). The default 5 MB capacity covers +reductions well past ``N = 64M`` elements at ``BLOCK_DIM=256`` (reduce uses only ``~N / BLOCK_DIM`` slots). + +The reduce monoid identity (e.g. ``+inf`` for ``min`` over ``f32``, ``2**31 - 1`` for ``min`` over ``i32``) is passed +to the kernel as its raw 4-byte bit pattern in a ``u32`` runtime arg, then ``qd.bit_cast``-ed to ``dtype`` inside the +kernel. This bypasses the ``default_ip`` overflow check that ``cast(literal, dtype)`` would otherwise hit on the wider +unsigned identities, and keeps ``identity`` out of the kernel template key (one fewer axis of cache fragmentation). +""" + +import struct + +from quadrants._scratch import ( + get_scratch_u32, + get_scratch_u64, + scratch_capacity_u32, + scratch_capacity_u64, +) +from quadrants.lang.impl import static +from quadrants.lang.kernel_impl import kernel +from quadrants.lang.misc import loop_config +from quadrants.lang.ops import bit_cast +from quadrants.lang.simt import block as _block +from quadrants.lang.simt.reductions import _bin_add, _bin_max, _bin_min +from quadrants.types.annotations import template +from quadrants.types.primitive_types import ( + f32, + f64, + i32, + i64, + u32, + u64, +) + +BLOCK_DIM = 256 +"""Threads per block for every device reduce / scan kernel. + +Chosen as a portable default: a multiple of every supported subgroup size (32 on CUDA / Vulkan-on-NV / Metal, 64 on +AMDGPU), and small enough to fit comfortably in shared memory budgets across backends. Re-tune (128 / 512) once +benchmarks land per the design doc's open questions. +""" + +_SUPPORTED_DTYPES_4B = (i32, u32, f32) +_SUPPORTED_DTYPES_8B = (i64, u64, f64) +_SUPPORTED_DTYPES = _SUPPORTED_DTYPES_4B + _SUPPORTED_DTYPES_8B + + +def _dtype_width_bytes(dtype) -> int: + """Return the byte width of ``dtype``: 4 for ``{i32, u32, f32}``, 8 for ``{i64, u64, f64}``. Raises for any + other dtype. + """ + if dtype in _SUPPORTED_DTYPES_4B: + return 4 + if dtype in _SUPPORTED_DTYPES_8B: + return 8 + raise NotImplementedError(f"device reduce dtype {dtype} not supported") + + +def _identity_bits(value, dtype) -> int: + """Reinterpret-cast ``value`` to its unsigned bit pattern: ``u32`` for 4-byte dtypes, ``u64`` for 8-byte. + + Used to ferry monoid identities (e.g. ``+inf`` for ``min`` over ``f32``, ``2**31 - 1`` for ``min`` over ``i32``, + ``+inf`` over ``f64``, ``2**63 - 1`` over ``i64``) into the reduce kernel as a runtime arg, sidestepping the + ``default_ip`` overflow check that ``cast(literal, dtype)`` would hit on wide unsigned identities. + """ + if dtype == u32: + return int(value) & 0xFFFFFFFF + if dtype == i32: + return struct.unpack(" 1: + prev = sizes[-1] + sizes.append((prev + BLOCK_DIM - 1) // BLOCK_DIM) + num_passes = len(sizes) - 1 + dst_offsets = [] + cumul = 0 + for k in range(num_passes): + if k == num_passes - 1: + dst_offsets.append(-1) + else: + dst_offsets.append(cumul) + cumul += sizes[k + 1] + return sizes, dst_offsets, cumul + + +def _device_reduce(arr, *, out, op, identity_value): + """Internal driver shared by ``device_reduce_{add,min,max}``. + + Dispatches on ``arr.dtype`` width: 4-byte dtypes go through the ``Field(u32)`` scratch and ``_reduce_pass``; + 8-byte dtypes go through the ``Field(u64)`` scratch and ``_reduce_pass_u64``. Everything else (control flow, + recursion plan, identity ferrying) is shared. + """ + if not hasattr(arr, "shape") or len(arr.shape) != 1: + raise TypeError(f"device reduce expects a 1-D input tensor; got shape {getattr(arr, 'shape', None)}") + if not hasattr(out, "shape") or out.shape != (1,): + raise TypeError(f"device reduce expects out.shape == (1,); got {out.shape}") + if arr.dtype != out.dtype: + raise TypeError(f"device reduce dtype mismatch: arr={arr.dtype}, out={out.dtype}") + dtype = arr.dtype + if dtype not in _SUPPORTED_DTYPES: + raise NotImplementedError( + f"device reduce dtype {dtype} not supported (need one of " + f"{[d for d in _SUPPORTED_DTYPES]}); see design doc dtype matrix" + ) + width = _dtype_width_bytes(dtype) + + N = arr.shape[0] + sizes, dst_offsets, total_scratch = _plan_levels(N) + + scratch_cap = scratch_capacity_u32() if width == 4 else scratch_capacity_u64() + if total_scratch > scratch_cap: + raise RuntimeError( + f"device reduce on N={N} (dtype={dtype}) needs {total_scratch} " + f"u{width * 8} scratch slots, but only {scratch_cap} are configured. " + f"Call quadrants._scratch.set_scratch_bytes(...) before any algorithm runs to raise the cap." + ) + + num_passes = len(sizes) - 1 + identity_bits = _identity_bits(identity_value, dtype) + + if num_passes == 0: + # Trivially short input (N == 0 or N == 1): no reduce kernel needed. N == 0: write `identity` to out[0]; + # N == 1: out[0] = arr[0]. + _device_reduce_trivial(arr, out=out, identity_bits=identity_bits) + return + + scratch = get_scratch_u32() if width == 4 else get_scratch_u64() + pass_kernel = _reduce_pass if width == 4 else _reduce_pass_u64 + + for k in range(num_passes): + n_in = sizes[k] + n_out = sizes[k + 1] + total_threads = n_out * BLOCK_DIM + is_first = k == 0 + is_last = k == num_passes - 1 + src = arr if is_first else scratch + dst = out if is_last else scratch + src_off = 0 if is_first else _src_off(k, dst_offsets) + dst_off = 0 if is_last else dst_offsets[k] + pass_kernel( + src, + dst, + src_off, + dst_off, + n_in, + total_threads, + identity_bits, + op, + dtype, + not is_first, + not is_last, + ) + + +def _src_off(k: int, dst_offsets): + """Source offset for pass ``k`` (k >= 1): equals the dst offset that pass ``k - 1`` wrote to.""" + return dst_offsets[k - 1] + + +@kernel +def _trivial_write_arr(arr: template(), out: template()): + """N == 1 path: copy arr[0] to out[0]. Two-element kernel keeps the host driver loop-free for the trivial case.""" + for _ in range(1): + out[0] = arr[0] + + +@kernel +def _trivial_write_identity(out: template(), identity_bits: u32, dtype: template()): + """N == 0 path: write the monoid identity (as a u32 bit pattern) to out[0]. + + Quadrants doesn't support 0-shape tensors today, so this path is currently unreachable from a caller - left in + place for defensiveness against future 0-length support. + """ + for _ in range(1): + out[0] = bit_cast(identity_bits, dtype) + + +def _device_reduce_trivial(arr, *, out, identity_bits): + N = arr.shape[0] + if N == 0: + _trivial_write_identity(out, identity_bits, out.dtype) + elif N == 1: + _trivial_write_arr(arr, out) + else: + raise AssertionError(f"_device_reduce_trivial called with N={N}") + + +def device_reduce_add(arr, out): + """Compute ``out[0] = sum(arr)`` on the device. + + Args: + arr: 1-D tensor of any supported scalar dtype - ``{i32, u32, f32, i64, u64, f64}``. Pass a ``qd.field``, + ``qd.ndarray``, or ``qd.Tensor`` wrapper around either. + out: 1-element tensor of the same dtype as ``arr``. Caller-supplied so the call is fully asynchronous - no + implicit device-to-host sync. To get a Python scalar, do ``out.to_numpy()[0]`` explicitly after this + call. + + The implementation is a two-or-more-pass tree reduction built on ``block.reduce_add``. Scratch is drawn from the + quadrants-level shared scratch field (``Field(u32)`` for 4-byte dtypes, ``Field(u64)`` for 8-byte); no per-call + allocation. See the design doc at ``perso_hugh/doc/qipc/qipc_device_algos_design.md`` for the recursion plan and + the ``bit_cast``-into-scratch scheme. + """ + _device_reduce(arr, out=out, op=_bin_add, identity_value=0) + + +def device_reduce_min(arr, out): + """Compute ``out[0] = min(arr)`` on the device. + + Args: + arr: see ``device_reduce_add`` (any of ``{i32, u32, f32, i64, u64, f64}``). + out: see ``device_reduce_add``. + + The monoid identity is derived from ``arr.dtype`` automatically (the largest representable value: + ``+inf`` for ``f32`` / ``f64``, ``INT32_MAX`` / ``INT64_MAX`` for signed ints, ``UINT32_MAX`` / ``UINT64_MAX`` + for unsigned). Mirrors the ``block.reduce_min`` / ``subgroup.reduce_min`` contract: the typed reduce + primitives do not take an identity argument because (op, dtype) fixes it. + """ + _device_reduce(arr, out=out, op=_bin_min, identity_value=_min_identity(arr.dtype)) + + +def device_reduce_max(arr, out): + """Compute ``out[0] = max(arr)`` on the device. Mirror of :func:`device_reduce_min` with ``max`` and the + dtype's *negative* extremum (``-inf`` for floats, ``INT32_MIN`` / ``INT64_MIN`` for signed ints, ``0`` for + unsigned ints), again derived from ``arr.dtype`` automatically. + """ + _device_reduce(arr, out=out, op=_bin_max, identity_value=_max_identity(arr.dtype)) + + +__all__ = ["device_reduce_add", "device_reduce_min", "device_reduce_max"] diff --git a/python/quadrants/algorithms/_reduce_by_key.py b/python/quadrants/algorithms/_reduce_by_key.py new file mode 100644 index 0000000000..8f1cfe4273 --- /dev/null +++ b/python/quadrants/algorithms/_reduce_by_key.py @@ -0,0 +1,282 @@ +# type: ignore +"""Device-wide reduce-by-key. + +Implements ``qd.algorithms.device_reduce_by_key_add`` on top of the existing device exclusive scan internals and the +shared ``Field(u32)`` scratch. + +Reduce-by-key takes two parallel 1-D tensors - ``keys`` and ``values`` - and collapses every **consecutive run of +equal keys** into a single output entry ``(unique_key, sum_of_values_in_run)``. Keys that are equal but separated by +other keys are treated as separate runs. To compute a global per-key sum, sort by key first (e.g. via +``qd.algorithms.device_radix_sort``) and then reduce-by-key. + +Algorithm (scan + scatter; no segmented-scan primitive needed): + +1. **Head-flag pass** (``_rbk_head_flags``). Compute ``head_flags[i] = 1`` if ``i == 0 or keys[i] != keys[i-1]``, else + ``0``, directly into the shared ``Field(u32)`` scratch ``scratch[0:N]`` (storing the ``i32`` flag bit-cast to + ``u32``). +2. **Exclusive scan of head_flags** (in-place over ``scratch[0:N]``, using ``_reduce_pass`` + + ``_exclusive_scan_inplace_u32`` + ``_scan_pass3`` reused from ``_reduce.py`` / ``_scan.py``). After this, + ``scratch[i] = exclusive_scan(head_flags)[i] = sum(head_flags[0:i])``. The 0-indexed run index of element ``i`` is + then ``positions[i] = scratch[i] + head_flag(i) - 1`` (i.e. ``inclusive_scan(head_flags)[i] - 1``); the scatter + pass recomputes ``head_flag(i)`` from the two keys at ``i`` and ``i - 1`` so the ``head_flags`` array itself does + not need to survive the scan. This lets the scan run in place, holding scratch to ~``1.004 * N`` slots. +3. **Zero-init values_out**. The scatter step uses ``atomic_add`` on ``values_out[positions[i]]``; the slots must + start at the additive identity ``0``. +4. **Scatter pass** (``_rbk_scatter``). For every ``i``: + - Recompute ``head_flag(i)`` from ``i == 0 or keys[i] != keys[i-1]`` and compute the run index ``pos = scratch[i] + + head_flag(i) - 1``. + - ``keys_out[pos] = keys[i]`` - race-free because every thread in a run writes the same key to the same slot. + - ``atomic_add(values_out[pos], values[i])`` folds the run's values into the run's output slot. +5. **Count pass** (``_rbk_count``). Computes ``num_runs[0] = scratch[N-1] + head_flag(N-1)`` where the head flag at + ``N-1`` is recomputed from ``keys[N-1] != keys[N-2]`` for ``N >= 2`` (``1`` for ``N == 1``). + +This first-land scope supports only the ``add`` reduction. ``min`` / ``max`` variants would need ``atomic_min`` / +``atomic_max``, which have spottier cross-backend support for ``f32`` - defer to a follow-up gated on real qipc usage. + +Scratch budget: ``N + ceil(N / 256) + ...`` ``u32`` slots, ≈ ``1.004 * N``. The default 5 MB scratch covers ``N`` up +to ~1.3M. For larger ``N``, raise via ``quadrants._scratch.set_scratch_bytes(...)`` before any algorithm call. +""" + +from quadrants._scratch import get_scratch_u32, scratch_capacity_u32 +from quadrants.lang.kernel_impl import kernel +from quadrants.lang.misc import loop_config +from quadrants.lang.ops import atomic_add, bit_cast +from quadrants.lang.simt.reductions import _bin_add +from quadrants.types.annotations import template +from quadrants.types.primitive_types import f32, i32, u32 + +from ._reduce import BLOCK_DIM, _identity_bits, _reduce_pass +from ._scan import _exclusive_scan_inplace_u32, _scan_pass3 + +_SUPPORTED_KEY_DTYPES = (u32, i32, f32) +_SUPPORTED_VALUE_DTYPES = (u32, i32, f32) + + +@kernel +def _rbk_head_flags(keys_in: template(), head_flags: template(), head_flags_off: i32, N: i32): + """Write ``head_flags[i] = 1 if (i == 0 or keys[i] != keys[i-1]) else 0`` to ``head_flags[head_flags_off + i]`` + (as the u32 bit pattern of i32). + + Linear-time, embarrassingly parallel: each thread reads at most two key elements (``keys[i]`` and ``keys[i-1]``) + and writes one flag. The boundary thread at ``i == 0`` always writes ``1`` since there is no predecessor and a run + trivially starts there. + """ + loop_config(block_dim=BLOCK_DIM) + for i in range(N): + flag = i32(0) + if i == 0: + flag = i32(1) + else: + if keys_in[i] != keys_in[i - 1]: + flag = i32(1) + head_flags[head_flags_off + i] = bit_cast(flag, u32) + + +@kernel +def _rbk_zero_values_out(values_out: template(), N: i32, dtype: template()): + """Set ``values_out[0 : N] = 0`` so the scatter ``atomic_add`` lands onto a clean additive identity. ``N`` is the + upper bound on ``num_runs``; the caller-supplied ``values_out`` may be longer but we only need the prefix that the + scatter can touch. + + We write ``bit_cast(u32(0), dtype)`` rather than relying on ``v - v == 0`` because the latter compiles to a real + subtract for ``f32`` (and yields NaN if the slot held NaN garbage from a prior allocation), whereas the bit-cast + lowers to a plain store. + """ + for i in range(N): + values_out[i] = bit_cast(u32(0), dtype) + + +@kernel +def _rbk_scatter( + keys_in: template(), + values_in: template(), + positions: template(), + positions_off: i32, + keys_out: template(), + values_out: template(), + N: i32, +): + """Per-element scatter: + + - Compute ``head_flag(i)`` on the fly from ``i == 0 or keys[i] != keys[i-1]`` and combine with the in-place + exclusive scan stored in ``positions`` to recover the inclusive run index + ``pos = positions[i] + head_flag(i) - 1``. + - ``keys_out[pos] = keys_in[i]`` - race-free because every thread in a run writes the same key to the same slot. + - ``atomic_add(values_out[pos], values_in[i])`` - folds the run's values into the run's output slot. + ``values_out`` must be pre-zeroed (see ``_rbk_zero_values_out``). + """ + for i in range(N): + head_i = i32(0) + if i == 0: + head_i = i32(1) + else: + if keys_in[i] != keys_in[i - 1]: + head_i = i32(1) + pos = bit_cast(positions[positions_off + i], i32) + head_i - i32(1) + keys_out[pos] = keys_in[i] + atomic_add(values_out[pos], values_in[i]) + + +@kernel +def _rbk_count(keys_in: template(), positions: template(), positions_off: i32, N: i32, num_runs: template()): + """One-thread tail kernel: write ``num_runs[0] = total head_flag count``. + + Equivalently: ``num_runs = exclusive_scan_at(N-1) + head_flag(N-1) = inclusive_scan_at(N-1) = + total_head_flags``. We can't read ``scratch[N-1]`` for the original head flag (the in-place scan overwrote it + with the exclusive prefix), so we recompute the flag from the last two keys. For ``N == 1``, + ``head_flag(0) == 1`` so ``num_runs = 0 + 1 = 1``. + """ + for _ in range(1): + pos_last = bit_cast(positions[positions_off + N - 1], i32) + head_last = i32(0) + if N == 1: + head_last = i32(1) + else: + if keys_in[N - 1] != keys_in[N - 2]: + head_last = i32(1) + num_runs[0] = pos_last + head_last + + +def _validate_inputs(keys_in, values_in, keys_out, values_out, num_runs): + if not hasattr(keys_in, "shape") or len(keys_in.shape) != 1: + raise TypeError(f"device_reduce_by_key_add expects 1-D keys_in; got shape {getattr(keys_in, 'shape', None)}") + if not hasattr(values_in, "shape") or values_in.shape != keys_in.shape: + raise TypeError( + f"device_reduce_by_key_add expects values_in.shape == keys_in.shape; got " + f"keys_in={keys_in.shape}, values_in={values_in.shape}" + ) + if not hasattr(keys_out, "shape") or len(keys_out.shape) != 1: + raise TypeError(f"device_reduce_by_key_add expects 1-D keys_out; got shape {getattr(keys_out, 'shape', None)}") + if keys_out.dtype != keys_in.dtype: + raise TypeError(f"device_reduce_by_key_add dtype mismatch: keys_in={keys_in.dtype}, keys_out={keys_out.dtype}") + if not hasattr(values_out, "shape") or len(values_out.shape) != 1: + raise TypeError( + f"device_reduce_by_key_add expects 1-D values_out; got shape {getattr(values_out, 'shape', None)}" + ) + if values_out.dtype != values_in.dtype: + raise TypeError( + f"device_reduce_by_key_add dtype mismatch: values_in={values_in.dtype}, values_out={values_out.dtype}" + ) + if keys_out.shape[0] < keys_in.shape[0]: + raise ValueError( + f"device_reduce_by_key_add keys_out.shape[0]={keys_out.shape[0]} < keys_in.shape[0]={keys_in.shape[0]}; " + f"keys_out must hold at least N entries (worst case: every key is unique)" + ) + if values_out.shape[0] < values_in.shape[0]: + raise ValueError( + f"device_reduce_by_key_add values_out.shape[0]={values_out.shape[0]} < values_in.shape[0]={values_in.shape[0]}; " + f"values_out must hold at least N entries" + ) + if not hasattr(num_runs, "shape") or num_runs.shape != (1,): + raise TypeError(f"device_reduce_by_key_add expects num_runs.shape == (1,); got {num_runs.shape}") + if num_runs.dtype != i32: + raise TypeError(f"device_reduce_by_key_add expects num_runs.dtype == qd.i32; got {num_runs.dtype}") + if keys_in.dtype not in _SUPPORTED_KEY_DTYPES: + raise NotImplementedError( + f"device_reduce_by_key_add keys dtype {keys_in.dtype} not in first-land set " + f"{[d for d in _SUPPORTED_KEY_DTYPES]}; see design doc dtype matrix" + ) + if values_in.dtype not in _SUPPORTED_VALUE_DTYPES: + raise NotImplementedError( + f"device_reduce_by_key_add values dtype {values_in.dtype} not in first-land set " + f"{[d for d in _SUPPORTED_VALUE_DTYPES]}; see design doc dtype matrix" + ) + + +def device_reduce_by_key_add(keys_in, values_in, keys_out, values_out, num_runs): + """Collapse every consecutive run of equal ``keys_in`` into ``(key, sum_of_values)``. + + Args: + keys_in: 1-D tensor of ``u32`` / ``i32`` / ``f32``. Sort by key beforehand (e.g. via + ``qd.algorithms.device_radix_sort``) if you need a global per-key sum rather than a per-run sum. + values_in: 1-D tensor of ``u32`` / ``i32`` / ``f32``, same shape as ``keys_in``. + keys_out: 1-D tensor of the same dtype as ``keys_in``, capacity ``>= N``. Receives the unique-run keys at + indices ``[0 : num_runs[0])``; the tail is left untouched. + values_out: 1-D tensor of the same dtype as ``values_in``, capacity ``>= N``. Receives the per-run sums. The + first ``num_runs[0]`` slots are overwritten; if ``values_out`` was longer, the tail past that prefix is + left untouched. + num_runs: 1-element ``i32`` tensor receiving the number of runs. + + Same async / no-implicit-sync contract as the rest of ``qd.algorithms.*``: ``num_runs`` is a tensor (not a Python + int); fetch the count with ``int(num_runs.to_numpy()[0])`` after the call. + + **NaN handling for f32 keys**: NaN ``!=`` NaN is true, so each NaN becomes its own run. This is consistent with + treating NaN as "different from everything", which matches the run-length-encoding spirit of reduce-by-key. + + **Scratch budget**: ~``1.004 * N`` u32 slots. Default 5 MB covers ``N`` up to ~1.3M; raise via + ``quadrants._scratch.set_scratch_bytes(...)`` for larger inputs. + """ + _validate_inputs(keys_in, values_in, keys_out, values_out, num_runs) + N = keys_in.shape[0] + if N == 0: + return + + scratch = get_scratch_u32() + cap = scratch_capacity_u32() + B0 = (N + BLOCK_DIM - 1) // BLOCK_DIM + positions_off = 0 + partials_off = N + if partials_off + B0 > cap: + raise RuntimeError( + f"device_reduce_by_key_add on N={N} needs >= {partials_off + B0} u32 scratch slots, " + f"but only {cap} are configured. Call quadrants._scratch.set_scratch_bytes(...) " + f"before any algorithm runs." + ) + + identity_bits = _identity_bits(0, i32) + op = _bin_add + dtype = i32 + + # Step 1: head_flags -> scratch[0:N]. + _rbk_head_flags(keys_in, scratch, positions_off, N) + + # Step 2: in-place exclusive scan of head_flags -> positions (still in scratch[0:N]). Mirrors the 3-pass dance in + # _select.py but with scratch as both source and dest for Pass 1 / Pass 3 (the existing kernels support src == + # dst aliasing). + if N > BLOCK_DIM: + _reduce_pass( + scratch, + scratch, + positions_off, + partials_off, + N, + B0 * BLOCK_DIM, + identity_bits, + op, + dtype, + True, + True, + ) + _exclusive_scan_inplace_u32(scratch, partials_off, B0, identity_bits, op, dtype, partials_off + B0) + _scan_pass3( + scratch, + positions_off, + scratch, + partials_off, + scratch, + positions_off, + N, + B0 * BLOCK_DIM, + identity_bits, + op, + dtype, + True, + True, + ) + else: + # Single-tile fast path: one block scans scratch[0:N] in place. Pass 1 still writes a single partial that is + # then trivially scanned, but it's cheaper to inline a 1-block scan kernel that reads + writes scratch + # directly. Reuse _exclusive_scan_inplace_u32's base case here. + _exclusive_scan_inplace_u32(scratch, positions_off, N, identity_bits, op, dtype, partials_off) + + # Step 3: zero-init values_out (only the prefix that the scatter can touch). + _rbk_zero_values_out(values_out, N, values_in.dtype) + + # Step 4: scatter keys + atomic-add values. + _rbk_scatter(keys_in, values_in, scratch, positions_off, keys_out, values_out, N) + + # Step 5: write num_runs. + _rbk_count(keys_in, scratch, positions_off, N, num_runs) + + +__all__ = ["device_reduce_by_key_add"] diff --git a/python/quadrants/algorithms/_scan.py b/python/quadrants/algorithms/_scan.py new file mode 100644 index 0000000000..6cbbda2184 --- /dev/null +++ b/python/quadrants/algorithms/_scan.py @@ -0,0 +1,528 @@ +# type: ignore +"""Device-wide exclusive-scan primitives. + +Implements ``qd.algorithms.device_exclusive_scan_{add,min,max}`` on top of the block-tier ``block.exclusive_scan`` +primitive. See the design doc at ``perso_hugh/doc/qipc/qipc_device_algos_design.md`` for the algorithmic rationale +(Blelloch 1990 / Harris-Sengupta-Owens 2007, three-pass formulation). + +Algorithm (three-pass, multi-level when needed): + +- **Pass 1: per-block tile reduce.** Each block reads ``BLOCK_DIM`` input elements, reduces them via + ``block.reduce(op, dtype)``, thread 0 writes the per-block aggregate into the shared ``u32`` scratch field + (``qd.bit_cast`` on write). Identical to ``_reduce_pass`` in ``_reduce.py``; we reuse that kernel. +- **Pass 2: exclusive-scan the partials.** Once the partials buffer is built, exclusive-scan it in place. For + ``B <= BLOCK_DIM`` a single block does it in one kernel launch (``_scan_block_inplace_u32``). For ``B > BLOCK_DIM`` + the driver recurses: it runs another tile-reduce on the partials buffer to produce a smaller partials-of-partials + buffer, recursively scans that, then runs a downsweep over the partials buffer to apply the per-tile prefixes. +- **Pass 3: per-block tile scan + block-prefix.** Each block re-reads its tile from the input source, computes + per-thread tile prefixes via ``block.exclusive_scan(op, identity, dtype)``, fetches its block prefix from the + scanned partials buffer, and writes ``out[i] = op(block_prefix, tile_prefix)``. + +Total scratch usage at ``N = 1M`` and ``BLOCK_DIM = 256``: ``B0 = 4096`` plus ``B1 = 16`` u32 slots = 4112 slots = +~16 KB, trivial relative to the 5 MB default. + +The ``PrefixSumExecutor`` class in ``_algorithms.py`` predates this work; it is kept for backward compat. The new +functional API is preferred for new code - see ``docs/source/user_guide/algorithms.md``. +""" + +from quadrants._scratch import ( + get_scratch_u32, + get_scratch_u64, + scratch_capacity_u32, + scratch_capacity_u64, +) +from quadrants.lang.impl import static +from quadrants.lang.kernel_impl import kernel +from quadrants.lang.misc import loop_config +from quadrants.lang.ops import bit_cast +from quadrants.lang.simt import block as _block +from quadrants.lang.simt.reductions import _bin_add, _bin_max, _bin_min +from quadrants.types.annotations import template +from quadrants.types.primitive_types import i32, u32, u64 + +from ._reduce import ( + _SUPPORTED_DTYPES as _REDUCE_SUPPORTED_DTYPES, +) +from ._reduce import ( + BLOCK_DIM, + _dtype_width_bytes, + _identity_bits, + _max_identity, + _min_identity, + _reduce_pass, + _reduce_pass_u64, +) + +_SUPPORTED_DTYPES = _REDUCE_SUPPORTED_DTYPES # {i32, u32, f32, i64, u64, f64} + + +@kernel +def _scan_block_inplace_u32( + buf: template(), + buf_off: i32, + n_valid: i32, + identity_bits: u32, + op: template(), + dtype: template(), +): + """Single-block in-place exclusive scan of ``buf[buf_off : buf_off + n_valid]`` (4-byte dtype path). + + Used at the recursion base of the scan driver, when the buffer being scanned fits in a single block. ``buf`` is + the shared ``Field(u32)`` scratch; the per-thread read / write go through ``qd.bit_cast`` to / from ``dtype``. + + Threads with ``i >= n_valid`` participate with ``identity`` (so the block-scope scan algorithm sees a clean + monoid) but do not write back. + """ + loop_config(block_dim=BLOCK_DIM) + for i in range(BLOCK_DIM): + identity = bit_cast(identity_bits, dtype) + v = identity + if i < n_valid: + v = bit_cast(buf[buf_off + i], dtype) + prefix = _block.exclusive_scan(v, BLOCK_DIM, op, identity, dtype) + if i < n_valid: + buf[buf_off + i] = bit_cast(prefix, u32) + + +@kernel +def _scan_block_inplace_u64( + buf: template(), + buf_off: i32, + n_valid: i32, + identity_bits: u64, + op: template(), + dtype: template(), +): + """8-byte sibling of :func:`_scan_block_inplace_u32`. Stages through the ``Field(u64)`` scratch.""" + loop_config(block_dim=BLOCK_DIM) + for i in range(BLOCK_DIM): + identity = bit_cast(identity_bits, dtype) + v = identity + if i < n_valid: + v = bit_cast(buf[buf_off + i], dtype) + prefix = _block.exclusive_scan(v, BLOCK_DIM, op, identity, dtype) + if i < n_valid: + buf[buf_off + i] = bit_cast(prefix, u64) + + +@kernel +def _scan_pass3( + src: template(), + src_off: i32, + prefixes: template(), + prefixes_off: i32, + dst: template(), + dst_off: i32, + n_valid: i32, + total_threads: i32, + identity_bits: u32, + op: template(), + dtype: template(), + src_is_u32: template(), + dst_is_u32: template(), +): + """Pass-3 downsweep: per-block tile scan + apply block prefix from scratch. + + Reads ``src[src_off : src_off + n_valid]`` (template-switched between the dtype tensor path and the u32-scratch + ``bit_cast`` path), computes per-thread tile prefixes via ``block.exclusive_scan``, looks up the block prefix at + ``prefixes[prefixes_off + block_id]`` (always a u32 scratch slot holding the dtype value bit-cast to u32, written + by Pass 2), and writes ``op(block_prefix, tile_prefix)`` to ``dst[dst_off + i]``. + + ``dst`` may alias ``src`` (in-place recursion case); the read-modify-write is per-thread and the + block.exclusive_scan internally barriers, so threads in a block see consistent values and writes by other blocks + land in disjoint tiles. + """ + loop_config(block_dim=BLOCK_DIM) + for i in range(total_threads): + tid = i % BLOCK_DIM + block_id = i // BLOCK_DIM + identity = bit_cast(identity_bits, dtype) + v = identity + if i < n_valid: + if static(src_is_u32): + v = bit_cast(src[src_off + i], dtype) + else: + v = src[src_off + i] + tile_prefix = _block.exclusive_scan(v, BLOCK_DIM, op, identity, dtype) + block_prefix = bit_cast(prefixes[prefixes_off + block_id], dtype) + if i < n_valid: + scanned = op(block_prefix, tile_prefix) + if static(dst_is_u32): + dst[dst_off + i] = bit_cast(scanned, u32) + else: + dst[dst_off + i] = scanned + + +@kernel +def _scan_pass3_u64( + src: template(), + src_off: i32, + prefixes: template(), + prefixes_off: i32, + dst: template(), + dst_off: i32, + n_valid: i32, + total_threads: i32, + identity_bits: u64, + op: template(), + dtype: template(), + src_is_u64: template(), + dst_is_u64: template(), +): + """8-byte sibling of :func:`_scan_pass3`. Stages through the ``Field(u64)`` scratch.""" + loop_config(block_dim=BLOCK_DIM) + for i in range(total_threads): + tid = i % BLOCK_DIM + block_id = i // BLOCK_DIM + identity = bit_cast(identity_bits, dtype) + v = identity + if i < n_valid: + if static(src_is_u64): + v = bit_cast(src[src_off + i], dtype) + else: + v = src[src_off + i] + tile_prefix = _block.exclusive_scan(v, BLOCK_DIM, op, identity, dtype) + block_prefix = bit_cast(prefixes[prefixes_off + block_id], dtype) + if i < n_valid: + scanned = op(block_prefix, tile_prefix) + if static(dst_is_u64): + dst[dst_off + i] = bit_cast(scanned, u64) + else: + dst[dst_off + i] = scanned + + +def _scan_total_scratch_slots(n: int, partials_cursor: int) -> int: + """Return the high-water-mark scratch slot index that ``_exclusive_scan_inplace_{u32,u64}`` will use to scan + ``n`` elements with its partials starting at ``partials_cursor`` (i.e. the smallest required ``capacity`` such + that ``capacity >= return_value`` is sufficient for the entire recursion). + + Mirrors the level-by-level allocation that the recursion does internally: at each level we bump + ``partials_cursor`` by ``B = ceil(n / BLOCK_DIM)`` and recurse on ``B``, until ``B <= BLOCK_DIM`` (base case, no + further partials). Callers (e.g. ``device_radix_sort``) should use this helper for their *up-front* scratch + check so they refuse the call before any in-place mutation runs (see PR 693 review: a single-level estimate + misses deeper recursion levels and lets ``_twiddle_pass`` corrupt the user's keys before the recursive + ``RuntimeError`` fires). + + The check inside ``_exclusive_scan_inplace_*`` itself stays as a defensive backstop; this helper is the + contract that the *outer* drivers should honour first. + """ + cursor = partials_cursor + while n > BLOCK_DIM: + B = (n + BLOCK_DIM - 1) // BLOCK_DIM + cursor += B + n = B + return cursor + + +def _exclusive_scan_inplace_u32(scratch, off: int, n: int, identity_bits: int, op, dtype, partials_cursor: int): + """Exclusive-scan ``scratch[off : off + n]`` in place. Recursive. + + ``partials_cursor`` is the next free u32 slot to use for the partials of the recursive level; the driver bumps it + down each level. + """ + if n <= BLOCK_DIM: + _scan_block_inplace_u32(scratch, off, n, identity_bits, op, dtype) + return + + B = (n + BLOCK_DIM - 1) // BLOCK_DIM + partials_off = partials_cursor + if partials_off + B > scratch_capacity_u32(): + raise RuntimeError( + f"device exclusive scan ran out of scratch at recursion level " + f"n={n}, B={B}, partials_off={partials_off}, capacity=" + f"{scratch_capacity_u32()}. Call _scratch.set_scratch_bytes(...) " + f"before any algorithm runs." + ) + + _reduce_pass( + scratch, + scratch, + off, + partials_off, + n, + B * BLOCK_DIM, + identity_bits, + op, + dtype, + True, + True, + ) + + _exclusive_scan_inplace_u32(scratch, partials_off, B, identity_bits, op, dtype, partials_off + B) + + _scan_pass3( + scratch, + off, + scratch, + partials_off, + scratch, + off, + n, + B * BLOCK_DIM, + identity_bits, + op, + dtype, + True, + True, + ) + + +def _exclusive_scan_inplace_u64(scratch, off: int, n: int, identity_bits: int, op, dtype, partials_cursor: int): + """8-byte sibling of :func:`_exclusive_scan_inplace_u32`. Stages through the ``Field(u64)`` scratch. + + Used internally by the 64-bit ``device_exclusive_scan_*`` path. Mirrors the 32-bit recursion shape: tile-reduce + into ``scratch[partials_off : partials_off + B]``, recurse on those partials, then downsweep back over the + original ``scratch[off : off + n]`` to apply per-tile prefixes. + """ + if n <= BLOCK_DIM: + _scan_block_inplace_u64(scratch, off, n, identity_bits, op, dtype) + return + + B = (n + BLOCK_DIM - 1) // BLOCK_DIM + partials_off = partials_cursor + if partials_off + B > scratch_capacity_u64(): + raise RuntimeError( + f"device exclusive scan ran out of u64 scratch at recursion level n={n}, B={B}, " + f"partials_off={partials_off}, capacity={scratch_capacity_u64()}. " + f"Call _scratch.set_scratch_bytes(...) before any algorithm runs." + ) + + _reduce_pass_u64( + scratch, + scratch, + off, + partials_off, + n, + B * BLOCK_DIM, + identity_bits, + op, + dtype, + True, + True, + ) + + _exclusive_scan_inplace_u64(scratch, partials_off, B, identity_bits, op, dtype, partials_off + B) + + _scan_pass3_u64( + scratch, + off, + scratch, + partials_off, + scratch, + off, + n, + B * BLOCK_DIM, + identity_bits, + op, + dtype, + True, + True, + ) + + +def _device_exclusive_scan(arr, *, out, op, identity_value): + """Internal driver shared by ``device_exclusive_scan_{add,min,max}``.""" + if not hasattr(arr, "shape") or len(arr.shape) != 1: + raise TypeError(f"device exclusive scan expects a 1-D input tensor; got shape {getattr(arr, 'shape', None)}") + if not hasattr(out, "shape") or out.shape != arr.shape: + raise TypeError(f"device exclusive scan expects out.shape == arr.shape; got arr={arr.shape}, out={out.shape}") + if arr.dtype != out.dtype: + raise TypeError(f"device exclusive scan dtype mismatch: arr={arr.dtype}, out={out.dtype}") + if arr is out: + # See design doc: in-place scan is rejected (no benefit when the caller already allocates `out` once and + # reuses it; protecting against same-buffer aliasing would just complicate the kernels). + raise ValueError( + "device exclusive scan does not support in-place operation; " + "pass a distinct `out` buffer (the API is designed around " + "caller-supplied out, see qipc_device_algos_design.md)" + ) + + dtype = arr.dtype + if dtype not in _SUPPORTED_DTYPES: + raise NotImplementedError( + f"device exclusive scan dtype {dtype} not supported (need one of " + f"{[d for d in _SUPPORTED_DTYPES]}); see design doc dtype matrix" + ) + width = _dtype_width_bytes(dtype) + + N = arr.shape[0] + identity_bits = _identity_bits(identity_value, dtype) + + if N == 0: + return + if N == 1: + if width == 4: + _scan_trivial_n1(out, identity_bits, dtype) + else: + _scan_trivial_n1_u64(out, identity_bits, dtype) + return + + if N <= BLOCK_DIM: + if width == 4: + _scan_single_tile_input_to_out(arr, out, N, identity_bits, op, dtype) + else: + _scan_single_tile_input_to_out_u64(arr, out, N, identity_bits, op, dtype) + return + + if width == 4: + scratch = get_scratch_u32() + scratch_cap = scratch_capacity_u32() + else: + scratch = get_scratch_u64() + scratch_cap = scratch_capacity_u64() + B0 = (N + BLOCK_DIM - 1) // BLOCK_DIM + # Reserve scratch slots: scratch[0:B0] for the top-level partials. The recursive scan sub-allocates from + # scratch[B0:] for any deeper levels. Use ``_scan_total_scratch_slots`` to account for the *full* recursion + # up front, so we refuse the call before pass 1 instead of partway through pass 2. + needed = _scan_total_scratch_slots(B0, partials_cursor=B0) + if needed > scratch_cap: + raise RuntimeError( + f"device exclusive scan on N={N} (dtype={dtype}) needs >= {needed} {scratch.dtype} scratch slots " + f"(top-level partials B0={B0} plus deeper recursion levels), but only {scratch_cap} are configured. " + f"Call _scratch.set_scratch_bytes(...) before any algorithm runs." + ) + + reduce_pass_kernel = _reduce_pass if width == 4 else _reduce_pass_u64 + scan_inplace_driver = _exclusive_scan_inplace_u32 if width == 4 else _exclusive_scan_inplace_u64 + pass3_kernel = _scan_pass3 if width == 4 else _scan_pass3_u64 + + # Pass 1: tile-reduce arr -> scratch[0:B0] + reduce_pass_kernel( + arr, + scratch, + 0, + 0, + N, + B0 * BLOCK_DIM, + identity_bits, + op, + dtype, + False, + True, + ) + + # Pass 2: exclusive-scan scratch[0:B0] in place (recursive if B0 > BLOCK_DIM). + scan_inplace_driver(scratch, 0, B0, identity_bits, op, dtype, B0) + + # Pass 3: arr + scratch[0:B0] -> out + pass3_kernel( + arr, + 0, + scratch, + 0, + out, + 0, + N, + B0 * BLOCK_DIM, + identity_bits, + op, + dtype, + False, + False, + ) + + +@kernel +def _scan_single_tile_input_to_out( + src: template(), + dst: template(), + n_valid: i32, + identity_bits: u32, + op: template(), + dtype: template(), +): + """Fast path for ``N <= BLOCK_DIM`` (4-byte dtype): one block reads the input tile, exclusive-scans, writes + ``out``. No scratch needed.""" + loop_config(block_dim=BLOCK_DIM) + for i in range(BLOCK_DIM): + identity = bit_cast(identity_bits, dtype) + v = identity + if i < n_valid: + v = src[i] + prefix = _block.exclusive_scan(v, BLOCK_DIM, op, identity, dtype) + if i < n_valid: + dst[i] = prefix + + +@kernel +def _scan_single_tile_input_to_out_u64( + src: template(), + dst: template(), + n_valid: i32, + identity_bits: u64, + op: template(), + dtype: template(), +): + """8-byte sibling of :func:`_scan_single_tile_input_to_out`.""" + loop_config(block_dim=BLOCK_DIM) + for i in range(BLOCK_DIM): + identity = bit_cast(identity_bits, dtype) + v = identity + if i < n_valid: + v = src[i] + prefix = _block.exclusive_scan(v, BLOCK_DIM, op, identity, dtype) + if i < n_valid: + dst[i] = prefix + + +@kernel +def _scan_trivial_n1(dst: template(), identity_bits: u32, dtype: template()): + """N == 1 path (4-byte dtype): write the identity to out[0]. Exclusive scan of a single element is just the + identity.""" + for _ in range(1): + dst[0] = bit_cast(identity_bits, dtype) + + +@kernel +def _scan_trivial_n1_u64(dst: template(), identity_bits: u64, dtype: template()): + """8-byte sibling of :func:`_scan_trivial_n1`.""" + for _ in range(1): + dst[0] = bit_cast(identity_bits, dtype) + + +def device_exclusive_scan_add(arr, out): + """Compute ``out[i] = sum(arr[0:i])`` (exclusive prefix sum) on the device. + + Args: + arr: 1-D tensor of any supported scalar dtype - ``{i32, u32, f32, i64, u64, f64}``. Pass a ``qd.field``, + ``qd.ndarray``, or ``qd.Tensor`` wrapper around either. + out: 1-D tensor with the same dtype and shape as ``arr``. Must be a distinct buffer (no in-place scan). + + The implementation is the three-pass Blelloch-style scan built on ``block.exclusive_scan`` and the shared + scratch fields (``Field(u32)`` for 4-byte dtypes, ``Field(u64)`` for 8-byte). Recurses on the partials buffer + when ``N`` is large enough that the partials count exceeds ``BLOCK_DIM``. + + See the design doc at ``perso_hugh/doc/qipc/qipc_device_algos_design.md`` for the algorithmic background and + the ``bit_cast``-into-scratch scheme. + """ + _device_exclusive_scan(arr, out=out, op=_bin_add, identity_value=0) + + +def device_exclusive_scan_min(arr, out): + """Compute ``out[i] = min(arr[0:i])`` (exclusive prefix min) on the device. + + Args: + arr: see ``device_exclusive_scan_add`` (any of ``{i32, u32, f32, i64, u64, f64}``). + out: see ``device_exclusive_scan_add``. + + The monoid identity is derived from ``arr.dtype`` automatically (largest representable value: ``+inf`` for + floats, ``INT{32,64}_MAX`` for signed ints, ``UINT{32,64}_MAX`` for unsigned). Mirrors the + ``block.exclusive_min`` / ``subgroup.exclusive_min_tiled`` contract: the typed scan primitives do not take an + identity argument because (op, dtype) fixes it. + """ + _device_exclusive_scan(arr, out=out, op=_bin_min, identity_value=_min_identity(arr.dtype)) + + +def device_exclusive_scan_max(arr, out): + """Compute ``out[i] = max(arr[0:i])`` (exclusive prefix max) on the device. Mirror of + :func:`device_exclusive_scan_min` with ``max`` and the dtype's *negative* extremum (``-inf`` for floats, + ``INT{32,64}_MIN`` for signed ints, ``0`` for unsigned), again derived from ``arr.dtype`` automatically. + """ + _device_exclusive_scan(arr, out=out, op=_bin_max, identity_value=_max_identity(arr.dtype)) + + +__all__ = [ + "device_exclusive_scan_add", + "device_exclusive_scan_max", + "device_exclusive_scan_min", +] diff --git a/python/quadrants/algorithms/_select.py b/python/quadrants/algorithms/_select.py new file mode 100644 index 0000000000..008d2fc33d --- /dev/null +++ b/python/quadrants/algorithms/_select.py @@ -0,0 +1,199 @@ +# type: ignore +"""Device-wide stream compaction (``select`` / ``compact``). + +``qd.algorithms.device_select(arr, flags, out, num_out)`` packs the elements of ``arr`` for which the corresponding +``flags`` entry is set into a dense prefix of ``out``, in stable input order, and writes the count of selected +elements to ``num_out[0]``. Each ``flags[i]`` must be exactly ``0`` or ``1`` (``1`` selects); the algorithm +prefix-sums ``flags`` directly as counts, so non-0/1 values produce wrong indices and counts (caller's responsibility, +no normalization pass). + +Algorithm (textbook scan-based compaction): + +1. Exclusive prefix-sum the ``flags`` (treated as 0 / 1) into the shared ``Field(u32)`` scratch, producing per-element + write indices. This reuses the same three-pass scan internals as ``device_exclusive_scan_add`` but targets a + scratch slice for the output instead of a caller-supplied ``out`` tensor. +2. A single fused "scatter" kernel reads ``arr[i]`` and ``flags[i]``, and if the flag is set, writes + ``out[indices[i]] = arr[i]``. +3. A 1-thread tail kernel sums ``indices[N-1] + flags[N-1]`` (= total count) and stores it in ``num_out[0]``. + +Scratch layout for the scan: ``scratch[0 : N]`` holds the per-element indices (i32 bit-cast to u32). +``scratch[N : N + B0]`` holds the level-0 partials, ``scratch[N + B0 : ...]`` deeper recursion levels (mirrors the +device scan). The scratch is *always* u32 regardless of the element dtype, because the scan operates on +flags-as-counts (i32) which always fit in u32; the element dtype only shows up at scatter time as +``dst[idx] = src[i]``, which lowers per-field for struct dtypes without any scratch reinterpretation. + +This is why ``device_select`` works on any element dtype Quadrants supports for field assignment - scalars +(``i32`` / ``u32`` / ``f32`` / ``i64`` / ``u64`` / ``f64``) and structs (libuipc ``Vector{2,3,4}i``, +``LinearBVHAABB``, etc.). + +Constraints (first land): ``N`` must fit comfortably within the configured scratch budget - the indices + partials +together must not exceed ``scratch_capacity_u32()``. For the default 5 MB budget that's +``N + ceil(N / 256) + ... <= ~1.3M``, which covers qipc's hot path (``N = 1M``) out of the box. Raise the budget via +``_scratch.set_scratch_bytes(...)`` before any algorithm runs to unlock larger inputs. +""" + +from quadrants._scratch import get_scratch_u32, scratch_capacity_u32 +from quadrants.lang.kernel_impl import kernel +from quadrants.lang.ops import bit_cast +from quadrants.lang.simt.reductions import _bin_add +from quadrants.types.annotations import template +from quadrants.types.primitive_types import i32 + +from ._reduce import BLOCK_DIM, _identity_bits, _reduce_pass +from ._scan import _exclusive_scan_inplace_u32, _scan_pass3 + + +@kernel +def _select_scatter( + src: template(), + flags: template(), + indices: template(), + indices_off: i32, + dst: template(), + n_valid: i32, +): + """Scatter pass: write ``dst[indices[i]] = src[i]`` for every ``i`` where ``flags[i] != 0``. ``indices`` is the + u32 scratch slice holding the exclusive scan of ``flags`` (i.e. the destination index of each surviving element); + we ``bit_cast`` it back to ``i32`` before indexing. + + No race: each surviving thread writes to a distinct ``dst`` slot (by construction of the exclusive scan over + 0 / 1 flags). Dropped threads do not write. + """ + for i in range(n_valid): + if flags[i] != 0: + dst_idx = bit_cast(indices[indices_off + i], i32) + dst[dst_idx] = src[i] + + +@kernel +def _select_count( + flags: template(), + indices: template(), + indices_off: i32, + n_valid: i32, + num_out: template(), +): + """1-thread tail kernel: ``num_out[0] = indices[N-1] + flags[N-1]``. + + Split into its own launch so the host driver doesn't have to insert a grid sync after the scatter; the kernel + boundary serializes against the preceding scan writes. + """ + for _ in range(1): + last_idx = bit_cast(indices[indices_off + n_valid - 1], i32) + last_flag = flags[n_valid - 1] + last_inc = i32(0) + if last_flag != 0: + last_inc = i32(1) + num_out[0] = last_idx + last_inc + + +def device_select(arr, flags, out, num_out): + """Stream-compact ``arr`` by ``flags``: copy ``arr[i]`` to a dense prefix of ``out`` for every ``i`` where + ``flags[i] == 1``, in stable input order. Write the count of selected elements to ``num_out[0]``. + + Args: + arr: 1-D tensor of any element dtype that Quadrants supports field-element assignment for: scalars + (``i32`` / ``u32`` / ``f32`` / ``i64`` / ``u64`` / ``f64``) and structs (``qd.Struct.field({...})`` or + ``qd.types.struct(...)`` - e.g. the libuipc ``Vector{2,3,4}i`` shapes). The scatter is + ``dst[idx] = src[i]``, which lowers per-field for struct dtypes, so no scratch reinterpretation is + needed for wider / composite element types. + flags: 1-D ``i32`` tensor, same shape as ``arr``. **Every entry must be exactly ``0`` or ``1``** (``1`` + selects). Non-0/1 values produce incorrect results - the algorithm prefix-sums ``flags`` directly as + counts, so a stray ``2`` would advance the destination cursor by 2 and break the dense-output / count + contract. Caller-built: populate with a separate kernel that applies your predicate, writing exactly + ``1`` for selected and ``0`` otherwise. + out: 1-D tensor with the same dtype as ``arr``. Must hold at least ``N`` elements (so a + worst-case-everyone-selected run fits); only the prefix ``out[0 : num_out[0]]`` is meaningful on return. + num_out: 1-element ``i32`` tensor receiving the selected count. + + Same async / no-implicit-sync contract as ``device_reduce_*`` and ``device_exclusive_scan_*``: ``num_out`` is a + tensor, not a Python scalar - call ``num_out.to_numpy()[0]`` explicitly to get the count host-side. + + See the design doc at ``perso_hugh/doc/qipc/qipc_device_algos_design.md`` for the scratch-into-indices layout and + the algorithm reference. + """ + if not hasattr(arr, "shape") or len(arr.shape) != 1: + raise TypeError(f"device_select expects a 1-D arr; got shape {getattr(arr, 'shape', None)}") + if not hasattr(flags, "shape") or flags.shape != arr.shape: + raise TypeError(f"device_select expects flags.shape == arr.shape; got arr={arr.shape}, flags={flags.shape}") + if flags.dtype != i32: + raise TypeError(f"device_select expects flags.dtype == qd.i32; got {flags.dtype}") + if not hasattr(out, "shape") or len(out.shape) != 1: + raise TypeError(f"device_select expects a 1-D out; got shape {getattr(out, 'shape', None)}") + if out.dtype != arr.dtype: + raise TypeError(f"device_select dtype mismatch: arr={arr.dtype}, out={out.dtype}") + if out.shape[0] < arr.shape[0]: + raise ValueError( + f"device_select out.shape[0]={out.shape[0]} < arr.shape[0]={arr.shape[0]}; " + "out must hold at least the input size to be safe in the all-selected case" + ) + if not hasattr(num_out, "shape") or num_out.shape != (1,): + raise TypeError(f"device_select expects num_out.shape == (1,); got {num_out.shape}") + if num_out.dtype != i32: + raise TypeError(f"device_select expects num_out.dtype == qd.i32; got {num_out.dtype}") + + N = arr.shape[0] + if N == 0: + return + + # Scratch layout: scratch[0:N] = indices, scratch[N : N + B0] = level-0 partials, then deeper levels above. + scratch = get_scratch_u32() + cap = scratch_capacity_u32() + B0 = (N + BLOCK_DIM - 1) // BLOCK_DIM + indices_off = 0 + partials_off = N + if partials_off + B0 > cap: + raise RuntimeError( + f"device_select on N={N} needs >= {partials_off + B0} u32 scratch slots, " + f"but only {cap} are configured. Call _scratch.set_scratch_bytes(...) " + f"before any algorithm runs." + ) + + identity_bits = _identity_bits(0, i32) + op = _bin_add + dtype = i32 + + # Three-pass scan of flags into scratch[0:N] (i32 indices, stored as u32 bit pattern). The general 3-pass path + # collapses gracefully when N <= BLOCK_DIM: pass 1 writes a single partial, pass 2 in-place-scans it, pass 3 + # applies the (trivial) prefix to the single-tile scan. + # Pass 1: per-block tile reduce of flags -> scratch[partials_off:] + _reduce_pass( + flags, + scratch, + 0, + partials_off, + N, + B0 * BLOCK_DIM, + identity_bits, + op, + dtype, + False, + True, + ) + # Pass 2: in-place scan of the partials (recursive if B0 > BLOCK_DIM). + _exclusive_scan_inplace_u32(scratch, partials_off, B0, identity_bits, op, dtype, partials_off + B0) + # Pass 3: flags + scanned partials -> scratch[0:N] (u32 indices) + _scan_pass3( + flags, + 0, + scratch, + partials_off, + scratch, + indices_off, + N, + B0 * BLOCK_DIM, + identity_bits, + op, + dtype, + False, + True, + ) + + # Step 2: scatter src[i] -> dst[indices[i]] for every selected i. + _select_scatter(arr, flags, scratch, indices_off, out, N) + + # Step 3: write the count. + _select_count(flags, scratch, indices_off, N, num_out) + + +__all__ = ["device_select"] diff --git a/quadrants/ir/frontend_ir.cpp b/quadrants/ir/frontend_ir.cpp index 51c896a493..032c0263f4 100644 --- a/quadrants/ir/frontend_ir.cpp +++ b/quadrants/ir/frontend_ir.cpp @@ -980,7 +980,7 @@ void AtomicOpExpression::type_check(const CompileConfig *config) { // Reject tensor (vector / matrix) destinations explicitly. The other atomic ops fan out to per-component // scalar AtomicOpStmts via scalarize / lower_matrix_ptr, but those passes use the 3-arg AtomicOpStmt // constructor and would silently drop `expected`, tripping QD_ASSERT(stmt->expected) in codegen. Until the - // scalarizers grow a 4-arg path that threads `expected_i` through, refuse tensor CAS at trace time. + // scalarizers grow a 4-arg path that threads `expected_i` through, refuse tensor CAS at compile time. if (dest_dtype->is()) { ErrorEmitter(QuadrantsTypeError(), this, fmt::format("'atomic_cas' on tensor (vector / matrix) destinations is not supported; got " diff --git a/quadrants/runtime/llvm/llvm_context.cpp b/quadrants/runtime/llvm/llvm_context.cpp index e8b1f89a41..76ae4b5a0e 100644 --- a/quadrants/runtime/llvm/llvm_context.cpp +++ b/quadrants/runtime/llvm/llvm_context.cpp @@ -524,7 +524,32 @@ std::unique_ptr QuadrantsLLVMContext::module_from_file(const std:: patch_intrinsic("thread_idx", llvm::Intrinsic::amdgcn_workitem_id_x); patch_intrinsic("block_thread_idx", llvm::Intrinsic::amdgcn_workitem_id_x); patch_intrinsic("block_idx", llvm::Intrinsic::amdgcn_workgroup_id_x); - patch_intrinsic("block_barrier", llvm::Intrinsic::amdgcn_s_barrier, false); + + // Synthesize ``block_barrier`` as ``fence release "workgroup" -> s_barrier -> fence acquire "workgroup"`` (the + // same sequence HIP's ``__syncthreads()`` emits via ``__work_group_barrier`` in + // ``hip/amd_detail/amd_device_functions.h``). The bare ``llvm.amdgcn.s.barrier`` intrinsic only emits the + // ``s_barrier`` instruction; without the surrounding fences, LDS writes issued before the barrier are not + // guaranteed to be visible to other lanes after the barrier on RDNA3 because the AMDGCN backend has no + // memory-model edge tying the barrier to prior stores. Symptom that motivated this fix: at ``BLOCK_DIM=256`` / + // ``NUM_SUBGROUPS=4``, ``block.reduce`` reads stale garbage from its inter-subgroup ``SharedArray`` at + // ``NBLOCKS>=~200`` (intermittent below, near-deterministic above) -- exactly the "publish to LDS, barrier, + // read LDS" pattern that needs the release/acquire pair to be sound. + { + auto func = module->getFunction("block_barrier"); + if (func) { + func->deleteBody(); + auto bb = llvm::BasicBlock::Create(*ctx, "entry", func); + IRBuilder<> builder(*ctx); + builder.SetInsertPoint(bb); + llvm::SyncScope::ID workgroup = ctx->getOrInsertSyncScopeID("workgroup"); + builder.CreateFence(llvm::AtomicOrdering::Release, workgroup); + builder.CreateIntrinsic(llvm::Intrinsic::amdgcn_s_barrier, llvm::ArrayRef{}, + llvm::ArrayRef{}); + builder.CreateFence(llvm::AtomicOrdering::Acquire, workgroup); + builder.CreateRetVoid(); + QuadrantsLLVMContext::mark_inline(func); + } + } patch_intrinsic("amdgpu_clock_i64", llvm::Intrinsic::amdgcn_s_memtime); patch_intrinsic("amdgpu_ds_bpermute", llvm::Intrinsic::amdgcn_ds_bpermute); // ``llvm.amdgcn.permlane64`` exchanges a 32-bit value between lanes ``i`` and ``i ^ 32`` in a single instruction. @@ -727,7 +752,20 @@ void QuadrantsLLVMContext::link_module_with_cuda_libdevice(std::unique_ptrsetTargetTriple(llvm::Triple("nvptx64-nvidia-cuda")); strip_nvvmir_version(libdevice_module.get()); - module->setDataLayout(libdevice_module->getDataLayout()); + + // `slim_libdevice.10.bc` ships without an explicit `target datalayout` line (only the `nvptx64-nvidia-gpulibs` + // triple), so its `getDataLayout()` returns the empty LLVM-default DL where `i64` ABI alignment is 4 bytes. + // Previously we copied that empty DL straight into the kernel module, which made every CreateStore / CreateLoad + // of i64 emit `align 4` -> the NVPTX backend then split each `align 4` i64 store into two `st.b32` halves, and + // ptxas in turn mis-combined those halves into a single `ST.E.64` that dropped the low 32 bits of values produced + // by f64 / i64 arithmetic (single 64-bit virtual reg holding the full bit pattern). End result: silent precision + // loss for `bit_cast(scan_result_f64, u64)` and friends. + // + // Pin the canonical NVPTX64 DL (matches LLVM's `NVPTXTargetMachine::computeDataLayout(is64Bit=true, + // UseShortPointers=false)`) so CreateStore / CreateLoad see `i64:64` and emit single aligned `st.b64` / `ld.b64`. + static const char *kNVPTX64DataLayout = "e-p6:32:32-i64:64-i128:128-v16:16-v32:32-n16:32:64"; + module->setDataLayout(kNVPTX64DataLayout); + libdevice_module->setDataLayout(kNVPTX64DataLayout); bool failed = llvm::Linker::linkModules(*module, std::move(libdevice_module)); if (failed) { diff --git a/tests/python/test_algorithms.py b/tests/python/test_algorithms.py new file mode 100644 index 0000000000..e4b4ac9960 --- /dev/null +++ b/tests/python/test_algorithms.py @@ -0,0 +1,1620 @@ +"""Tests for ``qd.algorithms.*`` device-wide primitives. + +Covers: + +- ``quadrants._scratch`` - the shared ``Field(u32)`` scratch buffer that backs every device algorithm. +- ``qd.algorithms.device_reduce_{add,min,max}`` - two-or-more-pass tree reduction with shared scratch + ``bit_cast``. +- ``qd.algorithms.device_exclusive_scan_{add,min,max}`` - three-pass scan. +- ``qd.algorithms.device_select`` - scan-based stream compaction. +- ``qd.algorithms.device_radix_sort`` - LSB radix sort built on ``block.radix_rank_match_atomic_or``. +- ``qd.algorithms.device_reduce_by_key_add`` - scan + scatter + atomic_add reduce-by-key. + +Each test runs across the full ``arch=qd.gpu`` parametrization so the kernels are exercised on CUDA, AMDGPU, Vulkan, +and Metal (where the host supports each). +""" + +import math +import platform +import struct + +import numpy as np +import pytest + +import quadrants as qd +from quadrants import _scratch +from quadrants.lang.util import to_numpy_type + +from tests import test_utils + +# --------------------------------------------------------------------------- +# Module-level constants: dtype sets, size sweeps, identity tables. +# --------------------------------------------------------------------------- + +# Supported scalar dtypes per algorithm. Reduce / scan / select / RBK share the same 6-dtype set; radix sort uses a +# slightly different ordering (u32 first because that's the natural histogram dtype). +_REDUCE_DTYPES = [qd.i32, qd.u32, qd.f32, qd.i64, qd.u64, qd.f64] +_SCAN_DTYPES = [qd.i32, qd.u32, qd.f32, qd.i64, qd.u64, qd.f64] +_SELECT_DTYPES = [qd.i32, qd.u32, qd.f32, qd.i64, qd.u64, qd.f64] +_RADIX_KEY_DTYPES = [qd.u32, qd.i32, qd.f32, qd.u64, qd.i64, qd.f64] + +# Numpy-dtype lookup. Used by every test that allocates a host buffer for ``from_numpy``. +_DTYPE_TO_NP = { + qd.i32: np.int32, + qd.u32: np.uint32, + qd.f32: np.float32, + qd.i64: np.int64, + qd.u64: np.uint64, + qd.f64: np.float64, +} + +# Identities for device_reduce_min / max (passed by tests that initialize an "all-identity" input). Floats use the +# +/- inf extremum; ints use the dtype's positive / negative range extreme. +_MIN_IDENTITY = { + qd.i32: 2**31 - 1, + qd.u32: 2**32 - 1, + qd.f32: float("inf"), + qd.i64: 2**63 - 1, + qd.u64: 2**64 - 1, + qd.f64: float("inf"), +} +_MAX_IDENTITY = { + qd.i32: -(2**31), + qd.u32: 0, + qd.f32: float("-inf"), + qd.i64: -(2**63), + qd.u64: 0, + qd.f64: float("-inf"), +} + +# Size sweeps. Chosen to cover (across algorithms): single-block path, on-block-boundary, off-by-one tile, two-block, +# many-block recursion. Reduce / scan / select / RBK share the structure with minor variations (radix and +# select-struct trim a few sizes to keep test runtime bounded). The 1M size only appears in scan / scratch / +# qipc-hot-path tests; the others top out at 200K within the default 5 MB scratch budget. +_REDUCE_SIZES = [1, 7, 255, 256, 257, 1023, 1024, 1025, 65536, 65537, 200_000] +_SCAN_SIZES = [1, 7, 255, 256, 257, 1023, 1024, 1025, 65536, 65537, 200_000, 1_000_000] +_SELECT_SIZES = [1, 7, 255, 256, 257, 1023, 1024, 1025, 65536, 65537, 200_000] +_SELECT_STRUCT_SIZES = [1, 7, 256, 1024, 65537] +_SELECT_STRUCT_NFIELDS = [2, 3, 4] # mirrors libuipc Vector2i / Vector3i / Vector4i +_RADIX_SORT_SIZES = [1, 7, 256, 257, 1023, 1024, 1025, 65536, 200_000] +_RBK_SIZES = [1, 2, 3, 7, 255, 256, 257, 1023, 1024, 1025, 65536, 65537, 200_000] + +# 64 KB; ~16K u32 slots. Used by the scratch-budget rejection tests so each one trips the budget guard with a tiny N +# (cheap to allocate, runtime-independent of the DEFAULT_SCRATCH_BYTES knob). +_TINY_SCRATCH_BYTES = 64 << 10 + + +# --------------------------------------------------------------------------- +# Backend dtype-support matrix + skip helpers. +# --------------------------------------------------------------------------- +# +# Anything outside the ``supported`` column is unsupported at the lang tier (the spirv / metal IR builders bail with +# "Type X not supported"), so every device-tier test must skip those (arch, platform, dtype) triples. +# +# arch | platform | supported dtypes +# ----------------------|----------|------------------------------------ +# qd.cuda | any | i32, u32, f32, i64, u64, f64 +# qd.amdgpu | any | i32, u32, f32, i64, u64, f64 +# qd.vulkan | Linux | i32, u32, f32, i64, u64, f64 +# qd.vulkan (MoltenVK) | Darwin | i32, u32, f32 (no i64 / u64 / f64) +# qd.metal | any | i32, u32, f32 (no i64 / u64 / f64) +# +# We encode the matrix as the *unsupported* set per backend, since that's what the skip predicates need. + + +def _is_apple_gpu_backend(): + """Metal or MoltenVK (Vulkan-on-Darwin). These two share the same dtype-support gaps in buffer-backed I/O.""" + arch = qd.lang.impl.current_cfg().arch + return arch == qd.metal or (arch == qd.vulkan and platform.system() == "Darwin") + + +def _unsupported_dtype_reason(dtype): + """Return a human-readable reason if ``dtype`` is unsupported on the current backend, else None. + + Single source of truth for "should we skip this dtype?". ``_skip_if_dtype_unsupported`` wraps it for + parametrized-test skipping; tests that iterate dtypes inside the body check the return value directly to + ``continue`` past unsupported dtypes. + """ + if _is_apple_gpu_backend(): + if dtype in (qd.i64, qd.u64): + return f"64-bit integer type {dtype} not supported on the current backend" + if dtype == qd.f64: + return "f64 not supported on the current backend" + return None + + +def _skip_if_dtype_unsupported(dtype): + """Skip the calling test if ``dtype`` is unsupported on the current backend. Mirrors the gate used in + ``test_simt.py`` so device-tier dtype coverage matches block / subgroup-tier coverage.""" + reason = _unsupported_dtype_reason(dtype) + if reason is not None: + pytest.skip(reason) + + +def _skip_if_radix_sort_large_n_on_apple_gpu(N): + """Skip large-N ``device_radix_sort`` calls on Metal / MoltenVK. + + *Why this skip exists.* On Apple GPUs (Metal directly, and MoltenVK / Vulkan-on-Darwin), ``device_radix_sort`` + produces incorrect results once N crosses ``BLOCK_DIM**2 = 65_536``: the ``test_device_radix_sort_keys_only`` + parametrization at N=200_000 reports 50-90% of elements in the wrong position on those backends. CUDA, AMDGPU, + and Linux Vulkan all pass at every tested size on the same code, so the regression is in the Apple-GPU codegen / + runtime path of one of the building blocks (most likely the histogram pass's threadgroup-shared atomic_or + + barrier sequence at high block counts), not in the radix-sort algorithm itself. Smaller N (N <= 65_536, the + single- and few-block paths) pass cleanly on Apple GPUs. + + Tracked as a follow-up; not blocking the device-algos first land. Tests that *transitively* hit this path + (radix-sort-then-RBK at N=1M, ``test_scratch_round_trip_across_qd_reset`` at ~2.6M) also need this guard. + """ + if N >= 200_000 and _is_apple_gpu_backend(): + pytest.skip("device_radix_sort produces incorrect results on Metal / MoltenVK at N >= 200_000") + + +# --------------------------------------------------------------------------- +# Tolerance contract for f32 / f64 reduce / scan / RBK assertions. +# --------------------------------------------------------------------------- +# +# Block-tree reduce: error scales as ``O(log N * eps_f32)``. At N=1M, log2(N)*eps_f32 ~ 2e-6, well under any +# tolerance below; we use ``_F32_REDUCE_*`` for the parametrized N<=200K dtype sweep and ``_F32_LARGE_N_*`` for the +# qipc N=1M hot path (slightly looser to absorb MoltenVK fast-math reordering headroom on the big sums). +# +# Block-tree scan: error scales as ``O(sqrt(N) * eps_f32)`` (Higham 2002, "Accuracy and Stability of Numerical +# Algorithms", §4.2 on pairwise / tree summation). The ``2e-5`` constant in ``_f32_scan_tol`` is a 2x headroom over +# the strict-IEEE bound (``eps_f32 ~ 1.2e-7``), there to absorb MoltenVK's more-aggressive fast-math reordering of +# f32 partial sums without papering over actual algorithmic regressions on CUDA / Linux Vulkan / AMDGPU. Two asserts +# inside the function pin the contract in plain language: at <= 100 elements f32 stays under 0.1% rel; at <= 100K +# under 1% rel; the qipc hot-path N=1M lands at ~2% rel. +# +# Reduce-by-key (f32 values): adds an atomic_add reordering layer on top of scan-style scatter; uses the +# ``_F32_LARGE_N_*`` floor so MoltenVK's reordering stays comfortably bounded. +# +# f64: strict-IEEE ``eps_f64 ~ 2.2e-16`` dominates everything; reordering is irrelevant at f64 precision for any +# tested N. + +_F32_REDUCE_RTOL = 1e-4 # tree reduce: log(N)*eps_f32 ~ 2e-6 at N=1M; 1e-4 is plenty for the N<=200K dtype sweep +_F32_REDUCE_ATOL = 1e-4 +_F32_LARGE_N_RTOL = 1e-3 # qipc hot path: N=1M reduce / scan-head / RBK; covers MoltenVK reorder headroom +_F32_LARGE_N_ATOL = 1e-3 +_F64_RTOL = 1e-12 # eps_f64 dominates; tight bound across every tested N +_F64_ATOL = 1e-9 + + +def _f32_scan_tol(N): + """Return ``(rtol, atol)`` for the f32 scan ``assert_allclose``. Scales rtol with sqrt(N); atol is constant. + + See the module-level comment block above for the derivation of the constant and the contract asserts below.""" + rtol = 2e-5 * math.sqrt(N) + assert rtol <= 1e-3 or N > 100, f"f32 scan rtol={rtol:g} too loose for small N={N}" + assert rtol <= 1e-2 or N > 100_000, f"f32 scan rtol={rtol:g} too loose for medium N={N}" + return rtol, 1e-3 + + +@test_utils.test(arch=qd.gpu) +def test_scratch_allocates_with_expected_capacity(): + """First call returns a Field(u32) sized to the configured byte budget.""" + s = _scratch.get_scratch_u32() + assert s.dtype == qd.u32 + assert s.shape == (_scratch.DEFAULT_SCRATCH_BYTES // 4,) + assert _scratch.scratch_capacity_u32() == _scratch.DEFAULT_SCRATCH_BYTES // 4 + + +@test_utils.test(arch=qd.gpu) +def test_scratch_is_shared_across_calls(): + """The same Field instance is returned on repeated calls within a runtime.""" + s1 = _scratch.get_scratch_u32() + s2 = _scratch.get_scratch_u32() + assert s1 is s2 + + +@test_utils.test(arch=qd.gpu) +def test_scratch_round_trips_bit_cast_f32(): + """Smoke: write f32 values into the u32 scratch via qd.bit_cast and read them back. Verifies the bit_cast pattern + used by every algorithm.""" + s = _scratch.get_scratch_u32() + N = 64 + + @qd.kernel + def write(): + for i in range(N): + v = qd.f32(i) * qd.f32(0.5) - qd.f32(7.25) + s[i] = qd.bit_cast(v, qd.u32) + + out = qd.field(qd.f32, shape=N) + + @qd.kernel + def read(): + for i in range(N): + out[i] = qd.bit_cast(s[i], qd.f32) + + write() + read() + for i in range(N): + expected = i * 0.5 - 7.25 + assert out[i] == expected, f"slot {i}: got {out[i]}, expected {expected}" + + +@test_utils.test(arch=qd.gpu) +def test_scratch_u64_allocates_with_expected_capacity(): + """First call to ``get_scratch_u64`` returns a Field(u64) sized to the same byte budget as the u32 scratch.""" + s = _scratch.get_scratch_u64() + assert s.dtype == qd.u64 + assert s.shape == (_scratch.DEFAULT_SCRATCH_BYTES // 8,) + assert _scratch.scratch_capacity_u64() == _scratch.DEFAULT_SCRATCH_BYTES // 8 + + +@test_utils.test(arch=qd.gpu) +def test_scratch_u64_is_shared_across_calls(): + s1 = _scratch.get_scratch_u64() + s2 = _scratch.get_scratch_u64() + assert s1 is s2 + + +@test_utils.test(arch=qd.gpu) +def test_scratch_round_trips_bit_cast_f64(): + """Smoke: feed exact-known f64 bit patterns into the kernel, bit_cast through the u64 scratch, read back. Mirrors + ``test_scratch_round_trips_bit_cast_f32`` for the 8-byte-dtype path used by 64-bit ``device_reduce_*``. + + We push the host-computed bit pattern in via a u64 source field rather than arithmetic on f64 literals to dodge + kernel-side fp-contract / FMA-reassociation that can offset the result by 1 ulp from the host-side value. + """ + _skip_if_dtype_unsupported(qd.f64) + N = 64 + s = _scratch.get_scratch_u64() + src_bits = qd.field(qd.u64, shape=N) + out = qd.field(qd.f64, shape=N) + + expected = [i * 0.5 - 7.25 + 1.0e-100 * i for i in range(N)] + bits_host = np.array([struct.unpack(" out is a copy of input, num_out = N.""" + N = 1024 + inp = qd.field(qd.i32, shape=N) + flags = qd.field(qd.i32, shape=N) + out = qd.field(qd.i32, shape=N) + num_out = qd.field(qd.i32, shape=1) + + rng = np.random.default_rng(seed=42) + host = rng.integers(-100, 100, size=N, dtype=np.int32) + _fill_field(inp, host) + _fill_field(flags, np.ones(N, dtype=np.int32)) + + qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + assert int(num_out.to_numpy()[0]) == N + np.testing.assert_array_equal(out.to_numpy(), host) + + +@test_utils.test(arch=qd.gpu) +def test_device_select_none_selected(): + """flags = all 0 -> nothing written, num_out = 0.""" + N = 1024 + inp = qd.field(qd.i32, shape=N) + flags = qd.field(qd.i32, shape=N) + out = qd.field(qd.i32, shape=N) + num_out = qd.field(qd.i32, shape=1) + + _fill_field(inp, np.arange(N, dtype=np.int32)) + _fill_field(flags, np.zeros(N, dtype=np.int32)) + + qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + assert int(num_out.to_numpy()[0]) == 0 + + +@test_utils.test(arch=qd.gpu) +def test_device_select_zero_one_flag_contract(): + """Pin the 0/1 flag contract documented in ``algorithms.md`` and the ``device_select`` docstring. + + ``device_select`` prefix-sums ``flags`` *directly* as counts (no implicit normalization), so the contract is + that every entry is exactly ``0`` or ``1`` and ``1`` selects. This regression test pins the contract by + construction: an interleaved ``[1, 0, 1, 0, ...]`` pattern of length ``N`` selects exactly the even indices + (``N/2`` elements, in input order). If a future change accidentally re-introduces an implicit normalization + or breaks the prefix-sum-as-count semantics, this test is the canary. + """ + N = 1024 + flags_host = np.zeros(N, dtype=np.int32) + flags_host[::2] = 1 # exactly 0 or 1, interleaved -> selects the even indices + inp_host = np.arange(N, dtype=np.int32) + + inp = qd.field(qd.i32, shape=N) + flags = qd.field(qd.i32, shape=N) + out = qd.field(qd.i32, shape=N) + num_out = qd.field(qd.i32, shape=1) + _fill_field(inp, inp_host) + _fill_field(flags, flags_host) + + qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + got_n = int(num_out.to_numpy()[0]) + assert got_n == N // 2, f"interleaved 0/1 flags should select N/2 = {N // 2} entries, got {got_n}" + expected = inp_host[::2] + np.testing.assert_array_equal(out.to_numpy()[:got_n], expected) + + +@test_utils.test(arch=qd.gpu) +def test_device_select_rejects_shape_mismatch(): + inp = qd.field(qd.i32, shape=4) + flags = qd.field(qd.i32, shape=5) + out = qd.field(qd.i32, shape=4) + num_out = qd.field(qd.i32, shape=1) + with pytest.raises(TypeError): + qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + + +@test_utils.test(arch=qd.gpu) +def test_device_select_rejects_flags_wrong_dtype(): + inp = qd.field(qd.i32, shape=4) + flags = qd.field(qd.f32, shape=4) + out = qd.field(qd.i32, shape=4) + num_out = qd.field(qd.i32, shape=1) + with pytest.raises(TypeError): + qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + + +@test_utils.test(arch=qd.gpu) +def test_device_select_rejects_dtype_mismatch(): + inp = qd.field(qd.i32, shape=4) + flags = qd.field(qd.i32, shape=4) + out = qd.field(qd.f32, shape=4) + num_out = qd.field(qd.i32, shape=1) + with pytest.raises(TypeError): + qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + + +@test_utils.test(arch=qd.gpu) +def test_device_select_rejects_short_out(): + """out must hold at least N elements (worst-case all-selected).""" + inp = qd.field(qd.i32, shape=8) + flags = qd.field(qd.i32, shape=8) + out = qd.field(qd.i32, shape=4) # < input size + num_out = qd.field(qd.i32, shape=1) + with pytest.raises(ValueError): + qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + + +# --------------------------------------------------------------------------- +# Device radix sort +# --------------------------------------------------------------------------- + + +def _gen_keys(rng, dtype, N): + """Generate sortable test inputs for every supported key dtype. The float paths sprinkle a few signed-zero / + inf / denormal specials at the front of the array to exercise the sort-twiddle pattern.""" + if dtype == qd.u32: + return rng.integers(0, 2**32, size=N, dtype=np.uint32) + if dtype == qd.i32: + return rng.integers(-(2**31), 2**31 - 1, size=N, dtype=np.int32) + if dtype == qd.f32: + arr = rng.standard_normal(N).astype(np.float32) * 1e3 + if N >= 6: + arr[0] = -0.0 + arr[1] = 0.0 + arr[2] = np.float32(np.inf) + arr[3] = np.float32(-np.inf) + arr[4] = np.float32(1e-30) + arr[5] = np.float32(-1e-30) + return arr + if dtype == qd.u64: + # Span the high half of the u64 range too so all 8 byte-passes see non-zero histograms. + return rng.integers(0, 2**63, size=N, dtype=np.uint64).astype(np.uint64) * np.uint64(2) + if dtype == qd.i64: + return rng.integers(-(2**62), 2**62, size=N, dtype=np.int64) + if dtype == qd.f64: + arr = rng.standard_normal(N).astype(np.float64) * 1e6 + if N >= 6: + arr[0] = -0.0 + arr[1] = 0.0 + arr[2] = np.float64(np.inf) + arr[3] = np.float64(-np.inf) + arr[4] = np.float64(1e-300) + arr[5] = np.float64(-1e-300) + return arr + raise ValueError(dtype) + + +@pytest.mark.parametrize("N", _RADIX_SORT_SIZES) +@pytest.mark.parametrize("dtype", _RADIX_KEY_DTYPES) +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_keys_only(dtype, N): + """device_radix_sort matches numpy.sort for every supported key dtype ({u32, i32, f32, u64, i64, f64}).""" + _skip_if_dtype_unsupported(dtype) + _skip_if_radix_sort_large_n_on_apple_gpu(N) + rng = np.random.default_rng(seed=1234) + host = _gen_keys(rng, dtype, N) + + keys = qd.field(dtype, shape=N) + tmp = qd.field(dtype, shape=N) + _fill_field(keys, host) + + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + got = keys.to_numpy() + want = np.sort(host, kind="stable") + np.testing.assert_array_equal(got, want, err_msg=f"{dtype} radix_sort(N={N})") + + +@pytest.mark.parametrize("N", _RADIX_SORT_SIZES) +@pytest.mark.parametrize("dtype", _RADIX_KEY_DTYPES) +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_key_value(dtype, N): + """Key-value sort: values permute in lock-step with keys; sort is stable. Exercises the libuipc-shaped u64-key + + i32-value path (``MatrixConverter::ij_hash`` sorted with ``sort_index``) among the parametrized cases.""" + _skip_if_dtype_unsupported(dtype) + _skip_if_radix_sort_large_n_on_apple_gpu(N) + rng = np.random.default_rng(seed=1234) + host = _gen_keys(rng, dtype, N) + + keys = qd.field(dtype, shape=N) + tmp_keys = qd.field(dtype, shape=N) + values = qd.field(qd.i32, shape=N) + tmp_values = qd.field(qd.i32, shape=N) + _fill_field(keys, host) + _fill_field(values, np.arange(N, dtype=np.int32)) + + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp_keys, values=values, tmp_values=tmp_values) + + got_keys = keys.to_numpy() + got_values = values.to_numpy() + # Stable argsort gives the values permutation we expect. + want_idx = np.argsort(host, kind="stable") + want_keys = host[want_idx] + np.testing.assert_array_equal(got_keys, want_keys, err_msg=f"{dtype} keys(N={N})") + np.testing.assert_array_equal(got_values, want_idx.astype(np.int32), err_msg=f"{dtype} values(N={N})") + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_already_sorted(): + """No-op-ish input: already-sorted keys still come back sorted.""" + N = 5000 + keys = qd.field(qd.u32, shape=N) + tmp = qd.field(qd.u32, shape=N) + host = np.arange(N, dtype=np.uint32) * 7 + _fill_field(keys, host) + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + np.testing.assert_array_equal(keys.to_numpy(), host) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_reverse_sorted(): + """Worst-case-for-comparison-sort input is just normal work for radix.""" + N = 5000 + keys = qd.field(qd.i32, shape=N) + tmp = qd.field(qd.i32, shape=N) + host = (np.arange(N, dtype=np.int32) * -7).astype(np.int32) # decreasing + _fill_field(keys, host) + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + np.testing.assert_array_equal(keys.to_numpy(), np.sort(host)) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_all_same(): + """Many duplicates: radix rank still groups + scatters them correctly.""" + N = 5000 + keys = qd.field(qd.i32, shape=N) + tmp = qd.field(qd.i32, shape=N) + host = np.full(N, 42, dtype=np.int32) + _fill_field(keys, host) + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + np.testing.assert_array_equal(keys.to_numpy(), host) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_n1(): + """N=1 is the trivial early-return path.""" + keys = qd.field(qd.i32, shape=1) + tmp = qd.field(qd.i32, shape=1) + _fill_field(keys, np.asarray([42], dtype=np.int32)) + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + assert int(keys.to_numpy()[0]) == 42 + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_rejects_dtype_mismatch(): + keys = qd.field(qd.i32, shape=8) + tmp = qd.field(qd.u32, shape=8) + with pytest.raises(TypeError): + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_rejects_shape_mismatch(): + keys = qd.field(qd.i32, shape=8) + tmp = qd.field(qd.i32, shape=4) + with pytest.raises(TypeError): + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_rejects_aliasing(): + keys = qd.field(qd.i32, shape=8) + with pytest.raises(ValueError): + qd.algorithms.device_radix_sort(keys, tmp_keys=keys) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_rejects_unsupported_dtype(): + """Supported set is {u32, i32, f32, u64, i64, f64}; narrower / wider dtypes raise NotImplementedError.""" + keys = qd.field(qd.i16, shape=8) + tmp = qd.field(qd.i16, shape=8) + with pytest.raises(NotImplementedError): + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_rejects_missing_tmp_values(): + """values requires tmp_values.""" + keys = qd.field(qd.i32, shape=8) + tmp_keys = qd.field(qd.i32, shape=8) + values = qd.field(qd.i32, shape=8) + with pytest.raises(ValueError): + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp_keys, values=values) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_rejects_odd_passes(): + """end_bit must yield an even number of passes so the result lands in keys.""" + keys = qd.field(qd.i32, shape=8) + tmp = qd.field(qd.i32, shape=8) + with pytest.raises(ValueError): + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp, end_bit=8) # 1 pass - odd + + +# --------------------------------------------------------------------------- +# Device reduce-by-key (add) +# --------------------------------------------------------------------------- + + +def _ref_rbk_add(keys, values): + """Reference reduce-by-key: collapse consecutive runs of equal keys, returning ``(unique_keys, sums)``.""" + if len(keys) == 0: + return np.array([], dtype=keys.dtype), np.array([], dtype=values.dtype) + uniq_keys = [keys[0]] + sums = [values[0]] + for i in range(1, len(keys)): + if keys[i] != keys[i - 1]: + uniq_keys.append(keys[i]) + sums.append(values[i]) + else: + sums[-1] = sums[-1] + values[i] + return np.asarray(uniq_keys, dtype=keys.dtype), np.asarray(sums, dtype=values.dtype) + + +def _gen_run_keys(rng, dtype, N): + """Build a key vector of size N with a realistic run-length distribution. + + Runs are drawn from a small alphabet of 5-15 distinct values and repeated 1-8 times, then concatenated and + truncated to N. This guarantees both multi-element runs (so the scatter's atomic_add path is exercised) and + single-element runs (so the position math is exercised at boundary). + """ + np_t = to_numpy_type(dtype) + if dtype == qd.f32: + alphabet = rng.standard_normal(15).astype(np_t) + elif dtype == qd.u32: + alphabet = rng.integers(0, 100, size=15, dtype=np_t) + else: + alphabet = rng.integers(-50, 50, size=15, dtype=np_t) + run_keys = rng.choice(alphabet, size=N // 3 + 2) + run_lengths = rng.integers(1, 8, size=len(run_keys)) + keys = np.repeat(run_keys, run_lengths) + if len(keys) < N: + # pad with the last key to fill N elements (extends the final run) + keys = np.concatenate([keys, np.full(N - len(keys), keys[-1], dtype=np_t)]) + return keys[:N].astype(np_t) + + +@pytest.mark.parametrize("N", _RBK_SIZES) +@pytest.mark.parametrize("key_dtype", [qd.i32, qd.u32, qd.f32]) +@pytest.mark.parametrize("val_dtype", [qd.i32, qd.u32, qd.f32]) +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add(key_dtype, val_dtype, N): + """Cross-product of key dtype × value dtype × size, against a CPU oracle. + + Values are bounded so ``f32`` accumulation error stays controlled - the tolerance is ``_F32_LARGE_N_*`` for f32 + (atomic_add reorder layered on the scan-style scatter) and bit-exact for integer types. + """ + rng = np.random.default_rng(seed=1234) + keys_host = _gen_run_keys(rng, key_dtype, N) + val_np = to_numpy_type(val_dtype) + if val_dtype == qd.f32: + values_host = rng.uniform(-1.0, 1.0, size=N).astype(val_np) + elif val_dtype == qd.u32: + values_host = rng.integers(0, 100, size=N, dtype=val_np) + else: + values_host = rng.integers(-100, 100, size=N, dtype=val_np) + + keys_in = qd.field(key_dtype, shape=N) + values_in = qd.field(val_dtype, shape=N) + keys_out = qd.field(key_dtype, shape=N) + values_out = qd.field(val_dtype, shape=N) + num_runs = qd.field(qd.i32, shape=1) + _fill_field(keys_in, keys_host) + _fill_field(values_in, values_host) + + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + nr = int(num_runs.to_numpy()[0]) + want_keys, want_vals = _ref_rbk_add(keys_host, values_host) + + assert nr == len(want_keys), f"{key_dtype}/{val_dtype} N={N}: num_runs {nr} vs {len(want_keys)}" + got_keys = keys_out.to_numpy()[:nr] + got_vals = values_out.to_numpy()[:nr] + np.testing.assert_array_equal(got_keys, want_keys, err_msg=f"{key_dtype}/{val_dtype} N={N}: keys") + if val_dtype == qd.f32: + np.testing.assert_allclose( + got_vals, + want_vals, + rtol=_F32_LARGE_N_RTOL, + atol=_F32_LARGE_N_ATOL, + err_msg=f"{key_dtype}/{val_dtype} N={N}: values", + ) + else: + np.testing.assert_array_equal(got_vals, want_vals, err_msg=f"{key_dtype}/{val_dtype} N={N}: values") + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add_all_same(): + """All keys equal -> single run, values_out[0] = sum of all values.""" + N = 1024 + keys_in = qd.field(qd.i32, shape=N) + values_in = qd.field(qd.i32, shape=N) + keys_out = qd.field(qd.i32, shape=N) + values_out = qd.field(qd.i32, shape=N) + num_runs = qd.field(qd.i32, shape=1) + _fill_field(keys_in, np.full(N, 42, dtype=np.int32)) + rng = np.random.default_rng(seed=42) + vals = rng.integers(-100, 100, size=N, dtype=np.int32) + _fill_field(values_in, vals) + + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + assert int(num_runs.to_numpy()[0]) == 1 + assert int(keys_out.to_numpy()[0]) == 42 + assert int(values_out.to_numpy()[0]) == int(vals.astype(np.int64).sum()) + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add_all_unique(): + """No two consecutive keys equal -> num_runs == N, values_out is a copy of values.""" + N = 1024 + keys_in = qd.field(qd.i32, shape=N) + values_in = qd.field(qd.i32, shape=N) + keys_out = qd.field(qd.i32, shape=N) + values_out = qd.field(qd.i32, shape=N) + num_runs = qd.field(qd.i32, shape=1) + keys_host = np.arange(N, dtype=np.int32) * 7 + vals_host = np.arange(N, dtype=np.int32) * 11 - 3 + _fill_field(keys_in, keys_host) + _fill_field(values_in, vals_host) + + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + assert int(num_runs.to_numpy()[0]) == N + np.testing.assert_array_equal(keys_out.to_numpy(), keys_host) + np.testing.assert_array_equal(values_out.to_numpy(), vals_host) + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add_rejects_shape_mismatch(): + keys_in = qd.field(qd.i32, shape=8) + values_in = qd.field(qd.i32, shape=4) # wrong length + keys_out = qd.field(qd.i32, shape=8) + values_out = qd.field(qd.i32, shape=8) + num_runs = qd.field(qd.i32, shape=1) + with pytest.raises(TypeError): + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add_rejects_dtype_mismatch(): + keys_in = qd.field(qd.i32, shape=8) + values_in = qd.field(qd.i32, shape=8) + keys_out = qd.field(qd.f32, shape=8) # dtype != keys_in + values_out = qd.field(qd.i32, shape=8) + num_runs = qd.field(qd.i32, shape=1) + with pytest.raises(TypeError): + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add_rejects_short_out(): + """keys_out and values_out must hold at least N entries (worst case: all unique).""" + keys_in = qd.field(qd.i32, shape=16) + values_in = qd.field(qd.i32, shape=16) + keys_out = qd.field(qd.i32, shape=8) # too short + values_out = qd.field(qd.i32, shape=16) + num_runs = qd.field(qd.i32, shape=1) + with pytest.raises(ValueError): + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add_rejects_unsupported_dtype(): + keys_in = qd.field(qd.i64, shape=8) + values_in = qd.field(qd.i64, shape=8) + keys_out = qd.field(qd.i64, shape=8) + values_out = qd.field(qd.i64, shape=8) + num_runs = qd.field(qd.i32, shape=1) + with pytest.raises(NotImplementedError): + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + + +# --------------------------------------------------------------------------- +# Cross-cutting: runtime lifecycle, ndarray polymorphism, deprecation, scratch-capacity errors, end_bit, pipeline +# composition, N=1M. +# --------------------------------------------------------------------------- + + +@test_utils.test(arch=qd.gpu) +def test_scratch_invalidate_resets_bytes_to_default(): + """``_scratch._invalidate`` (hooked into ``qd.reset()``) resets BOTH the cached field handle AND ``_scratch_bytes`` + to the default. + + Pins the invariant: every ``qd.init`` starts with a pristine scratch config, exactly as a fresh process would. We + test ``_invalidate`` directly (rather than going through ``qd.reset()``) because we want to assert the post-reset + state *inside* a single test without fighting the conftest's per-test ``init`` / ``reset`` pairing. + + The ``arch=qd.gpu`` parametrization is for uniformity with the rest of the file - the assertion itself only + touches Python module-level state, not the GPU, so the per-arch loop is redundant but harmless. + """ + assert _scratch._scratch_bytes == _scratch.DEFAULT_SCRATCH_BYTES, ( + "test prerequisite: scratch_bytes starts at default; the previous " + "test's qd.reset() teardown should have left it that way" + ) + saved_field = _scratch._scratch_field + saved_field_u64 = _scratch._scratch_field_u64 + try: + _scratch._scratch_bytes = 8 << 20 + _scratch._invalidate() + assert _scratch._scratch_bytes == _scratch.DEFAULT_SCRATCH_BYTES + assert _scratch._scratch_field is None + assert _scratch._scratch_field_u64 is None, "_scratch_field_u64 must also be invalidated on qd.reset()" + finally: + _scratch._scratch_bytes = _scratch.DEFAULT_SCRATCH_BYTES + _scratch._scratch_field = saved_field + _scratch._scratch_field_u64 = saved_field_u64 + + +@pytest.fixture +def big_scratch(): + """Bump scratch to 8 MB for the duration of the test. + + No teardown - the conftest's per-test ``qd.reset()`` fires ``_scratch._invalidate``, which sets ``_scratch_bytes`` + back to ``DEFAULT_SCRATCH_BYTES`` and drops the field handle. That is what delivers test isolation for the next + test. Restoring here via ``set_scratch_bytes`` would fail anyway: once the test has run an algorithm, the scratch + field is allocated, and ``set_scratch_bytes`` rejects post-allocation bumps by design. + """ + _scratch.set_scratch_bytes(8 << 20) + yield + + +@pytest.mark.parametrize("dtype", _RADIX_KEY_DTYPES) +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_n_1m(dtype, big_scratch): # pylint: disable=unused-argument,redefined-outer-name + """N = 1_000_000 - qipc's hot-path size. Requires scratch bumped to ~5 MB; the ``big_scratch`` fixture supplies + 8 MB and restores after. 8-byte key dtypes run twice as many passes (8 instead of 4) for the same N. Scratch + requirement is unchanged - the histograms are always u32 - so the same ``big_scratch`` covers both widths.""" + _skip_if_dtype_unsupported(dtype) + N = 1_000_000 + _skip_if_radix_sort_large_n_on_apple_gpu(N) + rng = np.random.default_rng(seed=1234) + host = _gen_keys(rng, dtype, N) + + keys = qd.field(dtype, shape=N) + tmp = qd.field(dtype, shape=N) + _fill_field(keys, host) + + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + np.testing.assert_array_equal(keys.to_numpy(), np.sort(host, kind="stable")) + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add_n_1m(big_scratch): # pylint: disable=unused-argument,redefined-outer-name + """N = 1_000_000 reduce-by-key. Same scratch requirement as the 1M radix sort; the kernel sequence is different + (just scan + scatter) but the in-place scan over scratch[0:N] needs the bump.""" + N = 1_000_000 + _skip_if_radix_sort_large_n_on_apple_gpu(N) + rng = np.random.default_rng(seed=1234) + keys_host = _gen_run_keys(rng, qd.i32, N) + values_host = rng.integers(-100, 100, size=N, dtype=np.int32) + + keys_in = qd.field(qd.i32, shape=N) + values_in = qd.field(qd.i32, shape=N) + keys_out = qd.field(qd.i32, shape=N) + values_out = qd.field(qd.i32, shape=N) + num_runs = qd.field(qd.i32, shape=1) + _fill_field(keys_in, keys_host) + _fill_field(values_in, values_host) + + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + nr = int(num_runs.to_numpy()[0]) + want_keys, want_vals = _ref_rbk_add(keys_host, values_host) + assert nr == len(want_keys) + np.testing.assert_array_equal(keys_out.to_numpy()[:nr], want_keys) + np.testing.assert_array_equal(values_out.to_numpy()[:nr], want_vals) + + +# --- Polymorphic-tensor coverage (Field vs Ndarray) for the algorithm surface is currently Field-only - every kernel +# param is annotated ``template()``, which only accepts Field-like storage. The design doc captures the +# future-direction switch to ``qd.Tensor`` kernel annotations (which would let bare Ndarrays through unchanged). +# That's a follow-up; the unwrap path for ``qd.Tensor(field)`` is exercised at the kernel-API level in +# ``test_tensor_wrapper_kernel.py`` so we don't repeat it here (passing ``qd.Tensor(...)`` through ``device_*`` would +# also pin the per-kernel ``_tensor_unwrap_indices`` cache and break subsequent bare-Field tests, by design of the +# kernel.py fast-path optimisation). + + +# --- Deprecation warnings on the legacy executor / parallel_sort surfaces. We added the warnings; assert they +# actually fire so an accidental rebase that drops them is caught. + + +@test_utils.test(arch=qd.gpu) +def test_prefix_sum_executor_emits_deprecation_warning(): + """`PrefixSumExecutor(N)` must emit `DeprecationWarning` per the migration plan in algorithms.md.""" + with pytest.warns(DeprecationWarning, match="device_exclusive_scan_add"): + qd.algorithms.PrefixSumExecutor(64) + + +@test_utils.test(arch=qd.gpu) +def test_parallel_sort_emits_deprecation_warning(): + """`parallel_sort` must emit `DeprecationWarning` per the migration plan in algorithms.md.""" + keys = qd.field(qd.i32, shape=8) + with pytest.warns(DeprecationWarning, match="device_radix_sort"): + qd.algorithms.parallel_sort(keys) + + +# --- Scratch-capacity error paths. Each algorithm raises a clear RuntimeError when N would push the scratch budget +# over the configured capacity, rather than corrupting data. Tests shrink scratch to ``_TINY_SCRATCH_BYTES`` so the +# trip point is reachable with a small N (cheap to allocate, runtime-independent of the DEFAULT_SCRATCH_BYTES knob). + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_rejects_oversized_n(): + """``device_radix_sort`` raises ``RuntimeError`` pointing the caller at ``set_scratch_bytes`` when N exceeds the + scratch budget. Shrink scratch first so the trip point is reachable with a tiny N.""" + _scratch.set_scratch_bytes(_TINY_SCRATCH_BYTES) + N = 4 * _scratch.scratch_capacity_u32() # comfortably over the tiny-scratch ceiling + keys = qd.field(qd.i32, shape=N) + tmp = qd.field(qd.i32, shape=N) + with pytest.raises(RuntimeError, match="scratch"): + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_recursive_scratch_check_keys_unchanged(): + """Regression: ``device_radix_sort`` must refuse the call *before* ``_twiddle_pass`` mutates the user's keys + when the scratch budget is too small for the *recursive* in-place scan footprint. + + The bug this guards against (PR 693 review): the up-front scratch check counted only one level of scan + partials (``hist_len + ceil(hist_len/BLOCK_DIM)``). For ``N`` large enough to force the in-place exclusive + scan to recurse (``hist_len > BLOCK_DIM**2``), a budget that's just a few slots too small slipped past that + single-level check, then ``_twiddle_pass`` ran (in-place XOR of sign bits for ``i32`` / ``f32`` keys), and + only *then* did the recursive scan raise a ``RuntimeError`` - leaving the caller's ``keys`` corrupted with + no recovery path. After the fix, the check uses ``_scan_total_scratch_slots`` to account for the full + recursion up front, so we refuse the call before any side effect runs. + + Setup picks a budget in the (single-level-pass, full-recursion-fail) window so the test would have *failed* + against the buggy old check (twiddle would have run, ``keys`` would be XOR'd) and *passes* against the fixed + check (``keys`` are byte-identical to what the user wrote in). + """ + from quadrants.algorithms._radix_sort import BLOCK_DIM, RADIX_DIGITS + from quadrants.algorithms._scan import _scan_total_scratch_slots + + N = 1_000_000 # large enough that hist_len > BLOCK_DIM**2 = 65_536, forcing the scan to recurse one level + num_blocks = (N + BLOCK_DIM - 1) // BLOCK_DIM + hist_len = num_blocks * RADIX_DIGITS + old_needed = hist_len + (hist_len + BLOCK_DIM - 1) // BLOCK_DIM # buggy single-level estimate + new_needed = _scan_total_scratch_slots(hist_len, partials_cursor=hist_len) # full recursive footprint + assert new_needed > old_needed, ( + "test setup invariant: scan must recurse for the test to discriminate against the bug; " + f"got old_needed={old_needed}, new_needed={new_needed} at N={N} - increase N if BLOCK_DIM grew" + ) + # Budget in the bug window: passes the buggy old check, fails the fixed one. (new - old is small, ~16 slots + # at N=1M, so any midpoint works.) Round to even so the bytes count is a multiple of 8 (a ``set_scratch_bytes`` + # precondition that holds because the u64 scratch field shares the same byte budget). + cap_target = old_needed + (new_needed - old_needed) // 2 + cap_target += cap_target & 1 # snap up to even + assert old_needed < cap_target < new_needed, ( + f"bug-window selection invariant: old_needed={old_needed} < cap_target={cap_target} < " + f"new_needed={new_needed} should hold for the test to discriminate against the bug" + ) + _scratch.set_scratch_bytes(cap_target * 4) + + rng = np.random.default_rng(seed=1234) + host = rng.integers(-(2**30), 2**30, size=N, dtype=np.int32) # signed -> hits the in-place twiddle path + keys = qd.field(qd.i32, shape=N) + tmp = qd.field(qd.i32, shape=N) + _fill_field(keys, host) + + with pytest.raises(RuntimeError, match="scratch"): + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp) + + # The crucial assertion: keys are still the user's original bit pattern, not XOR'd by twiddle. + np.testing.assert_array_equal(keys.to_numpy(), host) + + +@test_utils.test(arch=qd.gpu) +def test_device_select_rejects_oversized_n(): + """Same scratch-capacity error path for device_select.""" + _scratch.set_scratch_bytes(_TINY_SCRATCH_BYTES) + N = 4 * _scratch.scratch_capacity_u32() + inp = qd.field(qd.i32, shape=N) + flags = qd.field(qd.i32, shape=N) + out = qd.field(qd.i32, shape=N) + num_out = qd.field(qd.i32, shape=1) + with pytest.raises(RuntimeError, match="scratch"): + qd.algorithms.device_select(inp, flags, out=out, num_out=num_out) + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_by_key_add_rejects_oversized_n(): + """Same scratch-capacity error path for reduce-by-key.""" + _scratch.set_scratch_bytes(_TINY_SCRATCH_BYTES) + N = 4 * _scratch.scratch_capacity_u32() + keys_in = qd.field(qd.i32, shape=N) + values_in = qd.field(qd.i32, shape=N) + keys_out = qd.field(qd.i32, shape=N) + values_out = qd.field(qd.i32, shape=N) + num_runs = qd.field(qd.i32, shape=1) + with pytest.raises(RuntimeError, match="scratch"): + qd.algorithms.device_reduce_by_key_add( + keys_in, values_in, keys_out=keys_out, values_out=values_out, num_runs=num_runs + ) + + +@test_utils.test(arch=qd.gpu) +def test_device_reduce_add_rejects_oversized_n(): + """``device_reduce_*`` needs ~(B + B/256 + …) u32 slots where ``B = ceil(N / BLOCK_DIM)``; the trip point in N is + ``BLOCK_DIM * capacity_u32``. With the tiny scratch budget that's ~256 * 16K = 4M; use 5M to be comfortably over. + The kernel itself never launches; the validate-budget check trips first.""" + _scratch.set_scratch_bytes(_TINY_SCRATCH_BYTES) + N = 256 * _scratch.scratch_capacity_u32() + 100_000 + inp = qd.field(qd.i32, shape=N) + out = qd.field(qd.i32, shape=1) + with pytest.raises(RuntimeError, match="scratch"): + qd.algorithms.device_reduce_add(inp, out=out) + + +@test_utils.test(arch=qd.gpu) +def test_device_exclusive_scan_add_rejects_oversized_n(): + """Same scratch-capacity error path for device_exclusive_scan_add. ``device_exclusive_scan_*`` needs ``B`` u32 + partials slots at the top level (plus recursive); trip point in N is ``BLOCK_DIM * capacity_u32``.""" + _scratch.set_scratch_bytes(_TINY_SCRATCH_BYTES) + N = 256 * _scratch.scratch_capacity_u32() + 100_000 + inp = qd.field(qd.i32, shape=N) + out = qd.field(qd.i32, shape=N) + with pytest.raises(RuntimeError, match="scratch"): + qd.algorithms.device_exclusive_scan_add(inp, out=out) + + +# --- Reduce / scan at N = 1M alongside the radix sort + RBK 1M coverage. Reduce / scan's scratch budget at 1M is +# small (4K + recursion ~ 16 u32 slots), trivially below the default 5 MB, so no ``big_scratch`` fixture is needed - +# included here just to round out the qipc-hot-path coverage on the same dtypes as the other 1M tests. + + +@pytest.mark.parametrize("dtype", _REDUCE_DTYPES) +@test_utils.test(arch=qd.gpu) +def test_device_reduce_add_n_1m(dtype): + """N = 1_000_000 reduce over the full dtype matrix. 4-byte dtypes use the u32 scratch (4K slots for top-level + partials, recursion adds ~16); 8-byte dtypes use the u64 scratch with the same slot count at half the byte cost. + Default 5 MB capacity covers both by a wide margin.""" + _skip_if_dtype_unsupported(dtype) + N = 1_000_000 + rng = np.random.default_rng(seed=1234) + host = _rand_reduce_host(rng, dtype, N) + + inp = qd.field(dtype, shape=N) + out = qd.field(dtype, shape=1) + _fill_field(inp, host) + qd.algorithms.device_reduce_add(inp, out=out) + + got = out.to_numpy()[0] + if _is_float(dtype): + expected = float(np.sum(host.astype(np.float64))) + rtol, atol = (_F32_LARGE_N_RTOL, _F32_LARGE_N_ATOL) if dtype == qd.f32 else (_F64_RTOL, _F64_ATOL) + assert math.isclose(got, expected, rel_tol=rtol, abs_tol=atol) + else: + # Promote to a wide enough Python int / numpy int for the reference, then mask both to dtype width. + if dtype in (qd.i32, qd.i64): + expected = int(np.sum(host.astype(np.int64))) + else: + expected = int(np.sum(host.astype(np.uint64))) + got_int = int(got) + if dtype == qd.u32: + got_int &= 0xFFFFFFFF + expected &= 0xFFFFFFFF + elif dtype == qd.u64: + got_int &= 0xFFFFFFFFFFFFFFFF + expected &= 0xFFFFFFFFFFFFFFFF + assert got_int == expected + + +@pytest.mark.parametrize("dtype", _SCAN_DTYPES) +@test_utils.test(arch=qd.gpu) +def test_device_exclusive_scan_add_n_1m(dtype): + """N = 1_000_000 exclusive scan over the full dtype matrix. 4-byte dtypes go through the u32 scratch; 8-byte + dtypes through the u64 scratch (4K slots at the top level for both, recursion adds ~16). Both fit in default + 5 MB by a wide margin.""" + _skip_if_dtype_unsupported(dtype) + N = 1_000_000 + rng = np.random.default_rng(seed=1234) + np_dt = _DTYPE_TO_NP[dtype] + if dtype == qd.f32: + host = rng.uniform(-0.01, 0.01, size=N).astype(np_dt) + elif dtype == qd.f64: + host = rng.uniform(-0.01, 0.01, size=N).astype(np_dt) + elif dtype in (qd.u32, qd.u64): + host = rng.integers(0, 10, size=N, dtype=np_dt) + else: + host = rng.integers(-5, 5, size=N, dtype=np_dt) + + inp = qd.field(dtype, shape=N) + out = qd.field(dtype, shape=N) + _fill_field(inp, host) + qd.algorithms.device_exclusive_scan_add(inp, out=out) + + got = out.to_numpy() + if _is_float(dtype): + ref = np.concatenate([[0.0], np.cumsum(host.astype(np.float64))[:-1]]).astype(np_dt) + # f32: cumulative drift over 1M adds is real; check head only with a generous rtol, then verify finite at + # the tail. f64: 12 orders of magnitude more precision, can check the whole array with tight tolerance. + if dtype == qd.f32: + np.testing.assert_allclose(got[:64], ref[:64], rtol=_F32_LARGE_N_RTOL, atol=_F32_LARGE_N_ATOL) + assert np.isfinite(got[-1]) + else: + np.testing.assert_allclose(got, ref, rtol=_F64_RTOL, atol=_F64_ATOL) + else: + promote = np.int64 if dtype in (qd.i32, qd.u32, qd.i64) else np.uint64 + ref = np.concatenate([[promote(0)], np.cumsum(host.astype(promote))[:-1]]).astype(promote) + np.testing.assert_array_equal(got.astype(promote), ref) + + +# --- End-to-end round-trip: bump scratch, run a 1M algorithm, qd.reset + qd.init, then run a default-scratch-sized +# algorithm. This directly validates the principle "after reset+init, everything works as if there was nothing +# before it" - the bumped capacity from the first cycle must NOT leak into the second cycle's scratch. + + +@test_utils.test(arch=qd.gpu) +def test_scratch_round_trip_across_qd_reset(req_arch): + """Run a bumped-scratch algorithm; ``qd.reset()`` + ``qd.init()``; then run another algorithm at default scratch. + + The bumped capacity from cycle 1 must be gone in cycle 2 - otherwise the second ``qd.init()`` would over-allocate + against the unwanted bump. This is the *behavioural* version of ``test_scratch_invalidate_resets_bytes_to_default`` + (which only manipulates module state directly). + """ + # Pick a "too big" N relative to the default scratch, so the cycle-2 retry trips the budget guard regardless of + # what ``DEFAULT_SCRATCH_BYTES`` happens to be. ``2 * capacity_u32`` overshoots the default by 2x. + default_capacity_u32 = _scratch.DEFAULT_SCRATCH_BYTES // 4 + N1 = 2 * default_capacity_u32 # comfortably over the default scratch ceiling for radix sort + _skip_if_radix_sort_large_n_on_apple_gpu(N1) + + # --- Cycle 1: bump scratch enough to cover N1, run the sort. + _scratch.set_scratch_bytes(4 * _scratch.DEFAULT_SCRATCH_BYTES) + rng = np.random.default_rng(seed=1234) + host1 = rng.integers(0, 2**31 - 1, size=N1, dtype=np.int32) + keys1 = qd.field(qd.i32, shape=N1) + tmp1 = qd.field(qd.i32, shape=N1) + _fill_field(keys1, host1) + qd.algorithms.device_radix_sort(keys1, tmp_keys=tmp1) + np.testing.assert_array_equal(keys1.to_numpy(), np.sort(host1)) + + # --- Cross the qd.reset() + qd.init() boundary. After this, everything should behave as if cycle 1 never ran. + qd.reset() + qd.init(arch=req_arch, enable_fallback=False, device_memory_GB=0.3, print_full_traceback=True) + + # Post-reset invariants on the scratch module. + assert ( + _scratch._scratch_bytes == _scratch.DEFAULT_SCRATCH_BYTES + ), "_scratch_bytes did not reset to default across qd.reset() + qd.init() - the very leak this test pins" + assert _scratch._scratch_field is None, "_scratch_field handle was not invalidated across qd.reset()" + assert _scratch._scratch_field_u64 is None, "_scratch_field_u64 handle was not invalidated across qd.reset()" + + # --- Cycle 2: run a small algorithm with default scratch. Should just work - and crucially, attempting an + # over-budget sort NOW (without re-bumping) should *raise* RuntimeError because the bumped capacity is gone. + N2 = 1024 + host2 = rng.integers(0, 100, size=N2, dtype=np.int32) + keys2 = qd.field(qd.i32, shape=N2) + tmp2 = qd.field(qd.i32, shape=N2) + _fill_field(keys2, host2) + qd.algorithms.device_radix_sort(keys2, tmp_keys=tmp2) + np.testing.assert_array_equal(keys2.to_numpy(), np.sort(host2)) + + # Re-attempting the over-budget sort without re-bumping must fail - proves the capacity really did drop back. + keys3 = qd.field(qd.i32, shape=N1) + tmp3 = qd.field(qd.i32, shape=N1) + with pytest.raises(RuntimeError, match="scratch"): + qd.algorithms.device_radix_sort(keys3, tmp_keys=tmp3) + + +# --- end_bit on radix sort. Default 32; lower values let callers sort by only the low bits when they know the high +# bits are zero (qipc's case for some small-value sorts). + + +@test_utils.test(arch=qd.gpu) +def test_device_radix_sort_end_bit_16(): + """end_bit=16 sorts by only the low 16 bits; high bits are ignored. Build keys where the low 16 bits and the high + 16 bits disagree on order, then verify sort is by low 16.""" + N = 1024 + rng = np.random.default_rng(seed=1234) + low = rng.integers(0, 1 << 16, size=N, dtype=np.uint32) + # High bits decreasing, so a sort by high bits would reverse the array; if the algorithm correctly ignores the + # high bits, sort key is `low`. + high = (np.arange(N, dtype=np.uint32)[::-1]).astype(np.uint32) + host = ((high << np.uint32(16)) | low).astype(np.uint32) + + keys = qd.field(qd.u32, shape=N) + tmp = qd.field(qd.u32, shape=N) + _fill_field(keys, host) + + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp, end_bit=16) + got = keys.to_numpy() + # `got` should be sorted by the low 16 bits, in stable order of the original input. Tie-breaking on the low 16 + # bits keeps the original input index order. + got_low = got & 0xFFFF + assert np.all(np.diff(got_low.astype(np.int64)) >= 0), "low-16 not non-decreasing" + + +# --- Full pipeline: radix sort, then reduce-by-key. The qipc-shaped composition (unsorted (key, value) pairs -> +# global per-key sums). + + +@pytest.mark.parametrize("dtype", [qd.i32, qd.u32, qd.f32]) +@test_utils.test(arch=qd.gpu) +def test_radix_sort_then_reduce_by_key_pipeline(dtype): + """Sort by key, then reduce-by-key, to produce a global per-key sum. Cross-checked against numpy.unique + + numpy.add.reduceat.""" + N = 4096 + rng = np.random.default_rng(seed=1234) + if dtype == qd.f32: + # Use a small set of f32 values to maximise repeats; this also keeps the f32 atomic_add accumulation tolerance + # comfortable. + alphabet = rng.standard_normal(20).astype(np.float32) + elif dtype == qd.u32: + alphabet = rng.integers(0, 100, size=20, dtype=np.uint32) + else: + alphabet = rng.integers(-50, 50, size=20, dtype=np.int32) + keys_host = rng.choice(alphabet, size=N) + values_host = rng.integers(-10, 10, size=N, dtype=np.int32) + + keys = qd.field(dtype, shape=N) + tmp_keys = qd.field(dtype, shape=N) + values = qd.field(qd.i32, shape=N) + tmp_values = qd.field(qd.i32, shape=N) + _fill_field(keys, keys_host) + _fill_field(values, values_host) + + qd.algorithms.device_radix_sort(keys, tmp_keys=tmp_keys, values=values, tmp_values=tmp_values) + # After sort, keys is ascending; values is permuted to match. Now RBK collapses runs of equal keys into per-key + # sums. + keys_out = qd.field(dtype, shape=N) + values_out = qd.field(qd.i32, shape=N) + num_runs = qd.field(qd.i32, shape=1) + qd.algorithms.device_reduce_by_key_add(keys, values, keys_out=keys_out, values_out=values_out, num_runs=num_runs) + + nr = int(num_runs.to_numpy()[0]) + got_keys = keys_out.to_numpy()[:nr] + got_vals = values_out.to_numpy()[:nr] + + # CPU reference: numpy.unique with sum aggregation, matching the device's sort + RBK semantics. + uniq, idx = np.unique(keys_host, return_inverse=True) + sums = np.zeros(len(uniq), dtype=np.int64) + np.add.at(sums, idx, values_host.astype(np.int64)) + + assert nr == len(uniq), f"num_runs mismatch: got {nr}, expected {len(uniq)}" + if dtype == qd.f32: + np.testing.assert_allclose(got_keys, uniq, rtol=0, atol=0) + else: + np.testing.assert_array_equal(got_keys, uniq) + np.testing.assert_array_equal(got_vals.astype(np.int64), sums) diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 83ffba483b..368e12bc54 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -264,7 +264,19 @@ def _get_expected_matrix_apis(): "grad_replaced", "no_grad", ] -user_api[qd.algorithms] = ["PrefixSumExecutor", "parallel_sort"] +user_api[qd.algorithms] = [ + "PrefixSumExecutor", + "device_exclusive_scan_add", + "device_exclusive_scan_max", + "device_exclusive_scan_min", + "device_radix_sort", + "device_reduce_add", + "device_reduce_by_key_add", + "device_reduce_max", + "device_reduce_min", + "device_select", + "parallel_sort", +] user_api[qd.Field] = [ "copy_from", "dtype", diff --git a/tests/python/test_atomic.py b/tests/python/test_atomic.py index 20fd2bd0ca..59ab29fac4 100644 --- a/tests/python/test_atomic.py +++ b/tests/python/test_atomic.py @@ -812,7 +812,7 @@ def func(): # Pins the documented semantics of qd.atomic_cas: returns the prior value of `dest`, swaps in `desired` only when # the prior value equals `expected`. Single-thread sanity covering both the success path (prior == expected) and # the failure path (prior != expected) for every integer dtype the codegen path supports today (i32 / u32 / -# i64 / u64). f32 / f64 CAS is currently rejected at trace time -- a separate negative test pins that. +# i64 / u64). f32 / f64 CAS is currently rejected at compile time -- a separate negative test pins that. @pytest.mark.parametrize("dtype", [qd.i32, qd.u32, qd.i64, qd.u64]) @test_utils.test(arch=qd.gpu) def test_atomic_cas_returns_old_value(dtype): @@ -943,7 +943,7 @@ def func(): assert y[i] == 100 + i, f"failure-path CAS demoted: expected prior {100 + i}, got {y[i]}" -# Pins the doc claim that atomic_cas on float dtypes raises a type error at trace time. f32 / f64 CAS is not +# Pins the doc claim that atomic_cas on float dtypes raises a type error at compile time. f32 / f64 CAS is not # yet wired up (would need the same uint-bitcast trick xchg uses); the type_check carve-out in # AtomicOpExpression::type_check rejects it cleanly until the lowering lands. @pytest.mark.parametrize("dtype", [qd.f32, qd.f64]) @@ -961,7 +961,7 @@ def kern(): kern() -# Pins that atomic_cas on a Vector / Matrix destination is rejected at trace time. The other atomic ops fan +# Pins that atomic_cas on a Vector / Matrix destination is rejected at compile time. The other atomic ops fan # out to per-component scalar AtomicOpStmts via scalarize / lower_matrix_ptr, but those passes use the 3-arg # AtomicOpStmt constructor that drops `expected`. Until the scalarizers grow a 4-arg path, refusing tensor # CAS up front is the correct behaviour. Codex / alanray-tech P1 from PR #690 review.