3333from torch import Tensor , nn
3434from torch .nn .parallel import DistributedDataParallel as DDP
3535
36+ try :
37+ from flash_attn .flash_attn_interface import flash_attn_func as _flash_attn_func
38+ _HAS_FA3 = True
39+ except ImportError :
40+ _HAS_FA3 = False
41+
3642# -----------------------------
3743# HYPERPARAMETERS
3844# -----------------------------
@@ -58,7 +64,7 @@ class Hyperparameters:
5864 qk_gain_init = float (os .environ .get ("QK_GAIN_INIT" , 1.5 ))
5965
6066 vocab_size = int (os .environ .get ("VOCAB_SIZE" , 1024 ))
61- num_layers = int (os .environ .get ("NUM_LAYERS" , 10 ))
67+ num_layers = int (os .environ .get ("NUM_LAYERS" , 11 ))
6268 num_kv_heads = int (os .environ .get ("NUM_KV_HEADS" , 4 ))
6369 model_dim = int (os .environ .get ("MODEL_DIM" , 512 ))
6470 num_heads = int (os .environ .get ("NUM_HEADS" , 8 ))
@@ -86,13 +92,33 @@ class Hyperparameters:
8692 eval_stride = int (os .environ .get ("EVAL_STRIDE" , 32 ))
8793 eval_batch_seqs = int (os .environ .get ("EVAL_BATCH_SEQS" , 32 ))
8894
89- bigram_vocab_size = int (os .environ .get ("BIGRAM_VOCAB_SIZE" , 12288 ))
95+ bigram_vocab_size = int (os .environ .get ("BIGRAM_VOCAB_SIZE" , 10240 ))
9096 bigram_dim = int (os .environ .get ("BIGRAM_DIM" , 128 ))
9197
92- swa_enabled = bool (int (os .environ .get ("SWA_ENABLED" , "1 " )))
98+ swa_enabled = bool (int (os .environ .get ("SWA_ENABLED" , "0 " ))) # disabled, using EMA instead
9399 swa_start_frac = float (os .environ .get ("SWA_START_FRAC" , 0.4 ))
94100 swa_every = int (os .environ .get ("SWA_EVERY" , 25 ))
95101
102+ # EMA (replaces SWA for #315-style training)
103+ ema_enabled = bool (int (os .environ .get ("EMA_ENABLED" , "1" )))
104+ ema_decay = float (os .environ .get ("EMA_DECAY" , 0.997 ))
105+
106+ # TTT: Test-Time Training on validation data after quantization
107+ ttt_enabled = bool (int (os .environ .get ("TTT_ENABLED" , "1" )))
108+ ttt_epochs = int (os .environ .get ("TTT_EPOCHS" , 25 ))
109+ ttt_lr = float (os .environ .get ("TTT_LR" , 0.008 ))
110+ ttt_momentum = float (os .environ .get ("TTT_MOMENTUM" , 0.9 ))
111+ ttt_batch_seqs = int (os .environ .get ("TTT_BATCH_SEQS" , 32 ))
112+
113+ # XSA: Exclusive Self-Attention on last N layers
114+ xsa_last_n = int (os .environ .get ("XSA_LAST_N" , 4 ))
115+
116+ # Partial RoPE: only apply RoPE to first N dims of each head
117+ rope_dims = int (os .environ .get ("ROPE_DIMS" , 16 ))
118+
119+ # LN Scale: scale norm output by 1/sqrt(layer+1)
120+ ln_scale = bool (int (os .environ .get ("LN_SCALE" , "1" )))
121+
96122# -----------------------------
97123# MUON OPTIMIZER
98124# -----------------------------
@@ -503,9 +529,10 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None:
503529
504530
505531class Rotary (nn .Module ):
506- def __init__ (self , dim : int , base : float = 10000.0 ):
532+ def __init__ (self , dim : int , base : float = 10000.0 , rope_dims : int = 0 ):
507533 super ().__init__ ()
508- inv_freq = 1.0 / (base ** (torch .arange (0 , dim , 2 , dtype = torch .float32 ) / dim ))
534+ self .rope_dims = rope_dims if rope_dims > 0 else dim
535+ inv_freq = 1.0 / (base ** (torch .arange (0 , self .rope_dims , 2 , dtype = torch .float32 ) / self .rope_dims ))
509536 self .register_buffer ("inv_freq" , inv_freq , persistent = False )
510537 self ._seq_len_cached = 0
511538 self ._cos_cached : Tensor | None = None
@@ -527,13 +554,20 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup
527554
528555
529556def apply_rotary_emb (x : Tensor , cos : Tensor , sin : Tensor ) -> Tensor :
557+ rd = cos .size (- 1 ) * 2 # number of RoPE dims
558+ if rd < x .size (- 1 ):
559+ x_rope , x_pass = x [..., :rd ], x [..., rd :]
560+ half = rd // 2
561+ x1 , x2 = x_rope [..., :half ], x_rope [..., half :]
562+ x_rot = torch .cat ((x1 * cos + x2 * sin , x1 * (- sin ) + x2 * cos ), dim = - 1 )
563+ return torch .cat ((x_rot , x_pass ), dim = - 1 )
530564 half = x .size (- 1 ) // 2
531565 x1 , x2 = x [..., :half ], x [..., half :]
532566 return torch .cat ((x1 * cos + x2 * sin , x1 * (- sin ) + x2 * cos ), dim = - 1 )
533567
534568
535569class CausalSelfAttention (nn .Module ):
536- def __init__ (self , dim : int , num_heads : int , num_kv_heads : int , rope_base : float , qk_gain_init : float ):
570+ def __init__ (self , dim : int , num_heads : int , num_kv_heads : int , rope_base : float , qk_gain_init : float , rope_dims : int = 0 , use_xsa : bool = False ):
537571 super ().__init__ ()
538572 if dim % num_heads != 0 :
539573 raise ValueError ("model_dim must be divisible by num_heads" )
@@ -542,6 +576,7 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float
542576 self .num_heads = num_heads
543577 self .num_kv_heads = num_kv_heads
544578 self .head_dim = dim // num_heads
579+ self .use_xsa = use_xsa
545580 if self .head_dim % 2 != 0 :
546581 raise ValueError ("head_dim must be even for RoPE" )
547582 kv_dim = self .num_kv_heads * self .head_dim
@@ -551,24 +586,54 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float
551586 self .proj = CastedLinear (dim , dim , bias = False )
552587 self .proj ._zero_init = True
553588 self .q_gain = nn .Parameter (torch .full ((num_heads ,), qk_gain_init , dtype = torch .float32 ))
554- self .rotary = Rotary (self .head_dim , base = rope_base )
589+ self .rotary = Rotary (self .head_dim , base = rope_base , rope_dims = rope_dims )
590+
591+ def _xsa_efficient (self , y : Tensor , v : Tensor ) -> Tensor :
592+ """Remove self-value component from attention output via orthogonal projection."""
593+ # y: (B, H, T, D), v: (B, Hkv, T, D)
594+ B , H , T , D = y .shape
595+ Hkv = v .size (1 )
596+ group = H // Hkv
597+ y_g = y .reshape (B , Hkv , group , T , D )
598+ vn = F .normalize (v , dim = - 1 ).unsqueeze (2 ) # (B, Hkv, 1, T, D)
599+ proj = (y_g * vn ).sum (dim = - 1 , keepdim = True ) * vn
600+ return (y_g - proj ).reshape (B , H , T , D )
555601
556602 def forward (self , x : Tensor ) -> Tensor :
557603 bsz , seqlen , dim = x .shape
558- q = self .c_q (x ).reshape (bsz , seqlen , self .num_heads , self .head_dim ).transpose (1 , 2 )
559- k = self .c_k (x ).reshape (bsz , seqlen , self .num_kv_heads , self .head_dim ).transpose (1 , 2 )
560- v = self .c_v (x ).reshape (bsz , seqlen , self .num_kv_heads , self .head_dim ).transpose (1 , 2 )
604+ q = self .c_q (x ).reshape (bsz , seqlen , self .num_heads , self .head_dim )
605+ k = self .c_k (x ).reshape (bsz , seqlen , self .num_kv_heads , self .head_dim )
606+ v = self .c_v (x ).reshape (bsz , seqlen , self .num_kv_heads , self .head_dim )
607+ # (B, T, H, D) -> (B, H, T, D) for norm/rope
608+ q = q .transpose (1 , 2 )
609+ k = k .transpose (1 , 2 )
610+ v = v .transpose (1 , 2 )
561611 q = F .rms_norm (q , (q .size (- 1 ),))
562612 k = F .rms_norm (k , (k .size (- 1 ),))
563613 cos , sin = self .rotary (seqlen , x .device , q .dtype )
564614 q = apply_rotary_emb (q , cos , sin )
565615 k = apply_rotary_emb (k , cos , sin )
566616 q = q * self .q_gain .to (dtype = q .dtype )[None , :, None , None ]
567- y = F .scaled_dot_product_attention (
568- q , k , v , attn_mask = None , is_causal = True ,
569- enable_gqa = (self .num_kv_heads != self .num_heads ),
570- )
571- y = y .transpose (1 , 2 ).contiguous ().reshape (bsz , seqlen , dim )
617+ if _HAS_FA3 :
618+ # FA3 expects (B, T, H, D)
619+ q_fa = q .transpose (1 , 2 )
620+ k_fa = k .transpose (1 , 2 )
621+ v_fa = v .transpose (1 , 2 )
622+ y = _flash_attn_func (q_fa , k_fa , v_fa , causal = True )
623+ # y is (B, T, H, D), convert to (B, H, T, D) for XSA
624+ if self .use_xsa :
625+ y = self ._xsa_efficient (y .transpose (1 , 2 ), v )
626+ y = y .transpose (1 , 2 ).contiguous ().reshape (bsz , seqlen , dim )
627+ else :
628+ y = y .contiguous ().reshape (bsz , seqlen , dim )
629+ else :
630+ y = F .scaled_dot_product_attention (
631+ q , k , v , attn_mask = None , is_causal = True ,
632+ enable_gqa = (self .num_kv_heads != self .num_heads ),
633+ )
634+ if self .use_xsa :
635+ y = self ._xsa_efficient (y , v )
636+ y = y .transpose (1 , 2 ).contiguous ().reshape (bsz , seqlen , dim )
572637 return self .proj (y )
573638
574639
@@ -625,11 +690,12 @@ def forward(self, token_ids: Tensor) -> Tensor:
625690
626691
627692class Block (nn .Module ):
628- def __init__ (self , dim : int , num_heads : int , num_kv_heads : int , mlp_mult : float , rope_base : float , qk_gain_init : float ):
693+ def __init__ (self , dim : int , num_heads : int , num_kv_heads : int , mlp_mult : float , rope_base : float , qk_gain_init : float , layer_idx : int = 0 , ln_scale : bool = False , rope_dims : int = 0 , use_xsa : bool = False ):
629694 super ().__init__ ()
630695 self .attn_norm = RMSNorm ()
631696 self .mlp_norm = RMSNorm ()
632- self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init )
697+ self .ln_scale_factor = 1.0 / math .sqrt (layer_idx + 1 ) if ln_scale else 1.0
698+ self .attn = CausalSelfAttention (dim , num_heads , num_kv_heads , rope_base , qk_gain_init , rope_dims = rope_dims , use_xsa = use_xsa )
633699 self .mlp = MLP (dim , mlp_mult )
634700 self .attn_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
635701 self .mlp_scale = nn .Parameter (torch .ones (dim , dtype = torch .float32 ))
@@ -638,9 +704,10 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float,
638704 def forward (self , x : Tensor , x0 : Tensor ) -> Tensor :
639705 mix = self .resid_mix .to (dtype = x .dtype )
640706 x = mix [0 ][None , None , :] * x + mix [1 ][None , None , :] * x0
641- attn_out = self .attn (self .attn_norm (x ))
707+ s = self .ln_scale_factor
708+ attn_out = self .attn (self .attn_norm (x ) * s )
642709 x = x + self .attn_scale .to (dtype = x .dtype )[None , None , :] * attn_out
643- x = x + self .mlp_scale .to (dtype = x .dtype )[None , None , :] * self .mlp (self .mlp_norm (x ))
710+ x = x + self .mlp_scale .to (dtype = x .dtype )[None , None , :] * self .mlp (self .mlp_norm (x ) * s )
644711 return x
645712
646713
@@ -660,6 +727,9 @@ def __init__(
660727 qk_gain_init : float ,
661728 bigram_vocab_size : int = 0 ,
662729 bigram_dim : int = 128 ,
730+ xsa_last_n : int = 0 ,
731+ rope_dims : int = 0 ,
732+ ln_scale : bool = False ,
663733 ):
664734 super ().__init__ ()
665735 if logit_softcap <= 0.0 :
@@ -676,8 +746,10 @@ def __init__(
676746 self .smear = SmearGate (model_dim )
677747 self .blocks = nn .ModuleList (
678748 [
679- Block (model_dim , num_heads , num_kv_heads , mlp_mult , rope_base , qk_gain_init )
680- for _ in range (num_layers )
749+ Block (model_dim , num_heads , num_kv_heads , mlp_mult , rope_base , qk_gain_init ,
750+ layer_idx = i , ln_scale = ln_scale , rope_dims = rope_dims ,
751+ use_xsa = (i >= num_layers - xsa_last_n ) if xsa_last_n > 0 else False )
752+ for i in range (num_layers )
681753 ]
682754 )
683755 self .final_norm = RMSNorm ()
@@ -929,17 +1001,23 @@ def log0(msg: str, console: bool = True) -> None:
9291001 qk_gain_init = args .qk_gain_init ,
9301002 bigram_vocab_size = args .bigram_vocab_size ,
9311003 bigram_dim = args .bigram_dim ,
1004+ xsa_last_n = args .xsa_last_n ,
1005+ rope_dims = args .rope_dims ,
1006+ ln_scale = args .ln_scale ,
9321007 ).to (device ).bfloat16 ()
9331008 for module in base_model .modules ():
9341009 if isinstance (module , CastedLinear ):
9351010 module .float ()
9361011 restore_low_dim_params_to_fp32 (base_model )
9371012 # QAT: fake-quantize during training so weights learn to be quantization-friendly
1013+ # Late-stage QAT: only enable in the last 20% of training to avoid hurting convergence
9381014 qat_enabled = bool (int (os .environ .get ("QAT_ENABLED" , "1" )))
1015+ qat_start_frac = float (os .environ .get ("QAT_START_FRAC" , "0.96" )) # enable QAT after this fraction of steps (final 4%)
1016+ qat_activated = False
9391017 if qat_enabled :
9401018 for name , module in base_model .named_modules ():
9411019 if isinstance (module , CastedLinear ):
942- module ._qat = True
1020+ module ._qat = False # start with QAT disabled
9431021 module ._qat_int5 = ".mlp." in name
9441022 compiled_model = torch .compile (base_model , dynamic = False , fullgraph = True )
9451023 model : nn .Module = DDP (compiled_model , device_ids = [local_rank ], broadcast_buffers = False ) if distributed else compiled_model
@@ -1060,6 +1138,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10601138 model .require_backward_grad_sync = True
10611139 train_loader = DistributedTokenLoader (args .train_files , rank , world_size , device )
10621140
1141+ # EMA state
1142+ ema_state : dict [str , Tensor ] | None = None
1143+ if args .ema_enabled :
1144+ ema_state = {name : t .detach ().cpu ().clone () for name , t in base_model .state_dict ().items ()}
1145+
10631146 # MAIN TRAINING LOOP
10641147 training_time_ms = 0.0
10651148 stop_after_step : int | None = None
@@ -1127,6 +1210,26 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
11271210 step += 1
11281211 approx_training_time_ms = training_time_ms + 1000.0 * (time .perf_counter () - t0 )
11291212
1213+ # Late-stage QAT: enable fake-quantize after qat_start_frac of training
1214+ if qat_enabled and not qat_activated :
1215+ # Estimate progress: use wallclock fraction if available, else step fraction
1216+ if max_wallclock_ms is not None :
1217+ progress = approx_training_time_ms / max_wallclock_ms
1218+ else :
1219+ progress = step / args .iterations
1220+ if progress >= qat_start_frac :
1221+ for module in base_model .modules ():
1222+ if isinstance (module , CastedLinear ):
1223+ module ._qat = True
1224+ qat_activated = True
1225+ log0 (f"qat:enabled at step:{ step } progress:{ progress :.3f} " )
1226+
1227+ # EMA: update exponential moving average every step
1228+ if args .ema_enabled and ema_state is not None :
1229+ decay = args .ema_decay
1230+ for name , t in base_model .state_dict ().items ():
1231+ ema_state [name ].mul_ (decay ).add_ (t .detach ().cpu (), alpha = 1.0 - decay )
1232+
11301233 # SWA: collect checkpoints during warmdown
11311234 if args .swa_enabled and scale < args .swa_start_frac and step % args .swa_every == 0 :
11321235 if swa_state is None :
@@ -1161,6 +1264,16 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
11611264 f"reserved: { torch .cuda .max_memory_reserved () // 1024 // 1024 } MiB"
11621265 )
11631266
1267+ # Apply EMA if enabled
1268+ if args .ema_enabled and ema_state is not None :
1269+ log0 (f"ema:applying decay={ args .ema_decay } " )
1270+ current_state = base_model .state_dict ()
1271+ ema_applied = {
1272+ name : tensor .to (dtype = current_state [name ].dtype )
1273+ for name , tensor in ema_state .items ()
1274+ }
1275+ base_model .load_state_dict (ema_applied , strict = True )
1276+
11641277 # Apply SWA if collected
11651278 if args .swa_enabled and swa_state is not None and swa_count > 1 :
11661279 log0 (f"swa:applying averaged { swa_count } checkpoints" )
@@ -1218,6 +1331,40 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
12181331 deq_state = dequantize_mixed_int6 (quant_state ["w" ], quant_state ["m" ], sd_cpu )
12191332 base_model .load_state_dict (deq_state , strict = True )
12201333
1334+ # TTT: Test-Time Training — adapt quantized model on val data before final eval
1335+ if args .ttt_enabled :
1336+ torch .cuda .synchronize ()
1337+ t_ttt = time .perf_counter ()
1338+ log0 (f"ttt:start epochs:{ args .ttt_epochs } lr:{ args .ttt_lr } batch_seqs:{ args .ttt_batch_seqs } " )
1339+ base_model .train ()
1340+ ttt_optimizer = torch .optim .SGD (
1341+ base_model .parameters (), lr = args .ttt_lr , momentum = args .ttt_momentum ,
1342+ )
1343+ seq_len = args .train_seq_len
1344+ n_val = val_tokens .numel () - 1
1345+ n_seqs = n_val // seq_len
1346+ for ttt_ep in range (args .ttt_epochs ):
1347+ perm = torch .randperm (n_seqs )
1348+ ttt_loss_sum = 0.0
1349+ ttt_loss_count = 0
1350+ for batch_start in range (0 , n_seqs , args .ttt_batch_seqs ):
1351+ batch_end = min (batch_start + args .ttt_batch_seqs , n_seqs )
1352+ indices = perm [batch_start :batch_end ]
1353+ batch_x = torch .stack ([val_tokens [i * seq_len : i * seq_len + seq_len ] for i in indices ]).to (device = device , dtype = torch .int64 )
1354+ batch_y = torch .stack ([val_tokens [i * seq_len + 1 : i * seq_len + seq_len + 1 ] for i in indices ]).to (device = device , dtype = torch .int64 )
1355+ ttt_optimizer .zero_grad (set_to_none = True )
1356+ with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 , enabled = True ):
1357+ loss = base_model (batch_x , batch_y )
1358+ loss .backward ()
1359+ ttt_optimizer .step ()
1360+ ttt_loss_sum += loss .item () * (batch_end - batch_start )
1361+ ttt_loss_count += batch_end - batch_start
1362+ if ttt_ep == 0 or (ttt_ep + 1 ) % 5 == 0 or ttt_ep == args .ttt_epochs - 1 :
1363+ log0 (f"ttt:epoch:{ ttt_ep + 1 } /{ args .ttt_epochs } loss:{ ttt_loss_sum / max (ttt_loss_count , 1 ):.4f} " )
1364+ base_model .eval ()
1365+ torch .cuda .synchronize ()
1366+ log0 (f"ttt:done time:{ 1000.0 * (time .perf_counter () - t_ttt ):.0f} ms" )
1367+
12211368 # Sliding window eval on int6-roundtripped weights
12221369 torch .cuda .synchronize ()
12231370 t_qeval = time .perf_counter ()
0 commit comments