Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions apply_tent_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Apply TENT patch to train_gpt_pr398.py -> train_gpt_ours.py"""

with open("train_gpt_ours.py", "r") as f:
code = f.read()

# 1. Add TENT hyperparameters
old = ' ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0))'
new = ''' ttt_freeze_blocks = int(os.environ.get("TTT_FREEZE_BLOCKS", 0))
ttt_tent_enabled = bool(int(os.environ.get("TTT_TENT_ENABLED", "1")))
ttt_tent_epochs = int(os.environ.get("TTT_TENT_EPOCHS", 30))
ttt_tent_lr = float(os.environ.get("TTT_TENT_LR", 0.01))'''
code = code.replace(old, new)

# 2. Add tent_norm_recalib function before INT6 section
tent_fn = '''
NORM_PARAM_PATTERNS = ("attn_scale", "mlp_scale", "q_gain", "skip_weight")

def tent_norm_recalib(args, base_model, device, val_tokens, rank=0, world_size=1, log_fn=None):
seq_len = args.train_seq_len
total_seqs = (val_tokens.numel() - 1) // seq_len
batch_seqs = args.ttt_batch_seqs
for p in base_model.parameters():
p.requires_grad_(False)
norm_params = []
for name, p in base_model.named_parameters():
if any(k in name for k in NORM_PARAM_PATTERNS):
p.requires_grad_(True)
norm_params.append(p)
if log_fn:
log_fn(f"tent:start params={sum(p.numel() for p in norm_params)} epochs={args.ttt_tent_epochs}")
optimizer = torch.optim.Adam(norm_params, lr=args.ttt_tent_lr)
my_start = (total_seqs * rank) // world_size
my_end = (total_seqs * (rank + 1)) // world_size
base_model.train()
t0 = time.perf_counter()
for epoch in range(args.ttt_tent_epochs):
for bs in range(my_start, my_end, batch_seqs):
be = min(bs + batch_seqs, my_end)
local = val_tokens[bs*seq_len:be*seq_len+1].to(device=device, dtype=torch.int64)
x = local[:-1].reshape(-1, seq_len)
y = local[1:].reshape(-1, seq_len)
optimizer.zero_grad(set_to_none=True)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
loss = base_model(x, y)
loss.backward()
if world_size > 1:
for p in norm_params:
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.AVG)
optimizer.step()
if log_fn and (epoch+1) % max(1, args.ttt_tent_epochs//5) == 0:
log_fn(f"tent_epoch:{epoch+1}/{args.ttt_tent_epochs} time:{time.perf_counter()-t0:.1f}s")
for p in base_model.parameters():
p.requires_grad_(True)
if log_fn:
log_fn(f"tent:done elapsed={time.perf_counter()-t0:.1f}s")

'''
marker = "# -----------------------------\n# INT6 MIXED QUANTIZATION"
code = code.replace(marker, tent_fn + marker)

# 3. Insert TENT call before ttt_adapt
old_call = "ttt_adapt(args, base_model, device, val_tokens, rank, world_size, log)"
new_call = """if args.ttt_tent_enabled:
tent_norm_recalib(args, base_model, device, val_tokens, rank, world_size, log)
ttt_adapt(args, base_model, device, val_tokens, rank, world_size, log)"""
code = code.replace(old_call, new_call)

with open("train_gpt_ours.py", "w") as f:
f.write(code)
print("Patch applied OK")
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
## PR315 Recipe Reproduction on 1xH100 PCIe

Reproduction of PR #315 recipe on a single H100 PCIe GPU (RunPod, $2.39/hr).
Uses Flash Attention 2 instead of FA3.

### Results

| Metric | Value |
|--------|-------|
| Sliding-window BPB | **1.8338** |
| Pre-quant BPB | 1.4192 |
| Post-quant roundtrip BPB | 1.8398 |
| Steps | 492 |
| Wallclock | 600s (10min) |
| Artifact size | 10.0MB (int6+zstd) |
| Peak memory | 20785 MiB |

### Notes

- 1xH100 PCIe limits training to ~492 steps (vs ~6200 on 8xH100 SXM)
- QAT enabled at step 452 with only ~40 steps of adaptation, causing higher quantization loss (+0.42 BPB)
- On 8xH100 SXM this recipe achieves ~1.13 BPB with ~6200 steps and proper QAT convergence

### Configuration

- 11 layers, 512 dim, 8 heads, 4 KV heads, 3x MLP
- XSA (last 4 layers), EMA 0.997, Partial RoPE (16/64 dims)
- LN Scale, Late QAT (threshold 0.1)
- BigramHash 2048, SmearGate
- Muon optimizer + weight decay 0.04

### Run Command

```bash
pip install zstandard flash-attn

NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=2048 XSA_LAST_N=4 \
EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 \
ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 QAT_THRESHOLD=0.1 \
MUON_WD=0.04 ADAM_WD=0.04 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \
MUON_MOMENTUM_WARMUP_STEPS=200 WARMDOWN_ITERS=400 \
ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \
torchrun --standalone --nproc_per_node=1 train_gpt.py
```

### Hardware

- RunPod 1xH100 PCIe 80GB ($2.39/hr)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"author": "Sungjoon Park",
"github_id": "sjp611",
"name": "PR315 1xH100 PCIe reproduction",
"blurb": "Reproduction of #315 recipe (XSA+EMA+PartialRoPE+LateQAT) on 1xH100 PCIe with FA2. Sliding-window val_bpb=1.8338 in 10 minutes.",
"date": "2026-03-21T17:00:00Z",
"track": "non-record-10min-16mb",
"val_bpb": 1.8338,
"pre_quant_val_bpb": 1.4192,
"step_stop": 492,
"wallclock_seconds": 600,
"bytes_total": 10064451,
"bytes_model_int6_zstd": 9996844,
"bytes_code": 67607
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26829913
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
world_size:1 grad_accum_steps:8
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:9000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
step:0/9000 val_loss:6.9303 val_bpb:4.1045 train_time:0ms step_avg:0.02ms
step:1/9000 train_loss:6.9307 train_time:1209ms step_avg:1208.55ms
step:2/9000 train_loss:8.7402 train_time:2410ms step_avg:1205.09ms
step:3/9000 train_loss:7.9338 train_time:3613ms step_avg:1204.28ms
step:4/9000 train_loss:7.2458 train_time:4816ms step_avg:1203.91ms
step:5/9000 train_loss:6.9648 train_time:6020ms step_avg:1204.08ms
step:6/9000 train_loss:6.8266 train_time:7225ms step_avg:1204.22ms
step:7/9000 train_loss:6.7704 train_time:8427ms step_avg:1203.83ms
step:8/9000 train_loss:6.6462 train_time:9633ms step_avg:1204.10ms
step:9/9000 train_loss:6.4067 train_time:10836ms step_avg:1204.04ms
step:10/9000 train_loss:6.1323 train_time:12045ms step_avg:1204.50ms
step:200/9000 train_loss:2.6860 train_time:243995ms step_avg:1219.98ms
step:400/9000 train_loss:2.4791 train_time:487824ms step_avg:1219.56ms
late_qat:enabled step:452 scale:0.0999
step:492/9000 val_loss:2.3963 val_bpb:1.4192 train_time:600108ms step_avg:1219.73ms
stopping_early: wallclock_cap train_time:600108ms step:492/9000
peak memory allocated: 20785 MiB reserved: 20800 MiB
ema:applying EMA weights
Serialized model: 105783807 bytes
Code size: 67607 bytes
Serialized model int6+zstd: 9996844 bytes
Total submission size int6+zstd: 10064451 bytes
final_int6_roundtrip val_loss:3.1064 val_bpb:1.8398 eval_time:79713ms
final_int6_roundtrip_exact val_loss:3.10643932 val_bpb:1.83980834
final_int6_sliding_window val_loss:3.0963 val_bpb:1.8338 stride:64 eval_time:973581ms
final_int6_sliding_window_exact val_loss:3.09629109 val_bpb:1.83380284
Loading