diff --git a/torchtitan/experiments/auto_parallel/__init__.py b/torchtitan/experiments/auto_parallel/__init__.py index a67dfe18a..8d397c614 100644 --- a/torchtitan/experiments/auto_parallel/__init__.py +++ b/torchtitan/experiments/auto_parallel/__init__.py @@ -39,7 +39,7 @@ ) register_train_spec( TrainSpec( - name="deepseekv3_auto_parallel", + name="deepseek_v3_auto_parallel", model_cls=DeepSeekV3Model, model_args=deepseekv3_configs, parallelize_fn=parallelize_deepseekv3, diff --git a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py index cf69511e0..4220a8599 100644 --- a/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py +++ b/torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py @@ -19,13 +19,12 @@ from torchtitan.tools.logging import logger -def apply_local_map_to_moe(): +def apply_local_map_to_moe(mesh): """ TODO: fix HOPs not restoring the original signature. TODO: fix tracing with local shapes so that we can use Shard placements - Current HOP signature we get: - + Current HOP signature we get 9 inputs: class subgraph_0(torch.nn.Module): def forward(self, rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0", @@ -43,28 +42,42 @@ def forward(self, moe._moe_forward = local_map( moe._moe_forward, out_placements=( - (Replicate(),), # (Shard(0),), + # (Shard(0),), + (Replicate(),), (Replicate(),), ), in_placements=( - (Replicate(),), # (Shard(0),), + # x + # (Shard(0),), + (Replicate(),), + # router weight (Replicate(),), + # expert bias (Replicate(),), + # expert + # (Shard(0),), + # (Shard(0),), + # (Shard(0),), (Replicate(),), (Replicate(),), (Replicate(),), + # shared (Replicate(),), (Replicate(),), (Replicate(),), + None, + None ), redistribute_inputs=True, in_grad_placements=None, - device_mesh=None, + device_mesh=mesh, ) -# Run workflow with: -# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel +# Run AP with: +# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseek_v3_auto_parallel +# Non-AP command: +# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseek_v3 def parallelize_deepseekv3( model, parallel_dims: ParallelDims, @@ -100,7 +113,7 @@ def input_fn(): assert parallel_dims.pp_enabled is False, "PP not supported yet" # apply local_map to MoE - apply_local_map_to_moe() + apply_local_map_to_moe(world_mesh) # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( # lambda bucket_idx: 500 / parallel_dims.tp @@ -217,9 +230,16 @@ def get_moe_modules(model): # Copy custom attributes from original to parallel MoE modules # This is fine to do since these attributes are read only for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, 'moe_enabled'): par_moe.load_balance_coeff = orig_moe.load_balance_coeff + # piggyback in some asserts needed for our modified forward + # these can't be in the model code because export thinks they are data-dependent + assert orig_moe.score_before_experts + assert orig_moe.experts.use_grouped_mm + assert orig_moe.shared_experts is not None + # Copy load_balance_coeff if hasattr(orig_moe, 'load_balance_coeff'): par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/models/moe.py b/torchtitan/models/moe.py index 1ec8e3b23..bd020f6d5 100644 --- a/torchtitan/models/moe.py +++ b/torchtitan/models/moe.py @@ -31,6 +31,8 @@ class MoEArgs: use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 +def functional_feed_forward(w1, w2, w3, x): + return F.linear(F.silu(F.linear(x, w1)) * F.linear(x, w3), w2) # can be used as dense FFN layer or shared experts in MoE layers class FeedForward(nn.Module): @@ -191,7 +193,7 @@ def __init__( self.route_scale = route_scale def forward( - self, x: torch.Tensor, expert_bias: torch.Tensor | None = None + self, x: torch.Tensor, gate_weight: torch.nn.Parameter, expert_bias: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: @@ -209,7 +211,8 @@ def forward( Number of tokens assigned to each expert with shape ``(num_experts,)``. """ # scores shape (bs*slen, num_experts) - scores = self.gate(x) + # scores = self.gate(x) + scores = torch.nn.functional.linear(x, gate_weight) # By default, sigmoid or softmax is performed in float32 to avoid loss explosion if self.score_func == "sigmoid": @@ -311,7 +314,29 @@ def forward( num_tokens_per_expert, ) -def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts): +# def forward(self, +# rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0", +# self____modules__layers____modules__1____modules__moe____modules__router____modules__gate____parameters__weight: "f32[8, 256][256, 1]cuda:0", +# self____modules__layers____modules__1____modules__moe____buffers__expert_bias: "f32[8][1]cuda:0", +# self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w1: "f32[8, 256, 256][65536, 256, 1]cuda:0", +# self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w3: "f32[8, 256, 256][65536, 256, 1]cuda:0", +# self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w2: "f32[8, 256, 256][65536, 256, 1]cuda:0", +# self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w1____parameters__weight: "f32[512, 256][256, 1]cuda:0", +# self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w3____parameters__weight: "f32[512, 256][256, 1]cuda:0", +# self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w2____parameters__weight: "f32[256, 512][512, 1]cuda:0"): +def _moe_forward( + x, + router_gate_weight, + expert_bias, + experts_w1, + experts_w3, + experts_w2, + shared_w1, + shared_w3, + shared_w2, + router, # None + reorderer, # None +): # x: 64, 2048, 256 bs, slen, dim = x.shape x = x.view(-1, dim) @@ -322,7 +347,7 @@ def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, expert top_scores, selected_experts_indices, num_tokens_per_expert, - ) = router(x, expert_bias) + ) = router(x, router_gate_weight, expert_bias) # tokens_per_expert will be used to update the expert bias for load balancing. # and also to count the expert usage @@ -354,27 +379,23 @@ def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, expert # shape (bs*slen*top_k, dim) routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) - - if score_before_experts: - routed_input = ( - routed_input.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) + routed_input = ( + routed_input.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) # shape (bs*slen*top_k, dim) - routed_output = experts(routed_input, num_tokens_per_expert) - - if not score_before_experts: - routed_output = ( - routed_output.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) + # routed_output = experts(routed_input, num_tokens_per_expert) + routed_output = _run_experts_grouped_mm( + experts_w1, experts_w2, experts_w3, routed_input, num_tokens_per_expert + ) # shared expert - if shared_experts is not None: - out = shared_experts(x) - else: - out = torch.zeros_like(x) + # if shared_experts is not None: + # out = shared_experts(x) + # else: + # out = torch.zeros_like(x) + out = functional_feed_forward(shared_w1, shared_w2, shared_w3, x) out = out.scatter_add( dim=0, index=token_indices_experts_sorted, src=routed_output @@ -439,7 +460,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: Returns: out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. """ - out, num_tokens_per_expert = _moe_forward(x, self.router, self.expert_bias, self.reorderer, self.score_before_experts, self.experts, self.shared_experts) + experts_w1, experts_w2, experts_w3 = self.experts.parameters() + shared_w1, shared_w2, shared_w3 = self.shared_experts.parameters() + out, num_tokens_per_expert = _moe_forward( + x, + self.router.gate.weight, + self.expert_bias, + experts_w1, + experts_w3, + experts_w2, + shared_w1, + shared_w3, + shared_w2, + self.router, # None + self.reorderer, # None + ) # HOPs don't support buffer mutations, keep this outside with torch.no_grad():