Skip to content

fix: make NF4 dequantize torch.compile-safe via local custom_op#25

Open
thad0ctor wants to merge 11 commits into
mainfrom
compile-safe-bnb-dequant
Open

fix: make NF4 dequantize torch.compile-safe via local custom_op#25
thad0ctor wants to merge 11 commits into
mainfrom
compile-safe-bnb-dequant

Conversation

@thad0ctor
Copy link
Copy Markdown
Owner

@thad0ctor thad0ctor commented May 24, 2026

Description

Registers axolotl's NF4 dequant fast path as a torch.library.custom_op (torch.ops.axolotl.nf4_dequantize) with register_fake. dequantize() branches on torch.compiler.is_compiling(): eager calls the ctypes body directly (zero overhead), tracing dispatches through the opaque op so Dynamo can compile around it. Also drops dead legacy-list quant_state branch, unused out= parameter, and the HAS_CUDA_STREAM version gate. Test fixtures rewritten to use real bnb.functional.quantize_4bit artifacts (the old synthetic fixtures used blocksize=32, which is bnb-invalid — the old ctypes path silently skipped validation).

Motivation and Context

torch_compile: true on QLoRA appears to "work" on current bnb (≥0.46 ships registered ops) but Dynamo still encounters the Unsloth-derived axolotl fast path's raw ctypes calls. It can't trace ctypes.c_int(...) (_SimpleCData.__new__) or the FFI symbols, and ends up recompiling bitsandbytes.functional.get_ptr against different tensor dtypes (Byte for packed weight, Float for absmax) until it hits recompile_limit (8) and falls back to eager for that frame. The result is silent compile inefficiency — no crash, just lost throughput.

Wrapping the fast path in a Dynamo-opaque custom op removes both the ctypes trace failure and the get_ptr recompile thrash. Eager path is byte-identical to main.

How has this been tested?

Unit tests (tests/e2e/kernels/test_quantize.py, 6 cases): all pass on torch 2.11.0 + bnb 0.49.1. Cases cover null state, shape preservation, non-square transposed input with value comparison vs bnb.functional.dequantize_4bit(...).t(), non-nested (single-quant) fallback, and torch.compile correctness on both nested and non-nested paths.

Test environment:

  • GPU: NVIDIA RTX PRO 6000 Blackwell Workstation (96 GB)
  • torch 2.11.0+cu130, bitsandbytes 0.49.1, transformers 5.5.4, Linux 6.14, CUDA 13.0
  • Model: Qwen3.5-0.8B
  • Config: QLoRA, lora_mlp_kernel/lora_qkv_kernel/lora_o_kernel: true, micro_batch=2, seq_len=512

Kernel-level dequant bench (eager), main vs branch, 1000 iters per shape:

Shape main median branch median Δ
Qwen 1024² 31.7 µs 31.6 µs -0.3%
Qwen 3072×1024 34.1 µs 33.5 µs -1.7%
Llama8B 4096² 44.8 µs 44.2 µs -1.4%
Llama8B 14336×4096 133.5 µs 133.2 µs -0.3%
Llama8B 4096×14336 135.1 µs 134.0 µs -0.8%
Llama70B 8192² 151.4 µs 150.7 µs -0.5%
Llama70B 8192×28672 485.3 µs 485.3 µs 0.0%

Eager is at parity across the matrix (all within ±2% run-to-run noise).

End-to-end QLoRA training (300 steps, two reversed-order rounds):

Variant r1 wall r2 wall avg sec/step vs main+compile
main + eager 948 s 705 s 2.755
main + compile 737 s 755 s 2.487 baseline
branch + eager 795 s 686 s 2.468
branch + compile 603 s 558 s 1.935 −22.2%

Branch + torch_compile: true is fastest in both rounds, including the round where it ran first (cold caches). Equivalent to ≈ +28% throughput in steps/sec vs main + compile.

Mechanism validation (3-step diagnostic with relevant Dynamo warnings counted):

Dynamo event main branch
_SimpleCData.__new__ (ctypes) trace failure 1 0
recompile_limit (8) hit on bnb.functional.get_ptr 1 0
Recompile reason tensor 'A' dtype mismatch. expected Float, actual Byte (none)
Throughput at step 3 242 tok/s/gpu 315 tok/s/gpu

