Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Pre-Enrichment + EMA-GPU + SmearGate + XSA4

val_bpb: **1.1478** (sliding window, stride=64) | 14.94 MB | 8xH100 SXM, 600s

### 3-Seed Results

| Seed | Steps | val_bpb (sliding) | Artifact |
|---|---|---|---|
| 1337 | 9,268 | 1.1478 | 14,942,971 |
| 42 | 9,318 | 1.1468 | 14,922,769 |
| 3011 | 9,322 | 1.1463 | 14,939,305 |
| **Mean** | — | **1.1470** | — |
| **Std** | — | **0.0008** | — |

## Architecture

- **Model**: 10L, 512d, 8H/4KV GQA, MLP 3x, tied embeddings
- **GELU Pre-Enrichment** (512->768->512): wider nonlinear transformation before transformer blocks
- **XSA** on last 4 layers: removes self-value bias (arXiv:2603.09078)
- **SmearGate**: per-dim gate blending each token with previous token
- **BigramHash** (2048x128): hash-table embedding for token bigrams
- **U-Net skip connections**: encoder-decoder with learned skip weights
- **EMA** (decay=0.997) on GPU: 37% faster training (64.7ms vs 101ms/step)
- **Int6 QAT + lzma**: 14.94 MB artifact, quant gap 0.004
- **Sliding window eval**: stride=64, seq_len=2048

## Training

Muon+AdamW, WD=0.04, matrix_lr=0.025, warmdown=3500, batch=524K, seq=2048.
9,268 steps in 600s at 64.7ms/step.

## Key Metrics

| Metric | Value |
|---|---|
| val_bpb (sliding window) | 1.1478 |
| Post-quant val_bpb (standard) | 1.1690 |
| Pre-quant val_bpb | 1.1646 |
| Quant gap | 0.004 |
| Training time | 600,031ms (9,268 steps at 64.7ms) |
| Artifact size | 14,942,971 bytes |
| Model parameters | 25,254,992 |

## Credits

- Muon optimizer — modded-nanogpt baseline (kellerjordan)
- SmearGate + BigramHash — PR #65 (@aquariouseworkman)
- XSA — arXiv:2603.09078; GQA-aware PR #265 (@unnir)
- EMA + GPTQ-lite + warmdown tuning — PR #414 (@signalrush)
- Overtone init — modded-nanogpt baseline
- GELU Pre-Enrichment — original to this submission
- EMA on GPU — original to this submission

## Reproduction

