-
Notifications
You must be signed in to change notification settings - Fork 631
[moe] brings batch/sequence-wise load balance loss #2061
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…d seq-wise aux loss for load balance
torchtitan/train.py
Outdated
| job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager | ||
| ) | ||
|
|
||
| self.loss_fn = functools.partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can add a condition here to wrap loss or not for MoE. for now all models in torchtitan only return a single output so its ok for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.
wwwjn
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! @shuhuayu is working on a more formal review, and I have some house-keeping comments
torchtitan/config/job_config.py
Outdated
|
|
||
|
|
||
| @dataclass | ||
| class ExtraLosses: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This section is specifically for MoE load balancing loss for now, do you foresee any other loss related params will be used in this section? If not, let's make the name for descriptive and specific
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Followup here. Should we merge these configs to the Model dataclass?
shuhuayu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the pr @rakkit ! Made some comments here.
torchtitan/train.py
Outdated
| job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager | ||
| ) | ||
|
|
||
| self.loss_fn = functools.partial( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.
|
Thanks a lot for the feedback, @wwwjn @shuhuayu (sorry for the late update)! Summary of new changes:
And be aware that the PP & aux-loss still does not work |
torchtitan/models/moe/moe.py
Outdated
| self.load_balance_loss_weight, | ||
| ) | ||
| else: | ||
| load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I can see out is not defined in this scope yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed. thanks : )
| @staticmethod | ||
| def sequence_wise_aux_loss( | ||
| scores: torch.Tensor, | ||
| indices: torch.Tensor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will use the biased topk(scores + expert_bias) instead of the unbiased topk(scores) from DSv3 eq 18
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope, thats top_scores
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yeah, scores is the raw sigmoid output. But isn't indices (= selected_experts_indices) derived as topk(scores + expert_bias)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
emm, good question. need to think about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think you might be right, eq 18 the topk dont have "bias"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks. I fixed this and rerun the two aux loss types and no aux loss in PR description.
This is a draft PR for:
For now, it only applies to the DeepSeek model, but I can add it for all other moe models at the end.
(also, we dont log the aux loss, but i can add it in optimizer hook to do this if you want)
The main concern is that the aux loss does not work well with PP. From what I have tested, it works well only with 1F1B. And it is broken for ZBV or interleaved 1f1b.
To test it:

[sequence_wise, by default]
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --training.extra_losses.load_balance_loss_weight=0.001[batch_wise, need to pick this in ModelArgs]

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --training.extra_losses.load_balance_loss_weight=0.001(turn it off)

CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh