Skip to content

Commit dbb34cc

Browse files
authored
[llama4] add Grouped GEMM support for MoE (#1084)
This PR 1. adds grouped gemm support (pytorch/pytorch#150374) for llama4 MoE. In general, it avoids device/host syncs; for TP, it avoided the sharding prop overhead caused by varying number of tokens for individual experts. The speedup on the debug model is ~4x with/without TP. I'm deliberately keeping the for-loop implementation for now, for comparison and readability purposes. 2. moves the MoE indices kernel from the deepseek folder to the kernel folder. In order for TP to work, it requires some pytorch-side changes (e.g. DTensor support for `torch._grouped_mm`), for which I will submit PRs soon. A issue is that the grouped gemm version doesn't work well with AdamW optimizer, which is to be investigated. cc: @janeyx99
1 parent ab08612 commit dbb34cc

File tree

9 files changed

+127
-68
lines changed

9 files changed

+127
-68
lines changed

torchtitan/experiments/deepseek_v3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@
3737
import torch.utils.checkpoint
3838

3939
from attn_mask_utils import _prepare_4d_causal_attention_mask
40-
from indices import generate_permute_indices
4140
from model_config import ModelArgs
4241
from symm_mem_recipes import OnDeviceAllToAllV
4342
from torch import nn
4443
from torch.distributed._functional_collectives import all_to_all_single_autograd
4544

45+
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices
4646
from torchtitan.experiments.kernels.triton_mg_group_gemm.torchao_pr import (
4747
ALIGN_SIZE_M,
4848
grouped_gemm_forward,

torchtitan/experiments/llama4/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
**The Llama 4 folder is still under development.**
22

33
#### Available features
4-
- Llama 4 model definition (text-only), including the MoE architecture with token-choice routing
5-
- Basic FSDP, TP, PP, CP support
4+
- Llama 4 model definition (text-only), including the MoE architecture with token-choice routing using efficient bfloat16 Grouped MM kernels
5+
- FSDP, TP, PP, CP support
66
- DCP checkpoint conversion scripts
77

88
#### Download Llama 4 tokenizer
@@ -17,13 +17,13 @@ python scripts/download_tokenizer.py --repo_id meta-llama/Llama-4-Scout-17B-16E
1717
- load balance loss for token-choice MoE
1818
- alternative expert-choice MoE
1919
- multimodal support
20-
- Kernel integration
21-
- efficient bfloat16 GroupedGEMM kernels (from PyTorch core)
22-
- efficient float8 GroupedGEMM kernels (from torchao)
2320
- Parallelism
24-
- performant TP implementation and torch.compile support for MoE layers
2521
- Context Parallel support for FlexAttention, iRoPE, and multimodal inputs
2622
- Expert Parallel support
23+
- torch.compile
24+
- for MoE layers
25+
- Quantization
26+
- efficient float8 GroupedGEMM kernels (from torchao)
2727
- Testing
2828
- perfomance and loss converging tests
2929
- CI integration

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,21 @@
88
from functools import partial
99
from typing import Optional, Tuple
1010

11+
import torch
1112
import torch.nn as nn
1213
from torch.distributed.tensor import (
1314
DeviceMesh,
1415
distribute_module,
1516
distribute_tensor,
1617
DTensor,
17-
Partial,
1818
Replicate,
1919
Shard,
2020
)
2121
from torch.distributed.tensor.parallel import ParallelStyle
2222
from torch.distributed.tensor.placement_types import Placement
2323

2424

25-
# implementation of Tensor Parallel on the non-shared experts in MoE
25+
# implementation of Tensor Parallel for the GroupedExperts in MoE
2626
class TensorParallel(ParallelStyle):
2727
def __init__(
2828
self,
@@ -32,33 +32,31 @@ def __init__(
3232
use_local_output: bool = True,
3333
):
3434
super().__init__()
35-
self.input_layouts = input_layouts or (Replicate(), None)
36-
self.output_layout = output_layout or Partial()
37-
self.desired_input_layouts = (Replicate(), None)
35+
self.input_layouts = input_layouts or (Replicate(), Replicate())
36+
self.output_layout = output_layout or Replicate()
37+
self.desired_input_layouts = (Replicate(), Replicate())
3838
self.use_local_output = use_local_output
3939

4040
@staticmethod
4141
def _prepare_input_fn(
4242
input_layouts, desired_input_layouts, mod, inputs, device_mesh
4343
):
44-
# TODO: figure out dynamo support for instance method and switch this to instance method
45-
44+
prepared_inputs = []
4645
# annotate module input placements/sharding with input_layouts
47-
input_tensor, input_layout, desired_input_layout = (
48-
inputs[0],
49-
input_layouts[0],
50-
desired_input_layouts[0],
51-
)
52-
if not isinstance(input_tensor, DTensor):
53-
input_tensor = DTensor.from_local(
54-
input_tensor, device_mesh, (input_layout,), run_check=False
55-
)
56-
57-
if input_layouts != desired_input_layouts:
58-
input_tensor = input_tensor.redistribute(
59-
placements=(desired_input_layout,), async_op=True
60-
)
61-
return (input_tensor, *inputs[1:])
46+
for inp, input_layout, desired_input_layout in zip(
47+
inputs, input_layouts, desired_input_layouts
48+
):
49+
if isinstance(inp, torch.Tensor):
50+
if not isinstance(inp, DTensor):
51+
inp = DTensor.from_local(
52+
inp, device_mesh, (input_layout,), run_check=False
53+
)
54+
if input_layout != desired_input_layout:
55+
inp = inp.redistribute(
56+
placements=(desired_input_layout,), async_op=True
57+
)
58+
prepared_inputs.append(inp)
59+
return tuple(prepared_inputs)
6260

6361
def _partition_fn(self, name, module, device_mesh):
6462
module.register_parameter(

torchtitan/experiments/llama4/infra/parallelize_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,8 @@ def apply_moe_tp(
149149
# replicate computation for the router
150150
"moe.router.gate": NoParallel(),
151151
# input Replicate, output Partial
152-
"moe.experts": TensorParallel(),
153-
"moe.shared_expert": TensorParallel(),
152+
"moe.experts": TensorParallel(output_layout=Partial()),
153+
"moe.shared_expert": TensorParallel(output_layout=Partial()),
154154
}
155155
parallelize_module(
156156
module=transformer_block,

torchtitan/experiments/llama4/model/args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ class TransformerModelArgs(BaseModelArgs):
4747
interleave_moe_layer_step: int = 2
4848
# token-choice
4949
top_k: int = 1
50+
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
5051

5152
def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None:
5253
self.norm_type = job_config.model.norm_type

torchtitan/experiments/llama4/model/moe.py

Lines changed: 93 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,53 +17,72 @@ def __init__(
1717
dim: int,
1818
hidden_dim: int,
1919
num_experts: int,
20+
use_grouped_mm: bool,
2021
):
2122
super().__init__()
2223
self.num_experts = num_experts
2324
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
2425
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
2526
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
27+
self.use_grouped_mm = use_grouped_mm
2628

2729
def forward(
2830
self,
2931
x: torch.Tensor,
30-
num_local_tokens_per_expert: torch.Tensor | None = None,
32+
num_local_tokens_per_expert: torch.Tensor | list[int] | None = None,
3133
) -> torch.Tensor:
32-
if num_local_tokens_per_expert is not None:
33-
# a tuple of tensors indexed by experts
34-
# each with shape (tokens_per_expert(varying), dim)
35-
x = torch.split(
36-
x,
37-
split_size_or_sections=num_local_tokens_per_expert.tolist(),
38-
dim=0,
39-
)
40-
out_experts_splits = []
41-
for expert_idx, x_expert in enumerate(x):
42-
w1, w2, w3 = (
43-
self.w1[expert_idx],
44-
self.w2[expert_idx],
45-
self.w3[expert_idx],
34+
# TODO: keeping this for loop implementation for comparison
35+
# and readability, will remove later
36+
if not self.use_grouped_mm:
37+
if num_local_tokens_per_expert is not None:
38+
# a tuple of tensors indexed by experts
39+
# each with shape (tokens_per_expert(varying), dim)
40+
x = torch.split(
41+
x,
42+
split_size_or_sections=num_local_tokens_per_expert,
43+
dim=0,
4644
)
47-
h = F.silu(torch.matmul(x_expert, w1))
48-
h = h * torch.matmul(x_expert, w3)
49-
h = torch.matmul(h, w2)
50-
# h shape (tokens_per_expert(varying), dim)
51-
out_experts_splits.append(h)
52-
out = torch.cat(out_experts_splits, dim=0)
53-
54-
# TODO:optimize with GroupedGEMM
45+
out_experts_splits = []
46+
for expert_idx, x_expert in enumerate(x):
47+
w1, w2, w3 = (
48+
self.w1[expert_idx],
49+
self.w2[expert_idx],
50+
self.w3[expert_idx],
51+
)
52+
h = F.silu(torch.matmul(x_expert, w1))
53+
h = h * torch.matmul(x_expert, w3)
54+
h = torch.matmul(h, w2)
55+
# h shape (tokens_per_expert(varying), dim)
56+
out_experts_splits.append(h)
57+
out = torch.cat(out_experts_splits, dim=0)
58+
else:
59+
# x shape (num_experts, tokens_per_expert, dim)
60+
h = F.silu(torch.bmm(x, self.w1))
61+
h = h * torch.bmm(x, self.w3)
62+
# out shape (num_experts, tokens_per_expert, dim)
63+
out = torch.bmm(h, self.w2)
64+
65+
return out
66+
67+
# grouped mm implementation
68+
if num_local_tokens_per_expert is not None:
5569
# https://github.com/pytorch/pytorch/pull/150374
56-
# _gouped_mm requires shapes to be multiple of 8
57-
# offsets = torch.cumsum(num_local_tokens_per_expert, dim=0, dtype=torch.int32)
58-
# h = F.silu(torch._grouped_mm(x, self.w1.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16))
59-
# h = h * torch._grouped_mm(x, self.w3.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
60-
# out = torch._grouped_mm(h, self.w2.transpose(-2, -1), offs=offsets, out_dtype=torch.bfloat16)
70+
# NOTE: torch._gouped_mm requires bf16 dtypes
71+
# and shapes to be multiple of 8
72+
offsets = torch.cumsum(
73+
num_local_tokens_per_expert, dim=0, dtype=torch.int32
74+
)
75+
# grouped mm between a 2D tensor and a 3D tensor
76+
assert x.dim() == 2
6177
else:
62-
# x shape (num_experts, tokens_per_expert, dim)
63-
h = F.silu(torch.bmm(x, self.w1))
64-
h = h * torch.bmm(x, self.w3)
65-
# out shape (num_experts, tokens_per_expert, dim)
66-
out = torch.bmm(h, self.w2)
78+
offsets = None
79+
# fall back to regular bmm between 3D tensors
80+
assert x.dim() == 3
81+
82+
h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets))
83+
h = h * torch._grouped_mm(x, self.w3, offs=offsets)
84+
out = torch._grouped_mm(h, self.w2, offs=offsets)
85+
6786
return out
6887

6988
def init_weights(self, init_std: float):
@@ -166,14 +185,23 @@ def __init__(self, model_args: TransformerModelArgs):
166185
hidden_dim = int(hidden_dim / hidden_dim_denom)
167186
hidden_dim += -hidden_dim % model_args.multiple_of
168187

188+
self.use_grouped_mm = model_args.use_grouped_mm
169189
self.experts = GroupedExperts(
170-
dim=dim, hidden_dim=hidden_dim, num_experts=num_experts
190+
dim=dim,
191+
hidden_dim=hidden_dim,
192+
num_experts=num_experts,
193+
use_grouped_mm=self.use_grouped_mm,
171194
)
172195
self.router = TokenChoiceTopKRouter(
173196
dim=dim, num_experts=num_experts, top_k=model_args.top_k
174197
)
175198
self.shared_expert = (
176-
GroupedExperts(dim=dim, hidden_dim=hidden_dim, num_experts=1)
199+
GroupedExperts(
200+
dim=dim,
201+
hidden_dim=hidden_dim,
202+
num_experts=1,
203+
use_grouped_mm=self.use_grouped_mm,
204+
)
177205
if model_args.use_shared_expert
178206
else None
179207
)
@@ -206,6 +234,36 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
206234
)
207235
routed_input = routed_input * top_scores.reshape(-1, 1)
208236

237+
if self.use_grouped_mm:
238+
# NOTE: In order to use torch._grouped_mm, we need to make sure
239+
# the number of tokens each expert gets is a multiple of 16.
240+
# The following kernel helps achieve this via padding, without
241+
# incurring synchronization between device and host.
242+
from torchtitan.experiments.kernels.moe.indices import (
243+
generate_permute_indices,
244+
)
245+
246+
ALIGN_SIZE_M = 16
247+
248+
with torch.no_grad():
249+
permuted_indices, m_sizes = generate_permute_indices(
250+
num_local_tokens_per_expert,
251+
self.experts.num_experts,
252+
1,
253+
token_indices.shape[0] + self.experts.num_experts * ALIGN_SIZE_M,
254+
ALIGN_SIZE_M,
255+
)
256+
num_local_tokens_per_expert = m_sizes
257+
token_indices = torch.vstack(
258+
(token_indices, token_indices.new_zeros((dim)))
259+
)
260+
token_indices = token_indices[permuted_indices, :]
261+
routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim))))
262+
routed_input = routed_input[permuted_indices, :]
263+
else:
264+
# NOTE: this would incur a synchronization between device and host
265+
num_local_tokens_per_expert = num_local_tokens_per_expert.tolist()
266+
209267
# shape (bs*slen*top_k, dim)
210268
routed_output = self.experts(routed_input, num_local_tokens_per_expert)
211269

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ use_flex_attn = false
2929
attn_mask_type = "causal" # causal / block_causal
3030

3131
[optimizer]
32-
name = "AdamW"
32+
# TODO: currently grouped mm in MoE doesn't work with AdamW, need to investigate
33+
# name = "AdamW"
34+
name = "Adam"
3335
lr = 4e-3
3436
eps = 1e-15
3537

torchtitan/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,7 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
391391
dist_utils.dist_max(loss, world_mesh["dp_cp"]),
392392
)
393393
else:
394-
global_avg_loss = global_max_loss = loss.item()
394+
global_avg_loss = global_max_loss = loss.detach().item()
395395

396396
self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)
397397

0 commit comments

Comments
 (0)