Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add configuration for float8 with rowwise scaling, via recipe lookup #808

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,12 @@ def __init__(self):
action="store_true",
help="Whether precompute float8 scales dynamically for FSDP",
)
self.parser.add_argument(
"--float8.recipe_name",
type=str,
default=None,
help="If specified, creates float8 config from recipe name, choices are `all_axiswise` and `lw_axiswise_with_gw_hp`",
Copy link
Contributor

@tianyu-l tianyu-l Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please:

  1. add a choices arg, similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L292
  2. add more info on the recipes, including references, e.g. if users are interested, where they can learn more

Copy link
Contributor

@lessw2020 lessw2020 Feb 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related - imo, this naming should be made more intuitive for users. I understand you want exact breakout within your code but users won't necessarily grasp it vs calling it in more direct terms.

)

# communications library settings
self.parser.add_argument(
Expand Down
46 changes: 30 additions & 16 deletions torchtitan/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,43 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
return
try:
from torchao.float8 import Float8LinearConfig

# we should update this code after torchao exposes this publically
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# we should update this code after torchao exposes this publically
# TODO: we should update this code after torchao exposes this publically

from torchao.float8.config import (
Float8LinearRecipeName,
recipe_name_to_linear_config,
)
except ImportError as e:
raise ImportError(
"torchao is not installed. Please install it to use float8 linear layers."
) from e

# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
)

self.enabled = True

# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)

logger.info("Float8 training active")
if float8_config.recipe_name is not None:
# TODO(future PR): the recipe lookup by name is currently a private API, we'll need
# to expose it publically in torchao before a PR similar to this one can be
# landed in torchtitan
recipe = Float8LinearRecipeName(float8_config.recipe_name)
self.config = recipe_name_to_linear_config(recipe)
self.precompute_scale = False
logger.info(f"Float8 training active with recipe {recipe}")

else:
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
parallel_dims.dp_shard_enabled
and float8_config.enable_fsdp_float8_all_gather
)
self.config = Float8LinearConfig(
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather,
)
# for precompute_float8_dynamic_scale_for_fsdp
self.precompute_scale = (
enable_fsdp_float8_all_gather
and float8_config.precompute_float8_dynamic_scale_for_fsdp
)
logger.info("Float8 tensorwise scaled training active")

def convert_to_float8_training(self, model: nn.Module):
"""
Expand Down