Skip to content

feat(compile): expose torch_compile_options for Inductor flag tuning#34

Open
thad0ctor wants to merge 10 commits into
mainfrom
torch-compile-options-clean
Open

feat(compile): expose torch_compile_options for Inductor flag tuning#34
thad0ctor wants to merge 10 commits into
mainfrom
torch-compile-options-clean

Conversation

@thad0ctor

@thad0ctor thad0ctor commented May 30, 2026

Copy link
Copy Markdown
Owner

Description

Adds torch_compile_options: dict[str, bool|int|float|str] | None — an opt-in allowlist of 10 torch._inductor.config knobs that HF Trainer's TorchDynamoPlugin (transformers 5.8.1) never forwards. Default None → behavior identical to today. The allowlist lives in enums.py (INDUCTOR_COMPILE_OPTIONS_ALLOWLIST) as the single source of truth for the validators and the runtime apply.

torch_compile: true
torch_compile_mode: max-autotune
torch_compile_options:
  max_autotune_gemm: true   # −2.6% on sm_120 (≥32 GB); regresses on sm_86 — arch-specific
  • Validated at axolotl preprocess: a field_validator rejects non-allowlisted keys; a model_validator rejects options without torch_compile; a second model_validator warns on triton.cudagraphs + sample_packing (cudagraphs need static shapes, packed sequences are dynamic).
  • Runtime apply (TrainerBuilderBase._apply_torch_compile_options) setattrs each validated key on torch._inductor.config before compile; dotted keys (triton.cudagraphs) work natively.
  • Adds an example YAML and updates the docs/optimizations.qmd sm_86 note. Purely additive on the existing torch_compile* fields.

Motivation and Context

HF Trainer doesn't forward kwargs to torch._inductor.config, so flags like max_autotune_gemm previously required monkey-patching before axolotl train. These flags have no single correct default — their payoff is architecture-specific (see below) — so they're exposed as opt-in config rather than baked in.

How has this been tested?

Unit tests (torch 2.11.0+cu130 / transformers 5.8.1):

  • Schema (TestTorchCompileValidation, 9): allowlist round-trip, rejects disallowed key, requires torch_compile, and the new triton.cudagraphs+sample_packing warning.
  • Runtime (TestApplyTorchCompileOptions, 7): flags actually flip torch._inductor.config, all 10 allowlist keys verified present (rename sentinel, incl. dotted key), dotted/multi/empty-dict apply, wiring.

Benchmarks — Qwen3-8B + LoRA r=16 + bf16 + sdpa + seq 2048 + grad_ckpt + alpaca, 2 seeds, fresh process per (cond, seed), tail-100-median ms/step. Conds: a eager · b fused_attn_kernel (baseline) · c b + torch_compile · then single inductor flags on c.

torch_compile (c vs b) is a large win that scales with arch, and cuts peak memory 19,992 → 19,614 MB on every card:

c vs b vanilla 3090 (sm_86) 3090 Ti (sm_86) 5090 (sm_120)
compile win −14.25% −16.55% −19.06%

Of the 10 flags, only max_autotune_gemm has a real, arch-dependent effect (Δ vs plain compile); the rest are washes or torch-2.11-default no-ops on both arches:

flag sm_86 (3090 Ti) sm_120 (5090)
max_autotune_gemm +13% regression −2.57% gain
coordinate_descent_tuning wash wash
triton.cudagraphs wash (+mem pool) wash (+mem pool)

max_autotune_gemm is best run alone — a 6-combo f+X probe on sm_120 found nothing beats it (coordinate_descent_tuning and triton.cudagraphs each erode it). It hits the same 101 KB shared-mem limit on both arches (consumer Blackwell ≈ Ampere here), but the autotuner nets better tiles on sm_120, and its +2.2 GB peak needs ≥32 GB. Recommend on sm_120 + ≥32 GB; leave off on sm_86.

Full sm_120 single-flag sweep (RTX 5090), showing what does and doesn't help:

