Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# 11L Partial RoPE + XSA4 + VE128 + Tight SWA + Late QAT + GPTQ-lite

## Score: val_bpb = 1.1804 (post-quant, single seed)

Trained on 8×H100 SXM in 615 seconds. 15.95MB artifact (int6+zstd-22).

## Approach

Combines the PR #374 SOTA stack with MLP width reduction (1408 vs 1536) to fit under 16MB, plus GPTQ-lite quantization optimization.

### Architecture
- 11 transformer layers, 512-dim, 8 heads (4 KV heads, GQA)
- MLP hidden=1408 (2.75× expansion), relu-squared activation
- **Partial RoPE** (16/64 dims): Only 25% of head dims get rotary embeddings. The remaining 75% are position-free, improving generalization.
- **LN Scale** (1/sqrt(layer_idx+1)): Damps RMSNorm output in deeper layers, stabilizing gradient flow.
- **XSA** on last 4 layers: Exclusive Self Attention removes self-value bias via GQA-aware orthogonal projection. Zero new parameters, ~2ms/step.
- **Shared Value Embedding** (dim=128, layers 9,10): Single embedding table projected to KV dim, added to V in selected layers. Per-layer learned scales.
- SmearGate: Learned per-dim gate blending current + previous token embeddings.
- U-Net skip connections (5 encoder, 6 decoder), tied embeddings, logit softcap 30.

### Training
- Muon optimizer: lr=0.025, momentum=0.99 (warmup 0.92→0.99 over 1500 steps), WD=0.04
- AdamW: embed_lr=0.035, scalar_lr=0.025, WD=0.04
- Batch: 786,432 tokens/step, seq_len=2048
- Warmdown: 3000 iters (wallclock-based), grad_clip=0.3
- **Tight SWA**: Uniform average of checkpoints collected every 50 steps when lr_scale < 0.2 (6 checkpoints total). Zero quality penalty vs non-SWA.
- **Late QAT**: STE int6 fake-quantization activated when lr_scale < 0.1 (step 4070). LR halved at activation to avoid disrupting converged weights.

### Quantization
- **GPTQ-lite**: Per-tensor clip ratio search (5 candidates: 0.9999, 0.99999, 0.999999, 0.9999984, 1.0). Selects the clip percentile that minimizes reconstruction error L2. Zero training cost.
- Int6 step=4 rounding on layers 1-9 (64 distinct values for better compression)
- Int8 on layers 0 and 10 (input/output quality)
- FP16 tied embeddings (never quantized)
- zstd level 22 compression

## Key Metrics

| Metric | Value |
|--------|-------|
| Pre-quant val_bpb | 1.1770 |
| **Post-quant val_bpb** | **1.1804** |
| Quant gap | +0.0034 |
| Steps completed | 4,071 |
| Step time | 137ms avg (151ms after Late QAT) |
| Model parameters | 25,224,291 |
| Artifact size | 15,949,473 bytes (15.95 MB) |
| Peak GPU memory | 20,590 MiB |

## Convergence

| Step | val_bpb | train_time |
|------|---------|-----------|
| 1000 | 1.3246 | 136s |
| 2000 | 1.2551 | 274s |
| 3000 | 1.2139 | 413s |
| 4000 | 1.1793 | 551s |
| 4071 | 1.1770 | 615s (cap) |

## Lessons Learned

1. **MLP hidden=1408 > 1536 for artifact-constrained models**: Narrower MLP fits in 16MB with int6+zstd while enabling ~33% more training steps (137ms vs 178ms/step). The extra steps more than compensate for reduced per-step capacity.

2. **Late QAT timing matters**: Activating at lr_scale<0.1 (last ~1% of training) gives only 1 step of QAT adaptation. Earlier activation (lr_scale<0.2) would give more adaptation time but risks disrupting Muon momentum.

3. **Tight SWA (scale<0.2) eliminates SWA quality penalty**: Standard SWA (scale<0.5) averages stale early-warmdown checkpoints that hurt final quality. Restricting to scale<0.2 produces weight averaging with zero quality loss.

4. **GPTQ-lite clip search is free**: Trying 5 clip ratios per tensor during quantization costs ~2s total and reduces reconstruction error without any training cost.

## Setup

```bash
pip install --break-system-packages zstandard
# or: pip install -r requirements.txt
```

## Command

