-
Notifications
You must be signed in to change notification settings - Fork 632
[Autoparallel] Add local_map variant of DSv3 and 2D mesh AP #2129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #2129, branch: xmfan/stack/7
stack-info: PR: #2129, branch: xmfan/stack/7
stack-info: PR: #2129, branch: xmfan/stack/7
stack-info: PR: #2129, branch: xmfan/stack/7
stack-info: PR: #2129, branch: xmfan/stack/7
stack-info: PR: #2129, branch: xmfan/stack/7
| for layer in model.layers.values(): | ||
| if layer.moe_enabled: | ||
| layer.moe.mesh = world_mesh | ||
| layer.moe.axis_name = "dp_shard_in_ep" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems okay for now, but when we enable TP, this should change so just add a comment that modify this when enabling TP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added an assert to check for dp2ep specifically, because i don't think we're handling the mesh setup well...
| parallel_dims.tp_enabled | ||
| and not job_config.parallelism.disable_loss_parallel | ||
| ) | ||
| assert not loss_parallel_enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of failing here why not disable loss parallel and give a warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warnings are easy to miss, you won't know loss parallel is disable if you don't spot it before titan starts dumping its logs
sanketpurandare
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments that can be addressed optionally
stack-info: PR: #2129, branch: xmfan/stack/7
Stacked PRs:
[Autoparallel] Add local_map variant of DSv3 and 2D mesh AP
Currently, the AP experiment monkey patches Titan's main DSv3 implementation. But this is prone to breakage from both model definition changes in titan and from HOP/partitioner related changes in core. When these breaks happen, people are usually blocked until I find the root cause.
I'm going on PTO for the rest of the year, so I'm adding an integration to AP's DSv3 model in an attempt to make the development more stable for the upcoming PP integration.
Test: https://gist.github.com/xmfan/db15fda1e1bc1df7cd523005fe0baf33