feat(compile): expose torch_compile_options for Inductor flag tuning#34
feat(compile): expose torch_compile_options for Inductor flag tuning#34thad0ctor wants to merge 10 commits into
Conversation
Adds a `torch_compile_options: dict[str, bool | int | float | str] | None`
schema field that lets users opt into a small allowlist of
`torch._inductor.config` knobs that HF Trainer otherwise does not expose.
Current allowlist (INDUCTOR_COMPILE_OPTIONS_ALLOWLIST in
`src/axolotl/utils/schemas/enums.py`):
- coordinate_descent_tuning [benched on 8B; +0.5%]
- coordinate_descent_check_all_directions
- shape_padding [torch 2.11 default]
- epilogue_fusion [torch 2.11 default]
- max_autotune_gemm
- fx_graph_cache [torch 2.11 default]
- assume_aligned_inputs
- comprehensive_padding [torch 2.11 default]
- decompose_mem_bound_mm [experimental]
- triton.cudagraphs [breaks with dynamic shapes]
HF Trainer's TorchDynamoPlugin (transformers 5.8.1) only forwards
torch_compile / torch_compile_backend / torch_compile_mode; it does not
forward kwargs / options to torch._inductor.config. Users wanting to flip
e.g. coordinate_descent_tuning previously had to monkey-patch
torch._inductor.config themselves before invoking axolotl. This change
gives them a config-driven path. ConfigModule supports dotted setattr
natively, so triton.cudagraphs and other nested keys work without any
path-walking in the runtime apply.
Validation is enforced at schema time so failures surface during
`axolotl preprocess`, not after model load:
- field_validator rejects keys outside the allowlist with a clear
"Allowed: ..." error.
- model_validator rejects torch_compile_options set without
torch_compile enabled (`auto` allowed through, since the
auto-resolver may flip it true downstream).
The runtime apply path (`_apply_torch_compile_options` in
core/builders/base.py) trusts the validated schema and just iterates the
dict setattr-ing on torch._inductor.config.
Bench (Qwen3-8B + LoRA r=16 [q,k,v,o] + sdpa + bf16 + grad_ckpt +
seq=2048 + torch 2.11.0+cu130 + sm_86 / RTX 3090 Ti, 5 conditions x
2 seeds x 2000 steps, fresh process per replicate):
| cond | mean ms/step | delta vs fused-only baseline |
|---|---|---|
| eager (no fused, no compile) | 2520.6 | +2.8 % |
| fused_attn_kernel only | 2452.6 | 0.0 % |
| fused + torch_compile (default) | 2055.1 | -16.2 % |
| + coordinate_descent_tuning | 2045.0 | -16.6 % |
| + shape_padding + epilogue_fusion | 2044.0 | -16.7 % |
- coordinate_descent_tuning adds ~0.5 % steady-state on top of plain
compile on this rig (consistent across both seeds and across
tail-100 / tail-500 / tail-1000 / tail-2000 windows); cold-start
adds 10-30 s per unique kernel signature on the first compile.
- shape_padding and epilogue_fusion are already True by default in torch
2.11; explicit setting is a no-op vs cond d on this version. The
example YAML therefore only sets coordinate_descent_tuning to avoid
misleading readers into thinking the other two contributed to the delta.
- The other 6 allowlist keys are exposed for user experimentation; a
follow-up bench (conds f-l) is in progress on the same rig.
Files touched:
- src/axolotl/utils/schemas/enums.py: new
INDUCTOR_COMPILE_OPTIONS_ALLOWLIST frozenset (10 keys).
- src/axolotl/utils/schemas/config.py: new schema field +
validate_torch_compile_options field_validator +
check_torch_compile_options_requires_compile model_validator.
- src/axolotl/core/builders/base.py: _apply_torch_compile_options
staticmethod called from _configure_torch_compile when
torch_compile_options is set.
- tests/conftest.py: lifted capture_axolotl_warnings helper out of
tests/test_attn_implementation.py (previously duplicated).
- tests/test_attn_implementation.py: imports capture_axolotl_warnings
from tests.conftest instead of defining its own copy.
- tests/core/test_builders.py: TestApplyTorchCompileOptions covers the
runtime path (allowlisted apply, dotted-key apply, multi-flag apply,
empty-dict no-op, mock-based wiring assertions, and a regression
sentinel that fails if torch renames any allowlisted flag).
- tests/patched/test_validation.py: TestTorchCompileValidation extended
with rejects_disallowed_key, requires_torch_compile_enabled,
rejects_when_torch_compile_false, and with_auto_compile_passes.
- examples/qwen3/8b-lora-fused-attn-compile.yaml: example stacking
fused_attn_kernel + torch_compile + torch_compile_options, with
commented optional flags for the other allowlist keys.
- docs/optimizations.qmd: hardware note split into baseline /
cold-start / tuning bullets (replaces an inherited single
long sentence).
Note: built on top of the qwen-fused-kernels branch (PR #27) so the
example YAML can reference fused_attn_kernel. The schema field +
wiring itself is independent of PR #27 and would also work on plain
origin/main.
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR introduces a configurable PyTorch/Inductor compile options feature that allows users to tune compiler behavior through ChangesTorch Compile Options
🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
|
📖 Documentation Preview: Deployed on Netlify from commit fc18d11 |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tests/core/test_builders.py (1)
635-642: ⚡ Quick winTest may not validate dotted-key attributes correctly.
hasattr(obj, "a.b")checks for a literal attribute named"a.b", not nestedobj.a.b. For dotted keys inINDUCTOR_COMPILE_OPTIONS_ALLOWLIST(e.g.,"triton.cudagraphs"), this test may pass even if the nested structure doesn't exist, because it's only checking for a literal"triton.cudagraphs"attribute.To properly validate nested attributes, you'd need path-walking:
def has_nested_attr(obj, dotted_key): try: parts = dotted_key.split('.') for part in parts: obj = getattr(obj, part) return True except AttributeError: return FalseHowever, if the verification script I suggested above confirms that
torch._inductor.confighandles dotted keys specially, then this test is fine as-is.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/core/test_builders.py` around lines 635 - 642, The test test_all_allowlisted_flags_are_attributes_on_inductor_config incorrectly uses hasattr(_inductor_cfg, key) which treats dotted keys like "triton.cudagraphs" as a single literal attribute; replace this check with a nested-path walk (e.g., implement and call a helper has_nested_attr(obj, dotted_key) that splits key on '.' and iteratively getattr each part, returning False on AttributeError) and use it to validate each entry in INDUCTOR_COMPILE_OPTIONS_ALLOWLIST against _inductor_cfg.tests/test_attn_implementation.py (1)
17-17: 💤 Low valueOptional: prefer a shared helper module over importing from
conftest.py.Importing utilities directly from
conftest.pyis discouraged by pytest, since conftest files are special-cased during collection and can be imported under different module paths depending on rootdir, risking duplicate instances. Consider movingcapture_axolotl_warningsinto a plain module (e.g.,tests/utils.py) and importing from there. Works fine as-is, so this is non-blocking.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/test_attn_implementation.py` at line 17, Move the helper function capture_axolotl_warnings out of conftest.py into a regular test utility module (e.g., tests/utils.py) and update imports in tests/test_attn_implementation.py to import capture_axolotl_warnings from that module; locate the reference to capture_axolotl_warnings in tests/test_attn_implementation.py and replace the import from tests.conftest with an import from the new utils module, ensuring any other tests that use capture_axolotl_warnings are updated similarly.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@examples/qwen3/8b-lora-fused-attn-compile.yaml`:
- Around line 28-37: Add a semantic validation that rejects or warns when
sample_packing is enabled together with triton.cudagraphs in
torch_compile_options: inside AxolotlInputConfig implement a model validator (or
add a check in its post-init/validation hook) that inspects self.sample_packing
and self.torch_compile_options and raises a ValueError or issues a runtime
warning if sample_packing is true and torch_compile_options contains
triton.cudagraphs: true; alternatively add the same check at runtime in
_apply_torch_compile_options before calling setattr on inductor flags to prevent
silently applying the incompatible triton.cudagraphs setting.
In `@src/axolotl/utils/schemas/enums.py`:
- Around line 159-173: The INDUCTOR_COMPILE_OPTIONS_ALLOWLIST contains keys that
may not exist on torch._inductor.config for PyTorch 2.9; add a unit test (e.g.,
test_inductor_allowlist_keys_exist) that imports
INDUCTOR_COMPILE_OPTIONS_ALLOWLIST and asserts every key is an attribute on
torch._inductor.config (using hasattr) so the CI will catch invalid names
(including "triton.cudagraphs"); then update INDUCTOR_COMPILE_OPTIONS_ALLOWLIST
to only include keys that pass that test for the pinned PyTorch version (remove
or rename "coordinate_descent_check_all_directions", "decompose_mem_bound_mm",
"assume_aligned_inputs", "comprehensive_padding" if they are not present) and
document the expected PyTorch compatibility in a comment above
INDUCTOR_COMPILE_OPTIONS_ALLOWLIST.
---
Nitpick comments:
In `@tests/core/test_builders.py`:
- Around line 635-642: The test
test_all_allowlisted_flags_are_attributes_on_inductor_config incorrectly uses
hasattr(_inductor_cfg, key) which treats dotted keys like "triton.cudagraphs" as
a single literal attribute; replace this check with a nested-path walk (e.g.,
implement and call a helper has_nested_attr(obj, dotted_key) that splits key on
'.' and iteratively getattr each part, returning False on AttributeError) and
use it to validate each entry in INDUCTOR_COMPILE_OPTIONS_ALLOWLIST against
_inductor_cfg.
In `@tests/test_attn_implementation.py`:
- Line 17: Move the helper function capture_axolotl_warnings out of conftest.py
into a regular test utility module (e.g., tests/utils.py) and update imports in
tests/test_attn_implementation.py to import capture_axolotl_warnings from that
module; locate the reference to capture_axolotl_warnings in
tests/test_attn_implementation.py and replace the import from tests.conftest
with an import from the new utils module, ensuring any other tests that use
capture_axolotl_warnings are updated similarly.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c2f47365-91ef-47d5-a9d8-b103d755eca3
📒 Files selected for processing (9)
docs/optimizations.qmdexamples/qwen3/8b-lora-fused-attn-compile.yamlsrc/axolotl/core/builders/base.pysrc/axolotl/utils/schemas/config.pysrc/axolotl/utils/schemas/enums.pytests/conftest.pytests/core/test_builders.pytests/patched/test_validation.pytests/test_attn_implementation.py
9b82d76 to
c264720
Compare
- config.py: add check_cudagraphs_wo_static_shapes model_validator warning on triton.cudagraphs + sample_packing (cudagraphs need static shapes). +2 tests. - enums.py: one-line note that allowlist keys must resolve on torch._inductor.config (all 10 verified on torch 2.11); existing test_all_allowlisted_flags_are_attributes_on_inductor_config sentinel guards renames. - docs/optimizations.qmd + example YAML: correct the tuning guidance to the cross-arch findings — torch_compile win scales with arch (~-14 to -16.5% sm_86, ~-19% sm_120); coordinate_descent_tuning is within noise (not +0.5%); max_autotune_gemm is the one measured gain, on sm_120 + >=32 GB only (regresses on sm_86). List the full allowlist. - one-line docstrings/comments.
c264720 to
4cd9309
Compare
…ns-clean # Conflicts: # tests/core/test_builders.py
…ge-E MoEs (axolotl-ai-cloud#3712) * perf(scattermoe-lora): grouped-Gram dA/dB + sync-free dX_lora for large-E MoEs The dA/dB split kernel recomputes XA=X@A^T (resp. YB=dY@B) in registers once per output dim-block to avoid materializing it; and the non-fused dX_lora path is a Python per-expert loop with an expert_offsets[e].item() device sync per expert. Both scale badly with the expert count -- and modern MoEs run E>=128. The intermediates are rank-sized ([M*k, R], R~16), so materializing them once is near-free (the split kernel saves only ~2-8 MB across these shapes -- smaller than the dA/dB gradient it produces). Backport the multi-adapter wins to the single-adapter path: - grouped_gram.py: _grouped_gram_kernel + grouped_lora_weight_grads -- dA/dB as recompute-free grouped Gram products over precomputed XA/YB (bit-identical to group_bwd_lora, err 0.0). - ScatterMoELoRA.backward non-fused dA/dB: compute YB, XA once, then grouped-Gram; YB is reused by the non-fused dX_lora. - _compute_lora_input_grad: two grouped scatter2scatter GEMMs (sync-free, single launch each) when routing ids are given, reusing YB; the per-expert loop is kept as a fallback but with one host sync for the whole offset array, not O(E). Before/after, full ScatterMoELoRA fwd+bwd (E>=128, M=4096): fused-dX path (production, #dA/dB only): 1.08-1.15x non-fused path (#dA/dB + #dX_lora): up to 2.2x (Qwen3-MoE, DeepSeek) at +5..30 MB peak (the XA/YB buffers). The dA/dB kernel alone is 2-17x faster. MXFP4 unaffected: dX takes the is_mx branch (scatter2scatter_lora_dX_mx) and never the rewritten non-fused path; dA/dB never touch the (frozen) base, so LoRA grads are bit-identical for an MX base vs its bf16 dequantization at a workload above the fuse-gather threshold (new regression test). #1 (sonicmoe materialize) is already on main via baddbmm. * chore: lint * chore: lint
…MXFP4, sonicmoe fallback (axolotl-ai-cloud#3714) * perf(scattermoe-ep): skip DeepEP -1 sentinels in the local kernel Under DeepEP the local scattermoe kernel received the full N*K dispatched routing with remote slots mapped to expert 0 / weight 0, so the grouped GEMM + per-row LoRA processed every sentinel row (compute-and-mask). Worse, routing all sentinels to expert 0 piles ~half the rows into one bucket -> pathological load imbalance. scattermoe_experts_forward_ep drops the -1 sentinel rows before the GEMMs: runs both projections fully grouped over only the valid routed rows (the compacted routing breaks scatter2scatter's L == X.rows*k fan-out contract) and does the weighted token-combine via index_add_. Output is identical to the masked path since sentinel slots carry weight 0. The deep_ep dispatch now passes the raw -1-tagged routing through; each local kernel handles sentinels its own way (eager/scattermoe skip, grouped_mm masks). scattermoe LoRA fwd+bwd, RTX PRO 6000 (Blackwell), skip vs mask: Qwen3-30B ep2 2.9x/0.67x mem ep4 6.0x/0.37x ep8 10.3x/0.25x Qwen3-235B ep2 2.5x/0.65x ep4 4.9x/0.37x ep8 8.4x/0.24x DeepSeek ep2 2.1x/0.60x ep4 3.9x/0.34x ep8 6.8x/0.23x Validated bit-equivalent to the masked path (fp32 ~1e-7, bf16 ~6-9e-3) on output, dX and LoRA dA/dB for base + LoRA at ep2/ep4. * feat(scattermoe): gpt_oss layout + sm_120 sonicmoe fallback The sonic-moe CUTLASS kernel can't compile on consumer Blackwell (sm_120): its bundled quack GemmSm120 predates the concat_layout arg the dispatcher passes. Make the vendored scattermoe Triton path cover the gap. 1. gpt_oss layout in scattermoe_experts_forward: dispatch on the layout flags and handle the gpt_oss-style experts (is_transposed, not is_concatenated, has_bias) in a dedicated path -- weights are already [E, in, out] (no transpose), gate/up are interleaved ([..., ::2] / [..., 1::2]), per-expert bias is folded into the grouped GEMM, and the activation is the clamped sigmoid-GLU. LoRA fuses exactly as in the standard path (same in/out dims, same scatter2scatter_lora). Threads expert_biases through _parallel_linear_maybe_lora. 2. sm_120 fallback: when the sonic-moe kernel can't run (Blackwell) and the experts use a standard layout, sonicmoe_experts_forward_with_lora transparently routes to the scattermoe path, which runs there. Validated vs an eager reference (fp32 with TF32 off, ~1e-6): output, dX, and every gradient (gpt_oss base: d_gate_up/d_down + both biases; LoRA: dA/dB for gate_up and down) for base + LoRA, fp32 + bf16. gpt_oss MoE+LoRA fwd+bwd, vendored Triton vs eager, RTX PRO 6000 (Blackwell): gpt-oss-20b N=2048 9.5x / 0.10x mem gpt-oss-20b N=8192 3.6x / 0.32x mem gpt-oss-120b N=4096 48.8x / 0.05x mem MXFP4 gpt_oss weights dequantize on the fly via parallel_linear_lora (bf16 compute), as with other MXFP4 models on this path; fused MXFP4 for gpt_oss is not yet wired. * perf(scattermoe): use fused MXFP4 kernel in experts forward, no dequant scattermoe_experts_forward dequantized base MXFP4 weights to bf16 every step (the old `_get_base_param(...).transpose(2,1)` actually crashed on an MXTensor, so MXFP4 only ever ran via the dequant-based HFScatterMoEGatedMLP path). The fused MX kernel (scatter2scatter_lora_mx, dequant inside the K-loop) and all the plumbing (selective_mx_weights_fwd / get_active_experts / remap_expert_indices / selective_lora_weights) existed but were only exercised in tests. Wire them in: when the base param is MXFP4 and LoRA is active on both projections, keep the weights packed (4-bit) and route through the fused MX kernel. Gated by is_mxfp4_param, which is False for bf16/fp16 params and when torchao is absent, so non-MXFP4 models fall through to the unchanged bf16 path. The MX kernel is pure-Triton software dequant (LUT), so it needs no hardware MXFP4 support. MXFP4-without-LoRA dequantizes explicitly (the fused kernel is LoRA-only; MXTensor has no transpose). Memory (E=32, H=2048, I=768, N=2048, rank=16, RTX PRO 6000): persistent expert weights: MXFP4 80MB vs bf16 302MB (3.8x) fwd+bwd transient: fused-MX 489MB vs full-dequant 3020MB (6.2x) Validated vs the dequantized-weight reference within MX rounding tol (~1e-3..1e-2) on output, dX, and the active-expert slice of LoRA dA/dB. * feat(ep): sonicmoe+LoRA+EP on sm_120 via scattermoe, fused MXFP4 in EP path Factor the weight/LoRA prep (incl. the fused-MXFP4 active-expert selection) into _prepare_weights_and_lora and use it in both scattermoe_experts_forward and the EP sentinel-skip forward, so base MXFP4 stays packed (no dequant) under EP too. Wire _sonicmoe_local: on a device where the sonic-moe CUTLASS kernel can't run (sm_120) and the experts use a standard layout, route to scattermoe_experts_forward_ep -- giving sonicmoe + LoRA + EP on Blackwell (sentinel-skip + fused MXFP4). Elsewhere sonicmoe+EP still raises (needs the upstream EP-sentinel kernel). Validated MXFP4+EP vs the bf16-dequant reference within MX tol on output, dX, and the active-expert slice of LoRA dA/dB; _sonicmoe_local routing covered. bf16 EP path and the EP plugin suite unchanged (28 passed). * style: pre-commit lint (ruff format/check) for the EP+MXFP4 changes ruff-format reflow + fixes: drop redundant torch importorskip (F811), rename ambiguous I->IM (E741), lambda->def in tests (E731), zip(..., strict=True) (B905). * feat(scattermoe): fused NVFP4 base + LoRA, no dequant Mirror the MXFP4 fused path for NVFP4 so base NVFP4 weights stay 4-bit packed instead of dequantizing to bf16. NVFP4 shares the FP4 E2M1 packing with MXFP4 but scales per 16-element block with an E4M3 (fp8) value times an optional per-tensor scale, so the dequant is codebook(nibble) * e4m3_block_scale * per_tensor -- not a power of two. - is_nvfp4_param + selective_nvfp4_weights_fwd: slice the NVFP4Tensor to active experts and fold the E4M3 block scale (* per-tensor) into a linear fp32 scale, returned as an MXWeights with block_size=16, scale_is_linear=True. The packed buffer stays 4-bit; only the small [a, N, K/16] scale tensor is materialized. - Kernel: the fwd + dX MXFP4 inner loops gain a SCALE_LINEAR branch -- load the scale as a linear fp multiplier instead of E8M0 exp2(byte-127) -- threaded through both scatter2scatter_lora_mx / _dX_mx wrappers, which now read block_size + scale_is_linear off the container. MXFP4 (SCALE_LINEAR=False) is byte-identical to before. - Reuses the MXWeights container, so parallel_linear_lora's isinstance(_, MXWeights) check routes NVFP4 through the fused kernel with no change; _prepare_weights_and_lora adds NVFP4 alongside MXFP4 for both the standard and EP forwards. Validated vs the dequantized-weight reference within NV rounding tol (dX/dA/dB ~1e-3.. 1e-2) through the standard and EP forwards; MXFP4 suite still green (14 passed).
…#3677) * fix: make NF4 dequantize torch.compile-safe via local custom_op Wrap the Unsloth-derived NF4 dequant fast path in a torch.library.custom_op (axolotl::nf4_dequantize) with a register_fake impl. dequantize() branches on torch.compiler.is_compiling(): eager calls the ctypes body directly (zero op-dispatch overhead), while tracing dispatches through the opaque op so Dynamo can compile around it without graph-breaking on ctypes.c_int(...) or the foreign-function calls. Previously, torch.compile on any QLoRA model crashed with ctypes.ArgumentError the first time a Linear4bit forward fell into the fast path. Also: - Drop dead legacy-list quant_state branch (not produced by bnb 0.40+). - Drop the unused `out=` parameter (no production callers; grep-confirmed). - Drop the HAS_CUDA_STREAM version gate (axolotl pins bnb >> 0.43.3). - Rewrite tests/e2e/kernels/test_quantize.py to use real bnb.functional.quantize_4bit fixtures instead of synthetic ones whose blocksize=32 was silently invalid (the old ctypes path skipped validation). Adds compile regression tests for both nested (double-quant) and non-nested paths. Validated: - Eager bench vs main at 1024², 4096², 4096x11008 within ±2% noise. - Unit tests pass on torch 2.11.0 + bnb 0.49.1. - End-to-end: Qwen3.5-0.8B QLoRA train (lora_mlp/qkv/o_kernel=true, exercises dequantize in autograd Function fwd+bwd) + merge-lora pipeline succeed. * review: add docstrings to internal helpers + tighten transposed test Addresses CodeRabbit nits on PR #25: - Docstrings on _nf4_dequantize_op and its register_fake impl to lift docstring coverage above the 80% threshold. - test_dequantize_transposed now asserts value equality vs bnb.functional.dequantize_4bit(...).t() in addition to shape, so a layout/content regression would actually fail the test. Skipped CodeRabbit's fullgraph=True nit: tracing under fullgraph=True currently fails on both main and branch (same _SimpleCData.__new__ trace error before reaching the is_compiling() branch). Investigating is out of scope for this PR; default-mode torch.compile is what the fix targets and is verified working. Skipped CodeRabbit's CUDA_STREAM caching concern: the cache pattern is byte-identical to main and removing it costs +5-21% at the kernel level. Stream-management refactor is a separate concern. * docs: trim docstrings/comments to one-line WHY only - Module + dequantize() docstrings condensed to single line. - CUDA_STREAM comment dropped benchmark numbers and version reference. - Removed task-history language ("no longer crashes", "vs the prior implementation", "Regression:") from docstrings; comments now describe permanent invariants, not the PR diff. * style: ruff format dequantize() call arg lists per pre-commit * review: key CUDA_STREAM cache by device Per CodeRabbit: the single-global stream cache could in principle return the wrong device's stream for a multi-device-single-process caller. In practice axolotl always runs one process per CUDA device, so this is defensive rather than fixing a known reproducer. Using `setdefault` keeps the cache race benign under threading (worst case: redundant current_stream() call) and drops the `global` keyword since the dict is mutated in place, never rebound. * fix broken MX tests from transformers 5.8.1 upgrade --------- Co-authored-by: Wing Lian <wing@axolotl.ai>
…dn't be flag as an warning (axolotl-ai-cloud#3718)
Description
Adds
torch_compile_options: dict[str, bool|int|float|str] | None— an opt-in allowlist of 10torch._inductor.configknobs that HF Trainer'sTorchDynamoPlugin(transformers 5.8.1) never forwards. DefaultNone→ behavior identical to today. The allowlist lives inenums.py(INDUCTOR_COMPILE_OPTIONS_ALLOWLIST) as the single source of truth for the validators and the runtime apply.axolotl preprocess: afield_validatorrejects non-allowlisted keys; amodel_validatorrejects options withouttorch_compile; a secondmodel_validatorwarns ontriton.cudagraphs+sample_packing(cudagraphs need static shapes, packed sequences are dynamic).TrainerBuilderBase._apply_torch_compile_options)setattrs each validated key ontorch._inductor.configbefore compile; dotted keys (triton.cudagraphs) work natively.docs/optimizations.qmdsm_86 note. Purely additive on the existingtorch_compile*fields.Motivation and Context
HF Trainer doesn't forward kwargs to
torch._inductor.config, so flags likemax_autotune_gemmpreviously required monkey-patching beforeaxolotl train. These flags have no single correct default — their payoff is architecture-specific (see below) — so they're exposed as opt-in config rather than baked in.How has this been tested?
Unit tests (torch 2.11.0+cu130 / transformers 5.8.1):
TestTorchCompileValidation, 9): allowlist round-trip, rejects disallowed key, requirestorch_compile, and the newtriton.cudagraphs+sample_packingwarning.TestApplyTorchCompileOptions, 7): flags actually fliptorch._inductor.config, all 10 allowlist keys verified present (rename sentinel, incl. dotted key), dotted/multi/empty-dict apply, wiring.Benchmarks — Qwen3-8B + LoRA r=16 + bf16 + sdpa + seq 2048 + grad_ckpt + alpaca, 2 seeds, fresh process per (cond, seed), tail-100-median ms/step. Conds: a eager · b
fused_attn_kernel(baseline) · c b +torch_compile· then single inductor flags on c.torch_compile(c vs b) is a large win that scales with arch, and cuts peak memory 19,992 → 19,614 MB on every card:Of the 10 flags, only
max_autotune_gemmhas a real, arch-dependent effect (Δ vs plain compile); the rest are washes or torch-2.11-default no-ops on both arches:max_autotune_gemmcoordinate_descent_tuningtriton.cudagraphsmax_autotune_gemmis best run alone — a 6-combof+Xprobe on sm_120 found nothing beats it (coordinate_descent_tuningandtriton.cudagraphseach erode it). It hits the same 101 KB shared-mem limit on both arches (consumer Blackwell ≈ Ampere here), but the autotuner nets better tiles on sm_120, and its +2.2 GB peak needs ≥32 GB. Recommend on sm_120 + ≥32 GB; leave off on sm_86.Full sm_120 single-flag sweep (RTX 5090), showing what does and doesn't help:
torch_compilecoordinate_descent_tuningshape_padding+epilogue_fusionmax_autotune_gemmtriton.cudagraphsAI Usage Disclaimer
Yes — Opus 4.8 used throughout
Screenshots (if appropriate)
Types of changes
torch_compile_optionsconfig field; opt-in Inductor flag plumbingSocial Handles (Optional)