Skip to content

Commit

Permalink
adopt latest mlx-c and mlx v0.3.0 (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkoski committed Feb 22, 2024
1 parent fc4c1af commit 7838f8c
Show file tree
Hide file tree
Showing 10 changed files with 133 additions and 51 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ endif()
FetchContent_Declare(
mlx-c
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
GIT_TAG "main")
GIT_TAG "v0.0.2")
FetchContent_MakeAvailable(mlx-c)

# MLX package
Expand Down
2 changes: 1 addition & 1 deletion Source/Cmlx/mlx
Submodule mlx updated 88 files
+95 −6 .circleci/config.yml
+1 −1 .pre-commit-config.yaml
+2 −2 ACKNOWLEDGMENTS.md
+4 −4 CMakeLists.txt
+2 −2 README.md
+2 −4 benchmarks/python/comparative/compare.py
+1 −12 benchmarks/python/gather_bench.py
+35 −0 benchmarks/python/rope_bench.py
+56 −0 benchmarks/python/scatter_bench.py
+13 −1 benchmarks/python/time_utils.py
+1 −0 docs/.gitignore
+1 −0 docs/src/conf.py
+1 −1 docs/src/dev/extensions.rst
+2 −1 docs/src/python/devices_and_streams.rst
+4 −0 docs/src/python/nn/layers.rst
+3 −17 docs/src/python/optimizers.rst
+20 −0 docs/src/python/optimizers/common_optimizers.rst
+0 −0 docs/src/python/optimizers/optimizer.rst
+13 −0 docs/src/python/optimizers/schedulers.rst
+2 −1 mlx/CMakeLists.txt
+7 −0 mlx/array.cpp
+3 −0 mlx/array.h
+1 −39 mlx/backend/accelerate/primitives.cpp
+6 −1 mlx/backend/accelerate/softmax.cpp
+1 −0 mlx/backend/common/CMakeLists.txt
+21 −3 mlx/backend/common/binary.cpp
+14 −0 mlx/backend/common/rope.cpp
+6 −1 mlx/backend/common/softmax.cpp
+1 −0 mlx/backend/metal/CMakeLists.txt
+0 −9 mlx/backend/metal/device.cpp
+72 −146 mlx/backend/metal/indexing.cpp
+4 −1 mlx/backend/metal/kernels/CMakeLists.txt
+19 −9 mlx/backend/metal/kernels/binary.h
+23 −4 mlx/backend/metal/kernels/binary_two.metal
+6 −0 mlx/backend/metal/kernels/complex.h
+187 −0 mlx/backend/metal/kernels/gather.metal
+54 −0 mlx/backend/metal/kernels/indexing.h
+0 −290 mlx/backend/metal/kernels/indexing.metal
+44 −26 mlx/backend/metal/kernels/quantized.metal
+68 −0 mlx/backend/metal/kernels/rope.metal
+194 −0 mlx/backend/metal/kernels/scatter.metal
+2 −2 mlx/backend/metal/kernels/utils.h
+1 −1 mlx/backend/metal/make_compiled_preamble.sh
+1 −1 mlx/backend/metal/quantized.cpp
+55 −0 mlx/backend/metal/rope.cpp
+6 −1 mlx/backend/metal/softmax.cpp
+0 −14 mlx/backend/metal/utils.h
+5 −0 mlx/backend/no_metal/primitives.cpp
+128 −0 mlx/fast.cpp
+82 −0 mlx/fast.h
+17 −12 mlx/io.h
+6 −9 mlx/io/gguf.cpp
+17 −12 mlx/io/safetensor.cpp
+1 −0 mlx/mlx.h
+7 −10 mlx/ops.cpp
+1 −5 mlx/ops.h
+10 −0 mlx/types/complex.h
+10 −0 mlx/utils.cpp
+26 −0 mlx/utils.h
+1 −0 python/mlx/nn/layers/__init__.py
+308 −0 python/mlx/nn/layers/pooling.py
+10 −68 python/mlx/nn/layers/positional_encoding.py
+4 −0 python/mlx/optimizers/__init__.py
+48 −23 python/mlx/optimizers/optimizers.py
+86 −0 python/mlx/optimizers/schedulers.py
+0 −1 python/mlx/utils.py
+2 −0 python/src/CMakeLists.txt
+6 −0 python/src/array.cpp
+11 −3 python/src/device.cpp
+59 −0 python/src/fast.cpp
+33 −16 python/src/load.cpp
+8 −10 python/src/load.h
+5 −0 python/src/mlx.cpp
+3 −1 python/src/ops.cpp
+1 −1 python/src/random.cpp
+29 −4 python/src/stream.cpp
+81 −0 python/src/utils.cpp
+11 −0 python/tests/test_device.py
+158 −0 python/tests/test_fast.py
+59 −58 python/tests/test_fft.py
+19 −6 python/tests/test_load.py
+341 −0 python/tests/test_nn.py
+25 −3 python/tests/test_ops.py
+54 −7 python/tests/test_optimizers.py
+64 −0 python/tests/test_quantized.py
+1 −1 setup.py
+18 −0 tests/array_tests.cpp
+15 −9 tests/load_tests.cpp
44 changes: 42 additions & 2 deletions Source/MLX/Cmlx+Util.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func mlx_vector_array_values(_ vector_array: mlx_vector_array) -> [MLXArray] {
}
}

