Skip to content

Non Record: MuonEq-R + Context-Only SLOT + QK_GAIN=5.0 — val_bpb 1.1027 (3-seed mean)#1217

Open
bigbag wants to merge 2 commits intoopenai:mainfrom
bigbag:submission/muoneq-causal-slot
Open

Non Record: MuonEq-R + Context-Only SLOT + QK_GAIN=5.0 — val_bpb 1.1027 (3-seed mean)#1217
bigbag wants to merge 2 commits intoopenai:mainfrom
bigbag:submission/muoneq-causal-slot

Conversation

@bigbag
Copy link
Copy Markdown

@bigbag bigbag commented Apr 1, 2026

Summary

val_bpb: 1.1027 (3-seed mean, std 0.0011) | ≤15.80 MB | 8×H100 SXM | ~88.8ms/step | ~6654 steps

Built on PR #1179 (@dexhunter) with three additions:

  • MuonEq-R (row-normalization before Newton-Schulz) — from arXiv:2603.28254, ~15 lines
  • QK_GAIN_INIT=5.0 — our hyperparameter sweep across 45 experiments, monotonic gains from 1.5→5.0
  • Context-Only SLOTcausal variant that optimizes delta using only already-scored context tokens (see Legality section)

3-Seed Results

Seed Context-SLOT BPB TTT BPB Steps ms/step Artifact
1337 1.10166 1.11008 6660 88.8 15,795,518
42 1.10378 1.11206 6650 88.9 15,793,163
2024 1.10271 1.11108 6653 88.9 15,796,779
Mean 1.10272 ± 0.00106 1.11107 6654 88.8 15,795,153

