Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distributed layers #1270

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open

Distributed layers #1270

wants to merge 9 commits into from

Conversation

angeloskath
Copy link
Member

Adds linear layers that allow training and inference of a model sharded across several devices. The main things added are

  • float16/bfloat16 reductions for MPI
  • AllToShardedLinear and its quantized sibling
  • ShardedToAllLinear and its quantized sibling

simply changing linear layers to the above results in a model that works out of the box with distributed inference and training.

I am starting it as a draft so that we can iterate a bit on the design. The negative aspects of the above design are that we have yet another linear layer to think about when implementing LoRA and friends or weird new quantizations for instance. Perhaps it would be better to make the above layers with an internal linear layer so model surgery that swaps linear layers would still work out of the box.

@awni
Copy link
Member

awni commented Jul 17, 2024

I kind of like this design. I like that it's all quite simple and easy to follow and we have a lot of control over how to shard the model (as in ml-explore/mlx-examples#890). We could possibly find a way to reduce the code needed for adding a new custom linear-like layer.. but the simplicity is nice, I wouldn't want to give that up.

@angeloskath angeloskath force-pushed the distributed-layers branch 2 times, most recently from 061d214 to b32ce2c Compare August 29, 2024 08:20
@angeloskath angeloskath force-pushed the distributed-layers branch 2 times, most recently from ab26116 to 3d431c0 Compare September 6, 2024 18:03
@awni awni mentioned this pull request Sep 16, 2024
@angeloskath angeloskath force-pushed the distributed-layers branch 5 times, most recently from 2298954 to 1697581 Compare November 5, 2024 19:35
@awni awni force-pushed the distributed-layers branch 3 times, most recently from 31ba022 to 60e7e02 Compare January 18, 2025 14:06
@awni awni force-pushed the distributed-layers branch 2 times, most recently from 07b5bd5 to 794eb42 Compare February 6, 2025 15:36
@angeloskath angeloskath force-pushed the distributed-layers branch 3 times, most recently from 517eb95 to a323642 Compare March 4, 2025 21:32
@angeloskath
Copy link
Member Author

I am marking this ready for review. The main things that are new since I started the branch:

Exposing mx.contiguous. This ensures both that the array is contiguous and that it occupies at most x.size() * x.itemsize() + 16384 bytes. Mainly a contiguous slice is still going to be copied.

shard_linear convenience function and shard_inplace. The first one just creates the appropriate linear layer quantized or not. The second actually shards the parameters in place which allows us to shard any layer and apply the collective operations as we see fit. It is used for instance to shard the single stream transformer blocks in FLUX but only perform one communication (ml-explore/mlx-examples#1325).

The sharding functions now also take a groups argument. This assumes the linear layer is a fused one and splits it according to the groups argument (evenly or percentage wise). I think the argument name may need improving here.

@angeloskath angeloskath marked this pull request as ready for review March 6, 2025 23:29
Comment on lines +47 to +48
# The multiplication with 1 forces a copy, perhaps change to
# something better when available.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit remove comment?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops

Comment on lines +76 to +77
# The multiplication with 1 forces a copy, perhaps change to
# something better when available.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

if not isinstance(parameters[k], mx.array):
continue

axis = max(parameters[k].ndim - 2, 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should that just be always 0? Maybe would work for conv in that case as well?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well it assumes linear layers (as in fully connected layers). It isn't 0 so that it can work with Switch layers and their variants. Perhaps combining with your comment below we could make this more general and support both.

Comment on lines +104 to +115
def shard_inplace(
module: Module,
sharding: str,
*,
groups: Union[int, list] = 1,
group: Optional[mx.distributed.Group] = None,
):
_check_sharding(sharding)
shard_function = (
_all_to_sharded if sharding == "all-to-sharded" else _sharded_to_all
)
module.update(shard_function(module.parameters(), groups=groups, group=group))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have the this take a callable which returns a sharding based on a key? It would be more like nn.quantize and capable of one-shot sharding a Module in place with a given policy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly you are right. Ideally we can keep the same interface but provide this as an extra. Returning a tuple (or named tuple) with axis, groups (renamed) given path and weight would a nice interface I think.

@awni
Copy link
Member

awni commented Mar 11, 2025

The sharding functions now also take a groups argument. This assumes the linear layer is a fused one and splits it according to the groups argument (evenly or percentage wise)

The purpose there is to allow uneven shardings? I think it would be good to think on a name that is more different from group.

@angeloskath
Copy link
Member Author

The purpose there is to allow uneven shardings?

Totally agree that we should name it something different. It isn't for uneven shadings in the sense that one node can take 70% of the computation. This isn't supported in this API. It is for weights that comprise several concatenated weights. In this case for the sharded linear to be valid we need to split, shard and concatenate. Otherwise one node will get all the queries and no keys and so on.

@awni
Copy link
Member

awni commented Mar 11, 2025

Otherwise one node will get all the queries and no keys and so on.

Ah that makes sense now. Some suggestions on alternative names:

  • shards
  • segments
  • sections
  • splits

Maybe it makes sense to prefix sub with one of those like sub_shards?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants