Skip to content

Commit 38dff06

Browse files
committed
Add openai#315/openai#388 full stack: 11L, XSA4, Partial RoPE, LN Scale, EMA, Late QAT, TTT
Major rewrite targeting top-5 leaderboard: - 11 layers (from 10), BigramHash reduced to 10240 to fit 16MB - XSA (Exclusive Self-Attention) on last 4 layers - Partial RoPE: 16/64 head dims get position encoding - LN Scale: 1/sqrt(layer+1) dampening on deeper layers - EMA (decay=0.997) replaces SWA - Late QAT: STE int6 enabled only in final 4% of training - TTT: 25-epoch SGD on val data post-quantization - FA3 auto-detection with SDPA fallback - Reverted SwiGLU back to relu² (confirmed worse by openai#340, openai#344)
1 parent a3b1212 commit 38dff06

1 file changed

Lines changed: 169 additions & 22 deletions

File tree

  • records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32

records/track_10min_16mb/2026-03-21_QAT_BigramHash12K_Stride32/train_gpt.py

Lines changed: 169 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@
3333
from torch import Tensor, nn
3434
from torch.nn.parallel import DistributedDataParallel as DDP
3535

36+
try:
37+
from flash_attn.flash_attn_interface import flash_attn_func as _flash_attn_func
38+
_HAS_FA3 = True
39+
except ImportError:
40+
_HAS_FA3 = False
41+
3642
# -----------------------------
3743
# HYPERPARAMETERS
3844
# -----------------------------
@@ -58,7 +64,7 @@ class Hyperparameters:
5864
qk_gain_init = float(os.environ.get("QK_GAIN_INIT", 1.5))
5965

6066
vocab_size = int(os.environ.get("VOCAB_SIZE", 1024))
61-
num_layers = int(os.environ.get("NUM_LAYERS", 10))
67+
num_layers = int(os.environ.get("NUM_LAYERS", 11))
6268
num_kv_heads = int(os.environ.get("NUM_KV_HEADS", 4))
6369
model_dim = int(os.environ.get("MODEL_DIM", 512))
6470
num_heads = int(os.environ.get("NUM_HEADS", 8))
@@ -86,13 +92,33 @@ class Hyperparameters:
8692
eval_stride = int(os.environ.get("EVAL_STRIDE", 32))
8793
eval_batch_seqs = int(os.environ.get("EVAL_BATCH_SEQS", 32))
8894

89-
bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 12288))
95+
bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 10240))
9096
bigram_dim = int(os.environ.get("BIGRAM_DIM", 128))
9197

92-
swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "1")))
98+
swa_enabled = bool(int(os.environ.get("SWA_ENABLED", "0"))) # disabled, using EMA instead
9399
swa_start_frac = float(os.environ.get("SWA_START_FRAC", 0.4))
94100
swa_every = int(os.environ.get("SWA_EVERY", 25))
95101

102+
# EMA (replaces SWA for #315-style training)
103+
ema_enabled = bool(int(os.environ.get("EMA_ENABLED", "1")))
104+
ema_decay = float(os.environ.get("EMA_DECAY", 0.997))
105+
106+
# TTT: Test-Time Training on validation data after quantization
107+
ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1")))
108+
ttt_epochs = int(os.environ.get("TTT_EPOCHS", 25))
109+
ttt_lr = float(os.environ.get("TTT_LR", 0.008))
110+
ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9))
111+
ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32))
112+
113+
# XSA: Exclusive Self-Attention on last N layers
114+
xsa_last_n = int(os.environ.get("XSA_LAST_N", 4))
115+
116+
# Partial RoPE: only apply RoPE to first N dims of each head
117+
rope_dims = int(os.environ.get("ROPE_DIMS", 16))
118+
119+
# LN Scale: scale norm output by 1/sqrt(layer+1)
120+
ln_scale = bool(int(os.environ.get("LN_SCALE", "1")))
121+
96122
# -----------------------------
97123
# MUON OPTIMIZER
98124
# -----------------------------
@@ -503,9 +529,10 @@ def restore_low_dim_params_to_fp32(module: nn.Module) -> None:
503529

504530

505531
class Rotary(nn.Module):
506-
def __init__(self, dim: int, base: float = 10000.0):
532+
def __init__(self, dim: int, base: float = 10000.0, rope_dims: int = 0):
507533
super().__init__()
508-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
534+
self.rope_dims = rope_dims if rope_dims > 0 else dim
535+
inv_freq = 1.0 / (base ** (torch.arange(0, self.rope_dims, 2, dtype=torch.float32) / self.rope_dims))
509536
self.register_buffer("inv_freq", inv_freq, persistent=False)
510537
self._seq_len_cached = 0
511538
self._cos_cached: Tensor | None = None
@@ -527,13 +554,20 @@ def forward(self, seq_len: int, device: torch.device, dtype: torch.dtype) -> tup
527554

528555

529556
def apply_rotary_emb(x: Tensor, cos: Tensor, sin: Tensor) -> Tensor:
557+
rd = cos.size(-1) * 2 # number of RoPE dims
558+
if rd < x.size(-1):
559+
x_rope, x_pass = x[..., :rd], x[..., rd:]
560+
half = rd // 2
561+
x1, x2 = x_rope[..., :half], x_rope[..., half:]
562+
x_rot = torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1)
563+
return torch.cat((x_rot, x_pass), dim=-1)
530564
half = x.size(-1) // 2
531565
x1, x2 = x[..., :half], x[..., half:]
532566
return torch.cat((x1 * cos + x2 * sin, x1 * (-sin) + x2 * cos), dim=-1)
533567

534568

535569
class CausalSelfAttention(nn.Module):
536-
def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float):
570+
def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float, qk_gain_init: float, rope_dims: int = 0, use_xsa: bool = False):
537571
super().__init__()
538572
if dim % num_heads != 0:
539573
raise ValueError("model_dim must be divisible by num_heads")
@@ -542,6 +576,7 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float
542576
self.num_heads = num_heads
543577
self.num_kv_heads = num_kv_heads
544578
self.head_dim = dim // num_heads
579+
self.use_xsa = use_xsa
545580
if self.head_dim % 2 != 0:
546581
raise ValueError("head_dim must be even for RoPE")
547582
kv_dim = self.num_kv_heads * self.head_dim
@@ -551,24 +586,54 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, rope_base: float
551586
self.proj = CastedLinear(dim, dim, bias=False)
552587
self.proj._zero_init = True
553588
self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
554-
self.rotary = Rotary(self.head_dim, base=rope_base)
589+
self.rotary = Rotary(self.head_dim, base=rope_base, rope_dims=rope_dims)
590+
591+
def _xsa_efficient(self, y: Tensor, v: Tensor) -> Tensor:
592+
"""Remove self-value component from attention output via orthogonal projection."""
593+
# y: (B, H, T, D), v: (B, Hkv, T, D)
594+
B, H, T, D = y.shape
595+
Hkv = v.size(1)
596+
group = H // Hkv
597+
y_g = y.reshape(B, Hkv, group, T, D)
598+
vn = F.normalize(v, dim=-1).unsqueeze(2) # (B, Hkv, 1, T, D)
599+
proj = (y_g * vn).sum(dim=-1, keepdim=True) * vn
600+
return (y_g - proj).reshape(B, H, T, D)
555601

556602
def forward(self, x: Tensor) -> Tensor:
557603
bsz, seqlen, dim = x.shape
558-
q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2)
559-
k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
560-
v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
604+
q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim)
605+
k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim)
606+
v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim)
607+
# (B, T, H, D) -> (B, H, T, D) for norm/rope
608+
q = q.transpose(1, 2)
609+
k = k.transpose(1, 2)
610+
v = v.transpose(1, 2)
561611
q = F.rms_norm(q, (q.size(-1),))
562612
k = F.rms_norm(k, (k.size(-1),))
563613
cos, sin = self.rotary(seqlen, x.device, q.dtype)
564614
q = apply_rotary_emb(q, cos, sin)
565615
k = apply_rotary_emb(k, cos, sin)
566616
q = q * self.q_gain.to(dtype=q.dtype)[None, :, None, None]
567-
y = F.scaled_dot_product_attention(
568-
q, k, v, attn_mask=None, is_causal=True,
569-
enable_gqa=(self.num_kv_heads != self.num_heads),
570-
)
571-
y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim)
617+
if _HAS_FA3:
618+
# FA3 expects (B, T, H, D)
619+
q_fa = q.transpose(1, 2)
620+
k_fa = k.transpose(1, 2)
621+
v_fa = v.transpose(1, 2)
622+
y = _flash_attn_func(q_fa, k_fa, v_fa, causal=True)
623+
# y is (B, T, H, D), convert to (B, H, T, D) for XSA
624+
if self.use_xsa:
625+
y = self._xsa_efficient(y.transpose(1, 2), v)
626+
y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim)
627+
else:
628+
y = y.contiguous().reshape(bsz, seqlen, dim)
629+
else:
630+
y = F.scaled_dot_product_attention(
631+
q, k, v, attn_mask=None, is_causal=True,
632+
enable_gqa=(self.num_kv_heads != self.num_heads),
633+
)
634+
if self.use_xsa:
635+
y = self._xsa_efficient(y, v)
636+
y = y.transpose(1, 2).contiguous().reshape(bsz, seqlen, dim)
572637
return self.proj(y)
573638

