Skip to content

Commit ba5f832

Browse files
Add sm87 to aarch64 CUDA wheel targets (Jetson Orin)
Adds sm_87 (NVIDIA Jetson Orin: Nano / NX / AGX) to the aarch64 build_capability list in .github/scripts/build-cuda.sh and documents the addition in installation.mdx. Why an explicit cubin is needed: the CMake arch logic at CMakeLists.txt:226-230 only emits PTX for the latest capability. sm87 hardware can't JIT from sm90+ PTX (forward-compat is upward- only), so aarch64 wheels currently targeting sm75/sm80/sm90 ship PTX only for sm90 and Jetson Orin users fall back to slow or unsupported paths. This rebuts the "sm80 should cover sm87" reasoning that closed #1781. Wheel size impact (measured on Linux aarch64, CUDA 12.6.68, source HEAD a57d8e2): baseline (sm75;80;90): 5,710,520 bytes (5.45 MiB) with sm87 (sm75;80;87;90): 7,353,064 bytes (7.01 MiB) delta: +1,642,544 bytes (+1.57 MiB, +28.76%) Adds tests/test_linear4bit_sm87_multishape_regression.py — pytest reproducer for the multi-shape Linear4bit cold-start fault on sm_87 (#1936). The test runs the historical failing recipe (NF4 + bf16 quant_storage + bf16 compute + double_quant + ABC shape order + no hygiene + batch=1) at sm_87 cold-state. The fault is cold-start-specific; the test docstring documents the warm/cold distinction so CI runners can configure accordingly. Closes #1930 Closes #1218 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 2de5ec3 commit ba5f832

3 files changed

Lines changed: 111 additions & 5 deletions

File tree

.github/scripts/build-cuda.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,16 @@ set -xeuo pipefail
99
if [[ -v cuda_targets ]]; then
1010
build_capability="${cuda_targets}"
1111
elif [ "${build_arch}" = "aarch64" ]; then
12-
build_capability="75;80;90"
12+
# sm87 (Jetson Orin) is aarch64-only and needs an explicit cubin: only
13+
# the latest capability gets PTX, so sm87 hardware cannot JIT from
14+
# sm90+ PTX. See bitsandbytes-foundation/bitsandbytes#1930.
15+
build_capability="75;80;87;90"
1316

1417
# CUDA 12.8-12.9: Add sm100/sm120
15-
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;90;100;120"
18+
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;87;90;100;120"
1619

1720
# CUDA 13.0+: Add sm100/sm110/sm120
18-
[[ "${cuda_version}" == 13.*.* ]] && build_capability="75;80;90;100;110;120;121"
21+
[[ "${cuda_version}" == 13.*.* ]] && build_capability="75;80;87;90;100;110;120;121"
1922
else
2023
# By default, target Pascal through Hopper.
2124
build_capability="60;70;75;80;86;89;90"

