-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Distributed layers #1270
base: main
Are you sure you want to change the base?
Distributed layers #1270
Conversation
I kind of like this design. I like that it's all quite simple and easy to follow and we have a lot of control over how to shard the model (as in ml-explore/mlx-examples#890). We could possibly find a way to reduce the code needed for adding a new custom linear-like layer.. but the simplicity is nice, I wouldn't want to give that up. |
6090542
to
fea9644
Compare
061d214
to
b32ce2c
Compare
ab26116
to
3d431c0
Compare
2298954
to
1697581
Compare
31ba022
to
60e7e02
Compare
07b5bd5
to
794eb42
Compare
517eb95
to
a323642
Compare
a323642
to
dd89374
Compare
I am marking this ready for review. The main things that are new since I started the branch: Exposing
The sharding functions now also take a |
# The multiplication with 1 forces a copy, perhaps change to | ||
# something better when available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit remove comment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops
# The multiplication with 1 forces a copy, perhaps change to | ||
# something better when available. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same.
if not isinstance(parameters[k], mx.array): | ||
continue | ||
|
||
axis = max(parameters[k].ndim - 2, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should that just be always 0
? Maybe would work for conv in that case as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well it assumes linear layers (as in fully connected layers). It isn't 0 so that it can work with Switch layers and their variants. Perhaps combining with your comment below we could make this more general and support both.
def shard_inplace( | ||
module: Module, | ||
sharding: str, | ||
*, | ||
groups: Union[int, list] = 1, | ||
group: Optional[mx.distributed.Group] = None, | ||
): | ||
_check_sharding(sharding) | ||
shard_function = ( | ||
_all_to_sharded if sharding == "all-to-sharded" else _sharded_to_all | ||
) | ||
module.update(shard_function(module.parameters(), groups=groups, group=group)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to have the this take a callable which returns a sharding based on a key? It would be more like nn.quantize
and capable of one-shot sharding a Module in place with a given policy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Possibly you are right. Ideally we can keep the same interface but provide this as an extra. Returning a tuple (or named tuple) with axis
, groups
(renamed) given path
and weight
would a nice interface I think.
The purpose there is to allow uneven shardings? I think it would be good to think on a name that is more different from |
Totally agree that we should name it something different. It isn't for uneven shadings in the sense that one node can take 70% of the computation. This isn't supported in this API. It is for weights that comprise several concatenated weights. In this case for the sharded linear to be valid we need to split, shard and concatenate. Otherwise one node will get all the queries and no keys and so on. |
Ah that makes sense now. Some suggestions on alternative names:
Maybe it makes sense to prefix |
Adds linear layers that allow training and inference of a model sharded across several devices. The main things added are
simply changing linear layers to the above results in a model that works out of the box with distributed inference and training.
I am starting it as a draft so that we can iterate a bit on the design. The negative aspects of the above design are that we have yet another linear layer to think about when implementing LoRA and friends or weird new quantizations for instance. Perhaps it would be better to make the above layers with an internal linear layer so model surgery that swaps linear layers would still work out of the box.