From 3537666a9f97f3222ddf39f5d61d7ca25b3c4899 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 13:22:47 -0800 Subject: [PATCH] [not for land yet] example of float8 with rowwise scaling Summary: This is an example of how to call float8 training with rowwise scaling from torchao. TODO: finalize API in torchao, and finalize how we want to expose it in torchtitan, and optimize performance. ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ... // torchao main branch step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55% // torchao with https://github.com/pytorch/ao/pull/1629 step: 20 loss: 8.3823 memory: 58.55GiB(61.63%) tps: 6,424 mfu: 37.62% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags: --- docs/float8.md | 19 ++++++-- torchtitan/components/float8.py | 51 ++++++++++++++------ torchtitan/config_manager.py | 19 ++++++-- torchtitan/models/llama/parallelize_llama.py | 23 ++++++--- 4 files changed, 81 insertions(+), 31 deletions(-) diff --git a/docs/float8.md b/docs/float8.md index f05be2b60..1a35c2aee 100644 --- a/docs/float8.md +++ b/docs/float8.md @@ -5,20 +5,29 @@ Please install latest [TorchAO](https://github.com/pytorch/ao/tree/main/torchao/ USE_CPP=0 python -m pip install git+https://github.com/pytorch/ao.git ``` -Launch training job with the following command (or alternatively set configs in toml files) +For float8 with tensorwise scaling, launch training job with the following command (or alternatively set configs in toml files) ``` -CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd +CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp --float8.force_recompute_fp8_weight_in_bwd --training.compile ``` * `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. * `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth. * `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter. * `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward. +* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels -For parallelisms, we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). +For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files) +``` +CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_train.sh --model.converters="float8" --float8.recipe_name rowwise --training.compile +``` +* `--model.converters="float8"`: swap `nn.Linear` with `Float8Linear` to perform float8 matmul. +* `--float8.recipe_name="rowwise"`: use the rowwise scaling recipe for higher accuracy compared to tensorwise scaling +* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels + +For parallelisms, for float8 with tensorwise scaling we support float8 all-gather for FSDP (optional) and for TP (by default for `Float8Linear`). For float8 with rowwise scaling, all distributed communication is done in high precision. -For scaling strategy, we currently support tensor-wise scaling with dynamic scales, and are actively working on tensor-wise scaling with delayed scales. Row-wise scaling is under exploration. +For scaling strategy, we currently support tensorwise dynamic scaling (stable) and rowwise dynamic scaling (alpha). -## Why Composing Float8 with `torch.distributed` +## Benefits of composing of float8 with tensorwise scaling with `torch.distributed` **Float8 vs Bfloat16/Float32**: In float8 E4M3 format, we only have 3 bits for mantissa, it becomes user's responsibility to maintain consistent scales across operations (summation, multiplication) to balance between precision and range. For bfloat16/float32, exponent range is large enough and users do not need to maintain such scales. When using float8 in FSDP and TP, tensors are sharded across ranks. To keep single device semantics, it's critical to communicate scales across ranks. As shown below, for float8 for matmul, `torch._scaled_mm` requires both float8 tensors and their scales. Scales are calculated from `max(abs)` of a high precision tensor. diff --git a/torchtitan/components/float8.py b/torchtitan/components/float8.py index 7f148c001..1d5eccd22 100644 --- a/torchtitan/components/float8.py +++ b/torchtitan/components/float8.py @@ -49,25 +49,46 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): "torchao is not installed. Please install it to use float8 linear layers." ) from e - # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear - enable_fsdp_float8_all_gather = ( - parallel_dims.dp_shard_enabled - and float8_config.enable_fsdp_float8_all_gather - ) - self.config = Float8LinearConfig( - enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, - force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd, - ) + if float8_config.recipe_name is not None and not hasattr( + Float8LinearConfig, "from_recipe_name" + ): + logger.warning( + "Failed to swap to Float8Linear with recipe lookup because the torchao version " + + "is too old, please install torchao v0.9.0 or later and try again", + ) + return self.enabled = True - # for precompute_float8_dynamic_scale_for_fsdp - self.precompute_scale = ( - enable_fsdp_float8_all_gather - and float8_config.precompute_float8_dynamic_scale_for_fsdp - ) + if float8_config.recipe_name is not None: + assert ( + not float8_config.enable_fsdp_float8_all_gather + ), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported" + assert ( + not float8_config.force_recompute_fp8_weight_in_bwd + ), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported" + self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name) + self.precompute_scale = False + logger.info( + f"Float8 training active with recipe {float8_config.recipe_name}" + ) - logger.info("Float8 training active") + else: + # Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear + enable_fsdp_float8_all_gather = ( + parallel_dims.dp_shard_enabled + and float8_config.enable_fsdp_float8_all_gather + ) + self.config = Float8LinearConfig( + enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, + force_recompute_fp8_weight_in_bwd=float8_config.force_recompute_fp8_weight_in_bwd, + ) + # for precompute_float8_dynamic_scale_for_fsdp + self.precompute_scale = ( + enable_fsdp_float8_all_gather + and float8_config.precompute_float8_dynamic_scale_for_fsdp + ) + logger.info("Float8 tensorwise scaled training active") def convert(self, model: nn.Module): return self.convert_to_float8_training(model) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d8a98a745..1045d88c3 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -613,20 +613,31 @@ def __init__(self): self.parser.add_argument( "--float8.enable_fsdp_float8_all_gather", action="store_true", - help="Whether enable float8 all-gather in FSDP", + help="Whether enable float8 all-gather in FSDP, recommended for tensorwise scaling", ) self.parser.add_argument( "--float8.precompute_float8_dynamic_scale_for_fsdp", action="store_true", - help="Whether precompute float8 scales dynamically for FSDP", + help="Whether precompute float8 scales dynamically for FSDP, recommended for tensorwise scaling", ) self.parser.add_argument( "--float8.force_recompute_fp8_weight_in_bwd", action="store_true", help=""" Whether to force the recomputation of FP8 weights during backward pass. - When using FSDP, it is recommended to enable `force_recompute_fp8_weight_in_bwd` - to prevent saving unsharded FP8 weights for backward computation. + When using FSDP with tensorwise scaling, it is recommended to enable + `force_recompute_fp8_weight_in_bwd` to prevent saving unsharded FP8 weights + for backward computation. + """, + ) + self.parser.add_argument( + "--float8.recipe_name", + type=str, + default=None, + choices=["tensorwise", "rowwise", "rowwise_with_gw_hp"], + help=""" + If specified, creates float8 config from recipe name, valid choices are + `tensorwise`, `rowwise` and `rowwise_with_gw_hp`. """, ) diff --git a/torchtitan/models/llama/parallelize_llama.py b/torchtitan/models/llama/parallelize_llama.py index 18f380e60..e8ec2ef9e 100644 --- a/torchtitan/models/llama/parallelize_llama.py +++ b/torchtitan/models/llama/parallelize_llama.py @@ -56,12 +56,23 @@ def parallelize_llama( and not job_config.training.compile ): raise RuntimeError("Async TP requires --training.compile") + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + apply_tp( model, world_mesh["tp"], loss_parallel=parallel_dims.loss_parallel_enabled, - enable_float8=enable_float8_linear, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.experimental.enable_async_tensor_parallel, ) @@ -115,7 +126,7 @@ def apply_tp( model: nn.Module, tp_mesh: DeviceMesh, loss_parallel: bool, - enable_float8: bool, + enable_float8_tensorwise_tp: bool, enable_async_tp: bool, ): """Apply tensor parallelism.""" @@ -141,10 +152,8 @@ def apply_tp( ) # Parallel styles used for transformer block linear weights and their - # inputs may be different for float8 linears - if enable_float8: - # TODO(vkuzo): once float8 configuration supports delayed scaling, - # add a check here to enforce supported float8 all-gather configurations + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there from torchao.float8.float8_tensor_parallel import ( Float8ColwiseParallel, @@ -202,7 +211,7 @@ def apply_tp( enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info( - f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" "Tensor Parallelism to the model" )