-
Notifications
You must be signed in to change notification settings - Fork 631
Fix apply_compile called multiple times in PP initialization #2135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,81 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import unittest | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from torchtitan.config.job_config import Compile as CompileConfig | ||
| from torchtitan.models.llama4.infra.parallelize import apply_compile | ||
|
|
||
|
|
||
| class TransformerBlock(nn.Module): | ||
| def __init__(self, dim=512): | ||
| super().__init__() | ||
| self.attention = nn.Linear(dim, dim, bias=False) | ||
| self.mlp = nn.Linear(dim, dim, bias=False) | ||
| self.moe_enabled = False | ||
|
|
||
| def forward(self, x): | ||
| x = self.attention(x) | ||
| x = self.mlp(x) | ||
| return x | ||
|
|
||
|
|
||
| class TinyModel(nn.Module): | ||
| def __init__(self, num_layers=2, dim=512): | ||
| super().__init__() | ||
| self.layers = nn.ModuleDict( | ||
| {str(i): TransformerBlock(dim) for i in range(num_layers)} | ||
| ) | ||
|
|
||
| def forward(self, x): | ||
| for layer in self.layers.values(): | ||
| x = layer(x) | ||
| return x | ||
|
|
||
|
|
||
| class TestApplyCompile(unittest.TestCase): | ||
| def test_patched_once(self): | ||
| """ | ||
| Calls apply_compile multiple times, as in the case with PP. | ||
| But patches should only happen once | ||
| """ | ||
| unused_model1 = TinyModel(num_layers=2, dim=128) | ||
| unused_model2 = TinyModel(num_layers=2, dim=128) | ||
| compile_config = CompileConfig(backend="eager") | ||
|
|
||
| apply_compile(unused_model1, compile_config, ep_enabled=True) | ||
| apply_compile(unused_model2, compile_config, ep_enabled=True) | ||
|
|
||
| from torchtitan.models.moe import moe as moe_module | ||
|
|
||
| # Generate sample inputs for _run_experts_grouped_mm | ||
| num_experts = 8 | ||
| dim = 128 | ||
| hidden_dim = 256 | ||
| w1 = torch.randn(num_experts, hidden_dim, dim) | ||
| w2 = torch.randn(num_experts, dim, hidden_dim) | ||
| w3 = torch.randn(num_experts, hidden_dim, dim) | ||
| num_tokens_per_expert = torch.tensor( | ||
| [10, 8, 12, 9, 11, 7, 10, 13], dtype=torch.int32 | ||
| ) | ||
| total_tokens = num_tokens_per_expert.sum().item() | ||
| x = torch.randn(total_tokens, dim) | ||
|
|
||
| # Call the function, should not error | ||
| output = moe_module._run_experts_grouped_mm( | ||
| w1, w2, w3, x, num_tokens_per_expert | ||
| ) | ||
|
|
||
| print(f"Input shape: {x.shape}") | ||
| print(f"Output shape: {output.shape}") | ||
| print(f"Num tokens per expert: {num_tokens_per_expert}") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -572,27 +572,34 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b | |
|
|
||
| model.layers.register_module(layer_id, transformer_block) | ||
|
|
||
| moe_module._run_experts_grouped_mm = torch.compile( | ||
| moe_module._run_experts_grouped_mm, | ||
| backend=compile_config.backend, | ||
| fullgraph=True, | ||
| # Patch some globals only once (apply_compile is called multiple times for PP setup) | ||
| already_patched = ( | ||
| "_run_experts_grouped_mm_dynamic" | ||
| in moe_module._run_experts_grouped_mm.__qualname__ | ||
| ) | ||
| if not already_patched: | ||
| moe_module._run_experts_grouped_mm = torch.compile( | ||
| moe_module._run_experts_grouped_mm, | ||
| backend=compile_config.backend, | ||
| fullgraph=True, | ||
| ) | ||
|
|
||
| if ep_enabled: | ||
| compiled_fn = moe_module._run_experts_grouped_mm | ||
|
|
||
| def _run_experts_grouped_mm_dynamic( | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w3: torch.Tensor, | ||
| x: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # dynamic number of tokens in expert parallel | ||
| torch._dynamo.mark_dynamic(x, 0) | ||
| return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) | ||
|
|
||
| moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic | ||
| if ep_enabled: | ||
| compiled_fn = moe_module._run_experts_grouped_mm | ||
|
|
||
| # keep function logic in sync with `already_patched` above | ||
| def _run_experts_grouped_mm_dynamic( | ||
| w1: torch.Tensor, | ||
| w2: torch.Tensor, | ||
| w3: torch.Tensor, | ||
| x: torch.Tensor, | ||
| num_tokens_per_expert: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # dynamic number of tokens in expert parallel | ||
| torch._dynamo.mark_dynamic(x, 0) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe not relevant to this PR: how are you going to deal with dynamism in aot approach?
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. depends on the finalized API, you could do explicit dynamic shapes annotations like here, and error in guards evaluation when unexpected dynamic shapes are encountered |
||
| return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) | ||
|
|
||
| moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic | ||
|
|
||
| # NOTE: We don't compile for loop code path due to an issue with unbacked symints: | ||
| # https://github.com/pytorch/pytorch/issues/166460 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds a temp workaround. Will there be a "permanent" solution?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean (1) the need to mark dynamic or (2) the need to define a global patched method?
(1) afaik marking dynamic is the permanent solution to avoid an initial recompile
(2) patching was chosen to avoid writing this into the model code. Two alternatives: