Record: Varlen attention + fused MLP + doc-independent TTT (1.07336)#1530
Record: Varlen attention + fused MLP + doc-independent TTT (1.07336)#1530samacqua wants to merge 4 commits intoopenai:mainfrom
Conversation
…g + Muon 0.97 — val_bpb 1.07747 (3-seed mean) - 3-seed mean: 1.07747 BPP (std 0.00064) / 2.78321 nats - ~15.99 MB artifact, 8×H100 SXM, 600s - VarLen attention (within-document only), doc-independent LoRA TTT - Parameter banking + triple depth recurrence + parallel residuals - PyTorch MLP fallback (no Triton/CUTLASS dependency) - Based on PR openai#1530, PR openai#1523, PR openai#1514
|
I may be missing something, but I think there is one higher-scrutiny In the TTT path, the compile warmup appears to use actual validation tokens before the main eval loop, and it also does My read of the current guidance is:
If that reading is right, would you be willing to switch the warmup to:
That would make the legality story much cleaner. |
|
@dexhunter it could honestly just be commented out, given that warmup + eval time is still < 600s. But it shouldn't matter -- training warmup does the same thing, parameters and optimizer states are reset. As a sanity check I re-ran TTT on seed 2 w/ warmup commented out, and the loss was within expected variance between runs (actually did slightly better): But given that making a change + re-running what take an hour of 8xh100, I will only if it is a blocker. |
Community Review — VarLen attention + fused MLP + doc-independent TTTThanks @samacqua. Doc-independent TTT via cu_seqlens boundary isolation is a genuinely interesting approach to the causal-dependence question the SLOT cluster has been bouncing around. One import blocker, then a deeper question on the doc-independence claim. What I found (head SHA
"Doc-independent TTT" — the interesting idea. My read is that if the LoRA (or whatever TTT-like adaptation you're running) respects the same The open question is whether the adaptation state itself is per-document or per-batch. I couldn't find an Import blocker (smoke test). The CPU smoke on CT2038 hit: My flash_attn stub covers Questions
Compliance summary (partial)
Verdict: LOOKS INTERESTING, NEEDS AUTHOR CLARIFICATION on the TTT adaptation path. Recommendation to @cocohearts @valerio-oai @0hq @yuzhougu-oai @notapplica: HOLD pending author clarification on where and how the doc-independent TTT runs. If the adaptation respects cu_seqlens boundaries and the temporal ordering is score-before-adapt at the document level, this is a genuinely clean path out of the SLOT compliance bind, and I'd flip to MERGE. Reviewed by @MatoTeziTanka — The Agora. CPU smoke test (CT2038 proteus-engine, 2026-04-11): IMPORT_FAIL due to |
|
@MatoTeziTanka look at
So yes, it respects the same document boundaries. It is strictly harder (and more valid imo) than TTT on the full validation sequence autoregressively. See the "Methods" section of this blog for clarity. |
… (3-seed mean) PR openai#1530 v2 base + warmdown_frac=0.75 + TTT_CHUNK_SIZE=48 + Muon 0.97. 3-seed mean: 1.07406 (std 0.00132), 2.77441 nats. Delta vs merged SOTA (openai#1493): -0.01491 nats (clears 0.005 bar by 3.0x). All artifacts < 16 MB, train < 600s, eval < 225s.
|
The pattern structurally matches what @valerio-oai flagged as invalid in #677 ("adapt on validation before the reported eval pass"). Even though LoRA resets per batch, the compile warmup still runs backward+step on val tokens before the eval loop. Since you confirmed the fix is within variance, it would be worth switching to random/synthetic tokens to avoid any ambiguity during review. |
This submission actually looks good to me. They don't "adapt on validation before the reported eval pass", as the warmup/compilation throws away the updates. The final result wouldn't change at all if they replaced those validation tokens in warmup with any other tokens. The author even notes that the result is unchanged, even when they comment out the warmup. |
|
Hey btw @samacqua the training script will crash without the |
|
@msisovic fixed. Thanks! |
…nto SP8192 stack Adds records/track_10min_16mb/2026-04-15_SP8192_VarLen/train_gpt.py (readable 1446 lines): - flash_attn_varlen_func with cu_seqlens document packing in CausalSelfAttention - DocumentPackingLoader replacing per-sequence shuffling for training batches - Triton linear+LeakyReLU(0.5)^2 fused MLP kernel with two-lane output split - cu_seqlens threaded through Block / GPT forward; max_seqlen pinned to train_seq_len to avoid torch.compile recompilation on varying ints Retains full SP8192 stack: depth recurrence (2 loops, layers 3-5), parallel residuals from layer 7, QK-Gain 5.0, GPTQ INT6 + INT8 embed + SDClip 12.85, score-first chunk TTT, fused-softcap-ce eval kernel, SP8192 tokenizer. Eval paths unchanged (ShuffledSequenceLoader + flash_attn_3_func when cu_seqlens is None). New knobs: USE_VARLEN, USE_FUSED_MLP, CU_BUCKET_SIZE, MAX_DOC_LEN. Requires flash_attn_3 wheels (cu128_torch291) and Triton 3.2+ for TensorDescriptor API. Compiles clean locally. Awaiting 8xH100 smoke test to validate end-to-end.
…is regressive on our SP8192 + depth recurrence stack Three configs tested at seed 42 on 8xH100 SXM: - VarLen + Fused MLP: 1.93 pre-quant val_bpb, 1440 steps, 2.3M tok/s (3.4x slower) - Fused MLP only: 1.110 pre-quant val_bpb, 2581 steps, 3.4M tok/s (2.3x slower) - Pure baseline reproduction: pod terminated mid-run before completion Root cause: VarLen + depth recurrence + fullgraph torch.compile triggers cascading shape recompilations (combinatorial explosion of loop_iter x cu_seqlens shape) that overflow even a 64-entry compile cache. Fused MLP Triton kernel has per-call TensorDescriptor allocation overhead that doesn't amortize for our hidden_dim=2048. Conclusion: do not ship this port. PR openai#1572 (1.07974) remains best submission. Move 2 (per-layer GPTQ from PR openai#1586) and Move 3 (LoRA TTT from PR openai#1530, eval-only so no torch.compile recompile concern) are still viable next directions.
…192 stack Config-level changes only, no kernel/compile changes that could interact with our depth recurrence stack (unlike VarLen port in submission/sp8192-varlen-frontier): - MLP_CLIP_SIGMAS 12.0 (tight, preserve MLP precision) - ATTN_CLIP_SIGMAS 13.0 (looser, save bytes on attention weights) - EMBED_BITS 8 -> 7 with EMBED_CLIP_SIGMAS 20.0 -> 15.0 (~530 KB artifact savings) - MATRIX_LR 0.022 -> 0.026 (dexhunter 6-point sweep optimum) - WARMDOWN_FRAC 0.72 -> 0.75 (longer peak LR window) Dexhunter measured 1.07493 BPB (3-seed mean) applying these against PR openai#1530 base. Against our 1.07974 SP8192 baseline the expected delta is in the 0.003-0.005 BPB range; the adaptive clip is stack-independent and the embed-bits + LR tweaks are universal. Fresh branch from upstream/main per PR hygiene (PR openai#1572 untouched).
Replaces chunk-based score-first SGD TTT with doc-independent batched LoRA adaptation at eval time. Eval-only, training path unchanged, so none of the torch.compile recompile concerns from VarLen apply here. New machinery: - BatchedLinearLoRA: per-document LoRA factors (bsz, rank, in_features) - BatchedTTTLoRA: module holding Q/K/V/O/MLP-up/lm_head LoRAs per block - CausalSelfAttention.forward accepts optional lora_q/k/v/o (adds to projections) - MLP.forward accepts optional lora_up (adds to fc projection) - Block.forward threads the LoRA args - GPT.forward_ttt runs the full forward stack with LoRAs injected, returns per-token loss (reshaped to input shape) - ttt_lora_evaluate orchestrates score-first doc batches with distributed counter-based work stealing across ranks Compliance: each doc fully scored BEFORE its LoRA adapts (score-first). Each doc gets fresh LoRA weights (doc-independent, no cross-doc leakage). Standard causal attention throughout. No SLOT, no pre-quant TTT, no ETLB, no n-gram. Samacqua reports ~-0.008 BPB vs sliding-window eval on his stack. If it translates to our stack, would put us ~1.072-1.073, below the current 1.0728 frontier. TTT_MODE=lora is default. Set TTT_MODE=chunk to fall back to the old chunk- based score-first TTT.
Adds flash_attn_varlen_func path for within-document attention during training. Attention is restricted to doc boundaries detected via BOS token positions in each batch, eliminating cross-doc attention noise. Changes: - Import flash_attn_varlen_func alongside flash_attn_3_func - Add VARLEN_ENABLED and BOS_TOKEN_ID env var hyperparams - Add _build_cu_seqlens_from_batch helper (detects BOS, builds cu_seqlens) - Thread cu_seqlens/max_seqlen through CausalSelfAttention -> Block -> GPT - Branch in attention: varlen when cu_seqlens provided, else flash_attn_3 - Switch torch.compile to fullgraph=False when VARLEN_ENABLED=1 (data-dep branch) - Training step builds cu_seqlens per batch and passes to model Eval path unchanged. When VARLEN_ENABLED=0 (default) behavior is identical to PR openai#1493 reference. Compliance unchanged (training-only change, causality preserved by causal=True flag). Reference: PR openai#1530 @samacqua, PR openai#1536 @dexhunter Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
…zation Council consensus across 3 models (Gemini, Sonnet, Nemotron) + followup analysis identified these as the high-EV targeted fixes without lineage switch. LoRA semantics (samacqua PR openai#1530 exact match): - mlp_loras: dim -> dim (was dim -> hidden_dim), applied as parallel residual-level bypass at Block forward (was inner tweak inside MLP.forward) - o_lora: input is pre-attention normalized residual n (was attention output y) - MLP.forward reverted to no-lora signature (cleaner; mlp_lora lives at Block level) - CausalSelfAttention.forward now only takes lora_q/k/v (o_lora moved to Block) Pod speedgate at step 20 (env var POD_SPEEDGATE_MS, default 0 = disabled): - Measures ms/step at step 20 - RuntimeError abort if above threshold - Saves ~$5 per bad pod per council recommendation Looped-layer quantization (env var LOOP_CLIP_SIGMAS, default 10.0): - Tighter clip_sigmas for blocks.3/4/5 (the NUM_LOOPS=2 recurrent layers) - Motivation per Sonnet: quantization error compounds 2x through recurrence, and GPTQ error amplifies ~900x over 3 cycles per Issue openai#140 - Only active when NUM_LOOPS > 0 No training changes; all three fixes are eval-only behavior + a safety gate. Training path semantics unchanged from baseline.
Drops in samacqua's varlen + fused MLP + doc-independent batched LoRA TTT verbatim, with one targeted change to the eval-side TTT compile warmup: swap real validation-token slicing for torch.randint of identical shape. Eliminates validation token exposure during compile cache population without changing the cache itself (shapes drive recompilation, not contents). See record README for the full diff and reasoning.
…al_bpb 1.07193 (3-seed mean) Novel multi-phase global SGD during phased TTT evaluation. Builds on PR openai#1530 (@samacqua) + PR openai#1610 (@romeerp) phased TTT concept. 3-seed mean: 1.07193 BPB (2.76890 nats), std 0.00063. Seeds: 42, 0, 1234. All artifacts <16 MB.
…ropy regularization) Adapted from SOTA PR openai#1530 (samacqua): - collect_hessians(): activation statistics from real calibration data - gptq_quantize_weight(): full GPTQ with Cholesky decomposition + error propagation - gptq_mixed_quantize(): per-layer adaptive clip (matrix_clip_sigmas=12.85, embed=20.0) - entropy_regularization_loss(): CAT loss for compressibility training - Supports int5/int6/int8 mixed precision
Adapted from PR openai#1530 @samacqua (linear_leaky_relu_square_kernel). The kernel fuses matmul(x, W_up.T) with LeakyReLU(0.5)**2 activation into a single Triton kernel using TMA (Hopper H100). Saves the (B, T, 4D) pre-activation HBM round-trip in the forward; in backward, reuses the same kernel to apply the activation gradient to the incoming grad_output before the weight-gradient matmul. Gated by FUSED_MLP_ENABLED=1. When set, every Block's MLP uses the fused kernel during training. Falls back gracefully if Triton or TMA unavailable. Reference: PR openai#1530 @samacqua. Expected: 5-10% training speedup on MLP-dominated blocks, more steps in the 600s cap, ~0.002-0.005 BPB improvement from additional training. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This is a from-scratch Triton kernel (not just a copy) that fuses THREE operations into one kernel: RMSNorm (per-row inverse rms) multiplied by ln_scale, then matmul with W_up, then LeakyReLU(0.5)^2 activation. Saves the (B*T, D=512) x_normed HBM round-trip that PR openai#1530 leaves on the table. Two new kernels: - _rms_inv_kernel: per-row inverse-rms reduction (small) - _fused_rms_linear_lrs_kernel: takes inv_rms + ln_scale, applies the rmsnorm scaling row-wise during the K loop, then matmul + activation (extends PR openai#1530's persistent-TMA structure) Custom backward implements the full RMSNorm chain rule: dx = ln_scale * inv_rms * (dx_normed - x * inv_rms^2 * mean(dx_normed*x)) This makes the backward correct without saving x_normed (which would defeat the HBM savings). Block.forward branches on mlp.use_fused: when fused, it skips the eager mlp_norm() call and passes raw x + ln_scale_factor to MLP, which then runs the fused kernel that does normalization internally. Gated by FUSED_MLP_ENABLED=1. Eager fallback unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Adds _FusedSimpleMLPFn alongside _FusedRMSMLPFn, selectable by FUSED_MLP_FULL=1 env var. The simple variant does RMSNorm in eager PyTorch (like PR openai#1530) and only fuses matmul + LeakyReLU^2; my v1 variant (_FusedRMSMLPFn) additionally fuses per-row inv_rms * ln_scale scaling into the K-loop. Purpose: A/B test whether my RMSNorm fusion addition is counterproductive. If simple > v1, per-K scaling overhead eats HBM savings. If simple == v1, kernel choice is saturated. Reuses same Triton kernel via FUSE_RMS constexpr branch. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Record: Varlen attention + fused MLP + TTT
val_loss: 2.77261 | val_bpb: 1.07336 | ~15.99 MB | 8×H100 SXM, 587s train + ~340s TTT eval
Best PR bpb (PR #1529): bpb=1.0753 (delta=0.0019), loss=2.7776 (delta=0.0050)
Merged record bpb (PR #1493): bpb=1.0810 (delta=0.0076), loss=2.7923 (delta=0.0197)
Increased training speed ~5% via variable length attention, a fused MLP triton kernel (no
cutlass_evt_fusiondep), and grouping together small parameters, yielding ~.002 nats when comparing sliding window eval. Re-added document-based LoRA TTT which has no inter-sequence dependence and improves over strided evaluation by ~.008 nats.Main changes
Applied changes from my old PR to a recent record PR: #1523. But PR #1552 beat my previous bpb before I submitted the PR, so I incorporated their (orthogonal) improvements. Most of below is copied from my previous PR #1354.
This involves 3 things:
1. Variable length attention (~2% faster training, ~0.001 nats)
Replaced dense causal attention with Flash Attention 3's
flash_attn_varlen_func. During training, documents are packed into flat token buffers withcu_seqlensboundaries so attention is computed within documents only — the model never attends across unrelated documents that happen to be adjacent in a batch.This does two things:
100 * 100**2 = 1Mattention FLOPs vs10 * 1000**2 = 10Mwith dense attention.2. Fused MLP + grouped small params (~3% faster training, ~0.001 nats)
A custom Triton kernel (
linear_leaky_relu_square_kernel) fuses the up-projection, LeakyReLU(0.5)² activation, and squaring into a single kernel. Based on similar kernels from modded-nanogpt. I also group the many tiny replicated scalar/control gradients into a single all-reduce to avoid a pile of tiny collectives.3. Doc-based test-time training (TTT) (~0.008 nats over sliding window)
Although it is technically legal in this competition to train on tokens from previous documents in the dataset, I am spiritually opposed to this. Under the current formulation, if the eval set was bigger, the expectation of the loss would be lower which seems broken. So in this implementation, there is score-first TTT applied to each sequence in the validation set independently (and efficiently using batched LoRAs), which is strictly harder.
Re-adds LoRA-based TTT, based on my old implementation, but > 2x faster which allows for using smaller chunk sizes which leads to better performance. This is an instance of "Case 3" according to this classification.
It's interesting to note that adding test-time training improves loss more than adding ~215 steps. These 215 steps train on
786432*215=169,082,880tokens to gain ~.002 nats. The average sequence length in the validation set is ~200 tokens which means test-time training here gains ~.003 nats / 800 tokens on average (valid bc sequences are trained independently). So, in a way, TTT is~(.003/800) / (.002/169082880) >= 300ktimes more token efficient than pre-training: it helps to be in distribution :)Other small changes
Made some changes to make replication and dev based on this PR easier:
Replicating runs + dev