Skip to content

Commit 5707c3d

Browse files
authored
[DeepSeek] Add option for torch._group_mm (#1086)
We have two options for Group GEMM now: Option 1: `torch._group_mm` Option 2: `grouped_gemm_forward` (torchao)
1 parent dbb34cc commit 5707c3d

File tree

2 files changed

+40
-13
lines changed

2 files changed

+40
-13
lines changed

torchtitan/experiments/deepseek_v3/model.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,8 @@ class MoE(nn.Module):
453453
# 1. "torch_all_to_all"
454454
# 2. "symm_mem" (see `setup_symm_mem` below)
455455
shuffle_method = "torch_all_to_all"
456+
# Group GEMM method, "torch" or "torchao"
457+
group_mm = "torch"
456458

457459
# Symmetric memory buffers shared by all MoE instances across layers
458460
token_send_buf: Optional[torch.Tensor] = None
@@ -490,15 +492,21 @@ def __init__(self, config):
490492
config=config, intermediate_size=intermediate_size
491493
)
492494

493-
def combine_experts(self, submod_name):
495+
def combine_experts(self, submod_name: str):
494496
all_weights = []
495497
for expert in self.experts.values():
496498
lin = expert.get_submodule(submod_name)
497499
all_weights.append(lin.weight)
498500
lin.weight = None
499501

500-
concat_weight = torch.cat(all_weights)
501-
self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
502+
if self.group_mm == "torch":
503+
combined_weight = torch.stack(all_weights)
504+
elif self.group_mm == "torchao":
505+
combined_weight = torch.cat(all_weights)
506+
else:
507+
raise RuntimeError(f"Unknown Group GEMM method: {self.group_mm}")
508+
509+
self.register_parameter(f"{submod_name}_weight", nn.Parameter(combined_weight))
502510

503511
# This function is used to create a symm mem buffer for MoE's. It is for
504512
# shuffling tokens fully "on-device", as compared to traditional torch
@@ -510,7 +518,6 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
510518
self.shuffle_method = "symm_mem"
511519

512520
# Combine expert weights
513-
print("Combining expert weights for Group GEMM")
514521
self.combine_experts("gate_proj")
515522
self.combine_experts("up_proj")
516523
self.combine_experts("down_proj")
@@ -544,6 +551,7 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
544551
device=device,
545552
)
546553
print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
554+
print("Combining expert weights for Group GEMM")
547555

548556
def get_send_buf(self):
549557
# [Why detach?] During a first forward-backward step, the buffer would
@@ -735,7 +743,7 @@ def moe_on_device(self, x, topk_ids, topk_weight):
735743
token_send_buf = self.get_send_buf()
736744
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
737745
# Note: `out=` avoids copy, but it is not differentiable
738-
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
746+
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=token_send_buf[: idxs.shape[0]])
739747
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
740748
token_send_buf,
741749
input_splits,
@@ -746,7 +754,7 @@ def moe_on_device(self, x, topk_ids, topk_weight):
746754
# This part prepares a 1D tensor `permuted_indices` for such permutation.
747755
# This part doesn't need gradient.
748756
with torch.no_grad():
749-
permuted_indices, m_sizes = generate_permute_indices(
757+
permuted_indices, m_sizes, m_offsets = generate_permute_indices(
750758
tokens_per_expert_group,
751759
self.experts_per_rank,
752760
self.ep_size,
@@ -759,18 +767,36 @@ def moe_on_device(self, x, topk_ids, topk_weight):
759767

760768
# Run the first grouped GEMM
761769
w1 = self.get_parameter("gate_proj_weight")
762-
gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
770+
if self.group_mm == "torchao":
771+
gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
772+
else: # "torch"
773+
gate_proj = torch._grouped_mm(
774+
contig_tokens, w1.transpose(-2, -1), m_offsets, out_dtype=torch.bfloat16
775+
)
763776

764777
# Run the second grouped GEMM
765778
w3 = self.get_parameter("up_proj_weight")
766-
up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
779+
if self.group_mm == "torchao":
780+
up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
781+
else: # "torch"
782+
up_proj = torch._grouped_mm(
783+
contig_tokens, w3.transpose(-2, -1), m_offsets, out_dtype=torch.bfloat16
784+
)
767785

768786
# Apply activation
769787
hidden_outputs = MLP.act_fn(gate_proj) * up_proj
770788

771789
# Run the third grouped GEMM
772790
w2 = self.get_parameter("down_proj_weight")
773-
hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
791+
if self.group_mm == "torchao":
792+
hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
793+
else: # "torch"
794+
hidden_outputs = torch._grouped_mm(
795+
hidden_outputs,
796+
w2.transpose(-2, -1),
797+
m_offsets,
798+
out_dtype=torch.bfloat16,
799+
)
774800

775801
# Prepare buffer for tokens processed by experts
776802
# Take necessary space from `token_gather_buf` symm mem because we are

torchtitan/experiments/kernels/moe/indices.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def generate_permute_indices(
139139
torch.int32
140140
)
141141
# Perform another prefix sum to get the write offset of each expert in `permuted_indices`
142-
write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
142+
m_offsets = torch.cumsum(m_sizes, 0)
143+
write_offsets = m_offsets - m_sizes
143144
# Select the method to fill the permuted indices
144145
fill_fn = fill_indices_cpu if use_cpu else fill_indices
145146
# Fill the permuted indices
@@ -151,7 +152,7 @@ def generate_permute_indices(
151152
num_ranks,
152153
max_len,
153154
)
154-
return permuted_indices, m_sizes
155+
return permuted_indices, m_sizes, m_offsets.to(torch.int32)
155156

156157

157158
# Below is for testing only
@@ -167,11 +168,11 @@ def test():
167168
max_len = 128
168169
alignment = 32
169170
# Use the GPU kernel
170-
permuted_indices_gpu, m_sizes = generate_permute_indices(
171+
permuted_indices_gpu, m_sizes, _ = generate_permute_indices(
171172
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
172173
)
173174
# Use the CPU method
174-
permuted_indices_cpu, _ = generate_permute_indices(
175+
permuted_indices_cpu, _, _ = generate_permute_indices(
175176
tokens_per_expert_group,
176177
experts_per_rank,
177178
num_ranks,

0 commit comments

Comments
 (0)