Commit 3ebdf05
committed
[not for land yet] example of float8 with rowwise scaling
Summary:
This is an example of how to call float8 training with rowwise scaling
from torchao.
TODO: finalize API in torchao, and finalize how we want to expose it in
torchtitan, and optimize performance.
```
// baseline (bf16 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile
...
step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73%
// experiment (rowwise float8 + compile)
> with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile
...
step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55%
// for comparison, tensorwise float8 with float8 all-gather (on main branch)
with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp
...
step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08%
```
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:1 parent eddce12 commit 3ebdf05
2 files changed
+11
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
42 | 42 | | |
43 | 43 | | |
44 | 44 | | |
| 45 | + | |
45 | 46 | | |
46 | 47 | | |
47 | 48 | | |
| |||
55 | 56 | | |
56 | 57 | | |
57 | 58 | | |
| 59 | + | |
58 | 60 | | |
59 | 61 | | |
60 | 62 | | |
61 | 63 | | |
62 | 64 | | |
| 65 | + | |
63 | 66 | | |
64 | 67 | | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
65 | 75 | | |
66 | 76 | | |
67 | 77 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
114 | 114 | | |
115 | 115 | | |
116 | 116 | | |
| 117 | + | |
117 | 118 | | |
118 | 119 | | |
119 | 120 | | |
| |||
0 commit comments