Commit dbb34cc
authored
[llama4] add Grouped GEMM support for MoE (#1084)
This PR
1. adds grouped gemm support
(pytorch/pytorch#150374) for llama4 MoE. In
general, it avoids device/host syncs; for TP, it avoided the sharding
prop overhead caused by varying number of tokens for individual experts.
The speedup on the debug model is ~4x with/without TP. I'm deliberately
keeping the for-loop implementation for now, for comparison and
readability purposes.
2. moves the MoE indices kernel from the deepseek folder to the kernel
folder.
In order for TP to work, it requires some pytorch-side changes (e.g.
DTensor support for `torch._grouped_mm`), for which I will submit PRs
soon.
A issue is that the grouped gemm version doesn't work well with AdamW
optimizer, which is to be investigated. cc: @janeyx991 parent ab08612 commit dbb34cc
File tree
9 files changed
+127
-68
lines changed- torchtitan
- experiments
- deepseek_v3
- kernels/moe
- llama4
- infra
- model
- train_configs
9 files changed
+127
-68
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
37 | 37 | | |
38 | 38 | | |
39 | 39 | | |
40 | | - | |
41 | 40 | | |
42 | 41 | | |
43 | 42 | | |
44 | 43 | | |
45 | 44 | | |
| 45 | + | |
46 | 46 | | |
47 | 47 | | |
48 | 48 | | |
| |||
File renamed without changes.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1 | 1 | | |
2 | 2 | | |
3 | 3 | | |
4 | | - | |
5 | | - | |
| 4 | + | |
| 5 | + | |
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
21 | | - | |
22 | | - | |
23 | 20 | | |
24 | | - | |
25 | 21 | | |
26 | 22 | | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
27 | 27 | | |
28 | 28 | | |
29 | 29 | | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
8 | 8 | | |
9 | 9 | | |
10 | 10 | | |
| 11 | + | |
11 | 12 | | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
15 | 16 | | |
16 | 17 | | |
17 | | - | |
18 | 18 | | |
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
22 | 22 | | |
23 | 23 | | |
24 | 24 | | |
25 | | - | |
| 25 | + | |
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
| |||
32 | 32 | | |
33 | 33 | | |
34 | 34 | | |
35 | | - | |
36 | | - | |
37 | | - | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
38 | 38 | | |
39 | 39 | | |
40 | 40 | | |
41 | 41 | | |
42 | 42 | | |
43 | 43 | | |
44 | | - | |
45 | | - | |
| 44 | + | |
46 | 45 | | |
47 | | - | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
55 | | - | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
61 | | - | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
62 | 60 | | |
63 | 61 | | |
64 | 62 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
149 | 149 | | |
150 | 150 | | |
151 | 151 | | |
152 | | - | |
153 | | - | |
| 152 | + | |
| 153 | + | |
154 | 154 | | |
155 | 155 | | |
156 | 156 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
47 | 47 | | |
48 | 48 | | |
49 | 49 | | |
| 50 | + | |
50 | 51 | | |
51 | 52 | | |
52 | 53 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
| 20 | + | |
20 | 21 | | |
21 | 22 | | |
22 | 23 | | |
23 | 24 | | |
24 | 25 | | |
25 | 26 | | |
| 27 | + | |
26 | 28 | | |
27 | 29 | | |
28 | 30 | | |
29 | 31 | | |
30 | | - | |
| 32 | + | |
31 | 33 | | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | | - | |
39 | | - | |
40 | | - | |
41 | | - | |
42 | | - | |
43 | | - | |
44 | | - | |
45 | | - | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
46 | 44 | | |
47 | | - | |
48 | | - | |
49 | | - | |
50 | | - | |
51 | | - | |
52 | | - | |
53 | | - | |
54 | | - | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
55 | 69 | | |
56 | | - | |
57 | | - | |
58 | | - | |
59 | | - | |
60 | | - | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
61 | 77 | | |
62 | | - | |
63 | | - | |
64 | | - | |
65 | | - | |
66 | | - | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
67 | 86 | | |
68 | 87 | | |
69 | 88 | | |
| |||
166 | 185 | | |
167 | 186 | | |
168 | 187 | | |
| 188 | + | |
169 | 189 | | |
170 | | - | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
171 | 194 | | |
172 | 195 | | |
173 | 196 | | |
174 | 197 | | |
175 | 198 | | |
176 | | - | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
177 | 205 | | |
178 | 206 | | |
179 | 207 | | |
| |||
206 | 234 | | |
207 | 235 | | |
208 | 236 | | |
| 237 | + | |
| 238 | + | |
| 239 | + | |
| 240 | + | |
| 241 | + | |
| 242 | + | |
| 243 | + | |
| 244 | + | |
| 245 | + | |
| 246 | + | |
| 247 | + | |
| 248 | + | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
| 252 | + | |
| 253 | + | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
209 | 267 | | |
210 | 268 | | |
211 | 269 | | |
| |||
Lines changed: 3 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
29 | 29 | | |
30 | 30 | | |
31 | 31 | | |
32 | | - | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
33 | 35 | | |
34 | 36 | | |
35 | 37 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
391 | 391 | | |
392 | 392 | | |
393 | 393 | | |
394 | | - | |
| 394 | + | |
395 | 395 | | |
396 | 396 | | |
397 | 397 | | |
| |||
0 commit comments