Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Happy training!

| Run | Score | Author | Summary | Date | Info |
|-----|------:|--------|---------|------|------|
| 11L XSA4 + EMA + Batch524K | 1.1357 | dennisimoo | 11 layers, XSA on last 4 layers, EMA, batch 524288, official-template-safe zstd fallback | 2026-03-21 | [info](records/track_10min_16mb/2026-03-21_11L_XSA4_EMA_B524K_ZstdFallback/README.md) |
| 10L Int5-MLP + BigramHash(10240) | 1.1428 | thwu1 | 10 layers, mixed int5/int6 quantization, BigramHash(10240), SWA(0.4), WD=0.04 | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_10L_Int5MLP_MuonWD04_SWA50/README.md) |
| Int6 MLP3x + SmearGate + BigramHash | 1.1458 | Raahil Shah | 3x MLP + SmearGate + BigramHash + OrthoInit + Muon WD + SWA | 2026-03-20 | [info](records/track_10min_16mb/2026-03-20_Int6_MLP3x_SmearGate_BigramHash_MuonWD_SWA/README.md) |
| 11L MLP3x + Int6 QAT | 1.1502 | aruniyer | 11 layers, 3x MLP, int6 QAT, zstd-22, WD=0.04, sliding eval | 2026-03-20 | [info](records/track_10min_16mb/2026-03-19_MLP3x_QAT_Int6_SlidingWindow/README.md) |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Record: 11L XSA4 + EMA + Batch524K + zstd Fallback

**val_bpb = 1.1357** (sliding window, stride=64) | **15.67 MB** artifact | 8xH100 SXM, ~600s

Single-seed submission using an 11-layer int6 MLP3x model with XSA on the last 4 layers, EMA averaging, SmearGate, BigramHash, and a 524K fixed-time batch setting.

## Result

| Metric | Value |
|--------|-------|
| Pre-quant val_bpb | 1.1529 |
| Int6 roundtrip val_bpb | 1.1580 |
| **Int6 sliding val_bpb (stride 64)** | **1.1357** |
| Model bytes (int6+zstd) | 15,603,062 |
| Code bytes | 66,891 |
| **Total submission bytes** | **15,669,953** |

This is below the current merged README SOTA (`1.1428`) but it is not a 3-seed validated record claim.

## What's New

| Change | Impact |
|--------|--------|
| `TRAIN_BATCH_TOKENS=524288` | Better fixed-budget step count than the larger-batch 11-layer XSA+EMA setting |
| SDPA fallback for `flash_attn_interface` | Runs cleanly when FA3 Python bindings are unavailable in the official image |
| `torch.compile` behind an env flag | Reliable eager smoke tests, faster compiled full run |
| `zstd` Python-or-CLI fallback | Keeps int6 export under 16MB without depending on a specific Python package in the image |

## Configuration

```bash
NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \
TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=2048 EVAL_SEQ_LEN=2048 EVAL_STRIDE=64 \
BIGRAM_VOCAB_SIZE=2048 BIGRAM_DIM=128 \
XSA_LAST_N=4 EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 TTT_ENABLED=0 \
MUON_WD=0.04 ADAM_WD=0.04 MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500 \
WARMDOWN_ITERS=3000 WARMUP_STEPS=20 ENABLE_TORCH_COMPILE=1 \
MAX_WALLCLOCK_SECONDS=600 torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Key Run Details

| Metric | Value |
|--------|-------|
| Steps reached | 8,202 |
| Average train step time | 73.37 ms |
| Peak memory allocated | 13,879 MiB |
| Peak memory reserved | 14,004 MiB |
| Final eval mode | Sliding window, stride 64 |

## Included Files

- `train_gpt.py` — training, export, and eval script
- `run_hybrid_attempt.sh` — launch wrapper used for the run
- `train.log` — full log from the validated 600s attempt
- `submission.json` — metadata for the submission
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#!/usr/bin/env bash
set -euo pipefail

SEED="${SEED:-1337}"
NPROC_PER_NODE="${NPROC_PER_NODE:-8}"

export RUN_ID="${RUN_ID:-hybrid_xsa_ema_safe}"
export SEED
export NUM_LAYERS="${NUM_LAYERS:-11}"
export MODEL_DIM="${MODEL_DIM:-512}"
export NUM_HEADS="${NUM_HEADS:-8}"
export NUM_KV_HEADS="${NUM_KV_HEADS:-4}"
export MLP_MULT="${MLP_MULT:-3}"
export TRAIN_SEQ_LEN="${TRAIN_SEQ_LEN:-2048}"
export EVAL_SEQ_LEN="${EVAL_SEQ_LEN:-2048}"
export TRAIN_BATCH_TOKENS="${TRAIN_BATCH_TOKENS:-524288}"
export EVAL_STRIDE="${EVAL_STRIDE:-64}"
export BIGRAM_VOCAB_SIZE="${BIGRAM_VOCAB_SIZE:-2048}"
export BIGRAM_DIM="${BIGRAM_DIM:-128}"
export MATRIX_LR="${MATRIX_LR:-0.025}"
export SCALAR_LR="${SCALAR_LR:-0.025}"
export TIED_EMBED_LR="${TIED_EMBED_LR:-0.035}"
export MUON_MOMENTUM="${MUON_MOMENTUM:-0.99}"
export MUON_MOMENTUM_WARMUP_START="${MUON_MOMENTUM_WARMUP_START:-0.92}"
export MUON_MOMENTUM_WARMUP_STEPS="${MUON_MOMENTUM_WARMUP_STEPS:-1500}"
export MUON_WD="${MUON_WD:-0.04}"
export ADAM_WD="${ADAM_WD:-0.04}"
export WARMDOWN_ITERS="${WARMDOWN_ITERS:-3000}"
export WARMUP_STEPS="${WARMUP_STEPS:-30}"
export ITERATIONS="${ITERATIONS:-9000}"
export MAX_WALLCLOCK_SECONDS="${MAX_WALLCLOCK_SECONDS:-600}"
export VAL_LOSS_EVERY="${VAL_LOSS_EVERY:-0}"
export TRAIN_LOG_EVERY="${TRAIN_LOG_EVERY:-200}"
export XSA_LAST_N="${XSA_LAST_N:-4}"
export EMA_ENABLED="${EMA_ENABLED:-1}"
export EMA_DECAY="${EMA_DECAY:-0.997}"
export SWA_ENABLED="${SWA_ENABLED:-0}"
export TTT_ENABLED="${TTT_ENABLED:-0}"
export ENABLE_TORCH_COMPILE="${ENABLE_TORCH_COMPILE:-0}"

torchrun --standalone --nproc_per_node="${NPROC_PER_NODE}" train_gpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"author": "dennisimoo",
"github_id": "dennisimoo",
"name": "Record: 11L XSA4 + EMA + Batch524K + zstd fallback",
"blurb": "11-layer int6 MLP3x model with SmearGate, BigramHash(2048x128), XSA on the last 4 layers, EMA(0.997), WD=0.04, batch 524288, sliding-window eval stride 64, SDPA fallback when FA3 is unavailable, and official-template-safe compression via Python zstandard or zstd CLI fallback.",
"date": "2026-03-21T00:00:00Z",
"val_loss": 1.91760887,
"val_bpb": 1.13571899,
"pre_quant_val_loss": 1.9465,
"pre_quant_val_bpb": 1.1529,
"int6_zstd_val_loss": 1.95527145,
"int6_zstd_val_bpb": 1.15802189,
"bytes_total": 15669953,
"bytes_model_int6_zstd": 15603062,
"bytes_code": 66891
}
Loading