Skip to content

Commit 3ebdf05

Browse files
committed
[not for land yet] example of float8 with rowwise scaling
Summary: This is an example of how to call float8 training with rowwise scaling from torchao. TODO: finalize API in torchao, and finalize how we want to expose it in torchtitan, and optimize performance. ``` // baseline (bf16 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --training.compile ... step: 20 loss: 8.4931 memory: 47.65GiB(50.16%) tps: 5,760 mfu: 33.73% // experiment (rowwise float8 + compile) > with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile ... step: 40 loss: 7.3818 memory: 66.81GiB(70.33%) tps: 6,412 mfu: 37.55% // for comparison, tensorwise float8 with float8 all-gather (on main branch) with-proxy CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh --float8.enable_float8_linear --training.compile --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp ... step: 20 loss: 8.4258 memory: 47.32GiB(49.81%) tps: 7,186 mfu: 42.08% ``` Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent eddce12 commit 3ebdf05

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

torchtitan/float8.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4242
return
4343
try:
4444
from torchao.float8 import CastConfig, Float8LinearConfig, ScalingType
45+
from torchao.float8.config import Float8LinearRecipeName, recipe_name_to_linear_config
4546
except ImportError as e:
4647
raise ImportError(
4748
"torchao is not installed. Please install it to use float8 linear layers."
@@ -55,13 +56,22 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5556
scaling_type_input = ScalingType(float8_config.scaling_type_input)
5657
scaling_type_weight = ScalingType(float8_config.scaling_type_weight)
5758
scaling_type_grad_output = ScalingType(float8_config.scaling_type_grad_output)
59+
# Note: this is overridden below
5860
self.config = Float8LinearConfig(
5961
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
6062
cast_config_input=CastConfig(scaling_type=scaling_type_input),
6163
cast_config_weight=CastConfig(scaling_type=scaling_type_weight),
6264
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
65+
# force_recompute_fp8_weight_in_bwd=True,
6366
)
6467

68+
# Note: the recipe lookup by name is currently a private API, we'll need
69+
# to expose it publically in torchao before a PR similar to this one can be
70+
# landed in torchtitan
71+
recipe = "all_axiswise"
72+
recipe = Float8LinearRecipeName(recipe)
73+
self.config = recipe_name_to_linear_config(recipe)
74+
6575
self.enabled = True
6676

6777
# for precompute_float8_dynamic_scale_for_fsdp

train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ def main(job_config: JobConfig):
114114
float8_handler = Float8Handler(job_config, parallel_dims)
115115
# swap to Float8Linear based on float8 configs
116116
float8_handler.convert_to_float8_training(model)
117+
print(model)
117118

118119
# log model size
119120
model_param_count = utils.get_num_params(model)

0 commit comments

Comments
 (0)