Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions tests/unit_tests/test_compile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

import torch
import torch.nn as nn

from torchtitan.config.job_config import Compile as CompileConfig
from torchtitan.models.llama4.infra.parallelize import apply_compile


class TransformerBlock(nn.Module):
def __init__(self, dim=512):
super().__init__()
self.attention = nn.Linear(dim, dim, bias=False)
self.mlp = nn.Linear(dim, dim, bias=False)
self.moe_enabled = False

def forward(self, x):
x = self.attention(x)
x = self.mlp(x)
return x


class TinyModel(nn.Module):
def __init__(self, num_layers=2, dim=512):
super().__init__()
self.layers = nn.ModuleDict(
{str(i): TransformerBlock(dim) for i in range(num_layers)}
)

def forward(self, x):
for layer in self.layers.values():
x = layer(x)
return x


class TestApplyCompile(unittest.TestCase):
def test_patched_once(self):
"""
Calls apply_compile multiple times, as in the case with PP.
But patches should only happen once
"""
unused_model1 = TinyModel(num_layers=2, dim=128)
unused_model2 = TinyModel(num_layers=2, dim=128)
compile_config = CompileConfig(backend="eager")

apply_compile(unused_model1, compile_config, ep_enabled=True)
apply_compile(unused_model2, compile_config, ep_enabled=True)

from torchtitan.models.moe import moe as moe_module

# Generate sample inputs for _run_experts_grouped_mm
num_experts = 8
dim = 128
hidden_dim = 256
w1 = torch.randn(num_experts, hidden_dim, dim)
w2 = torch.randn(num_experts, dim, hidden_dim)
w3 = torch.randn(num_experts, hidden_dim, dim)
num_tokens_per_expert = torch.tensor(
[10, 8, 12, 9, 11, 7, 10, 13], dtype=torch.int32
)
total_tokens = num_tokens_per_expert.sum().item()
x = torch.randn(total_tokens, dim)

# Call the function, should not error
output = moe_module._run_experts_grouped_mm(
w1, w2, w3, x, num_tokens_per_expert
)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Num tokens per expert: {num_tokens_per_expert}")


if __name__ == "__main__":
unittest.main()
45 changes: 26 additions & 19 deletions torchtitan/models/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,27 +572,34 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b

model.layers.register_module(layer_id, transformer_block)

moe_module._run_experts_grouped_mm = torch.compile(
moe_module._run_experts_grouped_mm,
backend=compile_config.backend,
fullgraph=True,
# Patch some globals only once (apply_compile is called multiple times for PP setup)
already_patched = (
"_run_experts_grouped_mm_dynamic"
in moe_module._run_experts_grouped_mm.__qualname__
)
if not already_patched:
Comment on lines +575 to +580
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds a temp workaround. Will there be a "permanent" solution?

Copy link
Member Author

@xmfan xmfan Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean (1) the need to mark dynamic or (2) the need to define a global patched method?

(1) afaik marking dynamic is the permanent solution to avoid an initial recompile
(2) patching was chosen to avoid writing this into the model code. Two alternatives:

  • we could mark dynamic the outputs of token dispatch when ep is enabled
  • we could have a global parallelize function for pp to put code that can only run once

moe_module._run_experts_grouped_mm = torch.compile(
moe_module._run_experts_grouped_mm,
backend=compile_config.backend,
fullgraph=True,
)

if ep_enabled:
compiled_fn = moe_module._run_experts_grouped_mm

def _run_experts_grouped_mm_dynamic(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# dynamic number of tokens in expert parallel
torch._dynamo.mark_dynamic(x, 0)
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)

moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
if ep_enabled:
compiled_fn = moe_module._run_experts_grouped_mm

# keep function logic in sync with `already_patched` above
def _run_experts_grouped_mm_dynamic(
w1: torch.Tensor,
w2: torch.Tensor,
w3: torch.Tensor,
x: torch.Tensor,
num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# dynamic number of tokens in expert parallel
torch._dynamo.mark_dynamic(x, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe not relevant to this PR: how are you going to deal with dynamism in aot approach?

Copy link
Member Author

@xmfan xmfan Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depends on the finalized API, you could do explicit dynamic shapes annotations like here, and error in guards evaluation when unexpected dynamic shapes are encountered

return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)

moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic

# NOTE: We don't compile for loop code path due to an issue with unbacked symints:
# https://github.com/pytorch/pytorch/issues/166460
Expand Down
Loading