Skip to content

Commit 81ea3ef

Browse files
committed
Non-record: 11L PR374 Stack + GPTQ-lite (val_bpb=1.1804, 15.95MB)
1 parent 0b34042 commit 81ea3ef

File tree

5 files changed

+1339
-0
lines changed

5 files changed

+1339
-0
lines changed
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# 11L Partial RoPE + XSA4 + VE128 + Tight SWA + Late QAT + GPTQ-lite
2+
3+
## Score: val_bpb = 1.1804 (post-quant, single seed)
4+
5+
Trained on 8×H100 SXM in 615 seconds. 15.95MB artifact (int6+zstd-22).
6+
7+
## Approach
8+
9+
Combines the PR #374 SOTA stack with MLP width reduction (1408 vs 1536) to fit under 16MB, plus GPTQ-lite quantization optimization.
10+
11+
### Architecture
12+
- 11 transformer layers, 512-dim, 8 heads (4 KV heads, GQA)
13+
- MLP hidden=1408 (2.75× expansion), relu-squared activation
14+
- **Partial RoPE** (16/64 dims): Only 25% of head dims get rotary embeddings. The remaining 75% are position-free, improving generalization.
15+
- **LN Scale** (1/sqrt(layer_idx+1)): Damps RMSNorm output in deeper layers, stabilizing gradient flow.
16+
- **XSA** on last 4 layers: Exclusive Self Attention removes self-value bias via GQA-aware orthogonal projection. Zero new parameters, ~2ms/step.
17+
- **Shared Value Embedding** (dim=128, layers 9,10): Single embedding table projected to KV dim, added to V in selected layers. Per-layer learned scales.
18+
- SmearGate: Learned per-dim gate blending current + previous token embeddings.
19+
- U-Net skip connections (5 encoder, 6 decoder), tied embeddings, logit softcap 30.
20+
21+
### Training
22+
- Muon optimizer: lr=0.025, momentum=0.99 (warmup 0.92→0.99 over 1500 steps), WD=0.04
23+
- AdamW: embed_lr=0.035, scalar_lr=0.025, WD=0.04
24+
- Batch: 786,432 tokens/step, seq_len=2048
25+
- Warmdown: 3000 iters (wallclock-based), grad_clip=0.3
26+
- **Tight SWA**: Uniform average of checkpoints collected every 50 steps when lr_scale < 0.2 (6 checkpoints total). Zero quality penalty vs non-SWA.
27+
- **Late QAT**: STE int6 fake-quantization activated when lr_scale < 0.1 (step 4070). LR halved at activation to avoid disrupting converged weights.
28+
29+
### Quantization
30+
- **GPTQ-lite**: Per-tensor clip ratio search (5 candidates: 0.9999, 0.99999, 0.999999, 0.9999984, 1.0). Selects the clip percentile that minimizes reconstruction error L2. Zero training cost.
31+
- Int6 step=4 rounding on layers 1-9 (64 distinct values for better compression)
32+
- Int8 on layers 0 and 10 (input/output quality)
33+
- FP16 tied embeddings (never quantized)
34+
- zstd level 22 compression
35+
36+
## Key Metrics
37+
38+
| Metric | Value |
39+
|--------|-------|
40+
| Pre-quant val_bpb | 1.1770 |
41+
| **Post-quant val_bpb** | **1.1804** |
42+
| Quant gap | +0.0034 |
43+
| Steps completed | 4,071 |
44+
| Step time | 137ms avg (151ms after Late QAT) |
45+
| Model parameters | 25,224,291 |
46+
| Artifact size | 15,949,473 bytes (15.95 MB) |
47+
| Peak GPU memory | 20,590 MiB |
48+
49+
## Convergence
50+
51+
| Step | val_bpb | train_time |
52+
|------|---------|-----------|
53+
| 1000 | 1.3246 | 136s |
54+
| 2000 | 1.2551 | 274s |
55+
| 3000 | 1.2139 | 413s |
56+
| 4000 | 1.1793 | 551s |
57+
| 4071 | 1.1770 | 615s (cap) |
58+
59+
## Lessons Learned
60+
61+
1. **MLP hidden=1408 > 1536 for artifact-constrained models**: Narrower MLP fits in 16MB with int6+zstd while enabling ~33% more training steps (137ms vs 178ms/step). The extra steps more than compensate for reduced per-step capacity.
62+
63+
2. **Late QAT timing matters**: Activating at lr_scale<0.1 (last ~1% of training) gives only 1 step of QAT adaptation. Earlier activation (lr_scale<0.2) would give more adaptation time but risks disrupting Muon momentum.
64+
65+
3. **Tight SWA (scale<0.2) eliminates SWA quality penalty**: Standard SWA (scale<0.5) averages stale early-warmdown checkpoints that hurt final quality. Restricting to scale<0.2 produces weight averaging with zero quality loss.
66+
67+
4. **GPTQ-lite clip search is free**: Trying 5 clip ratios per tensor during quantization costs ~2s total and reduces reconstruction error without any training cost.
68+
69+
## Setup
70+
71+
```bash
72+
pip install --break-system-packages zstandard
73+
# or: pip install -r requirements.txt
74+
```
75+
76+
## Command
77+
78+
```bash
79+
RUN_ID=pr374_8x_v2 MLP_HIDDEN=1408 \
80+
DATA_PATH=../../../data/datasets/fineweb10B_sp1024/ \
81+
TOKENIZER_PATH=../../../data/tokenizers/fineweb_1024_bpe.model \
82+
VOCAB_SIZE=1024 \
83+
torchrun --standalone --nproc_per_node=8 train_gpt.py
84+
```
85+
86+
## Status
87+
88+
**Non-record submission.** Single seed (1337). val_bpb 1.1804 does not beat SOTA 1.1428 by the required 0.005 margin.
89+
90+
Submitted to document the systematic combination of frontier techniques (Partial RoPE, LN Scale, XSA, Shared VE, Tight SWA, Late QAT, GPTQ-lite) with the novel insight that MLP hidden=1408 (vs 1536) produces better results under the 16MB constraint because faster step time yields more training steps.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
zstandard
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
{
2+
"author": "Roberto Arce",
3+
"github_id": "rarce",
4+
"name": "11L Partial RoPE + XSA4 + VE128 + Tight SWA + Late QAT + GPTQ-lite",
5+
"blurb": "11-layer 512-dim model with MLP hidden=1408 (2.75x). Partial RoPE (16/64 dims) leaves 75% of head dims position-free. LN Scale (1/sqrt(layer+1)) damps deeper layers. XSA on last 4 layers removes self-value bias via GQA-aware orthogonal projection. Shared Value Embedding (dim=128) on layers 9,10. SmearGate for bigram context. Tight SWA (scale<0.2, 6 checkpoints) for zero-penalty weight averaging. Late QAT (STE int6 at lr_scale<0.1) avoids Muon momentum corruption. GPTQ-lite per-tensor clip ratio search minimizes reconstruction error at quantization time. Int6 step=4 on layers 1-9, int8 on layers 0,10. FP16 tied embeddings. zstd-22 compression. Muon lr=0.025, momentum=0.99, WD=0.04. Batch=786K, seq=2048, warmdown=3000, grad_clip=0.3.",
6+
"date": "2026-03-23T11:30:00Z",
7+
"val_loss": 1.99308832,
8+
"val_bpb": 1.18041917,
9+
"pre_quant_val_loss": 1.9873,
10+
"pre_quant_val_bpb": 1.1770,
11+
"step_stop": 4071,
12+
"wallclock_seconds": 615.394,
13+
"eval_time_seconds": 2.215,
14+
"bytes_total": 15949473,
15+
"bytes_model_int6_zstd": 15896909,
16+
"bytes_code": 52564,
17+
"seeds": [1337],
18+
"seed_results": {
19+
"1337": {"val_loss": 1.99308832, "val_bpb": 1.18041917}
20+
},
21+
"notes": "Single seed run. Non-record submission — documents combination of PR #374 stack techniques (Partial RoPE, LN Scale, XSA, Shared VE, Tight SWA, Late QAT) with novel GPTQ-lite clip search and MLP hidden=1408 for artifact size optimization."
22+
}
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
model_params:25224291
2+
step:0/20000 val_loss:6.9312 val_bpb:4.1050 train_time:0ms step_avg:0.01ms
3+
step:1/20000 train_loss:6.9328 train_time:150ms step_avg:150.20ms
4+
step:2/20000 train_loss:8.8303 train_time:263ms step_avg:131.58ms
5+
step:3/20000 train_loss:7.7380 train_time:387ms step_avg:128.98ms
6+
step:4/20000 train_loss:7.3025 train_time:523ms step_avg:130.73ms
7+
step:5/20000 train_loss:7.0590 train_time:705ms step_avg:140.90ms
8+
step:6/20000 train_loss:6.9447 train_time:894ms step_avg:148.94ms
9+
step:7/20000 train_loss:6.8608 train_time:1026ms step_avg:146.62ms
10+
step:8/20000 train_loss:6.7036 train_time:1195ms step_avg:149.32ms
11+
step:9/20000 train_loss:6.3999 train_time:1399ms step_avg:155.50ms
12+
step:10/20000 train_loss:6.1058 train_time:1614ms step_avg:161.42ms
13+
step:200/20000 train_loss:2.4242 train_time:26353ms step_avg:131.77ms
14+
step:400/20000 train_loss:2.4424 train_time:56529ms step_avg:141.32ms
15+
step:600/20000 train_loss:2.3537 train_time:82232ms step_avg:137.05ms
16+
step:800/20000 train_loss:2.2541 train_time:109933ms step_avg:137.42ms
17+
step:1000/20000 train_loss:2.2857 train_time:136310ms step_avg:136.31ms
18+
step:1000/20000 val_loss:2.2366 val_bpb:1.3246 train_time:136329ms step_avg:136.33ms
19+
step:1200/20000 train_loss:2.3652 train_time:165361ms step_avg:137.80ms
20+
step:1400/20000 train_loss:2.1893 train_time:193954ms step_avg:138.54ms
21+
step:1600/20000 train_loss:2.0746 train_time:219091ms step_avg:136.93ms
22+
step:1800/20000 train_loss:2.1419 train_time:247744ms step_avg:137.64ms
23+
step:2000/20000 train_loss:2.0570 train_time:273795ms step_avg:136.90ms
24+
step:2000/20000 val_loss:2.1191 val_bpb:1.2551 train_time:273809ms step_avg:136.90ms
25+
step:2200/20000 train_loss:2.1214 train_time:302110ms step_avg:137.32ms
26+
step:2400/20000 train_loss:2.0412 train_time:328279ms step_avg:136.78ms
27+
step:2600/20000 train_loss:2.0787 train_time:357420ms step_avg:137.47ms
28+
step:2800/20000 train_loss:2.1207 train_time:386817ms step_avg:138.15ms
29+
step:3000/20000 train_loss:2.1246 train_time:412728ms step_avg:137.58ms
30+
step:3000/20000 val_loss:2.0497 val_bpb:1.2139 train_time:412740ms step_avg:137.58ms
31+
step:3200/20000 train_loss:2.1237 train_time:440873ms step_avg:137.77ms
32+
step:3400/20000 train_loss:1.9635 train_time:466572ms step_avg:137.23ms
33+
step:3600/20000 train_loss:2.0342 train_time:494726ms step_avg:137.42ms
34+
step:3800/20000 train_loss:2.0025 train_time:521158ms step_avg:137.15ms
35+
step:4000/20000 train_loss:1.9024 train_time:550682ms step_avg:137.67ms
36+
step:4000/20000 val_loss:1.9912 val_bpb:1.1793 train_time:550743ms step_avg:137.69ms
37+
late_qat:activated at step 4070 (lr_scale=0.095)
38+
step:4071/20000 val_loss:1.9873 val_bpb:1.1770 train_time:615394ms step_avg:151.17ms
39+
stopping_early: wallclock_cap train_time:615394ms step:4071/20000
40+
peak memory allocated: 20590 MiB reserved: 21382 MiB
41+
swa:applying uniform average of 6 checkpoints
42+
Serialized model: 99630867 bytes
43+
Total submission size: 99683431 bytes
44+
Serialized model int8+zstd: 15896909 bytes (payload:25935478 raw_torch:25994685 payload_ratio:3.84x)
45+
Total submission size int8+zlib: 15949473 bytes
46+
final_int8_zlib_roundtrip val_loss:1.9931 val_bpb:1.1804 eval_time:2215ms
47+
final_int8_zlib_roundtrip_exact val_loss:1.99308832 val_bpb:1.18041917

0 commit comments

Comments
 (0)