-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
v0.15.2 update - this is mlx-c aligned with v0.14.0 and mlx v0.15.2 - fix x86 release builds -- conditional compile for neon code - add bitwise ops from previous releases (now present in mlx-c) - big change is JIT metal shaders -- the build should be faster and smaller
- Loading branch information
1 parent
c11212b
commit 084597e
Showing
254 changed files
with
99,426 additions
and
1,317 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
../../vendor/fmt |
Submodule mlx-c
updated
11 files
+1 −1 | CMakeLists.txt | |
+4 −0 | mlx/c/linalg.cpp | |
+1 −0 | mlx/c/linalg.h | |
+3 −0 | mlx/c/metal.cpp | |
+1 −0 | mlx/c/metal.h | |
+90 −0 | mlx/c/ops.cpp | |
+45 −0 | mlx/c/ops.h | |
+2 −0 | mlx/c/private/utils.h | |
+3 −0 | mlx/c/stream.cpp | |
+5 −0 | mlx/c/stream.h | |
+14 −0 | python/c.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
// Copyright © 2024 Apple Inc. | ||
|
||
// Note: this stubs out accelerate/softmax on x86 (e.g. via Release builds) | ||
|
||
#if defined(__aarch64__) | ||
|
||
#include "../mlx/mlx/backend/accelerate/softmax.cpp" | ||
|
||
#else | ||
|
||
#include "mlx/primitives.h" | ||
|
||
namespace mlx::core { | ||
void Softmax::eval_cpu(const std::vector<array>& inputs, array& out) { | ||
Softmax::eval(inputs, out); | ||
} | ||
} | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
namespace mlx::core::metal { | ||
|
||
const char* arange() { | ||
return R"preamble( | ||
template <typename T> | ||
[[kernel]] void arange( | ||
constant const T& start, | ||
constant const T& step, | ||
device T* out, | ||
uint index [[thread_position_in_grid]]) { | ||
out[index] = start + index * step; | ||
} | ||
)preamble"; | ||
} | ||
|
||
} // namespace mlx::core::metal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
namespace mlx::core::metal { | ||
|
||
const char* binary() { | ||
return R"preamble( | ||
template <typename T, typename U, typename Op> | ||
[[kernel]] void binary_ss( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
uint index [[thread_position_in_grid]]) { | ||
c[index] = Op()(a[0], b[0]); | ||
} | ||
template <typename T, typename U, typename Op> | ||
[[kernel]] void binary_sv( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
uint index [[thread_position_in_grid]]) { | ||
c[index] = Op()(a[0], b[index]); | ||
} | ||
template <typename T, typename U, typename Op> | ||
[[kernel]] void binary_vs( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
uint index [[thread_position_in_grid]]) { | ||
c[index] = Op()(a[index], b[0]); | ||
} | ||
template <typename T, typename U, typename Op> | ||
[[kernel]] void binary_vv( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
uint index [[thread_position_in_grid]]) { | ||
c[index] = Op()(a[index], b[index]); | ||
} | ||
template <typename T, typename U, typename Op> | ||
[[kernel]] void binary_g_nd1( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
constant const size_t& a_stride, | ||
constant const size_t& b_stride, | ||
uint index [[thread_position_in_grid]]) { | ||
auto a_idx = elem_to_loc_1(index, a_stride); | ||
auto b_idx = elem_to_loc_1(index, b_stride); | ||
c[index] = Op()(a[a_idx], b[b_idx]); | ||
} | ||
template <typename T, typename U, typename Op> | ||
[[kernel]] void binary_g_nd2( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
constant const size_t a_strides[2], | ||
constant const size_t b_strides[2], | ||
uint2 index [[thread_position_in_grid]], | ||
uint2 grid_dim [[threads_per_grid]]) { | ||
auto a_idx = elem_to_loc_2(index, a_strides); | ||
auto b_idx = elem_to_loc_2(index, b_strides); | ||
size_t out_idx = index.x + (size_t)grid_dim.x * index.y; | ||
c[out_idx] = Op()(a[a_idx], b[b_idx]); | ||
} | ||
template <typename T, typename U, typename Op> | ||
[[kernel]] void binary_g_nd3( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
constant const size_t a_strides[3], | ||
constant const size_t b_strides[3], | ||
uint3 index [[thread_position_in_grid]], | ||
uint3 grid_dim [[threads_per_grid]]) { | ||
auto a_idx = elem_to_loc_3(index, a_strides); | ||
auto b_idx = elem_to_loc_3(index, b_strides); | ||
size_t out_idx = | ||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); | ||
c[out_idx] = Op()(a[a_idx], b[b_idx]); | ||
} | ||
template <typename T, typename U, typename Op, int DIM> | ||
[[kernel]] void binary_g_nd( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
constant const int shape[DIM], | ||
constant const size_t a_strides[DIM], | ||
constant const size_t b_strides[DIM], | ||
uint3 index [[thread_position_in_grid]], | ||
uint3 grid_dim [[threads_per_grid]]) { | ||
auto idx = elem_to_loc_2_nd<DIM>(index, shape, a_strides, b_strides); | ||
size_t out_idx = | ||
index.x + (size_t)grid_dim.x * (index.y + (size_t)grid_dim.y * index.z); | ||
c[out_idx] = Op()(a[idx.x], b[idx.y]); | ||
} | ||
template <typename T, typename U, typename Op> | ||
[[kernel]] void binary_g( | ||
device const T* a, | ||
device const T* b, | ||
device U* c, | ||
constant const int* shape, | ||
constant const size_t* a_strides, | ||
constant const size_t* b_strides, | ||
constant const int& ndim, | ||
uint3 index [[thread_position_in_grid]], | ||
uint3 grid_dim [[threads_per_grid]]) { | ||
auto idx = elem_to_loc_2_nd(index, shape, a_strides, b_strides, ndim); | ||
size_t out_idx = index.x + grid_dim.x * (index.y + grid_dim.y * index.z); | ||
c[out_idx] = Op()(a[idx.x], b[idx.y]); | ||
} | ||
)preamble"; | ||
} | ||
|
||
} // namespace mlx::core::metal |
Oops, something went wrong.