Non-record: Depth Recurrence 5x3 — Weight-Shared Looping Transformer (6xH200, val_bpb=1.2716)#319
Open
Arth-Singh wants to merge 1 commit intoopenai:mainfrom
Open
Non-record: Depth Recurrence 5x3 — Weight-Shared Looping Transformer (6xH200, val_bpb=1.2716)#319Arth-Singh wants to merge 1 commit intoopenai:mainfrom
Arth-Singh wants to merge 1 commit intoopenai:mainfrom
Conversation
Explores depth recurrence as a parameter-efficient approach to increase effective model depth. 5 unique layers looped 3 times = 15 effective depth at 15M params. Includes learnable loop embeddings and loop gates for per-pass conditioning. Run terminated early due to compute budget.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
val_bpb: 1.2716 on 6xH200 (4500 steps, run terminated early — still improving). Non-record — exploring depth recurrence as an orthogonal axis to existing techniques.
5 unique transformer layers looped 3 times = 15 effective depth at only 15M params, compared to baseline's 9 unique layers at 17M params. The idea: store weights once, use them N times, spend the freed parameter budget on model width (dim=640 vs baseline dim=512).
Technique Stack
Research Pipeline
Stage 1: Architecture Design (Local, Apple M4)
Started from the observation that current SOTA stacks tricks on a fixed 9-10 layer architecture. Nobody had tried depth recurrence — using fewer unique layers looped multiple times. The neural scaling law framing: optimize L(N) by getting more effective compute per stored parameter.
RecurrentGPTclass with loop embeddings + loop gates on the MLX training script. Verified forward/backward pass, parameter counts, and that 5×3 at dim=640 fits ~15M params (within 16MB artifact budget).Stage 2: Single H200 Validation (Nebius)
Stage 3: 6xH200 Distributed Run (RunPod)
Scaled to 6 GPUs for faster iteration. PyTorch 2.8 + NCCL, grad_accum=1, ~197ms/step.
Run terminated at ~step 5300 due to RunPod compute budget expiring. val_bpb was still improving (~0.006/500 steps at end).
Negative Results (saving others time)
Loop gating kills the recurrence benefit
This is the main finding. Loop gates initialized at
1/num_loops = 0.33create the update rule:x = 0.33 * loop_output + 0.67 * x0. This means after loop 0, the model's representation is pulled ~67% back toward the initial embedding. The model effectively gets ~1.3 loops of useful computation instead of 3.The fix (not tested due to compute budget): Initialize gates at 1.0 (identity), or remove gating entirely and rely on the residual stream's natural stability.
Removing skip connections hurts more than recurrence helps
The baseline's encoder-decoder (U-Net style) skip connections were removed to keep the recurrence loop clean. This was a mistake — the skip connections provide critical information shortcuts that the recurrence doesn't compensate for. At 15 effective depth, the model needs these shortcuts even more than at 9 layers.
Width vs depth tradeoff at this scale
At 15M params with 16MB artifact constraint, going wider (dim=640) while reducing unique layer count doesn't win. The baseline's 9 unique layers at dim=512 provides more representational diversity than 5 unique layers at dim=640 looped 3x, even though the latter has more effective depth.
What We'd Do With More Compute
Phase 1: Fix the gating — initialize loop_gates at 1.0 or remove entirely. This is the single highest-leverage change. Rerun on 8xH100 within 10-min budget.
Phase 2: Add skip connections within each loop pass. Each loop iteration gets its own encoder-decoder skip pattern using the same skip weights.
Phase 3: Try 2 loops × 7 unique layers (dim=576) instead of 3×5. Less recurrence, more layer diversity. The hypothesis: 2 passes through 7 unique layers beats 3 passes through 5.
Phase 4: Combine with known winning techniques — sliding window eval (~0.03 BPB), FP16 embeddings, int6 mixed quantization, longer context training (seq=2048+).
Phase 5: Progressive loop training — start with 1 loop for first 50% of steps, add loop 2 at 50%, loop 3 at 75%. Lets the model learn basic features before asking it to refine through recurrence.
Hardware & Reproduction
Test Plan