End-to-end LoRA pipeline: trained 4 steps on Qwen3.5-0.8B with axolotl LoRA fast-path kernels enabled (which exercise dequantize in both forward and backward via autograd.Function), then axolotl merge-lora, then loaded the merged model and generated coherent output. Pipeline clean.

Risk validation matrix (4-step training runs on physical RTX 3090s, confirming no regression across architectural / config dimensions):

# Scenario Result Notes
1 fp16 compute dtype (Qwen3.5-0.8B QLoRA) ✅ 4 steps, no crash exercises if dtype == torch.float16 branch in _ctypes_nf4_dequant
2 DoRA adapter (Qwen3.5-0.8B, peft_use_dora: true) ✅ 4 steps, losses 1.74→0.61 hits kernels/dora.py:122 dequantize() call
3 Different architecture + larger model (Qwen3-Embedding-8B, Qwen3ForCausalLM — standard attention, 8B params) ✅ 4 steps, 11 GiB / 24 GiB on RTX 3090 confirms dequant works on non-Qwen3.5 architectures and at 8B scale
4 FSDP 2-rank multi-GPU (Qwen3.5-0.8B on 2× RTX 3090) ✅ 4 steps, losses 1.74→1.27, per-rank mem logging cross-reduces cleanly custom op axolotl::nf4_dequantize registers idempotently across ranks; per-process CUDA_STREAM cache works correctly

All four scenarios train to completion, exercise dequantize in forward + backward, and report sane losses / finite grad norms.

AI Usage Disclaimer

Yes — Claude Code (Anthropic) was used for implementation, bench scripting, skeptical review passes, and PR drafting. All numbers above are from local runs on the test environment listed; the diff and tests were reviewed before push.

Screenshots (if appropriate)

n/a

Types of changes

  • Bug fix (non-breaking change which fixes an issue)
  • Performance improvement
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Refactor (no functional changes)

Social Handles (Optional)

n/a

Wrap the Unsloth-derived NF4 dequant fast path in a torch.library.custom_op
(axolotl::nf4_dequantize) with a register_fake impl. dequantize() branches on
torch.compiler.is_compiling(): eager calls the ctypes body directly (zero
op-dispatch overhead), while tracing dispatches through the opaque op so
Dynamo can compile around it without graph-breaking on ctypes.c_int(...) or
the foreign-function calls.

Previously, torch.compile on any QLoRA model crashed with ctypes.ArgumentError
the first time a Linear4bit forward fell into the fast path.

Also:
- Drop dead legacy-list quant_state branch (not produced by bnb 0.40+).
- Drop the unused `out=` parameter (no production callers; grep-confirmed).
- Drop the HAS_CUDA_STREAM version gate (axolotl pins bnb >> 0.43.3).
- Rewrite tests/e2e/kernels/test_quantize.py to use real bnb.functional.quantize_4bit
  fixtures instead of synthetic ones whose blocksize=32 was silently invalid
  (the old ctypes path skipped validation). Adds compile regression tests for
  both nested (double-quant) and non-nested paths.

Validated:
- Eager bench vs main at 1024², 4096², 4096x11008 within ±2% noise.
- Unit tests pass on torch 2.11.0 + bnb 0.49.1.
- End-to-end: Qwen3.5-0.8B QLoRA train (lora_mlp/qkv/o_kernel=true, exercises
  dequantize in autograd Function fwd+bwd) + merge-lora pipeline succeed.
@coderabbitai
Copy link
Copy Markdown

coderabbitai Bot commented May 24, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 10dd05b1-3d39-4bde-beeb-ebcf46bc5636

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR refactors NF4 dequantization to use a torch.library.custom_op with torch.compile support and a ctypes fast path, removes legacy API parameters, and rewrites tests to use real bitsandbytes artifacts instead of mocked QuantState objects.

Changes

NF4 Dequantization Refactor