Beats merged SOTA (PR #1019, 1.1147) by 0.012 BPB (p ≪ 0.01).

Improvement Breakdown

Technique BPB Impact Cumulative
PR #1179 base (sliding, no SLOT) 1.1105 1.1105
+ MuonEq-R optimizer -0.001 ~1.1095
+ QK_GAIN=5.0 -0.001 ~1.1090
+ Context-Only SLOT (8 steps, lr=0.005) -0.006 ~1.1027

Legality

Training (≤600s on 8×H100)

  • Standard transformer training with Parallel Muon optimizer
  • MuonEq-R: row-normalization before Newton-Schulz orthogonalization (arXiv:2603.28254). Standard optimizer improvement — no rule restricts it.
  • QK_GAIN_INIT=5.0: hyperparameter choice — no rule restricts it
  • Full GPTQ calibration runs within the 600s training budget
  • No validation data accessed during training

Evaluation — Context-Only SLOT (LEGAL, causal by construction)

This is a causal variant of SLOT that addresses all prior causality concerns.

Protocol for each sliding window (seq_len=2048, stride=64):

  1. Hidden states computed for all 2048 positions under torch.no_grad() — model weights frozen, no gradient.
  2. Delta optimization: A 512-dim additive delta is optimized using cross-entropy loss on context positions only (positions 0 to 1983). The 64 new tokens being scored (positions 1984–2047) are excluded from the loss computation and contribute zero gradient.
  3. Scoring: Final logits computed for all positions with the optimized delta applied. NLL recorded for the 64 new positions.

Why this is causal:

  • The delta is learned exclusively from already-scored tokens (the context window)
  • The 64 new tokens at the end are never used for optimization — they only appear in the scoring step
  • This is equivalent to: "observe past tokens → learn a bias → predict future tokens"
  • The gradient signal comes 100% from the past, never from the future
  • With stride=64 and seq_len=2048, 96.9% of the window is context (already scored in previous windows)

Comparison to standard SLOT (which had causality concerns):

  • Standard SLOT: optimizes delta on ALL positions including the new 64 → future tokens influence the delta → causality concern
  • Context-Only SLOT: optimizes delta on context positions ONLY → future tokens have zero influence → trivially causal

This approach was proposed by @AnubhavBharadwaaj (original SLOT author) as a defensible causal variant in PR #1172 discussion, with claimed ~0.0002 BPB difference from standard SLOT.

Evaluation — TTT (score-first, ≤10 min additional)

No illegal techniques

  • ❌ No n-gram cache
  • ❌ No two-pass rescoring
  • ❌ No min-NLL epoch selection
  • ❌ No eval-time GPTQ on training data
  • ❌ No oracle/hindsight selection
  • ❌ No future-token information in SLOT optimization

Reproduction

pip install brotli
QK_GAIN_INIT=5.0 SLOT_ENABLED=1 SLOT_STEPS=8 SLOT_LR=0.005 SEED=$SEED \
  torchrun --standalone --nproc_per_node=8 train_gpt.py

Training: ~600s. Eval (sliding + context-only SLOT): ~190s. Total: ~13 min end-to-end.

Acknowledgments

PR #1179 (@dexhunter), MuonEq (arXiv:2603.28254), SLOT (Hu et al. arXiv:2505.12392v2), PR #549 (legal TTT pattern), @AnubhavBharadwaaj (context-only SLOT proposal).

🤖 Generated with Claude Code

Pavel Liashkov and others added 2 commits April 1, 2026 17:36
3-seed mean 1.10272 BPB (std 0.00106), beats merged SOTA by 0.012.
Built on PR openai#1179 with MuonEq-R optimizer, context-only SLOT
(causal variant), and QK_GAIN=5.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- train_gpt.py: LZMA2+base85 self-extracting wrapper (saves 49KB artifact)
- Added train_seed1337.log, train_seed42.log, train_seed2024.log
- Updated code_bytes in submission.json

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@clarkkev
Copy link
Copy Markdown

clarkkev commented Apr 1, 2026

I think this version of SLOT may still leak information. Restricting the update to context tokens fixes the issue for a single window. However, in the current setup, minibatches contain overlapping windows. In that case, the train update from a later-positioned window in the minibatch can leak information to the earlier windows.

@AnubhavBharadwaaj
Copy link
Copy Markdown

AnubhavBharadwaaj commented Apr 1, 2026

@clarkkev — good catch. The cross-window gradient leak through a shared delta is a valid concern. Here's the precise fix and analysis.

The problem, stated precisely

If delta has shape [1, 1, 512] and a batch contains overlapping windows w1 and w2 where w2's context includes w1's scored positions, then during delta optimization:

$$\nabla_\delta \mathcal{L} = \nabla_\delta \sum_{t \in w1_{\text{context}}} \mathcal{L}_t + \nabla_\delta \sum_{t \in w1_{\text{scored}}} \mathcal{L}_t + \nabla_\delta \sum_{t \in w2_{\text{context}}} \mathcal{L}_t + \nabla_\delta \sum_{t \in w2_{\text{scored}}} \mathcal{L}_t$$

The $w2_{\text{context}}$ term may include tokens that are also in $w1_{\text{scored}}$, leaking future information into w1's delta. Valid concern.

The fix: per-window delta with masked loss

# OLD (shared delta — has cross-window leak):
delta = torch.zeros(1, 1, d_model, device=device, requires_grad=True)

# NEW (per-window delta — no cross-window leak):
delta = torch.zeros(bsz, 1, d_model, device=device, requires_grad=True)

With shape [bsz, 1, 512], each window's delta is an independent slice along the batch dimension. The gradient of delta[i] depends only on window i's loss terms — PyTorch's autograd naturally separates the batch dimensions. No cross-window gradient flow.

AdamW's running moments are also per-element, so each window's delta gets its own momentum and variance tracking.

The loss mask remains per-window: for window i, compute CE loss only on context positions (0 to seq_len - stride - 1), zero out the stride positions. Each delta[i] is optimized exclusively from its own window's already-scored context.

Edge case: the first window (ws=0)

When ws=0, the code sets s=0 meaning all 2048 tokens are scored and there are zero context tokens. Context-only SLOT has nothing to optimize on, so delta stays at zero for this window. This is the correct conservative behavior — no calibration when there's no past data.

This affects 2048 out of ~62M total tokens (0.003%) — negligible impact on final BPB.

On the cascade concern

One might ask: "if window w_prev used SLOT to score its tokens, doesn't that contaminate the context for window w_next?"

No — the validation tokens themselves are fixed integers from the dataset. SLOT affects the score (NLL) assigned to each token, not the token values. Window w_next's context contains the same token IDs regardless of what score w_prev assigned. The hidden states H = forward_hidden(x) are deterministic functions of the fixed token sequence under frozen model weights. No SLOT output propagates forward.

Performance impact — honest assessment

I won't claim "negligible" without data. A shared delta aggregates gradient from $1984 \times \text{bsz}$ context tokens across the batch. A per-window delta sees only 1984 tokens. More gradient signal could mean better calibration. Alternatively, a shared delta averages over diverse local distributions, while a per-window delta specializes to each window's local context.

Whether shared or per-window is better is an empirical question — but both are strictly causal.

If the concern is that per-window delta might perform worse, note that the shared delta's advantage (more tokens) comes partly from the cross-window leak that @clarkkev identified. The "clean" shared delta — one that somehow excludes scored tokens from all windows in the batch — would see approximately the same effective token count as per-window, just different tokens.

Summary

Per-window delta ([bsz, 1, 512]) with context-only loss masking is strictly causal, handles the first-window edge case correctly, involves a one-line shape change, and has no cascade effects. The only open question is whether per-window calibration matches batch-wide calibration empirically. I'd welcome a comparison run.


@0hq @valerio-oai — context-only SLOT with per-window delta has zero information flow from scored tokens to the optimization. Is this the variant the organizers would accept?

@bigbag
Copy link
Copy Markdown
Author

bigbag commented Apr 1, 2026

Thanks @clarkkev and @AnubhavBharadwaaj for the detailed analysis. The cross-window gradient leak through a shared delta is a valid concern.

Fix implemented and tested

Changed delta shape from [1, 1, 512] to [bsz, 1, 512] (per-window delta). Each window's delta is independent — PyTorch autograd naturally separates batch dimensions. Zero cross-window gradient flow.

Result

Per-window delta is strictly causal but costs ~0.010 BPB:

Variant Sliding+SLOT BPB Delta
Shared delta [1,1,512] (original) 1.1017
Per-window delta [bsz,1,512] (fixed) 1.1120 +0.010
No SLOT at all 1.1104

Per-window SLOT provides almost no benefit over pure sliding (1.1120 vs 1.1104 = only -0.002). The shared delta's advantage came from aggregating gradient across 1984×32 = 63,488 context tokens, vs only 1984 per window.

@bigbag bigbag changed the title Record: MuonEq-R + Context-Only SLOT + QK_GAIN=5.0 — val_bpb 1.1027 (3-seed mean) Non Record: MuonEq-R + Context-Only SLOT + QK_GAIN=5.0 — val_bpb 1.1027 (3-seed mean) Apr 1, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants