Skip to content

Commit 0f26c08

Browse files
icryoclaude
andcommitted
Rebase on PR openai#1060 (1.1122 BPB) + TTT reset sweep
Competition moved while we were experimenting locally: PR openai#634: 1.1178 BPB (Full GPTQ + XSA-all + selective pruning) PR openai#1060: 1.1122 BPB (+ coprime loader + BigramHash 2816) Our contribution: TTT periodic reset on the PR openai#1060 base. PR openai#1060 found TTT unnecessary with Full GPTQ, but they didn't test TTT with anti-drift reset. If TTT drift was the reason it stopped helping, reset could unlock further gains. Files: train_gpt_ours.py — PR openai#1060 + TTT reset mechanism train_gpt_pr634.py — Full GPTQ reference (for study) train_gpt_pr1060.py — Original PR openai#1060 (for comparison) run_h100.sh — Train once, sweep 4 TTT configs TTT configs tested: A: SOTA (lr=0.002, 3ep) — baseline TTT B: PR openai#1039 (lr=0.0025, 4ep) — tuned TTT C: B + reset/100 — anti-drift, moderate D: B + reset/50 — anti-drift, aggressive Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 3b1c62b commit 0f26c08

File tree

4 files changed

+6413
-126
lines changed

4 files changed

+6413
-126
lines changed

run_h100.sh

Lines changed: 51 additions & 126 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,36 @@
11
#!/bin/bash
2-
# Full H100 experiment: train once, eval many configs
3-
# Maximizes information per dollar on rented 8xH100
2+
# Full H100 experiment: PR #1060 base (1.1122) + TTT with reset
43
#
5-
# Budget: ~1 hour total (~$25-30)
6-
# - Setup: 10 min
7-
# - Training: 10 min (identical to SOTA)
8-
# - Eval sweep: 40 min (6-8 configs on same model)
4+
# What we test (train once, eval many):
5+
# 1. PR #1060 baseline reproduction (no TTT): ~87s eval
6+
# 2. TTT with SOTA config (lr=0.002, 3ep): ~410s
7+
# 3. TTT with PR #1039 config (lr=0.0025, 4ep): ~410s
8+
# 4. TTT with periodic reset every 100 chunks: ~410s
9+
# 5. TTT with periodic reset every 50 chunks: ~410s
10+
#
11+
# Total: ~10min train + 87s baseline + 4×410s TTT ≈ 37 min
12+
# Cost: ~$15-20 on 8xH100
913
#
1014
# Usage:
1115
# git clone https://github.com/icryo/parameter-golf.git
1216
# cd parameter-golf && git checkout experiments/triton-kernels-qat-fix
1317
# pip install sentencepiece huggingface-hub datasets flash-attn
1418
# python3 data/cached_challenge_fineweb.py --variant sp1024
1519
# ./run_h100.sh 2>&1 | tee full_experiment.log
16-
1720
set -euo pipefail
1821
SEED="${1:-1337}"
1922
NPROC=8
2023

2124
echo "============================================================"
22-
echo "PARAMETER GOLF: Full H100 Experiment Suite"
25+
echo "PARAMETER GOLF: PR #1060 base + TTT reset experiments"
2326
echo "Seed: $SEED | GPUs: $NPROC | $(date)"
2427
echo "============================================================"
2528

26-
# === PHASE 1: Train (identical to SOTA, ~10 min) ===
29+
# === PHASE 1: Train with PR #1060 config ===
2730
echo ""
28-
echo "=== PHASE 1: Training (SOTA reproduction) ==="
29-
export RUN_ID="h100_s${SEED}"
31+
echo "=== PHASE 1: Training (PR #1060: coprime loader + Full GPTQ + XSA-all) ==="
32+
33+
export RUN_ID="pr1060_ttt_s${SEED}"
3034
export SEED="$SEED"
3135
export DATA_PATH="./data/datasets/fineweb10B_sp1024"
3236
export TOKENIZER_PATH="./data/tokenizers/fineweb_1024_bpe.model"
@@ -35,35 +39,33 @@ export ITERATIONS=9000 WARMUP_STEPS=20 WARMDOWN_ITERS=3500
3539
export TRAIN_BATCH_TOKENS=786432 TRAIN_SEQ_LEN=2048
3640
export EVAL_SEQ_LEN=2048 EVAL_STRIDE=64
3741
export NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3
38-
export TIE_EMBEDDINGS=1 XSA_LAST_N=4 ROPE_DIMS=16 LN_SCALE=1
42+
export TIE_EMBEDDINGS=1 ROPE_DIMS=16 LN_SCALE=1
3943
export VE_ENABLED=1 VE_DIM=128 VE_LAYERS="9,10"
40-
export BIGRAM_VOCAB_SIZE=2048 BIGRAM_DIM=128 LOGIT_SOFTCAP=30.0
44+
export LOGIT_SOFTCAP=30.0
4145
export MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035
4246
export MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500
4347
export MUON_WD=0.04 ADAM_WD=0.04 GRAD_CLIP_NORM=0.3
4448
export SWA_ENABLED=1 SWA_EVERY=50
4549
export LATE_QAT_THRESHOLD=0.15
46-
export EVAL_TEMPERATURE=0
47-
# Disable TTT in training script — we'll run it manually below
50+
export XSA_LAST_N=11
51+
export BIGRAM_VOCAB_SIZE=2816
52+
export BIGRAM_DIM=112
53+
export USE_GPTQ=1
54+
export GPTQ_RESERVE_MS=14000
4855
export TTT_ENABLED=0
4956