```bash
RUN_ID=pr374_8x_v2 MLP_HIDDEN=1408 \
DATA_PATH=../../../data/datasets/fineweb10B_sp1024/ \
TOKENIZER_PATH=../../../data/tokenizers/fineweb_1024_bpe.model \
VOCAB_SIZE=1024 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Status

**Non-record submission.** Single seed (1337). val_bpb 1.1804 does not beat SOTA 1.1428 by the required 0.005 margin.

Submitted to document the systematic combination of frontier techniques (Partial RoPE, LN Scale, XSA, Shared VE, Tight SWA, Late QAT, GPTQ-lite) with the novel insight that MLP hidden=1408 (vs 1536) produces better results under the 16MB constraint because faster step time yields more training steps.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
zstandard
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
{
"author": "Roberto Arce",
"github_id": "rarce",
"name": "11L Partial RoPE + XSA4 + VE128 + Tight SWA + Late QAT + GPTQ-lite",
"blurb": "11-layer 512-dim model with MLP hidden=1408 (2.75x). Partial RoPE (16/64 dims) leaves 75% of head dims position-free. LN Scale (1/sqrt(layer+1)) damps deeper layers. XSA on last 4 layers removes self-value bias via GQA-aware orthogonal projection. Shared Value Embedding (dim=128) on layers 9,10. SmearGate for bigram context. Tight SWA (scale<0.2, 6 checkpoints) for zero-penalty weight averaging. Late QAT (STE int6 at lr_scale<0.1) avoids Muon momentum corruption. GPTQ-lite per-tensor clip ratio search minimizes reconstruction error at quantization time. Int6 step=4 on layers 1-9, int8 on layers 0,10. FP16 tied embeddings. zstd-22 compression. Muon lr=0.025, momentum=0.99, WD=0.04. Batch=786K, seq=2048, warmdown=3000, grad_clip=0.3.",
"date": "2026-03-23T11:30:00Z",
"val_loss": 1.99308832,
"val_bpb": 1.18041917,
"pre_quant_val_loss": 1.9873,
"pre_quant_val_bpb": 1.1770,
"step_stop": 4071,
"wallclock_seconds": 615.394,
"eval_time_seconds": 2.215,
"bytes_total": 15949473,
"bytes_model_int6_zstd": 15896909,
"bytes_code": 52564,
"seeds": [1337],
"seed_results": {
"1337": {"val_loss": 1.99308832, "val_bpb": 1.18041917}
},
"notes": "Single seed run. Non-record submission — documents combination of PR #374 stack techniques (Partial RoPE, LN Scale, XSA, Shared VE, Tight SWA, Late QAT) with novel GPTQ-lite clip search and MLP hidden=1408 for artifact size optimization."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
model_params:25224291
step:0/20000 val_loss:6.9312 val_bpb:4.1050 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9328 train_time:150ms step_avg:150.20ms
step:2/20000 train_loss:8.8303 train_time:263ms step_avg:131.58ms
step:3/20000 train_loss:7.7380 train_time:387ms step_avg:128.98ms
step:4/20000 train_loss:7.3025 train_time:523ms step_avg:130.73ms
step:5/20000 train_loss:7.0590 train_time:705ms step_avg:140.90ms
step:6/20000 train_loss:6.9447 train_time:894ms step_avg:148.94ms
step:7/20000 train_loss:6.8608 train_time:1026ms step_avg:146.62ms
step:8/20000 train_loss:6.7036 train_time:1195ms step_avg:149.32ms
step:9/20000 train_loss:6.3999 train_time:1399ms step_avg:155.50ms
step:10/20000 train_loss:6.1058 train_time:1614ms step_avg:161.42ms
step:200/20000 train_loss:2.4242 train_time:26353ms step_avg:131.77ms
step:400/20000 train_loss:2.4424 train_time:56529ms step_avg:141.32ms
step:600/20000 train_loss:2.3537 train_time:82232ms step_avg:137.05ms
step:800/20000 train_loss:2.2541 train_time:109933ms step_avg:137.42ms
step:1000/20000 train_loss:2.2857 train_time:136310ms step_avg:136.31ms
step:1000/20000 val_loss:2.2366 val_bpb:1.3246 train_time:136329ms step_avg:136.33ms
step:1200/20000 train_loss:2.3652 train_time:165361ms step_avg:137.80ms
step:1400/20000 train_loss:2.1893 train_time:193954ms step_avg:138.54ms
step:1600/20000 train_loss:2.0746 train_time:219091ms step_avg:136.93ms
step:1800/20000 train_loss:2.1419 train_time:247744ms step_avg:137.64ms
step:2000/20000 train_loss:2.0570 train_time:273795ms step_avg:136.90ms
step:2000/20000 val_loss:2.1191 val_bpb:1.2551 train_time:273809ms step_avg:136.90ms
step:2200/20000 train_loss:2.1214 train_time:302110ms step_avg:137.32ms
step:2400/20000 train_loss:2.0412 train_time:328279ms step_avg:136.78ms
step:2600/20000 train_loss:2.0787 train_time:357420ms step_avg:137.47ms
step:2800/20000 train_loss:2.1207 train_time:386817ms step_avg:138.15ms
step:3000/20000 train_loss:2.1246 train_time:412728ms step_avg:137.58ms
step:3000/20000 val_loss:2.0497 val_bpb:1.2139 train_time:412740ms step_avg:137.58ms
step:3200/20000 train_loss:2.1237 train_time:440873ms step_avg:137.77ms
step:3400/20000 train_loss:1.9635 train_time:466572ms step_avg:137.23ms
step:3600/20000 train_loss:2.0342 train_time:494726ms step_avg:137.42ms
step:3800/20000 train_loss:2.0025 train_time:521158ms step_avg:137.15ms
step:4000/20000 train_loss:1.9024 train_time:550682ms step_avg:137.67ms
step:4000/20000 val_loss:1.9912 val_bpb:1.1793 train_time:550743ms step_avg:137.69ms
late_qat:activated at step 4070 (lr_scale=0.095)
step:4071/20000 val_loss:1.9873 val_bpb:1.1770 train_time:615394ms step_avg:151.17ms
stopping_early: wallclock_cap train_time:615394ms step:4071/20000
peak memory allocated: 20590 MiB reserved: 21382 MiB
swa:applying uniform average of 6 checkpoints
Serialized model: 99630867 bytes
Total submission size: 99683431 bytes
Serialized model int8+zstd: 15896909 bytes (payload:25935478 raw_torch:25994685 payload_ratio:3.84x)
Total submission size int8+zlib: 15949473 bytes
final_int8_zlib_roundtrip val_loss:1.9931 val_bpb:1.1804 eval_time:2215ms
final_int8_zlib_roundtrip_exact val_loss:1.99308832 val_bpb:1.18041917
Loading