forked from KellerJordan/modded-nanogpt
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_gpt_medium.py
More file actions
1805 lines (1553 loc) · 75.8 KB
/
train_gpt_medium.py
File metadata and controls
1805 lines (1553 loc) · 75.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import os
import sys
with open(sys.argv[0]) as f:
code = f.read() # read the code of this file ASAP, for logging
import copy
import glob
import math
import threading
import time
import uuid
from dataclasses import dataclass
from collections import defaultdict
from itertools import accumulate
from pathlib import Path
import gc
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
import torch
torch.empty(
1, device=f"cuda:{os.environ['LOCAL_RANK']}", requires_grad=True
).backward() # prevents a bug on some systems
import torch._dynamo as dynamo
import torch.distributed as dist
import torch.nn.functional as F
# torch._inductor.config.coordinate_descent_tuning = True # we have banned this flag for new records because it causes compilation to take 30min
import triton
import triton.language as tl
from kernels import get_kernel
from torch import Tensor, nn
dynamo.config.recompile_limit = 64
# -----------------------------------------------------------------------------
# Custom operators: FP8 matmul by @YouJiacheng
@torch.library.custom_op("nanogpt::mm", mutates_args=())
def mm_op(x: Tensor, w: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor, Tensor]:
@torch.compile
def impl(x: Tensor, w: Tensor):
assert x.is_contiguous() and w.is_contiguous()
x_f8 = x.div(x_s).to(torch.float8_e4m3fn)
w_f8 = w.div(w_s).to(torch.float8_e4m3fn)
out = torch._scaled_mm(
x_f8,
w_f8.T,
out_dtype=torch.bfloat16,
scale_a=x.new_tensor(x_s, dtype=torch.float32),
scale_b=x.new_tensor(w_s, dtype=torch.float32),
use_fast_accum=True,
)
return out, x_f8, w_f8
return impl(x, w)
@mm_op.register_fake
def _(x: Tensor, w: Tensor, *_):
assert x.ndim == w.ndim == 2
assert x.shape[1] == w.shape[1]
assert x.device == w.device
assert x.is_contiguous() and w.is_contiguous()
return x @ w.T, x.to(torch.float8_e4m3fn), w.to(torch.float8_e4m3fn)
@torch.library.custom_op("nanogpt::mm_backward", mutates_args=())
def mm_backward_op(g: Tensor, x_f8: Tensor, w_f8: Tensor, x_s: float, w_s: float, grad_s: float) -> tuple[Tensor, Tensor]:
@torch.compile
def impl(grad: Tensor, x_f8: Tensor, w_f8: Tensor):
assert grad.is_contiguous()
x_inv_s = grad.new_tensor(x_s, dtype=torch.float32)
w_inv_s = grad.new_tensor(w_s, dtype=torch.float32)
grad_inv_s = grad.new_tensor(grad_s, dtype=torch.float32)
grad_f8 = grad.div(grad_s).to(torch.float8_e5m2)
grad_x = torch._scaled_mm(
grad_f8,
w_f8.T.contiguous().T,
out_dtype=torch.bfloat16,
scale_a=grad_inv_s,
scale_b=w_inv_s,
use_fast_accum=False,
)
# faster than grad_f8_t @ x_f8, for (d_out, d_in) == (50304, 768)
grad_w = torch._scaled_mm(
x_f8.T.contiguous(),
grad_f8.T.contiguous().T,
out_dtype=torch.float32,
scale_a=x_inv_s,
scale_b=grad_inv_s,
use_fast_accum=False,
).T
return grad_x, grad_w
return impl(g, x_f8, w_f8)
@mm_backward_op.register_fake
def _(g: Tensor, x_f8: Tensor, w_f8: Tensor, *_):
return x_f8.to(torch.bfloat16), w_f8.T.contiguous().T.to(torch.float32)
def backward(ctx, grad_out: Tensor, *_):
x_f8, w_f8 = ctx.saved_tensors
x_s, w_s, grad_s = ctx.scales
grad_x, grad_w = torch.ops.nanogpt.mm_backward(
grad_out, x_f8, w_f8, x_s, w_s, grad_s
)
return grad_x, grad_w, None, None, None
def setup_context(ctx: torch.autograd.function.FunctionCtx, inputs, output):
*_, x_s, w_s, grad_s = inputs
_, x_f8, w_f8 = output
ctx.save_for_backward(x_f8, w_f8)
ctx.scales = x_s, w_s, grad_s
ctx.set_materialize_grads(False)
mm_op.register_autograd(backward, setup_context=setup_context)
# -----------------------------------------------------------------------------
# Triton kernel for symmetric matrix multiplication by @byronxu99
def _get_autotune_configs():
return [
triton.Config(
{
"BLOCK_SIZE_M": bm,
"BLOCK_SIZE_N": bn,
"BLOCK_SIZE_K": bk,
"GROUP_SIZE_M": 8,
"LOWER_UPPER": 1,
},
num_stages=stages,
num_warps=warps,
)
for bm in [64, 128]
for bn in [64, 128, 256]
for bk in [64, 128]
for stages, warps in [(3, 4), (3, 8), (4, 4)]
if bm // bn <= 2 and bn // bm <= 2
]
@triton.jit
def _pid_to_block(
pid,
M,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
# Split output matrix into blocks of size (BLOCK_SIZE_M, BLOCK_SIZE_N)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(M, BLOCK_SIZE_N)
# Map PID to a single matrix in batch
batch_idx = pid // (num_pid_m * num_pid_n)
pid = pid % (num_pid_m * num_pid_n)
# Map PID to 2D grid of blocks
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
pid_m, pid_n = tl.swizzle2d(pid_m, pid_n, num_pid_m, num_pid_n, GROUP_SIZE_M)
m_idx = pid_m * BLOCK_SIZE_M
n_idx = pid_n * BLOCK_SIZE_N
return batch_idx, m_idx, n_idx
@triton.autotune(
configs=_get_autotune_configs(),
key=["M", "K", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"],
)
@triton.jit
def XXT_kernel(
A_ptr, C_ptr,
M, K,
a_stride_b, a_stride_r, a_stride_c,
c_stride_b, c_stride_r, c_stride_c,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
LOWER_UPPER: tl.constexpr,
):
pid = tl.program_id(axis=0)
batch_idx, m_idx, n_idx = _pid_to_block(
pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
)
# Skip blocks that don't need to be computed
skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx)
skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx)
if skip_block_below_diag or skip_block_above_diag:
return
# Index into one matrix of batch
A_ptr += batch_idx * a_stride_b
C_ptr += batch_idx * c_stride_b
# Create pointer arrays for A and A.T
offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c)
at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Accumulate over blocks of K
for k in tl.range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
at = tl.load(at_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, at, accumulator)
a_ptrs += BLOCK_SIZE_K * a_stride_c
at_ptrs += BLOCK_SIZE_K * a_stride_c
out_dtype = C_ptr.dtype.element_ty
output = accumulator.to(out_dtype)
# Store block of C
offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M)
offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
tl.store(c_ptrs, output, mask=c_mask)
# Store block of C mirrored across the diagonal
c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c)
c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
tl.store(c_ptrs_t, output.T, mask=c_mask_t)
def XXT(A: torch.Tensor, out: torch.Tensor):
"""
Launch Triton kernel to compute C = A @ A.T
"""
assert A.ndim == 2 or A.ndim == 3
M, K = A.shape[-2:]
assert out.size(-2) == M, "Output matrix has incorrect shape"
assert out.size(-1) == M, "Output matrix has incorrect shape"
batch_size = A.size(0) if A.ndim == 3 else 1
input_batch_stride = A.stride(0) if A.ndim == 3 else 0
output_batch_stride = out.stride(0) if out.ndim == 3 else 0
grid = lambda meta: (
batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]),
)
XXT_kernel[grid](
A_ptr=A,
C_ptr=out,
M=M,
K=K,
a_stride_b=input_batch_stride,
a_stride_r=A.stride(-2),
a_stride_c=A.stride(-1),
c_stride_b=output_batch_stride,
c_stride_r=out.stride(-2),
c_stride_c=out.stride(-1),
)
return out
@triton.autotune(
configs=_get_autotune_configs(),
key=["M", "a_stride_r", "a_stride_c", "c_stride_r", "c_stride_c"],
)
@triton.jit
def ba_plus_cAA_kernel(
A_ptr, C_ptr,
M,
a_stride_b, a_stride_r, a_stride_c,
c_stride_b, c_stride_r, c_stride_c,
alpha, beta,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
LOWER_UPPER: tl.constexpr,
):
# This is mostly duplicated from XXT_kernel, but also loads and adds a block of A
# Performance is slightly slower than XXT_kernel, so we use two separate kernels
pid = tl.program_id(axis=0)
batch_idx, m_idx, n_idx = _pid_to_block(
pid, M, BLOCK_SIZE_M, BLOCK_SIZE_N, GROUP_SIZE_M
)
# Skip blocks that don't need to be computed
skip_block_below_diag = (LOWER_UPPER == 0) and (n_idx + BLOCK_SIZE_N <= m_idx)
skip_block_above_diag = (LOWER_UPPER != 0) and (m_idx + BLOCK_SIZE_M <= n_idx)
if skip_block_below_diag or skip_block_above_diag:
return
# Index into one matrix of batch
A_ptr += batch_idx * a_stride_b
C_ptr += batch_idx * c_stride_b
# Create pointer arrays for A and A.T
offs_m = (m_idx + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (n_idx + tl.arange(0, BLOCK_SIZE_N)) % M
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = A_ptr + (offs_m[:, None] * a_stride_r + offs_k[None, :] * a_stride_c)
at_ptrs = A_ptr + (offs_k[:, None] * a_stride_c + offs_n[None, :] * a_stride_r)
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Accumulate over blocks of K
for k in tl.range(0, tl.cdiv(M, BLOCK_SIZE_K)):
a = tl.load(a_ptrs, mask=offs_k[None, :] < M - k * BLOCK_SIZE_K, other=0.0)
at = tl.load(at_ptrs, mask=offs_k[:, None] < M - k * BLOCK_SIZE_K, other=0.0)
accumulator = tl.dot(a, at, accumulator)
a_ptrs += BLOCK_SIZE_K * a_stride_c
at_ptrs += BLOCK_SIZE_K * a_stride_c
# Load block of A to add (corresponds to the current block of C)
offs_am = m_idx + tl.arange(0, BLOCK_SIZE_M)
offs_an = n_idx + tl.arange(0, BLOCK_SIZE_N)
a_add_ptrs = A_ptr + (offs_am[:, None] * a_stride_r + offs_an[None, :] * a_stride_c)
a_add_mask = (offs_am[:, None] < M) & (offs_an[None, :] < M)
a_add = tl.load(a_add_ptrs, mask=a_add_mask, other=0.0).to(tl.float32)
# Apply alpha and beta
accumulator *= alpha
accumulator += a_add * beta
out_dtype = C_ptr.dtype.element_ty
output = accumulator.to(out_dtype)
# Store block of C
offs_cm = m_idx + tl.arange(0, BLOCK_SIZE_M)
offs_cn = n_idx + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = C_ptr + (offs_cm[:, None] * c_stride_r + offs_cn[None, :] * c_stride_c)
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < M)
tl.store(c_ptrs, output, mask=c_mask)
# Store block of C mirrored across the diagonal
c_ptrs_t = C_ptr + (offs_cn[:, None] * c_stride_r + offs_cm[None, :] * c_stride_c)
c_mask_t = (offs_cn[:, None] < M) & (offs_cm[None, :] < M)
tl.store(c_ptrs_t, output.T, mask=c_mask_t)
def ba_plus_cAA(A: torch.Tensor, alpha: float, beta: float, out: torch.Tensor):
"""
Launch Triton kernel to compute C = alpha * A @ A.T + beta * A
"""
assert A.ndim == 2 or A.ndim == 3
M, K = A.shape[-2:]
assert M == K, "Input matrix must be square"
assert out.size(-2) == M
assert out.size(-1) == M
batch_size = A.size(0) if A.ndim == 3 else 1
input_batch_stride = A.stride(0) if A.ndim == 3 else 0
output_batch_stride = out.stride(0) if out.ndim == 3 else 0
grid = lambda meta: (
batch_size * triton.cdiv(M, meta["BLOCK_SIZE_M"]) * triton.cdiv(M, meta["BLOCK_SIZE_N"]),
)
ba_plus_cAA_kernel[grid](
A_ptr=A,
C_ptr=out,
M=M,
a_stride_b=input_batch_stride,
a_stride_r=A.stride(-2),
a_stride_c=A.stride(-1),
c_stride_b=output_batch_stride,
c_stride_r=out.stride(-2),
c_stride_c=out.stride(-1),
alpha=alpha,
beta=beta,
)
return out
# Computed for num_iters=5, safety_factor=2e-2, cushion=2
polar_express_coeffs = [
(8.156554524902461, -22.48329292557795, 15.878769915207462),
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323)
]
@torch.compile(dynamic=False, fullgraph=True) # Must use dynamic=False or else it's much slower
def polar_express(G: torch.Tensor, split_baddbmm: bool = False):
"""
Polar Express Sign Method: https://arxiv.org/pdf/2505.16932
by Noah Amsel, David Persson, Christopher Musco, Robert M. Gower.
"""
X = G.bfloat16()
if G.size(-2) > G.size(-1):
X = X.mT
# Ensure spectral norm is at most 1
X = X / (X.norm(dim=(-2, -1), keepdim=True) * (1 + 2e-2) + 1e-6)
# Allocate buffers
X = X.contiguous()
A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
B = torch.empty_like(A)
C = torch.empty_like(X)
# Select batched vs unbatched
if split_baddbmm:
BX_matmul = torch.bmm if X.ndim > 2 else torch.mm
else:
aX_plus_BX = torch.baddbmm if X.ndim > 2 else torch.addmm
# Perform the iterations
for a, b, c in polar_express_coeffs:
XXT(X, out=A) # A = X @ X.mT
ba_plus_cAA(A, alpha=c, beta=b, out=B) # B = b * A + c * A @ A
# Referencing X twice causes pytorch to make a defensive copy,
# resulting in a cudaMemcpyAsync in baddbmm.
# For large matrices (i.e., the mlp weights), it's faster to split
# the operation into two kernels to avoid this.
if split_baddbmm:
BX_matmul(B, X, out=C) # C = B @ X
C.add_(X, alpha=a) # C = C + a*X (in-place, X only read)
else:
aX_plus_BX(X, B, X, beta=a, out=C) # C = a * X + B @ X
X, C = C, X # Swap references to avoid unnecessary copies
if G.size(-2) > G.size(-1):
X = X.mT
return X
# -----------------------------------------------------------------------------
# Compiled helpers for NorMuon by @chrisjmccormick
@torch.compile(dynamic=False, fullgraph=True)
def cautious_wd_and_update_inplace(p, v, wd_tensor, lr_tensor):
"""Cautious weight decay + parameter update. wd_tensor and lr_tensor are 0-D CPU tensors."""
mask = (v * p) >= 0
wd_factor = wd_tensor.to(p.dtype)
lr_factor = lr_tensor.to(p.dtype)
p.copy_(p - (p * mask * wd_factor * lr_factor) - (v * lr_factor))
@torch.compile(dynamic=False, fullgraph=True)
def apply_normuon_variance_reduction(v_chunk, second_momentum_buffer, beta2, red_dim):
"""NorMuon variance reduction. Algebraically fuses the normalization steps to minimize memory ops."""
v_mean = v_chunk.float().square().mean(dim=red_dim, keepdim=True)
red_dim_size = v_chunk.size(red_dim)
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True).mul_(red_dim_size)
v_norm = v_norm_sq.sqrt_()
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt_()
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt_()
final_scale = step_size * (v_norm / v_norm_new.clamp_min_(1e-10))
return v_chunk.mul_(final_scale.type_as(v_chunk))
# -----------------------------------------------------------------------------
# NorMuon optimizer
class NorMuon(torch.optim.Optimizer):
"""
Muon - MomentUm Orthogonalized by Newton-schulz
https://kellerjordan.github.io/posts/muon/
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has
the advantage that it can be stably run in bfloat16 on the GPU.
Warning: This optimizer should not be used for the embedding layer, the final fully connected layer,
or any {0,1}-D parameters; those should all be optimized by a standard method (e.g., AdamW).
Differences from standard Muon:
- Newton-Shulz is replaced with Polar Express for the orthogonalization step
- NorMuon adds a low-rank variance estimator similar to Adafactor. https://arxiv.org/pdf/2510.05491
- small 1D parameters handled here instead of in Adam
- Cautious weight decay, a gated version of decoupled weight decay
- Custom distributed sizing:
The model stores all attn and mlp weights in the same shape, and then updates the view as
needed on the forward pass. This enables attn and mlp weights to be contained within the same
dist.reduce_scatter_tensor() call. The model architecture has been customized to enable
(n_attn_layers+n_mlp_layers*2)%8==0 for batching across 8 GPUs with zero padding on mlp and attn.
The scheduling is:
1. reduce scatter attn_gate (10 params 6 padding params)
2. reduce scatter attn/mlp round 1 (10 attn params 6 mlp params)
3. reduce scatter attn/mlp round 2 (16 mlp params)
4. wait on step 1, then compute update of 1 and schedule all gather
5. wait on step 2, then compute update of 2 and schedule all gather
6. wait on step 3, then compute update of 3 and schedule all gather
GPUs receive [2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 ATTN, 2 MLP, 2 MLP, 2 MLP]
GPUs that receive params of type attn reshape before computing update
7. wait on 4, then compute update of 4 and schedule all gather
8. wait for each all gather to complete and update params
Empirically, leading with small params provides an additional 0.2s improvement.
"""
def __init__(self, params, lr=0.02, weight_decay=0.01, momentum=0.95, beta2=0.95, custom_sizing=True):
defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, beta2=beta2)
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
# custom sizing requires 8 GPUs
if custom_sizing and dist.get_world_size()==8:
param_groups = self.generate_custom_param_groups(params)
else:
param_groups = self.generate_standard_param_groups(params)
super().__init__(param_groups, defaults)
def reset(self):
# expose a reset for clearing buffers
for group in self.param_groups:
if "momentum_buffer" in group:
group["momentum_buffer"].zero_()
group["second_momentum_buffer"].zero_()
def generate_standard_param_groups(self, params):
"""
Use this method if running on less than 8 GPU or experimenting with additional attn or mlp modules.
Creates one param group per module.
"""
groups = defaultdict(list)
for param in params:
groups[param.label].append(param)
param_groups = []
for module_name, group_params in groups.items():
chunk_size = (len(group_params) + self.world_size - 1) // self.world_size
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
return param_groups
def generate_custom_param_groups(self, params):
"""
Implementation requires that a single GPU does not receive both attn
and mlp params when a param group is split across GPUs.
"""
params_list = list(params)
module_group_order = ['attn_gate', 'value_embed_gate', 'attn', 'mlp'] # 16, 10, 16, 32
group_sizes = [16, 10, 16, 16, 16]
params_list.sort(key=lambda x: module_group_order.index(x.label))
print0(len(params_list), console=True)
idx = 0
assert len(params_list) == sum(group_sizes)
param_groups = []
for size in group_sizes:
chunk_size = (size + self.world_size - 1) // self.world_size
group_params = params_list[idx: idx + size]
param_groups.append(dict(params=group_params, chunk_size=chunk_size))
idx += size
return param_groups
@torch.no_grad()
def step(self):
# Efficient distributed step by @YouJiacheng, @KonstantinWilleke, @alexrgilbert,
# @adricarda, @tuttyfrutyee, @vdlad, @ryanyang0, @vagrawal, @varunneal, @chrisjmccormick
rank = dist.get_rank()
group_infos = []
for group in self.param_groups:
params: list[Tensor] = group["params"]
if not params:
continue
chunk_size = group["chunk_size"]
padded_num_params = chunk_size * self.world_size
stacked_grads = torch.empty(
(padded_num_params, *params[0].shape),
dtype=params[0].dtype,
device=params[0].device
)
for i, p in enumerate(params):
stacked_grads[i].copy_(p.grad, non_blocking=True)
if len(params) < padded_num_params:
stacked_grads[len(params):].zero_()
grad_chunk = torch.empty_like(stacked_grads[:chunk_size])
reduce_future = dist.reduce_scatter_tensor(
grad_chunk, stacked_grads, op=dist.ReduceOp.AVG, async_op=True
).get_future()
group_infos.append(dict(grad_chunk=grad_chunk, reduce_future=reduce_future))
all_gather_infos = []
# Second pass: wait for gradients, compute updates for the local shard of parameters,
# and launch all async all_gather operations.
for group, info in zip(self.param_groups, group_infos):
info["reduce_future"].wait()
params = group["params"]
grad_chunk = info["grad_chunk"]
chunk_size = group["chunk_size"]
padded_num_params = chunk_size * self.world_size
start_idx = rank * chunk_size
module_idx = start_idx if start_idx < len(params) else 0
num_params = min(chunk_size, max(0, len(params) - start_idx)) # num params for this rank
if "momentum_buffer" not in group:
group["momentum_buffer"] = torch.zeros_like(grad_chunk[:num_params])
momentum_buffer = group["momentum_buffer"]
# Apply momentum update to the persistent momentum buffer in-place
momentum_buffer.lerp_(grad_chunk[:num_params], 1 - group["momentum"])
updated_grads = grad_chunk[:num_params].lerp_(momentum_buffer, group["momentum"])
grad_shape = updated_grads.shape
if params[module_idx].label == 'attn':
for p in params[module_idx:module_idx + num_params]:
assert p.label == 'attn'
updated_grads = updated_grads.view(4 * grad_shape[0], grad_shape[1] // 4, grad_shape[2])
ref_param = params[module_idx]
param_shape = ref_param.shape
# The below shape-based heuristic assumes that matrices have their input along the
# row dimension and their output along the columns. Gates are an exception.
is_gate = 'gate' in ref_param.label
if "second_momentum_buffer" not in group:
if is_gate:
group["second_momentum_buffer"] = torch.zeros_like(updated_grads[..., :, :1])
else:
group["second_momentum_buffer"] = (torch.zeros_like(updated_grads[..., :, :1])
if param_shape[-2] >= param_shape[-1] else torch.zeros_like(updated_grads[..., :1, :])
)
second_momentum_buffer = group["second_momentum_buffer"]
if "param_lr_cpu" not in group:
# Define multipliers for ALL params in this group (global, not per-shard)
lr_mults = []
wd_mults = []
for p in params:
# Increase learning rate for modules with larger inputs than outputs.
# This shape check also assumes rows=input, columns=output, so take care
# when changing memory layouts. @chrisjmccormick
shape = p.shape
if len(shape) >= 2:
shape_mult = max(1.0, shape[-2] / shape[-1]) ** 0.5
else:
shape_mult = 1.0
lr_mults.append(shape_mult * getattr(p, "lr_mul", 1.0))
wd_mults.append(getattr(p, "wd_mul", 1.0))
# Define as cpu tensors to enable Inductor constant folding
group["param_lr_cpu"] = torch.tensor(lr_mults, dtype=torch.float32, device="cpu")
group["param_wd_cpu"] = torch.tensor(wd_mults, dtype=torch.float32, device="cpu")
eff_lr_all = group["param_lr_cpu"] * group["lr"]
eff_wd_all = group["param_wd_cpu"] * group["weight_decay"] * group["lr"]
# Slice the portion corresponding to this rank's shard
eff_lr_cpu = eff_lr_all[module_idx:module_idx + num_params]
eff_wd_cpu = eff_wd_all[module_idx:module_idx + num_params]
# Compute zeropower for the entire chunk in a single, batched call.
if num_params == 0:
v_chunk = updated_grads
else:
v_chunk = polar_express(updated_grads, split_baddbmm=(ref_param.label == 'mlp'))
# Note that the head orientation in O is transposed relative to QKV, so red_dim
# is 'incorrect' for O. However, correcting this showed no improvement. @chrisjmccormick
red_dim = -1 if (is_gate or param_shape[-2] >= param_shape[-1]) else -2
v_chunk = apply_normuon_variance_reduction(
v_chunk, second_momentum_buffer, group["beta2"], red_dim
)
v_chunk = v_chunk.view(grad_shape)
# # "Cautious" weight decay (https://arxiv.org/abs/2510.12402)
updated_params = torch.empty_like(grad_chunk)
if num_params > 0:
# Work on a stacked copy to avoid touching original params
param_chunk = torch.stack(params[module_idx:module_idx + num_params])
for local_idx in range(num_params):
cautious_wd_and_update_inplace(
param_chunk[local_idx],
v_chunk[local_idx],
eff_wd_cpu[local_idx],
eff_lr_cpu[local_idx],
)
else:
param_chunk = torch.zeros_like(v_chunk)
updated_params[:num_params].copy_(param_chunk)
if num_params < chunk_size:
updated_params[num_params:].zero_()
stacked_params = torch.empty(
(padded_num_params, *param_shape),
dtype=updated_params.dtype,
device=updated_params.device,
)
gather_future = dist.all_gather_into_tensor(
stacked_params, updated_params, async_op=True
).get_future()
all_gather_infos.append(
{
"gather_future": gather_future,
"stacked_params": stacked_params,
"orig_params": params,
}
)
# Final pass: wait for all_gather to complete and copy results back into original parameter tensors.
for info in all_gather_infos:
info["gather_future"].wait()
stacked_params = info["stacked_params"]
orig_params = info["orig_params"]
unstacked_params = torch.unbind(stacked_params)
for i, p in enumerate(orig_params):
p.copy_(unstacked_params[i], non_blocking=True)
class DistAdam(torch.optim.Optimizer):
def __init__(self, params, label_order: list[str],lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
self.world_size = dist.get_world_size() if dist.is_initialized() else 1
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
params = list(params)
# Group by label, with explicit ordering for execution control.
params_by_label = defaultdict(list)
for p in params:
params_by_label[getattr(p, 'label', None)].append(p)
param_groups = []
for label in label_order:
if label in params_by_label:
param_groups.append(dict(params=params_by_label[label]))
# include any unlabeled params at the end (processed last)
if None in params_by_label:
param_groups.append(dict(params=params_by_label[None]))
super().__init__(param_groups, defaults)
# init state: small params (numel < 1024) use full-sized state, others use sharded
for p in params:
chunk = p if p.numel() < 1024 else p[:p.size(0) // self.world_size]
exp_avg = torch.zeros_like(chunk, dtype=torch.bfloat16, device=p.device)
self.state[p] = dict(step=0, exp_avg=exp_avg, exp_avg_sq=torch.zeros_like(exp_avg))
# DistributedAdam implementation by @vagrawal, @akash5474
self.should_sync = False
self._reduce_scatter_hooks = []
self._reduce_scatter_futures = {}
self.register_backward_hooks()
def register_backward_hooks(self):
for group in self.param_groups:
for param in group["params"]:
self._reduce_scatter_hooks.append(param.register_post_accumulate_grad_hook(self._sync_gradient))
@torch.no_grad()
def _sync_gradient(self, param):
if not self.should_sync:
return
grad = param.grad
if param.numel() < 1024:
# Small params: use all_reduce (no scatter/gather needed)
self._reduce_scatter_futures[param] = (
dist.all_reduce(grad, op=dist.ReduceOp.AVG, async_op=True).get_future(),
grad
)
else:
rank_size = grad.shape[0] // self.world_size
if grad is not None:
grad_slice = torch.empty_like(grad[:rank_size])
self._reduce_scatter_futures[param] = (
dist.reduce_scatter_tensor(grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True).get_future(),
grad_slice
)
def copy_lm_to_embed(self):
# run at 1/6 of training
lm_head = self.param_groups[0]['params'][0]
embed = self.param_groups[-1]['params'][0]
lm_head_state = self.state[lm_head]
embed_state = self.state[embed]
embed_state['step'] = lm_head_state['step']
embed_state['exp_avg'] = lm_head_state['exp_avg'].clone()
embed_state['exp_avg_sq'] = lm_head_state['exp_avg_sq'].clone()
embed.data.copy_(lm_head.data)
@torch.compile
@torch.no_grad()
def step(self):
rank = dist.get_rank()
all_gather_futures: list[torch.Future] = []
for group in self.param_groups:
beta1, beta2 = group['betas']
eps = group['eps']
wd = group['weight_decay']
for param in group['params']:
if param not in self._reduce_scatter_futures:
continue
fut, g_slice = self._reduce_scatter_futures[param]
fut.wait()
is_small = param.numel() < 1024
if is_small:
# Small params: g_slice is actually full grad, p_slice is full param
p_slice = param
else:
rank_size = param.shape[0] // self.world_size
p_slice = param[rank * rank_size:(rank + 1) * rank_size]
lr = group['lr'] * getattr(param, "lr_mul", 1.0)
state = self.state[param]
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]
state["step"] += 1
t = state["step"]
# update running averages
exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)
# bias corrections
bias1 = 1 - beta1 ** t
bias2 = 1 - beta2 ** t
# compute step
denom = exp_avg_sq.sqrt().add_(eps)
step_size = lr * (bias2 ** 0.5 / bias1)
update = exp_avg.div(denom).mul_(step_size)
# cautious weight decay
mask = (update * p_slice) > 0
# lr as weight decay schedule
eff_weight_decay = lr * wd * getattr(param, "wd_mul", 1.0)
update.addcmul_(p_slice, mask, value=eff_weight_decay * lr)
p_slice.add_(other=update, alpha=-1.0)
if not is_small:
all_gather_futures.append(dist.all_gather_into_tensor(param, p_slice, async_op=True).get_future())
self._reduce_scatter_futures.clear()
torch.futures.collect_all(all_gather_futures).wait()
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the model
def norm(x: Tensor):
return F.rms_norm(x, (x.size(-1),))
class CastedLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, use_fp8=False, x_s=1.0, w_s=1.0, grad_s=1.0):
super().__init__(in_features, out_features, bias=False)
self.use_fp8 = False # turn off fp8 for now -> requires tuning of scales which hasnt been done on medium track
self.x_s = x_s
self.w_s = w_s
self.grad_s = grad_s
def reset_parameters(self) -> None:
with torch.no_grad():
self.weight.zero_() # @Grad62304977 and others
def forward(self, x: Tensor):
if self.use_fp8 and self.training:
_x = x.flatten(0, -2)
out: Tensor = torch.ops.nanogpt.mm(_x, self.weight, x_s=self.x_s, w_s=self.w_s, grad_s=self.grad_s)[0]
return out.reshape(*x.shape[:-1], -1)
else:
return F.linear(x, self.weight.type_as(x))
# -----------------------------------------------------------------------------
# PyTorch nn.Module definitions for the model
# yarn implementation @classiclarryd
class Yarn(nn.Module):
def __init__(self, head_dim, max_seq_len):
super().__init__()
self.head_dim = head_dim
self.max_seq_len = max_seq_len
self.reset()
def reset(self):
angular_freq = (1 / 1024) ** torch.linspace(0, 1, steps=self.head_dim//4, dtype=torch.float32, device=device)
# half-truncate RoPE by @YouJiacheng (w/ base freq tuning)
angular_freq = torch.cat([angular_freq, angular_freq.new_zeros(self.head_dim//4)])
t = torch.arange(self.max_seq_len, dtype=torch.float32, device=device)
theta = torch.outer(t, angular_freq)
self.cos = nn.Buffer(
theta.cos().to(torch.bfloat16), persistent=False
)
self.sin = nn.Buffer(
theta.sin().to(torch.bfloat16), persistent=False
)
self.angular_freq = angular_freq
# start with 0.1, inspired by 0.12 from @leloykun and learnable scalars used by @brendanh0gan https://x.com/hi_tysam/status/1879693583898591283
self.attn_scale = 0.1
def apply(self, old_window: int, new_window: int, alpha: int=1, beta: int=32):
rotations = args.block_size * old_window * self.angular_freq / (2 * torch.pi)
scaling_factor = old_window / new_window
interpolation_weight = torch.clamp((rotations - alpha) / (beta - alpha), 0, 1)
self.angular_freq *= scaling_factor + interpolation_weight * (1 - scaling_factor)
t = torch.arange(self.max_seq_len, dtype=torch.float32, device=self.angular_freq.device)
theta = torch.outer(t, self.angular_freq)
self.cos.copy_(theta.cos())
self.sin.copy_(theta.sin())
self.attn_scale *= 0.2 * math.log(new_window / old_window) + 1
def rotary(x_BTHD: Tensor, cos: Tensor, sin: Tensor):
assert cos.size(0) >= x_BTHD.size(-3)
cos, sin = (
cos[None, : x_BTHD.size(-3), None, :],
sin[None, : x_BTHD.size(-3), None, :],
)
x1, x2 = x_BTHD.chunk(2, dim=-1)
y1 = x1 * cos + x2 * sin
y2 = x1 * (-sin) + x2 * cos
return torch.cat((y1, y2), 3)
@dataclass
class AttnArgs:
ve: torch.Tensor
sa_lambdas: torch.Tensor
seqlens: torch.Tensor
bm_size: int
cos: torch.Tensor
sin: torch.Tensor
attn_scale: float
key_offset: bool
flash_attn_interface = get_kernel('varunneal/flash-attention-3').flash_attn_interface
class CausalSelfAttention(nn.Module):
def __init__(self, dim: int, head_dim: int, num_heads: int, layer_idx: int):
super().__init__()
self.num_heads = num_heads
self.head_dim = head_dim
self.dim = dim
self.hdim = num_heads * head_dim
assert self.hdim == self.dim, "num_heads * head_dim must equal model_dim"
std = self.dim ** -0.5
bound = (3 ** 0.5) * std # improved init scale by @YouJiacheng
# merged QKVO weights: suggested by many, implemented by @fernbear.bsky.social, and further improved by @YouJiacheng
# https://x.com/hi_tysam/status/1879699187107033311
# Simplified layout by @chrisjmccormick
self.qkvo_w = nn.Parameter(torch.empty(self.dim * 4, self.hdim))
# label all modules for explicit optimizer grouping
self.qkvo_w.label = 'attn'
with torch.no_grad():
self.qkvo_w[:self.dim * 3].uniform_(-bound, bound) # init QKV weights
self.qkvo_w[self.dim * 3:].zero_() # init O weights to zero
# sparse gated attention to enable context based no-op by @classiclarryd
self.attn_gate = CastedLinear(16, num_heads)
self.attn_gate.weight.label = 'attn_gate'
self.attn_gate.weight.lr_mul = 0.1
# only include gates on layers with value embeds used on forward pass
if layer_idx in [0, 1, 2, 3, 4, 11, 12, 13, 14, 15]:
self.value_embed_gate = CastedLinear(16, num_heads)
self.value_embed_gate.weight.label = 'value_embed_gate'
self.value_embed_gate.weight.lr_mul = 0.1
def forward(self, x: Tensor, attn_args: AttnArgs):
B, T = x.size(0), x.size(1) # batch size, sequence length
assert B == 1, "varlen sequences requires B == 1"
assert T % 16 == 0
# unpack attention args
cos, sin = attn_args.cos, attn_args.sin
ve, sa_lambdas, key_offset = attn_args.ve, attn_args.sa_lambdas, attn_args.key_offset
seqlens, attn_scale, bm_size = attn_args.seqlens, attn_args.attn_scale, attn_args.bm_size
q, k, v = F.linear(x, sa_lambdas[0] * self.qkvo_w[:self.dim * 3].type_as(x)).view(B, T, 3 * self.num_heads, self.head_dim).chunk(3, dim=-2)
q, k = norm(q), norm(k) # QK norm @Grad62304977
q, k = rotary(q, cos, sin), rotary(k, cos, sin)
if key_offset:
# shift keys forward for the stationary head dims. Enables 1-layer induction.
k[:, 1:, :, self.head_dim // 4:self.head_dim // 2] = k[:, :-1, :, self.head_dim // 4:self.head_dim // 2]
k[:, 1:, :, 3 * self.head_dim // 4:] = k[:, :-1, :, 3 * self.head_dim // 4:]
if ve is not None:
ve_gate_out = 2 * torch.sigmoid(self.value_embed_gate(x[..., :self.value_embed_gate.weight.size(-1)])).view(B, T, self.num_heads, 1)
v = v + ve_gate_out * ve.view_as(v) # @ KoszarskyB & @Grad62304977
max_len = args.train_max_seq_len if self.training else (args.val_batch_size // (grad_accum_steps * world_size))
# use flash_attn over flex_attn @varunneal. flash_attn_varlen suggested by @YouJiacheng
y = flash_attn_interface.flash_attn_varlen_func(q[0], k[0], v[0], cu_seqlens_q=seqlens, cu_seqlens_k=seqlens,
max_seqlen_q=max_len, max_seqlen_k=max_len,
causal=True, softmax_scale=attn_scale, window_size=(bm_size, 0))
y = y.view(B, T, self.num_heads, self.head_dim)
y = y * torch.sigmoid(self.attn_gate(x[..., :self.attn_gate.weight.size(-1)])).view(B, T, self.num_heads, 1)
y = y.contiguous().view(B, T, self.num_heads * self.head_dim) # re-assemble all head outputs side by side
y = F.linear(y, sa_lambdas[1] * self.qkvo_w[self.dim * 3:].type_as(y)) # sa_lambdas[1] pre-multiplied to O @shenberg
return y
class MLP(nn.Module):
def __init__(self, dim: int):
super().__init__()
hdim = 4 * dim
# Transposed layout to match attention weights
self.c_fc = nn.Parameter(torch.empty(hdim, dim))
self.c_proj = nn.Parameter(torch.empty(hdim, dim))
# label all modules for explicit optimizer grouping
self.c_fc.label = 'mlp'
self.c_proj.label = 'mlp'
self.c_proj.lr_mul = 2.