-
Notifications
You must be signed in to change notification settings - Fork 631
Fix apply_compile called multiple times in PP initialization #2135
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #2135, branch: xmfan/stack/8
stack-info: PR: #2135, branch: xmfan/stack/8
stack-info: PR: #2135, branch: xmfan/stack/8
| # Patch some globals only once (apply_compile is called multiple times for PP setup) | ||
| already_patched = ( | ||
| "_run_experts_grouped_mm_dynamic" | ||
| in moe_module._run_experts_grouped_mm.__qualname__ | ||
| ) | ||
| if not already_patched: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sounds a temp workaround. Will there be a "permanent" solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you mean (1) the need to mark dynamic or (2) the need to define a global patched method?
(1) afaik marking dynamic is the permanent solution to avoid an initial recompile
(2) patching was chosen to avoid writing this into the model code. Two alternatives:
- we could mark dynamic the outputs of token dispatch when ep is enabled
- we could have a global parallelize function for pp to put code that can only run once
tianyu-l
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should add a test in the integration test to guard against the repro in #2124
I'm a bit confused why we were not hitting this error in https://github.com/pytorch/torchtitan/blob/main/tests/integration_tests/models.py#L96
| num_tokens_per_expert: torch.Tensor, | ||
| ) -> torch.Tensor: | ||
| # dynamic number of tokens in expert parallel | ||
| torch._dynamo.mark_dynamic(x, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe not relevant to this PR: how are you going to deal with dynamism in aot approach?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
depends on the finalized API, you could do explicit dynamic shapes annotations like here, and error in guards evaluation when unexpected dynamic shapes are encountered
Stacked PRs:
PP initialization calls apply_compile multiple times, once per pp stage. But apply_compile does some global patching. So I add
already_patchedto avoid patching the same method multiple times.If we patch multiple times, the second time will wrap
_run_experts_grouped_mm_dynamicin a torch.compile(fullgraph=True) leading to the error in the issue below.FIXES #2124