Skip to content

Commit b39ce05

Browse files
committed
update: add cosine TTT + per-layer LR (from PR openai#481)
1 parent e2ead99 commit b39ce05

1 file changed

Lines changed: 51 additions & 16 deletions

File tree

  • records/track_10min_16mb/2026-03-23_11L_TrigramHash_ValueResidual_GradQuant_TTT

records/track_10min_16mb/2026-03-23_11L_TrigramHash_ValueResidual_GradQuant_TTT/train_gpt.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
"""
2-
11L EMA + AdamW TTT + TrigramHash + Value Residual + Gradient-Guided Quantization
3-
4-
Built on PR #398/#442 baseline with three novel additions:
5-
1. TrigramHash(4096): hash-based trigram embeddings extending BigramHash to 3-token context
6-
2. Value Residual (ResFormer, arXiv:2410.17897): cache V from layer 0, blend into all layers
7-
3. Gradient-Guided Quantization: adaptive Int5/6/7 per-tensor based on gradient sensitivity
8-
9-
Mean val_bpb: 1.1132 (3 seeds), best: 1.1101
2+
train_gpt_submit.py — Submission v2: wider MLP + STE int6 QAT + MTP + seq2048 + NTK RoPE +
3+
fp16 embed + late-K passthrough + sliding window eval.
104
"""
115

126
from __future__ import annotations
@@ -117,18 +111,18 @@ class Hyperparameters:
117111
# TTT (Test-Time Training)
118112
ttt_enabled = bool(int(os.environ.get("TTT_ENABLED", "1")))
119113
ttt_lr = float(os.environ.get("TTT_LR", 0.0005))
120-
ttt_epochs = int(os.environ.get("TTT_EPOCHS", 10))
114+
ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30))
121115
ttt_momentum = float(os.environ.get("TTT_MOMENTUM", 0.9))
122116
ttt_batch_seqs = int(os.environ.get("TTT_BATCH_SEQS", 32))
123117
ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0))
124118
bigram_vocab_size = int(os.environ.get("BIGRAM_VOCAB_SIZE", 2048))
125119
bigram_dim = int(os.environ.get("BIGRAM_DIM", 128))
126-
# TrigramHash embedding (3-token context via hash table)
120+
# TrigramHash (our unique addition)
127121
trigram_vocab_size = int(os.environ.get("TRIGRAM_VOCAB_SIZE", 4096))
128122
trigram_dim = int(os.environ.get("TRIGRAM_DIM", 128))
129-
# Value Residual (ResFormer, arXiv:2410.17897)
123+
# Value Residual (from PR #413, ResFormer — -0.015 BPB for 18 params)
130124
value_residual = bool(int(os.environ.get("VALUE_RESIDUAL", "1")))
131-
# Gradient-Guided Adaptive Quantization
125+
# Gradient-Guided Quantization (from PR #332)
132126
grad_quant = bool(int(os.environ.get("GRAD_QUANT", "1")))
133127

