Skip to content

Record: 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357)#307

Open
dennisimoo wants to merge 1 commit intoopenai:mainfrom
dennisimoo:dennisimoo/official-template-safe-xsa-ema-11357
Open

Record: 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357)#307
dennisimoo wants to merge 1 commit intoopenai:mainfrom
dennisimoo:dennisimoo/official-template-safe-xsa-ema-11357

Conversation

@dennisimoo
Copy link

@dennisimoo dennisimoo commented Mar 21, 2026

Record: 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357)

val_bpb: 1.1357 (sliding window, stride=64) | 15.67 MB | 8xH100 SXM, 600s

Key changes

Change Impact
TRAIN_BATCH_TOKENS=524288 Better fixed-budget step count than the larger-batch 11-layer XSA+EMA setting
SDPA fallback for flash_attn_interface Runs cleanly when FA3 Python bindings are unavailable in the official image
torch.compile behind an env flag Reliable eager smoke tests, faster compiled full run
zstd Python-or-CLI fallback Keeps int6 export under 16MB without depending on a specific Python package in the image

Results

Metric Value
Pre-quant val_bpb 1.1529
Int6 roundtrip val_bpb 1.1580
Int6 sliding val_bpb (s64) 1.1357
Steps completed (600s cap) 8,202
Step time 73.37ms
Model params 26,829,913
Artifact size 15,669,953 bytes

Single-seed run. Below the current merged README SOTA (1.1428).

Run command

NUM_LAYERS=11 MODEL_DIM=512 NUM_HEADS=8 NUM_KV_HEADS=4 MLP_MULT=3 \
TRAIN_BATCH_TOKENS=524288 TRAIN_SEQ_LEN=2048 EVAL_SEQ_LEN=2048 EVAL_STRIDE=64 \
BIGRAM_VOCAB_SIZE=2048 BIGRAM_DIM=128 \
XSA_LAST_N=4 EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=0 TTT_ENABLED=0 \
MUON_WD=0.04 ADAM_WD=0.04 MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 MUON_MOMENTUM_WARMUP_STEPS=1500 \
WARMDOWN_ITERS=3000 WARMUP_STEPS=20 ENABLE_TORCH_COMPILE=1 \
ITERATIONS=9000 MAX_WALLCLOCK_SECONDS=600 \
torchrun --standalone --nproc_per_node=8 train_gpt.py

@dennisimoo dennisimoo force-pushed the dennisimoo/official-template-safe-xsa-ema-11357 branch from 57a1805 to 88828a6 Compare March 21, 2026 04:14
@dennisimoo dennisimoo changed the title Submission: 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357) 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357) Mar 21, 2026
@dennisimoo dennisimoo changed the title 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357) Record Submission: 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357) Mar 21, 2026
@dennisimoo dennisimoo force-pushed the dennisimoo/official-template-safe-xsa-ema-11357 branch from 88828a6 to 59fd45a Compare March 21, 2026 04:15
@dennisimoo dennisimoo changed the title Record Submission: 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357) Record: 11L XSA4 + EMA + Batch524K + zstd fallback (val_bpb: 1.1357) Mar 21, 2026
@dennisimoo dennisimoo force-pushed the dennisimoo/official-template-safe-xsa-ema-11357 branch from 59fd45a to ca19512 Compare March 21, 2026 04:16
@mohosy
Copy link

mohosy commented Mar 21, 2026

nice score, the zstd fallback is smart for portability. how many steps did you get in 600s? trying to see if theres a speed difference between fa2 and fa3 setups

@dennisimoo
Copy link
Author

8202 steps in ~600s, 73.37 ms/step avg from the training loop. I didn’t do a clean FA2 vs FA3 benchmark here, so I wouldn’t read too much into it beyond that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants