Skip to content

Commit 069bae2

Browse files
authored
add configuration for float8 with rowwise scaling, via recipe lookup (#808)
Summary: Exposes the float8 config from recipe name API from torchao, so we can use it to easily configure float8 with rowwise scaling from torchtitan. Usage: ``` with-proxy CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh \ --model.converters float8 \ --training.compile \ --float8.recipe_name rowwise ``` Example, pretraining LLaMa 3 8B on 8 H100s: ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.converters float8 --training.compile --float8.recipe_name rowwise ... step: 30 loss: 7.7109 memory: 47.77GiB(50.28%) tps: 6,468 mfu: 37.88% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./torchtitan/models/llama/train_configs/llama3_8b.toml" ./run_train.sh --model.converters float8 --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 0047aa2 commit 069bae2

File tree

4 files changed

+81
-31
lines changed

4 files changed

+81
-31
lines changed

docs/float8.md

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,29 @@ Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/
55
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
66
```
77

8-
Launch training job with the following command (or alternatively set configs in toml files)
8+
For float8 with tensorwise scaling, launch training job with the following command (or alternatively set configs in toml files)
99
```
10-
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd
10+
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --training.compile
1111
```
1212
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
1313
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
1414
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
1515
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
16+
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
1617

17-
For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`).
18+
For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files)
19+
```
20+
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.recipe_name rowwise --training.compile
21+
```
22+
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
23+
* `--float8.recipe_name="rowwise"`: use the rowwise scaling recipe for higher accuracy compared to tensorwise scaling
24+
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels
25+
26+
For parallelisms, for float8 with tensorwise scaling we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). For float8 with rowwise scaling, all distributed communication is done in high precision.
1827

19-
For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.
28+
For scaling strategy, we currently support tensorwise dynamic scaling (stable) and rowwise dynamic scaling (alpha).
2029

21-
## Why Composing Float8 with `torch.distributed`
30+
## Benefits of composing of float8 with tensorwise scaling with `torch.distributed`
2231
**Float8 vs Bfloat16/Float32**: In float8 E4M3 format, we only have 3 bits for mantissa, it becomes user's responsibility to maintain consistent scales across operations (summation, multiplication) to balance between precision and range. For bfloat16/float32, exponent range is large enough and users do not need to maintain such scales. When using float8 in FSDP and TP, tensors are sharded across ranks. To keep single device semantics, it's critical to communicate scales across ranks.
2332

2433
As shown below, for float8 for matmul, `torch._scaled_mm` requires both float8 tensors and their scales. Scales are calculated from `max(abs)` of a high precision tensor.

torchtitan/components/float8.py

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -49,25 +49,46 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4949
"torchao is not installed. Please install it to use float8 linear layers."
5050
) from e
5151

52-
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
53-
enable_fsdp_float8_all_gather = (
54-
parallel_dims.dp_shard_enabled
55-
and float8_config.enable_fsdp_float8_all_gather
56-
)
57-
self.config = Float8LinearConfig(
58-
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
59-
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
60-
)
52+
if float8_config.recipe_name is not None and not hasattr(
53+
Float8LinearConfig, "from_recipe_name"
54+
):
55+
logger.warning(
56+
"Failed to swap to Float8Linear with recipe lookup because the torchao version "
57+
+ "is too old, please install torchao v0.9.0 or later and try again",
58+
)
59+
return
6160

6261
self.enabled = True
6362

