11#! /bin/bash
2- # Full H100 experiment: train once, eval many configs
3- # Maximizes information per dollar on rented 8xH100
2+ # Full H100 experiment: PR #1060 base (1.1122) + TTT with reset
43#
5- # Budget: ~1 hour total (~$25-30)
6- # - Setup: 10 min
7- # - Training: 10 min (identical to SOTA)
8- # - Eval sweep: 40 min (6-8 configs on same model)
4+ # What we test (train once, eval many):
5+ # 1. PR #1060 baseline reproduction (no TTT): ~87s eval
6+ # 2. TTT with SOTA config (lr=0.002, 3ep): ~410s
7+ # 3. TTT with PR #1039 config (lr=0.0025, 4ep): ~410s
8+ # 4. TTT with periodic reset every 100 chunks: ~410s
9+ # 5. TTT with periodic reset every 50 chunks: ~410s
10+ #
11+ # Total: ~10min train + 87s baseline + 4×410s TTT ≈ 37 min
12+ # Cost: ~$15-20 on 8xH100
913#
1014# Usage:
1115# git clone https://github.com/icryo/parameter-golf.git
1216# cd parameter-golf && git checkout experiments/triton-kernels-qat-fix
1317# pip install sentencepiece huggingface-hub datasets flash-attn
1418# python3 data/cached_challenge_fineweb.py --variant sp1024
1519# ./run_h100.sh 2>&1 | tee full_experiment.log
16-
1720set -euo pipefail
1821SEED=" ${1:- 1337} "
1922NPROC=8
2023
2124echo " ============================================================"
22- echo " PARAMETER GOLF: Full H100 Experiment Suite "
25+ echo " PARAMETER GOLF: PR #1060 base + TTT reset experiments "
2326echo " Seed: $SEED | GPUs: $NPROC | $( date) "
2427echo " ============================================================"
2528
26- # === PHASE 1: Train (identical to SOTA, ~10 min) ===
29+ # === PHASE 1: Train with PR #1060 config ===
2730echo " "
28- echo " === PHASE 1: Training (SOTA reproduction) ==="
29- export RUN_ID=" h100_s${SEED} "
31+ echo " === PHASE 1: Training (PR #1060: coprime loader + Full GPTQ + XSA-all) ==="
32+
33+ export RUN_ID=" pr1060_ttt_s${SEED} "
3034export SEED=" $SEED "
3135export DATA_PATH=" ./data/datasets/fineweb10B_sp1024"
3236export TOKENIZER_PATH=" ./data/tokenizers/fineweb_1024_bpe.model"
@@ -35,35 +39,33 @@ export ITERATIONS=9000 WARMUP_STEPS=20 WARMDOWN_ITERS=3500
3539export TRAIN_BATCH_TOKENS=786432 TRAIN_SEQ_LEN=2048
3640export EVAL_SEQ_LEN=2048 EVAL_STRIDE=64
3741export NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3
38- export TIE_EMBEDDINGS=1 XSA_LAST_N=4 ROPE_DIMS=16 LN_SCALE=1
42+ export TIE_EMBEDDINGS=1 ROPE_DIMS=16 LN_SCALE=1
3943export VE_ENABLED=1 VE_DIM=128 VE_LAYERS=" 9,10"
40- export BIGRAM_VOCAB_SIZE=2048 BIGRAM_DIM=128 LOGIT_SOFTCAP=30.0
44+ export LOGIT_SOFTCAP=30.0
4145export MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035
4246export MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500
4347export MUON_WD=0.04 ADAM_WD=0.04 GRAD_CLIP_NORM=0.3
4448export SWA_ENABLED=1 SWA_EVERY=50
4549export LATE_QAT_THRESHOLD=0.15
46- export EVAL_TEMPERATURE=0
47- # Disable TTT in training script — we'll run it manually below
50+ export XSA_LAST_N=11
51+ export BIGRAM_VOCAB_SIZE=2816
52+ export BIGRAM_DIM=112
53+ export USE_GPTQ=1
54+ export GPTQ_RESERVE_MS=14000
4855export TTT_ENABLED=0
4956
50- torchrun --standalone --nproc_per_node=$NPROC train_gpt_safe .py
57+ torchrun --standalone --nproc_per_node=$NPROC train_gpt_ours .py
5158
5259echo " "
53- echo " === Training complete. Model saved to final_model.int6.ptz ==="
54- echo " === Now running eval sweep on saved checkpoint ==="
55-
56- # === PHASE 2: Eval sweep (multiple configs, ~40 min) ===
57- # Each eval_val_sliding takes ~74s, each TTT takes ~410s
58- # Uses all 8 GPUs for fast eval via torchrun
60+ echo " === Training complete. Now running TTT sweep. ==="
5961
60- cat > /tmp/eval_sweep.py << 'PYEOF '
62+ # === PHASE 2: TTT sweep on saved checkpoint ===
63+ cat > /tmp/ttt_sweep.py << 'PYEOF '
6164import torch, time, math, io, lzma, sys, os
6265import torch.distributed as dist
63-
6466sys.path.insert(0, '.')
65- from train_gpt_safe import (
66- Hyperparameters, GPT, CastedLinear, eval_val_sliding, eval_val_sliding_ttt,
67+ from train_gpt_ours import (
68+ Hyperparameters, GPT, CastedLinear, eval_val_sliding_ttt,
6769 dequantize_mixed_int6, _rebank_state_dict, _unbank_state_dict,
6870 build_sentencepiece_luts, load_validation_tokens, restore_low_dim_params_to_fp32,
6971)
@@ -82,27 +84,26 @@ master = (rank == 0)
8284def log(msg):
8385 if master: print(msg, flush=True)
8486
85- # Load tokenizer and validation data
8687sp = spm.SentencePieceProcessor(model_file=args.tokenizer_path)
8788val_tokens = load_validation_tokens(args.val_files, args.train_seq_len)
8889base_bytes_lut, has_leading_space_lut, is_boundary_token_lut = build_sentencepiece_luts(
8990 sp, args.vocab_size, device)
9091
91- # Load quantized model
9292with open('final_model.int6.ptz', 'rb') as f:
9393 quant_blob = f.read()
9494quant_state = torch.load(io.BytesIO(lzma.decompress(quant_blob)), map_location='cpu')
95- sd_cpu = {k: v.detach().cpu() for k, v in torch.load('final_model.pt', map_location='cpu', weights_only=True).items()}
95+ sd_cpu = {k: v.detach().cpu() for k, v in
96+ torch.load('final_model.pt', map_location='cpu', weights_only=True).items()}
9697unbanked = _unbank_state_dict(sd_cpu, args.num_layers)
9798deq = dequantize_mixed_int6(quant_state['w'], quant_state['m'], unbanked)
9899deq_banked = _rebank_state_dict(deq, args.num_layers, sd_cpu)
99100
100- def load_eval_model(softcap_scale=1.0 ):
101+ def load_fresh( ):
101102 m = GPT(
102103 vocab_size=args.vocab_size, num_layers=args.num_layers, model_dim=args.model_dim,
103104 num_heads=args.num_heads, num_kv_heads=args.num_kv_heads, mlp_mult=args.mlp_mult,
104105 tie_embeddings=args.tie_embeddings, tied_embed_init_std=args.tied_embed_init_std,
105- logit_softcap=args.logit_softcap * softcap_scale , rope_base=args.rope_base,
106+ logit_softcap=args.logit_softcap, rope_base=args.rope_base,
106107 qk_gain_init=args.qk_gain_init, mtp_num_heads=0, mtp_loss_weight=0.0,
107108 bigram_vocab_size=args.bigram_vocab_size, bigram_dim=args.bigram_dim,
108109 xsa_last_n=args.xsa_last_n, rope_dims=args.rope_dims, ln_scale=args.ln_scale,
@@ -120,25 +121,9 @@ def load_eval_model(softcap_scale=1.0):
120121 m.load_state_dict(deq_banked, strict=True)
121122 return m
122123
123- def run_sliding(label, temp=1.0, stride=64):
124- log(f' [{label}] T={temp:.2f} stride={stride}...')
125- model = load_eval_model(softcap_scale=temp)
126- torch.cuda.synchronize()
127- t0 = time.perf_counter()
128- loss, bpb = eval_val_sliding(
129- args, model, rank, world_size, device, val_tokens,
130- base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
131- stride=stride, eval_seq_len=args.eval_seq_len,
132- )
133- torch.cuda.synchronize()
134- elapsed = time.perf_counter() - t0
135- log(f' [{label}] val_bpb={bpb:.8f} time={elapsed:.1f}s')
136- del model; torch.cuda.empty_cache()
137- return bpb, elapsed
138-
139- def run_ttt(label, temp=1.0, ttt_lr=0.002, ttt_epochs=3, freeze_blocks=0, stride=64, reset_every=0):
140- log(f' [{label}] T={temp:.2f} lr={ttt_lr} ep={ttt_epochs} freeze={freeze_blocks} reset={reset_every}...')
141- model = load_eval_model(softcap_scale=temp)
124+ def run_ttt(label, ttt_lr, ttt_epochs, freeze_blocks=0, reset_every=0):
125+ log(f' [{label}] lr={ttt_lr} ep={ttt_epochs} reset={reset_every}...')
126+ model = load_fresh()
142127 args.ttt_lr = ttt_lr
143128 args.ttt_epochs = ttt_epochs
144129 args.ttt_freeze_blocks = freeze_blocks
@@ -152,103 +137,43 @@ def run_ttt(label, temp=1.0, ttt_lr=0.002, ttt_epochs=3, freeze_blocks=0, stride
152137 loss, bpb = eval_val_sliding_ttt(
153138 args, model, rank, world_size, device, val_tokens,
154139 base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
155- stride=stride , log0=log,
140+ stride=64 , log0=log,
156141 )
157142 torch.cuda.synchronize()
158143 elapsed = time.perf_counter() - t0
159144 log(f' [{label}] val_bpb={bpb:.8f} time={elapsed:.1f}s')
160145 del model; torch.cuda.empty_cache()
161- return bpb, elapsed
146+ return bpb
162147
163148log('')
164149log('=' * 60)
165- log('EVAL SWEEP: Temperature + Stride + TTT configs ')
150+ log('TTT SWEEP on PR #1060 (Full GPTQ + XSA-all) quantized model ')
166151log('=' * 60)
167152results = {}
168153
169- # --- Block 1: Temperature sweep at stride=64 (~6 x 74s = 444s) ---
170- log('')
171- log('--- Temperature sweep (stride=64, no TTT) ---')
172- for t in [0.85, 0.88, 0.90, 0.92, 0.95, 1.00]:
173- bpb, _ = run_sliding(f'temp_{t:.2f}', temp=t, stride=64)
174- results[f'sliding_T{t:.2f}_s64'] = bpb
175-
176- # Find best temperature
177- best_t = min([0.85, 0.88, 0.90, 0.92, 0.95, 1.00],
178- key=lambda t: results[f'sliding_T{t:.2f}_s64'])
179- best_bpb = results[f'sliding_T{best_t:.2f}_s64']
180- baseline_bpb = results['sliding_T1.00_s64']
181- temp_delta = best_bpb - baseline_bpb
182- log(f' >>> Best T={best_t:.2f} bpb={best_bpb:.8f} delta={temp_delta:+.8f} <<<')
154+ results['sota_ttt'] = run_ttt('sota_ttt', 0.002, 3)
155+ results['pr1039'] = run_ttt('pr1039', 0.0025, 4)
156+ results['reset100'] = run_ttt('reset100', 0.0025, 4, reset_every=100)
157+ results['reset50'] = run_ttt('reset50', 0.0025, 4, reset_every=50)
183158
184- # --- Block 2: TTT sweep with best temperature (~3 x 410s = 1230s) ---
185- log('')
186- log(f'--- TTT sweep (T={best_t:.2f}, stride=64) ---')
187-
188- # Config A: SOTA TTT (baseline)
189- bpb_a, t_a = run_ttt('ttt_sota', temp=best_t, ttt_lr=0.002, ttt_epochs=3, freeze_blocks=0)
190- results['ttt_sota'] = bpb_a
191-
192- # Config B: PR #1039 recipe (claimed 1.1184 BPB — potential record)
193- bpb_b, t_b = run_ttt('ttt_pr1039', temp=best_t, ttt_lr=0.0025, ttt_epochs=4, freeze_blocks=0)
194- results['ttt_pr1039'] = bpb_b
195-
196- # Config C: More epochs (deeper adaptation)
197- bpb_c, t_c = run_ttt('ttt_5ep', temp=best_t, ttt_lr=0.002, ttt_epochs=5, freeze_blocks=0)
198- results['ttt_5ep'] = bpb_c
199-
200- # Config D: PR #1039 + periodic reset every 100 chunks (anti-drift)
201- bpb_d, t_d = run_ttt('ttt_reset100', temp=best_t, ttt_lr=0.0025, ttt_epochs=4, freeze_blocks=0, reset_every=100)
202- results['ttt_reset100'] = bpb_d
203-
204- # Config E: Reset every 50 chunks (more aggressive anti-drift)
205- bpb_e, t_e = run_ttt('ttt_reset50', temp=best_t, ttt_lr=0.0025, ttt_epochs=4, freeze_blocks=0, reset_every=50)
206- results['ttt_reset50'] = bpb_e
207-
208- # --- Summary ---
209159log('')
210160log('=' * 60)
211- log('RESULTS SUMMARY ')
161+ log('RESULTS')
212162log('=' * 60)
213-
214- log('')
215- log('Temperature sweep (stride=64, no TTT):')
216- for t in [0.85, 0.88, 0.90, 0.92, 0.95, 1.00]:
217- bpb = results[f'sliding_T{t:.2f}_s64']
218- delta = bpb - baseline_bpb
219- marker = ' <<<' if t == best_t else ''
220- log(f' T={t:.2f} bpb={bpb:.8f} delta={delta:+.8f}{marker}')
221-
163+ log(f'PR #1060 no-TTT baseline: 1.1122 (their submission)')
222164log('')
223- log('TTT configurations:')
224- log(f' SOTA (lr=0.002, 3ep): bpb={bpb_a:.8f}')
225- log(f' PR1039 (lr=0.0025, 4ep): bpb={bpb_b:.8f} delta={bpb_b-bpb_a:+.8f}')
226- log(f' 5 epochs (lr=0.002, 5ep): bpb={bpb_c:.8f} delta={bpb_c-bpb_a:+.8f}')
227- log(f' PR1039 + reset/100 (anti-drift): bpb={bpb_d:.8f} delta={bpb_d-bpb_a:+.8f}')
228- log(f' PR1039 + reset/50 (anti-drift): bpb={bpb_e:.8f} delta={bpb_e-bpb_a:+.8f}')
229-
230- log('')
231- best_ttt = min(bpb_a, bpb_b, bpb_c, bpb_d, bpb_e)
232- log(f'SOTA reference (seed 1337): 1.11922988')
233- log(f'Our best result: {best_ttt:.8f}')
234- log(f'Delta vs SOTA: {best_ttt - 1.11922988:+.8f}')
235- log(f'Record threshold: 1.1144')
236- log(f'Gap to record: {best_ttt - 1.1144:+.8f}')
237-
238- if best_ttt < 1.1144:
239- log('>>> RECORD TERRITORY! Run 2 more seeds to confirm. <<<')
240- elif best_ttt < 1.1192:
241- log('>>> Better than SOTA seed, but need 3-seed mean. Worth 2 more seeds. <<<')
242- elif best_ttt < 1.1200:
243- log('>>> Close to SOTA. Temperature or TTT tuning helping marginally. <<<')
244- else:
245- log('>>> At or worse than SOTA. No path to record from these changes. <<<')
165+ for label, bpb in sorted(results.items(), key=lambda x: x[1]):
166+ log(f' {label:<20} bpb={bpb:.8f} vs_noTTT={bpb-1.1122:+.6f} vs_merged_SOTA={bpb-1.1194:+.6f}')
167+ best = min(results, key=results.get)
168+ log(f'\nBest: {best} = {results[best]:.8f}')
169+ log(f'Record threshold (vs PR #1060): <= 1.1072')
170+ log(f'Gap: {results[best] - 1.1072:+.8f}')
246171
247172if world_size > 1:
248173 dist.destroy_process_group()
249174PYEOF
250175
251- torchrun --standalone --nproc_per_node=$NPROC /tmp/eval_sweep .py
176+ torchrun --standalone --nproc_per_node=$NPROC /tmp/ttt_sweep .py
252177
253178echo " "
254179echo " ============================================================"
0 commit comments