Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Depth Recurrence: 5 Unique Layers x 3 Loops (val_bpb=1.2716)

**Non-record submission** exploring depth recurrence as a parameter-efficient way to increase effective model depth.

## Core Idea

Instead of N unique transformer blocks, use fewer unique blocks and loop through them multiple times. This gives more effective depth for the same parameter budget, freeing parameters to increase model width.

**Config:** 5 unique layers looped 3 times = 15 effective depth at dim=640 (15M params)
**Baseline comparison:** 9 unique layers at dim=512 (17M params, 9 effective depth)

## Architecture Changes

Three modifications to the baseline GPT:

1. **Shared-weight loops**: The model has `num_unique_layers=5` transformer blocks. During forward pass, it loops through all 5 blocks `num_loops=3` times, giving 15 effective layers of computation while only storing 5 layers of parameters.

2. **Loop embeddings**: Learnable per-loop vectors (`loop_embeds`, shape `[num_loops, dim]`) added to the residual stream at the start of each loop. These let the model differentiate between passes through the same weights. Initialized to zero so the first loop behaves like vanilla.

3. **Loop gates**: Learnable per-loop scalars (`loop_gates`) that control how much each loop contributes vs. reverting to the initial representation x0. After loop `i > 0`: `x = gate_i * x + (1 - gate_i) * x0`. Initialized to `1/num_loops`.

The encoder-decoder skip connection pattern from the baseline was removed to keep the recurrence clean.

## Results

| Step | val_bpb | val_loss |
|------|---------|----------|
| 500 | 1.4720 | 2.4853 |
| 1000 | 1.3730 | 2.3182 |
| 2000 | 1.3138 | 2.2184 |
| 3000 | 1.2916 | 2.1809 |
| 4000 | 1.2775 | 2.1570 |
| 4500 | 1.2716 | 2.1471 |

Run was terminated early at ~step 5300 due to compute budget (RunPod time limit). val_bpb was still improving.

**Baseline comparison (9L/512dim, same optimizer settings):**
- Step 9000: val_bpb = 1.2507

At comparable training FLOPs, depth recurrence underperformed the baseline by ~0.02 BPB.

## Analysis: Why It Didn't Beat Baseline

1. **Conservative loop gating**: Gates initialized at 1/3 pull the representation back toward x0 after each loop. This effectively limits the model to ~1-1.5 loops of useful computation. The gating was intended to stabilize training through deep recurrence, but it over-regularized.

2. **No skip connections**: The baseline's encoder-decoder skip pattern (U-Net style) provides important information shortcuts. We removed these to simplify the recurrence, but this likely hurt.

3. **Gradient amplification through loops**: The same weights receive gradients from all 3 loops, which changes the effective learning rate. We didn't compensate for this — a lower matrix_lr might help.

4. **Fewer unique representations**: 5 unique layers means less representational diversity per loop pass compared to 9 unique layers, even with loop conditioning.

## What Would Improve This

- Remove loop gating entirely, or initialize gates closer to 1.0
- Add skip connections within each loop pass (encoder-decoder pattern per loop)
- Use 2 loops with 7 unique layers instead of 3 loops with 5 — less recurrence, more layer diversity
- Scale down learning rate proportional to num_loops (each weight gets gradient from all loops)
- Progressive loop training: start with 1 loop, add more during training
- Combine with techniques from current SOTA: sliding window eval, FP16 embeddings, int6 quantization

## Hardware & Training

- **Hardware**: 6x NVIDIA H200 (141GB HBM3e each), RunPod
- **Torch**: 2.8.0+cu128
- **Training**: Distributed across 6 GPUs, ~197ms/step, FlashAttention-2, bfloat16
- **Total training time**: ~888 seconds (14.8 minutes) for 4500 steps
- **Run was wallclock-limited, not converged**

## How to Run

