Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Optimized Baseline with Mixed Quantization and BigramHash

Tried a bunch of things on top of the baseline script. Some worked, some didn't. Here's what stuck.

## What I changed

**Architecture:**
- Bumped to 10 layers (from 9) with U-Net skip connections
- ReLU² MLP with 3x expansion (hidden=1536) — same param count per layer as SwiGLU but faster since it's only 2 matmuls
- Added BigramHash embeddings — hash table of 10240 bigram pairs mapped to 128-dim vectors, projected to model dim. Gives the model cheap access to "what was the previous token" without burning attention compute on it
- Orthogonal init for all weight matrices, SVD-based init for embeddings

**Quantization:**
- Mixed precision: INT6 for all weight matrices, INT8 for embeddings
- Straight-Through Estimator during training so the model learns to deal with quantization from the start
- zstd level 22 compression instead of zlib — squeezes out a few more percent

**Training:**
- Muon optimizer with weight decay 0.04 and gradient clipping at 0.3
- Stochastic Weight Averaging over the last 50% of training (every 50 steps) — this smooths out the weight distribution which helps quantization a lot
- Momentum warmup from 0.85 to 0.99 over 1500 steps

## Result

- **val_bpb: 1.2421** (post-roundtrip)
- Pre-roundtrip was 1.1924, so quantization costs about 0.05 bpb
- 11,070 steps in 600 seconds on 8xH100 SXM (~54 ms/step)
- Artifact: 13.28 MB (well under 16 MB limit)

## What I'd do differently

The main bottleneck is quantization degradation. Pre-roundtrip score is 1.19 which would be competitive, but INT6 quantization adds ~0.05 bpb. The top submissions get this down to 0.01-0.02. I think better STE scheduling or per-channel quantization could help here.

Also didn't get to try longer training with 80 shards — only used 10 due to disk constraints on the cloud setup. More data would probably help.
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"author": "Serghei Brinza",
"github": "SergheiBrinza",
"val_bpb": 1.2421,
"date": "2026-03-21"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
W0321 19:29:27.143000 813 torch/distributed/run.py:774]
W0321 19:29:27.143000 813 torch/distributed/run.py:774] *****************************************
W0321 19:29:27.143000 813 torch/distributed/run.py:774] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0321 19:29:27.143000 813 torch/distributed/run.py:774] *****************************************
logs/025_final.txt
val_bpb:enabled tokenizer_path=/workspace/parameter-golf/tok.model
train_shards:10 val_tokens:62021632
model_params:25516624 mlp_hidden:1536
world_size:8 grad_accum_steps:1
features: ReLU²MLP3x BigramHash(10240x128) OrthoInit MixedINT5/6STE SWA(0.5@50) GradClip(0.3) MuonWD(0.04)
compress:zstd-22 quant:INT5(MLP)/INT6(attn)/INT8(embed)
tie_embeddings:True embed_lr:0.05 matrix_lr:0.02
train_batch_tokens:524288 seq_len:1024 iterations:20000 warmup:20 wallclock:600s
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9315 val_bpb:4.1052 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9315 train_time:113ms step_avg:113.33ms
step:2/20000 train_loss:15.4686 train_time:156ms step_avg:77.88ms
step:3/20000 train_loss:7.1075 train_time:212ms step_avg:70.64ms
step:4/20000 train_loss:6.0783 train_time:268ms step_avg:66.96ms
step:5/20000 train_loss:6.7387 train_time:323ms step_avg:64.68ms
step:6/20000 train_loss:6.9151 train_time:386ms step_avg:64.25ms
step:7/20000 train_loss:6.0252 train_time:454ms step_avg:64.81ms
step:8/20000 train_loss:5.7453 train_time:509ms step_avg:63.64ms
step:9/20000 train_loss:5.5471 train_time:565ms step_avg:62.79ms
step:10/20000 train_loss:5.3234 train_time:620ms step_avg:62.03ms
step:200/20000 train_loss:2.9903 train_time:12751ms step_avg:63.76ms
step:400/20000 train_loss:2.3684 train_time:25490ms step_avg:63.73ms
step:600/20000 train_loss:2.5532 train_time:38247ms step_avg:63.74ms
step:800/20000 train_loss:2.3092 train_time:50992ms step_avg:63.74ms
step:1000/20000 train_loss:2.4088 train_time:63686ms step_avg:63.69ms
step:1000/20000 val_loss:2.3622 val_bpb:1.3990 train_time:63705ms step_avg:63.71ms
step:1200/20000 train_loss:2.4158 train_time:76331ms step_avg:63.61ms
step:1400/20000 train_loss:2.4674 train_time:88942ms step_avg:63.53ms
step:1600/20000 train_loss:2.1494 train_time:101577ms step_avg:63.49ms
step:1800/20000 train_loss:2.2301 train_time:114422ms step_avg:63.57ms
step:2000/20000 train_loss:2.2556 train_time:124884ms step_avg:62.44ms
step:2000/20000 val_loss:2.2527 val_bpb:1.3342 train_time:124902ms step_avg:62.45ms
step:2200/20000 train_loss:2.3549 train_time:135354ms step_avg:61.52ms
step:2400/20000 train_loss:2.3675 train_time:145869ms step_avg:60.78ms
step:2600/20000 train_loss:2.2148 train_time:156406ms step_avg:60.16ms
step:2800/20000 train_loss:2.1670 train_time:166926ms step_avg:59.62ms
step:3000/20000 train_loss:3.2060 train_time:177436ms step_avg:59.15ms
step:3000/20000 val_loss:2.1964 val_bpb:1.3008 train_time:177456ms step_avg:59.15ms
step:3200/20000 train_loss:2.2708 train_time:187960ms step_avg:58.74ms
step:3400/20000 train_loss:2.1020 train_time:198463ms step_avg:58.37ms
step:3600/20000 train_loss:2.2082 train_time:208980ms step_avg:58.05ms
step:3800/20000 train_loss:2.1517 train_time:219467ms step_avg:57.75ms
step:4000/20000 train_loss:2.2769 train_time:229968ms step_avg:57.49ms
step:4000/20000 val_loss:2.1703 val_bpb:1.2854 train_time:229986ms step_avg:57.50ms
step:4200/20000 train_loss:2.2248 train_time:240536ms step_avg:57.27ms
step:4400/20000 train_loss:2.1726 train_time:251010ms step_avg:57.05ms
step:4600/20000 train_loss:2.2063 train_time:261502ms step_avg:56.85ms
step:4800/20000 train_loss:2.1472 train_time:271977ms step_avg:56.66ms
step:5000/20000 train_loss:2.2387 train_time:282463ms step_avg:56.49ms
step:5000/20000 val_loss:2.1600 val_bpb:1.2793 train_time:282481ms step_avg:56.50ms
step:5200/20000 train_loss:2.2970 train_time:292927ms step_avg:56.33ms
swa_started step:5335
step:5400/20000 train_loss:2.2399 train_time:303399ms step_avg:56.19ms
step:5600/20000 train_loss:2.1548 train_time:313874ms step_avg:56.05ms
step:5800/20000 train_loss:2.1925 train_time:324342ms step_avg:55.92ms
step:6000/20000 train_loss:2.1065 train_time:334822ms step_avg:55.80ms
step:6000/20000 val_loss:2.1466 val_bpb:1.2713 train_time:334841ms step_avg:55.81ms
step:6200/20000 train_loss:2.0958 train_time:345297ms step_avg:55.69ms
step:6400/20000 train_loss:1.8782 train_time:355767ms step_avg:55.59ms
step:6600/20000 train_loss:2.1018 train_time:366227ms step_avg:55.49ms
step:6800/20000 train_loss:2.1489 train_time:376678ms step_avg:55.39ms
step:7000/20000 train_loss:2.0989 train_time:387145ms step_avg:55.31ms
step:7000/20000 val_loss:2.1409 val_bpb:1.2680 train_time:387163ms step_avg:55.31ms
step:7200/20000 train_loss:1.9839 train_time:397593ms step_avg:55.22ms
step:7400/20000 train_loss:1.9241 train_time:408043ms step_avg:55.14ms
step:7600/20000 train_loss:2.1767 train_time:418494ms step_avg:55.07ms
step:7800/20000 train_loss:2.1418 train_time:428942ms step_avg:54.99ms
step:8000/20000 train_loss:2.0734 train_time:439395ms step_avg:54.92ms
step:8000/20000 val_loss:2.1325 val_bpb:1.2630 train_time:439413ms step_avg:54.93ms
step:8200/20000 train_loss:2.2139 train_time:449842ms step_avg:54.86ms
step:8400/20000 train_loss:2.1893 train_time:460366ms step_avg:54.81ms
step:8600/20000 train_loss:2.2138 train_time:470809ms step_avg:54.75ms
step:8800/20000 train_loss:2.0684 train_time:481258ms step_avg:54.69ms
step:9000/20000 train_loss:2.0865 train_time:491704ms step_avg:54.63ms
step:9000/20000 val_loss:2.1004 val_bpb:1.2440 train_time:491724ms step_avg:54.64ms
step:9200/20000 train_loss:2.1596 train_time:502152ms step_avg:54.58ms
step:9400/20000 train_loss:2.0030 train_time:512596ms step_avg:54.53ms
step:9600/20000 train_loss:1.9800 train_time:523049ms step_avg:54.48ms
step:9800/20000 train_loss:2.0284 train_time:533581ms step_avg:54.45ms
step:10000/20000 train_loss:1.9803 train_time:544026ms step_avg:54.40ms
step:10000/20000 val_loss:2.0606 val_bpb:1.2204 train_time:544053ms step_avg:54.41ms
step:10200/20000 train_loss:2.0710 train_time:554489ms step_avg:54.36ms
step:10400/20000 train_loss:2.0309 train_time:564930ms step_avg:54.32ms
step:10600/20000 train_loss:1.9979 train_time:575372ms step_avg:54.28ms
step:10800/20000 train_loss:2.0367 train_time:585815ms step_avg:54.24ms
step:11000/20000 train_loss:1.9998 train_time:596280ms step_avg:54.21ms
step:11000/20000 val_loss:2.0144 val_bpb:1.1930 train_time:596298ms step_avg:54.21ms
step:11070/20000 val_loss:2.0132 val_bpb:1.1924 train_time:599998ms step_avg:54.20ms
stopping_early: wallclock train_time:599998ms step:11070/20000
peak_mem alloc:13930MiB reserved:14022MiB
loaded SWA weights (averaged 115 snapshots)
raw model:98434727 code:40698 total:98475425
quant+zstd: 13238730 bytes (payload:25745728 ratio:3.82x)
artifact: 13279428 bytes PASS
roundtrip val_loss:2.0972 val_bpb:1.2421 eval:1782ms
final_roundtrip val_loss:2.09723679 val_bpb:1.24210176
Loading