Commit 988c5c9
authored
fix tensor parallelism for float8 training with rowwise scaling (#1718)
Summary:
1. add a test for toy model + TP + float8 rowwise scaling training
2. fix underlying issues to make the test pass:
a. add fast path for tensor view where the new shape is the same as
old shape, for rowwise scaled float8 (this is needed for DTensor)
b. modify the fake grad dependency workaround to work when grad is a
DTensor
Test Plan:
1. ./test/float8/test_everything.sh (one transient failure:
https://www.internalfb.com/phabricator/paste/view/P1733103301)
2. verified that float8 rowwise scaling behaves sanely in torchtitan on
LLaMa 3 8B on 8 H100s, with tp 2:
```
// requires pytorch/torchtitan#808
// baseline - bfloat16 + compile + tp 2
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.tensor_parallel_degree 2 --training.compile
[rank0]:2025-02-14 13:41:16,175 - root - INFO - step: 40 loss: 7.4240 memory: 35.56GiB(37.43%) tps: 1,669 mfu: 9.77%
// float8 baseline - float8 tensorwise + compile + tp 2
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile
[rank0]:2025-02-14 13:44:07,806 - root - INFO - step: 40 loss: 7.4993 memory: 35.57GiB(37.44%) tps: 2,141 mfu: 12.54%
// float8 rowwise without zero fake dep (for sanity) + compile + tp 2
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise
[rank0]:2025-02-14 13:47:51,400 - root - INFO - step: 40 loss: 7.3472 memory: 35.55GiB(37.42%) tps: 1,858 mfu: 10.88%
// float8 rowwise + compile + tp 2
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.tensor_parallel_degree 2 --training.compile --float8.recipe_name all_axiswise
[rank0]:2025-02-14 13:51:20,864 - root - INFO - step: 40 loss: 9.4211 memory: 35.55GiB(37.42%) tps: 1,820 mfu: 10.66%
```
Reviewers:
Subscribers:
Tasks:
Tags:1 parent 7b37eb0 commit 988c5c9
File tree
4 files changed
+113
-40
lines changed- test/float8
- torchao/float8
4 files changed
+113
-40
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
23 | 23 | | |
24 | 24 | | |
25 | 25 | | |
26 | | - | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
27 | 32 | | |
28 | 33 | | |
29 | 34 | | |
30 | 35 | | |
31 | 36 | | |
32 | 37 | | |
33 | 38 | | |
34 | | - | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
35 | 46 | | |
36 | 47 | | |
37 | 48 | | |
| |||
49 | 60 | | |
50 | 61 | | |
51 | 62 | | |
| 63 | + | |
| 64 | + | |
52 | 65 | | |
53 | 66 | | |
54 | 67 | | |
| |||
180 | 193 | | |
181 | 194 | | |
182 | 195 | | |
183 | | - | |
| 196 | + | |
184 | 197 | | |
185 | 198 | | |
186 | | - | |
187 | | - | |
188 | | - | |
189 | | - | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
190 | 207 | | |
191 | 208 | | |
192 | 209 | | |
| |||
196 | 213 | | |
197 | 214 | | |
198 | 215 | | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
199 | 230 | | |
200 | 231 | | |
201 | 232 | | |
202 | 233 | | |
203 | 234 | | |
204 | | - | |
205 | | - | |
206 | | - | |
| 235 | + | |
| 236 | + | |
| 237 | + | |
207 | 238 | | |
208 | 239 | | |
209 | 240 | | |
| |||
212 | 243 | | |
213 | 244 | | |
214 | 245 | | |
215 | | - | |
| 246 | + | |
216 | 247 | | |
217 | 248 | | |
218 | | - | |
219 | | - | |
220 | | - | |
| 249 | + | |
| 250 | + | |
| 251 | + | |
221 | 252 | | |
222 | 253 | | |
223 | 254 | | |
224 | 255 | | |
225 | 256 | | |
226 | | - | |
| 257 | + | |
227 | 258 | | |
228 | 259 | | |
229 | 260 | | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
230 | 273 | | |
231 | 274 | | |
232 | 275 | | |
233 | 276 | | |
234 | | - | |
235 | | - | |
236 | | - | |
237 | | - | |
238 | | - | |
239 | | - | |
240 | | - | |
241 | | - | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
242 | 281 | | |
243 | 282 | | |
244 | 283 | | |
| |||
278 | 317 | | |
279 | 318 | | |
280 | 319 | | |
281 | | - | |
| 320 | + | |
| 321 | + | |
282 | 322 | | |
283 | 323 | | |
284 | 324 | | |
285 | | - | |
| 325 | + | |
| 326 | + | |
286 | 327 | | |
287 | 328 | | |
288 | 329 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
168 | 168 | | |
169 | 169 | | |
170 | 170 | | |
171 | | - | |
172 | | - | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
173 | 175 | | |
174 | 176 | | |
175 | 177 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
113 | 113 | | |
114 | 114 | | |
115 | 115 | | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
116 | 131 | | |
117 | 132 | | |
118 | 133 | | |
119 | 134 | | |
120 | | - | |
121 | 135 | | |
122 | 136 | | |
123 | 137 | | |
| |||
146 | 160 | | |
147 | 161 | | |
148 | 162 | | |
| 163 | + | |
149 | 164 | | |
150 | 165 | | |
151 | 166 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
36 | 36 | | |
37 | 37 | | |
38 | 38 | | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
39 | 44 | | |
40 | 45 | | |
41 | 46 | | |
| |||
96 | 101 | | |
97 | 102 | | |
98 | 103 | | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
99 | 109 | | |
100 | 110 | | |
101 | 111 | | |
| |||
154 | 164 | | |
155 | 165 | | |
156 | 166 | | |
157 | | - | |
158 | | - | |
159 | | - | |
160 | | - | |
161 | | - | |
162 | | - | |
163 | | - | |
164 | | - | |
165 | | - | |
166 | | - | |
167 | | - | |
168 | | - | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
| 170 | + | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
169 | 184 | | |
170 | 185 | | |
171 | 186 | | |
| |||
0 commit comments