```bash
RUN_ID=recurrent_5x3_d640 \
NUM_UNIQUE_LAYERS=5 NUM_LOOPS=3 \
MODEL_DIM=640 NUM_HEADS=8 NUM_KV_HEADS=4 \
torchrun --standalone --nproc_per_node=1 train_gpt.py
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"name": "Arth Singh",
"github_id": "Arth-Singh",
"run_name": "Depth Recurrence 5x3 dim=640",
"val_bpb": 1.2716,
"val_loss": 2.1471,
"model_params": 15006123,
"compressed_model_bytes": null,
"train_steps": 4500,
"train_time_seconds": 888,
"hardware": "6x NVIDIA H200 (RunPod)",
"notes": "Non-record submission exploring depth recurrence: 5 unique transformer layers looped 3 times = 15 effective depth. Novel loop embeddings and learnable loop gates for per-pass conditioning. Run terminated early due to compute budget — val_bpb was still improving. See README for analysis."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
logs/recurrent_5x3_d640_6gpu.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:15006123
world_size:6 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.04 scalar_lr:0.04
train_batch_tokens:786432 train_seq_len:1024 iterations:20000 warmup_steps:20 max_wallclock_seconds:2400.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9442 val_bpb:4.1127 train_time:0ms step_avg:0.02ms
step:1/20000 train_loss:6.9449 train_time:96ms step_avg:96.05ms
step:2/20000 train_loss:19.7308 train_time:260ms step_avg:130.24ms
step:3/20000 train_loss:9.2896 train_time:442ms step_avg:147.40ms
step:4/20000 train_loss:6.3345 train_time:621ms step_avg:155.29ms
step:5/20000 train_loss:6.1266 train_time:799ms step_avg:159.74ms
step:6/20000 train_loss:6.1883 train_time:979ms step_avg:163.10ms
step:7/20000 train_loss:6.1662 train_time:1164ms step_avg:166.23ms
step:8/20000 train_loss:6.1170 train_time:1345ms step_avg:168.16ms
step:9/20000 train_loss:6.0177 train_time:1523ms step_avg:169.24ms
step:10/20000 train_loss:5.9612 train_time:1705ms step_avg:170.45ms
step:50/20000 train_loss:4.2888 train_time:8842ms step_avg:176.85ms
step:100/20000 train_loss:3.5005 train_time:17799ms step_avg:177.99ms
step:150/20000 train_loss:3.1377 train_time:29249ms step_avg:195.00ms
step:200/20000 train_loss:2.6386 train_time:38219ms step_avg:191.10ms
step:250/20000 train_loss:2.6223 train_time:47218ms step_avg:188.87ms
step:300/20000 train_loss:2.6910 train_time:58662ms step_avg:195.54ms
step:350/20000 train_loss:2.6512 train_time:246646ms step_avg:704.70ms
step:400/20000 train_loss:2.4294 train_time:281865ms step_avg:704.66ms
step:450/20000 train_loss:2.5145 train_time:87999ms step_avg:195.55ms
step:500/20000 train_loss:2.5140 train_time:96955ms step_avg:193.91ms
step:500/20000 val_loss:2.4853 val_bpb:1.4720 train_time:97049ms step_avg:194.10ms
step:550/20000 train_loss:2.4472 train_time:107848ms step_avg:196.09ms
step:600/20000 train_loss:2.4336 train_time:116822ms step_avg:194.70ms
step:1000/20000 train_loss:2.4097 train_time:195507ms step_avg:195.51ms
step:1000/20000 val_loss:2.3182 val_bpb:1.3730 train_time:195599ms step_avg:195.60ms
step:1500/20000 val_loss:2.2573 val_bpb:1.3369 train_time:294372ms step_avg:196.25ms
step:2000/20000 val_loss:2.2184 val_bpb:1.3138 train_time:393272ms step_avg:196.64ms
step:2500/20000 val_loss:2.1971 val_bpb:1.3012 train_time:492043ms step_avg:196.82ms
step:3000/20000 val_loss:2.1809 val_bpb:1.2916 train_time:591026ms step_avg:197.01ms
step:3500/20000 val_loss:2.1691 val_bpb:1.2847 train_time:690009ms step_avg:197.15ms
step:4000/20000 val_loss:2.1570 val_bpb:1.2775 train_time:789520ms step_avg:197.38ms
step:4500/20000 val_loss:2.1471 val_bpb:1.2716 train_time:888434ms step_avg:197.43ms
stopping_early: wallclock_cap train_time:~1046000ms step:~5300/20000
Loading