Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# 10L CountInitBigram + XSA + PartialRoPE + LN Scale

**val_bpb: 1.1522** (sliding window stride=64, post int5/int6+zstd quantization roundtrip)

## Run Command

```bash
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

All parameters are set as defaults in `train_gpt.py`. No env vars needed.

## Key Techniques

### 1. Count-Initialized Exact Bigram Logit Head (Novel)
A 1024×1024 lookup table providing exact bigram logit biases with zero hash collisions.
Initialized from corpus transition probabilities before training:

```
B[a,b] = log p(b|a) - log p(b)
```

Computed from the first 16M training tokens with additive smoothing (α=0.25), clipped to [-4, 4].
Applied BEFORE logit softcap so the bias is properly bounded.
The table is quantized to **int4 with nibble packing** (524KB vs 1MB at int8).

This gives the model a strong count-based language model prior from step 0, which the neural
network only needs to refine — valuable under a 10-minute training budget.

### 2. Int4 Nibble Packing (Novel)
Custom `pack_i4` / `unpack_i4` functions pack signed int4 values [-8,7] into uint8 bytes
(two values per byte). Applied to the bigram logit table, halving its storage cost.

### 3. XSA (Exclusive Self Attention) — Last 4 Layers
Removes the self-value component from attention output, forcing attention to model only
contextual information orthogonal to the self-representation (arxiv:2603.09078).

```python
vn = F.normalize(v, dim=-1)
y = y - (y * vn).sum(dim=-1, keepdim=True) * vn
```

### 4. Partial RoPE (16 of 64 dims)
Apply rotary position embeddings to only 25% of head dimensions. The remaining 75%
attend without positional bias, acting as position-independent feature detectors.

### 5. LN Scale
Block outputs scaled by `1/sqrt(layer_idx + 1)`. Damps deeper layers' contributions,
stabilizing training. Zero parameters.

### 6. Higher Learning Rates
- matrix_lr: 0.025 (up from 0.02)
- scalar_lr: 0.025
- tied_embed_lr: 0.035

## Architecture
- 10 layers, 512 dim, 8 heads, 4 KV heads (GQA)
- MLP 3x expansion (hidden=1536), relu² activation
- SmearGate + exact BigramLogitHead (count-initialized, int4 packed)
- Orthogonal init with muP-scaled output projections
- U-Net skip connections, tied embeddings

## Training Hyperparameters
- Muon optimizer: matrix_lr=0.025, WD=0.04, momentum=0.99
- AdamW for embeddings/scalars: WD=0.04
- warmdown=2800 iters, warmup=20 steps
- seq_len=2048, batch=786K tokens
- grad_clip=0.3, 3% magnitude pruning
- SWA: start_frac=0.4, every=50 steps (22 checkpoints)
- Sliding window eval: stride=64

## Quantization
- Int5 [-16,15] for MLP weights
- Int6 [-32,31] for attention weights
- Int4 [-8,7] nibble-packed for bigram logit table
- FP16 for tied embeddings
- zstd-22 compression

## Results
```
Steps completed: 6267 (wallclock capped at 600s)
Step time: 95.75 ms/step
Peak memory: 19609 MiB allocated, 19878 MiB reserved

Pre-SWA val_bpb: 1.1563
Post-SWA (22 checkpoints): improved
Post-quant roundtrip val_bpb: 1.1522
Quant gap: 0.004 bpb