574639

@@ -625,11 +690,12 @@ def forward(self, token_ids: Tensor) -> Tensor:
625690

626691

627692
class Block(nn.Module):
628-
def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float):
693+
def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float, rope_base: float, qk_gain_init: float, layer_idx: int = 0, ln_scale: bool = False, rope_dims: int = 0, use_xsa: bool = False):
629694
super().__init__()
630695
self.attn_norm = RMSNorm()
631696
self.mlp_norm = RMSNorm()
632-
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init)
697+
self.ln_scale_factor = 1.0 / math.sqrt(layer_idx + 1) if ln_scale else 1.0
698+
self.attn = CausalSelfAttention(dim, num_heads, num_kv_heads, rope_base, qk_gain_init, rope_dims=rope_dims, use_xsa=use_xsa)
633699
self.mlp = MLP(dim, mlp_mult)
634700
self.attn_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
635701
self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
@@ -638,9 +704,10 @@ def __init__(self, dim: int, num_heads: int, num_kv_heads: int, mlp_mult: float,
638704
def forward(self, x: Tensor, x0: Tensor) -> Tensor:
639705
mix = self.resid_mix.to(dtype=x.dtype)
640706
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
641-
attn_out = self.attn(self.attn_norm(x))
707+
s = self.ln_scale_factor
708+
attn_out = self.attn(self.attn_norm(x) * s)
642709
x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out
643-
x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
710+
x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x) * s)
644711
return x
645712

646713

@@ -660,6 +727,9 @@ def __init__(
660727
qk_gain_init: float,
661728
bigram_vocab_size: int = 0,
662729
bigram_dim: int = 128,
730+
xsa_last_n: int = 0,
731+
rope_dims: int = 0,
732+
ln_scale: bool = False,
663733
):
664734
super().__init__()
665735
if logit_softcap <= 0.0:
@@ -676,8 +746,10 @@ def __init__(
676746
self.smear = SmearGate(model_dim)
677747
self.blocks = nn.ModuleList(
678748
[
679-
Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init)
680-
for _ in range(num_layers)
749+
Block(model_dim, num_heads, num_kv_heads, mlp_mult, rope_base, qk_gain_init,
750+
layer_idx=i, ln_scale=ln_scale, rope_dims=rope_dims,
751+
use_xsa=(i >= num_layers - xsa_last_n) if xsa_last_n > 0 else False)
752+
for i in range(num_layers)
681753
]
682754
)
683755
self.final_norm = RMSNorm()
@@ -929,17 +1001,23 @@ def log0(msg: str, console: bool = True) -> None:
9291001
qk_gain_init=args.qk_gain_init,
9301002
bigram_vocab_size=args.bigram_vocab_size,
9311003
bigram_dim=args.bigram_dim,
1004+
xsa_last_n=args.xsa_last_n,
1005+
rope_dims=args.rope_dims,
1006+
ln_scale=args.ln_scale,
9321007
).to(device).bfloat16()
9331008
for module in base_model.modules():
9341009
if isinstance(module, CastedLinear):
9351010
module.float()
9361011
restore_low_dim_params_to_fp32(base_model)
9371012
# QAT: fake-quantize during training so weights learn to be quantization-friendly
1013+
# Late-stage QAT: only enable in the last 20% of training to avoid hurting convergence
9381014
qat_enabled = bool(int(os.environ.get("QAT_ENABLED", "1")))
1015+
qat_start_frac = float(os.environ.get("QAT_START_FRAC", "0.96")) # enable QAT after this fraction of steps (final 4%)
1016+
qat_activated = False
9391017
if qat_enabled:
9401018
for name, module in base_model.named_modules():
9411019
if isinstance(module, CastedLinear):
942-
module._qat = True
1020+
module._qat = False # start with QAT disabled
9431021
module._qat_int5 = ".mlp." in name
9441022
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
9451023
model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model
@@ -1060,6 +1138,11 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
10601138
model.require_backward_grad_sync = True
10611139
train_loader = DistributedTokenLoader(args.train_files, rank, world_size, device)
10621140

1141+
# EMA state
1142+
ema_state: dict[str, Tensor] | None = None
1143+
if args.ema_enabled:
1144+
ema_state = {name: t.detach().cpu().clone() for name, t in base_model.state_dict().items()}
1145+
10631146
# MAIN TRAINING LOOP
10641147
training_time_ms = 0.0
10651148
stop_after_step: int | None = None
@@ -1127,6 +1210,26 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
11271210
step += 1
11281211
approx_training_time_ms = training_time_ms + 1000.0 * (time.perf_counter() - t0)
11291212

1213+
# Late-stage QAT: enable fake-quantize after qat_start_frac of training
1214+
if qat_enabled and not qat_activated:
1215+
# Estimate progress: use wallclock fraction if available, else step fraction
1216+
if max_wallclock_ms is not None:
1217+
progress = approx_training_time_ms / max_wallclock_ms
1218+
else:
1219+
progress = step / args.iterations
1220+
if progress >= qat_start_frac:
1221+
for module in base_model.modules():
1222+
if isinstance(module, CastedLinear):
1223+
module._qat = True
1224+
qat_activated = True
1225+
log0(f"qat:enabled at step:{step} progress:{progress:.3f}")
1226+
1227+
# EMA: update exponential moving average every step
1228+
if args.ema_enabled and ema_state is not None:
1229+
decay = args.ema_decay
1230+
for name, t in base_model.state_dict().items():
1231+
ema_state[name].mul_(decay).add_(t.detach().cpu(), alpha=1.0 - decay)
1232+
11301233
# SWA: collect checkpoints during warmdown
11311234
if args.swa_enabled and scale < args.swa_start_frac and step % args.swa_every == 0:
11321235
if swa_state is None:
@@ -1161,6 +1264,16 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
11611264
f"reserved: {torch.cuda.max_memory_reserved() // 1024 // 1024} MiB"
11621265
)
11631266

1267+
# Apply EMA if enabled
1268+
if args.ema_enabled and ema_state is not None:
1269+
log0(f"ema:applying decay={args.ema_decay}")
1270+
current_state = base_model.state_dict()
1271+
ema_applied = {
1272+
name: tensor.to(dtype=current_state[name].dtype)
1273+
for name, tensor in ema_state.items()
1274+
}
1275+
base_model.load_state_dict(ema_applied, strict=True)
1276+
11641277
# Apply SWA if collected
11651278
if args.swa_enabled and swa_state is not None and swa_count > 1:
11661279
log0(f"swa:applying averaged {swa_count} checkpoints")
@@ -1218,6 +1331,40 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
12181331
deq_state = dequantize_mixed_int6(quant_state["w"], quant_state["m"], sd_cpu)
12191332
base_model.load_state_dict(deq_state, strict=True)
12201333

1334+
# TTT: Test-Time Training — adapt quantized model on val data before final eval
1335+
if args.ttt_enabled:
1336+
torch.cuda.synchronize()
1337+
t_ttt = time.perf_counter()
1338+
log0(f"ttt:start epochs:{args.ttt_epochs} lr:{args.ttt_lr} batch_seqs:{args.ttt_batch_seqs}")
1339+
base_model.train()
1340+
ttt_optimizer = torch.optim.SGD(
1341+
base_model.parameters(), lr=args.ttt_lr, momentum=args.ttt_momentum,
1342+
)
1343+
seq_len = args.train_seq_len
1344+
n_val = val_tokens.numel() - 1
1345+
n_seqs = n_val // seq_len
1346+
for ttt_ep in range(args.ttt_epochs):
1347+
perm = torch.randperm(n_seqs)
1348+
ttt_loss_sum = 0.0
1349+
ttt_loss_count = 0
1350+
for batch_start in range(0, n_seqs, args.ttt_batch_seqs):
1351+
batch_end = min(batch_start + args.ttt_batch_seqs, n_seqs)
1352+
indices = perm[batch_start:batch_end]
1353+
batch_x = torch.stack([val_tokens[i * seq_len : i * seq_len + seq_len] for i in indices]).to(device=device, dtype=torch.int64)
1354+
batch_y = torch.stack([val_tokens[i * seq_len + 1 : i * seq_len + seq_len + 1] for i in indices]).to(device=device, dtype=torch.int64)
1355+
ttt_optimizer.zero_grad(set_to_none=True)
1356+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
1357+
loss = base_model(batch_x, batch_y)
1358+
loss.backward()
1359+
ttt_optimizer.step()
1360+
ttt_loss_sum += loss.item() * (batch_end - batch_start)
1361+
ttt_loss_count += batch_end - batch_start
1362+
if ttt_ep == 0 or (ttt_ep + 1) % 5 == 0 or ttt_ep == args.ttt_epochs - 1:
1363+
log0(f"ttt:epoch:{ttt_ep + 1}/{args.ttt_epochs} loss:{ttt_loss_sum / max(ttt_loss_count, 1):.4f}")
1364+
base_model.eval()
1365+
torch.cuda.synchronize()
1366+
log0(f"ttt:done time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms")
1367+
12211368
# Sliding window eval on int6-roundtripped weights
12221369
torch.cuda.synchronize()
12231370
t_qeval = time.perf_counter()

0 commit comments

Comments
 (0)