Skip to content

Commit b9c4106

Browse files
RoyiRaclaude
andcommitted
feat(transformer): add int5 GPTQ quantization with Hessian error compensation
Implement GPTQ (Hessian-aware) quantization for int5 (31 levels, clip=15). Uses Cholesky-based error redistribution across columns for minimal quant damage. Calibrates on 256 training sequences. Enables fitting 12L+ models within 16MB artifact limit. Controlled by GPTQ_ENABLED=1 (default: off). Based on PR openai#576's technique (1.1162 BPB with 33.6M int5 params). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent f7eff92 commit b9c4106

1 file changed

Lines changed: 148 additions & 1 deletion

File tree

submission_2026-03-23/train_gpt.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ class Hyperparameters:
109109
ttt_max_doc_len = int(os.environ.get("TTT_MAX_DOC_LEN", 0)) # 0 = no cap
110110
ttt_batch_docs = int(os.environ.get("TTT_BATCH_DOCS", 64))
111111
ttt_temp = float(os.environ.get("TTT_TEMP", 1.0)) # Post-TTT temperature calibration
112+
# GPTQ: Hessian-aware quantization for int5 (0 = use naive int6)
113+
gptq_enabled = bool(int(os.environ.get("GPTQ_ENABLED", "0")))
114+
gptq_clip_range = int(os.environ.get("GPTQ_CLIP_RANGE", 15)) # 15 = int5, 31 = int6
115+
gptq_samples = int(os.environ.get("GPTQ_SAMPLES", 256))
112116
# Hyper-connections: mix k previous hidden states (0 = disabled)
113117
hyper_k = int(os.environ.get("HYPER_K", 0))
114118
hyper_layers = int(os.environ.get("HYPER_LAYERS", 4)) # apply to top N layers
@@ -1389,6 +1393,141 @@ def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tens
13891393
scale = torch.tensor(amax / clip_range if amax > 0 else 1.0, dtype=torch.float16)
13901394
q = torch.clamp(torch.round(t32 / scale.float()), -clip_range, clip_range).to(torch.int8)
13911395
return q, scale
1396+
def _find_best_row_scales(W: Tensor, clip_range: int = 15) -> Tensor:
1397+
"""Find optimal per-row scales by searching percentile clipping thresholds."""
1398+
t32 = W.float()
1399+
best_s = t32.abs().amax(dim=1) / clip_range
1400+
best_s = best_s.clamp_min(1.0 / clip_range)
1401+
best_err = torch.full((t32.shape[0],), float('inf'))
1402+
for pct in [0.9990, 0.9995, 0.9999, 0.99999, 1.0]:
1403+
if pct < 1.0:
1404+
row_clip = torch.quantile(t32.abs(), pct, dim=1)
1405+
else:
1406+
row_clip = t32.abs().amax(dim=1)
1407+
s = (row_clip / clip_range).clamp_min(1.0 / clip_range)
1408+
q = torch.clamp(torch.round(t32 / s[:, None]), -clip_range, clip_range)
1409+
recon = q * s[:, None]
1410+
err = (t32 - recon).pow(2).mean(dim=1)
1411+
improved = err < best_err
1412+
best_s[improved] = s[improved]
1413+
best_err[improved] = err[improved]
1414+
return best_s
1415+
1416+
def gptq_quantize_weight(W: Tensor, H: Tensor, clip_range: int = 15,
1417+
block_size: int = 128, percdamp: float = 0.01) -> tuple[Tensor, Tensor]:
1418+
"""GPTQ: quantize weight matrix W using Hessian H = X^T X for error compensation."""
1419+
W = W.float().clone()
1420+
rows, cols = W.shape
1421+
row_scale = _find_best_row_scales(W, clip_range)
1422+
H = H.float().clone()
1423+
damp = percdamp * H.diag().mean()
1424+
H.diagonal().add_(damp)
1425+
perm = torch.argsort(H.diag())
1426+
invperm = torch.argsort(perm)
1427+
W = W[:, perm]
1428+
H = H[perm][:, perm]
1429+
try:
1430+
L = torch.linalg.cholesky(H)
1431+
Hinv = torch.cholesky_inverse(L)
1432+
except torch._C._LinAlgError:
1433+
Hinv = torch.diag(1.0 / H.diag().clamp_min(1e-6))
1434+
Q = torch.zeros(rows, cols, dtype=torch.int8)
1435+
for i1 in range(0, cols, block_size):
1436+
i2 = min(i1 + block_size, cols)
1437+
W_block = W[:, i1:i2].clone()
1438+
Hinv_block = Hinv[i1:i2, i1:i2]
1439+
Err = torch.zeros_like(W_block)
1440+
for j in range(i2 - i1):
1441+
w_col = W_block[:, j]
1442+
h_inv_jj = Hinv_block[j, j].clamp_min(1e-8)
1443+
q_col = torch.clamp(torch.round(w_col / row_scale), -clip_range, clip_range)
1444+
deq_col = q_col * row_scale
1445+
Q[:, i1 + j] = q_col.to(torch.int8)
1446+
err = (w_col - deq_col) / h_inv_jj
1447+
Err[:, j] = err
1448+
if j + 1 < i2 - i1:
1449+
W_block[:, j + 1:] -= err.unsqueeze(1) * Hinv_block[j, j + 1:].unsqueeze(0)
1450+
if i2 < cols:
1451+
W[:, i2:] -= Err @ Hinv[i1:i2, i2:]
1452+
Q = Q[:, invperm]
1453+
return Q, row_scale.to(torch.float16)
1454+
1455+
def gptq_calibrate(model: nn.Module, train_pattern: str, device: torch.device,
1456+
n_samples: int = 256, seq_len: int = 2048) -> dict[str, Tensor]:
1457+
"""Collect Hessian H = X^T X for each linear layer using training data."""
1458+
hessians: dict[str, Tensor] = {}
1459+
n_seen: dict[str, int] = {}
1460+
hooks = []
1461+
def make_hook(name: str):
1462+
def hook_fn(module, inp, out):
1463+
x = inp[0].detach().float()
1464+
if x.ndim == 3:
1465+
x = x.reshape(-1, x.shape[-1])
1466+
if name not in hessians:
1467+
hessians[name] = torch.zeros(x.shape[1], x.shape[1], device=x.device, dtype=torch.float32)
1468+
n_seen[name] = 0
1469+
hessians[name].addmm_(x.t(), x)
1470+
n_seen[name] += x.shape[0]
1471+
return hook_fn
1472+
for name, module in model.named_modules():
1473+
if isinstance(module, (nn.Linear, CastedLinear)):
1474+
hooks.append(module.register_forward_hook(make_hook(name)))
1475+
stream = TokenStream(train_pattern)
1476+
model.eval()
1477+
with torch.no_grad():
1478+
for _ in range(n_samples):
1479+
tokens = stream.take(seq_len + 1).to(device=device, dtype=torch.int64)
1480+
x = tokens[:-1].unsqueeze(0)
1481+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
1482+
model.forward_logits(x)
1483+
for h in hooks:
1484+
h.remove()
1485+
for name in hessians:
1486+
hessians[name] /= max(n_seen[name], 1)
1487+
return hessians
1488+
1489+
def mixed_quantize_int5_gptq(state_dict: dict[str, Tensor], int5_cats: set[str],
1490+
hessians: dict[str, Tensor]) -> tuple[dict, dict]:
1491+
"""Int5 GPTQ quantization (clip_range=15, 31 levels) with Hessian error compensation."""
1492+
result: dict[str, Tensor] = {}
1493+
meta: dict[str, object] = {}
1494+
gptq_count, naive_count = 0, 0
1495+
for name, tensor in state_dict.items():
1496+
t = tensor.detach().cpu().contiguous()
1497+
cat = _classify_param(name)
1498+
if not t.is_floating_point() or t.numel() <= 65536:
1499+
result[name] = t.to(torch.float16) if t.is_floating_point() else t
1500+
meta[name] = "passthrough"
1501+
continue
1502+
if any(p in name for p in CONTROL_TENSOR_NAME_PATTERNS):
1503+
result[name] = t.float()
1504+
meta[name] = "passthrough_ctrl"
1505+
continue
1506+
if cat in int5_cats and t.ndim == 2:
1507+
module_name = name.rsplit(".weight", 1)[0] if name.endswith(".weight") else name
1508+
H = hessians.get(module_name)
1509+
if H is not None and H.shape[0] == t.shape[1]:
1510+
q, s = gptq_quantize_weight(t, H.cpu())
1511+
gptq_count += 1
1512+
else:
1513+
q, s = quantize_int6_per_row(t, clip_range=15)
1514+
naive_count += 1
1515+
result[name + ".q"] = q
1516+
result[name + ".scale"] = s
1517+
meta[name] = {"type": "int6"}
1518+
elif cat in int5_cats and t.ndim >= 1:
1519+
q, s = quantize_int6_per_row(t, clip_range=15)
1520+
result[name + ".q"] = q
1521+
result[name + ".scale"] = s
1522+
meta[name] = {"type": "int6"}
1523+
naive_count += 1
1524+
else:
1525+
q, s = quantize_float_tensor(t)
1526+
result[name + ".q"] = q
1527+
result[name + ".scale"] = s
1528+
meta[name] = {"type": "int8"}
1529+
return result, meta
1530+
13921531
def mixed_quantize_int6(state_dict: dict[str, Tensor], int6_cats: set[str]):
13931532
num_layers_total = max(
13941533
(int(k.split(".")[1]) for k in state_dict if k.startswith("blocks.")),
@@ -1837,7 +1976,15 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
18371976
log0(f"Serialized model: {model_bytes} bytes")
18381977
log0(f"Code size: {code_bytes} bytes")
18391978
sd_cpu = {k: v.detach().cpu() for k, v in export_sd.items()}
1840-
quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"})
1979+
if args.gptq_enabled:
1980+
log0(f"gptq:calibrating (samples={args.gptq_samples}, clip_range={args.gptq_clip_range})...")
1981+
t_gptq = time.perf_counter()
1982+
gptq_hessians = gptq_calibrate(base_model, args.train_files, device,
1983+
n_samples=args.gptq_samples, seq_len=args.train_seq_len)
1984+
log0(f"gptq:calibrated {len(gptq_hessians)} layers in {time.perf_counter()-t_gptq:.1f}s")
1985+
quant_result, quant_meta = mixed_quantize_int5_gptq(sd_cpu, {"mlp", "attn"}, gptq_hessians)
1986+
else:
1987+
quant_result, quant_meta = mixed_quantize_int6(sd_cpu, {"mlp", "attn"})
18411988
quant_buf = io.BytesIO()
18421989
torch.save({"w": quant_result, "m": quant_meta}, quant_buf)
18431990
# Save quantized model for fast eval-only iterations

0 commit comments

Comments
 (0)