Skip to content

Full-Depth MLP Megakernel + Fused Attention Preprocessing (non-record)#1316

Open
AR6420 wants to merge 2 commits intoopenai:mainfrom
AR6420:megakernel-fusion
Open

Full-Depth MLP Megakernel + Fused Attention Preprocessing (non-record)#1316
AR6420 wants to merge 2 commits intoopenai:mainfrom
AR6420:megakernel-fusion

Conversation

@AR6420
Copy link
Copy Markdown

@AR6420 AR6420 commented Apr 3, 2026

Full-Depth MLP Megakernel + Fused Attention Preprocessing

val_bpb: 1.1310 (1-seed, SEED=1337) | 15.6 MB | 8xH100 SXM, 600s | Track: non_record_16mb

The Idea: What if a Video Rendering Engine Architecture Could Train Transformers Faster?

This submission started with a question from a different domain entirely.

While designing a tile-based GPU rendering engine for real-time video rendering -- where 4K frames are split into tiles that fit in L2 cache and multiple operations (color correction, blur, sharpen) are fused within each tile to avoid VRAM bandwidth bottlenecks -- I realized the same memory hierarchy problem exists in transformer training: intermediate activations are written to HBM between every operation, even when the next operation immediately reads them back.

The video rendering's solution: keep data in fast on-chip L2 cache, apply all operations there, write once. The transformer equivalent: keep the 1536-dim MLP intermediate in GPU registers, process it via tiled accumulation through the gate projection -> activation -> down projection chain, and never let it touch HBM.

This cross-domain transfer produced two novel contributions, an honest failure, and a key insight about GPU computing.

H100 Results (8xH100 SXM, 600s)

Seed Steps ms/step Pre-quant BPB Sliding BPB Artifact
1337 4,917 122.0 1.1500 1.1310 15,597,863

Seeds 42 and 2025 blocked by compute budget exhaustion. Awaiting grant approval for additional validation runs.

Comparison vs SOTA (PR #1019, 1.1147 BPB): +0.016 BPB, caused by 41% slower step time (122ms vs 86.7ms) resulting in 2,005 fewer training steps.

What Worked

  • Full-depth MLP megakernel: 5 operations (RMSNorm -> gate projection -> LeakyReLU^2 -> down projection -> residual) fused into 1 Triton kernel. The 1536-dim intermediate is never written to HBM -- processed via tiled register accumulation in BLOCK_K=64 chunks. Deeper fusion than PR Record: Fused LeakyReLU² + Online GPTQ + Parallel Muon — val_bpb 1.117 (1-seed) #1072 (which fuses adjacent element-wise ops but still materializes the intermediate between groups).
  • Attention preprocessing fusion: QK RMSNorm + partial RoPE + q_gain fused into 2 Triton kernels, down from 6+. Nobody in the competition fuses these post-projection operations.
  • 41% memory reduction (1562 MiB vs 2656 MiB local; 15.7 GiB / 80 GiB on H100 = 19.6% utilization) -- hardware-independent, reproducible.
  • Near-perfect numerical accuracy: MLP cos_sim=0.99998, attention Q/K cos_sim=0.99999.

What Didn't Work

  • Step time: 41% slower on H100 (122ms vs 86.7ms). The megakernel's 24 small tl.dot calls cannot compete with cuBLAS's single large GEMM, which has decades of per-architecture tensor core optimization. This is worse than the 15% slowdown on consumer GPUs -- H100's stronger cuBLAS widens the gap.
  • The Tile Engine hypothesis was wrong: We expected H100's larger SRAM (228KB vs 101KB) would make the megakernel competitive. It didn't -- more SRAM doesn't overcome the structural disadvantage of replacing cuBLAS.
  • Fully fused attention preprocessing: Attempted fusing RMSNorm -> QKV projection -> QK norm -> RoPE -> gain into one kernel. Triton's block tensor model can't do the half-dimension register slicing that RoPE requires. Achievable in raw CUDA, not in Triton today.

The Key Insight

The Tile Engine metaphor works perfectly for element-wise operations but not for matmul-dominated workloads. In video processing (all per-pixel ops), tiling into SRAM is optimal -- there are no matrix multiplications to compete with cuBLAS. In transformers (90% matmul by compute), the matmuls should be delegated to hardware-optimized libraries while tiling handles only the element-wise glue between them. The right strategy isn't to replace cuBLAS -- it's to partner with it.

Local Benchmarks (RTX 5070 Ti, 1 GPU, 500 steps)

Metric Baseline (PR #1019) This Submission Delta
step_avg (steady) 392.7 ms 451.9 ms +15.1% slower
val_bpb@500 1.9266 2.0269 +0.1003
peak_memory 2656 MiB 1562 MiB -41%
reproducibility -- SubV1-SubV2: 0.0001 Deterministic

Architecture

Based on PR #1019 (@abaybektursun). Same 11L/512d/8H/4KV architecture with all existing components (XSA-all, BigramHash 3072x112, EMA+SWA, Self-Gen GPTQ, Parallel Muon). Added: Triton MLP megakernel + fused attention preprocessing.

Credits

PR #1019 @abaybektursun (base), PR #1072 (Triton fusion baseline), PR #1105 (backward epilogue), PR #399 @abaybektursun (Parallel Muon), PR #493 @parinzee (LeakyReLU^2), PR #478 @gowtham0992 (XSA), PR #315 @jfprincz (Partial RoPE), PR #374 @unnir (VE), PR #162 @raahilshah (BigramHash), PR #535 @raahilshah (GPTQ)

See README.md for full technical details, learnings, and future directions.

AR6420 and others added 2 commits April 3, 2026 15:37
Tile engine-inspired block-level Triton fusion for the 10min/16MB track:

- Full-depth MLP megakernel: 5 ops (RMSNorm → UpProj → LeakyReLU² → DownProj → Residual)
  fused into 1 Triton kernel. The 1536-dim intermediate is processed via tiled register
  accumulation and never materializes in HBM. Deeper than PR openai#1072.

- Fused attention preprocessing: QK RMSNorm + partial RoPE + q_gain in 2 Triton kernels
  (down from 6+). Novel — nobody in competition fuses post-projection ops.

- 41% memory reduction (1562 MiB vs 2656 MiB). Numerically exact (cos_sim>0.99998).

- Based on PR openai#1019 (abaybektursun). H100 results PENDING.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
H100 results (8xH100 SXM, 600s):
- Sliding BPB: 1.1310 (seed 1337)
- Steps: 4,917 at 122.0 ms/step
- Artifact: 15.6 MB

Key finding: megakernel is 41% slower on H100 (122ms vs SOTA 86.7ms).
cuBLAS tensor cores outperform tiled tl.dot even with larger SRAM.
41% memory reduction confirmed (15.7 GiB / 80 GiB).

Single-seed submission — seeds 42/2025 blocked by compute budget.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant