Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
3ccd12c
[WIP] Integrate autoparallel into torchtitan
wconstab Jun 13, 2025
e6d2caf
Autoparallel support for DP-only, DP+TP, or TP-only
wconstab Jun 27, 2025
68476b3
Update CLI inductor configs for bucketing/reordering
wconstab Jul 25, 2025
9ee9f75
add back llama3_autoparallel_init_fn
wconstab Jul 25, 2025
f6e4099
Track API change from new AOTAutograd interface
ezyang Jul 28, 2025
4d7ee8a
Support forcing the model into bf16 for perf debugging
wconstab Jul 28, 2025
b801d0b
Integrate MixedPrecision with AutoParallel and fix example_inputs
wconstab Jul 29, 2025
b099cf9
Use in-place compile API
ezyang Jul 29, 2025
b3587d9
Fix bucketing pass configs
wconstab Jul 29, 2025
42c2c07
Support both eager and autoparallel init based on model.name
wconstab Jul 30, 2025
d93845e
Remove llama3 init weights hack
wconstab Aug 6, 2025
60f5f11
Print profiling manifold url
wconstab Aug 7, 2025
6c782eb
Support new compile API from autoparallel PR #77
wconstab Aug 8, 2025
4712163
Fix bucket sizes for AutoParallel 1D (#1545)
fmassa Aug 8, 2025
3f04d22
Add support for loss parallel (#1546)
fmassa Aug 10, 2025
8e50870
Add config for running simple-fsdp bucketing/reordering passes
wconstab Aug 18, 2025
91c5639
Hook up deepseekv3_auto_parallel
wconstab Aug 19, 2025
1233902
[dsv3] patch graph break fix, works up until sharding rules
xmfan Aug 19, 2025
4f8677b
update simplefsdp pass config
ruisizhang123 Aug 21, 2025
714cc5b
[dsv3] disable MoE while we fix local_map, works up until optimizer
xmfan Aug 22, 2025
45647b3
Merge branch 'main' into whc/merge_autoparallel
wconstab Aug 28, 2025
bfa9f7f
tweak ds3 model.py to reflect main branch for DS3 baseline can run (#…
bdhirsh Sep 5, 2025
75fb2eb
add simplefsdp's autobucketing pass entry (#1658)
ruisizhang123 Sep 6, 2025
8769396
[dsv3] 1D AP w/ local_map
xmfan Sep 11, 2025
db22479
[dsv3] Turn off Flex for AP
xmfan Sep 17, 2025
45b15f6
[autoparallel] Add experimental config to enable autoparallel_asynctp
IvanKobzarev Sep 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions torchtitan/components/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,12 +351,35 @@ def _update_expert_bias(
dp_cp_mesh = (
parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None
)

################################################################3
# AP friendly methods

def is_moe_block(block):
moe_enabled = getattr(block, "moe_enabled", False)
has_moe_submod = hasattr(block, "moe") # AP
return moe_enabled or has_moe_submod

def get_transformer_blocks(model_part):
if isinstance(model_part.layers, nn.ModuleDict):
# regular torchtitan
blocks = model_part.layers.values()
else:
# TODO: fix autoparallel to preserve the module dict
blocks = model_part.layers.children()
return blocks

def should_manual_allreduce(tokens_per_expert_by_layer):
return not isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor)
################################################################3

# TODO: Currently this sync is blocking (thus exposed) and happens on the
# default compute stream. Need to assess if this is OK performance-wise.
tokens_per_expert_list = []
for model_part in model_parts:
for transformer_block in model_part.layers.values():
if not transformer_block.moe_enabled:
blocks = get_transformer_blocks(model_part)
for transformer_block in blocks:
if not is_moe_block(transformer_block):
continue
if transformer_block.moe.load_balance_coeff is None:
return
Expand All @@ -372,17 +395,19 @@ def _update_expert_bias(
tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list)

if dp_cp_mesh is not None:
# Perform single all-reduce to get global statistics across all processes
pg = dp_cp_mesh.get_group()
torch.distributed.all_reduce(
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
)
if should_manual_allreduce(tokens_per_expert_by_layer):
# Perform single all-reduce to get global statistics across all processes
pg = dp_cp_mesh.get_group()
torch.distributed.all_reduce(
tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM
)

moe_layer_idx = 0
with torch.no_grad():
for model_part in model_parts:
for transformer_block in model_part.layers.values():
if not transformer_block.moe_enabled:
blocks = get_transformer_blocks(model_part)
for transformer_block in blocks:
if not is_moe_block(transformer_block):
continue
moe = transformer_block.moe

Expand Down
36 changes: 36 additions & 0 deletions torchtitan/config/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,42 @@ class Experimental:
needs to ensure that the path can be imported.
"""

# "none", "all", "only_fsdp"
bucket_all_gathers_fx: str = "none"

# "none", "all"
bucket_reduce_scatters_fx: str = "none"

reorder_for_compute_comm_overlap: bool = False
"""
Whether to enable inductor comm reordering passes
"""

reorder_for_compute_comm_overlap_passes: list[str] = field(
default_factory=lambda: [
"sink_waits_iterative",
"reorder_communication_preserving_peak_memory",
]
)
"""
Sequence of reordering passes (names of functions inside _inductor.comms) to call,
if reorder_for_compute_comm_overlap is enabled.
"""

reorder_prefetch_limit: int | None = None
"""
How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory'
pass is enabled. default of None means unlimited
"""

autop_force_bf16: bool = False

enable_simplefsdp_passes: bool = False

enable_inductor_aten_fx_overlap_scheduler: bool = False
enable_inductor_aten_fx_overlap_scheduler_bucketing: bool = False
enable_autoparallel_asynctp: bool = False


@dataclass
class Validation:
Expand Down
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torchtitan.experiments.auto_parallel # noqa: F401
import torchtitan.experiments.llama4 # noqa: F401
import torchtitan.experiments.qwen3
import torchtitan.experiments.simple_fsdp # noqa: F401
11 changes: 11 additions & 0 deletions torchtitan/experiments/auto_parallel/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
## Auto Parallel

requires installing [email protected]:pytorch-labs/autoparallel.git

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4`

Use simplefsdp's autobucketing pass:

`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable`

(or llama3-8b.toml)
54 changes: 54 additions & 0 deletions torchtitan/experiments/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers, build_optimizers_with_moe_load_balancing
from torchtitan.components.validate import build_validator
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer
from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter
from torchtitan.models.deepseek_v3.model.state_dict_adapter import DeepSeekV3StateDictAdapter
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
from torchtitan.models.deepseek_v3 import deepseekv3_configs, DeepSeekV3Model
from .parallelize_llama import parallelize_llama
from .parallelize_deepseekv3 import parallelize_deepseekv3


register_train_spec(
TrainSpec(
name="llama3_auto_parallel",
model_cls=Transformer,
model_args=llama3_configs,
parallelize_fn=parallelize_llama,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
build_validator_fn=build_validator,
state_dict_adapter=Llama3StateDictAdapter,
)
)
register_train_spec(
TrainSpec(
name="deepseekv3_auto_parallel",
model_cls=DeepSeekV3Model,
model_args=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=pipeline_llama,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
state_dict_adapter=DeepSeekV3StateDictAdapter,
)
)
Loading
Loading