Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Record: SP8192 + Parallel Residuals + Coprime-Stride + Legal Score-First TTT

**val_bpb = 1.08286** (3-seed mean, std 0.00070) | 15.99 MB | 8xH100 SXM | ~1750s total (588s train + ~430s TTT eval)

## Results (3-seed)

| Seed | Pre-TTT BPP | Post-TTT BPP | TTT Gain | val_loss (nats) | Artifact |
|------|------------|-------------|----------|-----------------|----------|
| 1337 | ~1.084 | **1.08255** | -0.0015 | 2.79633 | 15,988,547 |
| 42 | ~1.084 | **1.08237** | -0.0016 | 2.79588 | 15,990,325 |
| 2025 | ~1.084 | **1.08366** | -0.0007 | 2.79921 | 15,989,566 |
| **Mean** | **~1.084** | **1.08286** | **-0.0013** | **2.79714** | |

Merged SOTA (PR #1019): **2.88218 nats** (1.1147 BPP). This run: **2.79714 nats**. Delta: **-0.0850 nats**. Clears the 0.005-nat threshold.

## Changes from Base (PR #1394)

### 1. Parallel Residuals (from layer 7)
Layers 7-10 execute attention and MLP in parallel (PaLM-style) instead of sequential. Zero additional parameters. Measured delta: **-0.0016 BPP** (R10).

### 2. Coprime-Stride Data Loader
Coprime-stride shard traversal for better data diversity. Each shard is traversed with a stride coprime to the number of sequences, ensuring all sequences visited exactly once in pseudo-random order. Measured delta: **-0.0016 BPP** (R10).

### 3. Legal Score-First TTT (eval-time)
Score-first test-time training on the quantized model. Each sliding-window chunk is scored under `torch.inference_mode()` BEFORE any gradient update. Training on a chunk only happens AFTER scoring. Last chunk is score-only. Config: SGD with momentum 0.9, LR=0.005, 3 epochs per chunk, 32768 tokens per chunk. Measured delta: **-0.0015 BPP** (R12).

Pattern follows PR #549 precedent:
```python
for chunk in chunks:
# Phase 1: SCORE (no grad)
with torch.inference_mode():
nll = model(batch); loss_sum += nll.sum()
# Phase 2: TRAIN (only on scored chunk)
if not last_chunk:
for epoch in range(3):
loss = model(x, y); loss.backward(); optimizer.step()
```

## Architecture
- SP8192 (8192 BPE tokens via SentencePiece)
- 11 layers, dim 512, MLP 4x, 8 heads / 4 KV heads (GQA)
- Depth recurrence: layers 4-5 looped 2x (effective 13 layers)
- XSA-all, skip gates, RMSNorm, LeakyReLU(0.5)^2
- MuonEq-R optimizer, EMA (0.997)
- GPTQ int6 weights + int8 embeddings + brotli + SDClip

## Compliance
- All training-side techniques are architecture changes. LEGAL.
- TTT is score-first: strict score-before-update ordering per PR #549.
- `torch.inference_mode()` during scoring prevents gradient accumulation.
- No SLOT, no pre-quant TTT, no n-gram caches, no eval-time logit bias.
- GPTQ calibration uses AR self-generated training data (not validation).

## Reproduction
```bash
pip install brotli
pip install flash_attn_3 --no-deps --find-links https://windreamer.github.io/flash-attention3-wheels/cu128_torch291/
torchrun --standalone --nproc_per_node=8 train_gpt.py
```
No env vars needed. SP8192 data downloads automatically.

## Credits
Base: PR #1394 (@clarkkev). Parallel residuals: PR #1334 (@aryanbhosale). TTT pattern: PR #549 (@abaybektursun), PR #1413 (@dexhunter).
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
brotli
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"track": "10min_16mb",
"val_bpb_mean": 1.08286,
"val_bpb_std": 0.00070,
"seeds": [1337, 42, 2025],
"results": {
"1337": {"val_bpb": 1.08254749, "val_loss": 2.79633260, "bytes_total": 15988547},
"42": {"val_bpb": 1.08237395, "val_loss": 2.79588432, "bytes_total": 15990325},
"2025": {"val_bpb": 1.08365966, "val_loss": 2.79920546, "bytes_total": 15989566}
},
"base_pr": 1394,
"hardware": "8xH100 SXM",
"training_time_seconds": 588,
"eval_method": "legal_score_first_ttt"
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
From https://github.com/resouer/parameter-golf
* branch exp/r12-3seed-w1 -> FETCH_HEAD
Note: switching to 'f5cd8f2480186942eae49e259409313b6058fc22'.
You are in 'detached HEAD' state. You can look around, make experimental
changes and commit them, and you can discard any commits you make in this
state without impacting any branches by switching back to a branch.
If you want to create a new branch to retain commits you create, you may
do so (now or later) by using -c with the switch command. Example:
git switch -c <new-branch-name>
Or undo this operation with:
git switch -
Turn off this advice by setting config variable advice.detachedHead to false
HEAD is now at f5cd8f2 Fix: use heredoc for vocab detection (shell escaping)
data_setup: vocab=8192 shards=128
Warning: You are sending unauthenticated requests to the HF Hub. Please set a HF_TOKEN to enable higher rate limits and faster downloads.
W0406 18:40:59.328000 231 torch/distributed/run.py:803]
W0406 18:40:59.328000 231 torch/distributed/run.py:803] *****************************************
W0406 18:40:59.328000 231 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
Hyperparameters:
adam_eps: 1e-08
adam_wd: 0.02
beta1: 0.9
beta2: 0.95
compressor: brotli
data_dir: ./data/
datasets_dir: ./data/datasets/fineweb10B_sp8192
distributed: True
ema_decay: 0.997
embed_bits: 8
embed_clip_sigmas: 20.0
embed_lr: 0.6
embed_wd: 0.085
embedding_dim: 512
enable_looping_at: 0.5
eval_seq_len: 2048
eval_stride: 64
gptq_calibration_batches: 64
gptq_reserve_seconds: 12.0
grad_accum_steps: 1
grad_clip_norm: 0.3
head_lr: 0.008
is_main_process: True
iterations: 20000
ln_scale: True
local_rank: 0
logfile: logs/0e75b012-595d-4dff-9dbf-a0ae07701c95.txt
logit_softcap: 30.0
loop_end: 5
loop_start: 4
matrix_bits: 6
matrix_clip_sigmas: 12.85
matrix_lr: 0.02
max_wallclock_seconds: 600.0
min_lr: 0.0
mlp_mult: 4.0
model_dim: 512
model_path: final_model.pt
muon_backend_steps: 5
muon_beta2: 0.95
muon_momentum: 0.99
muon_momentum_warmup_start: 0.92
muon_momentum_warmup_steps: 1500
muon_row_normalize: True
muon_wd: 0.085
num_heads: 8
num_kv_heads: 4
num_layers: 11
num_loops: 2
parallel_start_layer: 7
qk_gain_init: 4.0
quantized_model_path: final_model.int6.ptz
rank: 0
rope_base: 10000.0
rope_dims: 16
rope_train_seq_len: 2048
run_id: 0e75b012-595d-4dff-9dbf-a0ae07701c95
scalar_lr: 0.02
seed: 1337
skip_gates_enabled: True
sliding_window_enabled: True
tie_embeddings: True
tied_embed_init_std: 0.005
tied_embed_lr: 0.03
tokenizer_path: ./data/tokenizers/fineweb_8192_bpe.model
train_batch_tokens: 786432
train_files: ./data/datasets/fineweb10B_sp8192/fineweb_train_*.bin
train_log_every: 500
train_seq_len: 2048
ttt_batch_seqs: 32
ttt_chunk_tokens: 32768
ttt_enabled: True
ttt_epochs: 3
ttt_freeze_blocks: 0
ttt_grad_clip: 1.0
ttt_loop_only: False
ttt_lr: 0.005
ttt_momentum: 0.9
val_batch_tokens: 524288
val_files: ./data/datasets/fineweb10B_sp8192/fineweb_val_*.bin
val_loss_every: 4000
vocab_size: 8192
warmdown_frac: 0.667
warmup_steps: 20
world_size: 8
xsa_last_n: 11
train_shards: 128
val_tokens: 40540160
model_params:35943512
gptq:reserving 12s, effective=588000ms
warmup_step: 1/20
warmup_step: 2/20
warmup_step: 3/20
warmup_step: 4/20
warmup_step: 5/20
warmup_step: 6/20
warmup_step: 10/20
warmup_step: 20/20
loop_warmup:enabled encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10]
loop_warmup_step: 1/20
loop_warmup_step: 2/20
loop_warmup_step: 3/20
loop_warmup_step: 4/20
loop_warmup_step: 5/20
loop_warmup_step: 6/20
loop_warmup_step: 10/20
loop_warmup_step: 20/20
0/20000 val_loss: 9.0047 val_bpb: 3.4860
1/20000 train_loss: 9.0084 train_time: 0.0m tok/s: 7900535
2/20000 train_loss: 12.3253 train_time: 0.0m tok/s: 7913763
3/20000 train_loss: 11.0666 train_time: 0.0m tok/s: 7865955
4/20000 train_loss: 9.4822 train_time: 0.0m tok/s: 7843616
5/20000 train_loss: 8.4563 train_time: 0.0m tok/s: 7823176
500/20000 train_loss: 3.4002 train_time: 0.9m tok/s: 7641034
1000/20000 train_loss: 3.2027 train_time: 1.7m tok/s: 7635729
1500/20000 train_loss: 3.2049 train_time: 2.6m tok/s: 7635948
2000/20000 train_loss: 3.1246 train_time: 3.4m tok/s: 7637470
2500/20000 train_loss: 2.9917 train_time: 4.3m tok/s: 7640055
layer_loop:enabled step:2855 frac:0.500 encoder:[0, 1, 2, 3, 4, 5, 4] decoder:[5, 4, 5, 6, 7, 8, 9, 10]
3000/20000 train_loss: 2.9911 train_time: 5.2m tok/s: 7525107
3500/20000 train_loss: 3.0115 train_time: 6.4m tok/s: 7222866
4000/20000 train_loss: 2.9545 train_time: 7.5m tok/s: 7012856
4000/20000 val_loss: 2.9181 val_bpb: 1.1297
4500/20000 train_loss: 2.9286 train_time: 8.6m tok/s: 6847018
5000/20000 train_loss: 2.9144 train_time: 9.8m tok/s: 6719493
5021/20000 val_loss: 2.8156 val_bpb: 1.0900
stopping_early: wallclock_cap train_time: 588011ms step: 5021/20000
peak memory allocated: 34604 MiB reserved: 34708 MiB
ema:applying EMA weights
pre-quantization post-ema val_loss:2.81312613 val_bpb:1.08904879 eval_time:6303ms
Serialized model: 135426937 bytes
Code size: 17115 bytes
GPTQ:collecting Hessians from calibration data...
GPTQ:collected 67 Hessians in 11.4s
Quantized weights:
gptq (int6): blocks.attn.c_k.weight, blocks.attn.c_q.weight, blocks.attn.c_v.weight, blocks.attn.proj.weight, blocks.mlp.fc.weight, blocks.mlp.proj.weight
gptq (int8): tok_emb.weight
passthrough (float16): blocks.attn.q_gain, blocks.attn_scale, blocks.mlp_scale, blocks.resid_mix, skip_gates, skip_weights
Serialized model quantized+brotli: 15971432 bytes
Total submission size quantized+brotli: 15988547 bytes
final_int6_roundtrip_exact val_loss:2.84360773 val_bpb:1.10084917 eval_time:28686ms
final_int6_sliding_window val_loss:2.80047283 val_bpb:1.08415030 eval_time:115336ms
ttt_sliding:start chunks=1238 chunk_tokens=32768 total_windows=633409 stride=64 ttt_lr=0.005 ttt_epochs=3 freeze_blocks=0 loop_only=False
ttt_sliding:params unfrozen=35943512 frozen=0
ttt_chunk [1/1238] bpb=1.118163 time=37.2s
ttt_chunk [11/1238] bpb=1.071228 time=42.2s
ttt_chunk [21/1238] bpb=1.109543 time=45.1s
ttt_chunk [31/1238] bpb=1.104130 time=48.1s
ttt_chunk [41/1238] bpb=1.097377 time=51.0s
ttt_chunk [51/1238] bpb=1.090972 time=53.9s
ttt_chunk [61/1238] bpb=1.082549 time=56.8s
ttt_chunk [71/1238] bpb=1.089584 time=59.7s
ttt_chunk [81/1238] bpb=1.083170 time=62.6s
ttt_chunk [91/1238] bpb=1.079738 time=65.5s
ttt_chunk [101/1238] bpb=1.079437 time=68.4s
ttt_chunk [111/1238] bpb=1.077892 time=71.3s
ttt_chunk [121/1238] bpb=1.080594 time=74.2s
ttt_chunk [131/1238] bpb=1.084258 time=77.1s
ttt_chunk [141/1238] bpb=1.084778 time=80.0s
ttt_chunk [151/1238] bpb=1.084592 time=82.8s
ttt_chunk [161/1238] bpb=1.085142 time=85.8s
ttt_chunk [171/1238] bpb=1.085088 time=88.7s
ttt_chunk [181/1238] bpb=1.083454 time=91.6s
ttt_chunk [191/1238] bpb=1.083201 time=94.5s
ttt_chunk [201/1238] bpb=1.080695 time=97.4s
ttt_chunk [211/1238] bpb=1.085108 time=100.3s
ttt_chunk [221/1238] bpb=1.085479 time=103.2s
ttt_chunk [231/1238] bpb=1.087191 time=106.1s
ttt_chunk [241/1238] bpb=1.085559 time=109.0s
ttt_chunk [251/1238] bpb=1.085599 time=111.9s
ttt_chunk [261/1238] bpb=1.086578 time=114.8s
ttt_chunk [271/1238] bpb=1.086969 time=118.5s
ttt_chunk [281/1238] bpb=1.086274 time=121.4s
ttt_chunk [291/1238] bpb=1.087461 time=124.3s
ttt_chunk [301/1238] bpb=1.087628 time=127.2s
ttt_chunk [311/1238] bpb=1.086459 time=130.1s
ttt_chunk [321/1238] bpb=1.086353 time=133.0s
ttt_chunk [331/1238] bpb=1.086633 time=135.9s
ttt_chunk [341/1238] bpb=1.085745 time=138.8s
ttt_chunk [351/1238] bpb=1.086443 time=141.7s
ttt_chunk [361/1238] bpb=1.085377 time=144.6s
ttt_chunk [371/1238] bpb=1.083843 time=147.5s
ttt_chunk [381/1238] bpb=1.084234 time=150.4s
ttt_chunk [391/1238] bpb=1.083874 time=153.3s
ttt_chunk [401/1238] bpb=1.083945 time=156.2s
ttt_chunk [411/1238] bpb=1.084521 time=159.1s
ttt_chunk [421/1238] bpb=1.084030 time=162.0s
ttt_chunk [431/1238] bpb=1.084207 time=164.9s
ttt_chunk [441/1238] bpb=1.084283 time=167.8s
ttt_chunk [451/1238] bpb=1.085485 time=170.7s
ttt_chunk [461/1238] bpb=1.083717 time=173.6s
ttt_chunk [471/1238] bpb=1.083674 time=176.5s
ttt_chunk [481/1238] bpb=1.083824 time=179.5s
ttt_chunk [491/1238] bpb=1.084296 time=182.4s
ttt_chunk [501/1238] bpb=1.083920 time=185.3s
ttt_chunk [511/1238] bpb=1.083562 time=188.1s
ttt_chunk [521/1238] bpb=1.083066 time=191.0s
ttt_chunk [531/1238] bpb=1.083023 time=193.9s
ttt_chunk [541/1238] bpb=1.083115 time=196.8s
ttt_chunk [551/1238] bpb=1.082652 time=199.8s
ttt_chunk [561/1238] bpb=1.081962 time=202.6s
ttt_chunk [571/1238] bpb=1.081408 time=205.5s
ttt_chunk [581/1238] bpb=1.081743 time=208.4s
ttt_chunk [591/1238] bpb=1.082003 time=211.3s
ttt_chunk [601/1238] bpb=1.081961 time=214.2s
ttt_chunk [611/1238] bpb=1.082529 time=217.1s
ttt_chunk [621/1238] bpb=1.083403 time=220.0s
ttt_chunk [631/1238] bpb=1.083433 time=222.9s
ttt_chunk [641/1238] bpb=1.083868 time=225.8s
ttt_chunk [651/1238] bpb=1.084205 time=228.7s
ttt_chunk [661/1238] bpb=1.083537 time=231.5s
ttt_chunk [671/1238] bpb=1.083241 time=234.4s
ttt_chunk [681/1238] bpb=1.084590 time=237.3s
ttt_chunk [691/1238] bpb=1.084769 time=240.2s
ttt_chunk [701/1238] bpb=1.084572 time=243.8s
ttt_chunk [711/1238] bpb=1.085272 time=246.7s
ttt_chunk [721/1238] bpb=1.085589 time=249.6s
ttt_chunk [731/1238] bpb=1.084944 time=252.5s
ttt_chunk [741/1238] bpb=1.084656 time=255.4s
ttt_chunk [751/1238] bpb=1.083765 time=258.3s
ttt_chunk [761/1238] bpb=1.083148 time=261.2s
ttt_chunk [771/1238] bpb=1.082176 time=264.9s
ttt_chunk [781/1238] bpb=1.082138 time=267.8s
ttt_chunk [791/1238] bpb=1.082483 time=270.7s
ttt_chunk [801/1238] bpb=1.082786 time=273.7s
ttt_chunk [811/1238] bpb=1.082288 time=277.2s
ttt_chunk [821/1238] bpb=1.081116 time=280.1s
ttt_chunk [831/1238] bpb=1.080805 time=283.0s
ttt_chunk [841/1238] bpb=1.080375 time=285.9s
ttt_chunk [851/1238] bpb=1.080081 time=288.8s
ttt_chunk [861/1238] bpb=1.079724 time=291.8s
ttt_chunk [871/1238] bpb=1.079586 time=294.7s
ttt_chunk [881/1238] bpb=1.079109 time=297.7s
ttt_chunk [891/1238] bpb=1.078579 time=302.3s
ttt_chunk [901/1238] bpb=1.078963 time=305.2s
ttt_chunk [911/1238] bpb=1.078647 time=308.2s
ttt_chunk [921/1238] bpb=1.078943 time=311.3s
ttt_chunk [931/1238] bpb=1.079601 time=314.3s
ttt_chunk [941/1238] bpb=1.079993 time=317.3s
ttt_chunk [951/1238] bpb=1.079927 time=320.3s
ttt_chunk [961/1238] bpb=1.080753 time=323.3s
ttt_chunk [971/1238] bpb=1.081152 time=326.3s
ttt_chunk [981/1238] bpb=1.081519 time=329.2s
ttt_chunk [991/1238] bpb=1.081320 time=332.2s
ttt_chunk [1001/1238] bpb=1.081367 time=335.2s
ttt_chunk [1011/1238] bpb=1.081707 time=338.1s
ttt_chunk [1021/1238] bpb=1.082415 time=341.1s
ttt_chunk [1031/1238] bpb=1.082910 time=344.1s
ttt_chunk [1041/1238] bpb=1.083364 time=347.1s
ttt_chunk [1051/1238] bpb=1.083286 time=350.0s
ttt_chunk [1061/1238] bpb=1.083282 time=353.0s
ttt_chunk [1071/1238] bpb=1.083455 time=355.9s
ttt_chunk [1081/1238] bpb=1.083350 time=358.9s
ttt_chunk [1091/1238] bpb=1.083541 time=361.8s
ttt_chunk [1101/1238] bpb=1.084084 time=364.8s
ttt_chunk [1111/1238] bpb=1.084391 time=367.7s
ttt_chunk [1121/1238] bpb=1.084575 time=370.7s
ttt_chunk [1131/1238] bpb=1.084238 time=373.6s
ttt_chunk [1141/1238] bpb=1.083910 time=376.6s
ttt_chunk [1151/1238] bpb=1.083957 time=379.6s
ttt_chunk [1161/1238] bpb=1.084112 time=382.6s
ttt_chunk [1171/1238] bpb=1.083888 time=385.5s
ttt_chunk [1181/1238] bpb=1.083408 time=388.5s
ttt_chunk [1191/1238] bpb=1.083560 time=391.5s
ttt_chunk [1201/1238] bpb=1.083597 time=394.4s
ttt_chunk [1211/1238] bpb=1.083289 time=397.4s
ttt_chunk [1221/1238] bpb=1.082814 time=400.3s
ttt_chunk [1231/1238] bpb=1.082444 time=403.3s
ttt_chunk [1238/1238] bpb=1.082442 time=422.7s
ttt_sliding:done val_loss=2.796333 val_bpb=1.082547 elapsed=423.5s
legal_ttt_exact val_loss:2.79633260 val_bpb:1.08254749 eval_time:423701ms
results_json: {"val_bpb": 1.08254749, "val_loss": 2.7963326, "bytes_total": 15988547, "peak_memory_mib": 34604}
Loading