@@ -453,6 +453,8 @@ class MoE(nn.Module):
453453 # 1. "torch_all_to_all"
454454 # 2. "symm_mem" (see `setup_symm_mem` below)
455455 shuffle_method = "torch_all_to_all"
456+ # Group GEMM method, "torch" or "torchao"
457+ group_mm = "torch"
456458
457459 # Symmetric memory buffers shared by all MoE instances across layers
458460 token_send_buf : Optional [torch .Tensor ] = None
@@ -490,15 +492,21 @@ def __init__(self, config):
490492 config = config , intermediate_size = intermediate_size
491493 )
492494
493- def combine_experts (self , submod_name ):
495+ def combine_experts (self , submod_name : str ):
494496 all_weights = []
495497 for expert in self .experts .values ():
496498 lin = expert .get_submodule (submod_name )
497499 all_weights .append (lin .weight )
498500 lin .weight = None
499501
500- concat_weight = torch .cat (all_weights )
501- self .register_parameter (f"{ submod_name } _weight" , nn .Parameter (concat_weight ))
502+ if self .group_mm == "torch" :
503+ combined_weight = torch .stack (all_weights )
504+ elif self .group_mm == "torchao" :
505+ combined_weight = torch .cat (all_weights )
506+ else :
507+ raise RuntimeError (f"Unknown Group GEMM method: { self .group_mm } " )
508+
509+ self .register_parameter (f"{ submod_name } _weight" , nn .Parameter (combined_weight ))
502510
503511 # This function is used to create a symm mem buffer for MoE's. It is for
504512 # shuffling tokens fully "on-device", as compared to traditional torch
@@ -510,7 +518,6 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
510518 self .shuffle_method = "symm_mem"
511519
512520 # Combine expert weights
513- print ("Combining expert weights for Group GEMM" )
514521 self .combine_experts ("gate_proj" )
515522 self .combine_experts ("up_proj" )
516523 self .combine_experts ("down_proj" )
@@ -544,6 +551,7 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
544551 device = device ,
545552 )
546553 print (f"EP rank [{ self .ep_rank } ]: Created Symmetric Memory for MoE" )
554+ print ("Combining expert weights for Group GEMM" )
547555
548556 def get_send_buf (self ):
549557 # [Why detach?] During a first forward-backward step, the buffer would
@@ -735,7 +743,7 @@ def moe_on_device(self, x, topk_ids, topk_weight):
735743 token_send_buf = self .get_send_buf ()
736744 token_send_buf [: idxs .shape [0 ]].copy_ (sorted_tokens )
737745 # Note: `out=` avoids copy, but it is not differentiable
738- # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self. token_send_buf[: idxs.shape[0]])
746+ # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=token_send_buf[: idxs.shape[0]])
739747 token_gather_buf , output_splits = OnDeviceAllToAllV .apply (
740748 token_send_buf ,
741749 input_splits ,
@@ -746,7 +754,7 @@ def moe_on_device(self, x, topk_ids, topk_weight):
746754 # This part prepares a 1D tensor `permuted_indices` for such permutation.
747755 # This part doesn't need gradient.
748756 with torch .no_grad ():
749- permuted_indices , m_sizes = generate_permute_indices (
757+ permuted_indices , m_sizes , m_offsets = generate_permute_indices (
750758 tokens_per_expert_group ,
751759 self .experts_per_rank ,
752760 self .ep_size ,
@@ -759,18 +767,36 @@ def moe_on_device(self, x, topk_ids, topk_weight):
759767
760768 # Run the first grouped GEMM
761769 w1 = self .get_parameter ("gate_proj_weight" )
762- gate_proj = grouped_gemm_forward (contig_tokens , w1 , m_sizes )
770+ if self .group_mm == "torchao" :
771+ gate_proj = grouped_gemm_forward (contig_tokens , w1 , m_sizes )
772+ else : # "torch"
773+ gate_proj = torch ._grouped_mm (
774+ contig_tokens , w1 .transpose (- 2 , - 1 ), m_offsets , out_dtype = torch .bfloat16
775+ )
763776
764777 # Run the second grouped GEMM
765778 w3 = self .get_parameter ("up_proj_weight" )
766- up_proj = grouped_gemm_forward (contig_tokens , w3 , m_sizes )
779+ if self .group_mm == "torchao" :
780+ up_proj = grouped_gemm_forward (contig_tokens , w3 , m_sizes )
781+ else : # "torch"
782+ up_proj = torch ._grouped_mm (
783+ contig_tokens , w3 .transpose (- 2 , - 1 ), m_offsets , out_dtype = torch .bfloat16
784+ )
767785
768786 # Apply activation
769787 hidden_outputs = MLP .act_fn (gate_proj ) * up_proj
770788
771789 # Run the third grouped GEMM
772790 w2 = self .get_parameter ("down_proj_weight" )
773- hidden_outputs = grouped_gemm_forward (hidden_outputs , w2 , m_sizes )
791+ if self .group_mm == "torchao" :
792+ hidden_outputs = grouped_gemm_forward (hidden_outputs , w2 , m_sizes )
793+ else : # "torch"
794+ hidden_outputs = torch ._grouped_mm (
795+ hidden_outputs ,
796+ w2 .transpose (- 2 , - 1 ),
797+ m_offsets ,
798+ out_dtype = torch .bfloat16 ,
799+ )
774800
775801 # Prepare buffer for tokens processed by experts
776802 # Take necessary space from `token_gather_buf` symm mem because we are
0 commit comments