Skip to content

add configuration for float8 with rowwise scaling, via recipe lookup #808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions docs/float8.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,29 @@ Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/
USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git
```

Launch training job with the following command (or alternatively set configs in toml files)
For float8 with tensorwise scaling, launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --training.compile
```
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels

For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`).
For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files)
```
CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.recipe_name rowwise --training.compile
```
* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul.
* `--float8.recipe_name="rowwise"`: use the rowwise scaling recipe for higher accuracy compared to tensorwise scaling
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels

For parallelisms, for float8 with tensorwise scaling we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). For float8 with rowwise scaling, all distributed communication is done in high precision.

For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration.
For scaling strategy, we currently support tensorwise dynamic scaling (stable) and rowwise dynamic scaling (alpha).

## Why Composing Float8 with `torch.distributed`
## Benefits of composing of float8 with tensorwise scaling with `torch.distributed`
**Float8 vs Bfloat16/Float32**: In float8 E4M3 format, we only have 3 bits for mantissa, it becomes user's responsibility to maintain consistent scales across operations (summation, multiplication) to balance between precision and range. For bfloat16/float32, exponent range is large enough and users do not need to maintain such scales. When using float8 in FSDP and TP, tensors are sharded across ranks. To keep single device semantics, it's critical to communicate scales across ranks.

As shown below, for float8 for matmul, `torch._scaled_mm` requires both float8 tensors and their scales. Scales are calculated from `max(abs)` of a high precision tensor.
Expand Down
51 changes: 36 additions & 15 deletions torchtitan/components/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,46 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
"torchao is not installed. Please install it to use float8 linear layers."
) from e

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
)
if float8_config.recipe_name is not None and not hasattr(
Float8LinearConfig, "from_recipe_name"
):
logger.warning(
"Failed to swap to Float8Linear with recipe lookup because the torchao version "
+ "is too old, please install torchao v0.9.0 or later and try again",
)
return

self.enabled = True

# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)
if float8_config.recipe_name is not None:
assert (
not float8_config.enable_fsdp_float8_all_gather
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
assert (
not float8_config.force_recompute_fp8_weight_in_bwd
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if recipe_name is not in ("rowwise", "rowwise_with_gw_hp")?
shall we use choices in config_manager.py to guard, similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L191

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will throw a detailed error message from torchao (https://github.com/pytorch/ao/blob/1ab1b77ad744115c3fee62716e7d2083c57d80a1/torchao/float8/config.py#L301) with valid choices displayed in the error message - this code is getting that for free

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

never mind, moved to "choices" like you suggested

self.precompute_scale = False
logger.info(
f"Float8 training active with recipe {float8_config.recipe_name}"
)

logger.info("Float8 training active")
else:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd,
)
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)
logger.info("Float8 tensorwise scaled training active")

def convert(self, model: nn.Module):
return self.convert_to_float8_training(model)
Expand Down
19 changes: 15 additions & 4 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,20 +613,31 @@ def __init__(self):
self.parser.add_argument(
"--float8.enable_fsdp_float8_all_gather",
action="store_true",
help="Whether enable float8 all-gather in FSDP",
help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling",
)
self.parser.add_argument(
"--float8.precompute_float8_dynamic_scale_for_fsdp",
action="store_true",
help="Whether precompute float8 scales dynamically for FSDP",
help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling",
)
self.parser.add_argument(
"--float8.force_recompute_fp8_weight_in_bwd",
action="store_true",
help="""
Whether to force the recomputation of FP8 weights during backward pass.
When using FSDP, it is recommended to enable `force_recompute_fp8_weight_in_bwd`
to prevent saving unsharded FP8 weights for backward computation.
When using FSDP with tensorwise scaling, it is recommended to enable
`force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights
for backward computation.
""",
)
self.parser.add_argument(
"--float8.recipe_name",
type=str,
default=None,
choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"],
help="""
If specified, creates float8 config from recipe name, valid choices are
`tensorwise`, `rowwise` and `rowwise_with_gw_hp`.
""",
)

Expand Down
23 changes: 16 additions & 7 deletions torchtitan/models/llama/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,23 @@ def parallelize_llama(
and not job_config.training.compile
):
raise RuntimeError("Async TP requires --training.compile")

enable_float8_linear = "float8" in job_config.model.converters
float8_is_rowwise = job_config.float8.recipe_name in (
"rowwise",
"rowwise_with_gw_hp",
)
Comment on lines +61 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since they are the only two valid options, can we do assertion instead of using in, unless you plan to add other options soon.

Copy link
Contributor Author

@vkuzo vkuzo Feb 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you explain what you mean by that? not sure I follow, foo in ("a", "b") seems like a clean way to enforce that foo is either a or b me

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I was saying does it make sense to do

float8_is_rowwise = job_config.float8.recipe_name is not None
# and if you don't use `choices` to guard, we should
assert job_config.float8.recipe_name in ("rowwise", "rowwise_with_gw_hp")

if we only supports two rowwise recipes.
NVM if torchao plans to support other recipes soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, got it. Currently this PR is written as "support all recipes, including future ones", looks like you are asking for "only support recipes which have been tested in torchtitan". I agree with that feedback, let me change this PR to accomodate.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8_is_rowwise = job_config.float8.recipe_name is not None

I would prefer not do this though, as this forces the code reader to go look at "which recipes are valid" to understand what's going on.

assert job_config.float8.recipe_name in ("rowwise", "rowwise_with_gw_hp")

I like this as is displays the valid recipes here without making the code reader go look for them.


# For now, float8 all-gather with TP is only supported for tensorwise
# float8 scaling recipes. For rowwise recipes, we use regular TP and
# all-gather happens in high precision.
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise

apply_tp(
model,
world_mesh["tp"],
loss_parallel=parallel_dims.loss_parallel_enabled,
enable_float8=enable_float8_linear,
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp,
enable_async_tp=job_config.experimental.enable_async_tensor_parallel,
)

Expand Down Expand Up @@ -115,7 +126,7 @@ def apply_tp(
model: nn.Module,
tp_mesh: DeviceMesh,
loss_parallel: bool,
enable_float8: bool,
enable_float8_tensorwise_tp: bool,
enable_async_tp: bool,
):
"""Apply tensor parallelism."""
Expand All @@ -141,10 +152,8 @@ def apply_tp(
)

# Parallel styles used for transformer block linear weights and their
# inputs may be different for float8 linears
if enable_float8:
# TODO(vkuzo): once float8 configuration supports delayed scaling,
# add a check here to enforce supported float8 all-gather configurations
# inputs may be different for float8 linears with tensorwise scaling.
if enable_float8_tensorwise_tp:
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there
from torchao.float8.float8_tensor_parallel import (
Float8ColwiseParallel,
Expand Down Expand Up @@ -202,7 +211,7 @@ def apply_tp(
enable_symm_mem_for_group(tp_mesh.get_group().group_name)

logger.info(
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}"
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}"
"Tensor Parallelism to the model"
)

Expand Down