Artifact size: 15,322,709 bytes (int6+zstd)
Code size: 61,523 bytes
Total: 15,384,232 bytes (under 16,000,000 limit)
```

## What's Novel in This Submission
1. **Count-initialized bigram logit head** — no other submission uses corpus-derived
transition probabilities as logit biases. Provides a strong Markov prior from step 0.
2. **Int4 nibble packing** — custom bit-packing for the bigram table, halving storage.
3. **Combined with XSA + Partial RoPE + LN Scale** — these are adopted from recent
papers and pending PRs, but the combination with count-init bigram is unique.

Built on the SOTA baseline by @thwu1 (PR #180) and adopts XSA from arxiv:2603.09078,
Partial RoPE from PR #315, and LN Scale from PR #315.
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"name": "10L CountInitBigram + XSA + PartialRoPE + LN Scale",
"val_loss": 1.94544622,
"val_bpb": 1.15220588,
"bytes_total": 15384232,
"bytes_model_int6_zstd": 15322709,
"bytes_code": 61523,
"blurb": "10 layers with count-initialized exact bigram logit head (corpus log-prob residuals, int4 nibble-packed), XSA on last 4 layers, Partial RoPE (16/64 dims), LN Scale (1/sqrt(layer+1)), higher LR (0.025). SWA over 22 checkpoints, sliding window eval stride=64.",
"author": "Sri Harsha Gouru",
"github_id": "harsha-gouru",
"date": "2026-03-22"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
seed:42
seed:42
seed:42
seed:42
warmup_step:1/20
warmup_step:1/20
warmup_step:1/20
warmup_step:20/20
warmup_step:20/20
warmup_step:20/20
step:200/20000 train_loss:2.3712 train_time:19185ms step_avg:95.93ms
step:500/20000 val_loss:2.3625 val_bpb:1.3992 train_time:47906ms step_avg:95.81ms
step:800/20000 train_loss:2.2471 train_time:76829ms step_avg:96.04ms
step:1100/20000 train_loss:2.3264 train_time:105475ms step_avg:95.89ms
step:1400/20000 train_loss:2.1932 train_time:134238ms step_avg:95.88ms
step:1700/20000 train_loss:2.1487 train_time:162862ms step_avg:95.80ms
step:2000/20000 val_loss:2.1413 val_bpb:1.2682 train_time:191524ms step_avg:95.76ms
step:2300/20000 train_loss:2.1218 train_time:220300ms step_avg:95.78ms
step:2600/20000 train_loss:2.1206 train_time:248938ms step_avg:95.75ms
step:2900/20000 train_loss:2.0369 train_time:277499ms step_avg:95.69ms
step:3200/20000 train_loss:2.1853 train_time:306283ms step_avg:95.71ms
step:3500/20000 val_loss:2.0956 val_bpb:1.2411 train_time:334990ms step_avg:95.71ms
step:3800/20000 train_loss:2.0836 train_time:363568ms step_avg:95.68ms
step:4100/20000 train_loss:2.0146 train_time:392308ms step_avg:95.68ms
step:4400/20000 train_loss:2.0390 train_time:420921ms step_avg:95.66ms
step:4700/20000 train_loss:2.2341 train_time:449553ms step_avg:95.65ms
step:5000/20000 val_loss:2.0222 val_bpb:1.1976 train_time:478355ms step_avg:95.67ms
step:5300/20000 train_loss:2.0038 train_time:507084ms step_avg:95.68ms
step:5600/20000 train_loss:1.9457 train_time:535996ms step_avg:95.71ms
step:5900/20000 train_loss:1.8864 train_time:564743ms step_avg:95.72ms
step:6200/20000 train_loss:1.9320 train_time:593509ms step_avg:95.73ms
Final summary: {'run_id': 'submission-10L-v1', 'gpus': 8, 'exit_code': 0, 'elapsed_s': 1140.4, 'artifact_bytes': 15322709, 'code_bytes': 61523, 'total_bytes': 15384232, 'under_limit': True, 'headroom_bytes': 615768, 'final_line': 'final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588', 'val_bpb': 1.15220588}
val_bpb = 1.152206
final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/data/data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=/data/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
bigram_init:count_based time:0.22s
model_params:25779282
world_size:8 grad_accum_steps:1
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:42
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:7.3822 val_bpb:4.3721 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:7.3777 train_time:223ms step_avg:223.20ms
step:2/20000 train_loss:5.1389 train_time:303ms step_avg:151.25ms
step:3/20000 train_loss:6.8711 train_time:399ms step_avg:132.99ms
step:4/20000 train_loss:7.9386 train_time:498ms step_avg:124.43ms
step:5/20000 train_loss:7.0619 train_time:593ms step_avg:118.64ms
step:6/20000 train_loss:6.4111 train_time:689ms step_avg:114.78ms
step:7/20000 train_loss:5.8449 train_time:785ms step_avg:112.11ms
step:8/20000 train_loss:5.4955 train_time:879ms step_avg:109.93ms
step:9/20000 train_loss:5.4268 train_time:974ms step_avg:108.26ms
step:10/20000 train_loss:5.1444 train_time:1069ms step_avg:106.90ms
step:100/20000 train_loss:3.1036 train_time:9588ms step_avg:95.88ms
step:200/20000 train_loss:2.3712 train_time:19185ms step_avg:95.93ms
step:300/20000 train_loss:2.5539 train_time:28785ms step_avg:95.95ms
step:400/20000 train_loss:2.4180 train_time:38370ms step_avg:95.92ms
step:500/20000 train_loss:2.4018 train_time:47882ms step_avg:95.76ms
step:500/20000 val_loss:2.3625 val_bpb:1.3992 train_time:47906ms step_avg:95.81ms
step:600/20000 train_loss:2.3373 train_time:57505ms step_avg:95.84ms
step:700/20000 train_loss:2.3546 train_time:67225ms step_avg:96.04ms
step:800/20000 train_loss:2.2471 train_time:76829ms step_avg:96.04ms
step:900/20000 train_loss:2.1361 train_time:86414ms step_avg:96.02ms
step:1000/20000 train_loss:2.2881 train_time:95886ms step_avg:95.89ms
step:1000/20000 val_loss:2.2347 val_bpb:1.3235 train_time:95910ms step_avg:95.91ms
step:1100/20000 train_loss:2.3264 train_time:105475ms step_avg:95.89ms
step:1200/20000 train_loss:2.3602 train_time:115065ms step_avg:95.89ms
step:1300/20000 train_loss:2.1075 train_time:124651ms step_avg:95.89ms
step:1400/20000 train_loss:2.1932 train_time:134238ms step_avg:95.88ms
step:1500/20000 train_loss:2.2303 train_time:143696ms step_avg:95.80ms
step:1500/20000 val_loss:2.1931 val_bpb:1.2989 train_time:143719ms step_avg:95.81ms
step:1600/20000 train_loss:2.0843 train_time:153282ms step_avg:95.80ms
step:1700/20000 train_loss:2.1487 train_time:162862ms step_avg:95.80ms
step:1800/20000 train_loss:2.1646 train_time:172451ms step_avg:95.81ms
step:1900/20000 train_loss:2.1335 train_time:181904ms step_avg:95.74ms
step:2000/20000 train_loss:2.0755 train_time:191501ms step_avg:95.75ms
step:2000/20000 val_loss:2.1413 val_bpb:1.2682 train_time:191524ms step_avg:95.76ms
step:2100/20000 train_loss:2.0542 train_time:201099ms step_avg:95.76ms
step:2200/20000 train_loss:2.1418 train_time:210687ms step_avg:95.77ms
step:2300/20000 train_loss:2.1218 train_time:220300ms step_avg:95.78ms
step:2400/20000 train_loss:2.0761 train_time:229761ms step_avg:95.73ms
step:2500/20000 train_loss:2.1798 train_time:239357ms step_avg:95.74ms
step:2500/20000 val_loss:2.1160 val_bpb:1.2532 train_time:239380ms step_avg:95.75ms
step:2600/20000 train_loss:2.1206 train_time:248938ms step_avg:95.75ms
step:2700/20000 train_loss:2.1137 train_time:258488ms step_avg:95.74ms
step:2800/20000 train_loss:2.1632 train_time:268049ms step_avg:95.73ms
step:2900/20000 train_loss:2.0369 train_time:277499ms step_avg:95.69ms
step:3000/20000 train_loss:2.1721 train_time:287094ms step_avg:95.70ms
step:3000/20000 val_loss:2.1034 val_bpb:1.2457 train_time:287118ms step_avg:95.71ms
step:3100/20000 train_loss:2.0467 train_time:296656ms step_avg:95.70ms
step:3200/20000 train_loss:2.1853 train_time:306283ms step_avg:95.71ms
step:3300/20000 train_loss:2.0843 train_time:315738ms step_avg:95.68ms
step:3400/20000 train_loss:2.0352 train_time:325337ms step_avg:95.69ms
step:3500/20000 train_loss:2.1954 train_time:334966ms step_avg:95.70ms
step:3500/20000 val_loss:2.0956 val_bpb:1.2411 train_time:334990ms step_avg:95.71ms
step:3600/20000 train_loss:2.1086 train_time:344546ms step_avg:95.71ms
step:3700/20000 train_loss:2.1098 train_time:354112ms step_avg:95.71ms
step:3800/20000 train_loss:2.0836 train_time:363568ms step_avg:95.68ms
step:3900/20000 train_loss:2.0833 train_time:373128ms step_avg:95.67ms
step:4000/20000 train_loss:1.9800 train_time:382712ms step_avg:95.68ms
step:4000/20000 val_loss:2.0729 val_bpb:1.2277 train_time:382735ms step_avg:95.68ms
step:4100/20000 train_loss:2.0146 train_time:392308ms step_avg:95.68ms
step:4200/20000 train_loss:2.1597 train_time:401902ms step_avg:95.69ms
step:4300/20000 train_loss:2.0659 train_time:411352ms step_avg:95.66ms
step:4400/20000 train_loss:2.0390 train_time:420921ms step_avg:95.66ms
step:4500/20000 train_loss:2.1277 train_time:430528ms step_avg:95.67ms
step:4500/20000 val_loss:2.0468 val_bpb:1.2122 train_time:430552ms step_avg:95.68ms
step:4600/20000 train_loss:1.8398 train_time:440101ms step_avg:95.67ms
step:4700/20000 train_loss:2.2341 train_time:449553ms step_avg:95.65ms
step:4800/20000 train_loss:2.4262 train_time:459158ms step_avg:95.66ms
step:4900/20000 train_loss:2.0473 train_time:468724ms step_avg:95.66ms
step:5000/20000 train_loss:2.1047 train_time:478331ms step_avg:95.67ms
step:5000/20000 val_loss:2.0222 val_bpb:1.1976 train_time:478355ms step_avg:95.67ms
step:5100/20000 train_loss:2.1237 train_time:487891ms step_avg:95.66ms
swa:start step:5200
step:5200/20000 train_loss:2.0397 train_time:497349ms step_avg:95.64ms
step:5300/20000 train_loss:2.0038 train_time:507084ms step_avg:95.68ms
step:5400/20000 train_loss:2.0424 train_time:516727ms step_avg:95.69ms
step:5500/20000 train_loss:2.0107 train_time:526366ms step_avg:95.70ms
step:5500/20000 val_loss:1.9943 val_bpb:1.1811 train_time:526416ms step_avg:95.71ms
step:5600/20000 train_loss:1.9457 train_time:535996ms step_avg:95.71ms
step:5700/20000 train_loss:2.0044 train_time:545480ms step_avg:95.70ms
step:5800/20000 train_loss:1.9851 train_time:555138ms step_avg:95.71ms
step:5900/20000 train_loss:1.8864 train_time:564743ms step_avg:95.72ms
step:6000/20000 train_loss:1.9267 train_time:574385ms step_avg:95.73ms
step:6000/20000 val_loss:1.9648 val_bpb:1.1636 train_time:574435ms step_avg:95.74ms
step:6100/20000 train_loss:1.9044 train_time:583877ms step_avg:95.72ms
step:6200/20000 train_loss:1.9320 train_time:593509ms step_avg:95.73ms
step:6267/20000 val_loss:1.9523 val_bpb:1.1563 train_time:600046ms step_avg:95.75ms
stopping_early: wallclock_cap train_time:600046ms step:6267/20000
peak memory allocated: 19609 MiB reserved: 19878 MiB
swa:applying averaged 22 checkpoints
Serialized model: 98962351 bytes
Serialized model int6+zstd: 15322709 bytes
final_eval_mode:sliding_window stride:64 batch_seqs:32
final_int8_zlib_roundtrip val_loss:1.9454 val_bpb:1.1522 eval_time:183742ms
final_int8_zlib_roundtrip_exact val_loss:1.94544622 val_bpb:1.15220588
Loading
Loading