-
Notifications
You must be signed in to change notification settings - Fork 418
add configuration for float8 with rowwise scaling, via recipe lookup #808
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,12 +56,23 @@ def parallelize_llama( | |
and not job_config.training.compile | ||
): | ||
raise RuntimeError("Async TP requires --training.compile") | ||
|
||
enable_float8_linear = "float8" in job_config.model.converters | ||
float8_is_rowwise = job_config.float8.recipe_name in ( | ||
"rowwise", | ||
"rowwise_with_gw_hp", | ||
) | ||
Comment on lines
+61
to
+64
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since they are the only two valid options, can we do assertion instead of using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you explain what you mean by that? not sure I follow, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh I was saying does it make sense to do
if we only supports two rowwise recipes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah, got it. Currently this PR is written as "support all recipes, including future ones", looks like you are asking for "only support recipes which have been tested in torchtitan". I agree with that feedback, let me change this PR to accomodate. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I would prefer not do this though, as this forces the code reader to go look at "which recipes are valid" to understand what's going on.
I like this as is displays the valid recipes here without making the code reader go look for them. |
||
|
||
# For now, float8 all-gather with TP is only supported for tensorwise | ||
# float8 scaling recipes. For rowwise recipes, we use regular TP and | ||
# all-gather happens in high precision. | ||
enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise | ||
|
||
apply_tp( | ||
model, | ||
world_mesh["tp"], | ||
loss_parallel=parallel_dims.loss_parallel_enabled, | ||
enable_float8=enable_float8_linear, | ||
enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, | ||
enable_async_tp=job_config.experimental.enable_async_tensor_parallel, | ||
) | ||
|
||
|
@@ -115,7 +126,7 @@ def apply_tp( | |
model: nn.Module, | ||
tp_mesh: DeviceMesh, | ||
loss_parallel: bool, | ||
enable_float8: bool, | ||
enable_float8_tensorwise_tp: bool, | ||
enable_async_tp: bool, | ||
): | ||
"""Apply tensor parallelism.""" | ||
|
@@ -141,10 +152,8 @@ def apply_tp( | |
) | ||
|
||
# Parallel styles used for transformer block linear weights and their | ||
# inputs may be different for float8 linears | ||
if enable_float8: | ||
# TODO(vkuzo): once float8 configuration supports delayed scaling, | ||
# add a check here to enforce supported float8 all-gather configurations | ||
# inputs may be different for float8 linears with tensorwise scaling. | ||
if enable_float8_tensorwise_tp: | ||
# TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there | ||
from torchao.float8.float8_tensor_parallel import ( | ||
Float8ColwiseParallel, | ||
|
@@ -202,7 +211,7 @@ def apply_tp( | |
enable_symm_mem_for_group(tp_mesh.get_group().group_name) | ||
|
||
logger.info( | ||
f"Applied {'Float8 ' if enable_float8 else ''}{'Async ' if enable_async_tp else ''}" | ||
f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" | ||
"Tensor Parallelism to the model" | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what happens if
recipe_name
is not in("rowwise", "rowwise_with_gw_hp")
?shall we use
choices
in config_manager.py to guard, similar to https://github.com/pytorch/torchtitan/blob/main/torchtitan/config_manager.py#L191There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we will throw a detailed error message from torchao (https://github.com/pytorch/ao/blob/1ab1b77ad744115c3fee62716e7d2083c57d80a1/torchao/float8/config.py#L301) with valid choices displayed in the error message - this code is getting that for free
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
never mind, moved to "choices" like you suggested