cond config mean ms peak MB Δ vs b Δ vs c
a eager 928.9 19,992 −0.78%
b fused (baseline) 936.2 19,992 0.00%
c + torch_compile 757.8 19,614 −19.06% 0.00%
d + coordinate_descent_tuning 763.0 19,614 −18.50% +0.69% wash
e + shape_padding+epilogue_fusion 756.0 19,614 −19.24% −0.23% wash
f + max_autotune_gemm 738.3 21,836 −21.14% −2.57%
g + triton.cudagraphs 757.7 20,207 −19.06% −0.00% wash
l + all 10 keys 752.6 21,828 −19.60% −0.67%

AI Usage Disclaimer

Yes — Opus 4.8 used throughout

Screenshots (if appropriate)

Types of changes

  • New feature (non-breaking) — torch_compile_options config field; opt-in Inductor flag plumbing

Social Handles (Optional)

Adds a `torch_compile_options: dict[str, bool | int | float | str] | None`
schema field that lets users opt into a small allowlist of
`torch._inductor.config` knobs that HF Trainer otherwise does not expose.

Current allowlist (INDUCTOR_COMPILE_OPTIONS_ALLOWLIST in
`src/axolotl/utils/schemas/enums.py`):

  - coordinate_descent_tuning             [benched on 8B; +0.5%]
  - coordinate_descent_check_all_directions
  - shape_padding                         [torch 2.11 default]
  - epilogue_fusion                       [torch 2.11 default]
  - max_autotune_gemm
  - fx_graph_cache                        [torch 2.11 default]
  - assume_aligned_inputs
  - comprehensive_padding                 [torch 2.11 default]
  - decompose_mem_bound_mm                [experimental]
  - triton.cudagraphs                     [breaks with dynamic shapes]

HF Trainer's TorchDynamoPlugin (transformers 5.8.1) only forwards
torch_compile / torch_compile_backend / torch_compile_mode; it does not
forward kwargs / options to torch._inductor.config. Users wanting to flip
e.g. coordinate_descent_tuning previously had to monkey-patch
torch._inductor.config themselves before invoking axolotl. This change
gives them a config-driven path. ConfigModule supports dotted setattr
natively, so triton.cudagraphs and other nested keys work without any
path-walking in the runtime apply.

Validation is enforced at schema time so failures surface during
`axolotl preprocess`, not after model load:

  - field_validator rejects keys outside the allowlist with a clear
    "Allowed: ..." error.
  - model_validator rejects torch_compile_options set without
    torch_compile enabled (`auto` allowed through, since the
    auto-resolver may flip it true downstream).

The runtime apply path (`_apply_torch_compile_options` in
core/builders/base.py) trusts the validated schema and just iterates the
dict setattr-ing on torch._inductor.config.

Bench (Qwen3-8B + LoRA r=16 [q,k,v,o] + sdpa + bf16 + grad_ckpt +
seq=2048 + torch 2.11.0+cu130 + sm_86 / RTX 3090 Ti, 5 conditions x
2 seeds x 2000 steps, fresh process per replicate):

  | cond | mean ms/step | delta vs fused-only baseline |
  |---|---|---|
  | eager (no fused, no compile)        | 2520.6 | +2.8 % |
  | fused_attn_kernel only              | 2452.6 |  0.0 % |
  | fused + torch_compile (default)     | 2055.1 | -16.2 % |
  | + coordinate_descent_tuning         | 2045.0 | -16.6 % |
  | + shape_padding + epilogue_fusion   | 2044.0 | -16.7 % |

- coordinate_descent_tuning adds ~0.5 % steady-state on top of plain
  compile on this rig (consistent across both seeds and across
  tail-100 / tail-500 / tail-1000 / tail-2000 windows); cold-start
  adds 10-30 s per unique kernel signature on the first compile.
- shape_padding and epilogue_fusion are already True by default in torch
  2.11; explicit setting is a no-op vs cond d on this version. The
  example YAML therefore only sets coordinate_descent_tuning to avoid
  misleading readers into thinking the other two contributed to the delta.
- The other 6 allowlist keys are exposed for user experimentation; a
  follow-up bench (conds f-l) is in progress on the same rig.

Files touched:

- src/axolotl/utils/schemas/enums.py: new
  INDUCTOR_COMPILE_OPTIONS_ALLOWLIST frozenset (10 keys).
