Record: Fused MLP (Triton+CUTLASS EVT) + Fast Causal N-Gram Tilt & Subword Certainty (3-seed mean)#1105
Conversation
ba665dd to
64ce201
Compare
9b27cf4 to
c27131c
Compare
c27131c to
0df40cc
Compare
|
@abaybektursun - this is a fantastic write-up! Congrats on the SLOT improvement. If you need to free up even more room, you should check out the shrink.py script I used in PR 1089. I was able to shrink the train_gpt.py file by ~100KB. That might let you reduce pruning and/or promote one more group to int6. |
|
Ohhh I think with newer Pytroch performance and speed will be even better! I will try it when I can get my hands around 8xH100s |
| @@ -0,0 +1,110 @@ | |||
| # Record: Fused MLP (Triton+CUTLASS EVT) + MLP 3.5× + Mixed int5/int6 + Brotli | |||
…seed mean) Based on PR openai#1105 (abaybektursun) with improvements: - Window attention (size=512) on layers 2,4,6,8,10 via FA3 - Mixed seq_len training: 5 GPUs at 2048x36 + 3 GPUs at 6144x10 - Train-data GPTQ calibration (14s vs 220s AR self-gen) - Auto eval_seq_len detection from max train seq_len - Causal n-gram fix (within_hint/word_hint prefix-only) - Sliding window eval at seq_len=6144, stride=128 3-seed results (sliding window bpb): seed 1337: 1.1077 seed 42: 1.1083 seed 7: 1.1091 mean: 1.1084 (vs leader 1.1147) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Based on PR openai#1105 (abaybektursun) with improvements: - Window attention (size=512) on layers 2,4,6,8,10 via FA3 - Mixed seq_len training: 5 GPUs at 2048x36 + 3 GPUs at 6144x10 - Train-data GPTQ calibration (14s vs 220s AR self-gen) - Auto eval_seq_len detection from max train seq_len - Causal n-gram fix (within_hint/word_hint prefix-only) - Sliding window eval at seq_len=6144, stride=128 3-seed results (sliding window bpb): seed 1337: 1.1077 seed 42: 1.1083 seed 7: 1.1091 mean: 1.1084 (vs leader 1.1147) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
3b4bcf1 to
30c83c7
Compare
|
@abaybektursun - This is a great submission! If you haven't experimented with it yet, I'd recommend trying different negative slopes for the squared leaky ReLU activation (leaky_relu(x, slope).square()). I found ~0.3 was optimal on average with a similar architecture, but there's a depth-dependent pattern worth exploiting. I added a learned per-channel alpha parameter (initialized at 0.3) to each MLP layer. After three iterative runs — each warm-started from the prior run's endpoints — the values converged: layer 0 settled near 0 (essentially ReLU²), middle layers stayed around 0.3, and the deepest layers preferred 0.4–0.47. One gotcha: since the activation squares the output, alpha and -alpha produce identical results — only the magnitude matters. Once converged, I hardcoded the per-layer slopes so I wasn't wasting parameters on values that had already stabilized. Also, I wanted to share these results/learnings in case you are thinking about tokenizer experiments: |
|
@abaybektursun - since you have almost 500KB of headroom on the artifact you might be able to increase MLP to 3.625x which should help given your finding that the model was parameter starved in MLP. You can also get another ~100KB of headroom using this script to shrink your train_gpt.py before submission: |
30c83c7 to
17d5028
Compare
|
I am sleep deprived so things are bit messy rn, but data is accurate will clean up in the morning. |
…b 1.0962 (3-seed mean) SP4608 tokenizer (vocab 4608), MLP 3.5×, all-int6 GPTQ (66 layers), QK_GAIN=5.0, fast causal n-gram tilt (~295× speedup), Brotli-11, LR floor 0.05. 3-seed mean (42/314/1337): 1.0962 BPB submission, 1.0987 post-quant sliding. Eval: ~87s wall on 8×H100. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
dc03ff5 to
f734fda
Compare

