fix: make NF4 dequantize torch.compile-safe via local custom_op#25
fix: make NF4 dequantize torch.compile-safe via local custom_op#25thad0ctor wants to merge 11 commits into
Conversation
Wrap the Unsloth-derived NF4 dequant fast path in a torch.library.custom_op (axolotl::nf4_dequantize) with a register_fake impl. dequantize() branches on torch.compiler.is_compiling(): eager calls the ctypes body directly (zero op-dispatch overhead), while tracing dispatches through the opaque op so Dynamo can compile around it without graph-breaking on ctypes.c_int(...) or the foreign-function calls. Previously, torch.compile on any QLoRA model crashed with ctypes.ArgumentError the first time a Linear4bit forward fell into the fast path. Also: - Drop dead legacy-list quant_state branch (not produced by bnb 0.40+). - Drop the unused `out=` parameter (no production callers; grep-confirmed). - Drop the HAS_CUDA_STREAM version gate (axolotl pins bnb >> 0.43.3). - Rewrite tests/e2e/kernels/test_quantize.py to use real bnb.functional.quantize_4bit fixtures instead of synthetic ones whose blocksize=32 was silently invalid (the old ctypes path skipped validation). Adds compile regression tests for both nested (double-quant) and non-nested paths. Validated: - Eager bench vs main at 1024², 4096², 4096x11008 within ±2% noise. - Unit tests pass on torch 2.11.0 + bnb 0.49.1. - End-to-end: Qwen3.5-0.8B QLoRA train (lora_mlp/qkv/o_kernel=true, exercises dequantize in autograd Function fwd+bwd) + merge-lora pipeline succeed.
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR refactors NF4 dequantization to use a torch.library.custom_op with torch.compile support and a ctypes fast path, removes legacy API parameters, and rewrites tests to use real bitsandbytes artifacts instead of mocked QuantState objects. ChangesNF4 Dequantization Refactor
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
tests/e2e/kernels/test_quantize.py (2)
57-57: ⚡ Quick winMake the compile regressions fail on graph breaks.
torch.compile(dequantize)can still succeed by graph-breaking around this code, which means these tests would pass even if the custom op stopped being traceable. Usingfullgraph=Truewould keep the tests aligned with the contract this PR is adding.Proposed change
- compiled = torch.compile(dequantize)(packed, quant_state) + compiled = torch.compile(dequantize, fullgraph=True)(packed, quant_state)- compiled = torch.compile(dequantize)(packed, quant_state) + compiled = torch.compile(dequantize, fullgraph=True)(packed, quant_state)Also applies to: 70-70
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/e2e/kernels/test_quantize.py` at line 57, Replace the calls to torch.compile(dequantize) so the compiler runs in full-graph mode; specifically update the compilation at the assignment to compiled (currently "compiled = torch.compile(dequantize)(packed, quant_state)") to use fullgraph=True (e.g., torch.compile(dequantize, fullgraph=True)(...)), and make the same change for the second occurrence around line 70; this ensures the test exercises true-graph compilation for the dequantize function.
37-40: ⚡ Quick winAssert the transposed values, not just the swapped shape.
A shape-only check would still pass if this path returned the wrong layout/content. Comparing against
bnb.functional.dequantize_4bit(...).t()would turn this into a behavioral regression test.Proposed test tightening
- result = dequantize(packed.t(), quant_state) - assert tuple(result.shape) == (shape[1], shape[0]) + expected = bnb.functional.dequantize_4bit( + packed, quant_state, quant_type="nf4" + ).t() + result = dequantize(packed.t(), quant_state) + torch.testing.assert_close(result, expected)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/e2e/kernels/test_quantize.py` around lines 37 - 40, The test currently only checks shapes; update it to assert the transposed dequantized values match the reference implementation by comparing the returned tensor from dequantize(packed.t(), quant_state) against bnb.functional.dequantize_4bit(packed, quant_state).t() (or equivalently bnb.functional.dequantize_4bit(packed.t(), quant_state).t()) using an element-wise comparison such as torch.allclose/torch.equal to catch layout/content regressions; keep the existing shape assertion if desired but add the value comparison using the symbols result, packed, quant_state, dequantize, and bnb.functional.dequantize_4bit.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@src/axolotl/kernels/quantize.py`:
- Around line 44-47: The global caching of CUDA_STREAM causes stale stream
handles to be reused across devices and context changes; update the
dequantize/quantize callsite to query the current stream per call instead of
using the global CUDA_STREAM: remove or stop referencing the global CUDA_STREAM
and call torch.cuda.current_stream(target_device) inside the function (or
maintain a cache keyed by device+current-stream if you must cache) so the stream
passed into the bitsandbytes ctypes kernels always matches W.device and the
callerʼs active torch.cuda.stream context.
---
Nitpick comments:
In `@tests/e2e/kernels/test_quantize.py`:
- Line 57: Replace the calls to torch.compile(dequantize) so the compiler runs
in full-graph mode; specifically update the compilation at the assignment to
compiled (currently "compiled = torch.compile(dequantize)(packed, quant_state)")
to use fullgraph=True (e.g., torch.compile(dequantize, fullgraph=True)(...)),
and make the same change for the second occurrence around line 70; this ensures
the test exercises true-graph compilation for the dequantize function.
- Around line 37-40: The test currently only checks shapes; update it to assert
the transposed dequantized values match the reference implementation by
comparing the returned tensor from dequantize(packed.t(), quant_state) against
bnb.functional.dequantize_4bit(packed, quant_state).t() (or equivalently
bnb.functional.dequantize_4bit(packed.t(), quant_state).t()) using an
element-wise comparison such as torch.allclose/torch.equal to catch
layout/content regressions; keep the existing shape assertion if desired but add
the value comparison using the symbols result, packed, quant_state, dequantize,
and bnb.functional.dequantize_4bit.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 7ac3d340-e621-409d-b3a8-d8588ba27ec0
📒 Files selected for processing (2)
src/axolotl/kernels/quantize.pytests/e2e/kernels/test_quantize.py
| global CUDA_STREAM | ||
| if CUDA_STREAM is None: | ||
| CUDA_STREAM = torch.cuda.current_stream(target_device) | ||
| stream = CUDA_STREAM |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
set -euo pipefail
# Locate and inspect CUDA_STREAM usage in the file in question
FILE="src/axolotl/kernels/quantize.py"
echo "== Lines around CUDA_STREAM cache =="
python3 - <<'PY'
import itertools
path="src/axolotl/kernels/quantize.py"
start=1
# print around line 30-80
with open(path,'r',encoding='utf-8') as f:
lines=f.readlines()
for i in range(1, min(len(lines)+1, 120)):
if 35 <= i <= 90:
pass
for i in range(35,91):
if i<=len(lines):
print(f"{i}: {lines[i-1].rstrip()}")
PY
echo "== Grep full file for CUDA_STREAM occurrences =="
rg -n "CUDA_STREAM" "$FILE"
echo "== Show dequantize function (best effort) =="
python3 - <<'PY'
import re
path="src/axolotl/kernels/quantize.py"
with open(path,'r',encoding='utf-8') as f:
s=f.read()
# crude extraction: show any function named dequantize*
m=re.search(r"def\s+dequantize\w*\s*\([^)]*\)\s*:", s)
print("Found dequantize* start:", bool(m))
PY
# Use ast-grep if available to extract dequantize implementation (fallback to sed if not)
if command -v ast-grep >/dev/null 2>&1; then
echo "== ast-grep: extract def dequantize* =="
ast-grep --pattern 'def dequantize($_, $_): { $$$ }' "$FILE" || true
ast-grep --pattern 'def dequantize($_): { $$$ }' "$FILE" || true
ast-grep --pattern 'def dequantize::$_($_, $_): { $$$ }' "$FILE" || true
else
echo "ast-grep not available; skipping function extraction"
fi
echo "== Inspect call sites where stream is passed to custom ops =="
rg -n "stream|CUDA_STREAM|current_stream|capture|CUDAGraph" "$FILE"Repository: thad0ctor/axolotl
Length of output: 2645
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="src/axolotl/kernels/quantize.py"
echo "== File header / imports (first 60 lines) =="
sed -n '1,80p' "$FILE" | nl -ba
echo "== Show rest of dequant/op logic around CUDA_STREAM usage (lines 80-170) =="
sed -n '80,180p' "$FILE" | nl -ba
echo "== Search for CUDA graph/capture handling in this file =="
rg -n "capture|CUDAGraph|is_current_stream_capturing|is_capturing|cuda_graph|graph" "$FILE" || true
echo "== Search repo-wide for is_current_stream_capturing usage (quick) =="
rg -n "is_current_stream_capturing|CUDAGraph" . || trueRepository: thad0ctor/axolotl
Length of output: 167
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="src/axolotl/kernels/quantize.py"
echo "== File header / imports (first 90 lines) =="
cat -n "$FILE" | sed -n '1,90p'
echo "== Dequant/op logic around CUDA_STREAM usage (lines 90-220) =="
cat -n "$FILE" | sed -n '90,220p'
echo "== Search for CUDA graph/capture handling in this file =="
rg -n "capture|CUDAGraph|is_current_stream_capturing|is_captur|cuda_graph|graph" "$FILE" || true
echo "== Search repo-wide for is_current_stream_capturing / CUDAGraph (quick) =="
rg -n "is_current_stream_capturing|CUDAGraph" . || trueRepository: thad0ctor/axolotl
Length of output: 7389
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Repo-wide CUDA graph capture usage (CUDAGraph / cuda_graph / graph) =="
rg -n "CUDAGraph|cuda_graph|torch\.cuda\.graphs|graph capture|capture_begin|begin_capture|end_capture" . || true
echo "== Repo-wide stream management (current_stream / StreamContext / stream= / use of CUDA_STREAM-like patterns) =="
rg -n "current_stream\(|torch\.cuda\.Stream\(|stream\=.*current_stream|torch\.cuda\.stream\(" . || true
echo "== Find dequantize usage (to see typical call sites) =="
rg -n "kernels\.quantize import .*dequantize|dequantize\(" . || trueRepository: thad0ctor/axolotl
Length of output: 3083
🏁 Script executed:
#!/bin/bash
set -euo pipefail
echo "== Inspect tests for quantize/dequantize (device assumptions) =="
ls -R tests/e2e/kernels | head -n 50
cat -n tests/e2e/kernels/test_quantize.py | sed -n '1,220p'
echo "== Inspect CUDA graph health check location =="
cat -n src/axolotl/utils/optimizers/adopt.py | sed -n '150,260p'
echo "== Search for current_stream() usage repo-wide (non-regex errors) =="
rg -n "current_stream\(" . || true
rg -n "torch\.cuda\.Stream" . || true
echo "== Search for CUDA graph capture APIs repo-wide =="
rg -n "CUDAGraph|cuda_graph|torch\.cuda\.graphs|capture_begin|begin_capture|end_capture" . || trueRepository: thad0ctor/axolotl
Length of output: 9055
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="src/axolotl/utils/optimizers/adopt.py"
rg -n "_cuda_graph_capture_health_check" "$FILE"
start=$(rg -n "_cuda_graph_capture_health_check" "$FILE" | head -n1 | cut -d: -f1)
echo "== Showing surrounding health check definition (~$start) =="
# print range around start-80 to start+120
from=$((start-80)); if [ $from -lt 1 ]; then from=1; fi
to=$((start+160))
cat -n "$FILE" | sed -n "${from},${to}p"Repository: thad0ctor/axolotl
Length of output: 10101
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="src/axolotl/utils/optimizers/adopt.py"
rg -n "def _cuda_graph_capture_health_check" "$FILE" || true
line=$(rg -n "def _cuda_graph_capture_health_check" "$FILE" | head -n1 | cut -d: -f1)
if [ -n "${line:-}" ]; then
from=$((line-120)); if [ $from -lt 1 ]; then from=1; fi
to=$((line+200))
echo "== Showing health check definition around line $line =="
cat -n "$FILE" | sed -n "${from},${to}p"
fi
echo "== Also search for 'capturable' / 'cuda graph' related checks near health check =="
# narrow search window by printing matches with surrounding lines
python3 - <<'PY'
import re
path="src/axolotl/utils/optimizers/adopt.py"
with open(path,'r',encoding='utf-8') as f: lines=f.readlines()
keys=["capturable","cuda graph","CUDAGraph","capture","health_check"]
for i,l in enumerate(lines, start=1):
if "_cuda_graph_capture_health_check" in l or ("capturable" in l and i>1 and i<400) and False:
pass
PYRepository: thad0ctor/axolotl
Length of output: 43
🏁 Script executed:
#!/bin/bash
set -euo pipefail
FILE="src/axolotl/utils/optimizers/adopt.py"
echo "== Where _cuda_graph_capture_health_check is referenced =="
rg -n "_cuda_graph_capture_health_check" "$FILE" || true
echo "== Search for its definition in the repo =="
rg -n "def _cuda_graph_capture_health_check" . || true
echo "== If it’s not defined here, inspect ADOPT optimizer class header/base =="
# show top of file and class definitions
cat -n "$FILE" | sed -n '1,120p'
rg -n "class " "$FILE" || trueRepository: thad0ctor/axolotl
Length of output: 5560
🏁 Script executed:
#!/bin/bash
set -euo pipefail
python3 - <<'PY'
import inspect, torch
from torch.optim import optimizer as opt
# Try common locations
candidates = []
for name in ["Optimizer", "optimizer.Optimizer"]:
pass
# Get Optimizer class
Optimizer = torch.optim.Optimizer
meth = getattr(Optimizer, "_cuda_graph_capture_health_check", None)
print("torch version:", torch.__version__)
print("Has _cuda_graph_capture_health_check:", meth is not None)
if meth:
try:
src = inspect.getsource(meth)
print("---- _cuda_graph_capture_health_check source (truncated to 200 lines) ----")
lines=src.splitlines()
for i,l in enumerate(lines[:200],1):
print(f"{i:03d}: {l}")
if len(lines)>200:
print("... (truncated)")
except OSError as e:
print("Could not get source:", e)
PYRepository: thad0ctor/axolotl
Length of output: 179
🌐 Web query:
PyTorch Optimizer _cuda_graph_capture_health_check implementation current_stream stream handle stale CUDA graph capture
💡 Result:
PyTorch’s implementation detail (as of today’s current upstream sources) is that the CUDAGraph capture “health check” is meant to detect when CUDA graph capture state would become invalid, commonly due to stale CUDA stream references. In particular, a known failure mode is when an autograd node holds a reference to a non-capturing stream from a prior warmup iteration; during CUDA graph capture, the runtime can end up issuing cudaStreamWaitEvent on that stale stream, which pulls that non-capturing stream into the capture and invalidates the whole capture with cudaErrorStreamCaptureInvalidated. PyTorch addresses this by detecting stale references during capture and either raising a clear runtime error (default behavior) or (opt-in) redirecting stale non-capturing streams to the producer’s capturing stream so the capture can proceed. [1] How this relates to “current_stream / stream handle stale capture”: 1) During capture, PyTorch relies on the correct “current stream” semantics: CUDAGraph capture begins and ends on a specific stream (and capture must end on the same stream it began on). [2] 2) If some component (notably autograd) retains a stale CUDA stream handle from warmup, then when capture is running on a different stream, the stale handle can cause synchronization calls that involve a stream not actually participating in the capture. This “pulls” the stale stream into the capture and invalidates it. [1] 3) The “health check” is thus effectively guarding against these stream-mismatch / stale-stream situations: stale capture-related stream handles are detected, and PyTorch either errors (especially when the stale stream is the default stream, stream 0) or overrides by adopting the capturing stream. [1] Where in the codebase / docs this behavior is described: - The CUDAGraph API documentation describes that capture begins/ends on the current stream and that you generally shouldn’t manually call capture_begin/capture_end; these internal capture points depend on the current stream. [3] - Upstream discussion of stale stream references during CUDA graph capture is explicitly documented in the PyTorch PR about “Detect and fix stale stream references in autograd during CUDA graph capture”, including the mechanism (cudaStreamWaitEvent) and the two behaviors (default RuntimeError vs opt-in override). [1] - A separate PR explicitly mentions enabling “graph capture check in optimizer” (including XPU support), indicating that optimizer-side graph capture health checks exist and are an active part of PyTorch’s work. [4] Practical implications / mitigations (what you typically do to avoid stale-stream health-check failures): - Ensure that any warmup that creates autograd nodes / related stream-dependent state does not cause those nodes to permanently bind to a non-capturing stream that later becomes stale for the capture stream. The stale reference scenario described above arises when nodes are created during warmup with one stream current (e.g., default stream), then capture happens later on a side stream. [1] - If you can’t easily restructure so that capturing happens without stale stream bindings, PyTorch provides an opt-in override that redirects stale non-capturing streams to the producer’s capturing stream. [1] If you share your exact stack trace / PyTorch version, I can map the specific “_cuda_graph_capture_health_check” failure message to the specific stale-stream mechanism (e.g., whether it’s default-stream stream 0 vs non-default) and suggest the most targeted fix.
Citations:
- 1: Detect and fix stale stream references in autograd during CUDA graph capture pytorch/pytorch#180090
- 2: > if graph capture is thread local pytorch/pytorch#137844
- 3: https://docs.pytorch.org/docs/2.11/generated/torch.cuda.graphs.CUDAGraph.html
- 4: Enable XPU support to graph capture check in optimizer pytorch/pytorch#172759
Avoid caching torch.cuda.current_stream globally in quantize.py.
src/axolotl/kernels/quantize.py:44-47 caches CUDA_STREAM = torch.cuda.current_stream(target_device) on the first call and reuses it for all later dequantize calls, even when W.device (and/or the caller’s current stream via torch.cuda.stream(...)) changes. That stale stream handle is passed into the bitsandbytes ctypes kernels, risking wrong-device execution and incorrect stream ordering (and potentially interfering with CUDA graph capture). Refresh the stream per call (or at least key/refresh the cache by device and current-stream changes) instead of pinning the first stream forever.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@src/axolotl/kernels/quantize.py` around lines 44 - 47, The global caching of
CUDA_STREAM causes stale stream handles to be reused across devices and context
changes; update the dequantize/quantize callsite to query the current stream per
call instead of using the global CUDA_STREAM: remove or stop referencing the
global CUDA_STREAM and call torch.cuda.current_stream(target_device) inside the
function (or maintain a cache keyed by device+current-stream if you must cache)
so the stream passed into the bitsandbytes ctypes kernels always matches
W.device and the callerʼs active torch.cuda.stream context.
Addresses CodeRabbit nits on PR #25: - Docstrings on _nf4_dequantize_op and its register_fake impl to lift docstring coverage above the 80% threshold. - test_dequantize_transposed now asserts value equality vs bnb.functional.dequantize_4bit(...).t() in addition to shape, so a layout/content regression would actually fail the test. Skipped CodeRabbit's fullgraph=True nit: tracing under fullgraph=True currently fails on both main and branch (same _SimpleCData.__new__ trace error before reaching the is_compiling() branch). Investigating is out of scope for this PR; default-mode torch.compile is what the fix targets and is verified working. Skipped CodeRabbit's CUDA_STREAM caching concern: the cache pattern is byte-identical to main and removing it costs +5-21% at the kernel level. Stream-management refactor is a separate concern.
- Module + dequantize() docstrings condensed to single line.
- CUDA_STREAM comment dropped benchmark numbers and version reference.
- Removed task-history language ("no longer crashes", "vs the prior
implementation", "Regression:") from docstrings; comments now describe
permanent invariants, not the PR diff.
|
📖 Documentation Preview: Deployed on Netlify from commit 9506828 |
Per CodeRabbit: the single-global stream cache could in principle return the wrong device's stream for a multi-device-single-process caller. In practice axolotl always runs one process per CUDA device, so this is defensive rather than fixing a known reproducer. Using `setdefault` keeps the cache race benign under threading (worst case: redundant current_stream() call) and drops the `global` keyword since the dict is mutated in place, never rebound.
Addresses CodeRabbit nits on PR #25: - Docstrings on _nf4_dequantize_op and its register_fake impl to lift docstring coverage above the 80% threshold. - test_dequantize_transposed now asserts value equality vs bnb.functional.dequantize_4bit(...).t() in addition to shape, so a layout/content regression would actually fail the test. Skipped CodeRabbit's fullgraph=True nit: tracing under fullgraph=True currently fails on both main and branch (same _SimpleCData.__new__ trace error before reaching the is_compiling() branch). Investigating is out of scope for this PR; default-mode torch.compile is what the fix targets and is verified working. Skipped CodeRabbit's CUDA_STREAM caching concern: the cache pattern is byte-identical to main and removing it costs +5-21% at the kernel level. Stream-management refactor is a separate concern.
Addresses CodeRabbit nits on PR #25: - Docstrings on _nf4_dequantize_op and its register_fake impl to lift docstring coverage above the 80% threshold. - test_dequantize_transposed now asserts value equality vs bnb.functional.dequantize_4bit(...).t() in addition to shape, so a layout/content regression would actually fail the test. Skipped CodeRabbit's fullgraph=True nit: tracing under fullgraph=True currently fails on both main and branch (same _SimpleCData.__new__ trace error before reaching the is_compiling() branch). Investigating is out of scope for this PR; default-mode torch.compile is what the fix targets and is verified working. Skipped CodeRabbit's CUDA_STREAM caching concern: the cache pattern is byte-identical to main and removing it costs +5-21% at the kernel level. Stream-management refactor is a separate concern.
Description
Registers axolotl's NF4 dequant fast path as a
torch.library.custom_op(torch.ops.axolotl.nf4_dequantize) withregister_fake.dequantize()branches ontorch.compiler.is_compiling(): eager calls the ctypes body directly (zero overhead), tracing dispatches through the opaque op so Dynamo can compile around it. Also drops dead legacy-listquant_statebranch, unusedout=parameter, and theHAS_CUDA_STREAMversion gate. Test fixtures rewritten to use realbnb.functional.quantize_4bitartifacts (the old synthetic fixtures usedblocksize=32, which is bnb-invalid — the old ctypes path silently skipped validation).Motivation and Context
torch_compile: trueon QLoRA appears to "work" on current bnb (≥0.46 ships registered ops) but Dynamo still encounters the Unsloth-derived axolotl fast path's raw ctypes calls. It can't tracectypes.c_int(...)(_SimpleCData.__new__) or the FFI symbols, and ends up recompilingbitsandbytes.functional.get_ptragainst different tensor dtypes (Byte for packed weight, Float for absmax) until it hitsrecompile_limit (8)and falls back to eager for that frame. The result is silent compile inefficiency — no crash, just lost throughput.Wrapping the fast path in a Dynamo-opaque custom op removes both the ctypes trace failure and the
get_ptrrecompile thrash. Eager path is byte-identical to main.How has this been tested?
Unit tests (
tests/e2e/kernels/test_quantize.py, 6 cases): all pass on torch 2.11.0 + bnb 0.49.1. Cases cover null state, shape preservation, non-square transposed input with value comparison vsbnb.functional.dequantize_4bit(...).t(), non-nested (single-quant) fallback, andtorch.compilecorrectness on both nested and non-nested paths.Test environment:
lora_mlp_kernel/lora_qkv_kernel/lora_o_kernel: true,micro_batch=2,seq_len=512Kernel-level dequant bench (eager), main vs branch, 1000 iters per shape:
Eager is at parity across the matrix (all within ±2% run-to-run noise).
End-to-end QLoRA training (300 steps, two reversed-order rounds):
Branch +
torch_compile: trueis fastest in both rounds, including the round where it ran first (cold caches). Equivalent to ≈ +28% throughput in steps/sec vs main + compile.Mechanism validation (3-step diagnostic with relevant Dynamo warnings counted):
_SimpleCData.__new__(ctypes) trace failurerecompile_limit (8)hit onbnb.functional.get_ptrtensor 'A' dtype mismatch. expected Float, actual ByteEnd-to-end LoRA pipeline: trained 4 steps on Qwen3.5-0.8B with axolotl LoRA fast-path kernels enabled (which exercise
dequantizein both forward and backward viaautograd.Function), thenaxolotl merge-lora, then loaded the merged model and generated coherent output. Pipeline clean.Risk validation matrix (4-step training runs on physical RTX 3090s, confirming no regression across architectural / config dimensions):
if dtype == torch.float16branch in_ctypes_nf4_dequantpeft_use_dora: true)kernels/dora.py:122dequantize()callQwen3ForCausalLM— standard attention, 8B params)axolotl::nf4_dequantizeregisters idempotently across ranks; per-processCUDA_STREAMcache works correctlyAll four scenarios train to completion, exercise
dequantizein forward + backward, and report sane losses / finite grad norms.AI Usage Disclaimer
Yes — Claude Code (Anthropic) was used for implementation, bench scripting, skeptical review passes, and PR drafting. All numbers above are from local runs on the test environment listed; the diff and tests were reviewed before push.
Screenshots (if appropriate)
n/a
Types of changes
Social Handles (Optional)
n/a