-
Notifications
You must be signed in to change notification settings - Fork 270
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add configuration for float8 with rowwise scaling, via recipe lookup #808
Open
vkuzo
wants to merge
1
commit into
main
Choose a base branch
from
20250127_float8_rowwise_example
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -42,29 +42,43 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): | |||||
return | ||||||
try: | ||||||
from torchao.float8 import Float8LinearConfig | ||||||
|
||||||
# we should update this code after torchao exposes this publically | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
from torchao.float8.config import ( | ||||||
Float8LinearRecipeName, | ||||||
recipe_name_to_linear_config, | ||||||
) | ||||||
except ImportError as e: | ||||||
raise ImportError( | ||||||
"torchao is not installed. Please install it to use float8 linear layers." | ||||||
) from e | ||||||
|
||||||
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear | ||||||
enable_fsdp_float8_all_gather = ( | ||||||
parallel_dims.dp_shard_enabled | ||||||
and float8_config.enable_fsdp_float8_all_gather | ||||||
) | ||||||
self.config = Float8LinearConfig( | ||||||
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, | ||||||
) | ||||||
|
||||||
self.enabled = True | ||||||
|
||||||
# for precompute_float8_dynamic_scale_for_fsdp | ||||||
self.precompute_scale = ( | ||||||
enable_fsdp_float8_all_gather | ||||||
and float8_config.precompute_float8_dynamic_scale_for_fsdp | ||||||
) | ||||||
|
||||||
logger.info("Float8 training active") | ||||||
if float8_config.recipe_name is not None: | ||||||
# TODO(future PR): the recipe lookup by name is currently a private API, we'll need | ||||||
# to expose it publically in torchao before a PR similar to this one can be | ||||||
# landed in torchtitan | ||||||
recipe = Float8LinearRecipeName(float8_config.recipe_name) | ||||||
self.config = recipe_name_to_linear_config(recipe) | ||||||
self.precompute_scale = False | ||||||
logger.info(f"Float8 training active with recipe {recipe}") | ||||||
|
||||||
else: | ||||||
# Mutates the model inplace replacing instances of torch.nn.Linear with Float8Linear | ||||||
enable_fsdp_float8_all_gather = ( | ||||||
parallel_dims.dp_shard_enabled | ||||||
and float8_config.enable_fsdp_float8_all_gather | ||||||
) | ||||||
self.config = Float8LinearConfig( | ||||||
enable_fsdp_float8_all_gather=enable_fsdp_float8_all_gather, | ||||||
) | ||||||
# for precompute_float8_dynamic_scale_for_fsdp | ||||||
self.precompute_scale = ( | ||||||
enable_fsdp_float8_all_gather | ||||||
and float8_config.precompute_float8_dynamic_scale_for_fsdp | ||||||
) | ||||||
logger.info("Float8 tensorwise scaled training active") | ||||||
|
||||||
def convert_to_float8_training(self, model: nn.Module): | ||||||
""" | ||||||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please:
choices
arg, similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L292There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related - imo, this naming should be made more intuitive for users. I understand you want exact breakout within your code but users won't necessarily grasp it vs calling it in more direct terms.