Layer / File(s) Summary
Custom op and fast path infrastructure
src/axolotl/kernels/quantize.py
Registers axolotl::nf4_dequantize custom op backed by a cached CUDA_STREAM and _ctypes_nf4_dequant fast path; includes register_fake for FakeTensor shape inference.
Public API and FP8 dequantization updates
src/axolotl/kernels/quantize.py
Removes out parameter and legacy list quant_state format; adds early returns for None and FP8; adds bnb.functional fallback for non-double-quant; dispatches to custom op during compilation or ctypes otherwise. Updates dequantize_fp8 documentation and tail-block handling.
Test helper and bitsandbytes integration
tests/e2e/kernels/test_quantize.py
Introduces _nf4_pair() to generate real NF4 artifacts via bnb.functional.quantize_4bit with configurable double-quantization, replacing manual QuantState construction.
Dequantization test coverage
tests/e2e/kernels/test_quantize.py
Tests validate null state passthrough, shape/dtype/device preservation, transposition, single-quant fallback, and torch.compile compatibility on double-quant and single-quant paths with numerical closeness assertions.

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title directly and accurately summarizes the main change: registering the NF4 dequantize implementation as a torch.compile-safe custom op to fix crashes during compilation.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch compile-safe-bnb-dequant

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
tests/e2e/kernels/test_quantize.py (2)

57-57: ⚡ Quick win

Make the compile regressions fail on graph breaks.

torch.compile(dequantize) can still succeed by graph-breaking around this code, which means these tests would pass even if the custom op stopped being traceable. Using fullgraph=True would keep the tests aligned with the contract this PR is adding.

Proposed change
-    compiled = torch.compile(dequantize)(packed, quant_state)
+    compiled = torch.compile(dequantize, fullgraph=True)(packed, quant_state)
-    compiled = torch.compile(dequantize)(packed, quant_state)
+    compiled = torch.compile(dequantize, fullgraph=True)(packed, quant_state)

Also applies to: 70-70

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/e2e/kernels/test_quantize.py` at line 57, Replace the calls to
torch.compile(dequantize) so the compiler runs in full-graph mode; specifically
update the compilation at the assignment to compiled (currently "compiled =
torch.compile(dequantize)(packed, quant_state)") to use fullgraph=True (e.g.,
torch.compile(dequantize, fullgraph=True)(...)), and make the same change for
the second occurrence around line 70; this ensures the test exercises true-graph
compilation for the dequantize function.

37-40: ⚡ Quick win

Assert the transposed values, not just the swapped shape.

A shape-only check would still pass if this path returned the wrong layout/content. Comparing against bnb.functional.dequantize_4bit(...).t() would turn this into a behavioral regression test.

Proposed test tightening
-    result = dequantize(packed.t(), quant_state)
-    assert tuple(result.shape) == (shape[1], shape[0])
+    expected = bnb.functional.dequantize_4bit(
+        packed, quant_state, quant_type="nf4"
+    ).t()
+    result = dequantize(packed.t(), quant_state)
+    torch.testing.assert_close(result, expected)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/e2e/kernels/test_quantize.py` around lines 37 - 40, The test currently
only checks shapes; update it to assert the transposed dequantized values match
the reference implementation by comparing the returned tensor from
dequantize(packed.t(), quant_state) against
bnb.functional.dequantize_4bit(packed, quant_state).t() (or equivalently
bnb.functional.dequantize_4bit(packed.t(), quant_state).t()) using an
element-wise comparison such as torch.allclose/torch.equal to catch
layout/content regressions; keep the existing shape assertion if desired but add
the value comparison using the symbols result, packed, quant_state, dequantize,
and bnb.functional.dequantize_4bit.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@src/axolotl/kernels/quantize.py`:
- Around line 44-47: The global caching of CUDA_STREAM causes stale stream
handles to be reused across devices and context changes; update the
dequantize/quantize callsite to query the current stream per call instead of
using the global CUDA_STREAM: remove or stop referencing the global CUDA_STREAM
and call torch.cuda.current_stream(target_device) inside the function (or
maintain a cache keyed by device+current-stream if you must cache) so the stream
passed into the bitsandbytes ctypes kernels always matches W.device and the
callerʼs active torch.cuda.stream context.

---

Nitpick comments:
In `@tests/e2e/kernels/test_quantize.py`:
- Line 57: Replace the calls to torch.compile(dequantize) so the compiler runs
in full-graph mode; specifically update the compilation at the assignment to
compiled (currently "compiled = torch.compile(dequantize)(packed, quant_state)")
to use fullgraph=True (e.g., torch.compile(dequantize, fullgraph=True)(...)),
and make the same change for the second occurrence around line 70; this ensures
the test exercises true-graph compilation for the dequantize function.
- Around line 37-40: The test currently only checks shapes; update it to assert
the transposed dequantized values match the reference implementation by
comparing the returned tensor from dequantize(packed.t(), quant_state) against
bnb.functional.dequantize_4bit(packed, quant_state).t() (or equivalently
bnb.functional.dequantize_4bit(packed.t(), quant_state).t()) using an
element-wise comparison such as torch.allclose/torch.equal to catch
layout/content regressions; keep the existing shape assertion if desired but add
the value comparison using the symbols result, packed, quant_state, dequantize,
and bnb.functional.dequantize_4bit.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 7ac3d340-e621-409d-b3a8-d8588ba27ec0

📥 Commits

Reviewing files that changed from the base of the PR and between dc8f7c7 and 8cb1694.

📒 Files selected for processing (2)
  • src/axolotl/kernels/quantize.py
  • tests/e2e/kernels/test_quantize.py

Comment thread src/axolotl/kernels/quantize.py Outdated
Comment on lines +44 to +47
global CUDA_STREAM
if CUDA_STREAM is None:
CUDA_STREAM = torch.cuda.current_stream(target_device)
stream = CUDA_STREAM
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
set -euo pipefail

# Locate and inspect CUDA_STREAM usage in the file in question
FILE="src/axolotl/kernels/quantize.py"

echo "== Lines around CUDA_STREAM cache =="
python3 - <<'PY'
import itertools
path="src/axolotl/kernels/quantize.py"
start=1
# print around line 30-80
with open(path,'r',encoding='utf-8') as f:
    lines=f.readlines()
for i in range(1, min(len(lines)+1, 120)):
    if 35 <= i <= 90:
        pass
for i in range(35,91):
    if i<=len(lines):
        print(f"{i}: {lines[i-1].rstrip()}")
PY

echo "== Grep full file for CUDA_STREAM occurrences =="
rg -n "CUDA_STREAM" "$FILE"

echo "== Show dequantize function (best effort) =="
python3 - <<'PY'
import re
path="src/axolotl/kernels/quantize.py"
with open(path,'r',encoding='utf-8') as f:
    s=f.read()
# crude extraction: show any function named dequantize*
m=re.search(r"def\s+dequantize\w*\s*\([^)]*\)\s*:", s)
print("Found dequantize* start:", bool(m))
PY

# Use ast-grep if available to extract dequantize implementation (fallback to sed if not)
if command -v ast-grep >/dev/null 2>&1; then
  echo "== ast-grep: extract def dequantize* =="
  ast-grep --pattern 'def dequantize($_, $_): { $$$ }' "$FILE" || true
  ast-grep --pattern 'def dequantize($_): { $$$ }' "$FILE" || true
  ast-grep --pattern 'def dequantize::$_($_, $_): { $$$ }' "$FILE" || true
else
  echo "ast-grep not available; skipping function extraction"
fi

echo "== Inspect call sites where stream is passed to custom ops =="
rg -n "stream|CUDA_STREAM|current_stream|capture|CUDAGraph" "$FILE"

Repository: thad0ctor/axolotl

Length of output: 2645


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="src/axolotl/kernels/quantize.py"

echo "== File header / imports (first 60 lines) =="
sed -n '1,80p' "$FILE" | nl -ba

echo "== Show rest of dequant/op logic around CUDA_STREAM usage (lines 80-170) =="
sed -n '80,180p' "$FILE" | nl -ba

echo "== Search for CUDA graph/capture handling in this file =="
rg -n "capture|CUDAGraph|is_current_stream_capturing|is_capturing|cuda_graph|graph" "$FILE" || true

echo "== Search repo-wide for is_current_stream_capturing usage (quick) =="
rg -n "is_current_stream_capturing|CUDAGraph" . || true

Repository: thad0ctor/axolotl

Length of output: 167


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="src/axolotl/kernels/quantize.py"

echo "== File header / imports (first 90 lines) =="
cat -n "$FILE" | sed -n '1,90p'

echo "== Dequant/op logic around CUDA_STREAM usage (lines 90-220) =="
cat -n "$FILE" | sed -n '90,220p'

echo "== Search for CUDA graph/capture handling in this file =="
rg -n "capture|CUDAGraph|is_current_stream_capturing|is_captur|cuda_graph|graph" "$FILE" || true

echo "== Search repo-wide for is_current_stream_capturing / CUDAGraph (quick) =="
rg -n "is_current_stream_capturing|CUDAGraph" . || true

Repository: thad0ctor/axolotl

Length of output: 7389


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Repo-wide CUDA graph capture usage (CUDAGraph / cuda_graph / graph) =="
rg -n "CUDAGraph|cuda_graph|torch\.cuda\.graphs|graph capture|capture_begin|begin_capture|end_capture" . || true

echo "== Repo-wide stream management (current_stream / StreamContext / stream= / use of CUDA_STREAM-like patterns) =="
rg -n "current_stream\(|torch\.cuda\.Stream\(|stream\=.*current_stream|torch\.cuda\.stream\(" . || true

echo "== Find dequantize usage (to see typical call sites) =="
rg -n "kernels\.quantize import .*dequantize|dequantize\(" . || true

Repository: thad0ctor/axolotl

Length of output: 3083


🏁 Script executed:

#!/bin/bash
set -euo pipefail

echo "== Inspect tests for quantize/dequantize (device assumptions) =="
ls -R tests/e2e/kernels | head -n 50
cat -n tests/e2e/kernels/test_quantize.py | sed -n '1,220p'

echo "== Inspect CUDA graph health check location =="
cat -n src/axolotl/utils/optimizers/adopt.py | sed -n '150,260p'

echo "== Search for current_stream() usage repo-wide (non-regex errors) =="
rg -n "current_stream\(" . || true
rg -n "torch\.cuda\.Stream" . || true

echo "== Search for CUDA graph capture APIs repo-wide =="
rg -n "CUDAGraph|cuda_graph|torch\.cuda\.graphs|capture_begin|begin_capture|end_capture" . || true

Repository: thad0ctor/axolotl

Length of output: 9055


🏁 Script executed:

#!/bin/bash
set -euo pipefail
FILE="src/axolotl/utils/optimizers/adopt.py"

rg -n "_cuda_graph_capture_health_check" "$FILE"
start=$(rg -n "_cuda_graph_capture_health_check" "$FILE" | head -n1 | cut -d: -f1)
echo "== Showing surrounding health check definition (~$start) =="
# print range around start-80 to start+120
from=$((start-80)); if [ $from -lt 1 ]; then from=1; fi
to=$((start+160))
cat -n "$FILE" | sed -n "${from},${to}p"

Repository: thad0ctor/axolotl

Length of output: 10101


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="src/axolotl/utils/optimizers/adopt.py"
rg -n "def _cuda_graph_capture_health_check" "$FILE" || true
line=$(rg -n "def _cuda_graph_capture_health_check" "$FILE" | head -n1 | cut -d: -f1)
if [ -n "${line:-}" ]; then
  from=$((line-120)); if [ $from -lt 1 ]; then from=1; fi
  to=$((line+200))
  echo "== Showing health check definition around line $line =="
  cat -n "$FILE" | sed -n "${from},${to}p"
fi

echo "== Also search for 'capturable' / 'cuda graph' related checks near health check =="
# narrow search window by printing matches with surrounding lines
python3 - <<'PY'
import re
path="src/axolotl/utils/optimizers/adopt.py"
with open(path,'r',encoding='utf-8') as f: lines=f.readlines()
keys=["capturable","cuda graph","CUDAGraph","capture","health_check"]
for i,l in enumerate(lines, start=1):
    if "_cuda_graph_capture_health_check" in l or ("capturable" in l and i>1 and i<400) and False:
        pass
PY

Repository: thad0ctor/axolotl

Length of output: 43


🏁 Script executed:

#!/bin/bash
set -euo pipefail

FILE="src/axolotl/utils/optimizers/adopt.py"

echo "== Where _cuda_graph_capture_health_check is referenced =="
rg -n "_cuda_graph_capture_health_check" "$FILE" || true

echo "== Search for its definition in the repo =="
rg -n "def _cuda_graph_capture_health_check" . || true

echo "== If it’s not defined here, inspect ADOPT optimizer class header/base =="
# show top of file and class definitions
cat -n "$FILE" | sed -n '1,120p'
rg -n "class " "$FILE" || true

Repository: thad0ctor/axolotl

Length of output: 5560


🏁 Script executed:

#!/bin/bash
set -euo pipefail
python3 - <<'PY'
import inspect, torch
from torch.optim import optimizer as opt
# Try common locations
candidates = []
for name in ["Optimizer", "optimizer.Optimizer"]:
    pass
# Get Optimizer class
Optimizer = torch.optim.Optimizer
meth = getattr(Optimizer, "_cuda_graph_capture_health_check", None)
print("torch version:", torch.__version__)
print("Has _cuda_graph_capture_health_check:", meth is not None)
if meth:
    try:
        src = inspect.getsource(meth)
        print("---- _cuda_graph_capture_health_check source (truncated to 200 lines) ----")
        lines=src.splitlines()
        for i,l in enumerate(lines[:200],1):
            print(f"{i:03d}: {l}")
        if len(lines)>200:
            print("... (truncated)")
    except OSError as e:
        print("Could not get source:", e)
PY

Repository: thad0ctor/axolotl

Length of output: 179


🌐 Web query:

PyTorch Optimizer _cuda_graph_capture_health_check implementation current_stream stream handle stale CUDA graph capture

💡 Result:

PyTorch’s implementation detail (as of today’s current upstream sources) is that the CUDAGraph capture “health check” is meant to detect when CUDA graph capture state would become invalid, commonly due to stale CUDA stream references. In particular, a known failure mode is when an autograd node holds a reference to a non-capturing stream from a prior warmup iteration; during CUDA graph capture, the runtime can end up issuing cudaStreamWaitEvent on that stale stream, which pulls that non-capturing stream into the capture and invalidates the whole capture with cudaErrorStreamCaptureInvalidated. PyTorch addresses this by detecting stale references during capture and either raising a clear runtime error (default behavior) or (opt-in) redirecting stale non-capturing streams to the producer’s capturing stream so the capture can proceed. [1] How this relates to “current_stream / stream handle stale capture”: 1) During capture, PyTorch relies on the correct “current stream” semantics: CUDAGraph capture begins and ends on a specific stream (and capture must end on the same stream it began on). [2] 2) If some component (notably autograd) retains a stale CUDA stream handle from warmup, then when capture is running on a different stream, the stale handle can cause synchronization calls that involve a stream not actually participating in the capture. This “pulls” the stale stream into the capture and invalidates it. [1] 3) The “health check” is thus effectively guarding against these stream-mismatch / stale-stream situations: stale capture-related stream handles are detected, and PyTorch either errors (especially when the stale stream is the default stream, stream 0) or overrides by adopting the capturing stream. [1] Where in the codebase / docs this behavior is described: - The CUDAGraph API documentation describes that capture begins/ends on the current stream and that you generally shouldn’t manually call capture_begin/capture_end; these internal capture points depend on the current stream. [3] - Upstream discussion of stale stream references during CUDA graph capture is explicitly documented in the PyTorch PR about “Detect and fix stale stream references in autograd during CUDA graph capture”, including the mechanism (cudaStreamWaitEvent) and the two behaviors (default RuntimeError vs opt-in override). [1] - A separate PR explicitly mentions enabling “graph capture check in optimizer” (including XPU support), indicating that optimizer-side graph capture health checks exist and are an active part of PyTorch’s work. [4] Practical implications / mitigations (what you typically do to avoid stale-stream health-check failures): - Ensure that any warmup that creates autograd nodes / related stream-dependent state does not cause those nodes to permanently bind to a non-capturing stream that later becomes stale for the capture stream. The stale reference scenario described above arises when nodes are created during warmup with one stream current (e.g., default stream), then capture happens later on a side stream. [1] - If you can’t easily restructure so that capturing happens without stale stream bindings, PyTorch provides an opt-in override that redirects stale non-capturing streams to the producer’s capturing stream. [1] If you share your exact stack trace / PyTorch version, I can map the specific “_cuda_graph_capture_health_check” failure message to the specific stale-stream mechanism (e.g., whether it’s default-stream stream 0 vs non-default) and suggest the most targeted fix.

Citations:


Avoid caching torch.cuda.current_stream globally in quantize.py.

src/axolotl/kernels/quantize.py:44-47 caches CUDA_STREAM = torch.cuda.current_stream(target_device) on the first call and reuses it for all later dequantize calls, even when W.device (and/or the caller’s current stream via torch.cuda.stream(...)) changes. That stale stream handle is passed into the bitsandbytes ctypes kernels, risking wrong-device execution and incorrect stream ordering (and potentially interfering with CUDA graph capture). Refresh the stream per call (or at least key/refresh the cache by device and current-stream changes) instead of pinning the first stream forever.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/axolotl/kernels/quantize.py` around lines 44 - 47, The global caching of
CUDA_STREAM causes stale stream handles to be reused across devices and context
changes; update the dequantize/quantize callsite to query the current stream per
call instead of using the global CUDA_STREAM: remove or stop referencing the
global CUDA_STREAM and call torch.cuda.current_stream(target_device) inside the
function (or maintain a cache keyed by device+current-stream if you must cache)
so the stream passed into the bitsandbytes ctypes kernels always matches
W.device and the callerʼs active torch.cuda.stream context.