Results: val_bpb 1.0962 (3-seed mean) | 8×H100 SXM | 600s
sp4608 tokenizer, N_INT6=66 (all int6), QK_GAIN=5.0
Post-quant BPB and Submission BPB come from two separate evaluation runs of the same model, not two passes. The sliding-window eval measures the quantized neural model alone; the n-gram eval adds the causal tilt on top. Both are single-pass, single-run evaluations — we report both for ablation (isolating the n-gram contribution).
N-gram eval: ~87s wall (~50s loop, 3-seed mean). Tilted: 23.8% (mean) | Hits: 53.2% (mean).
What does the improvement look like? Side-by-side generation (temp=0.8)
Prompt (50 tokens): "Insurance Company Declares Living Man Dead George Johannesen is very much alive. Which is why it was so surpr"
The old model drifts into incoherence ("Rachel Drobles... techniques of the car industry... Lyon Man is dead"). The new model stays on topic — insurance, health measurement, living man — and maintains grammatical coherence throughout. Both are wrong (the real text is about a cancelled driver's license), but the new model's errors are at least topically plausible.
Changes vs our PR 1019
1. Fused MLP Kernels: Triton TMA Forward + CUTLASS EVT Backward
Forward (Triton TMA): Fuses
F.linear(x, up_w) → LeakyReLU(0.5) → squareinto a single kernel. The 302MB intermediate never touches HBM.Backward (CUTLASS EVT): Fuses
(go @ down_w.T) * act_gradinto a single CUTLASS 3.x kernel via Epilogue Visitor Tree. The elementwise multiply runs in the GEMM epilogue while tiles are still in registers — eliminating one 302MB write + read per layer.Key design insight — pre-computed activation gradient: We store the activation gradient in the forward pass instead of the pre-activation:
The identity
post = 0.5 · act_grad · preholds for both signs because:This eliminates all branching from the backward, reducing the CUTLASS EVT epilogue to a trivial 3-node tree:
Sm90EVT<multiplies, AccFetch, AuxLoad>. No conditionals in the kernel.CUTLASS EVT is a hard dependency — no silent fallback. See Appendix A.3 and A.4 for detailed benchmarks.
2. Fast Causal N-Gram Tilt & Subword Certainty (~0.0025 BPB, ~295× speedup)
Architecture Shift: Sparse Auxiliary Memory
This PR replaces the old eval-time n-gram mixing path with a fast, legal, single-pass causal n-gram tilt system. The core change is that the n-gram is no longer treated as a second language model. Instead, it acts as a sparse auxiliary memory that proposes a hinted token from the strict prefix, while the neural model remains the full normalized distribution. We then apply a one-token exponential tilt directly on the GPU.
Motivation & Interpretability
This work was guided by the interpretability results in our PR 1019 model autopsy and PR 1105 model autopsy. Those analyses showed that the model is not broadly weak at language modeling; it is specifically weak at exact copy/repetition. In particular, it has very limited induction capability, while much of the remaining loss is in categories like numbers, punctuation, and whitespace where generic short-order n-grams do not help much.
That changed the design target. Instead of building “better PPM everywhere,” we focused on the narrow places where n-grams are actually complementary:
The Key Insight: Mechanical Subword Certainty
Initially, within-word BPE completions seemed redundant since the neural baseline already assigns high probability to these tokens. However, the most significant BPB drop (~0.002, characterized on the MLP 3.0× model) was unlocked by aggressively lowering the
within_thresholdfrom 0.80 to 0.25, allowing the expert to fire on 35.7% of positions.Why it works: While the neural model knows subword patterns, it inherently hedges its bets by distributing probability mass across alternatives. The n-gram expert acts as a mechanical override, capturing the absolute certainty of BPE completions that the neural model refuses to assign a 1.0 probability to.
Measured eval time (8×H100): ~88s (setup 37s + loop 50s). Our C++ open-addressing hash achieves ~295× speedup over naive Python PPM implementations — the enabling constraint that makes causal n-gram tilt feasible within the 600s eval budget. See Appendix A.5 and A.6 for full engineering details and benchmarks.
3. Brotli-11 Compression (replaces LZMA-9)
−581 KB (−5.9%) vs LZMA-9. Independently discovered; PR 1089 (mikeapedia) also uses Brotli.
4. Memmap Multi-Shard Data Pipeline + GPU Prefetch
Coprime-stride sampling, daemon thread, CUDA stream prefetch. Credit: DeepReinforce (PR 726).
5. MLP 3.5× (1536 → 1792 hidden dim)
Motivated by mechanistic analysis: SVD analysis of our PR 1019 model showed MLP at 94.4% rank utilization (fully packed) while attention Q sat at 72.6% (spare capacity). The model was parameter-starved in MLP, not attention — so we made MLP wider.
Increases hidden dim from 3.0 × 512 = 1536 to 3.5 × 512 = 1792 (+2.88M params). With sp4608 (which removed BigramHash and SmearGate), the full model is 31.85M params and fits under 16 MB with uniform int6.
Impact: −0.003 BPB from capacity, +13ms/step on 2×H100 (bigger GEMMs). Credit: PR 185 (dttdrv), PR 344 (aryanbhosale).
6. LR Floor (0.05)
During warmdown, learning rate normally decays to 0. With
lr_floor=0.05, it stops at 5% of peak instead. Prevents the optimizer from stalling, which helps with quantization-sensitive weight distributions still being refined at end of training.Impact: ~0.001 BPB. Credit: PR 130 (mohosy).
7. Vocab 4608
Inspired by PR 1218 (Clark), which established 4096 as effective. We measured β(V) — bytes per token — across intermediate vocab sizes. Controlled comparison (same architecture, seed 42):
+2.2% bytes per token, +0.055 nats per-token loss. The larger vocab also freed ~1 MB by making BigramHash and SmearGate redundant (removed), enabling uniform int6 for all 66 layers.
Negative Results
Note: This model still has known inefficiencies — the sp4608 architecture has not been fully tuned (hyperparameters, layer count, MLP ratio, and quantization bit allocation were carried over from the sp1024 stack). We believe further BPB reductions are achievable.
Appendix
A.1 Prior Results
Prior results: sp1024, val_bpb 1.1052 (3-seed mean)
Mixed quantization: 10 layers int6, 56 layers int5, no pruning needed.
Calibration regression (sp1024 model)
ECE increased from 0.24% (PR 1019 model) to 1.26% (sp1024 model) — the mixed int5/int6 quantization introduces slight overconfidence. Model entropy dropped from 1.899 to 1.847 nats (more confident) while accuracy dropped from 54.99% to 54.46%. Not yet re-measured on the sp4608 all-int6 model.
Prior results (val_bpb 1.1125, 3-seed)
SLOT study (removed from submission — causality violation)
SLOT (Selective Logit Offset Tuning) optimizes a 512-dim delta vector at the last hidden layer using AdamW (lr=0.003, 5 steps) per sliding-window batch. It gave −0.0037 BPB (1.1125 → 1.1088), but violates causality: the delta has shape
[1,1,512]and is optimized using targets at all positions, then applied to all positions — so position t's prediction is influenced by future tokens through the shared delta. Removed from submission code; results below are for reference only.Credit: PR 609 (saml212).
Prior results: fused kernels + Brotli only (val_bpb 1.1138, 3-seed)
Delta vs PR 549: −0.00943 nats. Welch's t = −10.26, df ≈ 3.78, p < 0.01.
A.2 Throughput Recovery (sp1024 stack)
Throughput progression (sp1024, prior to sp4608 migration)
Our PR 1019 traded throughput for quality — full Hessian GPTQ and BigramHash 3072×112 added 3.3ms/step. Fused MLP kernels recover that regression. With sp4608, BigramHash was removed entirely (redundant at larger vocab), and all layers use int6.
A.3 Kernel Benchmarks
Kernel benchmarks + incremental deltas (2×H100)
Per-layer kernel timing:
CUTLASS vs Triton: +0.032 ms/layer, +0.347 ms/step kernel-level.
End-to-end training (35 steps, seed=42):
Kernel-level 0.347ms translates to 0.43ms end-to-end (cache/scheduling interactions).
8×H100: 86.7ms (our PR 1019, unfused) → 83.5ms (this PR) = −3.2ms/step (−3.7%).
A.4 Step-Time Profile
Step-time profile — where all 313ms goes (2×H100, Nsight)
Why surgical fusion, not full-MLP autograd.Function: The 21.6% from torch.compile's cross-layer fusions (RMSNorm backward, residual adds, RoPE backward) only exists because these ops are visible to the compiler. Wrapping the full MLP backward in
autograd.Functionmakes it opaque to Inductor — all backward GEMMs plus cross-layer fusion run in eager mode, 2.7× slower net (identified in our PR 670). We fuse only forward and one backward GEMM+pointwise, preserving the compiler's scope.Top individual kernels:
Wall-clock breakdown: forward+backward compute ~94%, NCCL ~1.6%, CPU overhead ~4.1%.
A.5 N-Gram Engineering Details
Engineering Overhaul
Previous attempts at n-gram blending using flat tables and Python/NumPy logic were bottlenecked by severe hash collisions and massive FFI overhead. Initial runs with a logistic mixer yielded a catastrophic +0.210 BPB degradation because collision noise was inflating token probabilities.
By migrating to an open-addressing scheme (64M entries, 26-bit) to store exact keys, we eliminated false positives, pushing token PPM accuracy to 82.3%. To solve the execution bottleneck, we deployed a highly optimized pipeline:
fused_expert_blend.cpp).A.6 N-Gram Benchmarks
The ~295× speedup vs naive Python was the enabling constraint: a brute-force per-token PPM would take hours on 44M+ tokens; our C++ open-addressing hash with batched nanobind calls runs in ~20s (n-gram lookup only), well within the 600s eval budget.
A.7 Architecture
BigramHashSmearGateCalibration legality: AR self-generated (64 seqs × 2048 tokens, temp=0.8). No val data, no train data accessed during quantization. Same method as our PR 1019.
A.8 Setup & Reproduction