11"""
2- 11L EMA + AdamW TTT + TrigramHash + Value Residual + Gradient-Guided Quantization
3-
4- Built on PR #398/#442 baseline with three novel additions:
5- 1. TrigramHash(4096): hash-based trigram embeddings extending BigramHash to 3-token context
6- 2. Value Residual (ResFormer, arXiv:2410.17897): cache V from layer 0, blend into all layers
7- 3. Gradient-Guided Quantization: adaptive Int5/6/7 per-tensor based on gradient sensitivity
8-
9- Mean val_bpb: 1.1132 (3 seeds), best: 1.1101
2+ train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE +
3+ fp16 embed + late-K passthrough + sliding window eval.
104"""
115
126from __future__ import annotations
@@ -117,18 +111,18 @@ class Hyperparameters:
117111 # TTT (Test-Time Training)
118112 ttt_enabled = bool (int (os .environ .get ("TTT_ENABLED" , "1" )))
119113 ttt_lr = float (os .environ .get ("TTT_LR" , 0.0005 ))
120- ttt_epochs = int (os .environ .get ("TTT_EPOCHS" , 10 ))
114+ ttt_epochs = int (os .environ .get ("TTT_EPOCHS" , 30 ))
121115 ttt_momentum = float (os .environ .get ("TTT_MOMENTUM" , 0.9 ))
122116 ttt_batch_seqs = int (os .environ .get ("TTT_BATCH_SEQS" , 32 ))
123117 ttt_freeze_blocks = int (os .environ .get ("TTT_FREEZE_BLOCKS" , 0 ))
124118 bigram_vocab_size = int (os .environ .get ("BIGRAM_VOCAB_SIZE" , 2048 ))
125119 bigram_dim = int (os .environ .get ("BIGRAM_DIM" , 128 ))
126- # TrigramHash embedding (3-token context via hash table )
120+ # TrigramHash (our unique addition )
127121 trigram_vocab_size = int (os .environ .get ("TRIGRAM_VOCAB_SIZE" , 4096 ))
128122 trigram_dim = int (os .environ .get ("TRIGRAM_DIM" , 128 ))
129- # Value Residual (ResFormer, arXiv:2410.17897 )
123+ # Value Residual (from PR #413, ResFormer — -0.015 BPB for 18 params )
130124 value_residual = bool (int (os .environ .get ("VALUE_RESIDUAL" , "1" )))
131- # Gradient-Guided Adaptive Quantization
125+ # Gradient-Guided Quantization (from PR #332)
132126 grad_quant = bool (int (os .environ .get ("GRAD_QUANT" , "1" )))
133127
134128# -----------------------------
@@ -1199,10 +1193,12 @@ def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object],
11991193# -----------------------------
12001194
12011195def ttt_adapt (args , base_model , device , val_tokens , rank = 0 , world_size = 1 , log_fn = None ):
1202- """Full-weight TTT: SGD adaptation on val data with DDP across all GPUs ."""
1196+ """Full-weight TTT with cosine LR decay and per-layer LR (from PR #481) ."""
12031197 seq_len = args .train_seq_len
12041198 total_seqs = (val_tokens .numel () - 1 ) // seq_len
12051199 batch_seqs = args .ttt_batch_seqs
1200+ ttt_cosine = bool (int (os .environ .get ("TTT_COSINE" , "1" )))
1201+ ttt_perlayer = bool (int (os .environ .get ("TTT_PERLAYER" , "1" )))
12061202
12071203 frozen_params = set ()
12081204 if args .ttt_freeze_blocks > 0 :
@@ -1212,24 +1208,62 @@ def ttt_adapt(args, base_model, device, val_tokens, rank=0, world_size=1, log_fn
12121208 p .requires_grad_ (False )
12131209 frozen_params .add (id (p ))
12141210
1215- ttt_params = [p for p in base_model .parameters () if p .requires_grad ]
1211+ # Per-layer LR: MLP output projections get 3× LR (most quant damage),
1212+ # MLP input projections get 0.5× LR (least damage)
12161213 ttt_use_adamw = bool (int (os .environ .get ("TTT_ADAMW" , "1" )))
1214+ if ttt_perlayer :
1215+ proj_params = [p for n , p in base_model .named_parameters ()
1216+ if "mlp.proj" in n and p .requires_grad and id (p ) not in frozen_params ]
1217+ fc_params = [p for n , p in base_model .named_parameters ()
1218+ if "mlp.fc" in n and p .requires_grad and id (p ) not in frozen_params ]
1219+ other_params = [p for p in base_model .parameters ()
1220+ if p .requires_grad and id (p ) not in frozen_params
1221+ and id (p ) not in {id (q ) for q in proj_params + fc_params }]
1222+ param_groups = [g for g in [
1223+ {"params" : proj_params , "lr" : args .ttt_lr * 3.0 },
1224+ {"params" : fc_params , "lr" : args .ttt_lr * 0.5 },
1225+ {"params" : other_params , "lr" : args .ttt_lr },
1226+ ] if g ["params" ]]
1227+ ttt_params = proj_params + fc_params + other_params
1228+ else :
1229+ ttt_params = [p for p in base_model .parameters () if p .requires_grad and id (p ) not in frozen_params ]
1230+ param_groups = [{"params" : ttt_params , "lr" : args .ttt_lr }]
1231+
12171232 if ttt_use_adamw :
1218- optimizer = torch .optim .AdamW (ttt_params , lr = args . ttt_lr , weight_decay = 0.0 )
1233+ optimizer = torch .optim .AdamW (param_groups , weight_decay = 0.0 )
12191234 else :
1220- optimizer = torch .optim .SGD (ttt_params , lr = args .ttt_lr , momentum = args .ttt_momentum )
1235+ optimizer = torch .optim .SGD (param_groups , momentum = args .ttt_momentum )
1236+
1237+ # Store initial LR for cosine schedule
1238+ if ttt_cosine :
1239+ for g in optimizer .param_groups :
1240+ g ["initial_lr" ] = g ["lr" ]
12211241
12221242 my_start = (total_seqs * rank ) // world_size
12231243 my_end = (total_seqs * (rank + 1 )) // world_size
1244+ steps_per_epoch = (my_end - my_start ) // max (batch_seqs , 1 )
1245+ total_steps = args .ttt_epochs * steps_per_epoch
1246+ global_step = 0
12241247
12251248 base_model .train ()
12261249 t0 = time .perf_counter ()
12271250
1251+ if log_fn :
1252+ n_ttt = sum (p .numel () for p in ttt_params )
1253+ log_fn (f"ttt:config params:{ n_ttt } cosine:{ ttt_cosine } perlayer:{ ttt_perlayer } " )
1254+
12281255 for epoch in range (args .ttt_epochs ):
12291256 epoch_loss_sum = torch .zeros ((), device = device , dtype = torch .float64 )
12301257 epoch_tokens = torch .zeros ((), device = device , dtype = torch .float64 )
12311258
12321259 for batch_start in range (my_start , my_end , batch_seqs ):
1260+ # Cosine LR decay
1261+ if ttt_cosine and total_steps > 0 :
1262+ progress = global_step / total_steps
1263+ mul = 0.5 * (1.0 + math .cos (math .pi * min (progress , 1.0 )))
1264+ for g in optimizer .param_groups :
1265+ g ["lr" ] = g ["initial_lr" ] * mul
1266+
12331267 batch_end = min (batch_start + batch_seqs , my_end )
12341268 raw_start = batch_start * seq_len
12351269 raw_end = batch_end * seq_len + 1
@@ -1252,6 +1286,7 @@ def ttt_adapt(args, base_model, device, val_tokens, rank=0, world_size=1, log_fn
12521286
12531287 epoch_loss_sum += loss .detach ().to (torch .float64 ) * y .numel ()
12541288 epoch_tokens += float (y .numel ())
1289+ global_step += 1
12551290
12561291 if world_size > 1 :
12571292 dist .all_reduce (epoch_loss_sum , op = dist .ReduceOp .SUM )
0 commit comments