```
python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 80
torchrun --standalone --nproc_per_node=8 train_gpt.py
```
8xH100 SXM, 600s training + ~120s eval.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"author": "Idanr",
"github_id": "idan3011",
"name": "Pre-Enrichment + EMA-GPU + SmearGate + XSA4",
"blurb": "GELU pre-enrichment (512->768->512), EMA on GPU (64.7ms/step, 37% faster), SmearGate, BigramHash(2048x128), XSA on last 4 layers, U-Net skip connections. 10L 512d, int6 QAT + lzma.",
"date": "2026-03-28T00:00:00Z",
"val_loss": 1.93793804,
"val_bpb": 1.14775606,
"pre_quant_val_loss": 1.9663,
"pre_quant_val_bpb": 1.1646,
"step_stop": 9268,
"wallclock_seconds": 600.031,
"eval_time_seconds": 120.0,
"bytes_total": 14942971,
"bytes_model_int6_lzma": 14878748,
"bytes_code": 64223
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
W0326 02:39:19.172000 34413 torch/distributed/run.py:803]
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] *****************************************
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0326 02:39:19.172000 34413 torch/distributed/run.py:803] *****************************************
logs/0d771539-26db-4427-b5a8-0a4c24bd56ad.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:25254992
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=True flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:524288 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9319 val_bpb:4.1055 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9318 train_time:62ms step_avg:61.75ms
step:2/20000 train_loss:7.1516 train_time:121ms step_avg:60.53ms
step:3/20000 train_loss:6.1791 train_time:185ms step_avg:61.59ms
step:4/20000 train_loss:6.4189 train_time:249ms step_avg:62.18ms
step:5/20000 train_loss:6.5862 train_time:313ms step_avg:62.55ms
step:6/20000 train_loss:6.2277 train_time:377ms step_avg:62.78ms
step:7/20000 train_loss:5.4960 train_time:441ms step_avg:62.97ms
step:8/20000 train_loss:5.2973 train_time:505ms step_avg:63.10ms
step:9/20000 train_loss:5.0005 train_time:569ms step_avg:63.20ms
step:10/20000 train_loss:4.8514 train_time:633ms step_avg:63.30ms
step:200/20000 train_loss:2.7511 train_time:12872ms step_avg:64.36ms
step:400/20000 train_loss:2.2579 train_time:25781ms step_avg:64.45ms
step:600/20000 train_loss:2.4713 train_time:38736ms step_avg:64.56ms
step:800/20000 train_loss:2.2316 train_time:51722ms step_avg:64.65ms
step:1000/20000 train_loss:2.3340 train_time:64727ms step_avg:64.73ms
step:1000/20000 val_loss:2.2855 val_bpb:1.3536 train_time:64739ms step_avg:64.74ms
step:1200/20000 train_loss:2.3620 train_time:77744ms step_avg:64.79ms
step:1400/20000 train_loss:2.3964 train_time:90750ms step_avg:64.82ms
step:1600/20000 train_loss:2.0689 train_time:103750ms step_avg:64.84ms
step:1800/20000 train_loss:2.1729 train_time:116742ms step_avg:64.86ms
step:2000/20000 train_loss:2.2158 train_time:129716ms step_avg:64.86ms
step:2000/20000 val_loss:2.1975 val_bpb:1.3015 train_time:129728ms step_avg:64.86ms
step:2200/20000 train_loss:2.0324 train_time:142686ms step_avg:64.86ms
step:2400/20000 train_loss:2.1624 train_time:155641ms step_avg:64.85ms
step:2600/20000 train_loss:2.3841 train_time:168596ms step_avg:64.84ms
step:2800/20000 train_loss:2.2002 train_time:181543ms step_avg:64.84ms
step:3000/20000 train_loss:2.1908 train_time:194474ms step_avg:64.82ms
step:3000/20000 val_loss:2.1539 val_bpb:1.2757 train_time:194486ms step_avg:64.83ms
step:3200/20000 train_loss:2.1563 train_time:207406ms step_avg:64.81ms
step:3400/20000 train_loss:2.1250 train_time:220338ms step_avg:64.81ms
step:3600/20000 train_loss:2.0721 train_time:233268ms step_avg:64.80ms
step:3800/20000 train_loss:2.1786 train_time:246196ms step_avg:64.79ms
step:4000/20000 train_loss:2.1419 train_time:259115ms step_avg:64.78ms
step:4000/20000 val_loss:2.1367 val_bpb:1.2655 train_time:259127ms step_avg:64.78ms
step:4200/20000 train_loss:2.1372 train_time:272101ms step_avg:64.79ms
step:4400/20000 train_loss:2.0839 train_time:285022ms step_avg:64.78ms
step:4600/20000 train_loss:1.9446 train_time:297946ms step_avg:64.77ms
step:4800/20000 train_loss:2.2371 train_time:310856ms step_avg:64.76ms
step:5000/20000 train_loss:1.9905 train_time:323763ms step_avg:64.75ms
step:5000/20000 val_loss:2.1285 val_bpb:1.2606 train_time:323775ms step_avg:64.76ms
step:5200/20000 train_loss:2.1516 train_time:336678ms step_avg:64.75ms
step:5400/20000 train_loss:2.1670 train_time:349585ms step_avg:64.74ms
step:5600/20000 train_loss:2.1609 train_time:362500ms step_avg:64.73ms
step:5800/20000 train_loss:2.1178 train_time:375416ms step_avg:64.73ms
step:6000/20000 train_loss:2.1963 train_time:388331ms step_avg:64.72ms
step:6000/20000 val_loss:2.1194 val_bpb:1.2552 train_time:388343ms step_avg:64.72ms
step:6200/20000 train_loss:2.0618 train_time:401239ms step_avg:64.72ms
step:6400/20000 train_loss:2.1328 train_time:414152ms step_avg:64.71ms
step:6600/20000 train_loss:2.0839 train_time:427067ms step_avg:64.71ms
step:6800/20000 train_loss:2.1327 train_time:439971ms step_avg:64.70ms
step:7000/20000 train_loss:2.1739 train_time:452890ms step_avg:64.70ms
step:7000/20000 val_loss:2.0766 val_bpb:1.2299 train_time:452903ms step_avg:64.70ms
step:7200/20000 train_loss:2.1442 train_time:465802ms step_avg:64.69ms
step:7400/20000 train_loss:2.0575 train_time:478715ms step_avg:64.69ms
step:7600/20000 train_loss:1.9264 train_time:491637ms step_avg:64.69ms
step:7800/20000 train_loss:2.0683 train_time:504556ms step_avg:64.69ms
step:8000/20000 train_loss:2.0304 train_time:517550ms step_avg:64.69ms
step:8000/20000 val_loss:2.0324 val_bpb:1.2037 train_time:517563ms step_avg:64.70ms
step:8200/20000 train_loss:2.1001 train_time:530461ms step_avg:64.69ms
step:8400/20000 train_loss:2.0298 train_time:543436ms step_avg:64.69ms
step:8600/20000 train_loss:2.0308 train_time:556429ms step_avg:64.70ms
step:8800/20000 train_loss:1.9809 train_time:569549ms step_avg:64.72ms
step:9000/20000 train_loss:1.8848 train_time:582572ms step_avg:64.73ms
step:9000/20000 val_loss:1.9773 val_bpb:1.1711 train_time:582573ms step_avg:64.73ms
step:9200/20000 train_loss:1.9494 train_time:595634ms step_avg:64.74ms
step:9268/20000 val_loss:1.9663 val_bpb:1.1646 train_time:600031ms step_avg:64.74ms
stopping_early: wallclock_cap train_time:600031ms step:9268/20000
peak memory allocated: 13058 MiB reserved: 13280 MiB
swa: averaging 14 checkpoints on top of EMA
ema: loading weights
Serialized model: 99486509 bytes
Code size: 64223 bytes
Total submission size: 99550732 bytes
Serialized model int6+lzma: 14878748 bytes (payload:25993024 raw_torch:26045291 payload_ratio:3.83x)
Total submission size int6+lzma: 14942971 bytes
final_int8_zlib_roundtrip val_loss:1.9738 val_bpb:1.1690 eval_time:2054ms
final_int8_zlib_roundtrip_exact val_loss:1.97382834 val_bpb:1.16901232
final_sliding_window val_bpb:1.1478 eval_time:120000ms
final_sliding_window_exact val_bpb:1.14775606
Loading