thad0ctor added 4 commits May 23, 2026 18:15
Addresses CodeRabbit nits on PR #25:
- Docstrings on _nf4_dequantize_op and its register_fake impl to lift
  docstring coverage above the 80% threshold.
- test_dequantize_transposed now asserts value equality vs
  bnb.functional.dequantize_4bit(...).t() in addition to shape, so a
  layout/content regression would actually fail the test.

Skipped CodeRabbit's fullgraph=True nit: tracing under fullgraph=True
currently fails on both main and branch (same _SimpleCData.__new__ trace
error before reaching the is_compiling() branch). Investigating is out
of scope for this PR; default-mode torch.compile is what the fix targets
and is verified working.

Skipped CodeRabbit's CUDA_STREAM caching concern: the cache pattern is
byte-identical to main and removing it costs +5-21% at the kernel level.
Stream-management refactor is a separate concern.
- Module + dequantize() docstrings condensed to single line.
- CUDA_STREAM comment dropped benchmark numbers and version reference.
- Removed task-history language ("no longer crashes", "vs the prior
  implementation", "Regression:") from docstrings; comments now describe
  permanent invariants, not the PR diff.
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 24, 2026

📖 Documentation Preview:

Deployed on Netlify from commit 9506828

Per CodeRabbit: the single-global stream cache could in principle return
the wrong device's stream for a multi-device-single-process caller. In
practice axolotl always runs one process per CUDA device, so this is
defensive rather than fixing a known reproducer.

Using `setdefault` keeps the cache race benign under threading (worst
case: redundant current_stream() call) and drops the `global` keyword
since the dict is mutated in place, never rebound.
thad0ctor added a commit that referenced this pull request May 24, 2026
Addresses CodeRabbit nits on PR #25:
- Docstrings on _nf4_dequantize_op and its register_fake impl to lift
  docstring coverage above the 80% threshold.
- test_dequantize_transposed now asserts value equality vs
  bnb.functional.dequantize_4bit(...).t() in addition to shape, so a
  layout/content regression would actually fail the test.

Skipped CodeRabbit's fullgraph=True nit: tracing under fullgraph=True
currently fails on both main and branch (same _SimpleCData.__new__ trace
error before reaching the is_compiling() branch). Investigating is out
of scope for this PR; default-mode torch.compile is what the fix targets
and is verified working.

Skipped CodeRabbit's CUDA_STREAM caching concern: the cache pattern is
byte-identical to main and removing it costs +5-21% at the kernel level.
Stream-management refactor is a separate concern.
thad0ctor added a commit that referenced this pull request May 28, 2026
Addresses CodeRabbit nits on PR #25:
- Docstrings on _nf4_dequantize_op and its register_fake impl to lift
  docstring coverage above the 80% threshold.
- test_dequantize_transposed now asserts value equality vs
  bnb.functional.dequantize_4bit(...).t() in addition to shape, so a
  layout/content regression would actually fail the test.

Skipped CodeRabbit's fullgraph=True nit: tracing under fullgraph=True
currently fails on both main and branch (same _SimpleCData.__new__ trace
error before reaching the is_compiling() branch). Investigating is out
of scope for this PR; default-mode torch.compile is what the fix targets
and is verified working.

Skipped CodeRabbit's CUDA_STREAM caching concern: the cache pattern is
byte-identical to main and removing it costs +5-21% at the kernel level.
Stream-management refactor is a separate concern.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants