diff --git a/mlx/backend/cpu/primitives.cpp b/mlx/backend/cpu/primitives.cpp index b5d9d7ef35..126ebe2d54 100644 --- a/mlx/backend/cpu/primitives.cpp +++ b/mlx/backend/cpu/primitives.cpp @@ -141,8 +141,10 @@ void Concatenate::eval_cpu(const std::vector& inputs, array& out) { void Contiguous::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous)) { + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { out.copy_shared_buffer(in); } else { copy(in, out, CopyType::General); diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index df8638012f..20d8409ada 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -252,8 +252,10 @@ void Concatenate::eval_gpu(const std::vector& inputs, array& out) { void Contiguous::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); auto& in = inputs[0]; - if (in.flags().row_contiguous || - (allow_col_major_ && in.flags().col_contiguous)) { + constexpr size_t extra_bytes = 16384; + if (in.buffer_size() <= out.nbytes() + extra_bytes && + (in.flags().row_contiguous || + (allow_col_major_ && in.flags().col_contiguous))) { move_or_copy(in, out); } else { copy_gpu(in, out, CopyType::General); diff --git a/mlx/ops.cpp b/mlx/ops.cpp index 4e147487d5..5a64a78521 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -993,6 +993,9 @@ array concatenate( throw std::invalid_argument( "[concatenate] No arrays provided for concatenation"); } + if (arrays.size() == 1) { + return arrays[0]; + } auto ax = normalize_axis_index(axis, arrays[0].ndim(), "[concatenate] "); diff --git a/python/mlx/distributed_run.py b/python/mlx/distributed_run.py index 5d6bc4383d..3391ba620d 100644 --- a/python/mlx/distributed_run.py +++ b/python/mlx/distributed_run.py @@ -761,6 +761,8 @@ def main(): "--cwd", help="Set the working directory on each node to the provided one" ) args, rest = parser.parse_known_args() + if rest[0] == "--": + rest.pop(0) if args.print_python: print(sys.executable) diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index c1d89fed9f..26f77917fc 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -60,6 +60,12 @@ ConvTranspose2d, ConvTranspose3d, ) +from mlx.nn.layers.distributed import ( + AllToShardedLinear, + QuantizedAllToShardedLinear, + QuantizedShardedToAllLinear, + ShardedToAllLinear, +) from mlx.nn.layers.dropout import Dropout, Dropout2d, Dropout3d from mlx.nn.layers.embedding import Embedding from mlx.nn.layers.linear import Bilinear, Identity, Linear diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py new file mode 100644 index 0000000000..4e9a665c60 --- /dev/null +++ b/python/mlx/nn/layers/distributed.py @@ -0,0 +1,553 @@ +# Copyright © 2024 Apple Inc. + +import math +from functools import lru_cache +from typing import Optional, Union + +import mlx.core as mx +from mlx.nn.layers.base import Module +from mlx.nn.layers.linear import Linear +from mlx.nn.layers.quantized import QuantizedLinear + + +@lru_cache +def sum_gradients(group): + if group.size() == 1: + return lambda x: x + + @mx.custom_function + def f(x): + return x + + @f.vjp + def f(x, dx, _): + return mx.distributed.all_sum(dx, group=group) + + return f + + +def _split(weight, groups, axis): + if isinstance(groups, int) or isinstance(groups[0], int): + return mx.split(weight, groups, axis=axis) + + N = weight.shape[axis] + indices = [int(g * N) for g in groups] + return mx.split(weight, indices, axis=axis) + + +def _all_to_sharded( + parameters: dict, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): + group = group or mx.distributed.init() + N = group.size() + r = group.rank() + + # The multiplication with 1 forces a copy, perhaps change to + # something better when available. + for k in parameters: + if not isinstance(parameters[k], mx.array): + continue + + axis = max(parameters[k].ndim - 2, 0) + parameters[k] = mx.contiguous( + mx.concatenate( + [ + _split(part, N, axis)[r] + for part in _split(parameters[k], groups, axis) + ], + axis=axis, + ) + ) + + return parameters + + +def _sharded_to_all( + parameters: dict, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): + group = group or mx.distributed.init() + N = group.size() + r = group.rank() + + # The multiplication with 1 forces a copy, perhaps change to + # something better when available. + for k in parameters: + if not isinstance(parameters[k], mx.array): + continue + if k == "bias": + continue + + parameters[k] = mx.contiguous( + mx.concatenate( + [_split(part, N, -1)[r] for part in _split(parameters[k], groups, -1)], + axis=-1, + ) + ) + + return parameters + + +def _check_sharding(sharding): + if sharding not in ("all-to-sharded", "sharded-to-all"): + raise ValueError( + ( + f"Sharding type {sharding=} not supported, " + "choose one of 'all-to-sharded' or 'sharded-to-all'" + ) + ) + + +def shard_inplace( + module: Module, + sharding: str, + *, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): + _check_sharding(sharding) + shard_function = ( + _all_to_sharded if sharding == "all-to-sharded" else _sharded_to_all + ) + module.update(shard_function(module.parameters(), groups=groups, group=group)) + + +def shard_linear( + module: Module, + sharding: str, + *, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, +): + _check_sharding(sharding) + fns = { + ("all-to-sharded", True): AllToShardedLinear.from_linear, + ("all-to-sharded", False): QuantizedAllToShardedLinear.from_quantized_linear, + ("sharded-to-all", True): ShardedToAllLinear.from_linear, + ("sharded-to-all", False): QuantizedShardedToAllLinear.from_quantized_linear, + } + return fns[sharding, isinstance(module, Linear)](module, groups=groups, group=group) + + +class AllToShardedLinear(Module): + """Each member of the group applies part of the affine transformation such + that the result is sharded across the group. + + The gradients are automatically aggregated from each member of the group. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` the the layer will not use a + bias. Default is ``True``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Initialize the parameters + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (output_dims % N) != 0: + raise ValueError( + f"Cannot shard the output of size {output_dims} across {N} devices." + ) + + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N, input_dims), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N,), + ) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + N = self.group.size() + out_dims *= N + return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" + + def __call__(self, x: mx.array) -> mx.array: + # Aggregate the gradients coming from each shard + if self.group.size() > 1: + x = sum_gradients(self.group)(x) + + # Compute the affine projection + if "bias" in self: + x = mx.addmm(self["bias"], x, self["weight"].T) + else: + x = x @ self["weight"].T + return x + + @classmethod + def from_linear( + cls, + linear_layer: Module, + *, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = linear_layer.weight.shape + + sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) + sl.update( + _all_to_sharded(linear_layer.parameters(), groups=groups, group=group) + ) + + return sl + + +class ShardedToAllLinear(Module): + """Each member of the group applies part of the affine transformation and + then aggregates the results. + + All nodes will have the same exact result after this layer. + + :class:`ShardedToAllLinear` provides a classmethod :meth:`from_linear` to + convert linear layers to sharded :obj:`ShardedToAllLinear` layers. + + Args: + input_dims (int): The dimensionality of the input features + output_dims (int): The dimensionality of the output features + bias (bool, optional): If set to ``False`` the the layer will not use a + bias. Default is ``True``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Initialize the parameters + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (input_dims % N) != 0: + raise ValueError( + f"The input of size {input_dims} cannot be sharded across {N} devices." + ) + + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims // N), + ) + if bias: + self.bias = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims,), + ) + + def _extra_repr(self) -> str: + N = self.group.size() + out_dims, in_dims = self.weight.shape + in_dims *= N + return f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}" + + def __call__(self, x: mx.array) -> mx.array: + if self.group.size() > 1: + # Perform the local projection and aggregate the results + x = x @ self["weight"].T + x = mx.distributed.all_sum(x, group=self.group) + + # Add the bias if we have one + if "bias" in self: + x = x + self["bias"] + else: + # Normal linear layer as we are not in a distributed setting. + if "bias" in self: + x = mx.addmm(self["bias"], x, self["weight"].T) + else: + x = x @ self["weight"].T + return x + + @classmethod + def from_linear( + cls, + linear_layer: Module, + *, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = linear_layer.weight.shape + + sl = cls(input_dims, output_dims, hasattr(linear_layer, "bias"), group) + sl.update( + _sharded_to_all(linear_layer.parameters(), groups=groups, group=group) + ) + + return sl + + +class QuantizedAllToShardedLinear(Module): + """Each member of the group applies part of the affine transformation with + a quantized matrix such that the result is sharded across the group. + + It is the quantized equivalent of :class:`mlx.nn.AllToShardedLinear`. + Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and + will not be included in any gradient computation. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``True``. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (output_dims % N) != 0: + raise ValueError( + f"Cannot shard the output of size {output_dims} across {N} devices." + ) + + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims // N, input_dims), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims // N,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + in_dims *= 32 // self.bits + out_dims *= self.group.size() + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + def __call__(self, x: mx.array) -> mx.array: + # Aggregate the gradients coming from each shard + if self.group.size() > 1: + x = sum_gradients(self.group)(x) + + x = mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if "bias" in self: + x = x + self["bias"] + return x + + @classmethod + def from_quantized_linear( + cls, + quantized_linear_layer: Module, + *, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = quantized_linear_layer.weight.shape + input_dims *= 32 // quantized_linear_layer.bits + + sl = cls( + input_dims, + output_dims, + hasattr(quantized_linear_layer, "bias"), + group_size=quantized_linear_layer.group_size, + bits=quantized_linear_layer.bits, + group=group, + ) + sl.update( + _all_to_sharded( + quantized_linear_layer.parameters(), groups=groups, group=group + ) + ) + + return sl + + +class QuantizedShardedToAllLinear(Module): + """Each member of the group applies part of the affine transformation using + the quantized matrix and then aggregates the results. + + All nodes will have the same exact result after this layer. + + It is the quantized equivalent of :class:`mlx.nn.ShardedToAllLinear`. + Similar to :class:`mlx.nn.QuantizedLinear` its parameters are frozen and + will not be included in any gradient computation. + + Args: + input_dims (int): The dimensionality of the input features. + output_dims (int): The dimensionality of the output features. + bias (bool, optional): If set to ``False`` then the layer will not use + a bias. Default: ``True``. + group_size (int, optional): The group size to use for the quantized + weight. See :func:`~mlx.core.quantize`. Default: ``64``. + bits (int, optional): The bit width to use for the quantized weight. + See :func:`~mlx.core.quantize`. Default: ``4``. + group (mx.distributed.Group, optional): The sharding will happen across + this group. If not set then the global group is used. Default is + ``None``. + """ + + def __init__( + self, + input_dims: int, + output_dims: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + group: Optional[mx.distributed.Group] = None, + ): + super().__init__() + + # Quantization config + self.group_size = group_size + self.bits = bits + + # Initialize the quantized weight + scale = math.sqrt(1.0 / input_dims) + self.group = group or mx.distributed.init() + N = self.group.size() + + if (input_dims % N) != 0: + raise ValueError( + f"The input of size {input_dims} cannot be sharded across {N} devices." + ) + + weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(output_dims, input_dims // N), + ) + self.weight, self.scales, self.biases = mx.quantize(weight, group_size, bits) + + # And bias if needed + if bias: + self.bias = mx.zeros((output_dims,)) + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + def _extra_repr(self) -> str: + out_dims, in_dims = self.weight.shape + in_dims *= (32 // self.bits) * self.group.size() + return ( + f"input_dims={in_dims}, output_dims={out_dims}, bias={'bias' in self}, " + f"group_size={self.group_size}, bits={self.bits}" + ) + + def __call__(self, x: mx.array) -> mx.array: + x = mx.quantized_matmul( + x, + self["weight"], + scales=self["scales"], + biases=self["biases"], + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if self.group.size() > 1: + x = mx.distributed.all_sum(x, group=self.group) + if "bias" in self: + x = x + self["bias"] + return x + + @classmethod + def from_quantized_linear( + cls, + quantized_linear_layer: Module, + *, + groups: Union[int, list] = 1, + group: Optional[mx.distributed.Group] = None, + ): + group = group or mx.distributed.init() + output_dims, input_dims = quantized_linear_layer.weight.shape + input_dims *= 32 // quantized_linear_layer.bits + + sl = cls( + input_dims, + output_dims, + hasattr(quantized_linear_layer, "bias"), + group_size=quantized_linear_layer.group_size, + bits=quantized_linear_layer.bits, + group=group, + ) + sl.update( + _sharded_to_all( + quantized_linear_layer.parameters(), groups=groups, group=group + ) + ) + + return sl diff --git a/python/src/ops.cpp b/python/src/ops.cpp index 1577cae185..6de580d1b9 100644 --- a/python/src/ops.cpp +++ b/python/src/ops.cpp @@ -5124,4 +5124,23 @@ void init_ops(nb::module_& m) { [0, 1, 0], [0, 1, 0]], dtype=float32) )pbdoc"); + m.def( + "contiguous", + &mx::contiguous, + nb::arg(), + "allow_col_major"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def contiguous(a: array, /, allow_col_major: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Force an array to be row contiguous. Copy if necessary. + + Args: + a (array): The input to make contiguous + allow_col_major (bool): Consider column major as contiguous and don't copy + + Returns: + array: The row or col contiguous output. + )pbdoc"); }