- src/axolotl/utils/schemas/config.py: new schema field +
  validate_torch_compile_options field_validator +
  check_torch_compile_options_requires_compile model_validator.
- src/axolotl/core/builders/base.py: _apply_torch_compile_options
  staticmethod called from _configure_torch_compile when
  torch_compile_options is set.
- tests/conftest.py: lifted capture_axolotl_warnings helper out of
  tests/test_attn_implementation.py (previously duplicated).
- tests/test_attn_implementation.py: imports capture_axolotl_warnings
  from tests.conftest instead of defining its own copy.
- tests/core/test_builders.py: TestApplyTorchCompileOptions covers the
  runtime path (allowlisted apply, dotted-key apply, multi-flag apply,
  empty-dict no-op, mock-based wiring assertions, and a regression
  sentinel that fails if torch renames any allowlisted flag).
- tests/patched/test_validation.py: TestTorchCompileValidation extended
  with rejects_disallowed_key, requires_torch_compile_enabled,
  rejects_when_torch_compile_false, and with_auto_compile_passes.
- examples/qwen3/8b-lora-fused-attn-compile.yaml: example stacking
  fused_attn_kernel + torch_compile + torch_compile_options, with
  commented optional flags for the other allowlist keys.
- docs/optimizations.qmd: hardware note split into baseline /
  cold-start / tuning bullets (replaces an inherited single
  long sentence).