64-
# for precompute_float8_dynamic_scale_for_fsdp
65-
self.precompute_scale = (
66-
enable_fsdp_float8_all_gather
67-
and float8_config.precompute_float8_dynamic_scale_for_fsdp
68-
)
63+
if float8_config.recipe_name is not None:
64+
assert (
65+
not float8_config.enable_fsdp_float8_all_gather
66+
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
67+
assert (
68+
not float8_config.force_recompute_fp8_weight_in_bwd
69+
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
70+
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
71+
self.precompute_scale = False
72+
logger.info(
73+
f"Float8 training active with recipe {float8_config.recipe_name}"
74+
)
6975

70-
logger.info("Float8 training active")
76+
else:
77+
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
78+
enable_fsdp_float8_all_gather = (
79+
parallel_dims.dp_shard_enabled
80+
and float8_config.enable_fsdp_float8_all_gather
81+
)
82+
self.config = Float8LinearConfig(
83+
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
84+
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
85+
)
86+
# for precompute_float8_dynamic_scale_for_fsdp
87+
self.precompute_scale = (
88+
enable_fsdp_float8_all_gather
89+
and float8_config.precompute_float8_dynamic_scale_for_fsdp
90+
)
91+
logger.info("Float8 tensorwise scaled training active")
7192

7293
def convert(self, model: nn.Module):
7394
return self.convert_to_float8_training(model)

torchtitan/config_manager.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -613,20 +613,31 @@ def __init__(self):
613613
self.parser.add_argument(
614614
"--float8.enable_fsdp_float8_all_gather",
615615
action="store_true",
616-
help="Whether enable float8 all-gather in FSDP",
616+
help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
617617
)
618618
self.parser.add_argument(
619619
"--float8.precompute_float8_dynamic_scale_for_fsdp",
620620
action="store_true",
621-
help="Whether precompute float8 scales dynamically for FSDP",
621+
help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
622622
)
623623
self.parser.add_argument(
624624
"--float8.force_recompute_fp8_weight_in_bwd",
625625
action="store_true",
626626
help="""
627627
Whether to force the recomputation of FP8 weights during backward pass.
628-
When using FSDP, it is recommended to enable `force_recompute_fp8_weight_in_bwd`
629-
to prevent saving unsharded FP8 weights for backward computation.
628+
When using FSDP with tensorwise scaling, it is recommended to enable
629+
`force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
630+
for backward computation.
631+
""",
632+
)
633+
self.parser.add_argument(
634+
"--float8.recipe_name",
635+
type=str,
636+
default=None,
637+
choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
638+
help="""
639+
If specified, creates float8 config from recipe name, valid choices are
640+
`tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
630641
""",
631642
)
632643

torchtitan/models/llama/parallelize_llama.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,23 @@ def parallelize_llama(
5656
and not job_config.training.compile
5757
):
5858
raise RuntimeError("Async TP requires --training.compile")
59+
5960
enable_float8_linear = "float8" in job_config.model.converters
61+
float8_is_rowwise = job_config.float8.recipe_name in (
62+
"rowwise",
63+
"rowwise_with_gw_hp",
64+
)
65+
66+
# For now, float8 all-gather with TP is only supported for tensorwise
67+
# float8 scaling recipes. For rowwise recipes, we use regular TP and
68+
# all-gather happens in high precision.
69+
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise
70+
6071
apply_tp(
6172
model,
6273
world_mesh["tp"],
6374
loss_parallel=parallel_dims.loss_parallel_enabled,
64-
enable_float8=enable_float8_linear,
75+
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
6576
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
6677
)
6778

@@ -115,7 +126,7 @@ def apply_tp(
115126
model: nn.Module,
116127
tp_mesh: DeviceMesh,
117128
loss_parallel: bool,
118-
enable_float8: bool,
129+
enable_float8_tensorwise_tp: bool,
119130
enable_async_tp: bool,
120131
):
121132
"""Apply tensor parallelism."""
@@ -141,10 +152,8 @@ def apply_tp(
141152
)
142153

143154
# Parallel styles used for transformer block linear weights and their
144-
# inputs may be different for float8 linears
145-
if enable_float8:
146-
# TODO(vkuzo): once float8 configuration supports delayed scaling,
147-
# add a check here to enforce supported float8 all-gather configurations
155+
# inputs may be different for float8 linears with tensorwise scaling.
156+
if enable_float8_tensorwise_tp:
148157
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
149158
from torchao.float8.float8_tensor_parallel import (
150159
Float8ColwiseParallel,
@@ -202,7 +211,7 @@ def apply_tp(
202211
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
203212

204213
logger.info(
205-
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
214+
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
206215
"Tensor Parallelism to the model"
207216
)
208217

0 commit comments

Comments
 (0)