docs/source/installation.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ The currently distributed `bitsandbytes` packages are built with the following c
5252
| **Linux x86-64** | 11.8 - 12.6 | GCC 11.2 | sm60, sm70, sm75, sm80, sm86, sm89, sm90
5353
| **Linux x86-64** | 12.8 - 12.9 | GCC 11.2 | sm70, sm75, sm80, sm86, sm89, sm90, sm100, sm120
5454
| **Linux x86-64** | 13.0 | GCC 11.2 | sm75, sm80, sm86, sm89, sm90, sm100, sm120
55-
| **Linux aarch64** | 11.8 - 12.6 | GCC 11.2 | sm75, sm80, sm90
56-
| **Linux aarch64** | 12.8 - 13.0 | GCC 11.2 | sm75, sm80, sm90, sm100, sm110, sm120, sm121
55+
| **Linux aarch64** | 11.8 - 12.6 | GCC 11.2 | sm75, sm80, sm87, sm90
56+
| **Linux aarch64** | 12.8 - 13.0 | GCC 11.2 | sm75, sm80, sm87, sm90, sm100, sm110, sm120, sm121
5757
| **Windows x86-64** | 11.8 - 12.6 | MSVC 19.43+ (VS2022) | sm50, sm60, sm75, sm80, sm86, sm89, sm90
5858
| **Windows x86-64** | 12.8 - 12.9 | MSVC 19.43+ (VS2022) | sm70, sm75, sm80, sm86, sm89, sm90, sm100, sm120
5959
| **Windows x86-64** | 13.0 | MSVC 19.43+ (VS2022) | sm75, sm80, sm86, sm89, sm90, sm100, sm120
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""Regression test for #1936: sm_87 multi-shape Linear4bit reboot.
2+
3+
Bug history
4+
-----------
5+
On bnb 0.46.1 against an NVIDIA Jetson Orin (sm_87) at nvpmodel
6+
MAXN_SUPER, the host kernel reboots when three Linear4bit layers are
7+
constructed and forwarded in sequence with the following recipe:
8+
9+
- shape order: monotonically-increasing by output-feature product (A → B → C below)
10+
- quant_type: NF4
11+
- quant_storage: torch.bfloat16 (FSDP-compatible)
12+
- compute_dtype: torch.bfloat16
13+
- compress_statistics (double_quant): True
14+
- no inter-layer memory hygiene (no `del layer` or `empty_cache`)
15+
- batch size: 1
16+
17+
The fault is overwhelmingly cold-start-specific (~78% reboot rate
18+
at cold-start, 0% at warm state across N=29+ warm-state samples).
19+
bnb 0.49.2 also reboots at cold-start (N=1) — the original "fixed
20+
in 0.49.2" framing was a warm-state artifact and was retracted in
21+
a 2026-05-05 comment on the issue. A 256x256 NF4 forward executed
22+
as the first GPU op after boot closes the cold-start race window
23+
(N=3 verified). This test runs the original failing recipe at
24+
sm_87 so any future regression surfaces in CI; CI runners targeting
25+
sm_87 are expected to provide cold-state (e.g., reboot before the
26+
test run) for it to fire reliably.
27+
28+
References
29+
----------
30+
- Issue: https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1936
31+
- Reproduced N=4 across two physically separate Jetson Orin Nano
32+
Super 8GB units; fault travels with silicon + bnb build, not one
33+
defective board.
34+
- A 13-test orthogonal-axis bisection found six axes (shape order,
35+
quant_type, quant_storage, compute_dtype, double_quant, hygiene)
36+
each independently sufficient to prevent the fault. The recipe in
37+
this test is the unique intersection that triggers it on 0.46.1.
38+
39+
Caveats
40+
-------
41+
- On a *broken* bnb the test crashes the host (system reboot), not
42+
just the test runner — failure mode is OS-level. pytest cannot
43+
capture this; absence of test output IS the regression signal.
44+
- The bug is timing-sensitive (race condition); lower-power
45+
nvpmodel modes (15W / 25W) prevent it. The test does not enforce
46+
a power mode — sm_87 CI is expected to run at MAXN.
47+
- Skipped on all non-sm_87 hardware. sm_87 is exclusive to the
48+
Jetson Orin family (Nano / NX / AGX).
49+
"""
50+
51+
import pytest
52+
import torch
53+
54+
import bitsandbytes as bnb
55+
56+
pytestmark = pytest.mark.skipif(
57+
not torch.cuda.is_available() or torch.cuda.get_device_capability() != (8, 7),
58+
reason="Regression test for sm_87 (NVIDIA Jetson Orin family) only",
59+
)
60+
61+
62+
# Real-world shapes from Llama-3 / Qwen-2.5 / Mistral lm_head dimensions,
63+
# selected to match the historical #1936 reproducer. Monotonically
64+
# increasing by output-feature product is load-bearing — bisection found
65+
# all five non-ABC permutations (ACB, BAC, BCA, CAB, CBA) pass cleanly.
66+
# Do not substitute toy shapes here without re-validating against the
67+
# historical repro.
68+
SHAPES_ABC = [
69+
(4096, 32768), # A
70+
(4096, 128256), # B
71+
(3584, 152064), # C
72+
]
73+
74+
75+
def test_linear4bit_multishape_bf16_storage_no_fault():
76+
"""Run the #1936 recipe; assert each forward completes without fault.
77+
78+
Cold-start race: at sm_87 cold-state on a buggy bnb the host reboots
79+
before the third forward returns. At warm-state (any prior bnb-NF4 op
80+
in the same session) every bnb version tested passes — see module
81+
docstring for the cold/warm characterization.
82+
"""
83+
for in_features, out_features in SHAPES_ABC:
84+
layer = bnb.nn.Linear4bit(
85+
in_features,
86+
out_features,
87+
bias=False,
88+
compute_dtype=torch.bfloat16,
89+
compress_statistics=True, # double_quant
90+
quant_type="nf4",
91+
quant_storage=torch.bfloat16, # FSDP-compatible storage path
92+
).to("cuda")
93+
94+
x = torch.randn(1, in_features, dtype=torch.bfloat16, device="cuda")
95+
y = layer(x)
96+
97+
assert y.shape == (1, out_features)
98+
assert y.dtype == torch.bfloat16
99+
100+
# Deliberately no `del layer`, `torch.cuda.empty_cache()`, or
101+
# `torch.cuda.synchronize()` between iterations. Any of those
102+
# individually prevents the historical fault and would mask a
103+
# regression of the 0.49.2 fix.

0 commit comments

Comments
 (0)