Skip to content

Non-record: Depth Recurrence 5x3 — Weight-Shared Looping Transformer (6xH200, val_bpb=1.2716)#319

Open
Arth-Singh wants to merge 1 commit intoopenai:mainfrom
Arth-Singh:depth-recurrence-5x3-640
Open

Non-record: Depth Recurrence 5x3 — Weight-Shared Looping Transformer (6xH200, val_bpb=1.2716)#319
Arth-Singh wants to merge 1 commit intoopenai:mainfrom
Arth-Singh:depth-recurrence-5x3-640

Conversation

@Arth-Singh
Copy link

@Arth-Singh Arth-Singh commented Mar 21, 2026

Summary

val_bpb: 1.2716 on 6xH200 (4500 steps, run terminated early — still improving). Non-record — exploring depth recurrence as an orthogonal axis to existing techniques.

5 unique transformer layers looped 3 times = 15 effective depth at only 15M params, compared to baseline's 9 unique layers at 17M params. The idea: store weights once, use them N times, spend the freed parameter budget on model width (dim=640 vs baseline dim=512).

Technique Stack

Technique Rationale Status
Weight-shared looping (5 unique × 3 loops) 15 effective depth for 5 layers of stored params Working, but underperforms baseline by ~0.02 BPB
Loop embeddings (learnable per-loop vectors) Let model differentiate passes through same weights Working, initialized at zero
Loop gates (learnable per-loop scalars) Control contribution of each loop vs initial repr x0 Over-regularized — key negative finding
Wider model (dim=640 vs 512) Spend saved params on width Working
Muon + Adam split optimizer Inherited from baseline Working
FlashAttention-2 + GQA (8h/4kv) Inherited from baseline Working
torch.compile + bfloat16 Standard H200 optimization Working, ~197ms/step on 6xH200

Research Pipeline

Stage 1: Architecture Design (Local, Apple M4)

Started from the observation that current SOTA stacks tricks on a fixed 9-10 layer architecture. Nobody had tried depth recurrence — using fewer unique layers looped multiple times. The neural scaling law framing: optimize L(N) by getting more effective compute per stored parameter.

  • Prototype on MLX: Built RecurrentGPT class with loop embeddings + loop gates on the MLX training script. Verified forward/backward pass, parameter counts, and that 5×3 at dim=640 fits ~15M params (within 16MB artifact budget).
  • Parameter sweep: Tested configs from 4×4 (16 depth, 12.1M params) to 6×3 (18 depth, 14.5M params). Selected 5×3 dim=640 as best depth-width tradeoff.

Stage 2: Single H200 Validation (Nebius)

  • Ran 5×3 recurrence and vanilla 9L baseline on identical hardware/hyperparameters.
  • Recurrence step 500: val_bpb=1.4720, train_loss=2.51
  • Baseline step 9000: val_bpb=1.2507
  • At comparable training FLOPs, recurrence tracked ~0.02-0.03 BPB behind baseline. Loss curves healthy but convergence slower.

Stage 3: 6xH200 Distributed Run (RunPod)

Scaled to 6 GPUs for faster iteration. PyTorch 2.8 + NCCL, grad_accum=1, ~197ms/step.

Step val_bpb val_loss wall_time
500 1.4720 2.4853 97s
1000 1.3730 2.3182 196s
2000 1.3138 2.2184 393s
3000 1.2916 2.1809 591s
4000 1.2775 2.1570 790s
4500 1.2716 2.1471 888s

Run terminated at ~step 5300 due to RunPod compute budget expiring. val_bpb was still improving (~0.006/500 steps at end).

Negative Results (saving others time)

Loop gating kills the recurrence benefit

This is the main finding. Loop gates initialized at 1/num_loops = 0.33 create the update rule: x = 0.33 * loop_output + 0.67 * x0. This means after loop 0, the model's representation is pulled ~67% back toward the initial embedding. The model effectively gets ~1.3 loops of useful computation instead of 3.

The fix (not tested due to compute budget): Initialize gates at 1.0 (identity), or remove gating entirely and rely on the residual stream's natural stability.

Removing skip connections hurts more than recurrence helps

The baseline's encoder-decoder (U-Net style) skip connections were removed to keep the recurrence loop clean. This was a mistake — the skip connections provide critical information shortcuts that the recurrence doesn't compensate for. At 15 effective depth, the model needs these shortcuts even more than at 9 layers.

Width vs depth tradeoff at this scale

At 15M params with 16MB artifact constraint, going wider (dim=640) while reducing unique layer count doesn't win. The baseline's 9 unique layers at dim=512 provides more representational diversity than 5 unique layers at dim=640 looped 3x, even though the latter has more effective depth.

What We'd Do With More Compute

Phase 1: Fix the gating — initialize loop_gates at 1.0 or remove entirely. This is the single highest-leverage change. Rerun on 8xH100 within 10-min budget.

Phase 2: Add skip connections within each loop pass. Each loop iteration gets its own encoder-decoder skip pattern using the same skip weights.

Phase 3: Try 2 loops × 7 unique layers (dim=576) instead of 3×5. Less recurrence, more layer diversity. The hypothesis: 2 passes through 7 unique layers beats 3 passes through 5.

Phase 4: Combine with known winning techniques — sliding window eval (~0.03 BPB), FP16 embeddings, int6 mixed quantization, longer context training (seq=2048+).

Phase 5: Progressive loop training — start with 1 loop for first 50% of steps, add loop 2 at 50%, loop 3 at 75%. Lets the model learn basic features before asking it to refine through recurrence.

Hardware & Reproduction

# 6xH200 (RunPod), PyTorch 2.8, ~197ms/step
RUN_ID=recurrent_5x3_d640 \
NUM_UNIQUE_LAYERS=5 NUM_LOOPS=3 \
MODEL_DIM=640 NUM_HEADS=8 NUM_KV_HEADS=4 \
TRAIN_BATCH_TOKENS=786432 \
torchrun --standalone --nproc_per_node=6 train_gpt.py

# Single GPU
RUN_ID=recurrent_5x3_d640 \
NUM_UNIQUE_LAYERS=5 NUM_LOOPS=3 \
MODEL_DIM=640 NUM_HEADS=8 NUM_KV_HEADS=4 \
torchrun --standalone --nproc_per_node=1 train_gpt.py

Test Plan

  • Local MLX prototype — forward pass, parameter counts, architecture validation
  • Single H200 training — confirmed loss convergence, healthy training dynamics
  • 6xH200 distributed training — 4500 steps, val_bpb=1.2716
  • Baseline comparison on same hardware — recurrence ~0.02 BPB behind at comparable FLOPs
  • All submission files: README.md, submission.json, train.log, train_gpt.py
  • 8xH100 full 10-min run (awaiting compute grant)
  • Loop gate ablation (gates=1.0 vs gates=1/N vs no gates)
  • 3-seed statistical validation

Explores depth recurrence as a parameter-efficient approach to increase
effective model depth. 5 unique layers looped 3 times = 15 effective
depth at 15M params. Includes learnable loop embeddings and loop gates
for per-pass conditioning. Run terminated early due to compute budget.
@Arth-Singh Arth-Singh changed the title Non-record: Depth Recurrence 5x3 dim=640, val_bpb=1.2716 Non-record: Depth Recurrence 5x3 — Weight-Shared Looping Transformer (6xH200, val_bpb=1.2716) Mar 21, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant