Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,14 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing

2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.


users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:

```bash
--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
```
```bash
--compile.backend "aot_eager" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_autobucketing"
```
3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
- "aot_eager_manualbucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend.
```bash
--compile.backend "aot_eager" --job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config --compile.model_backend_override "aot_eager_manualbucketing" --compile.manual_bucketed_modules "tok_embeddings,layers.[0-5],norm+output"
```

### Citation

Expand Down
38 changes: 35 additions & 3 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,18 @@
from typing import Any, Union

import torch
from .job_config import Compile as CompileConfig


def get_compile_backend(backend_name: str) -> Union[str, callable]:
def get_compile_backend(compile_config: CompileConfig) -> Union[str, callable]:
# return the compile backends used in SimpleFSDP training
# Step1: check if backend_name is inside available torch.compile backends
# Step2: check if the backend_name has been registered as a customized backend
backend_name = (
getattr(compile_config, "model_backend_override", None)
or compile_config.backend
)

available_torch_backend = torch._dynamo.list_backends(exclude_tags=())
if backend_name in available_torch_backend:
return backend_name
Expand All @@ -25,8 +31,6 @@ def get_compile_backend(backend_name: str) -> Union[str, callable]:
schedule_overlap_bucketing,
)

torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = True
torch._inductor.config.test_configs.aten_fx_overlap_insert_overlap_deps = False
torch._inductor.config.allow_buffer_reuse = False

def aten_autobucketing_reordering_pass(
Expand All @@ -41,6 +45,34 @@ def aten_autobucketing_reordering_pass(
bw_compiler=aten_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif backend_name == "aot_eager_manualbucketing":
# Perform manual optimization in aten fx-level and execute code in aot_eager backend
# The manualbucketing logic is here:
bucketing_modules = compile_config.manual_bucketed_modules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This local variable is not used.

from functools import partial

from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
from torch._inductor.fx_passes.overlap_manual_scheduling import (
manual_overlap_bucketing,
)

torch._inductor.config.allow_buffer_reuse = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.

manual_overlap_bucketing = partial(
manual_overlap_bucketing,
module_bucket_plans=compile_config.manual_bucketed_modules,
)

def aten_manualbucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
manual_overlap_bucketing(gm)
return gm

backend = aot_autograd_backend(
fw_compiler=aten_manualbucketing_reordering_pass,
bw_compiler=aten_manualbucketing_reordering_pass,
keep_inference_input_mutations=True,
)
else:
raise AssertionError(f"Unsupported customized backend: {backend_name}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,6 @@ def parallelize_deepseekv3(
if job_config.compile.enable:
torch._inductor.config.reorder_for_peak_memory = False
torch._dynamo.config.capture_scalar_outputs = True
model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True)
model = torch.compile(model, backend=get_compile_backend(job_config.compile), fullgraph=True)

return model
13 changes: 12 additions & 1 deletion torchtitan/experiments/simple_fsdp/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,18 @@
@dataclass
class Compile:
model_backend_override: str | None = None
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing """
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should make this subclass torchtitan.config.job_config.Compile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's additional config extended from job_config.Comfile. not sure wdym here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something like class Compile(torchtitan.config.job_config.Compile)


manual_bucketed_modules: list[str] = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to have instructions about this field. E.g. it's not super obvious what this means "tok_embeddings,layers.[0-5],norm+output", as it involves regex I have a guess, but users might not.

btw, are the list separated by ,?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list is separated by ,; but I didn't do explicit spilting here. essentially, it's similar to filter_fqns here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add fsdp_ prefix? Or do we imagine this field will be use for other use cases, if so what are the use cases?

"""
Manual bucket modules based on user specified FQNs
Abbreviations are supported to make specifying modules easier.
Currently, the following abbreviations are available:
(1) layers.[0-2] -> [layers.0], [layers.1], [layers.2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now user has to know how many layer a particular flavor of model has, when applying manual bucketing. Do you think we can improve the UX by automatically resolving the number of layers?

I even think we shouldn't expose this option in toml. In toml user should just need to specify bucketing_mode = "none", "transformer_block", "auto"
And if it's transformer_block, we explicitly iterate over all the transformerblocks and pass the expanded fqns in manual_overlap_bucketing. That means manual_overlap_bucketing don't need to be smart about abbreviations.

Happy to hear people's thoughts.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean we could have another "manual" mode supporting Manual bucket modules if people really want to override, but a good default of transformer block level bucketing should be enabled more easily.

(layers are split three separate buckets)
(2) norm+output -> [norm, output]
(norm and output are in one bucket)
"""


@dataclass
Expand Down
6 changes: 1 addition & 5 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,9 @@ def parallelize_llama(

if job_config.compile.enable and "model" in job_config.compile.components:
torch._inductor.config.reorder_for_peak_memory = False
backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
)
model = torch.compile(
model,
backend=get_compile_backend(backend),
backend=get_compile_backend(job_config.compile),
fullgraph=True,
)

Expand Down
25 changes: 13 additions & 12 deletions torchtitan/experiments/simple_fsdp/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,19 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"1D",
"1d",
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.model_backend_override aot_eager_autobucketing",
],
],
"1D+aot_eager_autobucketing",
"1d_aot_eager_autobucketing",
),
# TODO(ruisizhang123): add back after autobucketing pass is mature
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add a manual bucketing test?

we should also add one in the loss unit test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a few to do items for reordering. I think it'd be better to add the tests after the API is stable?

# OverrideDefinitions(
# [
# [
# "--model.name simple_fsdp.llama3",
# "--compile.enable",
# "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
# "--compile.model_backend_override aot_eager_autobucketing",
# ],
# ],
# "1D+aot_eager_autobucketing",
# "1d_aot_eager_autobucketing",
# ),
OverrideDefinitions(
[
[
Expand Down