Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
8bd2336
[feat] Add block.reduce_{add,min,max} + reduce_all_* (CUB-style block…
hughperkins May 10, 2026
ee98d5b
[feat] Add block.{inclusive,exclusive}_{add,min,max} (CUB-style block…
hughperkins May 10, 2026
5cfdfbc
[feat] Add block.radix_rank_match_atomic_or (CUB ATOMIC_OR path)
hughperkins May 10, 2026
3b34815
[fix] block.radix_rank: pair subgroup.sync() with subgroup.mem_fence()
hughperkins May 10, 2026
d3e73ba
[fix] block.radix_rank: feed u32 to clz, not i32 (SPIR-V FindSMsb bug)
hughperkins May 10, 2026
bac3f9f
[fix] block.reduce/scan: use logical lane (tid & WARP_SIZE-1), not in…
hughperkins May 10, 2026
351017b
Merge branch 'hp/new-qipc-ops-subgroup' into hp/new-qipc-ops-block
hughperkins May 10, 2026
b97f308
Merge remote-tracking branch 'origin/hp/new-qipc-ops-subgroup' into h…
hughperkins May 12, 2026
8d9f0aa
Merge remote-tracking branch 'origin/hp/new-qipc-ops-subgroup' into h…
hughperkins May 12, 2026
dfc9729
[block] Drop tid arg; fold internal _reduce/_scan helpers into public…
hughperkins May 12, 2026
b805112
[block] Drop log2_warp arg from public API; read subgroup size intern…
hughperkins May 12, 2026
f4f8db2
[doc] Sweep 'trace time' -> 'compile time' in block / test_atomic
hughperkins May 12, 2026
bbf4a11
[block] Wrap template arithmetic in impl.static() so static_assert se…
hughperkins May 12, 2026
e5207eb
[subgroup] Add private generic-op _reduce(value, op, log2_size); bloc…
hughperkins May 12, 2026
3b9052f
[block] Rename warp -> subgroup for the cross-GPU naming convention
hughperkins May 12, 2026
cd9e546
[block] radix_rank_match_atomic_or: add wave64 path
hughperkins May 12, 2026
3f1e5bd
Merge origin/hp/new-qipc-ops-subgroup into hp/new-qipc-ops-block
hughperkins May 12, 2026
456f002
[block] test: skip block_dim < subgroup_size on wave64
hughperkins May 12, 2026
20a68d7
Merge origin/hp/new-qipc-ops-subgroup into hp/new-qipc-ops-block
hughperkins May 12, 2026
e6a5b71
[block] test: parametrize block tests by subgroups-per-block, derive …
hughperkins May 13, 2026
60265cf
[block] style: black 25.1.0 reformatting of merged subgroup tests
hughperkins May 13, 2026
87e43e9
[block] doc: refresh radix_rank constraints + memory footprint for wa…
hughperkins May 13, 2026
59f831a
Merge branch 'hp/new-qipc-ops-subgroup' into hp/new-qipc-ops-block
hughperkins May 13, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 109 additions & 10 deletions docs/source/user_guide/block.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,28 @@ The closely-related device-scope memory fence is documented separately in [grid]

## What's available

| Op | CUDA | AMDGPU | Vulkan | Metal |
|-------------------------------------------------|------|--------|--------|-------|
| `block.sync()` | yes | yes | yes | yes |
| `block.sync_all_nonzero(predicate)` | yes | yes\* | yes\* | yes\* |
| `block.sync_any_nonzero(predicate)` | yes | yes\* | yes\* | yes\* |
| `block.sync_count_nonzero(predicate)` | yes | yes\* | yes\* | yes\* |
| `block.mem_fence()` | yes | yes | yes | yes |
| `block.SharedArray(shape, dtype)` | yes | yes | yes | yes |
| `block.global_thread_idx()` | yes | yes | yes | yes |
| `block.thread_idx()` | yes | yes | yes | yes |
| Op | CUDA | AMDGPU | Vulkan | Metal |
|-------------------------------------------------|------|---------|--------|-------|
| `block.sync()` | yes | yes | yes | yes |
| `block.sync_all_nonzero(predicate)` | yes | yes\* | yes\* | yes\* |
| `block.sync_any_nonzero(predicate)` | yes | yes\* | yes\* | yes\* |
| `block.sync_count_nonzero(predicate)` | yes | yes\* | yes\* | yes\* |
| `block.mem_fence()` | yes | yes | yes | yes |
| `block.SharedArray(shape, dtype)` | yes | yes | yes | yes |
| `block.global_thread_idx()` | yes | yes | yes | yes |
| `block.thread_idx()` | yes | yes | yes | yes |
| `block.reduce_{add,min,max}(v, block_dim, dtype)` | yes | yes | yes | yes |
| `block.reduce_all_{add,min,max}(v, block_dim, dtype)` | yes | yes | yes | yes |
| `block.inclusive_{add,min,max}(v, block_dim, dtype)` | yes | yes | yes | yes |
| `block.exclusive_{add,min,max}(v, block_dim, ...)` | yes | yes | yes | yes |
| `block.radix_rank_match_atomic_or(...)` | yes | yes | yes | yes |

Vulkan and Metal share a SPIR-V codegen path (Metal goes through MoltenVK → MSL); they are listed as separate columns because a couple of ops have Metal-specific caveats called out below. Footnoted entries are still functional, just with the limitations the footnote describes.

\* On AMDGPU, Vulkan, and Metal the `block.sync_{all,any,count}_nonzero(p)` ops are *emulated* via shared memory (one shared `i32` slot + 2 block barriers + a single `atomic_add` per contributing thread) rather than a single hardware-fused barrier-with-reduction. CUDA has the fused NVPTX `barrier.cta.red.{and,or,popc}.aligned.all.sync` family of intrinsics so it stays on the fast path; the other backends do not have a direct analog (in particular, SPIR-V `OpGroupNonUniform*` only operates at subgroup scope reliably across Vulkan + Metal). All three reductions are routed through `atomic_add` rather than `atomic_or` / `atomic_and`: the latter trip a Metal-specific bug where `OpAtomicOr` on threadgroup memory silently no-ops via MoltenVK / SPIRV-Cross. The emulation is correct and portable but costs two `block.sync()`s plus one shared-memory atomic per call instead of a single barrier instruction; if you have an inner loop calling these ops millions of times, consider whether you can batch the predicate before reducing it.

`block.radix_rank_match_atomic_or` is portable across wave32 (CUDA, Vulkan-on-NVIDIA, Metal) and wave64 (AMDGPU — Quadrants pins every AMDGPU target to `+wavefrontsize64`). The match-mask shared-memory region picks its dtype at compile time: `i32` on wave32 (32-lane ballot fits in a single `i32`, with `subgroup.lanemask_le` and `clz` / `popcnt` on `u32`) and `i64` on wave64 (64-lane ballot needs 64 bits, with an inline u64 `lanemask_le` and `clz` / `popcnt` on `u64`). The two paths share steps 1–4 (per-subgroup histograms, column-sum upsweep, block exclusive scan, downsweep) and step 6 (publish bins + exclusive prefix); only the per-key match phase (step 5) diverges. Atomic `or` on `i64` shared memory is native on AMDGPU LDS; wave32 backends never see the `i64` path, so portability does not depend on SPIR-V / Metal supporting 64-bit threadgroup atomics.

Naming note: `block.mem_sync()` was recently renamed to `block.mem_fence()` for consistency with the project's "fence vs barrier" terminology. The old name is still available as a deprecated alias that emits `DeprecationWarning` on first use; new code should use `block.mem_fence()`.

## Barrier vs fence: the distinction that matters
Expand Down Expand Up @@ -118,6 +125,98 @@ This is the thread's index *within its own block / workgroup*. To get the across

Today only the X dimension is exposed (1-D blocks). For 2-D / 3-D blocks the calling code should compute the linear index from `block.thread_idx()` and the block-Y / Z dimensions itself, or stick to 1-D blocks (the dominant Quadrants idiom — `qd.loop_config(block_dim=N)` always sets the X extent).

### `block.reduce_{add,min,max}(value, block_dim, dtype)`

Block-scope reductions following the standard two-stage subgroup-reduction strategy: each subgroup reduces its lanes via a `shuffle_down` tree, lane 0 of each subgroup publishes the subgroup aggregate to shared memory, then thread 0 sequentially folds the subgroup aggregates with the same operator. The result is valid in **thread 0 only**; other threads retain partial values. For the broadcast-to-every-thread variants see `block.reduce_all_{add,min,max}` below.

Arguments:

- `value`: per-thread input.
- `block_dim`: threads per block (compile-time `template()`). Must be a positive multiple of `subgroup.group_size()`, which resolves to 32 on CUDA / Metal / Vulkan-on-NVIDIA and 64 on AMDGPU. Passing a `block_dim` that is not a multiple of the subgroup size raises a compile-time error.
- `dtype`: scalar dtype for the inter-subgroup shared-memory staging slot; must match `value`'s type.

The calling thread's block-local index is read internally via `block.thread_idx()`; the subgroup size is read from `subgroup.group_size()` at compile time. Neither is plumbed through as an argument.

Cost: `log2(subgroup_size)` shuffles + 1 shared-memory write/read per subgroup + 1 `block.sync()` + `(block_dim / subgroup_size) - 1` ops on thread 0. When the block is exactly one subgroup the shared-memory path is short-circuited at compile time.

```python
@qd.kernel
def kern(src: qd.types.ndarray(ndim=1), out: qd.types.ndarray(ndim=1)):
qd.loop_config(block_dim=128)
for i in range(N):
agg = qd.simt.block.reduce_add(src[i], 128, qd.f32)
if qd.simt.block.thread_idx() == 0:
out[i // 128] = agg
```

A generic `block.reduce(value, block_dim, op, dtype)` is also available for custom associative operators (e.g. bitwise ops, custom monoids). It accepts an `op: template()` `@qd.func` taking `(a, b)` and returning the same type as `value`.

### `block.reduce_all_{add,min,max}(value, block_dim, dtype)`

The broadcast variants of the above. Identical semantics, but the result is published to a one-slot `SharedArray` and read back by every thread after a second `block.sync()`. Use this when downstream code on every thread needs the block-wide aggregate (e.g. normalising each thread's value by the block sum). Cost: one extra `block.sync()` plus one shared-memory hop vs. the lane-0-only variants. The corresponding generic form is `block.reduce_all(value, block_dim, op, dtype)`.

### `block.inclusive_{add,min,max}(value, block_dim, dtype)`

Block-scope inclusive prefix scans via the standard two-stage subgroup-scan strategy: each subgroup does a Hillis-Steele scan via `subgroup` shuffles, the last lane of each subgroup publishes the subgroup aggregate to shared memory, then every thread sequentially folds the cross-subgroup prefix and applies its own subgroup's prefix to its scan value. **All threads receive a valid result.** After the call, thread `i` holds `op(v[0], v[1], ..., v[i])`.

Args match `block.reduce_add` (`value, block_dim, dtype`). Cost: per-subgroup Hillis-Steele tree (`log2(subgroup_size)` shuffles) + 1 shared-memory write/read per subgroup + 1 `block.sync()` + `(block_dim / subgroup_size) - 1` ops on every thread (the cross-subgroup prefix is computed redundantly to avoid a second barrier). When the block is exactly one subgroup the shared-memory path is short-circuited at compile time.

```python
@qd.kernel
def kern(src: qd.types.ndarray(ndim=1), out: qd.types.ndarray(ndim=1)):
qd.loop_config(block_dim=128)
for i in range(N):
out[i] = qd.simt.block.inclusive_add(src[i], 128, qd.i32)
```

The corresponding generic form is `block.inclusive_scan(value, block_dim, op, dtype)` for custom monoids.

### `block.exclusive_{add,min,max}(value, block_dim[, identity], dtype)`

Block-scope exclusive prefix scans. Same strategy and cost profile as `inclusive_*`, but each thread receives the prefix `op(v[0], ..., v[i-1])` instead — and thread 0 receives the operator's identity.

- `exclusive_add`: identity is the additive zero; derived from `value - value` so callers do not need to pass it. After the call, thread 0 holds 0.
- `exclusive_min(..., identity, dtype)`: pass `identity` greater than or equal to every legal element of the input — typically `+∞` for floats or the dtype's maximum for integers. Thread 0 holds `identity`. There is no portable type-extreme derivable from `value` alone, so this op takes an explicit `identity` argument (mirrors `subgroup.exclusive_min`).
- `exclusive_max(..., identity, dtype)`: pass `identity` less than or equal to every legal element of the input — typically `-∞` for floats or the dtype's minimum for integers. Thread 0 holds `identity`.

The corresponding generic form is `block.exclusive_scan(value, block_dim, op, identity, dtype)`.

### `block.radix_rank_match_atomic_or(key, block_dim, radix_bits, bit_start, num_bits, bins, excl_prefix)`

Block-level radix ranking via the atomic-OR match-and-count strategy (the workhorse of an SM90-style onesweep radix sort). Each thread holds one `u32` key; the function returns the key's stable rank within the block under the digit `(key >> bit_start) & ((1 << num_bits) - 1)`, and writes the per-digit count and exclusive-prefix arrays to two caller-supplied `block.SharedArray` outparams.

Constraints (currently):

- `block_dim` must equal `1 << radix_bits` (each digit gets exactly one thread for the per-thread bin / exclusive-prefix output). Typical configuration is `radix_bits=8, block_dim=256`.
- `subgroup.group_size()` must be 32 (CUDA / Metal / Vulkan-on-NVIDIA) or 64 (AMDGPU). The match path picks its ballot dtype at compile time — `i32` on wave32, `i64` on wave64 — and the function `static_assert`s this at compile time.
- One key per thread (`items_per_thread = 1`). Multi-item per thread is a future extension.
- `num_bits <= radix_bits`; `bit_start` is the offset of the digit's low bit.

Args:

- `key`: per-thread `u32` input.
- `block_dim`, `radix_bits`, `bit_start`, `num_bits`: all compile-time `template()`.
- `bins`: `block.SharedArray((1 << radix_bits,), qd.i32)`. After the call, `bins[d]` holds the count of keys whose digit equals `d`.
- `excl_prefix`: `block.SharedArray((1 << radix_bits,), qd.i32)`. After the call, `excl_prefix[d]` holds the exclusive prefix sum of `bins` up to digit `d`.

The calling thread's block-local index is read internally via `block.thread_idx()`.

Cost: 2 `block.sync()` + a handful of `subgroup.sync()` calls + 1 block exclusive scan + per-key `atomic_or` + leader-only `atomic_add` on shared memory. Shared-memory footprint at the default `radix_bits=8` configuration: 4 KiB `i32` for the per-subgroup offsets + a match-mask region whose dtype is wave-size-specific — 4 KiB `i32` on wave32 (8 subgroups × 256 digits × 4 B) or 8 KiB `i64` on wave64 (4 subgroups × 256 digits × 8 B). So 8 KiB total on wave32, 12 KiB total on wave64.

```python
@qd.kernel
def kern(keys_in: qd.types.ndarray(ndim=1), ranks_out: qd.types.ndarray(ndim=1)):
qd.loop_config(block_dim=256)
for i in range(256):
bins = qd.simt.block.SharedArray((256,), qd.i32)
excl = qd.simt.block.SharedArray((256,), qd.i32)
ranks_out[i] = qd.simt.block.radix_rank_match_atomic_or(
keys_in[i], 256, 8, 0, 8, bins, excl
)
```

The function inserts the necessary `block.sync()` retires before returning, so callers can read `bins` / `excl_prefix` immediately after the call without an extra barrier.

## Related

- [grid](grid.md) — the device-scope counterpart of `block.mem_fence()`. For coordination within a single block, prefer `block.mem_fence()` — it is cheaper.
Expand Down
Loading
Loading