diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index d3e9628103..1cadefbf71 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -351,12 +351,35 @@ def _update_expert_bias( dp_cp_mesh = ( parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None ) + + ################################################################3 + # AP friendly methods + + def is_moe_block(block): + moe_enabled = getattr(block, "moe_enabled", False) + has_moe_submod = hasattr(block, "moe") # AP + return moe_enabled or has_moe_submod + + def get_transformer_blocks(model_part): + if isinstance(model_part.layers, nn.ModuleDict): + # regular torchtitan + blocks = model_part.layers.values() + else: + # TODO: fix autoparallel to preserve the module dict + blocks = model_part.layers.children() + return blocks + + def should_manual_allreduce(tokens_per_expert_by_layer): + return not isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor) + ################################################################3 + # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if not transformer_block.moe_enabled: + blocks = get_transformer_blocks(model_part) + for transformer_block in blocks: + if not is_moe_block(transformer_block): continue if transformer_block.moe.load_balance_coeff is None: return @@ -372,17 +395,19 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) if dp_cp_mesh is not None: - # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() - torch.distributed.all_reduce( - tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM - ) + if should_manual_allreduce(tokens_per_expert_by_layer): + # Perform single all-reduce to get global statistics across all processes + pg = dp_cp_mesh.get_group() + torch.distributed.all_reduce( + tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM + ) moe_layer_idx = 0 with torch.no_grad(): for model_part in model_parts: - for transformer_block in model_part.layers.values(): - if not transformer_block.moe_enabled: + blocks = get_transformer_blocks(model_part) + for transformer_block in blocks: + if not is_moe_block(transformer_block): continue moe = transformer_block.moe diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 14ff6e7e8a..8e10785bdf 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -707,6 +707,42 @@ class Experimental: needs to ensure that the path can be imported. """ + # "none", "all", "only_fsdp" + bucket_all_gathers_fx: str = "none" + + # "none", "all" + bucket_reduce_scatters_fx: str = "none" + + reorder_for_compute_comm_overlap: bool = False + """ + Whether to enable inductor comm reordering passes + """ + + reorder_for_compute_comm_overlap_passes: list[str] = field( + default_factory=lambda: [ + "sink_waits_iterative", + "reorder_communication_preserving_peak_memory", + ] + ) + """ + Sequence of reordering passes (names of functions inside _inductor.comms) to call, + if reorder_for_compute_comm_overlap is enabled. + """ + + reorder_prefetch_limit: int | None = None + """ + How many ops to allow moving any individual collective, if 'reorder_communication_preserving_peak_memory' + pass is enabled. default of None means unlimited + """ + + autop_force_bf16: bool = False + + enable_simplefsdp_passes: bool = False + + enable_inductor_aten_fx_overlap_scheduler: bool = False + enable_inductor_aten_fx_overlap_scheduler_bucketing: bool = False + enable_autoparallel_asynctp: bool = False + @dataclass class Validation: diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 9d81f6b885..ffa7309dd4 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import torchtitan.experiments.auto_parallel # noqa: F401 import torchtitan.experiments.llama4 # noqa: F401 import torchtitan.experiments.qwen3 import torchtitan.experiments.simple_fsdp # noqa: F401 diff --git a/torchtitan/experiments/auto_parallel/README.md b/torchtitan/experiments/auto_parallel/README.md new file mode 100644 index 0000000000..7e112329b9 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/README.md @@ -0,0 +1,11 @@ +## Auto Parallel + +requires installing git@github.com:pytorch-labs/autoparallel.git + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4` + +Use simplefsdp's autobucketing pass: + +`CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --model.name llama3_auto_parallel --parallelism.tensor_parallel_degree 4 --experimental.enable_simplefsdp_passes --compile.enable` + +(or llama3-8b.toml) diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py new file mode 100644 index 0000000000..a67dfe18aa --- /dev/null +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers, build_optimizers_with_moe_load_balancing +from torchtitan.components.validate import build_validator +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.llama3 import llama3_configs, pipeline_llama, Transformer +from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter +from torchtitan.models.deepseek_v3.model.state_dict_adapter import DeepSeekV3StateDictAdapter +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from torchtitan.models.deepseek_v3 import deepseekv3_configs, DeepSeekV3Model +from .parallelize_llama import parallelize_llama +from .parallelize_deepseekv3 import parallelize_deepseekv3 + + +register_train_spec( + TrainSpec( + name="llama3_auto_parallel", + model_cls=Transformer, + model_args=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) +) +register_train_spec( + TrainSpec( + name="deepseekv3_auto_parallel", + model_cls=DeepSeekV3Model, + model_args=deepseekv3_configs, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) +) diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py new file mode 100644 index 0000000000..cf69511e0a --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -0,0 +1,225 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed import DeviceMesh +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def apply_local_map_to_moe(): + """ + TODO: fix HOPs not restoring the original signature. + TODO: fix tracing with local shapes so that we can use Shard placements + + Current HOP signature we get: + + class subgraph_0(torch.nn.Module): + def forward(self, + rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__router____modules__gate____parameters__weight: "f32[8, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____buffers__expert_bias: "f32[8][1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w1: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w3: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w2: "f32[8, 256, 256][65536, 256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w1____parameters__weight: "f32[512, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w3____parameters__weight: "f32[512, 256][256, 1]cuda:0", + self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w2____parameters__weight: "f32[256, 512][512, 1]cuda:0"): + """ + from torchtitan.models import moe + from torch.distributed._tensor.experimental import local_map + moe._moe_forward = local_map( + moe._moe_forward, + out_placements=( + (Replicate(),), # (Shard(0),), + (Replicate(),), + ), + in_placements=( + (Replicate(),), # (Shard(0),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + (Replicate(),), + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=None, + ) + + +# Run workflow with: +# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel +def parallelize_deepseekv3( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply Autoparallel to the model + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + world_mesh = parallel_dims.world_mesh + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + model.model_args.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # apply local_map to MoE + apply_local_map_to_moe() + + # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + # lambda bucket_idx: 500 / parallel_dims.tp + # ) + # torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + # lambda bucket_idx: 1000 / parallel_dims.tp + # ) + + # if job_config.experimental.autop_force_bf16: + # logger.info("Forcing bf16 on model") + # model = model.bfloat16() + + # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + mp_policy = None + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.compile, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + out_sharding = x_sharding + loss_parallel_enabled = ( + parallel_dims.tp_enabled and not job_config.parallelism.disable_loss_parallel + ) + if loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + _preserve_moe_attributes(model, parallel_mod) + + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` ane `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, 'layers'): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = model.layers.children() if hasattr(model.layers, 'children') else [] + + for block in blocks: + if hasattr(block, 'moe_enabled') and block.moe_enabled and hasattr(block, 'moe'): + moe_modules.append(block.moe) + elif hasattr(block, 'moe'): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, 'moe_enabled'): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, 'load_balance_coeff'): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/experiments/auto_parallel/parallelize_llama.py b/torchtitan/experiments/auto_parallel/parallelize_llama.py new file mode 100644 index 0000000000..9d53d9a755 --- /dev/null +++ b/torchtitan/experiments/auto_parallel/parallelize_llama.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel + +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + world_mesh = parallel_dims.world_mesh + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + lambda bucket_idx: 500 / parallel_dims.tp + ) + torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + lambda bucket_idx: 1000 / parallel_dims.tp + ) + + enable_overlap_scheduling = ( + job_config.experimental.enable_inductor_aten_fx_overlap_scheduler + ) + enable_overlap_scheduling_bucketing = ( + job_config.experimental.enable_inductor_aten_fx_overlap_scheduler_bucketing + ) + if enable_overlap_scheduling_bucketing: + assert ( + enable_overlap_scheduling + ), "bucketing can not be used without overlap scheduling" + + if enable_overlap_scheduling: + from torch._inductor.fx_passes.overlap_scheduling import OverlapScheduler + + torch._inductor.config.test_configs.aten_fx_overlap_preserving_bucketing = ( + enable_overlap_scheduling_bucketing + ) + + def _overlap_bucketing_pass(graph): + overlap_scheduler = OverlapScheduler(graph.owning_module) + overlap_scheduler.run() + + torch._inductor.config.post_grad_custom_post_pass = _overlap_bucketing_pass + + if job_config.experimental.enable_autoparallel_asynctp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + assert "tp" in world_mesh.mesh_dim_names + enable_symm_mem_for_group(world_mesh["tp"].get_group().group_name) + torch._inductor.config._micro_pipeline_tp = False + # Disable inductor AsyncTP passes, in favor of using Autoparallel passes fork. + from autoparallel.asynctp import micro_pipeline_tp_pass + + existing_post_grad_custom_post_pass = ( + torch._inductor.config.post_grad_custom_post_pass + ) + + def _pass(graph): + if existing_post_grad_custom_post_pass is not None: + existing_post_grad_custom_post_pass(graph) + + micro_pipeline_tp_pass(graph, None) + + torch._inductor.config.post_grad_custom_post_pass = _pass + + # bail out + # model = model_fn() + # return model + if job_config.experimental.autop_force_bf16: + logger.info("Forcing bf16 on model") + model = model.bfloat16() + + param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + with AutoParallel( + model, + input_fn, + world_mesh, + mp_policy=mp_policy, + compile=job_config.compile, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "dp_shard": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_shard": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in world_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in world_mesh.mesh_dim_names + ) + out_sharding = x_sharding + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + if loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in world_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, world_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + return parallel_mod diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 1c3d2b19d2..bf3232fd04 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -100,8 +100,8 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( vocab_size=102400, @@ -127,8 +127,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( vocab_size=129280, @@ -154,8 +154,8 @@ qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, - attn_mask_type="block_causal", + # use_flex_attn=True, + # attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index e2c4bbeda9..9074919c99 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -270,8 +270,8 @@ def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): self.attention = Attention(model_args) self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) - self.moe_enabled = layer_id >= model_args.n_dense_layers + if self.moe_enabled: self.moe = MoE( model_args.moe_args, diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 8be14ecbf0..1ec8e3b23b 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -12,6 +12,7 @@ from torch import nn from torchtitan.distributed.expert_parallel import expert_parallel +from torch.distributed.tensor.placement_types import Shard, Replicate @dataclass @@ -310,6 +311,77 @@ def forward( num_tokens_per_expert, ) +def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts): + # x: 64, 2048, 256 + bs, slen, dim = x.shape + x = x.view(-1, dim) + + # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = router(x, expert_bias) + + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + # moved out to remove mutation + # with torch.no_grad(): + # tokens_per_expert.add_(num_tokens_per_expert) + + # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = reorderer(top_scores, selected_experts_indices) + + # shape (bs*slen*top_k, dim) + token_indices_experts_sorted = token_indices_experts_sorted.reshape( + -1, 1 + ).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + + if score_before_experts: + routed_input = ( + routed_input.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + routed_output = experts(routed_input, num_tokens_per_expert) + + if not score_before_experts: + routed_output = ( + routed_output.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shared expert + if shared_experts is not None: + out = shared_experts(x) + else: + out = torch.zeros_like(x) + + out = out.scatter_add( + dim=0, index=token_indices_experts_sorted, src=routed_output + ) + out = out.reshape(bs, slen, dim) + return out, num_tokens_per_expert + class MoE(nn.Module): def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): @@ -367,72 +439,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - bs, slen, dim = x.shape - x = x.view(-1, dim) - - # top_scores and selected_experts_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - ( - top_scores, - selected_experts_indices, - num_tokens_per_expert, - ) = self.router(x, self.expert_bias) + out, num_tokens_per_expert = _moe_forward(x, self.router, self.expert_bias, self.reorderer, self.score_before_experts, self.experts, self.shared_experts) - # tokens_per_expert will be used to update the expert bias for load balancing. - # and also to count the expert usage - # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- - # first in the forward pass, and then in the backward pass. However, this has no - # effect on the expert bias update thanks to the torch.sign() operator. + # HOPs don't support buffer mutations, keep this outside with torch.no_grad(): self.tokens_per_expert.add_(num_tokens_per_expert) - - # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - # NOTE: the reason we need to compute num_tokens_per_expert again is: - # 1st computation in router is to update self.tokens_per_expert - # which would be the same across all TP ranks. - # 2nd computation in reorderer is for the actual routing and experts computation - # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. - # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. - ( - top_scores_experts_sorted, - token_indices_experts_sorted, - num_tokens_per_expert, - ) = self.reorderer(top_scores, selected_experts_indices) - - # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape( - -1, 1 - ).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) - - if self.score_before_experts: - routed_input = ( - routed_input.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_tokens_per_expert) - - if not self.score_before_experts: - routed_output = ( - routed_output.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shared expert - if self.shared_experts is not None: - out = self.shared_experts(x) - else: - out = torch.zeros_like(x) - - out = out.scatter_add( - dim=0, index=token_indices_experts_sorted, src=routed_output - ) - out = out.reshape(bs, slen, dim) return out def init_weights( diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index 0e851d335a..af537c8504 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -19,7 +19,9 @@ # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 - +PERFETTO_UI_ROOT_URL = ( + "https://interncache-all.fbcdn.net/manifold/perfetto-artifacts/tree/ui/index.html" +) @contextlib.contextmanager def maybe_enable_profiling( @@ -46,12 +48,22 @@ def trace_handler(prof): logger.info(f"Dumping profiler traces at step {prof.step_num}") begin = time.monotonic() - output_file = os.path.join(curr_trace_dir, f"rank{rank}_trace.json") + prof.export_chrome_trace(output_file) - logger.info( - f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" - ) + log_str = f"Finished dumping profiler traces in {time.monotonic() - begin:.2f} seconds" + # not directly landable on upstream titan, + # but conveniently prints the internal url for perfetto on manifold for mast jobs + manifold_mount_prefix = "/mnt/mffuse/" + if output_file.find(manifold_mount_prefix) == 0: + manifold_path = os.path.join("torchtrain_datasets/tree", output_file.split(manifold_mount_prefix)[1]) + perfetto_url = ( + PERFETTO_UI_ROOT_URL + + "#!/?url=https://interncache-all.fbcdn.net/manifold/" + + manifold_path + ) + log_str += f": {perfetto_url}" + logger.info(log_str) logger.info(f"Profiling active. Traces will be saved at {trace_dir}") diff --git a/torchtitan/train.py b/torchtitan/train.py index 9b69fd6798..3396fa56d2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -8,10 +8,12 @@ import os import time from datetime import timedelta +from functools import partial from typing import Any, Generator, Iterable, Optional import torch from torch.distributed.elastic.multiprocessing.errors import record +from torch.distributed.tensor import DTensor import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager @@ -121,6 +123,54 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # TODO(whc) + # I do this becuase otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + if job_config.experimental.enable_simplefsdp_passes: + # enable simplefsdp's autobucketing and reorder passes (original code in https://github.com/pytorch/pytorch/pull/160282) + from autoparallel.auto_bucketing import ( + simple_fsdp_autobucketing_reordering_pass, + simplefsdp_autobucketing_config, + ) + + torch._inductor.config.allow_buffer_reuse = False + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = True + simplefsdp_autobucketing_config.save_estimation_path = ( + "/tmp/torchtitan_simplefsdp_comm_estimation.pkl" + ) + simple_fsdp_autobucketing_reordering_pass = partial( + simple_fsdp_autobucketing_reordering_pass, + configs=simplefsdp_autobucketing_config, + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + simple_fsdp_autobucketing_reordering_pass + ] + + # Don't use both sets of passes at the same time! + torch._inductor.config.bucket_all_gathers_fx = "none" + torch._inductor.config.bucket_reduce_scatters_fx = "none" + else: + torch._inductor.config.bucket_all_gathers_fx = ( + job_config.experimental.bucket_all_gathers_fx + ) + torch._inductor.config.bucket_reduce_scatters_fx = ( + job_config.experimental.bucket_reduce_scatters_fx + ) + torch._inductor.config.reorder_for_compute_comm_overlap = ( + job_config.experimental.reorder_for_compute_comm_overlap + ) + torch._inductor.config.reorder_for_compute_comm_overlap_passes = ( + job_config.experimental.reorder_for_compute_comm_overlap_passes + ) + torch._inductor.config.reorder_prefetch_limit = ( + job_config.experimental.reorder_prefetch_limit + ) + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -147,6 +197,7 @@ def __init__(self, job_config: JobConfig): # build model (using meta init) model_args = self.train_spec.model_args[job_config.model.flavor] + model_cls = self.train_spec.model_cls # set the model args from training job configs model_args.update_from_config(job_config) self.model_args = model_args @@ -154,12 +205,12 @@ def __init__(self, job_config: JobConfig): logger.info( f"Building {self.train_spec.name} {job_config.model.flavor} with {model_args}" ) - with torch.device("meta"): - model = self.train_spec.model_cls(model_args) - # Build the collection of model converters. No-op if `model.converters` empty - model_converters = build_model_converters(job_config, parallel_dims) - model_converters.convert(model) + with torch.device("meta"): + model = model_cls(model_args) + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) # metrics logging build_metrics_processor_fn = ( @@ -256,7 +307,7 @@ def __init__(self, job_config: JobConfig): # confirm that user will be able to view loss metrics on the console ensure_pp_loss_visible(parallel_dims, job_config, color) else: - # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + # apply Autoparallel model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) model.to_empty(device=init_device)