50-
torchrun --standalone --nproc_per_node=$NPROC train_gpt_safe.py
57+
torchrun --standalone --nproc_per_node=$NPROC train_gpt_ours.py
5158

5259
echo ""
53-
echo "=== Training complete. Model saved to final_model.int6.ptz ==="
54-
echo "=== Now running eval sweep on saved checkpoint ==="
55-
56-
# === PHASE 2: Eval sweep (multiple configs, ~40 min) ===
57-
# Each eval_val_sliding takes ~74s, each TTT takes ~410s
58-
# Uses all 8 GPUs for fast eval via torchrun
60+
echo "=== Training complete. Now running TTT sweep. ==="
5961

60-
cat > /tmp/eval_sweep.py << 'PYEOF'
62+
# === PHASE 2: TTT sweep on saved checkpoint ===
63+
cat > /tmp/ttt_sweep.py << 'PYEOF'
6164
import torch, time, math, io, lzma, sys, os
6265
import torch.distributed as dist
63-
6466
sys.path.insert(0, '.')
65-
from train_gpt_safe import (
66-
Hyperparameters, GPT, CastedLinear, eval_val_sliding, eval_val_sliding_ttt,
67+
from train_gpt_ours import (
68+
Hyperparameters, GPT, CastedLinear, eval_val_sliding_ttt,
6769
dequantize_mixed_int6, _rebank_state_dict, _unbank_state_dict,
6870
build_sentencepiece_luts, load_validation_tokens, restore_low_dim_params_to_fp32,
6971
)
@@ -82,27 +84,26 @@ master = (rank == 0)
8284
def log(msg):
8385
if master: print(msg, flush=True)
8486
85-
# Load tokenizer and validation data
8687
sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
8788
val_tokens = load_validation_tokens(args.val_files, args.train_seq_len)
8889
base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(
8990
sp, args.vocab_size, device)
9091
91-
# Load quantized model
9292
with open('final_model.int6.ptz', 'rb') as f:
9393
quant_blob = f.read()
9494
quant_state = torch.load(io.BytesIO(lzma.decompress(quant_blob)), map_location='cpu')
95-
sd_cpu = {k: v.detach().cpu() for k, v in torch.load('final_model.pt', map_location='cpu', weights_only=True).items()}
95+
sd_cpu = {k: v.detach().cpu() for k, v in
96+
torch.load('final_model.pt', map_location='cpu', weights_only=True).items()}
9697
unbanked = _unbank_state_dict(sd_cpu, args.num_layers)
9798
deq = dequantize_mixed_int6(quant_state['w'], quant_state['m'], unbanked)
9899
deq_banked = _rebank_state_dict(deq, args.num_layers, sd_cpu)
99100
100-
def load_eval_model(softcap_scale=1.0):
101+
def load_fresh():
101102
m = GPT(
102103
vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim,
103104
num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult,
104105
tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std,
105-
logit_softcap=args.logit_softcap * softcap_scale, rope_base=args.rope_base,
106+
logit_softcap=args.logit_softcap, rope_base=args.rope_base,
106107
qk_gain_init=args.qk_gain_init, mtp_num_heads=0, mtp_loss_weight=0.0,
107108
bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim,
108109
xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale,
@@ -120,25 +121,9 @@ def load_eval_model(softcap_scale=1.0):
120121
m.load_state_dict(deq_banked, strict=True)
121122
return m
122123
123-
def run_sliding(label, temp=1.0, stride=64):
124-
log(f' [{label}] T={temp:.2f} stride={stride}...')
125-
model = load_eval_model(softcap_scale=temp)
126-
torch.cuda.synchronize()
127-
t0 = time.perf_counter()
128-
loss, bpb = eval_val_sliding(
129-
args, model, rank, world_size, device, val_tokens,
130-
base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
131-
stride=stride, eval_seq_len=args.eval_seq_len,
132-
)
133-
torch.cuda.synchronize()
134-
elapsed = time.perf_counter() - t0
135-
log(f' [{label}] val_bpb={bpb:.8f} time={elapsed:.1f}s')
136-
del model; torch.cuda.empty_cache()
137-
return bpb, elapsed
138-
139-
def run_ttt(label, temp=1.0, ttt_lr=0.002, ttt_epochs=3, freeze_blocks=0, stride=64, reset_every=0):
140-
log(f' [{label}] T={temp:.2f} lr={ttt_lr} ep={ttt_epochs} freeze={freeze_blocks} reset={reset_every}...')
141-
model = load_eval_model(softcap_scale=temp)
124+
def run_ttt(label, ttt_lr, ttt_epochs, freeze_blocks=0, reset_every=0):
125+
log(f' [{label}] lr={ttt_lr} ep={ttt_epochs} reset={reset_every}...')
126+
model = load_fresh()
142127
args.ttt_lr = ttt_lr
143128
args.ttt_epochs = ttt_epochs
144129
args.ttt_freeze_blocks = freeze_blocks
@@ -152,103 +137,43 @@ def run_ttt(label, temp=1.0, ttt_lr=0.002, ttt_epochs=3, freeze_blocks=0, stride
152137
loss, bpb = eval_val_sliding_ttt(
153138
args, model, rank, world_size, device, val_tokens,
154139
base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
155-
stride=stride, log0=log,
140+
stride=64, log0=log,
156141
)
157142
torch.cuda.synchronize()
158143
elapsed = time.perf_counter() - t0
159144
log(f' [{label}] val_bpb={bpb:.8f} time={elapsed:.1f}s')
160145
del model; torch.cuda.empty_cache()
161-
return bpb, elapsed
146+
return bpb
162147
163148
log('')
164149
log('=' * 60)
165-
log('EVAL SWEEP: Temperature + Stride + TTT configs')
150+
log('TTT SWEEP on PR #1060 (Full GPTQ + XSA-all) quantized model')
166151
log('=' * 60)
167152
results = {}
168153
169-
# --- Block 1: Temperature sweep at stride=64 (~6 x 74s = 444s) ---
170-
log('')
171-
log('--- Temperature sweep (stride=64, no TTT) ---')
172-
for t in [0.85, 0.88, 0.90, 0.92, 0.95, 1.00]:
173-
bpb, _ = run_sliding(f'temp_{t:.2f}', temp=t, stride=64)
174-
results[f'sliding_T{t:.2f}_s64'] = bpb
175-
176-
# Find best temperature
177-
best_t = min([0.85, 0.88, 0.90, 0.92, 0.95, 1.00],
178-
key=lambda t: results[f'sliding_T{t:.2f}_s64'])
179-
best_bpb = results[f'sliding_T{best_t:.2f}_s64']
180-
baseline_bpb = results['sliding_T1.00_s64']
181-
temp_delta = best_bpb - baseline_bpb
182-
log(f' >>> Best T={best_t:.2f} bpb={best_bpb:.8f} delta={temp_delta:+.8f} <<<')
154+
results['sota_ttt'] = run_ttt('sota_ttt', 0.002, 3)
155+
results['pr1039'] = run_ttt('pr1039', 0.0025, 4)
156+
results['reset100'] = run_ttt('reset100', 0.0025, 4, reset_every=100)
157+
results['reset50'] = run_ttt('reset50', 0.0025, 4, reset_every=50)
183158
184-
# --- Block 2: TTT sweep with best temperature (~3 x 410s = 1230s) ---
185-
log('')
186-
log(f'--- TTT sweep (T={best_t:.2f}, stride=64) ---')
187-
188-
# Config A: SOTA TTT (baseline)
189-
bpb_a, t_a = run_ttt('ttt_sota', temp=best_t, ttt_lr=0.002, ttt_epochs=3, freeze_blocks=0)
190-
results['ttt_sota'] = bpb_a
191-
192-
# Config B: PR #1039 recipe (claimed 1.1184 BPB — potential record)
193-
bpb_b, t_b = run_ttt('ttt_pr1039', temp=best_t, ttt_lr=0.0025, ttt_epochs=4, freeze_blocks=0)
194-
results['ttt_pr1039'] = bpb_b
195-
196-
# Config C: More epochs (deeper adaptation)
197-
bpb_c, t_c = run_ttt('ttt_5ep', temp=best_t, ttt_lr=0.002, ttt_epochs=5, freeze_blocks=0)
198-
results['ttt_5ep'] = bpb_c
199-
200-
# Config D: PR #1039 + periodic reset every 100 chunks (anti-drift)
201-
bpb_d, t_d = run_ttt('ttt_reset100', temp=best_t, ttt_lr=0.0025, ttt_epochs=4, freeze_blocks=0, reset_every=100)
202-
results['ttt_reset100'] = bpb_d
203-
204-
# Config E: Reset every 50 chunks (more aggressive anti-drift)
205-
bpb_e, t_e = run_ttt('ttt_reset50', temp=best_t, ttt_lr=0.0025, ttt_epochs=4, freeze_blocks=0, reset_every=50)
206-
results['ttt_reset50'] = bpb_e
207-
208-
# --- Summary ---
209159
log('')
210160
log('=' * 60)
211-
log('RESULTS SUMMARY')
161+
log('RESULTS')
212162
log('=' * 60)
213-
214-
log('')
215-
log('Temperature sweep (stride=64, no TTT):')
216-
for t in [0.85, 0.88, 0.90, 0.92, 0.95, 1.00]:
217-
bpb = results[f'sliding_T{t:.2f}_s64']
218-
delta = bpb - baseline_bpb
219-
marker = ' <<<' if t == best_t else ''
220-
log(f' T={t:.2f} bpb={bpb:.8f} delta={delta:+.8f}{marker}')
221-
163+
log(f'PR #1060 no-TTT baseline: 1.1122 (their submission)')
222164
log('')
223-
log('TTT configurations:')
224-
log(f' SOTA (lr=0.002, 3ep): bpb={bpb_a:.8f}')
225-
log(f' PR1039 (lr=0.0025, 4ep): bpb={bpb_b:.8f} delta={bpb_b-bpb_a:+.8f}')
226-
log(f' 5 epochs (lr=0.002, 5ep): bpb={bpb_c:.8f} delta={bpb_c-bpb_a:+.8f}')
227-
log(f' PR1039 + reset/100 (anti-drift): bpb={bpb_d:.8f} delta={bpb_d-bpb_a:+.8f}')
228-
log(f' PR1039 + reset/50 (anti-drift): bpb={bpb_e:.8f} delta={bpb_e-bpb_a:+.8f}')
229-
230-
log('')
231-
best_ttt = min(bpb_a, bpb_b, bpb_c, bpb_d, bpb_e)
232-
log(f'SOTA reference (seed 1337): 1.11922988')
233-
log(f'Our best result: {best_ttt:.8f}')
234-
log(f'Delta vs SOTA: {best_ttt - 1.11922988:+.8f}')
235-
log(f'Record threshold: 1.1144')
236-
log(f'Gap to record: {best_ttt - 1.1144:+.8f}')
237-
238-
if best_ttt < 1.1144:
239-
log('>>> RECORD TERRITORY! Run 2 more seeds to confirm. <<<')
240-
elif best_ttt < 1.1192:
241-
log('>>> Better than SOTA seed, but need 3-seed mean. Worth 2 more seeds. <<<')
242-
elif best_ttt < 1.1200:
243-
log('>>> Close to SOTA. Temperature or TTT tuning helping marginally. <<<')
244-
else:
245-
log('>>> At or worse than SOTA. No path to record from these changes. <<<')
165+
for label, bpb in sorted(results.items(), key=lambda x: x[1]):
166+
log(f' {label:<20} bpb={bpb:.8f} vs_noTTT={bpb-1.1122:+.6f} vs_merged_SOTA={bpb-1.1194:+.6f}')
167+
best = min(results, key=results.get)
168+
log(f'\nBest: {best} = {results[best]:.8f}')
169+
log(f'Record threshold (vs PR #1060): <= 1.1072')
170+
log(f'Gap: {results[best] - 1.1072:+.8f}')
246171
247172
if world_size > 1:
248173
dist.destroy_process_group()
249174
PYEOF
250175

251-
torchrun --standalone --nproc_per_node=$NPROC /tmp/eval_sweep.py
176+
torchrun --standalone --nproc_per_node=$NPROC /tmp/ttt_sweep.py
252177

253178
echo ""
254179
echo "============================================================"

0 commit comments

Comments
 (0)