@@ -109,6 +109,10 @@ class Hyperparameters:
109109 ttt_max_doc_len = int (os .environ .get ("TTT_MAX_DOC_LEN" , 0 )) # 0 = no cap
110110 ttt_batch_docs = int (os .environ .get ("TTT_BATCH_DOCS" , 64 ))
111111 ttt_temp = float (os .environ .get ("TTT_TEMP" , 1.0 )) # Post-TTT temperature calibration
112+ # GPTQ: Hessian-aware quantization for int5 (0 = use naive int6)
113+ gptq_enabled = bool (int (os .environ .get ("GPTQ_ENABLED" , "0" )))
114+ gptq_clip_range = int (os .environ .get ("GPTQ_CLIP_RANGE" , 15 )) # 15 = int5, 31 = int6
115+ gptq_samples = int (os .environ .get ("GPTQ_SAMPLES" , 256 ))
112116 # Hyper-connections: mix k previous hidden states (0 = disabled)
113117 hyper_k = int (os .environ .get ("HYPER_K" , 0 ))
114118 hyper_layers = int (os .environ .get ("HYPER_LAYERS" , 4 )) # apply to top N layers
@@ -1389,6 +1393,141 @@ def quantize_int6_per_row(t: Tensor, clip_range: int = 31) -> tuple[Tensor, Tens
13891393 scale = torch .tensor (amax / clip_range if amax > 0 else 1.0 , dtype = torch .float16 )
13901394 q = torch .clamp (torch .round (t32 / scale .float ()), - clip_range , clip_range ).to (torch .int8 )
13911395 return q , scale
1396+ def _find_best_row_scales (W : Tensor , clip_range : int = 15 ) -> Tensor :
1397+ """Find optimal per-row scales by searching percentile clipping thresholds."""
1398+ t32 = W .float ()
1399+ best_s = t32 .abs ().amax (dim = 1 ) / clip_range
1400+ best_s = best_s .clamp_min (1.0 / clip_range )
1401+ best_err = torch .full ((t32 .shape [0 ],), float ('inf' ))
1402+ for pct in [0.9990 , 0.9995 , 0.9999 , 0.99999 , 1.0 ]:
1403+ if pct < 1.0 :
1404+ row_clip = torch .quantile (t32 .abs (), pct , dim = 1 )
1405+ else :
1406+ row_clip = t32 .abs ().amax (dim = 1 )
1407+ s = (row_clip / clip_range ).clamp_min (1.0 / clip_range )
1408+ q = torch .clamp (torch .round (t32 / s [:, None ]), - clip_range , clip_range )
1409+ recon = q * s [:, None ]
1410+ err = (t32 - recon ).pow (2 ).mean (dim = 1 )
1411+ improved = err < best_err
1412+ best_s [improved ] = s [improved ]
1413+ best_err [improved ] = err [improved ]
1414+ return best_s
1415+
1416+ def gptq_quantize_weight (W : Tensor , H : Tensor , clip_range : int = 15 ,
1417+ block_size : int = 128 , percdamp : float = 0.01 ) -> tuple [Tensor , Tensor ]:
1418+ """GPTQ: quantize weight matrix W using Hessian H = X^T X for error compensation."""
1419+ W = W .float ().clone ()
1420+ rows , cols = W .shape
1421+ row_scale = _find_best_row_scales (W , clip_range )
1422+ H = H .float ().clone ()
1423+ damp = percdamp * H .diag ().mean ()
1424+ H .diagonal ().add_ (damp )
1425+ perm = torch .argsort (H .diag ())
1426+ invperm = torch .argsort (perm )
1427+ W = W [:, perm ]
1428+ H = H [perm ][:, perm ]
1429+ try :
1430+ L = torch .linalg .cholesky (H )
1431+ Hinv = torch .cholesky_inverse (L )
1432+ except torch ._C ._LinAlgError :
1433+ Hinv = torch .diag (1.0 / H .diag ().clamp_min (1e-6 ))
1434+ Q = torch .zeros (rows , cols , dtype = torch .int8 )
1435+ for i1 in range (0 , cols , block_size ):
1436+ i2 = min (i1 + block_size , cols )
1437+ W_block = W [:, i1 :i2 ].clone ()
1438+ Hinv_block = Hinv [i1 :i2 , i1 :i2 ]
1439+ Err = torch .zeros_like (W_block )
1440+ for j in range (i2 - i1 ):
1441+ w_col = W_block [:, j ]
1442+ h_inv_jj = Hinv_block [j , j ].clamp_min (1e-8 )
1443+ q_col = torch .clamp (torch .round (w_col / row_scale ), - clip_range , clip_range )
1444+ deq_col = q_col * row_scale
1445+ Q [:, i1 + j ] = q_col .to (torch .int8 )
1446+ err = (w_col - deq_col ) / h_inv_jj
1447+ Err [:, j ] = err
1448+ if j + 1 < i2 - i1 :
1449+ W_block [:, j + 1 :] -= err .unsqueeze (1 ) * Hinv_block [j , j + 1 :].unsqueeze (0 )
1450+ if i2 < cols :
1451+ W [:, i2 :] -= Err @ Hinv [i1 :i2 , i2 :]
1452+ Q = Q [:, invperm ]
1453+ return Q , row_scale .to (torch .float16 )
1454+
1455+ def gptq_calibrate (model : nn .Module , train_pattern : str , device : torch .device ,
1456+ n_samples : int = 256 , seq_len : int = 2048 ) -> dict [str , Tensor ]:
1457+ """Collect Hessian H = X^T X for each linear layer using training data."""
1458+ hessians : dict [str , Tensor ] = {}
1459+ n_seen : dict [str , int ] = {}
1460+ hooks = []
1461+ def make_hook (name : str ):
1462+ def hook_fn (module , inp , out ):
1463+ x = inp [0 ].detach ().float ()
1464+ if x .ndim == 3 :
1465+ x = x .reshape (- 1 , x .shape [- 1 ])
1466+ if name not in hessians :
1467+ hessians [name ] = torch .zeros (x .shape [1 ], x .shape [1 ], device = x .device , dtype = torch .float32 )
1468+ n_seen [name ] = 0
1469+ hessians [name ].addmm_ (x .t (), x )
1470+ n_seen [name ] += x .shape [0 ]
1471+ return hook_fn
1472+ for name , module in model .named_modules ():
1473+ if isinstance (module , (nn .Linear , CastedLinear )):
1474+ hooks .append (module .register_forward_hook (make_hook (name )))
1475+ stream = TokenStream (train_pattern )
1476+ model .eval ()
1477+ with torch .no_grad ():
1478+ for _ in range (n_samples ):
1479+ tokens = stream .take (seq_len + 1 ).to (device = device , dtype = torch .int64 )
1480+ x = tokens [:- 1 ].unsqueeze (0 )
1481+ with torch .autocast (device_type = "cuda" , dtype = torch .bfloat16 ):
1482+ model .forward_logits (x )
1483+ for h in hooks :
1484+ h .remove ()
1485+ for name in hessians :
1486+ hessians [name ] /= max (n_seen [name ], 1 )
1487+ return hessians
1488+
1489+ def mixed_quantize_int5_gptq (state_dict : dict [str , Tensor ], int5_cats : set [str ],
1490+ hessians : dict [str , Tensor ]) -> tuple [dict , dict ]:
1491+ """Int5 GPTQ quantization (clip_range=15, 31 levels) with Hessian error compensation."""
1492+ result : dict [str , Tensor ] = {}
1493+ meta : dict [str , object ] = {}
1494+ gptq_count , naive_count = 0 , 0
1495+ for name , tensor in state_dict .items ():
1496+ t = tensor .detach ().cpu ().contiguous ()
1497+ cat = _classify_param (name )
1498+ if not t .is_floating_point () or t .numel () <= 65536 :
1499+ result [name ] = t .to (torch .float16 ) if t .is_floating_point () else t
1500+ meta [name ] = "passthrough"
1501+ continue
1502+ if any (p in name for p in CONTROL_TENSOR_NAME_PATTERNS ):
1503+ result [name ] = t .float ()
1504+ meta [name ] = "passthrough_ctrl"
1505+ continue
1506+ if cat in int5_cats and t .ndim == 2 :
1507+ module_name = name .rsplit (".weight" , 1 )[0 ] if name .endswith (".weight" ) else name
1508+ H = hessians .get (module_name )
1509+ if H is not None and H .shape [0 ] == t .shape [1 ]:
1510+ q , s = gptq_quantize_weight (t , H .cpu ())
1511+ gptq_count += 1
1512+ else :
1513+ q , s = quantize_int6_per_row (t , clip_range = 15 )
1514+ naive_count += 1
1515+ result [name + ".q" ] = q
1516+ result [name + ".scale" ] = s
1517+ meta [name ] = {"type" : "int6" }
1518+ elif cat in int5_cats and t .ndim >= 1 :
1519+ q , s = quantize_int6_per_row (t , clip_range = 15 )
1520+ result [name + ".q" ] = q
1521+ result [name + ".scale" ] = s
1522+ meta [name ] = {"type" : "int6" }
1523+ naive_count += 1
1524+ else :
1525+ q , s = quantize_float_tensor (t )
1526+ result [name + ".q" ] = q
1527+ result [name + ".scale" ] = s
1528+ meta [name ] = {"type" : "int8" }
1529+ return result , meta
1530+
13921531def mixed_quantize_int6 (state_dict : dict [str , Tensor ], int6_cats : set [str ]):
13931532 num_layers_total = max (
13941533 (int (k .split ("." )[1 ]) for k in state_dict if k .startswith ("blocks." )),
@@ -1837,7 +1976,15 @@ def lr_mul(step: int, elapsed_ms: float) -> float:
18371976 log0 (f"Serialized model: { model_bytes } bytes" )
18381977 log0 (f"Code size: { code_bytes } bytes" )
18391978 sd_cpu = {k : v .detach ().cpu () for k , v in export_sd .items ()}
1840- quant_result , quant_meta = mixed_quantize_int6 (sd_cpu , {"mlp" , "attn" })
1979+ if args .gptq_enabled :
1980+ log0 (f"gptq:calibrating (samples={ args .gptq_samples } , clip_range={ args .gptq_clip_range } )..." )
1981+ t_gptq = time .perf_counter ()
1982+ gptq_hessians = gptq_calibrate (base_model , args .train_files , device ,
1983+ n_samples = args .gptq_samples , seq_len = args .train_seq_len )
1984+ log0 (f"gptq:calibrated { len (gptq_hessians )} layers in { time .perf_counter ()- t_gptq :.1f} s" )
1985+ quant_result , quant_meta = mixed_quantize_int5_gptq (sd_cpu , {"mlp" , "attn" }, gptq_hessians )
1986+ else :
1987+ quant_result , quant_meta = mixed_quantize_int6 (sd_cpu , {"mlp" , "attn" })
18411988 quant_buf = io .BytesIO ()
18421989 torch .save ({"w" : quant_result , "m" : quant_meta }, quant_buf )
18431990 # Save quantized model for fast eval-only iterations
0 commit comments