Skip to content

Record: Varlen attention + fused MLP + doc-independent TTT (1.07336)#1530

Open
samacqua wants to merge 4 commits intoopenai:mainfrom
samacqua:varlen-fused-ttt-v2
Open

Record: Varlen attention + fused MLP + doc-independent TTT (1.07336)#1530
samacqua wants to merge 4 commits intoopenai:mainfrom
samacqua:varlen-fused-ttt-v2

Conversation

@samacqua
Copy link
Copy Markdown
Contributor

@samacqua samacqua commented Apr 11, 2026

Record: Varlen attention + fused MLP + TTT

val_loss: 2.77261 | val_bpb: 1.07336 | ~15.99 MB | 8×H100 SXM, 587s train + ~340s TTT eval

Seed BPB Loss
0 1.07258208 2.77059090
1 1.07324696 2.77230836
2 1.07426259 2.77493185
Mean 1.07336388 2.77261037
Std 0.00084633 0.00218618

Best PR bpb (PR #1529): bpb=1.0753 (delta=0.0019), loss=2.7776 (delta=0.0050)

Merged record bpb (PR #1493): bpb=1.0810 (delta=0.0076), loss=2.7923 (delta=0.0197)

Increased training speed ~5% via variable length attention, a fused MLP triton kernel (no cutlass_evt_fusion dep), and grouping together small parameters, yielding ~.002 nats when comparing sliding window eval. Re-added document-based LoRA TTT which has no inter-sequence dependence and improves over strided evaluation by ~.008 nats.

Main changes

Applied changes from my old PR to a recent record PR: #1523. But PR #1552 beat my previous bpb before I submitted the PR, so I incorporated their (orthogonal) improvements. Most of below is copied from my previous PR #1354.

This involves 3 things:

1. Variable length attention (~2% faster training, ~0.001 nats)

Replaced dense causal attention with Flash Attention 3's flash_attn_varlen_func. During training, documents are packed into flat token buffers with cu_seqlens boundaries so attention is computed within documents only — the model never attends across unrelated documents that happen to be adjacent in a batch.

This does two things:

  • Removes the need for the model to learn to ignore pre-BOS content from unrelated documents
  • Reduces wasted FLOPs: e.g. 10 short (100-token) docs packed into a 1k-token buffer cost proportional to 100 * 100**2 = 1M attention FLOPs vs 10 * 1000**2 = 10M with dense attention.

2. Fused MLP + grouped small params (~3% faster training, ~0.001 nats)

A custom Triton kernel (linear_leaky_relu_square_kernel) fuses the up-projection, LeakyReLU(0.5)² activation, and squaring into a single kernel. Based on similar kernels from modded-nanogpt. I also group the many tiny replicated scalar/control gradients into a single all-reduce to avoid a pile of tiny collectives.

3. Doc-based test-time training (TTT) (~0.008 nats over sliding window)

Blog explaining LoRA-based TTT from past record

Although it is technically legal in this competition to train on tokens from previous documents in the dataset, I am spiritually opposed to this. Under the current formulation, if the eval set was bigger, the expectation of the loss would be lower which seems broken. So in this implementation, there is score-first TTT applied to each sequence in the validation set independently (and efficiently using batched LoRAs), which is strictly harder.

Re-adds LoRA-based TTT, based on my old implementation, but > 2x faster which allows for using smaller chunk sizes which leads to better performance. This is an instance of "Case 3" according to this classification.

It's interesting to note that adding test-time training improves loss more than adding ~215 steps. These 215 steps train on 786432*215=169,082,880 tokens to gain ~.002 nats. The average sequence length in the validation set is ~200 tokens which means test-time training here gains ~.003 nats / 800 tokens on average (valid bc sequences are trained independently). So, in a way, TTT is ~(.003/800) / (.002/169082880) >= 300k times more token efficient than pre-training: it helps to be in distribution :)

Other small changes

Made some changes to make replication and dev based on this PR easier:

  • Load from a checkpoint just for eval
  • Didn't submit minified code, instead wrote that utility into the script when calculating file size so that it is easier for people to build off of this
  • Store unminified code in logs

Replicating runs + dev

# setup
uv venv
source .venv/bin/activate
uv pip install -r records/track_10min_16mb/2026-04-10_VarLenAttn/requirements.txt
uv pip install --break-system-packages flash_attn_3 --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291
uv pip install torch==2.9.1+cu128 --extra-index-url https://download.pytorch.org/whl/cu128

MATCHED_FINEWEB_REPO_ID=kevclark/parameter-golf \
  python3 data/cached_challenge_fineweb.py --variant sp8192 --train-shards  128

# train + eval
SEED=0
ARTIFACT_DIR="runs/varlen${SEED}" SEED=$SEED \
    torchrun --standalone --nproc_per_node=8 \
    records/track_10min_16mb/2026-04-10_VarLenAttn/train_gpt.py

# eval saved checkpoint w/ TTT (useful for dev)
EVAL_ONLY_PATH="runs/varlen${SEED}/final_model.pt" SEED=$SEED \
    torchrun --standalone --nproc_per_node=8 \
    records/track_10min_16mb/2026-04-10_VarLenAttn/train_gpt.py

@samacqua samacqua changed the title Varlen attention + fused MLP + doc-independent TTT Varlen attention + fused MLP + doc-independent TTT Apr 11, 2026
@samacqua samacqua changed the title Varlen attention + fused MLP + doc-independent TTT Varlen attention + fused MLP + doc-independent TTT (1.07643) Apr 11, 2026
@samacqua samacqua changed the title Varlen attention + fused MLP + doc-independent TTT (1.07643) Record: Varlen attention + fused MLP + doc-independent TTT (1.07643) Apr 11, 2026
dexhunter added a commit to dexhunter/parameter-golf that referenced this pull request Apr 11, 2026
…g + Muon 0.97 — val_bpb 1.07747 (3-seed mean)

- 3-seed mean: 1.07747 BPP (std 0.00064) / 2.78321 nats
- ~15.99 MB artifact, 8×H100 SXM, 600s
- VarLen attention (within-document only), doc-independent LoRA TTT
- Parameter banking + triple depth recurrence + parallel residuals
- PyTorch MLP fallback (no Triton/CUTLASS dependency)
- Based on PR openai#1530, PR openai#1523, PR openai#1514
@dexhunter
Copy link
Copy Markdown
Contributor

I may be missing something, but I think there is one higher-scrutiny #1017 / README issue worth clarifying.

In the TTT path, the compile warmup appears to use actual validation tokens before the main eval loop, and it also does backward() / step() inside that warmup block. The main score loop itself looks score-first, so this is not a claim about the core TTT logic; the concern is specifically the pre-eval warmup.

My read of the current guidance is:

  • Track B allows score-first adaptation using previously scored eval tokens
  • but not adaptation on validation tokens before they are scored

If that reading is right, would you be willing to switch the warmup to:

  • synthetic tokens / shape-only warmup, or
  • training tokens, or
  • a no-update warmup

That would make the legality story much cleaner.

@samacqua samacqua changed the title Record: Varlen attention + fused MLP + doc-independent TTT (1.07643) Record: Varlen attention + fused MLP + doc-independent TTT (1.07336) Apr 11, 2026
@samacqua
Copy link
Copy Markdown
Contributor Author

samacqua commented Apr 11, 2026

@dexhunter it could honestly just be commented out, given that warmup + eval time is still < 600s. But it shouldn't matter -- training warmup does the same thing, parameters and optimizer states are reset. As a sanity check I re-ran TTT on seed 2 w/ warmup commented out, and the loss was within expected variance between runs (actually did slightly better): quantized_ttt_lora val_loss:2.77492177 val_bpb:1.07425869 eval_time:338465ms.

But given that making a change + re-running what take an hour of 8xh100, I will only if it is a blocker.

@MatoTeziTanka
Copy link
Copy Markdown

Community Review — VarLen attention + fused MLP + doc-independent TTT

Thanks @samacqua. Doc-independent TTT via cu_seqlens boundary isolation is a genuinely interesting approach to the causal-dependence question the SLOT cluster has been bouncing around. One import blocker, then a deeper question on the doc-independence claim.

What I found (head SHA 161d64428159c61f2d42dd6d415ec1386599ef90, records/track_10min_16mb/2026-04-10_VarLenAttn/train_gpt.py, 116,694 bytes of actual source — not shim-compressed, directly readable):

  • Imports (L1-12): from flash_attn_interface import (flash_attn_varlen_func, ...) at the top — hard import, no fallback
  • eval_val_sliding at L1948 — uses BOS_ID to find document boundaries via (chunk_cpu[:-1] == BOS_ID).nonzero(), builds cu_seqlens from those boundaries, then calls the FA3 varlen kernel with that cu_seqlens argument. This is standard document-packed varlen attention — within a batch, attention can't cross a BOS boundary.
  • _build_cu_seqlens(bos_pos, total_len, device, max_doc_len=0, bucket_size=64) at L260 — helper that builds the cu_seqlens tensor with a bucket-size 64 padding
  • DocumentPackingLoader at L283 — the training-time document-packed loader
  • Standard GPT at L730 with varlen-aware blocks

"Doc-independent TTT" — the interesting idea. My read is that if the LoRA (or whatever TTT-like adaptation you're running) respects the same cu_seqlens boundaries as attention, then when token t in document D is scored, the adaptation state derived from document D' ≠ D doesn't influence t's scoring through the attention path — because attention physically can't cross the boundary. That's a clean causal isolation argument IF the adaptation state also respects document boundaries.

The open question is whether the adaptation state itself is per-document or per-batch. I couldn't find an eval_val_sliding_ttt / eval_val_ttt / eval_val_slot function in the 116KB source via my structural grep — could you point me at where the TTT adaptation lives in this codebase? "doc-independent TTT" in the title suggests there's an eval-time adaptation somewhere, but the function I was expecting by name doesn't show up. Is the TTT-like adaptation integrated into the main eval path, or does it use a different function name?

Import blocker (smoke test). The CPU smoke on CT2038 hit:

IMPORT_FAIL error=ImportError("cannot import name 'flash_attn_varlen_func' from 'flash_attn_interface' (unknown location)")

My flash_attn stub covers flash_attn_func but not the varlen variant. This is a CPU-stub limitation, not a PR defect — FA3 is available on H100s where this is intended to run.

Questions

  1. Where is the doc-independent TTT adaptation loop? eval_val_sliding at L1948 looks like standard no-grad scoring. Is the adaptation inlined into the forward pass, or does it live in a separately-named function I missed?
  2. Per-batch vs per-doc adaptation state: if documents D_i and D_j are packed into the same batch with disjoint cu_seqlens, does the TTT state they each produce stay isolated, or does it mix before being used to score the next batch?
  3. Cross-document information flow through DocumentPackingLoader: the training-time loader packs multiple documents. Does the TTT adapt at training time (which would be legal as a training-side technique) or at eval time (which needs the doc-independence argument to close)?

Compliance summary (partial)

  • N-gram family bug: not present (no full_key / ctx_hash ^ target * primes)
  • Scored-region SLOT: not present (no slot_loss / mask = scored region)
  • Pre-Quant TTT on val_tokens: not present (no prequant_ttt_adapt_adamw)
  • Varlen attention via FA3 cu_seqlens: legitimate hardware optimization per Issue Are HW optimization solutions also welcome? #1409
  • Doc-independent TTT causal argument: pending clarification on the adaptation loop location

Verdict: LOOKS INTERESTING, NEEDS AUTHOR CLARIFICATION on the TTT adaptation path.

Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: HOLD pending author clarification on where and how the doc-independent TTT runs. If the adaptation respects cu_seqlens boundaries and the temporal ordering is score-before-adapt at the document level, this is a genuinely clean path out of the SLOT compliance bind, and I'd flip to MERGE.


Reviewed by @MatoTeziTankaThe Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): IMPORT_FAIL due to flash_attn_varlen_func missing from my flash_attn stub (known stub limitation, not a PR defect). Static review of the 116,694-byte source performed against the compliance axes above — 4 of 5 clean on the standard audits, 1 needs author clarification on the TTT adaptation location. AI tooling: review drafted with Claude Code (Opus); batch-9 subagent quota exhausted mid-batch so this review was authored in the main session with reduced audit depth — the flagged questions above are places I'd want confirmation from the author rather than statements I'm making from full verification. SHA 161d64428159c61f2d42dd6d415ec1386599ef90.

@samacqua
Copy link
Copy Markdown
Contributor Author

@MatoTeziTanka look at eval_val_ttt_lora here:

  • each sequence (split by BOS) has it's own LoRA, so there is no dependence between sequences
  • each sequence is split into 32-token chunks. We iterate through the chunks in order, evaluating on chunk i, then training on chunk i before moving on to chunk i+1.

So yes, it respects the same document boundaries. It is strictly harder (and more valid imo) than TTT on the full validation sequence autoregressively.

See the "Methods" section of this blog for clarity.

dexhunter added a commit to dexhunter/parameter-golf that referenced this pull request Apr 12, 2026
… (3-seed mean)

PR openai#1530 v2 base + warmdown_frac=0.75 + TTT_CHUNK_SIZE=48 + Muon 0.97.
3-seed mean: 1.07406 (std 0.00132), 2.77441 nats.
Delta vs merged SOTA (openai#1493): -0.01491 nats (clears 0.005 bar by 3.0x).
All artifacts < 16 MB, train < 600s, eval < 225s.
@dexhunter
Copy link
Copy Markdown
Contributor

The pattern structurally matches what @valerio-oai flagged as invalid in #677 ("adapt on validation before the reported eval pass"). Even though LoRA resets per batch, the compile warmup still runs backward+step on val tokens before the eval loop. Since you confirmed the fix is within variance, it would be worth switching to random/synthetic tokens to avoid any ambiguity during review.

@msisovic
Copy link
Copy Markdown
Contributor

The pattern structurally matches what @valerio-oai flagged as invalid in #677 ("adapt on validation before the reported eval pass"). Even though LoRA resets per batch, the compile warmup still runs backward+step on val tokens before the eval loop. Since you confirmed the fix is within variance, it would be worth switching to random/synthetic tokens to avoid any ambiguity during review.

This submission actually looks good to me. They don't "adapt on validation before the reported eval pass", as the warmup/compilation throws away the updates. The final result wouldn't change at all if they replaced those validation tokens in warmup with any other tokens. The author even notes that the result is unchanged, even when they comment out the warmup.

@msisovic
Copy link
Copy Markdown
Contributor

Hey btw @samacqua the training script will crash without the pyminify CLI tool installed on the machine, so you might want to add a step to the README that it should be installed (maybe I have missed it though).

@samacqua
Copy link
Copy Markdown
Contributor Author

@msisovic fixed. Thanks!

anthony-maio added a commit to anthony-maio/parameter-golf that referenced this pull request Apr 14, 2026
…nto SP8192 stack

Adds records/track_10min_16mb/2026-04-15_SP8192_VarLen/train_gpt.py (readable 1446 lines):

- flash_attn_varlen_func with cu_seqlens document packing in CausalSelfAttention
- DocumentPackingLoader replacing per-sequence shuffling for training batches
- Triton linear+LeakyReLU(0.5)^2 fused MLP kernel with two-lane output split
- cu_seqlens threaded through Block / GPT forward; max_seqlen pinned to
  train_seq_len to avoid torch.compile recompilation on varying ints

Retains full SP8192 stack: depth recurrence (2 loops, layers 3-5), parallel
residuals from layer 7, QK-Gain 5.0, GPTQ INT6 + INT8 embed + SDClip 12.85,
score-first chunk TTT, fused-softcap-ce eval kernel, SP8192 tokenizer.

Eval paths unchanged (ShuffledSequenceLoader + flash_attn_3_func when
cu_seqlens is None). New knobs: USE_VARLEN, USE_FUSED_MLP, CU_BUCKET_SIZE,
MAX_DOC_LEN. Requires flash_attn_3 wheels (cu128_torch291) and Triton 3.2+
for TensorDescriptor API.

Compiles clean locally. Awaiting 8xH100 smoke test to validate end-to-end.
anthony-maio added a commit to anthony-maio/parameter-golf that referenced this pull request Apr 14, 2026
…is regressive on our SP8192 + depth recurrence stack

Three configs tested at seed 42 on 8xH100 SXM:
- VarLen + Fused MLP: 1.93 pre-quant val_bpb, 1440 steps, 2.3M tok/s (3.4x slower)
- Fused MLP only: 1.110 pre-quant val_bpb, 2581 steps, 3.4M tok/s (2.3x slower)
- Pure baseline reproduction: pod terminated mid-run before completion

Root cause: VarLen + depth recurrence + fullgraph torch.compile triggers cascading
shape recompilations (combinatorial explosion of loop_iter x cu_seqlens shape)
that overflow even a 64-entry compile cache. Fused MLP Triton kernel has per-call
TensorDescriptor allocation overhead that doesn't amortize for our hidden_dim=2048.

Conclusion: do not ship this port. PR openai#1572 (1.07974) remains best submission.
Move 2 (per-layer GPTQ from PR openai#1586) and Move 3 (LoRA TTT from PR openai#1530, eval-only
so no torch.compile recompile concern) are still viable next directions.
anthony-maio added a commit to anthony-maio/parameter-golf that referenced this pull request Apr 14, 2026
…192 stack

Config-level changes only, no kernel/compile changes that could interact with
our depth recurrence stack (unlike VarLen port in submission/sp8192-varlen-frontier):

- MLP_CLIP_SIGMAS 12.0 (tight, preserve MLP precision)
- ATTN_CLIP_SIGMAS 13.0 (looser, save bytes on attention weights)
- EMBED_BITS 8 -> 7 with EMBED_CLIP_SIGMAS 20.0 -> 15.0 (~530 KB artifact savings)
- MATRIX_LR 0.022 -> 0.026 (dexhunter 6-point sweep optimum)
- WARMDOWN_FRAC 0.72 -> 0.75 (longer peak LR window)

Dexhunter measured 1.07493 BPB (3-seed mean) applying these against PR openai#1530 base.
Against our 1.07974 SP8192 baseline the expected delta is in the 0.003-0.005 BPB
range; the adaptive clip is stack-independent and the embed-bits + LR tweaks are
universal. Fresh branch from upstream/main per PR hygiene (PR openai#1572 untouched).
anthony-maio added a commit to anthony-maio/parameter-golf that referenced this pull request Apr 14, 2026
Replaces chunk-based score-first SGD TTT with doc-independent batched LoRA
adaptation at eval time. Eval-only, training path unchanged, so none of the
torch.compile recompile concerns from VarLen apply here.

New machinery:
- BatchedLinearLoRA: per-document LoRA factors (bsz, rank, in_features)
- BatchedTTTLoRA: module holding Q/K/V/O/MLP-up/lm_head LoRAs per block
- CausalSelfAttention.forward accepts optional lora_q/k/v/o (adds to projections)
- MLP.forward accepts optional lora_up (adds to fc projection)
- Block.forward threads the LoRA args
- GPT.forward_ttt runs the full forward stack with LoRAs injected, returns
  per-token loss (reshaped to input shape)
- ttt_lora_evaluate orchestrates score-first doc batches with distributed
  counter-based work stealing across ranks

Compliance: each doc fully scored BEFORE its LoRA adapts (score-first). Each
doc gets fresh LoRA weights (doc-independent, no cross-doc leakage). Standard
causal attention throughout. No SLOT, no pre-quant TTT, no ETLB, no n-gram.

Samacqua reports ~-0.008 BPB vs sliding-window eval on his stack. If it
translates to our stack, would put us ~1.072-1.073, below the current 1.0728
frontier.

TTT_MODE=lora is default. Set TTT_MODE=chunk to fall back to the old chunk-
based score-first TTT.
ChideraIbe123 pushed a commit to ChideraIbe123/parameter-golf that referenced this pull request Apr 14, 2026
Adds flash_attn_varlen_func path for within-document attention during
training. Attention is restricted to doc boundaries detected via BOS
token positions in each batch, eliminating cross-doc attention noise.

Changes:
- Import flash_attn_varlen_func alongside flash_attn_3_func
- Add VARLEN_ENABLED and BOS_TOKEN_ID env var hyperparams
- Add _build_cu_seqlens_from_batch helper (detects BOS, builds cu_seqlens)
- Thread cu_seqlens/max_seqlen through CausalSelfAttention -> Block -> GPT
- Branch in attention: varlen when cu_seqlens provided, else flash_attn_3
- Switch torch.compile to fullgraph=False when VARLEN_ENABLED=1 (data-dep branch)
- Training step builds cu_seqlens per batch and passes to model

Eval path unchanged. When VARLEN_ENABLED=0 (default) behavior is identical
to PR openai#1493 reference. Compliance unchanged (training-only change, causality
preserved by causal=True flag).

Reference: PR openai#1530 @samacqua, PR openai#1536 @dexhunter

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
anthony-maio added a commit to anthony-maio/parameter-golf that referenced this pull request Apr 14, 2026
…zation

Council consensus across 3 models (Gemini, Sonnet, Nemotron) + followup analysis
identified these as the high-EV targeted fixes without lineage switch.

LoRA semantics (samacqua PR openai#1530 exact match):
- mlp_loras: dim -> dim (was dim -> hidden_dim), applied as parallel residual-level
  bypass at Block forward (was inner tweak inside MLP.forward)
- o_lora: input is pre-attention normalized residual n (was attention output y)
- MLP.forward reverted to no-lora signature (cleaner; mlp_lora lives at Block level)
- CausalSelfAttention.forward now only takes lora_q/k/v (o_lora moved to Block)

Pod speedgate at step 20 (env var POD_SPEEDGATE_MS, default 0 = disabled):
- Measures ms/step at step 20
- RuntimeError abort if above threshold
- Saves ~$5 per bad pod per council recommendation

Looped-layer quantization (env var LOOP_CLIP_SIGMAS, default 10.0):
- Tighter clip_sigmas for blocks.3/4/5 (the NUM_LOOPS=2 recurrent layers)
- Motivation per Sonnet: quantization error compounds 2x through recurrence,
  and GPTQ error amplifies ~900x over 3 cycles per Issue openai#140
- Only active when NUM_LOOPS > 0

No training changes; all three fixes are eval-only behavior + a safety gate.
Training path semantics unchanged from baseline.
anthony-maio added a commit to anthony-maio/parameter-golf that referenced this pull request Apr 14, 2026
Drops in samacqua's varlen + fused MLP + doc-independent batched LoRA TTT
verbatim, with one targeted change to the eval-side TTT compile warmup:
swap real validation-token slicing for torch.randint of identical shape.
Eliminates validation token exposure during compile cache population
without changing the cache itself (shapes drive recompilation, not contents).

See record README for the full diff and reasoning.
dexhunter added a commit to dexhunter/parameter-golf that referenced this pull request Apr 14, 2026
…al_bpb 1.07193 (3-seed mean)

Novel multi-phase global SGD during phased TTT evaluation.
Builds on PR openai#1530 (@samacqua) + PR openai#1610 (@romeerp) phased TTT concept.
3-seed mean: 1.07193 BPB (2.76890 nats), std 0.00063.
Seeds: 42, 0, 1234. All artifacts <16 MB.
kailean added a commit to kailean/parameter-golf that referenced this pull request Apr 15, 2026
…ropy regularization)

Adapted from SOTA PR openai#1530 (samacqua):
- collect_hessians(): activation statistics from real calibration data
- gptq_quantize_weight(): full GPTQ with Cholesky decomposition + error propagation
- gptq_mixed_quantize(): per-layer adaptive clip (matrix_clip_sigmas=12.85, embed=20.0)
- entropy_regularization_loss(): CAT loss for compressibility training
- Supports int5/int6/int8 mixed precision
ChideraIbe123 pushed a commit to ChideraIbe123/parameter-golf that referenced this pull request Apr 15, 2026
Adapted from PR openai#1530 @samacqua (linear_leaky_relu_square_kernel).
The kernel fuses matmul(x, W_up.T) with LeakyReLU(0.5)**2 activation
into a single Triton kernel using TMA (Hopper H100). Saves the
(B, T, 4D) pre-activation HBM round-trip in the forward; in backward,
reuses the same kernel to apply the activation gradient to the
incoming grad_output before the weight-gradient matmul.

Gated by FUSED_MLP_ENABLED=1. When set, every Block's MLP uses the
fused kernel during training. Falls back gracefully if Triton or TMA
unavailable.

Reference: PR openai#1530 @samacqua. Expected: 5-10% training speedup on
MLP-dominated blocks, more steps in the 600s cap, ~0.002-0.005 BPB
improvement from additional training.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
ChideraIbe123 pushed a commit to ChideraIbe123/parameter-golf that referenced this pull request Apr 15, 2026
This is a from-scratch Triton kernel (not just a copy) that fuses
THREE operations into one kernel: RMSNorm (per-row inverse rms)
multiplied by ln_scale, then matmul with W_up, then LeakyReLU(0.5)^2
activation. Saves the (B*T, D=512) x_normed HBM round-trip that
PR openai#1530 leaves on the table.

Two new kernels:
- _rms_inv_kernel: per-row inverse-rms reduction (small)
- _fused_rms_linear_lrs_kernel: takes inv_rms + ln_scale, applies
  the rmsnorm scaling row-wise during the K loop, then matmul +
  activation (extends PR openai#1530's persistent-TMA structure)

Custom backward implements the full RMSNorm chain rule:
  dx = ln_scale * inv_rms * (dx_normed - x * inv_rms^2 * mean(dx_normed*x))
This makes the backward correct without saving x_normed (which would
defeat the HBM savings).

Block.forward branches on mlp.use_fused: when fused, it skips the
eager mlp_norm() call and passes raw x + ln_scale_factor to MLP,
which then runs the fused kernel that does normalization internally.

Gated by FUSED_MLP_ENABLED=1. Eager fallback unchanged.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
ChideraIbe123 pushed a commit to ChideraIbe123/parameter-golf that referenced this pull request Apr 15, 2026
Adds _FusedSimpleMLPFn alongside _FusedRMSMLPFn, selectable by
FUSED_MLP_FULL=1 env var. The simple variant does RMSNorm in eager
PyTorch (like PR openai#1530) and only fuses matmul + LeakyReLU^2; my v1
variant (_FusedRMSMLPFn) additionally fuses per-row inv_rms * ln_scale
scaling into the K-loop.

Purpose: A/B test whether my RMSNorm fusion addition is counterproductive.
If simple > v1, per-K scaling overhead eats HBM savings.
If simple == v1, kernel choice is saturated.

Reuses same Triton kernel via FUSE_RMS constexpr branch.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants