Skip to content

Commit 65e612a

Browse files
committed
exp49: cosine pre-eval TTT 30ep + per-layer LR (from PR openai#481/openai#486)
- 30 epochs AdamW(lr=0.0005) on val tokens with cosine LR decay - per-layer LR: 3x for mlp.proj (high quant error), 0.5x for mlp.fc - DDP gradient sync via all_reduce(AVG) + grad clip 1.0 - keep LeakyReLU(0.5)^2 from exp48 - expected: ~0.06 BPB gain (1.127 → ~1.07) - modal timeout 3600s for 30-epoch TTT
1 parent d127837 commit 65e612a

2 files changed

Lines changed: 72 additions & 1 deletion

File tree

modal_train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
@app.function(
3939
image=image,
4040
gpu="H100:8",
41-
timeout=1800,
41+
timeout=3600,
4242
)
4343
def train(env_overrides: dict[str, str] | None = None):
4444
"""8xh100 training"""

train_gpt.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,6 +1382,77 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
13821382
f"eval_time:{1000.0 * (time.perf_counter() - t_qeval):.0f}ms"
13831383
)
13841384
log0(f"final_int6_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}")
1385+
1386+
# cosine pre-eval TTT (from PR #481/#486 — 30 epochs AdamW with cosine LR + per-layer LR)
1387+
ttt_epochs = int(os.environ.get("TTT_EPOCHS", 30))
1388+
ttt_lr = float(os.environ.get("TTT_LR", 0.0005))
1389+
if ttt_epochs > 0:
1390+
torch.cuda.synchronize()
1391+
t_ttt = time.perf_counter()
1392+
log0(f"ttt: starting {ttt_epochs} epochs, lr={ttt_lr}, cosine+perlayer")
1393+
# per-layer LR groups: 3x for MLP output projections, 0.5x for MLP input
1394+
proj_params, fc_params, other_params = [], [], []
1395+
for name, p in eval_model.named_parameters():
1396+
p.requires_grad_(True)
1397+
if "mlp.proj" in name:
1398+
proj_params.append(p)
1399+
elif "mlp.fc" in name:
1400+
fc_params.append(p)
1401+
else:
1402+
other_params.append(p)
1403+
ttt_opt = torch.optim.AdamW([
1404+
{"params": proj_params, "lr": ttt_lr * 3.0},
1405+
{"params": fc_params, "lr": ttt_lr * 0.5},
1406+
{"params": other_params, "lr": ttt_lr},
1407+
], weight_decay=0.0)
1408+
total_val = val_tokens.numel() - 1
1409+
ttt_batch = 32
1410+
rank_tokens = total_val // world_size
1411+
rank_start = rank * rank_tokens
1412+
rank_end = rank_start + rank_tokens
1413+
steps_per_epoch = max(1, (rank_end - rank_start - args.train_seq_len) // (ttt_batch * args.train_seq_len))
1414+
total_steps = ttt_epochs * steps_per_epoch
1415+
global_step = 0
1416+
eval_model.train()
1417+
for ep in range(ttt_epochs):
1418+
ep_loss, ep_steps = 0.0, 0
1419+
for bs in range(rank_start, rank_end - args.train_seq_len, ttt_batch * args.train_seq_len):
1420+
be = min(bs + ttt_batch * args.train_seq_len + 1, rank_end + 1)
1421+
local = val_tokens[bs:be].to(device=device, dtype=torch.int64)
1422+
n = (local.numel() - 1) // args.train_seq_len
1423+
if n == 0:
1424+
continue
1425+
x = local[:n * args.train_seq_len].reshape(n, args.train_seq_len)
1426+
y = local[1:n * args.train_seq_len + 1].reshape(n, args.train_seq_len)
1427+
# cosine LR schedule
1428+
progress = global_step / max(total_steps, 1)
1429+
cos_mul = 0.5 * (1.0 + math.cos(math.pi * progress))
1430+
for g in ttt_opt.param_groups:
1431+
g["lr"] = g.get("initial_lr", g["lr"]) * cos_mul
1432+
if global_step == 0:
1433+
for g in ttt_opt.param_groups:
1434+
g["initial_lr"] = g["lr"]
1435+
ttt_opt.zero_grad()
1436+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=True):
1437+
loss = eval_model(x, y)
1438+
loss.backward()
1439+
# sync gradients across ranks
1440+
if distributed:
1441+
for p in eval_model.parameters():
1442+
if p.grad is not None:
1443+
dist.all_reduce(p.grad, op=dist.ReduceOp.AVG)
1444+
torch.nn.utils.clip_grad_norm_(eval_model.parameters(), 1.0)
1445+
ttt_opt.step()
1446+
ep_loss += loss.item()
1447+
ep_steps += 1
1448+
global_step += 1
1449+
if master_process and (ep + 1) % 5 == 0:
1450+
log0(f"ttt_epoch:{ep + 1}/{ttt_epochs} avg_loss:{ep_loss / max(ep_steps, 1):.4f}")
1451+
del ttt_opt
1452+
torch.cuda.empty_cache()
1453+
torch.cuda.synchronize()
1454+
log0(f"ttt: completed in {1000.0 * (time.perf_counter() - t_ttt):.0f}ms")
1455+
13851456
sw_seq_len = effective_eval_seq_len
13861457
if args.eval_stride > 0 and args.eval_stride < sw_seq_len:
13871458
torch.cuda.synchronize()

0 commit comments

Comments
 (0)