Skip to content

Commit

Permalink
update to pick up mlx v0.16.0 (#115)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkoski committed Jul 15, 2024
1 parent 96fe763 commit 597aaa5
Show file tree
Hide file tree
Showing 9 changed files with 174 additions and 12 deletions.
10 changes: 1 addition & 9 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,9 @@ endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG "v0.0.8")
GIT_TAG "v0.0.9")
FetchContent_MakeAvailable(mlx-c)

# TEMPORARY OVERRIDE -- 0.0.8 depends on v0.14.0 but we need v0.15.2 for iOS /
# float16 issues
FetchContent_Declare(
mlx
GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
GIT_TAG v0.15.2)
FetchContent_MakeAvailable(mlx)

# swift-numerics
set(swift_numerics_patch git apply
${CMAKE_CURRENT_SOURCE_DIR}/cmake/swift-numerics.patch)
Expand Down
4 changes: 4 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ let package = Package(
"mlx/mlx/distributed/mpi",
"mlx/mlx/distributed/ops.cpp",
"mlx/mlx/distributed/primitives.cpp",

// the mlx-c side of distributed
"include/mlx/c/distributed.cpp",
"include/mlx/c/distributed_group.cpp",
],

cSettings: [
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx
128 changes: 128 additions & 0 deletions Source/Cmlx/mlx-generated/hadamard.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
namespace mlx::core::metal {

const char* hadamard() {
return R"preamble(
using namespace metal;
template <short R>
METAL_FUNC void radix_func(thread float* x) {
constexpr short logR = __builtin_ctz(R);
short h = 1;
#pragma clang loop unroll(full)
for (short s = 0; s < logR; s++) {
#pragma clang loop unroll(full)
for (short i = 0; i < R / 2; i++) {
short k = i & (h - 1);
short j = ((i - k) << 1) + k;
float a = x[j];
float b = x[j + h];
x[j] = a + b;
x[j + h] = a - b;
}
h <<= 1;
}
}
template <typename T, int N, int max_radix, int read_width>
[[kernel]] void hadamard_n(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
constexpr short num_threads = N / max_radix;
constexpr short logN = __builtin_ctz(N);
constexpr short logR = __builtin_ctz(max_radix);
constexpr short num_steps = logN / logR;
constexpr short logFinal = logN % logR;
constexpr short final_radix = 1 << (logFinal);
int batch_idx = elem.x * N;
short i = elem.y;
threadgroup T buf[N];
#pragma clang loop unroll(full)
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
buf[index + r] = in[batch_idx + index + r];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
float x[max_radix];
short h = 1;
#pragma clang loop unroll(full)
for (short s = 0; s < num_steps; s++) {
short k = i & (h - 1);
short j = ((i - k) << logR) + k;
#pragma clang loop unroll(full)
for (short r = 0; r < max_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<max_radix>(x);
#pragma clang loop unroll(full)
for (short r = 0; r < max_radix; r++) {
buf[j + h * r] = x[r];
}
h <<= logR;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
if (final_radix > 1) {
#pragma clang loop unroll(full)
for (int t = 0; t < max_radix / final_radix; t++) {
short index = i + t * num_threads;
short k = index & (h - 1);
short j = ((index - k) << logFinal) + k;
#pragma clang loop unroll(full)
for (short r = 0; r < final_radix; r++) {
x[r] = buf[j + h * r];
}
radix_func<final_radix>(x);
#pragma clang loop unroll(full)
for (short r = 0; r < final_radix; r++) {
buf[j + h * r] = x[r];
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
#pragma clang loop unroll(full)
for (short j = 0; j < max_radix / read_width; j++) {
short index = j * read_width * num_threads + i * read_width;
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
out[batch_idx + index + r] = buf[index + r] * scale;
}
}
}
template <typename T, int N, int M, int read_width>
[[kernel]] void hadamard_m(
const device T* in [[buffer(0)]],
device T* out [[buffer(1)]],
constant const float& scale,
uint3 elem [[thread_position_in_grid]],
uint3 grid [[threads_per_grid]]) {
int index = elem.x * grid.y + elem.y;
short i = index % (N / read_width);
int batch_idx = index / (N / read_width) * M * N;
float x[read_width][M];
#pragma clang loop unroll(full)
for (short c = 0; c < M; c++) {
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
x[r][c] = in[batch_idx + c * N + i * read_width + r];
}
}
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
hadamard_radix_m(x[r]);
}
#pragma clang loop unroll(full)
for (short c = 0; c < M; c++) {
#pragma clang loop unroll(full)
for (short r = 0; r < read_width; r++) {
out[batch_idx + c * N + i * read_width + r] = x[r][c] * scale;
}
}
}
)preamble";
}

} // namespace mlx::core::metal
1 change: 1 addition & 0 deletions Source/MLX/Documentation.docc/free-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,4 @@ operations as methods for convenience.

- ``diag(_:k:stream:)``
- ``diagonal(_:offset:axis1:axis2:stream:)``
- ``view(_:dtype:stream:)``
16 changes: 16 additions & 0 deletions Source/MLX/MLXArray+Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2658,4 +2658,20 @@ extension MLXArray {
MLXArray(mlx_var_all(ctx, keepDims, ddof.int32, stream.ctx))
}

/// View the array as a different type.
///
/// The output array will change along the last axis if the input array's
/// type and the output array's type do not have the same size.
///
/// Note: the view op does not imply that the input and output arrays share
/// their underlying data. The view only gaurantees that the binary
/// representation of each element (or group of elements) is the same.
///
/// - Parameters:
/// - dtype: type to change to
/// - stream: stream or device to evaluate on
/// - Returns: array with the new type
public func view(dtype: DType, stream: StreamOrDevice = .default) -> MLXArray {
MLXArray(mlx_view(ctx, dtype.cmlxDtype, stream.ctx))
}
}
19 changes: 19 additions & 0 deletions Source/MLX/Ops+Array.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1715,3 +1715,22 @@ public func variance(
) -> MLXArray {
MLXArray(mlx_var_all(array.ctx, keepDims, ddof.int32, stream.ctx))
}

/// View the array as a different type.
///
/// The output array will change along the last axis if the input array's
/// type and the output array's type do not have the same size.
///
/// Note: the view op does not imply that the input and output arrays share
/// their underlying data. The view only gaurantees that the binary
/// representation of each element (or group of elements) is the same.
///
/// - Parameters:
/// - dtype: type to change to
/// - stream: stream or device to evaluate on
///
/// ### See Also
///- ``MLXArray/view(dtype:stream:)``
public func view(_ array: MLXArray, dtype: DType, stream: StreamOrDevice = .default) -> MLXArray {
MLXArray(mlx_view(array.ctx, dtype.cmlxDtype, stream.ctx))
}
4 changes: 3 additions & 1 deletion tools/update-mlx.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ cmake ../Source/Cmlx/mlx -DMLX_METAL_JIT=ON -DMACOS_VERSION=14.0

# NOTE:
# until mlx supports overriding the METAL_VERSION you will need to edit
# Source/Cmlx/mlx/mlx/backend/metal/CMakeLists.txt and manually set the METAL_VERSION.
# Source/Cmlx/mlx/mlx/backend/metal/CMakeLists.txt and manually set the METAL_VERSION
# to "3.0"
#
# Also Plugins/PrepareMetalShaders/main.swift kernels needs to be in sync.

Expand All @@ -34,6 +35,7 @@ make \
fft \
gather \
gemm \
hadamard \
quantized \
reduce \
reduce_utils \
Expand Down

0 comments on commit 597aaa5

Please sign in to comment.