134128
# -----------------------------
@@ -1199,10 +1193,12 @@ def dequantize_mixed_int6(result: dict[str, Tensor], meta: dict[str, object],
11991193
# -----------------------------
12001194

12011195
def ttt_adapt(args, base_model, device, val_tokens, rank=0, world_size=1, log_fn=None):
1202-
"""Full-weight TTT: SGD adaptation on val data with DDP across all GPUs."""
1196+
"""Full-weight TTT with cosine LR decay and per-layer LR (from PR #481)."""
12031197
seq_len = args.train_seq_len
12041198
total_seqs = (val_tokens.numel() - 1) // seq_len
12051199
batch_seqs = args.ttt_batch_seqs
1200+
ttt_cosine = bool(int(os.environ.get("TTT_COSINE", "1")))
1201+
ttt_perlayer = bool(int(os.environ.get("TTT_PERLAYER", "1")))
12061202

12071203
frozen_params = set()
12081204
if args.ttt_freeze_blocks > 0:
@@ -1212,24 +1208,62 @@ def ttt_adapt(args, base_model, device, val_tokens, rank=0, world_size=1, log_fn
12121208
p.requires_grad_(False)
12131209
frozen_params.add(id(p))
12141210

1215-
ttt_params = [p for p in base_model.parameters() if p.requires_grad]
1211+
# Per-layer LR: MLP output projections get 3× LR (most quant damage),
1212+
# MLP input projections get 0.5× LR (least damage)
12161213
ttt_use_adamw = bool(int(os.environ.get("TTT_ADAMW", "1")))
1214+
if ttt_perlayer:
1215+
proj_params = [p for n, p in base_model.named_parameters()
1216+
if "mlp.proj" in n and p.requires_grad and id(p) not in frozen_params]
1217+
fc_params = [p for n, p in base_model.named_parameters()
1218+
if "mlp.fc" in n and p.requires_grad and id(p) not in frozen_params]
1219+
other_params = [p for p in base_model.parameters()
1220+
if p.requires_grad and id(p) not in frozen_params
1221+
and id(p) not in {id(q) for q in proj_params + fc_params}]
1222+
param_groups = [g for g in [
1223+
{"params": proj_params, "lr": args.ttt_lr * 3.0},
1224+
{"params": fc_params, "lr": args.ttt_lr * 0.5},
1225+
{"params": other_params, "lr": args.ttt_lr},
1226+
] if g["params"]]
1227+
ttt_params = proj_params + fc_params + other_params
1228+
else:
1229+
ttt_params = [p for p in base_model.parameters() if p.requires_grad and id(p) not in frozen_params]
1230+
param_groups = [{"params": ttt_params, "lr": args.ttt_lr}]
1231+
12171232
if ttt_use_adamw:
1218-
optimizer = torch.optim.AdamW(ttt_params, lr=args.ttt_lr, weight_decay=0.0)
1233+
optimizer = torch.optim.AdamW(param_groups, weight_decay=0.0)
12191234
else:
1220-
optimizer = torch.optim.SGD(ttt_params, lr=args.ttt_lr, momentum=args.ttt_momentum)
1235+
optimizer = torch.optim.SGD(param_groups, momentum=args.ttt_momentum)
1236+
1237+
# Store initial LR for cosine schedule
1238+
if ttt_cosine:
1239+
for g in optimizer.param_groups:
1240+
g["initial_lr"] = g["lr"]
12211241

12221242
my_start = (total_seqs * rank) // world_size
12231243
my_end = (total_seqs * (rank + 1)) // world_size
1244+
steps_per_epoch = (my_end - my_start) // max(batch_seqs, 1)
1245+
total_steps = args.ttt_epochs * steps_per_epoch
1246+
global_step = 0
12241247

12251248
base_model.train()
12261249
t0 = time.perf_counter()
12271250

1251+
if log_fn:
1252+
n_ttt = sum(p.numel() for p in ttt_params)
1253+
log_fn(f"ttt:config params:{n_ttt} cosine:{ttt_cosine} perlayer:{ttt_perlayer}")
1254+
12281255
for epoch in range(args.ttt_epochs):
12291256
epoch_loss_sum = torch.zeros((), device=device, dtype=torch.float64)
12301257
epoch_tokens = torch.zeros((), device=device, dtype=torch.float64)
12311258

12321259
for batch_start in range(my_start, my_end, batch_seqs):
1260+
# Cosine LR decay
1261+
if ttt_cosine and total_steps > 0:
1262+
progress = global_step / total_steps
1263+
mul = 0.5 * (1.0 + math.cos(math.pi * min(progress, 1.0)))
1264+
for g in optimizer.param_groups:
1265+
g["lr"] = g["initial_lr"] * mul
1266+
12331267
batch_end = min(batch_start + batch_seqs, my_end)
12341268
raw_start = batch_start * seq_len
12351269
raw_end = batch_end * seq_len + 1
@@ -1252,6 +1286,7 @@ def ttt_adapt(args, base_model, device, val_tokens, rank=0, world_size=1, log_fn
12521286

12531287
epoch_loss_sum += loss.detach().to(torch.float64) * y.numel()
12541288
epoch_tokens += float(y.numel())
1289+
global_step += 1
12551290

12561291
if world_size > 1:
12571292
dist.all_reduce(epoch_loss_sum, op=dist.ReduceOp.SUM)

0 commit comments

Comments
 (0)