Note: built on top of the qwen-fused-kernels branch (PR #27) so the
example YAML can reference fused_attn_kernel. The schema field +
wiring itself is independent of PR #27 and would also work on plain
origin/main.
@coderabbitai

coderabbitai Bot commented May 30, 2026

Copy link
Copy Markdown

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 2c97bf2b-d8cc-420d-8b0e-a9503d79ec3e

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR introduces a configurable PyTorch/Inductor compile options feature that allows users to tune compiler behavior through torch_compile_options configuration. The implementation includes schema validation, allowlist enforcement, runtime application, comprehensive tests, and documentation.

Changes

Torch Compile Options

Layer / File(s) Summary
Inductor compile options allowlist
src/axolotl/utils/schemas/enums.py
Defines INDUCTOR_COMPILE_OPTIONS_ALLOWLIST as a frozen set of permitted Inductor option identifiers to restrict which flags can be configured.
Configuration schema and validation
src/axolotl/utils/schemas/config.py
Adds torch_compile_options optional field to AxolotlInputConfig accepting allowlisted flag dictionaries; validates keys against allowlist and enforces that torch_compile must be enabled when options are provided.
Apply torch compile options to inductor config
src/axolotl/core/builders/base.py
Implements _apply_torch_compile_options static method to set allowlisted attributes on torch._inductor.config globally before compilation; updates _configure_torch_compile to invoke it when options are configured.
Builder unit tests for option application
tests/core/test_builders.py
Tests allowlist enforcement, dotted-key mutation, empty-dict no-ops, and integration with _configure_torch_compile using inductor_config_snapshot fixture to restore inductor state between tests.
Configuration validation tests
tests/patched/test_validation.py
Validates schema behavior: torch_compile_options defaults to None, accepts allowlisted dictionaries, rejects disallowed keys, and requires torch_compile to be enabled.
Shared test infrastructure
tests/conftest.py, tests/test_attn_implementation.py
Adds capture_axolotl_warnings context manager fixture in conftest.py for logging-aware test support; refactors test_attn_implementation.py to import and use the shared fixture.
Documentation and example configuration
docs/optimizations.qmd, examples/qwen3/8b-lora-fused-attn-compile.yaml
Updates optimization docs with steady-state and compile cold-start performance guidance; adds example YAML demonstrating Qwen3 8B fine-tuning with LoRA, fused attention, and compile options.

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 5.88% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: exposing torch_compile_options as a new feature for Inductor flag tuning, which aligns directly with the PR's core objective of adding a torch_compile_options config field.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch torch-compile-options-clean

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions

github-actions Bot commented May 30, 2026

Copy link
Copy Markdown

📖 Documentation Preview:

Deployed on Netlify from commit fc18d11

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tests/core/test_builders.py (1)

635-642: ⚡ Quick win

Test may not validate dotted-key attributes correctly.

hasattr(obj, "a.b") checks for a literal attribute named "a.b", not nested obj.a.b. For dotted keys in INDUCTOR_COMPILE_OPTIONS_ALLOWLIST (e.g., "triton.cudagraphs"), this test may pass even if the nested structure doesn't exist, because it's only checking for a literal "triton.cudagraphs" attribute.

To properly validate nested attributes, you'd need path-walking:

def has_nested_attr(obj, dotted_key):
    try:
        parts = dotted_key.split('.')
        for part in parts:
            obj = getattr(obj, part)
        return True
    except AttributeError:
        return False

However, if the verification script I suggested above confirms that torch._inductor.config handles dotted keys specially, then this test is fine as-is.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/core/test_builders.py` around lines 635 - 642, The test
test_all_allowlisted_flags_are_attributes_on_inductor_config incorrectly uses
hasattr(_inductor_cfg, key) which treats dotted keys like "triton.cudagraphs" as
a single literal attribute; replace this check with a nested-path walk (e.g.,
implement and call a helper has_nested_attr(obj, dotted_key) that splits key on
'.' and iteratively getattr each part, returning False on AttributeError) and
use it to validate each entry in INDUCTOR_COMPILE_OPTIONS_ALLOWLIST against
_inductor_cfg.
tests/test_attn_implementation.py (1)

17-17: 💤 Low value

Optional: prefer a shared helper module over importing from conftest.py.

Importing utilities directly from conftest.py is discouraged by pytest, since conftest files are special-cased during collection and can be imported under different module paths depending on rootdir, risking duplicate instances. Consider moving capture_axolotl_warnings into a plain module (e.g., tests/utils.py) and importing from there. Works fine as-is, so this is non-blocking.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/test_attn_implementation.py` at line 17, Move the helper function
capture_axolotl_warnings out of conftest.py into a regular test utility module
(e.g., tests/utils.py) and update imports in tests/test_attn_implementation.py
to import capture_axolotl_warnings from that module; locate the reference to
capture_axolotl_warnings in tests/test_attn_implementation.py and replace the
import from tests.conftest with an import from the new utils module, ensuring
any other tests that use capture_axolotl_warnings are updated similarly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@examples/qwen3/8b-lora-fused-attn-compile.yaml`:
- Around line 28-37: Add a semantic validation that rejects or warns when
sample_packing is enabled together with triton.cudagraphs in
torch_compile_options: inside AxolotlInputConfig implement a model validator (or
add a check in its post-init/validation hook) that inspects self.sample_packing
and self.torch_compile_options and raises a ValueError or issues a runtime
warning if sample_packing is true and torch_compile_options contains
triton.cudagraphs: true; alternatively add the same check at runtime in
_apply_torch_compile_options before calling setattr on inductor flags to prevent
silently applying the incompatible triton.cudagraphs setting.

In `@src/axolotl/utils/schemas/enums.py`:
- Around line 159-173: The INDUCTOR_COMPILE_OPTIONS_ALLOWLIST contains keys that
may not exist on torch._inductor.config for PyTorch 2.9; add a unit test (e.g.,
test_inductor_allowlist_keys_exist) that imports
INDUCTOR_COMPILE_OPTIONS_ALLOWLIST and asserts every key is an attribute on
torch._inductor.config (using hasattr) so the CI will catch invalid names
(including "triton.cudagraphs"); then update INDUCTOR_COMPILE_OPTIONS_ALLOWLIST
to only include keys that pass that test for the pinned PyTorch version (remove
or rename "coordinate_descent_check_all_directions", "decompose_mem_bound_mm",
"assume_aligned_inputs", "comprehensive_padding" if they are not present) and
document the expected PyTorch compatibility in a comment above
INDUCTOR_COMPILE_OPTIONS_ALLOWLIST.

---

Nitpick comments:
In `@tests/core/test_builders.py`:
- Around line 635-642: The test
test_all_allowlisted_flags_are_attributes_on_inductor_config incorrectly uses
hasattr(_inductor_cfg, key) which treats dotted keys like "triton.cudagraphs" as
a single literal attribute; replace this check with a nested-path walk (e.g.,
implement and call a helper has_nested_attr(obj, dotted_key) that splits key on
'.' and iteratively getattr each part, returning False on AttributeError) and
use it to validate each entry in INDUCTOR_COMPILE_OPTIONS_ALLOWLIST against
_inductor_cfg.

In `@tests/test_attn_implementation.py`:
- Line 17: Move the helper function capture_axolotl_warnings out of conftest.py
into a regular test utility module (e.g., tests/utils.py) and update imports in
tests/test_attn_implementation.py to import capture_axolotl_warnings from that
module; locate the reference to capture_axolotl_warnings in
tests/test_attn_implementation.py and replace the import from tests.conftest
with an import from the new utils module, ensuring any other tests that use
capture_axolotl_warnings are updated similarly.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: c2f47365-91ef-47d5-a9d8-b103d755eca3

📥 Commits

Reviewing files that changed from the base of the PR and between bf19bff and 95b7ce9.

📒 Files selected for processing (9)
  • docs/optimizations.qmd
  • examples/qwen3/8b-lora-fused-attn-compile.yaml
  • src/axolotl/core/builders/base.py
  • src/axolotl/utils/schemas/config.py
  • src/axolotl/utils/schemas/enums.py
  • tests/conftest.py
  • tests/core/test_builders.py
  • tests/patched/test_validation.py
  • tests/test_attn_implementation.py

Comment thread examples/qwen3/8b-lora-fused-attn-compile.yaml Outdated
Comment thread src/axolotl/utils/schemas/enums.py Outdated
@thad0ctor thad0ctor force-pushed the torch-compile-options-clean branch from 9b82d76 to c264720 Compare May 30, 2026 22:21
- config.py: add check_cudagraphs_wo_static_shapes model_validator warning
  on triton.cudagraphs + sample_packing (cudagraphs need static shapes).
  +2 tests.
- enums.py: one-line note that allowlist keys must resolve on
  torch._inductor.config (all 10 verified on torch 2.11); existing
  test_all_allowlisted_flags_are_attributes_on_inductor_config sentinel
  guards renames.
- docs/optimizations.qmd + example YAML: correct the tuning guidance to
  the cross-arch findings — torch_compile win scales with arch (~-14 to
  -16.5% sm_86, ~-19% sm_120); coordinate_descent_tuning is within noise
  (not +0.5%); max_autotune_gemm is the one measured gain, on sm_120 +
  >=32 GB only (regresses on sm_86). List the full allowlist.
- one-line docstrings/comments.
@thad0ctor thad0ctor force-pushed the torch-compile-options-clean branch from c264720 to 4cd9309 Compare May 30, 2026 22:27
thad0ctor and others added 8 commits June 4, 2026 12:21
…ns-clean

# Conflicts:
#	tests/core/test_builders.py
…ge-E MoEs (axolotl-ai-cloud#3712)

* perf(scattermoe-lora): grouped-Gram dA/dB + sync-free dX_lora for large-E MoEs

The dA/dB split kernel recomputes XA=X@A^T (resp. YB=dY@B) in registers once per
output dim-block to avoid materializing it; and the non-fused dX_lora path is a
Python per-expert loop with an expert_offsets[e].item() device sync per expert.
Both scale badly with the expert count -- and modern MoEs run E>=128.

The intermediates are rank-sized ([M*k, R], R~16), so materializing them once is
near-free (the split kernel saves only ~2-8 MB across these shapes -- smaller than
the dA/dB gradient it produces). Backport the multi-adapter wins to the
single-adapter path:

- grouped_gram.py: _grouped_gram_kernel + grouped_lora_weight_grads -- dA/dB as
  recompute-free grouped Gram products over precomputed XA/YB (bit-identical to
  group_bwd_lora, err 0.0).
- ScatterMoELoRA.backward non-fused dA/dB: compute YB, XA once, then grouped-Gram;
  YB is reused by the non-fused dX_lora.
- _compute_lora_input_grad: two grouped scatter2scatter GEMMs (sync-free, single
  launch each) when routing ids are given, reusing YB; the per-expert loop is kept
  as a fallback but with one host sync for the whole offset array, not O(E).

Before/after, full ScatterMoELoRA fwd+bwd (E>=128, M=4096):
  fused-dX path (production, #dA/dB only): 1.08-1.15x
  non-fused path (#dA/dB + #dX_lora):      up to 2.2x (Qwen3-MoE, DeepSeek)
at +5..30 MB peak (the XA/YB buffers). The dA/dB kernel alone is 2-17x faster.

MXFP4 unaffected: dX takes the is_mx branch (scatter2scatter_lora_dX_mx) and never
the rewritten non-fused path; dA/dB never touch the (frozen) base, so LoRA grads
are bit-identical for an MX base vs its bf16 dequantization at a workload above the
fuse-gather threshold (new regression test). #1 (sonicmoe materialize) is already
on main via baddbmm.

* chore: lint

* chore: lint
…MXFP4, sonicmoe fallback (axolotl-ai-cloud#3714)

* perf(scattermoe-ep): skip DeepEP -1 sentinels in the local kernel

Under DeepEP the local scattermoe kernel received the full N*K dispatched
routing with remote slots mapped to expert 0 / weight 0, so the grouped GEMM
+ per-row LoRA processed every sentinel row (compute-and-mask). Worse, routing
all sentinels to expert 0 piles ~half the rows into one bucket -> pathological
load imbalance.

scattermoe_experts_forward_ep drops the -1 sentinel rows before the GEMMs:
runs both projections fully grouped over only the valid routed rows (the
compacted routing breaks scatter2scatter's L == X.rows*k fan-out contract) and
does the weighted token-combine via index_add_. Output is identical to the
masked path since sentinel slots carry weight 0.

The deep_ep dispatch now passes the raw -1-tagged routing through; each local
kernel handles sentinels its own way (eager/scattermoe skip, grouped_mm masks).

scattermoe LoRA fwd+bwd, RTX PRO 6000 (Blackwell), skip vs mask:
  Qwen3-30B   ep2 2.9x/0.67x mem  ep4 6.0x/0.37x  ep8 10.3x/0.25x
  Qwen3-235B  ep2 2.5x/0.65x      ep4 4.9x/0.37x  ep8  8.4x/0.24x
  DeepSeek    ep2 2.1x/0.60x      ep4 3.9x/0.34x  ep8  6.8x/0.23x

Validated bit-equivalent to the masked path (fp32 ~1e-7, bf16 ~6-9e-3) on
output, dX and LoRA dA/dB for base + LoRA at ep2/ep4.

* feat(scattermoe): gpt_oss layout + sm_120 sonicmoe fallback

The sonic-moe CUTLASS kernel can't compile on consumer Blackwell (sm_120): its
bundled quack GemmSm120 predates the concat_layout arg the dispatcher passes.
Make the vendored scattermoe Triton path cover the gap.

1. gpt_oss layout in scattermoe_experts_forward: dispatch on the layout flags and
   handle the gpt_oss-style experts (is_transposed, not is_concatenated, has_bias)
   in a dedicated path -- weights are already [E, in, out] (no transpose), gate/up
   are interleaved ([..., ::2] / [..., 1::2]), per-expert bias is folded into the
   grouped GEMM, and the activation is the clamped sigmoid-GLU. LoRA fuses exactly
   as in the standard path (same in/out dims, same scatter2scatter_lora). Threads
   expert_biases through _parallel_linear_maybe_lora.

2. sm_120 fallback: when the sonic-moe kernel can't run (Blackwell) and the experts
   use a standard layout, sonicmoe_experts_forward_with_lora transparently routes to
   the scattermoe path, which runs there.

Validated vs an eager reference (fp32 with TF32 off, ~1e-6): output, dX, and every
gradient (gpt_oss base: d_gate_up/d_down + both biases; LoRA: dA/dB for gate_up and
down) for base + LoRA, fp32 + bf16.

gpt_oss MoE+LoRA fwd+bwd, vendored Triton vs eager, RTX PRO 6000 (Blackwell):
  gpt-oss-20b  N=2048  9.5x / 0.10x mem
  gpt-oss-20b  N=8192  3.6x / 0.32x mem
  gpt-oss-120b N=4096 48.8x / 0.05x mem

MXFP4 gpt_oss weights dequantize on the fly via parallel_linear_lora (bf16 compute),
as with other MXFP4 models on this path; fused MXFP4 for gpt_oss is not yet wired.

* perf(scattermoe): use fused MXFP4 kernel in experts forward, no dequant

scattermoe_experts_forward dequantized base MXFP4 weights to bf16 every step (the
old `_get_base_param(...).transpose(2,1)` actually crashed on an MXTensor, so MXFP4
only ever ran via the dequant-based HFScatterMoEGatedMLP path). The fused MX kernel
(scatter2scatter_lora_mx, dequant inside the K-loop) and all the plumbing
(selective_mx_weights_fwd / get_active_experts / remap_expert_indices /
selective_lora_weights) existed but were only exercised in tests.

Wire them in: when the base param is MXFP4 and LoRA is active on both projections,
keep the weights packed (4-bit) and route through the fused MX kernel. Gated by
is_mxfp4_param, which is False for bf16/fp16 params and when torchao is absent, so
non-MXFP4 models fall through to the unchanged bf16 path. The MX kernel is pure-Triton
software dequant (LUT), so it needs no hardware MXFP4 support. MXFP4-without-LoRA
dequantizes explicitly (the fused kernel is LoRA-only; MXTensor has no transpose).

Memory (E=32, H=2048, I=768, N=2048, rank=16, RTX PRO 6000):
  persistent expert weights:  MXFP4 80MB vs bf16 302MB (3.8x)
  fwd+bwd transient: fused-MX 489MB vs full-dequant 3020MB (6.2x)

Validated vs the dequantized-weight reference within MX rounding tol (~1e-3..1e-2)
on output, dX, and the active-expert slice of LoRA dA/dB.

* feat(ep): sonicmoe+LoRA+EP on sm_120 via scattermoe, fused MXFP4 in EP path

Factor the weight/LoRA prep (incl. the fused-MXFP4 active-expert selection) into
_prepare_weights_and_lora and use it in both scattermoe_experts_forward and the EP
sentinel-skip forward, so base MXFP4 stays packed (no dequant) under EP too.

Wire _sonicmoe_local: on a device where the sonic-moe CUTLASS kernel can't run
(sm_120) and the experts use a standard layout, route to scattermoe_experts_forward_ep
-- giving sonicmoe + LoRA + EP on Blackwell (sentinel-skip + fused MXFP4). Elsewhere
sonicmoe+EP still raises (needs the upstream EP-sentinel kernel).

Validated MXFP4+EP vs the bf16-dequant reference within MX tol on output, dX, and the
active-expert slice of LoRA dA/dB; _sonicmoe_local routing covered. bf16 EP path and
the EP plugin suite unchanged (28 passed).

* style: pre-commit lint (ruff format/check) for the EP+MXFP4 changes

ruff-format reflow + fixes: drop redundant torch importorskip (F811), rename
ambiguous I->IM (E741), lambda->def in tests (E731), zip(..., strict=True) (B905).

* feat(scattermoe): fused NVFP4 base + LoRA, no dequant

Mirror the MXFP4 fused path for NVFP4 so base NVFP4 weights stay 4-bit packed instead
of dequantizing to bf16. NVFP4 shares the FP4 E2M1 packing with MXFP4 but scales per
16-element block with an E4M3 (fp8) value times an optional per-tensor scale, so the
dequant is codebook(nibble) * e4m3_block_scale * per_tensor -- not a power of two.

- is_nvfp4_param + selective_nvfp4_weights_fwd: slice the NVFP4Tensor to active experts
  and fold the E4M3 block scale (* per-tensor) into a linear fp32 scale, returned as an
  MXWeights with block_size=16, scale_is_linear=True. The packed buffer stays 4-bit;
  only the small [a, N, K/16] scale tensor is materialized.
- Kernel: the fwd + dX MXFP4 inner loops gain a SCALE_LINEAR branch -- load the scale as
  a linear fp multiplier instead of E8M0 exp2(byte-127) -- threaded through both
  scatter2scatter_lora_mx / _dX_mx wrappers, which now read block_size + scale_is_linear
  off the container. MXFP4 (SCALE_LINEAR=False) is byte-identical to before.
- Reuses the MXWeights container, so parallel_linear_lora's isinstance(_, MXWeights)
  check routes NVFP4 through the fused kernel with no change; _prepare_weights_and_lora
  adds NVFP4 alongside MXFP4 for both the standard and EP forwards.

Validated vs the dequantized-weight reference within NV rounding tol (dX/dA/dB ~1e-3..
1e-2) through the standard and EP forwards; MXFP4 suite still green (14 passed).
…#3677)

* fix: make NF4 dequantize torch.compile-safe via local custom_op

Wrap the Unsloth-derived NF4 dequant fast path in a torch.library.custom_op
(axolotl::nf4_dequantize) with a register_fake impl. dequantize() branches on
torch.compiler.is_compiling(): eager calls the ctypes body directly (zero
op-dispatch overhead), while tracing dispatches through the opaque op so
Dynamo can compile around it without graph-breaking on ctypes.c_int(...) or
the foreign-function calls.

Previously, torch.compile on any QLoRA model crashed with ctypes.ArgumentError
the first time a Linear4bit forward fell into the fast path.

Also:
- Drop dead legacy-list quant_state branch (not produced by bnb 0.40+).
- Drop the unused `out=` parameter (no production callers; grep-confirmed).
- Drop the HAS_CUDA_STREAM version gate (axolotl pins bnb >> 0.43.3).
- Rewrite tests/e2e/kernels/test_quantize.py to use real bnb.functional.quantize_4bit
  fixtures instead of synthetic ones whose blocksize=32 was silently invalid
  (the old ctypes path skipped validation). Adds compile regression tests for
  both nested (double-quant) and non-nested paths.

Validated:
- Eager bench vs main at 1024², 4096², 4096x11008 within ±2% noise.
- Unit tests pass on torch 2.11.0 + bnb 0.49.1.
- End-to-end: Qwen3.5-0.8B QLoRA train (lora_mlp/qkv/o_kernel=true, exercises
  dequantize in autograd Function fwd+bwd) + merge-lora pipeline succeed.

* review: add docstrings to internal helpers + tighten transposed test

Addresses CodeRabbit nits on PR #25:
- Docstrings on _nf4_dequantize_op and its register_fake impl to lift
  docstring coverage above the 80% threshold.
- test_dequantize_transposed now asserts value equality vs
  bnb.functional.dequantize_4bit(...).t() in addition to shape, so a
  layout/content regression would actually fail the test.

Skipped CodeRabbit's fullgraph=True nit: tracing under fullgraph=True
currently fails on both main and branch (same _SimpleCData.__new__ trace
error before reaching the is_compiling() branch). Investigating is out
of scope for this PR; default-mode torch.compile is what the fix targets
and is verified working.

Skipped CodeRabbit's CUDA_STREAM caching concern: the cache pattern is
byte-identical to main and removing it costs +5-21% at the kernel level.
Stream-management refactor is a separate concern.

* docs: trim docstrings/comments to one-line WHY only

- Module + dequantize() docstrings condensed to single line.
- CUDA_STREAM comment dropped benchmark numbers and version reference.
- Removed task-history language ("no longer crashes", "vs the prior
  implementation", "Regression:") from docstrings; comments now describe
  permanent invariants, not the PR diff.

* style: ruff format dequantize() call arg lists per pre-commit

* review: key CUDA_STREAM cache by device

Per CodeRabbit: the single-global stream cache could in principle return
the wrong device's stream for a multi-device-single-process caller. In
practice axolotl always runs one process per CUDA device, so this is
defensive rather than fixing a known reproducer.

Using `setdefault` keeps the cache race benign under threading (worst
case: redundant current_stream() call) and drops the `global` keyword
since the dict is mutated in place, never rebound.

* fix broken MX tests from transformers 5.8.1 upgrade

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants