Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Dec 10, 2025

Stacked PRs:


PP initialization calls apply_compile multiple times, once per pp stage. But apply_compile does some global patching. So I add already_patched to avoid patching the same method multiple times.

If we patch multiple times, the second time will wrap _run_experts_grouped_mm_dynamic in a torch.compile(fullgraph=True) leading to the error in the issue below.

FIXES #2124

xmfan added a commit that referenced this pull request Dec 10, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 10, 2025
xmfan added a commit that referenced this pull request Dec 10, 2025
Comment on lines +575 to +580
# Patch some globals only once (apply_compile is called multiple times for PP setup)
already_patched = (
"_run_experts_grouped_mm_dynamic"
in moe_module._run_experts_grouped_mm.__qualname__
)
if not already_patched:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds a temp workaround. Will there be a "permanent" solution?

Copy link
Member Author

@xmfan xmfan Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean (1) the need to mark dynamic or (2) the need to define a global patched method?

(1) afaik marking dynamic is the permanent solution to avoid an initial recompile
(2) patching was chosen to avoid writing this into the model code. Two alternatives:

  • we could mark dynamic the outputs of token dispatch when ep is enabled
  • we could have a global parallelize function for pp to put code that can only run once

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add a test in the integration test to guard against the repro in #2124

I'm a bit confused why we were not hitting this error in https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/models.py#L96

num_tokens_per_expert: torch.Tensor,
) -> torch.Tensor:
# dynamic number of tokens in expert parallel
torch._dynamo.mark_dynamic(x, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe not relevant to this PR: how are you going to deal with dynamism in aot approach?

Copy link
Member Author

@xmfan xmfan Dec 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

depends on the finalized API, you could do explicit dynamic shapes annotations like here, and error in guards evaluation when unexpected dynamic shapes are encountered

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

DeepSeekV3 model component is not compilable when using Interleaved1F1B pipeline parallelism

3 participants