-
Notifications
You must be signed in to change notification settings - Fork 568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Decoupled Momentum Optimization #771
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks interesting, but I don't actually know what this optimizer is. Can you give some background?
Edit: There is a description at the top. Am blind.
compression_topk=cfg.optimizer.compression_topk, | ||
compression_chunk=cfg.optimizer.compression_chunk, | ||
weight_decay=cfg.optimizer.weight_decay, | ||
process_group=None, # TODO: fix for hybrid sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems important? Hybrid is necessary for big models.
### DeMo parameters | ||
compression_decay: float = 0.999 | ||
|
||
compression_topk: int = 32 | ||
""" | ||
How many numbers of topk to transmit per chunk, if dynamic is enabled, this is the initial topk | ||
""" | ||
|
||
compression_chunk: int = 64 | ||
""" | ||
Size of the chunk of the gradients, note that 2D gradients are chunked in 2D, which the topk sparsity is squared compared to 1D | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Prefix these with demo_
?
disable_grad_sync: bool = False | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see this setting twice, once here, and once in DDPGradSyncMode
?
@@ -647,6 +649,177 @@ def get_post_step_metrics( | |||
return metrics | |||
|
|||
|
|||
class DeMo(torch.optim.SGD, Optimizer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the organization would make more sense if this class, and demo_utils.py
, were in their own file together, and then we use __all__
to make this optimizer appear the same as the others.
Oh, I see. You put a reference in the description 🙈. Paper says you pushed this to 1B/100B tokens. Can you go further? Experience says, things like this stop working if you go really big. |
Cleaned-up version of https://github.com/bloc97/DeMo for integrating efficient distributed training a la Decoupled Monentum Optimization (https://arxiv.org/abs/2411.19870)