func mlx_map_values(_ mlx_map: mlx_map_string_to_array) -> [String: MLXArray] {
func mlx_map_array_values(_ mlx_map: mlx_map_string_to_array) -> [String: MLXArray] {
var result = [String: MLXArray]()

let iterator = mlx_map_string_to_array_iterate(mlx_map)!
Expand All @@ -58,7 +58,31 @@ func mlx_map_values(_ mlx_map: mlx_map_string_to_array) -> [String: MLXArray] {
return result
}

func new_mlx_map(_ dictionary: [String: MLXArray]) -> mlx_map_string_to_array {
func mlx_map_string_values(_ mlx_map: mlx_map_string_to_string) -> [String: String] {
var result = [String: String]()

let iterator = mlx_map_string_to_string_iterate(mlx_map)!
defer { mlx_free(iterator) }

while !mlx_map_string_to_string_iterator_end(iterator) {
let mlx_key = mlx_map_string_to_string_iterator_key(iterator)!
defer { mlx_free(mlx_key) }
let key = String(cString: mlx_string_data(mlx_key))

// note: transfer ownership
let mlx_value = mlx_map_string_to_string_iterator_value(iterator)!
defer { mlx_free(mlx_value) }
let value = String(cString: mlx_string_data(mlx_value))

result[key] = value

mlx_map_string_to_string_iterator_next(iterator)
}

return result
}

func new_mlx_array_map(_ dictionary: [String: MLXArray]) -> mlx_map_string_to_array {
let mlx_map = mlx_map_string_to_array_new()!

for (key, array) in dictionary {
Expand All @@ -71,6 +95,22 @@ func new_mlx_map(_ dictionary: [String: MLXArray]) -> mlx_map_string_to_array {
return mlx_map
}

func new_mlx_string_map(_ dictionary: [String: String]) -> mlx_map_string_to_string {
let mlx_map = mlx_map_string_to_string_new()!

for (key, value) in dictionary {
let mlx_key = mlx_string_new(key.cString(using: .utf8))!
defer { mlx_free(mlx_key) }

let mlx_value = mlx_string_new(value.cString(using: .utf8))!
defer { mlx_free(mlx_value) }

mlx_map_string_to_string_insert(mlx_map, mlx_key, mlx_value)
}

return mlx_map
}

func new_mlx_closure(_ f: @escaping ([MLXArray]) -> [MLXArray]) -> mlx_closure {

// holds reference to `f()` as capture state for the mlx_closure
Expand Down
2 changes: 1 addition & 1 deletion Source/MLX/Documentation.docc/Articles/lazy-evaluation.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ for batch in dataset {

An important behavior to be aware of is when the graph will be implicitly
evaluated. Anytime you `print` an array, or otherwise access it's memory,
the graph will be evaluated. Saving arrays via ``save(arrays:url:stream:)``
the graph will be evaluated. Saving arrays via ``save(arrays:metadata:url:stream:)``
(or any other MLX saving functions) will also evaluate the array.


Expand Down
3 changes: 2 additions & 1 deletion Source/MLX/Documentation.docc/free-functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ operations as methods for convenience.

- ``loadArray(url:stream:)``
- ``loadArrays(url:stream:)``
- ``loadArraysAndMetadata(url:stream:)``
- ``save(array:url:stream:)``
- ``save(arrays:url:stream:)``
- ``save(arrays:metadata:url:stream:)``

### Logical

Expand Down
67 changes: 56 additions & 11 deletions Source/MLX/Ops.swift
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ enum LoadSaveError: Error {
/// ### See Also
/// - ``loadArrays(url:stream:)``
/// - ``save(array:url:stream:)``
/// - ``save(arrays:url:stream:)``
/// - ``save(arrays:metadata:url:stream:)``
public func loadArray(url: URL, stream: StreamOrDevice = .default) throws -> MLXArray {
precondition(url.isFileURL)
let path = url.path(percentEncoded: false)
Expand Down Expand Up @@ -809,8 +809,9 @@ public func loadArray(url: URL, stream: StreamOrDevice = .default) throws -> MLX
///
/// ### See Also
/// - ``loadArray(url:stream:)``
/// - ``loadArraysAndMetadata(url:stream:)``
/// - ``save(array:url:stream:)``
/// - ``save(arrays:url:stream:)``
/// - ``save(arrays:metadata:url:stream:)``
public func loadArrays(url: URL, stream: StreamOrDevice = .default) throws -> [String: MLXArray] {
precondition(url.isFileURL)
let path = url.path(percentEncoded: false)
Expand All @@ -819,10 +820,47 @@ public func loadArrays(url: URL, stream: StreamOrDevice = .default) throws -> [S

switch url.pathExtension {
case "safetensors":
let mlx_map = mlx_load_safetensors(filename, stream.ctx)!
defer { mlx_free(mlx_map) }
let mlx_safetensors = mlx_load_safetensors(filename, stream.ctx)!
defer { mlx_free(mlx_safetensors) }

return mlx_map_values(mlx_map)
let mlx_arrays = mlx_safetensors_data(mlx_safetensors)!
defer { mlx_free(mlx_arrays) }

return mlx_map_array_values(mlx_arrays)
default:
throw LoadSaveError.unknownExtension(url.pathExtension)
}
}

/// Load dictionary of ``MLXArray`` and metadata `[String:String]` from a `safetensors` file.
///
/// - Parameters:
/// - url: URL of file to load
/// - stream: stream or device to evaluate on
///
/// ### See Also
/// - ``loadArrays(url:stream:)``
/// - ``loadArray(url:stream:)``
public func loadArraysAndMetadata(url: URL, stream: StreamOrDevice = .default) throws -> (
[String: MLXArray], [String: String]
) {
precondition(url.isFileURL)
let path = url.path(percentEncoded: false)
let filename = mlx_string_new(path.cString(using: .utf8))!
defer { mlx_free(filename) }

switch url.pathExtension {
case "safetensors":
let mlx_safetensors = mlx_load_safetensors(filename, stream.ctx)!
defer { mlx_free(mlx_safetensors) }

let mlx_arrays = mlx_safetensors_data(mlx_safetensors)!
defer { mlx_free(mlx_arrays) }

let mlx_metadata = mlx_safetensors_metadata(mlx_safetensors)!
defer { mlx_free(mlx_metadata) }

return (mlx_map_array_values(mlx_arrays), mlx_map_string_values(mlx_metadata))
default:
throw LoadSaveError.unknownExtension(url.pathExtension)
}
Expand Down Expand Up @@ -1162,7 +1200,7 @@ public func remainder<A: ScalarOrArray, B: ScalarOrArray>(
/// - stream: stream or device to evaluate on
///
/// ### See Also
/// - ``save(arrays:url:stream:)``
/// - ``save(arrays:metadata:url:stream:)``
/// - ``loadArray(url:stream:)``
/// - ``loadArrays(url:stream:)``
public func save(array: MLXArray, url: URL, stream: StreamOrDevice = .default) throws {
Expand All @@ -1188,26 +1226,33 @@ public func save(array: MLXArray, url: URL, stream: StreamOrDevice = .default) t
///
/// - Parameters:
/// - a: array to save
/// - metadata: metadata to save
/// - url: URL of file to load
/// - stream: stream or device to evaluate on
///
/// ### See Also
/// - ``save(arrays:url:stream:)``
/// - ``save(arrays:metadata:url:stream:)``
/// - ``loadArray(url:stream:)``
/// - ``loadArrays(url:stream:)``
public func save(arrays: [String: MLXArray], url: URL, stream: StreamOrDevice = .default) throws {
public func save(
arrays: [String: MLXArray], metadata: [String: String] = [:], url: URL,
stream: StreamOrDevice = .default
) throws {
precondition(url.isFileURL)
let path = url.path(percentEncoded: false)

let mlx_map = new_mlx_map(arrays)
defer { mlx_free(mlx_map) }
let mlx_arrays = new_mlx_array_map(arrays)
defer { mlx_free(mlx_arrays) }

let mlx_metadata = new_mlx_string_map(metadata)
defer { mlx_free(mlx_metadata) }

switch url.pathExtension {
case "safetensors":
if let fp = fopen(path, "r") {
defer { fclose(fp) }

mlx_save_safetensors_file(fp, mlx_map)
mlx_save_safetensors_file(fp, mlx_arrays, mlx_metadata)

} else {
let message = String(cString: strerror(errno))
Expand Down
56 changes: 28 additions & 28 deletions Tests/MLXTests/IntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -362,11 +362,11 @@ class MLXIntegrationTests: XCTestCase {
XCTAssertEqual(result.shape, [4, 3])
XCTAssertEqual(result.dtype, .float32)
XCTAssertEqual(
result.mean().item(Float.self), -0.14204177260398865,
accuracy: -0.002840835452079773)
result.mean().item(Float.self), 0.28858500719070435,
accuracy: 0.005771700143814087)
XCTAssertEqual(
result.sum().item(Float.self), -1.7045011520385742,
accuracy: -0.03409002304077149)
result.sum().item(Float.self), 3.463019847869873,
accuracy: 0.06926039695739747)
}

func testModOp1() {
Expand All @@ -385,11 +385,11 @@ class MLXIntegrationTests: XCTestCase {
XCTAssertEqual(result.shape, [4, 3])
XCTAssertEqual(result.dtype, .float32)
XCTAssertEqual(
result.mean().item(Float.self), 0.39097175002098083,
accuracy: 0.007819435000419617)
result.mean().item(Float.self), -0.029992982745170593,
accuracy: -0.0005998596549034119)
XCTAssertEqual(
result.sum().item(Float.self), 4.6916608810424805,
accuracy: 0.09383321762084962)
result.sum().item(Float.self), -0.3599157929420471,
accuracy: -0.007198315858840942)
}

func testModOp2() {
Expand All @@ -408,11 +408,11 @@ class MLXIntegrationTests: XCTestCase {
XCTAssertEqual(result.shape, [4, 3])
XCTAssertEqual(result.dtype, .float32)
XCTAssertEqual(
result.mean().item(Float.self), -0.018092799931764603,
accuracy: -0.00036185599863529204)
result.mean().item(Float.self), 0.6319072246551514,
accuracy: 0.012638144493103028)
XCTAssertEqual(
result.sum().item(Float.self), -0.21711358428001404,
accuracy: -0.0043422716856002805)
result.sum().item(Float.self), 7.582886219024658,
accuracy: 0.15165772438049316)
}

func testPowOp() {
Expand Down Expand Up @@ -3851,11 +3851,11 @@ class MLXIntegrationTests: XCTestCase {
XCTAssertEqual(result.shape, [4, 3])
XCTAssertEqual(result.dtype, .float32)
XCTAssertEqual(
result.mean().item(Float.self), 0.09767231345176697,
accuracy: 0.0019534462690353393)
result.mean().item(Float.self), -0.007391604594886303,
accuracy: -0.00014783209189772606)
XCTAssertEqual(
result.sum().item(Float.self), 1.1720677614212036,
accuracy: 0.023441355228424072)
result.sum().item(Float.self), -0.08869925141334534,
accuracy: -0.0017739850282669067)
}

func testSubtract() {
Expand Down Expand Up @@ -6459,11 +6459,11 @@ class MLXIntegrationTests: XCTestCase {
XCTAssertEqual(result.shape, [2, 8, 16])
XCTAssertEqual(result.dtype, .float32)
XCTAssertEqual(
result.mean().item(Float.self), 0.982857882976532,
accuracy: 0.01965715765953064)
result.mean().item(Float.self), 0.9828577637672424,
accuracy: 0.01965715527534485)
XCTAssertEqual(
result.sum().item(Float.self), 251.6116180419922,
accuracy: 5.032232360839844)
result.sum().item(Float.self), 251.61158752441406,
accuracy: 5.032231750488282)
}

func testSoftsign() {
Expand Down Expand Up @@ -6569,11 +6569,11 @@ class MLXIntegrationTests: XCTestCase {
XCTAssertEqual(result.shape, [2, 8, 16])
XCTAssertEqual(result.dtype, .float32)
XCTAssertEqual(
result.mean().item(Float.self), -0.47959864139556885,
accuracy: -0.009591972827911377)
result.mean().item(Float.self), -0.4795985519886017,
accuracy: -0.009591971039772034)
XCTAssertEqual(
result.sum().item(Float.self), -122.77725219726562,
accuracy: -2.4555450439453126)
result.sum().item(Float.self), -122.77722930908203,
accuracy: -2.4555445861816407)
}

func testPReLU() {
Expand Down Expand Up @@ -6613,11 +6613,11 @@ class MLXIntegrationTests: XCTestCase {
XCTAssertEqual(result.shape, [2, 8, 16])
XCTAssertEqual(result.dtype, .float32)
XCTAssertEqual(
result.mean().item(Float.self), 0.3656384348869324,
accuracy: 0.007312768697738648)
result.mean().item(Float.self), 0.3656383752822876,
accuracy: 0.007312767505645752)
XCTAssertEqual(
result.sum().item(Float.self), 93.60343933105469,
accuracy: 1.8720687866210939)
result.sum().item(Float.self), 93.60342407226562,
accuracy: 1.8720684814453126)
}

func testTanh1() {
Expand Down
3 changes: 0 additions & 3 deletions Tests/MLXTests/TransformTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -243,9 +243,6 @@ class TransformTests: XCTestCase {
func testCompilePerformance() {
// this is the code from compilation.md

// disabling until we pick up the fix for https://github.com/ml-explore/mlx/issues/31
return

func measure(_ f: (MLXArray) -> MLXArray, _ x: MLXArray) {
// warm up
for _ in 0 ..< 10 {
Expand Down
3 changes: 1 addition & 2 deletions Tests/MLXTests/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,5 @@ func assertEqual(
}

func setDefaultDevice() {
// run tests on CPU for now until we pick up https://github.com/ml-explore/mlx/issues/31
MLX.Device.setDefault(device: .cpu)
MLX.Device.setDefault(device: .gpu)
}

0 comments on commit 7838f8c

Please sign in to comment.