Skip to content

Commit d30daab

Browse files
committed
R13 base: R12 full TTT (483 lines) + PR openai#1420 tilt reference
1 parent 3f2866b commit d30daab

File tree

9 files changed

+4624
-1117
lines changed

9 files changed

+4624
-1117
lines changed

evaluate.py

Lines changed: 1009 additions & 0 deletions
Large diffs are not rendered by default.

ngram_ref/eval_ngram.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""Eval-only: run sliding window + n-gram tilt on an existing quantized model.
2+
Usage: torchrun --standalone --nproc_per_node=8 eval_ngram.py --model final_model.int6.ptz
3+
"""
4+
import argparse, glob, io, math, os, time
5+
from pathlib import Path
6+
import numpy as np
7+
import sentencepiece as spm
8+
import torch
9+
import torch.distributed as dist
10+
import torch.nn.functional as F
11+
12+
def load_data_shard(file):
13+
header = np.fromfile(file, dtype="<i4", count=256)
14+
return torch.from_numpy(
15+
np.fromfile(file, dtype="<u2", count=int(header[2]),
16+
offset=256 * np.dtype("<i4").itemsize).astype(np.uint16, copy=False))
17+
18+
def build_luts(sp, vocab_size, device):
19+
sp_vs = int(sp.vocab_size())
20+
sz = max(sp_vs, vocab_size)
21+
bb = np.zeros(sz, dtype=np.int16)
22+
ls = np.zeros(sz, dtype=np.bool_)
23+
bd = np.ones(sz, dtype=np.bool_)
24+
for tid in range(sp_vs):
25+
if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid):
26+
continue
27+
bd[tid] = False
28+
if sp.is_byte(tid):
29+
bb[tid] = 1
30+
continue
31+
piece = sp.id_to_piece(tid)
32+
if piece.startswith("\u2581"):
33+
ls[tid] = True
34+
piece = piece[1:]
35+
bb[tid] = len(piece.encode("utf-8"))
36+
return (torch.tensor(bb, dtype=torch.int16, device=device),
37+
torch.tensor(ls, dtype=torch.bool, device=device),
38+
torch.tensor(bd, dtype=torch.bool, device=device))
39+
40+
def main():
41+
parser = argparse.ArgumentParser()
42+
parser.add_argument("--code", default="train_gpt.py")
43+
parser.add_argument("--model", default="final_model.int6.ptz")
44+
parser.add_argument("--val-pattern", default="./data/datasets/fineweb10B_sp4096/fineweb_val_*.bin")
45+
parser.add_argument("--tokenizer", default="./data/tokenizers/fineweb_4096_bpe.model")
46+
parser.add_argument("--stride", type=int, default=64)
47+
parser.add_argument("--seq-len", type=int, default=2048)
48+
parser.add_argument("--batch-seqs", type=int, default=32)
49+
parser.add_argument("--base-beta", type=float, default=1.0)
50+
parser.add_argument("--agree-bonus", type=float, default=0.5)
51+
parser.add_argument("--within-threshold", type=float, default=0.25)
52+
parser.add_argument("--within-beta", type=float, default=0.55)
53+
parser.add_argument("--word-threshold", type=float, default=0.80)
54+
parser.add_argument("--word-beta", type=float, default=0.50)
55+
# Model architecture args (must match training)
56+
parser.add_argument("--vocab-size", type=int, default=4096)
57+
parser.add_argument("--num-layers", type=int, default=11)
58+
parser.add_argument("--model-dim", type=int, default=512)
59+
parser.add_argument("--num-heads", type=int, default=8)
60+
parser.add_argument("--num-kv-heads", type=int, default=4)
61+
parser.add_argument("--mlp-mult", type=float, default=4.0)
62+
parser.add_argument("--logit-softcap", type=float, default=30.0)
63+
parser.add_argument("--rope-base", type=float, default=10000.0)
64+
parser.add_argument("--qk-gain-init", type=float, default=5.0)
65+
parser.add_argument("--xsa-last-n", type=int, default=11)
66+
parser.add_argument("--rope-dims", type=int, default=16)
67+
parser.add_argument("--ve-enabled", type=int, default=1)
68+
parser.add_argument("--ve-dim", type=int, default=128)
69+
parser.add_argument("--ve-layers", default="9,10")
70+
parser.add_argument("--recur-layers", default="4,5")
71+
parser.add_argument("--parallel-start-layer", type=int, default=7)
72+
args = parser.parse_args()
73+
74+
# Distributed init
75+
distributed = "RANK" in os.environ
76+
rank = int(os.environ.get("RANK", "0"))
77+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
78+
local_rank = int(os.environ.get("LOCAL_RANK", "0"))
79+
device = torch.device("cuda", local_rank)
80+
torch.cuda.set_device(device)
81+
if distributed:
82+
dist.init_process_group(backend="nccl", device_id=device)
83+
dist.barrier()
84+
master = rank == 0
85+
86+
# Import training code as module
87+
import importlib.util
88+
os.environ.setdefault("MODEL_NAME", "eval")
89+
os.environ.setdefault("SEED", "42")
90+
import sys
91+
spec = importlib.util.spec_from_file_location("tg", args.code)
92+
tg = importlib.util.module_from_spec(spec)
93+
sys.modules["tg"] = tg
94+
spec.loader.exec_module(tg)
95+
96+
# Load val tokens
97+
val_files = sorted(glob.glob(args.val_pattern))
98+
val_tokens = torch.cat([load_data_shard(Path(f)) for f in val_files]).contiguous()
99+
total_tokens = val_tokens.numel() - 1
100+
if master:
101+
print(f"Val tokens: {total_tokens:,}")
102+
103+
# Build LUTs
104+
sp = spm.SentencePieceProcessor(model_file=args.tokenizer)
105+
bb_lut, ls_lut, bd_lut = build_luts(sp, args.vocab_size, device)
106+
107+
# Load model
108+
model = tg.GPT(tg.Hyperparameters()).to(device).bfloat16()
109+
tg.restore_fp32_params(model)
110+
with open(args.model, "rb") as f:
111+
blob = f.read()
112+
import brotli
113+
dec = brotli.decompress(blob)
114+
if hasattr(tg, "_byte_unshuffle"):
115+
dec = tg._byte_unshuffle(dec)
116+
qs = torch.load(io.BytesIO(dec), map_location="cpu")
117+
sd = {k: v.detach().cpu() for k, v in model.state_dict().items()}
118+
dq = tg.dequantize_mixed_int6(qs["w"], qs["m"], sd)
119+
model.load_state_dict(dq, strict=True)
120+
if hasattr(model, "set_recurrence_active"):
121+
model.set_recurrence_active(True)
122+
model.eval()
123+
if master:
124+
print("Model loaded.")
125+
126+
# Compile
127+
logits_fn = torch.compile(model.forward_logits, dynamic=False, fullgraph=True)
128+
129+
# Sliding window setup
130+
seq_len = args.seq_len
131+
stride = args.stride
132+
context_size = seq_len - stride
133+
window_starts = [ws for ws in range(0, total_tokens, stride)
134+
if ws + context_size < total_tokens]
135+
total_windows = len(window_starts)
136+
my_s = (total_windows * rank) // world_size
137+
my_e = (total_windows * (rank + 1)) // world_size
138+
my_windows = window_starts[my_s:my_e]
139+
140+
# Precompute n-gram hints
141+
all_hints = np.zeros(total_tokens + 1, dtype=np.int32)
142+
all_betas = np.zeros(total_tokens + 1, dtype=np.float64)
143+
if master:
144+
from fused_expert_ext import ContextMixer
145+
sp_vs = int(sp.vocab_size())
146+
sz = max(sp_vs, args.vocab_size)
147+
bb_np = np.zeros(sz, dtype=np.int16)
148+
ls_np = np.zeros(sz, dtype=np.uint8)
149+
bd_np = np.ones(sz, dtype=np.uint8)
150+
for tid in range(sp_vs):
151+
if sp.is_control(tid) or sp.is_unknown(tid) or sp.is_unused(tid):
152+
continue
153+
bd_np[tid] = 0
154+
if sp.is_byte(tid):
155+
bb_np[tid] = 1
156+
continue
157+
piece = sp.id_to_piece(tid)
158+
if piece.startswith("\u2581"):
159+
ls_np[tid] = 1
160+
piece = piece[1:]
161+
bb_np[tid] = len(piece.encode("utf-8"))
162+
val_np = val_tokens.numpy().astype(np.int64)
163+
ngram = ContextMixer(
164+
base_beta=args.base_beta, agree_bonus=args.agree_bonus,
165+
within_threshold=args.within_threshold, within_beta=args.within_beta,
166+
word_threshold=args.word_threshold, word_beta=args.word_beta,
167+
open_table_bits=26, token_threshold_scale=1.0, order_stride=2)
168+
ngram.set_tokens(val_np)
169+
ngram.set_luts(bb_np, ls_np, bd_np)
170+
positions = np.arange(1, total_tokens + 1, dtype=np.int64)
171+
ngram.get_hints_batch(positions, all_hints[1:], all_betas[1:])
172+
print(f"N-gram precomputed for {total_tokens} positions")
173+
if distributed:
174+
hints_t = torch.from_numpy(all_hints).to(device)
175+
betas_t = torch.from_numpy(all_betas).to(device)
176+
dist.broadcast(hints_t, src=0)
177+
dist.broadcast(betas_t, src=0)
178+
else:
179+
hints_t = torch.from_numpy(all_hints).to(device)
180+
betas_t = torch.from_numpy(all_betas).to(device)
181+
182+
if master:
183+
print(f"Windows: {total_windows:,}, my_windows: {len(my_windows):,}")
184+
185+
# Run eval: compute both base SW and n-gram tilted in one pass
186+
val_gpu = val_tokens.to(device=device, dtype=torch.int64)
187+
base_loss = torch.zeros((), device=device, dtype=torch.float64)
188+
tilt_loss = torch.zeros((), device=device, dtype=torch.float64)
189+
tc = torch.zeros((), device=device, dtype=torch.float64)
190+
bc = torch.zeros((), device=device, dtype=torch.float64)
191+
t0 = time.perf_counter()
192+
193+
with torch.inference_mode():
194+
for bi in range(0, len(my_windows), args.batch_seqs):
195+
batch_ws = my_windows[bi:bi + args.batch_seqs]
196+
bsz = len(batch_ws)
197+
x = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
198+
y = torch.zeros(bsz, seq_len, dtype=torch.int64, device=device)
199+
wlens = []
200+
for i, ws in enumerate(batch_ws):
201+
we = min(ws + seq_len, total_tokens)
202+
wlen = we - ws
203+
wlens.append(wlen)
204+
chunk = val_gpu[ws:we + 1]
205+
x[i, :wlen] = chunk[:-1]
206+
y[i, :wlen] = chunk[1:]
207+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
208+
logits = logits_fn(x)
209+
logits_f = logits.float()
210+
nll_all = F.cross_entropy(
211+
logits_f.reshape(-1, logits_f.size(-1)),
212+
y.reshape(-1), reduction="none",
213+
).reshape(bsz, seq_len)
214+
for i, ws in enumerate(batch_ws):
215+
wlen = wlens[i]
216+
s = 0 if ws == 0 else context_size
217+
scored_nll = nll_all[i, s:wlen].to(torch.float64)
218+
base_loss += scored_nll.sum()
219+
# N-gram tilt
220+
gp = torch.arange(ws + s + 1, ws + wlen + 1, device=device, dtype=torch.int64)
221+
hint = hints_t[gp]
222+
beta = betas_t[gp]
223+
has_hint = (hint >= 0).to(torch.float64)
224+
scored_logits = logits_f[i, s:wlen]
225+
tgt = y[i, s:wlen]
226+
safe_h = hint.clamp(min=0)
227+
logit_tgt = scored_logits.gather(-1, tgt.unsqueeze(-1)).squeeze(-1).to(torch.float64)
228+
logit_hint = scored_logits.gather(-1, safe_h.unsqueeze(-1)).squeeze(-1).to(torch.float64)
229+
lse = scored_nll + logit_tgt
230+
p_hint = (logit_hint - lse).exp().clamp(0.0, 1.0)
231+
Z = 1.0 + p_hint * (beta.exp() - 1.0)
232+
is_hit = (tgt == hint).to(torch.float64)
233+
mixed_nll = scored_nll + has_hint * (Z.log() - beta * is_hit)
234+
tilt_loss += mixed_nll.sum()
235+
tc += float(wlen - s)
236+
prev = x[i, s:wlen]
237+
tb = bb_lut[tgt].to(torch.float64)
238+
tb += (ls_lut[tgt] & ~bd_lut[prev]).to(torch.float64)
239+
bc += tb.sum()
240+
241+
if distributed:
242+
for t in (base_loss, tilt_loss, tc, bc):
243+
dist.all_reduce(t, op=dist.ReduceOp.SUM)
244+
245+
elapsed = time.perf_counter() - t0
246+
tpb = tc.item() / bc.item()
247+
base_bpb = (base_loss.item() / tc.item() / math.log(2)) * tpb
248+
tilt_bpb = (tilt_loss.item() / tc.item() / math.log(2)) * tpb
249+
250+
if master:
251+
print(f"\nbase_sw_bpb: {base_bpb:.8f}")
252+
print(f"ngram_tilt_bpb: {tilt_bpb:.8f}")
253+
print(f"delta: {tilt_bpb - base_bpb:+.8f}")
254+
print(f"eval_time: {elapsed:.1f}s")
255+
256+
if distributed:
257+
dist.barrier()
258+
dist.destroy_process_group()
259+
260+
if __name__ == "__main__":
261+
main()

0 commit comments

Comments
 (0)