Skip to content

Commit 0c5cd6f

Browse files
committed
Fix apply_compile called multiple times in PP initialization
stack-info: PR: #2135, branch: xmfan/stack/8
1 parent fbafd44 commit 0c5cd6f

File tree

2 files changed

+103
-19
lines changed

2 files changed

+103
-19
lines changed

tests/unit_tests/test_compile.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
import torch.nn as nn
11+
12+
from torchtitan.config.job_config import Compile as CompileConfig
13+
from torchtitan.models.llama4.infra.parallelize import apply_compile
14+
15+
16+
class TransformerBlock(nn.Module):
17+
def __init__(self, dim=512):
18+
super().__init__()
19+
self.attention = nn.Linear(dim, dim, bias=False)
20+
self.mlp = nn.Linear(dim, dim, bias=False)
21+
self.moe_enabled = False
22+
23+
def forward(self, x):
24+
x = self.attention(x)
25+
x = self.mlp(x)
26+
return x
27+
28+
29+
class TinyModel(nn.Module):
30+
def __init__(self, num_layers=2, dim=512):
31+
super().__init__()
32+
self.layers = nn.ModuleDict(
33+
{str(i): TransformerBlock(dim) for i in range(num_layers)}
34+
)
35+
36+
def forward(self, x):
37+
for layer in self.layers.values():
38+
x = layer(x)
39+
return x
40+
41+
42+
class TestApplyCompile(unittest.TestCase):
43+
def test_patched_once(self):
44+
"""
45+
Calls apply_compile multiple times, as in the case with PP.
46+
But patches should only happen once
47+
"""
48+
unused_model1 = TinyModel(num_layers=2, dim=128)
49+
unused_model2 = TinyModel(num_layers=2, dim=128)
50+
compile_config = CompileConfig(backend="eager")
51+
52+
apply_compile(unused_model1, compile_config, ep_enabled=True)
53+
apply_compile(unused_model2, compile_config, ep_enabled=True)
54+
55+
from torchtitan.models.moe import moe as moe_module
56+
57+
# Generate sample inputs for _run_experts_grouped_mm
58+
num_experts = 8
59+
dim = 128
60+
hidden_dim = 256
61+
w1 = torch.randn(num_experts, hidden_dim, dim)
62+
w2 = torch.randn(num_experts, dim, hidden_dim)
63+
w3 = torch.randn(num_experts, hidden_dim, dim)
64+
num_tokens_per_expert = torch.tensor([10, 8, 12, 9, 11, 7, 10, 13], dtype=torch.int32)
65+
total_tokens = num_tokens_per_expert.sum().item()
66+
x = torch.randn(total_tokens, dim)
67+
68+
# Call the function, should not error
69+
output = moe_module._run_experts_grouped_mm(w1, w2, w3, x, num_tokens_per_expert)
70+
71+
print(f"Input shape: {x.shape}")
72+
print(f"Output shape: {output.shape}")
73+
print(f"Num tokens per expert: {num_tokens_per_expert}")
74+
75+
76+
if __name__ == "__main__":
77+
unittest.main()

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -572,27 +572,34 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b
572572

573573
model.layers.register_module(layer_id, transformer_block)
574574

575-
moe_module._run_experts_grouped_mm = torch.compile(
576-
moe_module._run_experts_grouped_mm,
577-
backend=compile_config.backend,
578-
fullgraph=True,
575+
# Patch some globals only once (apply_compile is called multiple times for PP setup)
576+
already_patched = (
577+
"_run_experts_grouped_mm_dynamic"
578+
in moe_module._run_experts_grouped_mm.__qualname__
579579
)
580+
if not already_patched:
581+
moe_module._run_experts_grouped_mm = torch.compile(
582+
moe_module._run_experts_grouped_mm,
583+
backend=compile_config.backend,
584+
fullgraph=True,
585+
)
580586

581-
if ep_enabled:
582-
compiled_fn = moe_module._run_experts_grouped_mm
583-
584-
def _run_experts_grouped_mm_dynamic(
585-
w1: torch.Tensor,
586-
w2: torch.Tensor,
587-
w3: torch.Tensor,
588-
x: torch.Tensor,
589-
num_tokens_per_expert: torch.Tensor,
590-
) -> torch.Tensor:
591-
# dynamic number of tokens in expert parallel
592-
torch._dynamo.mark_dynamic(x, 0)
593-
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)
594-
595-
moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
587+
if ep_enabled:
588+
compiled_fn = moe_module._run_experts_grouped_mm
589+
590+
# keep function logic in sync with `already_patched` above
591+
def _run_experts_grouped_mm_dynamic(
592+
w1: torch.Tensor,
593+
w2: torch.Tensor,
594+
w3: torch.Tensor,
595+
x: torch.Tensor,
596+
num_tokens_per_expert: torch.Tensor,
597+
) -> torch.Tensor:
598+
# dynamic number of tokens in expert parallel
599+
torch._dynamo.mark_dynamic(x, 0)
600+
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)
601+
602+
moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
596603

597604
# NOTE: We don't compile for loop code path due to an issue with unbacked symints:
598605
# https://github.com/pytorch/pytorch/issues/166460

0 commit comments

Comments
 (0)