-
Notifications
You must be signed in to change notification settings - Fork 591
[SimpleFSDP] add manual bucketing pass #1881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,12 +7,18 @@ | |
| from typing import Any, Union | ||
|
|
||
| import torch | ||
| from .job_config import Compile as CompileConfig | ||
|
|
||
|
|
||
| def get_compile_backend(backend_name: str) -> Union[str, callable]: | ||
| def get_compile_backend(compile_config: CompileConfig) -> Union[str, callable]: | ||
| # return the compile backends used in SimpleFSDP training | ||
| # Step1: check if backend_name is inside available torch.compile backends | ||
| # Step2: check if the backend_name has been registered as a customized backend | ||
| backend_name = ( | ||
| getattr(compile_config, "model_backend_override", None) | ||
| or compile_config.backend | ||
| ) | ||
|
|
||
| available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) | ||
| if backend_name in available_torch_backend: | ||
| return backend_name | ||
|
|
@@ -25,8 +31,6 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]: | |
| schedule_overlap_bucketing, | ||
| ) | ||
|
|
||
| torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True | ||
| torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False | ||
| torch._inductor.config.allow_buffer_reuse = False | ||
|
|
||
| def aten_autobucketing_reordering_pass( | ||
|
|
@@ -41,6 +45,34 @@ def aten_autobucketing_reordering_pass( | |
| bw_compiler=aten_autobucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| elif backend_name == "aot_eager_manualbucketing": | ||
| # Perform manual optimization in aten fx-level and execute code in aot_eager backend | ||
| # The manualbucketing logic is here: | ||
| bucketing_modules = compile_config.manual_bucketed_modules | ||
| from functools import partial | ||
|
|
||
| from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||
| from torch._inductor.fx_passes.overlap_manual_scheduling import ( | ||
| manual_overlap_bucketing, | ||
| ) | ||
|
|
||
| torch._inductor.config.allow_buffer_reuse = False | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens by default? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan. |
||
| manual_overlap_bucketing = partial( | ||
| manual_overlap_bucketing, | ||
| module_bucket_plans=compile_config.manual_bucketed_modules, | ||
| ) | ||
|
|
||
| def aten_manualbucketing_reordering_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
| manual_overlap_bucketing(gm) | ||
| return gm | ||
|
|
||
| backend = aot_autograd_backend( | ||
| fw_compiler=aten_manualbucketing_reordering_pass, | ||
| bw_compiler=aten_manualbucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| else: | ||
| raise AssertionError(f"Unsupported customized backend: {backend_name}") | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,7 +10,18 @@ | |
| @dataclass | ||
| class Compile: | ||
| model_backend_override: str | None = None | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" | ||
| """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """ | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should make this subclass torchtitan.config.job_config.Compile There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's additional config extended from job_config.Comfile. not sure wdym here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something like |
||
|
|
||
| manual_bucketed_modules: list[str] = field(default_factory=list) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to have instructions about this field. E.g. it's not super obvious what this means btw, are the list separated by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The list is separated by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we add |
||
| """ | ||
| Manual bucket modules based on user specified FQNs | ||
| Abbreviations are supported to make specifying modules easier. | ||
| Currently, the following abbreviations are available: | ||
| (1) layers.[0-2] -> [layers.0], [layers.1], [layers.2] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right now user has to know how many layer a particular flavor of model has, when applying manual bucketing. Do you think we can improve the UX by automatically resolving the number of layers? I even think we shouldn't expose this option in toml. In toml user should just need to specify bucketing_mode = "none", "transformer_block", "auto" Happy to hear people's thoughts. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I mean we could have another "manual" mode supporting Manual bucket modules if people really want to override, but a good default of transformer block level bucketing should be enabled more easily. |
||
| (layers are split three separate buckets) | ||
| (2) norm+output -> [norm, output] | ||
| (norm and output are in one bucket) | ||
| """ | ||
|
|
||
|
|
||
| @dataclass | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,18 +29,19 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: | |
| "1D", | ||
| "1d", | ||
| ), | ||
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
| "--model.name simple_fsdp.llama3", | ||
| "--compile.enable", | ||
| "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", | ||
| "--compile.model_backend_override aot_eager_autobucketing", | ||
| ], | ||
| ], | ||
| "1D+aot_eager_autobucketing", | ||
| "1d_aot_eager_autobucketing", | ||
| ), | ||
| # TODO(ruisizhang123): add back after autobucketing pass is mature | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shall we add a manual bucketing test? we should also add one in the loss unit test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a few to do items for reordering. I think it'd be better to add the tests after the API is stable? |
||
| # OverrideDefinitions( | ||
| # [ | ||
| # [ | ||
| # "--model.name simple_fsdp.llama3", | ||
| # "--compile.enable", | ||
| # "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", | ||
| # "--compile.model_backend_override aot_eager_autobucketing", | ||
| # ], | ||
| # ], | ||
| # "1D+aot_eager_autobucketing", | ||
| # "1d_aot_eager_autobucketing", | ||
| # ), | ||
| OverrideDefinitions( | ||
| [ | ||
| [ | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This local variable is not used.