From d0170b178e8d43903f8aa7c40060c3e2b1ff5496 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 22 Sep 2025 16:03:29 -0700 Subject: [PATCH] [DSV3] Replace 1D A2A with 2D A2A --- torchtitan/__init__.py | 6 +- torchtitan/experiments/__init__.py | 8 +- torchtitan/experiments/deepseek_v3/model.py | 189 ++++++++------------ 3 files changed, 78 insertions(+), 125 deletions(-) diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index 73c51446d2..b96146ac5c 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -5,9 +5,9 @@ # LICENSE file in the root directory of this source tree. # Import to register quantization modules. -import torchtitan.components.quantization # noqa: F401 +#import torchtitan.components.quantization # noqa: F401 # Import the built-in models here so that the corresponding register_model_spec() # will be called. -import torchtitan.experiments # noqa: F401 -import torchtitan.models # noqa: F401 +#import torchtitan.experiments # noqa: F401 +#import torchtitan.models # noqa: F401 diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index d11ef99d88..0c2e95ee34 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import torchtitan.experiments.llama4 # noqa: F401 -import torchtitan.experiments.qwen3 -import torchtitan.experiments.simple_fsdp # noqa: F401 -import torchtitan.experiments.vlm # noqa: F401 +#import torchtitan.experiments.llama4 # noqa: F401 +#import torchtitan.experiments.qwen3 +#import torchtitan.experiments.simple_fsdp # noqa: F401 +#import torchtitan.experiments.vlm # noqa: F401 diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py index 5ee68524c6..03e02b3710 100644 --- a/torchtitan/experiments/deepseek_v3/model.py +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -44,22 +44,12 @@ from attn_mask_utils import _prepare_4d_causal_attention_mask -from group_gemms import ( - DSGroupGEMM, - ManualLoopGroupGEMM, - TorchAOBF16GroupGEMM, - TorchBF16GroupGEMM, - TorchFP8GroupGEMM, - TritonCGBF16GroupGEMM, -) - from model_config import ModelArgs from symm_mem_recipes import OnDeviceAllToAllV from torch import nn from torch.distributed._functional_collectives import all_to_all_single_autograd -from torchtitan.experiments.kernels.moe.indices import generate_permute_indices -from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import ALIGN_SIZE_M +ALIGN_SIZE_M = 8 # Get model parallel subgroup by name: @@ -472,11 +462,6 @@ class MoE(nn.Module): token_send_buf: Optional[torch.Tensor] = None token_gather_buf: Optional[torch.Tensor] = None - # Group GEMM strategies - group_gemm_strategies = None - # which group gemm to use? - group_mm = "manual" # fp8 options = ["torchfp8", "dsgemm"] bf16 = ["torch", , "torchao", "tritoncg", "manual"] - def __init__(self, config): super().__init__() self.config = config @@ -512,50 +497,8 @@ def __init__(self, config): config=config, intermediate_size=intermediate_size ) - # Group Gemm - # Initialize group GEMM strategies if not already loaded - if MoE.group_gemm_strategies is None: - MoE._initialize_group_gemm_strategies() - - assert ( - MoE.group_mm in MoE.group_gemm_strategies - ), f"selected group gemm {self.group_mm} is not available!" - # keep active gg ready - self.group_gemm_instance = MoE.group_gemm_strategies[MoE.group_mm] self._buffer_initialized = False - @classmethod - def _initialize_group_gemm_strategies(cls): - """Initialize available group GEMM strategies""" - cls.group_gemm_strategies = { - # torch._group_MM - "torch": TorchBF16GroupGEMM(MLP.act_fn), - # torch.mm with looping - "manual": ManualLoopGroupGEMM(MLP.act_fn), - "torchao": ( - TorchAOBF16GroupGEMM(MLP.act_fn) - if TorchAOBF16GroupGEMM.is_available() - else None - ), - "torchfp8": ( - TorchFP8GroupGEMM(MLP.act_fn) - if TorchFP8GroupGEMM.is_available() - else None - ), - "dsgemm": ( - DSGroupGEMM(MLP.act_fn, use_triton_quant=True) - if DSGroupGEMM.is_available() - else None - ), - "tritoncg": ( - TritonCGBF16GroupGEMM( - MLP.act_fn, - ) - if TritonCGBF16GroupGEMM.is_available() - else None - ), - } - def combine_experts(self, submod_name: str): all_weights = [] for expert in self.experts.values(): @@ -565,12 +508,7 @@ def combine_experts(self, submod_name: str): lin.weight = None # let the group gemm strategy prep the final weight layout - combined_weight = self.group_gemm_instance.arrange_expert_weights( - all_weights, submod_name, self - ) - - if combined_weight is None: - raise NotImplementedError("expert weights not handled by group gemmm") + combined_weight = torch.stack(all_weights) self.register_parameter(f"{submod_name}_weight", nn.Parameter(combined_weight)) @@ -599,10 +537,17 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): if MoE.token_send_buf is not None: return + self.group_name = self.ep_group.group_name + + symm_mem.set_backend("NVSHMEM") + symm_mem.enable_symm_mem_for_group("0") + symm_mem.enable_symm_mem_for_group(self.group_name) + # Input buffer for DP-to-EP shuffle MoE.token_send_buf = symm_mem.empty( self.config.max_seq_len - * self.num_experts_per_tok, # seq len * top k (flattened) + * self.num_experts_per_tok # seq len * top k (flattened) + * overflow, self.config.hidden_size, # hidden dim dtype=dtype, device=device, @@ -617,6 +562,15 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): device=device, ) + nsplits = self.config.n_routed_experts + MoE.in_splits = symm_mem.empty(nsplits, dtype=torch.int64, device=device) + MoE.out_splits_offsets = symm_mem.empty( + (2, nsplits), dtype=torch.int64, device=device + ) + MoE.combine_out_splits_offsets = symm_mem.empty( + (2, nsplits), dtype=torch.int64, device=device + ) + def get_send_buf(self): # [Why detach?] During a first forward-backward step, the buffer would # be included in a computational graph. In a second step, autograd will @@ -791,13 +745,37 @@ def sort_tokens(self, x, topk_ids, topk_weights): def _run_group_gemm(self, contig_tokens, m_sizes, m_offsets): """Run the appropriate group GEMM implementation based on configuration""" - try: - return self.group_gemm_strategies[self.group_mm].execute( - contig_tokens, m_sizes, m_offsets, self - ) - except Exception as e: - # Flag the error - print(f"Error using {self.group_mm} strategy: {e}") + # Get weights + w_gate = self.get_parameter("gate_proj_weight") + w_up = self.get_parameter("up_proj_weight") + w_down = self.get_parameter("down_proj_weight") + + # Run first two GEMMs (gate and up projections) + gate_proj = torch._grouped_mm( + contig_tokens, + w_gate.transpose(-2, -1), + m_offsets, + out_dtype=torch.bfloat16, + ) + up_proj = torch._grouped_mm( + contig_tokens, + w_up.transpose(-2, -1), + m_offsets, + out_dtype=torch.bfloat16, + ) + + # Apply activation + hidden_outputs = self.activation_function(gate_proj) * up_proj + + # Run the third GEMM (down projection) + hidden_outputs = torch._grouped_mm( + hidden_outputs, + w_down.transpose(-2, -1), + m_offsets, + out_dtype=torch.bfloat16, + ) + + return hidden_outputs def moe_on_device(self, x, topk_ids, topk_weight): ( @@ -814,65 +792,40 @@ def moe_on_device(self, x, topk_ids, topk_weight): # band", which is not part of the actual data. Thus no gradient is # needed. - # Sum the tokens over local experts, then we get tokens per EP rank, - # which is the input splits - with torch.no_grad(): - tokens_per_expert_group = tokens_per_expert.new_empty( - tokens_per_expert.shape[0] - ) - dist.all_to_all_single( - tokens_per_expert_group, tokens_per_expert, group=self.ep_group - ) - input_splits = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + MoE.in_splits.copy_(tokens_per_expert.view(-1)) # Move input to the `token_send_buf` symm mem token_send_buf = self.get_send_buf() token_send_buf[: token_indices.shape[0]].copy_(sorted_tokens) - # Note: `out=` avoids copy, but it is not differentiable - # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=token_send_buf[: idxs.shape[0]]) - token_gather_buf, output_splits = OnDeviceAllToAllV.apply( - token_send_buf, - input_splits, - self.ep_group, - ) + token_gather_buf = self.get_gather_buf() - # We need to permute the received tokens so that tokens for the same expert are contiguous. - # This part prepares a 1D tensor `permuted_indices` for such permutation. - # This part doesn't need gradient. - with torch.no_grad(): - permuted_indices, m_sizes, m_offsets = generate_permute_indices( - tokens_per_expert_group, - self.experts_per_rank, - self.ep_size, - token_gather_buf.shape[0], - ALIGN_SIZE_M, - ) + # Dispatch the tokens + torch.ops.symm_mem.all_to_all_vdev_2d( + token_send_buf, token_gather_buf, MoE.in_splits, MoE.out_splits_offsets, self.group_name, major_align=ALIGN_SIZE_M + ) - # Permute the received tokens so that tokens for the same expert are contiguous. - contig_tokens = token_gather_buf[permuted_indices] + m_offsets = torch.empty(self.experts_per_rank, dtype=MoE.in_splits.dtype, device=MoE.in_splits.device) + exclusive_offsets = MoE.out_splits_offsets[1].view(self.experts_per_rank, -1) + m_offsets[:-1].copy_(exclusive_offsets[1: , 0]) + m_offsets[-1] = exclusive_offsets[-1, 0] + MoE.in_splits[-1] # group gemm - handle all three group gemms (up, gate, down for all experts) - hidden_outputs = self._run_group_gemm( - contig_tokens, - m_sizes, - m_offsets, + processed_tokens = self._run_group_gemm( + token_gather_buf, + MoE.out_splits_offsets[0], + m_offsets.to(torch.int32), ) - # Prepare buffer for tokens processed by experts - processed_tokens = self.get_gather_buf() - - # Move into Symmetric Memory for the return shuffle - processed_tokens[permuted_indices] = hidden_outputs + token_gather_buf.copy_(processed_tokens) - # Now shuffle the tokens back to their original owner, i.e. EP to DP shuffle. - # The input/output splits are just a reverse of the previous shuffle. - token_return_buf, _ = OnDeviceAllToAllV.apply( - processed_tokens, - output_splits, - self.ep_group, + # Combine the tokens + # `out_splits_offsets` from shuffle is exactly the `input_splits_offsets` for combine + # `out` data from shuffle is exactly the `input` data for combine + torch.ops.symm_mem.all_to_all_vdev_2d_offset( + token_gather_buf, token_send_buf, MoE.out_splits_offsets, MoE.combine_out_splits_offsets, self.group_name ) - returned_tokens = token_return_buf[:seqlen_sorted_tokens] + returned_tokens = token_send_buf[:seqlen_sorted_tokens] output_tokens = torch.empty_like(returned_tokens) output_tokens[token_indices] = returned_tokens