Skip to content

Commit

Permalink
V0.15.2 mlx/mlx-c update (#101)
Browse files Browse the repository at this point in the history
v0.15.2 update

- this is mlx-c aligned with v0.14.0 and mlx v0.15.2

- fix x86 release builds -- conditional compile for neon code
- add bitwise ops from previous releases (now present in mlx-c)
- big change is JIT metal shaders -- the build should be faster and smaller
  • Loading branch information
davidkoski authored Jul 1, 2024
1 parent c11212b commit 084597e
Show file tree
Hide file tree
Showing 254 changed files with 99,426 additions and 1,317 deletions.
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,17 @@ endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG "v0.0.7")
GIT_TAG "v0.0.8")
FetchContent_MakeAvailable(mlx-c)

# TEMPORARY OVERRIDE -- 0.0.8 depends on v0.14.0 but we need v0.15.2 for iOS /
# float16 issues
FetchContent_Declare(
mlx
GIT_REPOSITORY "https://github.com/ml-explore/mlx.git"
GIT_TAG v0.15.2)
FetchContent_MakeAvailable(mlx)

# swift-numerics
set(swift_numerics_patch git apply
${CMAKE_CURRENT_SOURCE_DIR}/cmake/swift-numerics.patch)
Expand Down
27 changes: 26 additions & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ let package = Package(
// vendored library, do not include driver
"gguf-tools/gguf-tools.c",

// vendored library
"fmt/test",
"fmt/doc",
"fmt/support",
"fmt/src/os.cc",
"fmt/src/fmt.cc",

// these are selected conditionally
// via mlx-conditional/compiled_conditional.cpp
"mlx/mlx/backend/common/compiled_nocpu.cpp",
Expand All @@ -77,10 +84,27 @@ let package = Package(

// opt-out of these backends (using metal)
"mlx/mlx/backend/no_metal",
"mlx/mlx/backend/accelerate",
"mlx/mlx/backend/no_cpu",

"mlx/mlx/backend/common/default_primitives.cpp",

// this uses neon code and will not build on x86 (e.g. via Release).
// see mlx-conditional/accelerate-softmax.cpp
"mlx/mlx/backend/accelerate/softmax.cpp",

// build variants (we are opting _out_ of these)
"mlx/mlx/io/no_safetensors.cpp",
"mlx/mlx/io/gguf.cpp",
"mlx/mlx/io/gguf_quants.cpp",

// see PrepareMetalShaders -- don't build the kernels in place
"mlx/mlx/backend/metal/kernels",
"mlx/mlx/backend/metal/nojit_kernels.cpp",

// do not build distributed support (yet)
"mlx/mlx/distributed/mpi",
"mlx/mlx/distributed/ops.cpp",
"mlx/mlx/distributed/primitives.cpp",
],

cSettings: [
Expand All @@ -94,6 +118,7 @@ let package = Package(
.headerSearchPath("metal-cpp"),
.headerSearchPath("json/single_include/nlohmann"),
.headerSearchPath("gguf-tools"),
.headerSearchPath("fmt/include"),

.define("ACCELERATE_NEW_LAPACK"),
.define("_METAL_"),
Expand Down
41 changes: 41 additions & 0 deletions Plugins/PrepareMetalShaders/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ struct PrepareMetalShaders: BuildToolPlugin {
/// pattern to rewrite
private let include = try! Regex("#include \"mlx/backend/metal/kernels/([^\"]*)\"")

// see Source/Cmlx/mlx/mlx/backend/metal/kernels/CMakeLists.txt
let kernels: Set = [
"arg_reduce.metal",
"conv.metal",
"gemv.metal",
"gemv_masked.metal",
"random.metal",
"rms_norm.metal",
"layer_norm.metal",
"rope.metal",
"scaled_dot_product_attention.metal",
]

func transformIncludes(url: URL) throws {
let contents = try String(contentsOf: url, encoding: .utf8)

Expand Down Expand Up @@ -79,6 +92,12 @@ struct PrepareMetalShaders: BuildToolPlugin {
continue
}

if url.pathExtension == "h" || kernels.contains(url.lastPathComponent) {
// ok
} else {
continue
}

let modDate = resourceValues.contentModificationDate ?? Date()

// these will be moved to the top level (see below in building)
Expand Down Expand Up @@ -184,6 +203,28 @@ struct PrepareMetalShaders: BuildToolPlugin {
}
}

// remove any kernels that are not in the list
if let enumerator = FileManager.default.enumerator(
at: destination, includingPropertiesForKeys: [.isRegularFileKey],
options: [.skipsHiddenFiles, .skipsPackageDescendants])
{
for case let url as URL in enumerator {
let isRegularFile =
try url.resourceValues(forKeys: [.isRegularFileKey]).isRegularFile ?? false
guard isRegularFile else {
continue
}

if url.pathExtension == "h" || kernels.contains(url.lastPathComponent) {
// keep it
print("keeping \(url.lastPathComponent)")
} else {
print("removing \(url.lastPathComponent)")
try FileManager.default.removeItem(at: url)
}
}
}

// foreach file, transform the #includes
if let enumerator = FileManager.default.enumerator(
at: destination, includingPropertiesForKeys: [.isRegularFileKey],
Expand Down
1 change: 1 addition & 0 deletions Source/Cmlx/fmt
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx
Submodule mlx updated 284 files
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx-c
19 changes: 19 additions & 0 deletions Source/Cmlx/mlx-conditional/accelerate-softmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
// Copyright © 2024 Apple Inc.

// Note: this stubs out accelerate/softmax on x86 (e.g. via Release builds)

#if defined(__aarch64__)

#include "../mlx/mlx/backend/accelerate/softmax.cpp"

#else

#include "mlx/primitives.h"

namespace mlx::core {
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) {
Softmax::eval(inputs, out);
}
}

#endif
17 changes: 17 additions & 0 deletions Source/Cmlx/mlx-generated/arange.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace mlx::core::metal {

const char* arange() {
return R"preamble(
template <typename T>
[[kernel]] void arange(
constant const T& start,
constant const T& step,
device T* out,
uint index [[thread_position_in_grid]]) {
out[index] = start + index * step;
}
)preamble";
}

} // namespace mlx::core::metal
111 changes: 111 additions & 0 deletions Source/Cmlx/mlx-generated/binary.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
namespace mlx::core::metal {

const char* binary() {
return R"preamble(
template <typename T, typename U, typename Op>
[[kernel]] void binary_ss(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_sv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[0], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vs(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[0]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_vv(
device const T* a,
device const T* b,
device U* c,
uint index [[thread_position_in_grid]]) {
c[index] = Op()(a[index], b[index]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd1(
device const T* a,
device const T* b,
device U* c,
constant const size_t& a_stride,
constant const size_t& b_stride,
uint index [[thread_position_in_grid]]) {
auto a_idx = elem_to_loc_1(index, a_stride);
auto b_idx = elem_to_loc_1(index, b_stride);
c[index] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd2(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[2],
constant const size_t b_strides[2],
uint2 index [[thread_position_in_grid]],
uint2 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_2(index, a_strides);
auto b_idx = elem_to_loc_2(index, b_strides);
size_t out_idx = index.x + (size_t)grid_dim.x * index.y;
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g_nd3(
device const T* a,
device const T* b,
device U* c,
constant const size_t a_strides[3],
constant const size_t b_strides[3],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto a_idx = elem_to_loc_3(index, a_strides);
auto b_idx = elem_to_loc_3(index, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[a_idx], b[b_idx]);
}
template <typename T, typename U, typename Op, int DIM>
[[kernel]] void binary_g_nd(
device const T* a,
device const T* b,
device U* c,
constant const int shape[DIM],
constant const size_t a_strides[DIM],
constant const size_t b_strides[DIM],
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides);
size_t out_idx =
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
template <typename T, typename U, typename Op>
[[kernel]] void binary_g(
device const T* a,
device const T* b,
device U* c,
constant const int* shape,
constant const size_t* a_strides,
constant const size_t* b_strides,
constant const int& ndim,
uint3 index [[thread_position_in_grid]],
uint3 grid_dim [[threads_per_grid]]) {
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim);
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z);
c[out_idx] = Op()(a[idx.x], b[idx.y]);
}
)preamble";
}

} // namespace mlx::core::metal
Loading

0 comments on commit 084597e

Please sign in to comment.