@@ -1382,6 +1382,77 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
13821382 f"eval_time:{ 1000.0 * (time .perf_counter () - t_qeval ):.0f} ms"
13831383 )
13841384 log0 (f"final_int6_roundtrip_exact val_loss:{ q_val_loss :.8f} val_bpb:{ q_val_bpb :.8f} " )
1385+
1386+ # cosine pre-eval TTT (from PR #481/#486 — 30 epochs AdamW with cosine LR + per-layer LR)
1387+ ttt_epochs = int (os .environ .get ("TTT_EPOCHS" , 30 ))
1388+ ttt_lr = float (os .environ .get ("TTT_LR" , 0.0005 ))
1389+ if ttt_epochs > 0 :
1390+ torch .cuda .synchronize ()
1391+ t_ttt = time .perf_counter ()
1392+ log0 (f"ttt: starting { ttt_epochs } epochs, lr={ ttt_lr } , cosine+perlayer" )
1393+ # per-layer LR groups: 3x for MLP output projections, 0.5x for MLP input
1394+ proj_params , fc_params , other_params = [], [], []
1395+ for name , p in eval_model .named_parameters ():
1396+ p .requires_grad_ (True )
1397+ if "mlp.proj" in name :
1398+ proj_params .append (p )
1399+ elif "mlp.fc" in name :
1400+ fc_params .append (p )
1401+ else :
1402+ other_params .append (p )
1403+ ttt_opt = torch .optim .AdamW ([
1404+ {"params" : proj_params , "lr" : ttt_lr * 3.0 },
1405+ {"params" : fc_params , "lr" : ttt_lr * 0.5 },
1406+ {"params" : other_params , "lr" : ttt_lr },
1407+ ], weight_decay = 0.0 )
1408+ total_val = val_tokens .numel () - 1
1409+ ttt_batch = 32
1410+ rank_tokens = total_val // world_size
1411+ rank_start = rank * rank_tokens
1412+ rank_end = rank_start + rank_tokens
1413+ steps_per_epoch = max (1 , (rank_end - rank_start - args .train_seq_len ) // (ttt_batch * args .train_seq_len ))
1414+ total_steps = ttt_epochs * steps_per_epoch
1415+ global_step = 0
1416+ eval_model .train ()
1417+ for ep in range (ttt_epochs ):
1418+ ep_loss , ep_steps = 0.0 , 0
1419+ for bs in range (rank_start , rank_end - args .train_seq_len , ttt_batch * args .train_seq_len ):
1420+ be = min (bs + ttt_batch * args .train_seq_len + 1 , rank_end + 1 )
1421+ local = val_tokens [bs :be ].to (device = device , dtype = torch .int64 )
1422+ n = (local .numel () - 1 ) // args .train_seq_len
1423+ if n == 0 :
1424+ continue
1425+ x = local [:n * args .train_seq_len ].reshape (n , args .train_seq_len )
1426+ y = local [1 :n * args .train_seq_len + 1 ].reshape (n , args .train_seq_len )
1427+ # cosine LR schedule
1428+ progress = global_step / max (total_steps , 1 )
1429+ cos_mul = 0.5 * (1.0 + math .cos (math .pi * progress ))
1430+ for g in ttt_opt .param_groups :
1431+ g ["lr" ] = g .get ("initial_lr" , g ["lr" ]) * cos_mul
1432+ if global_step == 0 :
1433+ for g in ttt_opt .param_groups :
1434+ g ["initial_lr" ] = g ["lr" ]
1435+ ttt_opt .zero_grad ()
1436+ with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 , enabled = True ):
1437+ loss = eval_model (x , y )
1438+ loss .backward ()
1439+ # sync gradients across ranks
1440+ if distributed :
1441+ for p in eval_model .parameters ():
1442+ if p .grad is not None :
1443+ dist .all_reduce (p .grad , op = dist .ReduceOp .AVG )
1444+ torch .nn .utils .clip_grad_norm_ (eval_model .parameters (), 1.0 )
1445+ ttt_opt .step ()
1446+ ep_loss += loss .item ()
1447+ ep_steps += 1
1448+ global_step += 1
1449+ if master_process and (ep + 1 ) % 5 == 0 :
1450+ log0 (f"ttt_epoch:{ ep + 1 } /{ ttt_epochs } avg_loss:{ ep_loss / max (ep_steps , 1 ):.4f} " )
1451+ del ttt_opt
1452+ torch .cuda .empty_cache ()
1453+ torch .cuda .synchronize ()
1454+ log0 (f"ttt: completed in { 1000.0 * (time .perf_counter () - t_ttt ):.0f} ms" )
1455+
13851456 sw_seq_len = effective_eval_seq_len
13861457 if args .eval_stride > 0 and args .eval_stride < sw_seq_len :
13871458 torch .cuda .synchronize ()
0 commit comments