From 3ff8fa33b68e1690c3b7e575f9830dc819d05c4a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Aug 2024 17:46:34 +0200 Subject: [PATCH 01/19] initial eigvalsh --- docs/src/python/linalg.rst | 1 + mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/eigvalsh.cpp | 162 ++++++++++++++++++++++ mlx/backend/metal/primitives.cpp | 4 + mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/linalg.cpp | 34 +++++ mlx/linalg.h | 2 + mlx/primitives.cpp | 8 ++ mlx/primitives.h | 16 +++ python/src/linalg.cpp | 37 +++++ python/tests/test_linalg.py | 48 +++++++ tests/linalg_tests.cpp | 53 +++++++ 15 files changed, 370 insertions(+) create mode 100644 mlx/backend/common/eigvalsh.cpp diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 227711c22..0ae00da17 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -16,3 +16,4 @@ Linear Algebra cross qr svd + eigvalsh diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index eee93f2ab..4350ebd8d 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) +DEFAULT(Eigvalsh) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 925f4731c..a1bf8c5ce 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eigvalsh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index f8932c5f8..ecc7ee553 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -114,6 +114,7 @@ DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) +DEFAULT(Eigvalsh) namespace { diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp new file mode 100644 index 000000000..4fa6c138f --- /dev/null +++ b/mlx/backend/common/eigvalsh.cpp @@ -0,0 +1,162 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +namespace mlx::core { + +namespace { + +// Delegate to the eigenvalue decomposition taking into account differences in +// LAPACK implementations (basically how to pass the 'jobz' and 'uplo' strings +// to fortran). +int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) { + int info; + int lwork = -1; + int liwork = -1; + float work_query; + int iwork_query; + + // Query for optimal work array sizes +#ifdef LAPACK_FORTRAN_STRLEN_END + ssyevd_( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* w = */ w, + /* work = */ &work_query, + /* lwork = */ &lwork, + /* iwork = */ &iwork_query, + /* liwork = */ &liwork, + /* info = */ &info, + /* jobz_len = */ static_cast(1), + /* uplo_len = */ static_cast(1)); +#else + ssyevd_( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* w = */ w, + /* work = */ &work_query, + /* lwork = */ &lwork, + /* iwork = */ &iwork_query, + /* liwork = */ &liwork, + /* info = */ &info); +#endif + + lwork = static_cast(work_query); + liwork = iwork_query; + + std::vector work(lwork); + std::vector iwork(liwork); + + // Compute eigenvalues (and optionally eigenvectors) +#ifdef LAPACK_FORTRAN_STRLEN_END + ssyevd_( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* w = */ w, + /* work = */ work.data(), + /* lwork = */ &lwork, + /* iwork = */ iwork.data(), + /* liwork = */ &liwork, + /* info = */ &info, + /* jobz_len = */ static_cast(1), + /* uplo_len = */ static_cast(1)); +#else + ssyevd_( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* w = */ w, + /* work = */ work.data(), + /* lwork = */ &lwork, + /* iwork = */ iwork.data(), + /* liwork = */ &liwork, + /* info = */ &info); +#endif + + return info; +} + +} // namespace + +void eigenvalues_impl( + const array& a, + array& values, + array& vectors, + bool compute_vectors) { + char jobz = compute_vectors ? 'V' : 'N'; + char uplo = 'U'; // Use upper triangle of the matrix + + // Copy input to a buffer for in-place computation + array buffer; + copy( + a, + buffer, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + const int N = a.shape(-1); + const size_t num_matrices = a.size() / (N * N); + + // Allocate output arrays + values = array::empty({num_matrices, N}, float32); + if (compute_vectors) { + vectors = array::empty(a.shape(), float32); + } + + float* matrix = buffer.data(); + float* w = values.data(); + float* vecs = compute_vectors ? vectors.data() : nullptr; + + for (int i = 0; i < num_matrices; i++) { + // Compute eigenvalue decomposition + int info = ssyevd_wrapper(jobz, uplo, matrix, w, N); + + if (info != 0) { + std::stringstream msg; + msg << "[eigenvalues] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } + + // Copy eigenvectors if computed + if (compute_vectors) { + std::copy(matrix, matrix + N * N, vecs); + vecs += N * N; + } + + // Move to next matrix + matrix += N * N; + w += N; + } +} + +void Eigenvalues::eval( + const std::vector& inputs, + array& values, + array& vectors) { + if (inputs[0].dtype() != float32) { + throw std::runtime_error("[Eigenvalues::eval] only supports float32."); + } + eigenvalues_impl(inputs[0], values, vectors, compute_vectors_); +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 31f2248d7..88ff4919d 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -401,6 +401,10 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } +void Eigvalsh::eval_gpu(const std::vector& inputs, array& out) { + throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigvalsh NYI."); +} + void View::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto ibytes = size_of(in.dtype()); diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index fd15c403b..fbff70673 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -48,6 +48,7 @@ NO_CPU(Divide) NO_CPU_MULTI(DivMod) NO_CPU(NumberOfElements) NO_CPU(Remainder) +NO_CPU(Eigvalsh) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 5270a6fdd..0c1e37bf3 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -112,6 +112,7 @@ NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) +NO_GPU(Eigvalsh) NO_GPU(View) namespace fast { diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index a64f98aa8..ec0afd600 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -454,4 +454,38 @@ array cross( return concatenate(outputs, axis, s); } +array eigvalsh( + const array& a, + bool upper /* = true */, + StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::eigvalsh] Arrays must be type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::eigvalsh] Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument( + "[linalg::eigvalsh] Eigenvalues are only defined for square matrices."); + } + + std::vector out_shape(a.shape().begin(), a.shape().end() - 1); + out_shape.back() = a.shape(-1); + + return array( + out_shape, + a.dtype(), + std::make_shared(to_stream(s), upper), + {a}); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index acfcc1a41..fa943d5a9 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -83,4 +83,6 @@ array cross( int axis = -1, StreamOrDevice s = {}); +array eigvalsh(const array& a, bool upper = true, StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 8aa0392b7..d36af5651 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -767,6 +767,14 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } +std::pair, std::vector> Eigvalsh::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::eigvalsh(a, upper_, stream())}, {ax}}; +} + std::vector Concatenate::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 4bec71445..b6fc1fc1c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2196,4 +2196,20 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; +class Eigvalsh : public UnaryPrimitive { + public: + explicit Eigvalsh(Stream stream, bool upper) + : UnaryPrimitive(stream), upper_(upper) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_PRINT(Eigvalsh) + + private: + void eval(const std::vector& inputs, array& output); + bool upper_; +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 13d61e980..9c6e3e761 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -405,4 +405,41 @@ void init_linalg(nb::module_& parent_module) { Returns: array: The cross product of ``a`` and ``b`` along the specified axis. )pbdoc"); + m.def( + "eigvalsh", + &eigvalsh, + "a"_a, + "upper"_a = true, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def eigvalsh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the eigenvalues of a complex Hermitian or real symmetric matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvalues are computed for each matrix + in the last two dimensions of ``a``. + + Args: + a (array): Input array. Must be a real symmetric or complex Hermitian matrix. + upper (bool, optional): Whether to use the upper or lower triangle of the matrix. + Default is True (upper triangle). + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The eigenvalues in ascending order. + + Note: + The input matrix is assumed to be symmetric (or Hermitian). Only the + upper triangle (if upper=True) or lower triangle (if upper=False) is used. + No checks for symmetry are performed. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> eigenvalues = mx.linalg.eigvalsh(A) + >>> eigenvalues + array([-1., 3.], dtype=float32) + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 6051beef7..1d336ee21 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -268,6 +268,54 @@ def test_cross_product(self): with self.assertRaises(ValueError): mx.linalg.cross(a, b) + def test_eigvalsh(self): + # Test a simple 2x2 symmetric matrix + A_mx = mx.array([[1.0, 2.0], [2.0, 4.0]], dtype=mx.float32) + A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32) + + eigenvalues_mx = mx.linalg.eigvalsh(A_mx) + eigenvalues_np = np.linalg.eigvalsh(A_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + + # Test a larger random symmetric matrix + n = 5 + rng = np.random.default_rng(42) + B_np = rng.random((n, n)).astype(np.float32) + B_np = (B_np + B_np.T) / 2 # Make sure B is symmetric + B_mx = mx.array(B_np) + + eigenvalues_mx = mx.linalg.eigvalsh(B_mx) + eigenvalues_np = np.linalg.eigvalsh(B_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + + # Test that eigenvalues are in ascending order + self.assertTrue(mx.all(eigenvalues_mx[1:] >= eigenvalues_mx[:-1])) + + # Test with upper=False + eigenvalues_mx_lower = mx.linalg.eigvalsh(B_mx, upper=False) + eigenvalues_np_lower = np.linalg.eigvalsh(B_np, UPLO='L') + + self.assertTrue(mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5)) + + # Test with batched input + C_np = rng.random((3, n, n)).astype(np.float32) + C_np = (C_np + np.transpose(C_np, (0, 2, 1))) / 2 # Make sure C is symmetric for each batch + C_mx = mx.array(C_np) + + eigenvalues_mx = mx.linalg.eigvalsh(C_mx) + eigenvalues_np = np.linalg.eigvalsh(C_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + + # Test error cases + with self.assertRaises(ValueError): + mx.linalg.eigvalsh(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigvalsh(mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) # Non-square matrix + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index e9e196583..0a89bcf22 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -349,6 +349,7 @@ TEST_CASE("test matrix cholesky") { .item()); } +<<<<<<< HEAD TEST_CASE("test matrix pseudo-inverse") { // 0D and 1D throw CHECK_THROWS(linalg::pinv(array(0.0), Device::cpu)); @@ -435,3 +436,55 @@ TEST_CASE("test cross product") { result = cross(a, b); CHECK(allclose(result, expected).item()); } + +TEST_CASE("test matrix eigvalsh") { + // 0D and 1D throw + CHECK_THROWS(linalg::eigvalsh(array(0.0), /* upper = */ true, Device::cpu)); + CHECK_THROWS( + linalg::eigvalsh(array({0.0, 1.0}), /* upper = */ true, Device::cpu)); + + // Unsupported types throw + CHECK_THROWS( + linalg::eigvalsh(array({0, 1}, {1, 2}), /* upper = */ true, Device::cpu)); + + // Non-square throws + CHECK_THROWS(linalg::eigvalsh( + array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); + + // Test a simple 2x2 symmetric matrix + array A = array({{1.0, 2.0}, {2.0, 4.0}}); + array eigenvalues = linalg::eigvalsh(A, /* upper = */ true, Device::cpu); + + // Expected eigenvalues (calculated analytically) + array expected_eigenvalues = array({0.0, 5.0}); + + CHECK(allclose( + eigenvalues, + expected_eigenvalues, + /* rtol = */ 1e-5, + /* atol = */ 1e-5) + .item()); + + // Test a larger symmetric matrix + const auto prng_key = random::key(42); + const auto B = random::normal({5, 5}, prng_key); + const auto B_sym = 0.5 * (B + transpose(B)); // Make sure B is symmetric + const auto B_eigenvalues = + linalg::eigvalsh(B_sym, /* upper = */ true, Device::cpu); + + // Check that eigenvalues are real and in ascending order + CHECK(B_eigenvalues.dtype() == float32); + CHECK(B_eigenvalues.shape() == std::vector{5}); + CHECK(all(isfinite(B_eigenvalues)).item()); + CHECK(all(B_eigenvalues [1:] >= B_eigenvalues[:-1]).item()); + + // Reconstruct the matrix using eigendecomposition and check if it's close to + // the original + const auto D = diag(B_eigenvalues); + const auto V = linalg::eigh(B_sym, /* upper = */ true, Device::cpu) + .second; // Assuming eigh is implemented + const auto B_reconstructed = matmul(matmul(V, D), transpose(V)); + + CHECK(allclose(B_reconstructed, B_sym, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); +} From 508b6c1beca2114541e453d7e34300c3af66ae6c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Aug 2024 18:27:08 +0200 Subject: [PATCH 02/19] add compute_vectors --- mlx/backend/common/eigvalsh.cpp | 29 +++++++++++++---------------- mlx/primitives.h | 13 ++++++++----- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp index 4fa6c138f..f1e47dce7 100644 --- a/mlx/backend/common/eigvalsh.cpp +++ b/mlx/backend/common/eigvalsh.cpp @@ -98,28 +98,25 @@ int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) { } // namespace -void eigenvalues_impl( +void eigvalsh_impl( const array& a, array& values, array& vectors, + bool upper, bool compute_vectors) { char jobz = compute_vectors ? 'V' : 'N'; - char uplo = 'U'; // Use upper triangle of the matrix + char uplo = (upper) ? 'L' : 'U'; // Use upper triangle of the matrix - // Copy input to a buffer for in-place computation - array buffer; - copy( - a, - buffer, - a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + // Create a copy of the input array for in-place computation + array buffer = copy(a); - const int N = a.shape(-1); - const size_t num_matrices = a.size() / (N * N); + const int N = static_cast(a.shape(-1)); + const int num_matrices = static_cast(a.size() / (N * N)); // Allocate output arrays - values = array::empty({num_matrices, N}, float32); + values = zeros({num_matrices, N}, float32); if (compute_vectors) { - vectors = array::empty(a.shape(), float32); + vectors = zeros(a.shape(), float32); } float* matrix = buffer.data(); @@ -132,7 +129,7 @@ void eigenvalues_impl( if (info != 0) { std::stringstream msg; - msg << "[eigenvalues] Eigenvalue decomposition failed with error code " + msg << "[eigvalsh] Eigenvalue decomposition failed with error code " << info; throw std::runtime_error(msg.str()); } @@ -149,14 +146,14 @@ void eigenvalues_impl( } } -void Eigenvalues::eval( +void Eigvalsh::eval( const std::vector& inputs, array& values, array& vectors) { if (inputs[0].dtype() != float32) { - throw std::runtime_error("[Eigenvalues::eval] only supports float32."); + throw std::runtime_error("[Eigvalsh::eval] only supports float32."); } - eigenvalues_impl(inputs[0], values, vectors, compute_vectors_); + eigvalsh_impl(inputs[0], values, vectors, upper_, compute_vectors_); } } // namespace mlx::core \ No newline at end of file diff --git a/mlx/primitives.h b/mlx/primitives.h index b6fc1fc1c..c72d798de 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2198,18 +2198,21 @@ class Cholesky : public UnaryPrimitive { class Eigvalsh : public UnaryPrimitive { public: - explicit Eigvalsh(Stream stream, bool upper) - : UnaryPrimitive(stream), upper_(upper) {} + explicit Eigvalsh(Stream stream, bool upper, bool compute_vectors) + : UnaryPrimitive(stream), + upper_(upper), + compute_vectors_(compute_vectors) {} - void eval_cpu(const std::vector& inputs, array& out) override; - void eval_gpu(const std::vector& inputs, array& out) override; + void eval_cpu(const std::vector& inputs, array& values) override; + void eval_gpu(const std::vector& inputs, array& values) override; DEFINE_VMAP() DEFINE_PRINT(Eigvalsh) private: - void eval(const std::vector& inputs, array& output); + void eval(const std::vector& inputs, array& values); bool upper_; + bool compute_vectors_; }; } // namespace mlx::core From 6fb8d7f402e13401afaa6b2d1795afb8f2cc5e9c Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 18 Aug 2024 18:43:00 +0200 Subject: [PATCH 03/19] add compute_vectors_ --- mlx/backend/common/eigvalsh.cpp | 6 ++++-- mlx/linalg.cpp | 3 ++- mlx/linalg.h | 5 +++++ mlx/primitives.cpp | 2 +- mlx/primitives.h | 6 +++--- 5 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp index f1e47dce7..69b57112e 100644 --- a/mlx/backend/common/eigvalsh.cpp +++ b/mlx/backend/common/eigvalsh.cpp @@ -114,9 +114,11 @@ void eigvalsh_impl( const int num_matrices = static_cast(a.size() / (N * N)); // Allocate output arrays - values = zeros({num_matrices, N}, float32); + std::vector values_shape = {num_matrices, N}; + values = array(values_shape, a.dtype()); + if (compute_vectors) { - vectors = zeros(a.shape(), float32); + vectors = array(a.shape(), a.dtype()); } float* matrix = buffer.data(); diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index ec0afd600..201102d45 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -457,6 +457,7 @@ array cross( array eigvalsh( const array& a, bool upper /* = true */, + bool compute_vectors /* = false */, StreamOrDevice s /* = {} */) { if (a.dtype() != float32) { std::ostringstream msg; @@ -484,7 +485,7 @@ array eigvalsh( return array( out_shape, a.dtype(), - std::make_shared(to_stream(s), upper), + std::make_shared(to_stream(s), upper, compute_vectors), {a}); } diff --git a/mlx/linalg.h b/mlx/linalg.h index fa943d5a9..1cc27509c 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -84,5 +84,10 @@ array cross( StreamOrDevice s = {}); array eigvalsh(const array& a, bool upper = true, StreamOrDevice s = {}); +array eigvalsh( + const array& a, + bool upper = true, + bool compute_vectors = false, + StreamOrDevice s = {}); } // namespace mlx::core::linalg diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index d36af5651..9b049d2bc 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -772,7 +772,7 @@ std::pair, std::vector> Eigvalsh::vmap( const std::vector& axes) { auto ax = axes[0] >= 0 ? 0 : -1; auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::eigvalsh(a, upper_, stream())}, {ax}}; + return {{linalg::eigvalsh(a, upper_, compute_vectors_, stream())}, {ax}}; } std::vector Concatenate::vjp( diff --git a/mlx/primitives.h b/mlx/primitives.h index c72d798de..2bc5c3741 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2203,14 +2203,14 @@ class Eigvalsh : public UnaryPrimitive { upper_(upper), compute_vectors_(compute_vectors) {} - void eval_cpu(const std::vector& inputs, array& values) override; - void eval_gpu(const std::vector& inputs, array& values) override; + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; DEFINE_VMAP() DEFINE_PRINT(Eigvalsh) private: - void eval(const std::vector& inputs, array& values); + void eval(const std::vector& inputs, array& output); bool upper_; bool compute_vectors_; }; From b254961f866d81b2b0afe1b215442560498fe97f Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Mon, 19 Aug 2024 20:57:42 +0200 Subject: [PATCH 04/19] return a pair --- mlx/backend/common/eigvalsh.cpp | 11 +++++------ mlx/linalg.cpp | 11 ++++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp index 69b57112e..0841a451f 100644 --- a/mlx/backend/common/eigvalsh.cpp +++ b/mlx/backend/common/eigvalsh.cpp @@ -105,7 +105,7 @@ void eigvalsh_impl( bool upper, bool compute_vectors) { char jobz = compute_vectors ? 'V' : 'N'; - char uplo = (upper) ? 'L' : 'U'; // Use upper triangle of the matrix + char uplo = (upper) ? 'U' : 'L'; // Use upper triangle of the matrix // Create a copy of the input array for in-place computation array buffer = copy(a); @@ -115,10 +115,10 @@ void eigvalsh_impl( // Allocate output arrays std::vector values_shape = {num_matrices, N}; - values = array(values_shape, a.dtype()); + values = array({}, values_shape, a.dtype()); if (compute_vectors) { - vectors = array(a.shape(), a.dtype()); + vectors = array({}, a.shape(), a.dtype()); } float* matrix = buffer.data(); @@ -150,12 +150,11 @@ void eigvalsh_impl( void Eigvalsh::eval( const std::vector& inputs, - array& values, - array& vectors) { + std::vector& outputs) { if (inputs[0].dtype() != float32) { throw std::runtime_error("[Eigvalsh::eval] only supports float32."); } - eigvalsh_impl(inputs[0], values, vectors, upper_, compute_vectors_); + eigvalsh_impl(inputs[0], output[0], output[1], upper_, compute_vectors_); } } // namespace mlx::core \ No newline at end of file diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 201102d45..949053d32 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -454,7 +454,7 @@ array cross( return concatenate(outputs, axis, s); } -array eigvalsh( +std::pair eigvalsh( const array& a, bool upper /* = true */, bool compute_vectors /* = false */, @@ -482,11 +482,12 @@ array eigvalsh( std::vector out_shape(a.shape().begin(), a.shape().end() - 1); out_shape.back() = a.shape(-1); - return array( - out_shape, - a.dtype(), + auto out = array::make_arrays( + {out_shape, compute_vectors ? a.shape() : std::vector()}, + {a.dtype(), a.dtype()}, std::make_shared(to_stream(s), upper, compute_vectors), - {a}); + {astype(a, a.dtype(), s)}); + return std::make_pair(out[0], out[1]); } } // namespace mlx::core::linalg From fa81b3f1b209096abe1b691523899ddce1913f62 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 29 Aug 2024 17:35:03 +0200 Subject: [PATCH 05/19] add eigh to return only eigenvectors --- docs/src/python/linalg.rst | 1 + mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/common/eigvalsh.cpp | 77 +++++++++++++++-------- mlx/backend/metal/primitives.cpp | 4 ++ mlx/backend/no_cpu/primitives.cpp | 1 + mlx/backend/no_metal/primitives.cpp | 1 + mlx/linalg.cpp | 46 +++++++++++--- mlx/linalg.h | 9 +-- mlx/primitives.cpp | 10 ++- mlx/primitives.h | 23 +++++-- python/src/linalg.cpp | 39 ++++++++++++ 12 files changed, 168 insertions(+), 45 deletions(-) diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 0ae00da17..f6c51ed0b 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -17,3 +17,4 @@ Linear Algebra qr svd eigvalsh + eigh diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 4350ebd8d..e4f07ccfb 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -82,6 +82,7 @@ DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) DEFAULT(Eigvalsh) +DEFAULT(Eigh) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index ecc7ee553..7588596fc 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -115,6 +115,7 @@ DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) DEFAULT(Eigvalsh) +DEFAULT(Eigh) namespace { diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp index 0841a451f..e7a97cdc5 100644 --- a/mlx/backend/common/eigvalsh.cpp +++ b/mlx/backend/common/eigvalsh.cpp @@ -101,60 +101,87 @@ int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) { void eigvalsh_impl( const array& a, array& values, - array& vectors, - bool upper, - bool compute_vectors) { - char jobz = compute_vectors ? 'V' : 'N'; - char uplo = (upper) ? 'U' : 'L'; // Use upper triangle of the matrix + bool upper) { + char jobz = 'N'; // Only compute eigenvalues + char uplo = (upper) ? 'U' : 'L'; - // Create a copy of the input array for in-place computation array buffer = copy(a); const int N = static_cast(a.shape(-1)); const int num_matrices = static_cast(a.size() / (N * N)); - // Allocate output arrays std::vector values_shape = {num_matrices, N}; - values = array({}, values_shape, a.dtype()); - - if (compute_vectors) { - vectors = array({}, a.shape(), a.dtype()); - } + values = array(allocator::malloc(num_matrices * N * size_of(a.dtype())), values_shape, a.dtype()); float* matrix = buffer.data(); float* w = values.data(); - float* vecs = compute_vectors ? vectors.data() : nullptr; for (int i = 0; i < num_matrices; i++) { - // Compute eigenvalue decomposition int info = ssyevd_wrapper(jobz, uplo, matrix, w, N); if (info != 0) { std::stringstream msg; - msg << "[eigvalsh] Eigenvalue decomposition failed with error code " - << info; + msg << "[eigvalsh] Eigenvalue decomposition failed with error code " << info; throw std::runtime_error(msg.str()); } - // Copy eigenvectors if computed - if (compute_vectors) { - std::copy(matrix, matrix + N * N, vecs); - vecs += N * N; + matrix += N * N; + w += N; + } +} + +void eigh_impl( + const array& a, + array& vectors, + bool upper) { + char jobz = 'V'; // Compute both eigenvalues and eigenvectors + char uplo = (upper) ? 'U' : 'L'; + + array buffer = copy(a); + + const int N = static_cast(a.shape(-1)); + const int num_matrices = static_cast(a.size() / (N * N)); + + std::vector vectors_shape = a.shape(); + vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype()); + + float* matrix = buffer.data(); + float* vecs = vectors.data(); + + // Temporary buffer for eigenvalues (we don't return these) + std::vector w(N); + + for (int i = 0; i < num_matrices; i++) { + int info = ssyevd_wrapper(jobz, uplo, matrix, w.data(), N); + + if (info != 0) { + std::stringstream msg; + msg << "[eigh] Eigenvalue decomposition failed with error code " << info; + throw std::runtime_error(msg.str()); } - // Move to next matrix + // Copy eigenvectors to the output array + std::copy(matrix, matrix + N * N, vecs); + matrix += N * N; - w += N; + vecs += N * N; } } void Eigvalsh::eval( - const std::vector& inputs, - std::vector& outputs) { + const std::vector& inputs, array& output) { if (inputs[0].dtype() != float32) { throw std::runtime_error("[Eigvalsh::eval] only supports float32."); } - eigvalsh_impl(inputs[0], output[0], output[1], upper_, compute_vectors_); + eigvalsh_impl(inputs[0], output, upper_); +} + +void Eigh::eval( + const std::vector& inputs, array& output) { + if (inputs[0].dtype() != float32) { + throw std::runtime_error("[Eigh::eval] only supports float32."); + } + eigh_impl(inputs[0], output, upper_); } } // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 88ff4919d..43d45e618 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -405,6 +405,10 @@ void Eigvalsh::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigvalsh NYI."); } +void Eigvh::eval_gpu(const std::vector& inputs, array& out) { + throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); +} + void View::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto ibytes = size_of(in.dtype()); diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index fbff70673..b5e62dab4 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -49,6 +49,7 @@ NO_CPU_MULTI(DivMod) NO_CPU(NumberOfElements) NO_CPU(Remainder) NO_CPU(Eigvalsh) +NO_CPU(Eigh) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 0c1e37bf3..b4e2974b5 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -113,6 +113,7 @@ NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) NO_GPU(Eigvalsh) +NO_GPU(Eigh) NO_GPU(View) namespace fast { diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 949053d32..7840ddf01 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -382,6 +382,7 @@ array cholesky_inv( } } +<<<<<<< HEAD array cross( const array& a, const array& b, @@ -454,10 +455,9 @@ array cross( return concatenate(outputs, axis, s); } -std::pair eigvalsh( +array eigvalsh( const array& a, - bool upper /* = true */, - bool compute_vectors /* = false */, + bool upper /* = false */, StreamOrDevice s /* = {} */) { if (a.dtype() != float32) { std::ostringstream msg; @@ -482,12 +482,42 @@ std::pair eigvalsh( std::vector out_shape(a.shape().begin(), a.shape().end() - 1); out_shape.back() = a.shape(-1); - auto out = array::make_arrays( - {out_shape, compute_vectors ? a.shape() : std::vector()}, - {a.dtype(), a.dtype()}, - std::make_shared(to_stream(s), upper, compute_vectors), + return array( + out_shape, + a.dtype(), + std::make_shared(to_stream(s), upper), + {astype(a, a.dtype(), s)}); +} + +array eigh( + const array& a, + bool upper /* = false */, + StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::eigh] Arrays must be type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::eigh] Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument( + "[linalg::eigh] Eigenvectors are only defined for square matrices."); + } + + return array( + a.shape(), + a.dtype(), + std::make_shared(to_stream(s), upper), {astype(a, a.dtype(), s)}); - return std::make_pair(out[0], out[1]); } } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index 1cc27509c..63328c637 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -83,11 +83,8 @@ array cross( int axis = -1, StreamOrDevice s = {}); -array eigvalsh(const array& a, bool upper = true, StreamOrDevice s = {}); -array eigvalsh( - const array& a, - bool upper = true, - bool compute_vectors = false, - StreamOrDevice s = {}); +array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {}); + +array eigh(const array& a, bool upper = false, StreamOrDevice s = {}); } // namespace mlx::core::linalg diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 9b049d2bc..87ed99d96 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -772,7 +772,15 @@ std::pair, std::vector> Eigvalsh::vmap( const std::vector& axes) { auto ax = axes[0] >= 0 ? 0 : -1; auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::eigvalsh(a, upper_, compute_vectors_, stream())}, {ax}}; + return {{linalg::eigvalsh(a, upper_, stream())}, {ax}}; +} + +std::pair, std::vector> Eigh::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::eigh(a, upper_, stream())}, {ax}}; } std::vector Concatenate::vjp( diff --git a/mlx/primitives.h b/mlx/primitives.h index 2bc5c3741..328219140 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2198,10 +2198,8 @@ class Cholesky : public UnaryPrimitive { class Eigvalsh : public UnaryPrimitive { public: - explicit Eigvalsh(Stream stream, bool upper, bool compute_vectors) - : UnaryPrimitive(stream), - upper_(upper), - compute_vectors_(compute_vectors) {} + explicit Eigvalsh(Stream stream, bool upper) + : UnaryPrimitive(stream), upper_(upper) {} void eval_cpu(const std::vector& inputs, array& out) override; void eval_gpu(const std::vector& inputs, array& out) override; @@ -2212,7 +2210,22 @@ class Eigvalsh : public UnaryPrimitive { private: void eval(const std::vector& inputs, array& output); bool upper_; - bool compute_vectors_; +}; + +class Eigh : public UnaryPrimitive { + public: + explicit Eigh(Stream stream, bool upper) + : UnaryPrimitive(stream), upper_(upper) {} + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_PRINT(Eigh) + + private: + void eval(const std::vector& inputs, array& output); + bool upper_; }; } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 9c6e3e761..b216a7ee0 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -442,4 +442,43 @@ void init_linalg(nb::module_& parent_module) { >>> eigenvalues array([-1., 3.], dtype=float32) )pbdoc"); + m.def( + "eigh", + &eigh, + "a"_a, + "upper"_a = true, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def eigh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the eigenvectors of a complex Hermitian or real symmetric matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvectors are computed for each matrix + in the last two dimensions of ``a``. + + Args: + a (array): Input array. Must be a real symmetric or complex Hermitian matrix. + upper (bool, optional): Whether to use the upper or lower triangle of the matrix. + Default is True (upper triangle). + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The normalized eigenvectors. The column v[:, i] is the + eigenvector corresponding to the i-th eigenvalue. + + Note: + The input matrix is assumed to be symmetric (or Hermitian). Only the + upper triangle (if upper=True) or lower triangle (if upper=False) is used. + No checks for symmetry are performed. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> v = mx.linalg.eigh(A) + >>> v + array([[ 0.707107, -0.707107], + [ 0.707107, 0.707107]], dtype=float32) + )pbdoc"); } From a1b7593e09ebfef6a94806bd9caa4fa1626a579a Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 29 Aug 2024 17:35:54 +0200 Subject: [PATCH 06/19] fixed typo --- mlx/backend/metal/primitives.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 43d45e618..d08035b54 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -405,7 +405,7 @@ void Eigvalsh::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigvalsh NYI."); } -void Eigvh::eval_gpu(const std::vector& inputs, array& out) { +void Eigh::eval_gpu(const std::vector& inputs, array& out) { throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); } From ffe26c8348c58bd54171df95dd130c7632041487 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Sep 2024 21:21:49 +0200 Subject: [PATCH 07/19] merge merge Eighvalsh and Eigh into a single primitive --- mlx/backend/common/eigvalsh.cpp | 78 ++++++++++++--------------------- mlx/primitives.cpp | 30 ++++++++----- mlx/primitives.h | 31 ++++--------- 3 files changed, 56 insertions(+), 83 deletions(-) diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp index e7a97cdc5..734eb3fd5 100644 --- a/mlx/backend/common/eigvalsh.cpp +++ b/mlx/backend/common/eigvalsh.cpp @@ -98,12 +98,14 @@ int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) { } // namespace -void eigvalsh_impl( +void eigh_impl( const array& a, array& values, - bool upper) { - char jobz = 'N'; // Only compute eigenvalues - char uplo = (upper) ? 'U' : 'L'; + array& vectors, + bool upper, + bool compute_eigenvectors) { + char jobz = compute_eigenvectors ? 'V' : 'N'; + char uplo = upper ? 'U' : 'L'; array buffer = copy(a); @@ -116,43 +118,15 @@ void eigvalsh_impl( float* matrix = buffer.data(); float* w = values.data(); - for (int i = 0; i < num_matrices; i++) { - int info = ssyevd_wrapper(jobz, uplo, matrix, w, N); - - if (info != 0) { - std::stringstream msg; - msg << "[eigvalsh] Eigenvalue decomposition failed with error code " << info; - throw std::runtime_error(msg.str()); - } - - matrix += N * N; - w += N; + if (compute_eigenvectors) { + std::vector vectors_shape = a.shape(); + vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype()); } -} -void eigh_impl( - const array& a, - array& vectors, - bool upper) { - char jobz = 'V'; // Compute both eigenvalues and eigenvectors - char uplo = (upper) ? 'U' : 'L'; - - array buffer = copy(a); - - const int N = static_cast(a.shape(-1)); - const int num_matrices = static_cast(a.size() / (N * N)); - - std::vector vectors_shape = a.shape(); - vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype()); - - float* matrix = buffer.data(); - float* vecs = vectors.data(); - - // Temporary buffer for eigenvalues (we don't return these) - std::vector w(N); + float* vecs = compute_eigenvectors ? vectors.data() : nullptr; for (int i = 0; i < num_matrices; i++) { - int info = ssyevd_wrapper(jobz, uplo, matrix, w.data(), N); + int info = ssyevd_wrapper(jobz, uplo, matrix, w, N); if (info != 0) { std::stringstream msg; @@ -160,28 +134,32 @@ void eigh_impl( throw std::runtime_error(msg.str()); } - // Copy eigenvectors to the output array - std::copy(matrix, matrix + N * N, vecs); + if (compute_eigenvectors) { + // Copy eigenvectors to the output array + std::copy(matrix, matrix + N * N, vecs); + vecs += N * N; + } matrix += N * N; - vecs += N * N; + w += N; } } -void Eigvalsh::eval( - const std::vector& inputs, array& output) { +void EighPrimitive::eval( + const std::vector& inputs, + std::vector& outputs) { if (inputs[0].dtype() != float32) { - throw std::runtime_error("[Eigvalsh::eval] only supports float32."); + throw std::runtime_error("[EighPrimitive::eval] only supports float32."); } - eigvalsh_impl(inputs[0], output, upper_); -} -void Eigh::eval( - const std::vector& inputs, array& output) { - if (inputs[0].dtype() != float32) { - throw std::runtime_error("[Eigh::eval] only supports float32."); + array values, vectors; + eigh_impl(inputs[0], values, vectors, upper_, compute_eigenvectors_); + + if (compute_eigenvectors_) { + outputs = {values, vectors}; + } else { + outputs = {values}; } - eigh_impl(inputs[0], output, upper_); } } // namespace mlx::core \ No newline at end of file diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 87ed99d96..52da9c2f6 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -767,20 +767,30 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } -std::pair, std::vector> Eigvalsh::vmap( - const std::vector& inputs, - const std::vector& axes) { - auto ax = axes[0] >= 0 ? 0 : -1; - auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::eigvalsh(a, upper_, stream())}, {ax}}; -} - -std::pair, std::vector> Eigh::vmap( +std::pair, std::vector> EighPrimitive::vmap( const std::vector& inputs, const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + auto ax = axes[0] >= 0 ? 0 : -1; auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - return {{linalg::eigh(a, upper_, stream())}, {ax}}; + + array values, vectors; + linalg::eigh_impl(a, values, vectors, upper_, compute_eigenvectors_); + + std::vector outputs; + std::vector out_axes; + + outputs.push_back(values); + out_axes.push_back(ax); + + if (compute_eigenvectors_) { + outputs.push_back(vectors); + out_axes.push_back(ax); + } + + return {outputs, out_axes}; } std::vector Concatenate::vjp( diff --git a/mlx/primitives.h b/mlx/primitives.h index 328219140..25b80e010 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2196,36 +2196,21 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; -class Eigvalsh : public UnaryPrimitive { +class EighPrimitive : public Primitive { public: - explicit Eigvalsh(Stream stream, bool upper) - : UnaryPrimitive(stream), upper_(upper) {} - - void eval_cpu(const std::vector& inputs, array& out) override; - void eval_gpu(const std::vector& inputs, array& out) override; - - DEFINE_VMAP() - DEFINE_PRINT(Eigvalsh) - - private: - void eval(const std::vector& inputs, array& output); - bool upper_; -}; + explicit EighPrimitive(Stream stream, bool upper, bool compute_eigenvectors) + : Primitive(stream), upper_(upper), compute_eigenvectors_(compute_eigenvectors) {} -class Eigh : public UnaryPrimitive { - public: - explicit Eigh(Stream stream, bool upper) - : UnaryPrimitive(stream), upper_(upper) {} - - void eval_cpu(const std::vector& inputs, array& out) override; - void eval_gpu(const std::vector& inputs, array& out) override; + void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_VMAP() - DEFINE_PRINT(Eigh) + DEFINE_PRINT(EighPrimitive) private: - void eval(const std::vector& inputs, array& output); + void eval(const std::vector& inputs, std::vector& outputs); bool upper_; + bool compute_eigenvectors_; }; } // namespace mlx::core From d6e8cab816a70fd0879285e7594c42f98af39ab5 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Sep 2024 21:25:02 +0200 Subject: [PATCH 08/19] use the same primate with the flag --- mlx/linalg.cpp | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 7840ddf01..d27731961 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -485,11 +485,11 @@ array eigvalsh( return array( out_shape, a.dtype(), - std::make_shared(to_stream(s), upper), + std::make_shared(to_stream(s), upper, false), {astype(a, a.dtype(), s)}); } -array eigh( +std::pair eigh( const array& a, bool upper /* = false */, StreamOrDevice s /* = {} */) { @@ -513,11 +513,12 @@ array eigh( "[linalg::eigh] Eigenvectors are only defined for square matrices."); } - return array( - a.shape(), - a.dtype(), - std::make_shared(to_stream(s), upper), + auto out = array::make_arrays( + {std::vector(a.shape().begin(), a.shape().end() - 1), a.shape()}, + {a.dtype(), a.dtype()}, + std::make_shared(to_stream(s), upper, true), {astype(a, a.dtype(), s)}); + return std::make_pair(out[0], out[1]); } } // namespace mlx::core::linalg From 052e49db35868227ea38a77a2ccae1eaea27b75e Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Sep 2024 21:33:51 +0200 Subject: [PATCH 09/19] fix primatives --- mlx/backend/accelerate/primitives.cpp | 3 +-- mlx/backend/common/default_primitives.cpp | 3 +-- mlx/backend/no_cpu/primitives.cpp | 3 +-- mlx/backend/no_metal/primitives.cpp | 3 +-- mlx/linalg.h | 2 +- 5 files changed, 5 insertions(+), 9 deletions(-) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index e4f07ccfb..3c1b88903 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,8 +81,7 @@ DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) -DEFAULT(Eigvalsh) -DEFAULT(Eigh) +DEFAULT(EighPrimitive) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 7588596fc..c8207732f 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -114,8 +114,7 @@ DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) -DEFAULT(Eigvalsh) -DEFAULT(Eigh) +DEFAULT(EighPrimitive) namespace { diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index b5e62dab4..4429ad68d 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -48,8 +48,7 @@ NO_CPU(Divide) NO_CPU_MULTI(DivMod) NO_CPU(NumberOfElements) NO_CPU(Remainder) -NO_CPU(Eigvalsh) -NO_CPU(Eigh) +NO_CPU(EighPrimitive) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index b4e2974b5..bf8bcbb22 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -112,8 +112,7 @@ NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) -NO_GPU(Eigvalsh) -NO_GPU(Eigh) +NO_GPU(EighPrimitive) NO_GPU(View) namespace fast { diff --git a/mlx/linalg.h b/mlx/linalg.h index 63328c637..ff7203b8f 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -85,6 +85,6 @@ array cross( array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {}); -array eigh(const array& a, bool upper = false, StreamOrDevice s = {}); +std::pair eigh(const array& a, bool upper = false, StreamOrDevice s = {}); } // namespace mlx::core::linalg From 273da20bfc8bfcc8c3e203e12eb392cff7a675c0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Sep 2024 22:22:28 +0200 Subject: [PATCH 10/19] use MULTI --- mlx/backend/accelerate/primitives.cpp | 2 +- mlx/backend/common/default_primitives.cpp | 2 +- mlx/backend/common/eigvalsh.cpp | 27 +++++++++++++++++++---- mlx/backend/no_cpu/primitives.cpp | 2 +- mlx/backend/no_metal/primitives.cpp | 2 +- mlx/primitives.cpp | 16 +++++--------- mlx/primitives.h | 19 +++++++++++++++- 7 files changed, 51 insertions(+), 19 deletions(-) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 3c1b88903..352ca9f2f 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,7 +81,7 @@ DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) -DEFAULT(EighPrimitive) +DEFAULT_MULTI(EighPrimitive) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index c8207732f..2418b18a6 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -114,7 +114,7 @@ DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) -DEFAULT(EighPrimitive) +DEFAULT_MULTI(EighPrimitive) namespace { diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp index 734eb3fd5..c21f2bcb4 100644 --- a/mlx/backend/common/eigvalsh.cpp +++ b/mlx/backend/common/eigvalsh.cpp @@ -1,5 +1,6 @@ // Copyright © 2023-2024 Apple Inc. +#include "mlx/array.h" #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" #include "mlx/linalg.h" @@ -148,14 +149,32 @@ void eigh_impl( void EighPrimitive::eval( const std::vector& inputs, std::vector& outputs) { - if (inputs[0].dtype() != float32) { - throw std::runtime_error("[EighPrimitive::eval] only supports float32."); + // Validate the number of inputs + if (inputs.size() != 1) { + throw std::invalid_argument("[EighPrimitive::eval] Expected exactly one input array."); } - array values, vectors; - eigh_impl(inputs[0], values, vectors, upper_, compute_eigenvectors_); + const array& input = inputs[0]; + // Ensure the input array is evaluated before accessing its data + const_cast(input).eval(); + + // Validate the data type + Dtype input_dtype = input.dtype(); // Changed from 'dtype_t' to 'Dtype' + + // Validate the number of dimensions (expecting at least 2D) + if (input.ndim() < 2) { + throw std::invalid_argument("[EighPrimitive::eval] Input array must be at least 2-dimensional."); + } + + array values{}; + array vectors{}; + eigh_impl(input, values, vectors, upper_, compute_eigenvectors_); + + // Ensure the output arrays are evaluated + values.eval(); if (compute_eigenvectors_) { + vectors.eval(); outputs = {values, vectors}; } else { outputs = {values}; diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 4429ad68d..180ae3b17 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -48,7 +48,7 @@ NO_CPU(Divide) NO_CPU_MULTI(DivMod) NO_CPU(NumberOfElements) NO_CPU(Remainder) -NO_CPU(EighPrimitive) +NO_CPU_MULTI(EighPrimitive) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index bf8bcbb22..77ca47ab4 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -112,7 +112,7 @@ NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) -NO_GPU(EighPrimitive) +NO_GPU_MULTI(EighPrimitive) NO_GPU(View) namespace fast { diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 52da9c2f6..64224b876 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -776,20 +776,16 @@ std::pair, std::vector> EighPrimitive::vmap( auto ax = axes[0] >= 0 ? 0 : -1; auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - array values, vectors; - linalg::eigh_impl(a, values, vectors, upper_, compute_eigenvectors_); - std::vector outputs; - std::vector out_axes; - - outputs.push_back(values); - out_axes.push_back(ax); - if (compute_eigenvectors_) { - outputs.push_back(vectors); - out_axes.push_back(ax); + auto [values, vectors] = linalg::eigh(a, upper_, stream()); + outputs = {values, vectors}; + } else { + outputs = {linalg::eigvalsh(a, upper_, stream())}; } + std::vector out_axes(outputs.size(), ax); + return {outputs, out_axes}; } diff --git a/mlx/primitives.h b/mlx/primitives.h index 25b80e010..91f869ffc 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2195,7 +2195,6 @@ class Cholesky : public UnaryPrimitive { void eval(const std::vector& inputs, array& output); bool upper_; }; - class EighPrimitive : public Primitive { public: explicit EighPrimitive(Stream stream, bool upper, bool compute_eigenvectors) @@ -2207,6 +2206,24 @@ class EighPrimitive : public Primitive { DEFINE_VMAP() DEFINE_PRINT(EighPrimitive) + std::vector> output_shapes( + const std::vector& inputs) override { + auto shape = inputs[0].shape(); + shape.pop_back(); // Remove last dimension for eigenvalues + if (compute_eigenvectors_) { + return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors + } else { + return {shape}; // Only eigenvalues + } + } + + bool is_equivalent(const Primitive& other) const override { + if (auto* p = dynamic_cast(&other)) { + return upper_ == p->upper_ && compute_eigenvectors_ == p->compute_eigenvectors_; + } + return false; + } + private: void eval(const std::vector& inputs, std::vector& outputs); bool upper_; From c76b8a194da7f4f133b2f20eac241520af7ed9dd Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Sep 2024 22:26:30 +0200 Subject: [PATCH 11/19] fix eval_gpu --- mlx/backend/metal/primitives.cpp | 8 ++------ mlx/primitives.h | 1 + 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d08035b54..ee3d5bb5e 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -401,12 +401,8 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } -void Eigvalsh::eval_gpu(const std::vector& inputs, array& out) { - throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigvalsh NYI."); -} - -void Eigh::eval_gpu(const std::vector& inputs, array& out) { - throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI."); +void EighPrimitive::eval_gpu(const std::vector& inputs, array& out) { + throw std::runtime_error("[Eigvalsh::eval_gpu] Metal EighPrimitive NYI."); } void View::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/primitives.h b/mlx/primitives.h index 91f869ffc..f5ebf95e7 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2195,6 +2195,7 @@ class Cholesky : public UnaryPrimitive { void eval(const std::vector& inputs, array& output); bool upper_; }; + class EighPrimitive : public Primitive { public: explicit EighPrimitive(Stream stream, bool upper, bool compute_eigenvectors) From e07c9e8a819c37903f2c80a51a2fef7f9346b959 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Sun, 15 Sep 2024 22:29:51 +0200 Subject: [PATCH 12/19] fix decleration --- mlx/backend/metal/primitives.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index ee3d5bb5e..c7fbfd7b5 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -401,7 +401,7 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } -void EighPrimitive::eval_gpu(const std::vector& inputs, array& out) { +void EighPrimitive::eval_gpu(const std::vector& inputs, std::vector& outputs) { throw std::runtime_error("[Eigvalsh::eval_gpu] Metal EighPrimitive NYI."); } From 5cee17f09ba047e71c0b5f8350dd7e56fece1791 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 1 Oct 2024 17:56:58 +0200 Subject: [PATCH 13/19] rename EighPrimitive to Eigh --- mlx/backend/accelerate/primitives.cpp | 2 +- mlx/backend/common/default_primitives.cpp | 2 +- mlx/backend/common/eigvalsh.cpp | 6 +++--- mlx/backend/metal/primitives.cpp | 4 ++-- mlx/backend/no_cpu/primitives.cpp | 2 +- mlx/backend/no_metal/primitives.cpp | 2 +- mlx/linalg.cpp | 1 - mlx/primitives.cpp | 2 +- mlx/primitives.h | 8 ++++---- 9 files changed, 14 insertions(+), 15 deletions(-) diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 352ca9f2f..1f80224ad 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,7 +81,7 @@ DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) -DEFAULT_MULTI(EighPrimitive) +DEFAULT_MULTI(Eigh) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index 2418b18a6..edc7192bd 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -114,7 +114,7 @@ DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) -DEFAULT_MULTI(EighPrimitive) +DEFAULT_MULTI(Eigh) namespace { diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp index c21f2bcb4..c49b291b0 100644 --- a/mlx/backend/common/eigvalsh.cpp +++ b/mlx/backend/common/eigvalsh.cpp @@ -146,12 +146,12 @@ void eigh_impl( } } -void EighPrimitive::eval( +void Eigh::eval( const std::vector& inputs, std::vector& outputs) { // Validate the number of inputs if (inputs.size() != 1) { - throw std::invalid_argument("[EighPrimitive::eval] Expected exactly one input array."); + throw std::invalid_argument("[Eigh::eval] Expected exactly one input array."); } const array& input = inputs[0]; @@ -164,7 +164,7 @@ void EighPrimitive::eval( // Validate the number of dimensions (expecting at least 2D) if (input.ndim() < 2) { - throw std::invalid_argument("[EighPrimitive::eval] Input array must be at least 2-dimensional."); + throw std::invalid_argument("[Eigh::eval] Input array must be at least 2-dimensional."); } array values{}; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index c7fbfd7b5..be7c3469e 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -401,8 +401,8 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } -void EighPrimitive::eval_gpu(const std::vector& inputs, std::vector& outputs) { - throw std::runtime_error("[Eigvalsh::eval_gpu] Metal EighPrimitive NYI."); +void Eigh::eval_gpu(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); } void View::eval_gpu(const std::vector& inputs, array& out) { diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index 180ae3b17..c87fcc8bb 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -48,7 +48,7 @@ NO_CPU(Divide) NO_CPU_MULTI(DivMod) NO_CPU(NumberOfElements) NO_CPU(Remainder) -NO_CPU_MULTI(EighPrimitive) +NO_CPU_MULTI(Eigh) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 77ca47ab4..aaee51d83 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -112,7 +112,7 @@ NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) -NO_GPU_MULTI(EighPrimitive) +NO_GPU_MULTI(Eigh) NO_GPU(View) namespace fast { diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index d27731961..71bfed053 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -382,7 +382,6 @@ array cholesky_inv( } } -<<<<<<< HEAD array cross( const array& a, const array& b, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 64224b876..0ebed3bd1 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -767,7 +767,7 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } -std::pair, std::vector> EighPrimitive::vmap( +std::pair, std::vector> Eigh::vmap( const std::vector& inputs, const std::vector& axes) { assert(inputs.size() == 1); diff --git a/mlx/primitives.h b/mlx/primitives.h index f5ebf95e7..082ae105d 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2196,16 +2196,16 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; -class EighPrimitive : public Primitive { +class Eigh : public Primitive { public: - explicit EighPrimitive(Stream stream, bool upper, bool compute_eigenvectors) + explicit Eigh(Stream stream, bool upper, bool compute_eigenvectors) : Primitive(stream), upper_(upper), compute_eigenvectors_(compute_eigenvectors) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) override; void eval_gpu(const std::vector& inputs, std::vector& outputs) override; DEFINE_VMAP() - DEFINE_PRINT(EighPrimitive) + DEFINE_PRINT(Eigh) std::vector> output_shapes( const std::vector& inputs) override { @@ -2219,7 +2219,7 @@ class EighPrimitive : public Primitive { } bool is_equivalent(const Primitive& other) const override { - if (auto* p = dynamic_cast(&other)) { + if (auto* p = dynamic_cast(&other)) { return upper_ == p->upper_ && compute_eigenvectors_ == p->compute_eigenvectors_; } return false; From d6a5fb037e5257fe7c084731806f047765d189b8 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 1 Oct 2024 17:59:10 +0200 Subject: [PATCH 14/19] tests --- python/tests/test_linalg.py | 67 +++++++++++++++++++++++++++++++++++++ tests/linalg_tests.cpp | 51 ++++++++++++++++++++++------ 2 files changed, 108 insertions(+), 10 deletions(-) diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 1d336ee21..39c0449b4 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -309,6 +309,9 @@ def test_eigvalsh(self): self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + # Test that eigenvalues are in ascending order for each batch + self.assertTrue(mx.all(eigenvalues_mx[:, 1:] >= eigenvalues_mx[:, :-1])) + # Test error cases with self.assertRaises(ValueError): mx.linalg.eigvalsh(mx.array([1.0, 2.0])) # 1D array @@ -316,6 +319,70 @@ def test_eigvalsh(self): with self.assertRaises(ValueError): mx.linalg.eigvalsh(mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) # Non-square matrix + def test_eigh(self): + # Test a simple 2x2 symmetric matrix + A_mx = mx.array([[1.0, 2.0], [2.0, 4.0]], dtype=mx.float32) + A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32) + + eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(A_mx) + eigenvalues_np, eigenvectors_np = np.linalg.eigh(A_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + self.assertTrue(mx.allclose(mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5)) + + # Test a larger random symmetric matrix + n = 5 + rng = np.random.default_rng(42) + B_np = rng.random((n, n)).astype(np.float32) + B_np = (B_np + B_np.T) / 2 # Make sure B is symmetric + B_mx = mx.array(B_np) + + eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(B_mx) + eigenvalues_np, eigenvectors_np = np.linalg.eigh(B_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + self.assertTrue(mx.allclose(mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5)) + + # Test that eigenvalues are in ascending order + self.assertTrue(mx.all(eigenvalues_mx[1:] >= eigenvalues_mx[:-1])) + + # Test orthogonality of eigenvectors + identity = mx.eye(n) + self.assertTrue(mx.allclose(eigenvectors_mx.T @ eigenvectors_mx, identity, atol=1e-5)) + + # Test with upper=False + eigenvalues_mx_lower, eigenvectors_mx_lower = mx.linalg.eigh(B_mx, upper=False) + eigenvalues_np_lower, eigenvectors_np_lower = np.linalg.eigh(B_np, UPLO='L') + + self.assertTrue(mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5)) + self.assertTrue(mx.allclose(mx.abs(eigenvectors_mx_lower), mx.abs(mx.array(eigenvectors_np_lower)), atol=1e-5)) + + # Test with batched input + C_np = rng.random((3, n, n)).astype(np.float32) + C_np = (C_np + np.transpose(C_np, (0, 2, 1))) / 2 # Make sure C is symmetric for each batch + C_mx = mx.array(C_np) + + eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(C_mx) + eigenvalues_np, eigenvectors_np = np.linalg.eigh(C_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + self.assertTrue(mx.allclose(mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5)) + + # Test that eigenvalues are in ascending order for each batch + self.assertTrue(mx.all(eigenvalues_mx[:, 1:] >= eigenvalues_mx[:, :-1])) + + # Test orthogonality of eigenvectors for each batch + identity = mx.eye(n) + for i in range(3): + self.assertTrue(mx.allclose(eigenvectors_mx[i].T @ eigenvectors_mx[i], identity, atol=1e-5)) + + # Test error cases + with self.assertRaises(ValueError): + mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigh(mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) # Non-square matrix + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 0a89bcf22..242c7b851 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -349,7 +349,6 @@ TEST_CASE("test matrix cholesky") { .item()); } -<<<<<<< HEAD TEST_CASE("test matrix pseudo-inverse") { // 0D and 1D throw CHECK_THROWS(linalg::pinv(array(0.0), Device::cpu)); @@ -452,7 +451,7 @@ TEST_CASE("test matrix eigvalsh") { array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); // Test a simple 2x2 symmetric matrix - array A = array({{1.0, 2.0}, {2.0, 4.0}}); + array A = array({{1.0f, 2.0f}, {2.0f, 4.0f}}, float32); array eigenvalues = linalg::eigvalsh(A, /* upper = */ true, Device::cpu); // Expected eigenvalues (calculated analytically) @@ -465,7 +464,7 @@ TEST_CASE("test matrix eigvalsh") { /* atol = */ 1e-5) .item()); - // Test a larger symmetric matrix + /// Test a larger symmetric matrix const auto prng_key = random::key(42); const auto B = random::normal({5, 5}, prng_key); const auto B_sym = 0.5 * (B + transpose(B)); // Make sure B is symmetric @@ -476,15 +475,47 @@ TEST_CASE("test matrix eigvalsh") { CHECK(B_eigenvalues.dtype() == float32); CHECK(B_eigenvalues.shape() == std::vector{5}); CHECK(all(isfinite(B_eigenvalues)).item()); - CHECK(all(B_eigenvalues [1:] >= B_eigenvalues[:-1]).item()); + CHECK(all(B_eigenvalues[slice(1, 5)] >= B_eigenvalues[slice(0, 4)]).item()); - // Reconstruct the matrix using eigendecomposition and check if it's close to - // the original - const auto D = diag(B_eigenvalues); - const auto V = linalg::eigh(B_sym, /* upper = */ true, Device::cpu) - .second; // Assuming eigh is implemented - const auto B_reconstructed = matmul(matmul(V, D), transpose(V)); + // Reconstruct the matrix using eigendecomposition and check if it's close to the original + const auto [eigenvalues, eigenvectors] = linalg::eigh(B_sym, /* upper = */ true, Device::cpu); + const auto B_reconstructed = matmul(matmul(eigenvectors, diag(eigenvalues)), transpose(eigenvectors)); CHECK(allclose(B_reconstructed, B_sym, /* rtol = */ 1e-5, /* atol = */ 1e-5) .item()); + + // Check that eigvalsh and eigh produce the same eigenvalues + CHECK(allclose(B_eigenvalues, eigenvalues, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); +} + +TEST_CASE("test matrix eigh") { + // 0D and 1D throw + CHECK_THROWS(linalg::eigh(array(0.0), /* upper = */ true, Device::cpu)); + CHECK_THROWS(linalg::eigh(array({0.0, 1.0}), /* upper = */ true, Device::cpu)); + + // Unsupported types throw + CHECK_THROWS(linalg::eigh(array({0, 1}, {1, 2}), /* upper = */ true, Device::cpu)); + + // Non-square throws + CHECK_THROWS(linalg::eigh(array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); + + // Test a simple 2x2 symmetric matrix + array A = array({{1.0f, 2.0f}, {2.0f, 4.0f}}, float32); + auto [eigenvalues, eigenvectors] = linalg::eigh(A, /* upper = */ true, Device::cpu); + + // Expected eigenvalues and eigenvectors (calculated analytically) + array expected_eigenvalues = array({0.0, 5.0}); + array expected_eigenvectors = array({{-0.4472136, 0.8944272}, {0.8944272, 0.4472136}}, float32); + + CHECK(allclose(eigenvalues, expected_eigenvalues, /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); + CHECK(allclose(abs(eigenvectors), abs(expected_eigenvectors), /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); + + // Verify orthogonality of eigenvectors + CHECK(allclose(matmul(transpose(eigenvectors), eigenvectors), eye(2), /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); + + // Verify eigendecomposition + CHECK(allclose(matmul(matmul(eigenvectors, diag(eigenvalues)), transpose(eigenvectors)), A, /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); } + + From f09376ef823fccc694dbd58535c6f37ce4b6ecd7 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 1 Oct 2024 18:24:08 +0200 Subject: [PATCH 15/19] tests --- tests/linalg_tests.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 242c7b851..1c4dea1f7 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -517,5 +517,3 @@ TEST_CASE("test matrix eigh") { // Verify eigendecomposition CHECK(allclose(matmul(matmul(eigenvectors, diag(eigenvalues)), transpose(eigenvectors)), A, /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); } - - From 25b79dba0e83ec145a139d3c0e4d7c57d9850f69 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 06:37:45 -0700 Subject: [PATCH 16/19] fix rebase and format --- mlx/backend/common/eigvalsh.cpp | 24 +++++---- mlx/backend/metal/primitives.cpp | 4 +- mlx/linalg.cpp | 10 ++-- mlx/linalg.h | 3 +- mlx/primitives.cpp | 8 +-- mlx/primitives.h | 19 ++++--- python/src/linalg.cpp | 23 ++++++--- python/tests/test_linalg.py | 88 ++++++++++++++++++++++++-------- tests/linalg_tests.cpp | 67 ++++++++++++++++++------ 9 files changed, 173 insertions(+), 73 deletions(-) diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp index c49b291b0..62eadf0dc 100644 --- a/mlx/backend/common/eigvalsh.cpp +++ b/mlx/backend/common/eigvalsh.cpp @@ -1,7 +1,7 @@ // Copyright © 2023-2024 Apple Inc. -#include "mlx/array.h" #include "mlx/allocator.h" +#include "mlx/array.h" #include "mlx/backend/common/copy.h" #include "mlx/linalg.h" #include "mlx/primitives.h" @@ -114,14 +114,20 @@ void eigh_impl( const int num_matrices = static_cast(a.size() / (N * N)); std::vector values_shape = {num_matrices, N}; - values = array(allocator::malloc(num_matrices * N * size_of(a.dtype())), values_shape, a.dtype()); + values = array( + allocator::malloc(num_matrices * N * size_of(a.dtype())), + values_shape, + a.dtype()); float* matrix = buffer.data(); float* w = values.data(); if (compute_eigenvectors) { std::vector vectors_shape = a.shape(); - vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype()); + vectors = array( + allocator::malloc(a.size() * size_of(a.dtype())), + vectors_shape, + a.dtype()); } float* vecs = compute_eigenvectors ? vectors.data() : nullptr; @@ -146,12 +152,11 @@ void eigh_impl( } } -void Eigh::eval( - const std::vector& inputs, - std::vector& outputs) { +void Eigh::eval(const std::vector& inputs, std::vector& outputs) { // Validate the number of inputs if (inputs.size() != 1) { - throw std::invalid_argument("[Eigh::eval] Expected exactly one input array."); + throw std::invalid_argument( + "[Eigh::eval] Expected exactly one input array."); } const array& input = inputs[0]; @@ -160,11 +165,12 @@ void Eigh::eval( const_cast(input).eval(); // Validate the data type - Dtype input_dtype = input.dtype(); // Changed from 'dtype_t' to 'Dtype' + Dtype input_dtype = input.dtype(); // Changed from 'dtype_t' to 'Dtype' // Validate the number of dimensions (expecting at least 2D) if (input.ndim() < 2) { - throw std::invalid_argument("[Eigh::eval] Input array must be at least 2-dimensional."); + throw std::invalid_argument( + "[Eigh::eval] Input array must be at least 2-dimensional."); } array values{}; diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index be7c3469e..e5a7d885b 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -401,7 +401,9 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } -void Eigh::eval_gpu(const std::vector& inputs, std::vector& outputs) { +void Eigh::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigh NYI."); } diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 71bfed053..b7a58736e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -484,14 +484,12 @@ array eigvalsh( return array( out_shape, a.dtype(), - std::make_shared(to_stream(s), upper, false), + std::make_shared(to_stream(s), upper, false), {astype(a, a.dtype(), s)}); } -std::pair eigh( - const array& a, - bool upper /* = false */, - StreamOrDevice s /* = {} */) { +std::pair +eigh(const array& a, bool upper /* = false */, StreamOrDevice s /* = {} */) { if (a.dtype() != float32) { std::ostringstream msg; msg << "[linalg::eigh] Arrays must be type float32. Received array " @@ -515,7 +513,7 @@ std::pair eigh( auto out = array::make_arrays( {std::vector(a.shape().begin(), a.shape().end() - 1), a.shape()}, {a.dtype(), a.dtype()}, - std::make_shared(to_stream(s), upper, true), + std::make_shared(to_stream(s), upper, true), {astype(a, a.dtype(), s)}); return std::make_pair(out[0], out[1]); } diff --git a/mlx/linalg.h b/mlx/linalg.h index ff7203b8f..1b29b5dc5 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -85,6 +85,7 @@ array cross( array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {}); -std::pair eigh(const array& a, bool upper = false, StreamOrDevice s = {}); +std::pair +eigh(const array& a, bool upper = false, StreamOrDevice s = {}); } // namespace mlx::core::linalg diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 0ebed3bd1..43148ad48 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -772,10 +772,10 @@ std::pair, std::vector> Eigh::vmap( const std::vector& axes) { assert(inputs.size() == 1); assert(axes.size() == 1); - + auto ax = axes[0] >= 0 ? 0 : -1; auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; - + std::vector outputs; if (compute_eigenvectors_) { auto [values, vectors] = linalg::eigh(a, upper_, stream()); @@ -783,9 +783,9 @@ std::pair, std::vector> Eigh::vmap( } else { outputs = {linalg::eigvalsh(a, upper_, stream())}; } - + std::vector out_axes(outputs.size(), ax); - + return {outputs, out_axes}; } diff --git a/mlx/primitives.h b/mlx/primitives.h index 082ae105d..64d9ca94e 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2199,10 +2199,14 @@ class Cholesky : public UnaryPrimitive { class Eigh : public Primitive { public: explicit Eigh(Stream stream, bool upper, bool compute_eigenvectors) - : Primitive(stream), upper_(upper), compute_eigenvectors_(compute_eigenvectors) {} + : Primitive(stream), + upper_(upper), + compute_eigenvectors_(compute_eigenvectors) {} - void eval_cpu(const std::vector& inputs, std::vector& outputs) override; - void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; DEFINE_VMAP() DEFINE_PRINT(Eigh) @@ -2210,17 +2214,18 @@ class Eigh : public Primitive { std::vector> output_shapes( const std::vector& inputs) override { auto shape = inputs[0].shape(); - shape.pop_back(); // Remove last dimension for eigenvalues + shape.pop_back(); // Remove last dimension for eigenvalues if (compute_eigenvectors_) { - return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors + return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors } else { - return {shape}; // Only eigenvalues + return {shape}; // Only eigenvalues } } bool is_equivalent(const Primitive& other) const override { if (auto* p = dynamic_cast(&other)) { - return upper_ == p->upper_ && compute_eigenvectors_ == p->compute_eigenvectors_; + return upper_ == p->upper_ && + compute_eigenvectors_ == p->compute_eigenvectors_; } return false; } diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index b216a7ee0..869444cd3 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -444,18 +444,21 @@ void init_linalg(nb::module_& parent_module) { )pbdoc"); m.def( "eigh", - &eigh, + [](const array& a, bool upper, StreamOrDevice s) { + auto result = eigh(a, upper, s); + return nb::make_tuple(result.first, result.second); + }, "a"_a, "upper"_a = true, nb::kw_only(), "stream"_a = nb::none(), nb::sig( - "def eigh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> array"), + "def eigh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), R"pbdoc( - Compute the eigenvectors of a complex Hermitian or real symmetric matrix. + Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix. This function supports arrays with at least 2 dimensions. When the input - has more than two dimensions, the eigenvectors are computed for each matrix + has more than two dimensions, the eigenvalues and eigenvectors are computed for each matrix in the last two dimensions of ``a``. Args: @@ -466,8 +469,10 @@ void init_linalg(nb::module_& parent_module) { in which case the default stream of the default device is used. Returns: - array: The normalized eigenvectors. The column v[:, i] is the - eigenvector corresponding to the i-th eigenvalue. + Tuple[array, array]: A tuple containing: + - The eigenvalues in ascending order. + - The normalized eigenvectors. The column v[:, i] is the + eigenvector corresponding to the i-th eigenvalue. Note: The input matrix is assumed to be symmetric (or Hermitian). Only the @@ -476,9 +481,11 @@ void init_linalg(nb::module_& parent_module) { Example: >>> A = mx.array([[1., -2.], [-2., 1.]]) - >>> v = mx.linalg.eigh(A) + >>> w, v = mx.linalg.eigh(A) + >>> w + array([-1., 3.], dtype=float32) >>> v array([[ 0.707107, -0.707107], - [ 0.707107, 0.707107]], dtype=float32) + [ 0.707107, 0.707107]], dtype=float32) )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 39c0449b4..edf671b88 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -276,7 +276,9 @@ def test_eigvalsh(self): eigenvalues_mx = mx.linalg.eigvalsh(A_mx) eigenvalues_np = np.linalg.eigvalsh(A_np) - self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) + ) # Test a larger random symmetric matrix n = 5 @@ -288,26 +290,34 @@ def test_eigvalsh(self): eigenvalues_mx = mx.linalg.eigvalsh(B_mx) eigenvalues_np = np.linalg.eigvalsh(B_np) - self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) + ) # Test that eigenvalues are in ascending order self.assertTrue(mx.all(eigenvalues_mx[1:] >= eigenvalues_mx[:-1])) # Test with upper=False eigenvalues_mx_lower = mx.linalg.eigvalsh(B_mx, upper=False) - eigenvalues_np_lower = np.linalg.eigvalsh(B_np, UPLO='L') + eigenvalues_np_lower = np.linalg.eigvalsh(B_np, UPLO="L") - self.assertTrue(mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5) + ) # Test with batched input C_np = rng.random((3, n, n)).astype(np.float32) - C_np = (C_np + np.transpose(C_np, (0, 2, 1))) / 2 # Make sure C is symmetric for each batch + C_np = ( + C_np + np.transpose(C_np, (0, 2, 1)) + ) / 2 # Make sure C is symmetric for each batch C_mx = mx.array(C_np) eigenvalues_mx = mx.linalg.eigvalsh(C_mx) eigenvalues_np = np.linalg.eigvalsh(C_np) - self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) + ) # Test that eigenvalues are in ascending order for each batch self.assertTrue(mx.all(eigenvalues_mx[:, 1:] >= eigenvalues_mx[:, :-1])) @@ -317,7 +327,9 @@ def test_eigvalsh(self): mx.linalg.eigvalsh(mx.array([1.0, 2.0])) # 1D array with self.assertRaises(ValueError): - mx.linalg.eigvalsh(mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) # Non-square matrix + mx.linalg.eigvalsh( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix def test_eigh(self): # Test a simple 2x2 symmetric matrix @@ -327,8 +339,14 @@ def test_eigh(self): eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(A_mx) eigenvalues_np, eigenvectors_np = np.linalg.eigh(A_np) - self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) - self.assertTrue(mx.allclose(mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) + ) + self.assertTrue( + mx.allclose( + mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5 + ) + ) # Test a larger random symmetric matrix n = 5 @@ -340,33 +358,57 @@ def test_eigh(self): eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(B_mx) eigenvalues_np, eigenvectors_np = np.linalg.eigh(B_np) - self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) - self.assertTrue(mx.allclose(mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) + ) + self.assertTrue( + mx.allclose( + mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5 + ) + ) # Test that eigenvalues are in ascending order self.assertTrue(mx.all(eigenvalues_mx[1:] >= eigenvalues_mx[:-1])) # Test orthogonality of eigenvectors identity = mx.eye(n) - self.assertTrue(mx.allclose(eigenvectors_mx.T @ eigenvectors_mx, identity, atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvectors_mx.T @ eigenvectors_mx, identity, atol=1e-5) + ) # Test with upper=False eigenvalues_mx_lower, eigenvectors_mx_lower = mx.linalg.eigh(B_mx, upper=False) - eigenvalues_np_lower, eigenvectors_np_lower = np.linalg.eigh(B_np, UPLO='L') + eigenvalues_np_lower, eigenvectors_np_lower = np.linalg.eigh(B_np, UPLO="L") - self.assertTrue(mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5)) - self.assertTrue(mx.allclose(mx.abs(eigenvectors_mx_lower), mx.abs(mx.array(eigenvectors_np_lower)), atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5) + ) + self.assertTrue( + mx.allclose( + mx.abs(eigenvectors_mx_lower), + mx.abs(mx.array(eigenvectors_np_lower)), + atol=1e-5, + ) + ) # Test with batched input C_np = rng.random((3, n, n)).astype(np.float32) - C_np = (C_np + np.transpose(C_np, (0, 2, 1))) / 2 # Make sure C is symmetric for each batch + C_np = ( + C_np + np.transpose(C_np, (0, 2, 1)) + ) / 2 # Make sure C is symmetric for each batch C_mx = mx.array(C_np) eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(C_mx) eigenvalues_np, eigenvectors_np = np.linalg.eigh(C_np) - self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) - self.assertTrue(mx.allclose(mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5)) + self.assertTrue( + mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) + ) + self.assertTrue( + mx.allclose( + mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5 + ) + ) # Test that eigenvalues are in ascending order for each batch self.assertTrue(mx.all(eigenvalues_mx[:, 1:] >= eigenvalues_mx[:, :-1])) @@ -374,14 +416,20 @@ def test_eigh(self): # Test orthogonality of eigenvectors for each batch identity = mx.eye(n) for i in range(3): - self.assertTrue(mx.allclose(eigenvectors_mx[i].T @ eigenvectors_mx[i], identity, atol=1e-5)) + self.assertTrue( + mx.allclose( + eigenvectors_mx[i].T @ eigenvectors_mx[i], identity, atol=1e-5 + ) + ) # Test error cases with self.assertRaises(ValueError): mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array with self.assertRaises(ValueError): - mx.linalg.eigh(mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) # Non-square matrix + mx.linalg.eigh( + mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) + ) # Non-square matrix if __name__ == "__main__": diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 1c4dea1f7..f149b7c9f 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -451,7 +451,7 @@ TEST_CASE("test matrix eigvalsh") { array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); // Test a simple 2x2 symmetric matrix - array A = array({{1.0f, 2.0f}, {2.0f, 4.0f}}, float32); + array A = array({{1, 2}, {2, 4}}); array eigenvalues = linalg::eigvalsh(A, /* upper = */ true, Device::cpu); // Expected eigenvalues (calculated analytically) @@ -475,45 +475,78 @@ TEST_CASE("test matrix eigvalsh") { CHECK(B_eigenvalues.dtype() == float32); CHECK(B_eigenvalues.shape() == std::vector{5}); CHECK(all(isfinite(B_eigenvalues)).item()); - CHECK(all(B_eigenvalues[slice(1, 5)] >= B_eigenvalues[slice(0, 4)]).item()); + CHECK(all(slice(B_eigenvalues, {1}, {5}) >= slice(B_eigenvalues, {0}, {4})) + .item()); - // Reconstruct the matrix using eigendecomposition and check if it's close to the original - const auto [eigenvalues, eigenvectors] = linalg::eigh(B_sym, /* upper = */ true, Device::cpu); - const auto B_reconstructed = matmul(matmul(eigenvectors, diag(eigenvalues)), transpose(eigenvectors)); + // Reconstruct the matrix using eigendecomposition and check if it's close to + // the original + const auto [eigenvalues_eigh, eigenvectors] = + linalg::eigh(B_sym, /* upper = */ true, Device::cpu); + const auto B_reconstructed = matmul( + matmul(eigenvectors, diag(eigenvalues_eigh)), transpose(eigenvectors)); CHECK(allclose(B_reconstructed, B_sym, /* rtol = */ 1e-5, /* atol = */ 1e-5) .item()); // Check that eigvalsh and eigh produce the same eigenvalues - CHECK(allclose(B_eigenvalues, eigenvalues, /* rtol = */ 1e-5, /* atol = */ 1e-5) - .item()); + CHECK( + allclose( + B_eigenvalues, eigenvalues_eigh, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); } TEST_CASE("test matrix eigh") { // 0D and 1D throw CHECK_THROWS(linalg::eigh(array(0.0), /* upper = */ true, Device::cpu)); - CHECK_THROWS(linalg::eigh(array({0.0, 1.0}), /* upper = */ true, Device::cpu)); + CHECK_THROWS( + linalg::eigh(array({0.0, 1.0}), /* upper = */ true, Device::cpu)); // Unsupported types throw - CHECK_THROWS(linalg::eigh(array({0, 1}, {1, 2}), /* upper = */ true, Device::cpu)); + CHECK_THROWS( + linalg::eigh(array({0, 1}, {1, 2}), /* upper = */ true, Device::cpu)); // Non-square throws - CHECK_THROWS(linalg::eigh(array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); + CHECK_THROWS(linalg::eigh( + array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); // Test a simple 2x2 symmetric matrix - array A = array({{1.0f, 2.0f}, {2.0f, 4.0f}}, float32); - auto [eigenvalues, eigenvectors] = linalg::eigh(A, /* upper = */ true, Device::cpu); + array A = array({1.0, 2.0, 2.0, 4.0}, {2, 2}, float32); + auto [eigenvalues, eigenvectors] = + linalg::eigh(A, /* upper = */ true, Device::cpu); // Expected eigenvalues and eigenvectors (calculated analytically) array expected_eigenvalues = array({0.0, 5.0}); - array expected_eigenvectors = array({{-0.4472136, 0.8944272}, {0.8944272, 0.4472136}}, float32); + array expected_eigenvectors = + array({-0.4472136f, 0.8944272f, 0.8944272f, 0.4472136f}, {2, 2}, float32); - CHECK(allclose(eigenvalues, expected_eigenvalues, /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); - CHECK(allclose(abs(eigenvectors), abs(expected_eigenvectors), /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); + CHECK(allclose( + eigenvalues, + expected_eigenvalues, + /* rtol = */ 1e-5, + /* atol = */ 1e-5) + .item()); + CHECK(allclose( + abs(eigenvectors), + abs(expected_eigenvectors), + /* rtol = */ 1e-5, + /* atol = */ 1e-5) + .item()); // Verify orthogonality of eigenvectors - CHECK(allclose(matmul(transpose(eigenvectors), eigenvectors), eye(2), /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); + CHECK(allclose( + matmul(transpose(eigenvectors), eigenvectors), + eye(2), + /* rtol = */ 1e-5, + /* atol = */ 1e-5) + .item()); // Verify eigendecomposition - CHECK(allclose(matmul(matmul(eigenvectors, diag(eigenvalues)), transpose(eigenvectors)), A, /* rtol = */ 1e-5, /* atol = */ 1e-5).item()); + CHECK( + allclose( + matmul( + matmul(eigenvectors, diag(eigenvalues)), transpose(eigenvectors)), + A, + /* rtol = */ 1e-5, + /* atol = */ 1e-5) + .item()); } From 9b800fde54b7df53f449c4eb5b619f5c0ffdbd41 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 11:34:13 -0700 Subject: [PATCH 17/19] cleanup lapack --- mlx/array.cpp | 5 +- mlx/backend/common/CMakeLists.txt | 2 +- mlx/backend/common/cholesky.cpp | 44 +--- mlx/backend/common/conv.cpp | 7 +- mlx/backend/common/default_primitives.cpp | 6 +- mlx/backend/common/eigh.cpp | 106 ++++++++++ mlx/backend/common/eigvalsh.cpp | 190 ------------------ mlx/backend/common/inverse.cpp | 25 +-- .../common/{lapack_helper.h => lapack.h} | 2 +- mlx/backend/common/masked_mm.cpp | 7 +- mlx/backend/common/qrf.cpp | 7 +- mlx/backend/common/svd.cpp | 2 +- mlx/linalg.cpp | 54 ++--- mlx/linalg.h | 4 +- mlx/primitives.cpp | 13 +- mlx/primitives.h | 8 +- python/src/linalg.cpp | 62 +++--- python/tests/test_linalg.py | 170 +++------------- tests/linalg_tests.cpp | 102 ++-------- 19 files changed, 227 insertions(+), 589 deletions(-) create mode 100644 mlx/backend/common/eigh.cpp delete mode 100644 mlx/backend/common/eigvalsh.cpp rename mlx/backend/common/{lapack_helper.h => lapack.h} (93%) diff --git a/mlx/array.cpp b/mlx/array.cpp index 374c2d36f..7eb69092f 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -178,9 +178,10 @@ void array::move_shared_buffer( array_desc_->flags = flags; array_desc_->data_size = data_size; auto char_offset = sizeof(char) * itemsize() * offset; - array_desc_->data_ptr = static_cast( - static_cast(other.array_desc_->data_ptr) + char_offset); + auto data_ptr = other.array_desc_->data_ptr; other.array_desc_->data_ptr = nullptr; + array_desc_->data_ptr = static_cast( + static_cast(data_ptr) + char_offset); } void array::move_shared_buffer(array other) { diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index a1bf8c5ce..4fca2274e 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -31,7 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/eigvalsh.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eigh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp index 5fd9c8065..d94822c9d 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/common/cholesky.cpp @@ -2,46 +2,12 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/linalg.h" #include "mlx/primitives.h" -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - namespace mlx::core { -namespace { - -// Delegate to the Cholesky factorization taking into account differences in -// LAPACK implementations (basically how to pass the 'uplo' string to fortran). -int spotrf_wrapper(char uplo, float* matrix, int N) { - int info; - -#ifdef LAPACK_FORTRAN_STRLEN_END - spotrf_( - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info, - /* uplo_len = */ static_cast(1)); -#else - spotrf_( - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info); -#endif - - return info; -} - -} // namespace - void cholesky_impl(const array& a, array& factor, bool upper) { // Lapack uses the column-major convention. We take advantage of the fact that // the matrix should be symmetric: @@ -66,7 +32,13 @@ void cholesky_impl(const array& a, array& factor, bool upper) { for (int i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. - int info = spotrf_wrapper(uplo, matrix, N); + int info; + MLX_LAPACK_FUNC(spotrf)( + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info); // TODO: We do nothing when the matrix is not positive semi-definite // because throwing an error would result in a crash. If we figure out how diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 76edc9a27..57c90e250 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -3,12 +3,7 @@ #include #include -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - +#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/copy.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index edc7192bd..547d8e25d 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -1,14 +1,10 @@ // Copyright © 2023-2024 Apple Inc. -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif #include #include "mlx/array.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/common/eigh.cpp b/mlx/backend/common/eigh.cpp new file mode 100644 index 000000000..c9af5bc9f --- /dev/null +++ b/mlx/backend/common/eigh.cpp @@ -0,0 +1,106 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/array.h" +#include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +namespace mlx::core { + +namespace { + +void ssyevd(char jobz, char uplo, float* a, int N, float* w, float* work, int lwork, int* iwork, int liwork) { + int info; + MLX_LAPACK_FUNC(ssyevd)( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ a, + /* lda = */ &N, + /* w = */ w, + /* work = */ work, + /* lwork = */ &lwork, + /* iwork = */ iwork, + /* liwork = */ &liwork, + /* info = */ &info); + if (info != 0) { + std::stringstream msg; + msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); + } +} + +} // namespace + +void Eigh::eval(const std::vector& inputs, std::vector& outputs) { + const auto& a = inputs[0]; + auto& values = outputs[0]; + + auto vectors = compute_eigenvectors_ ? outputs[1] : array(a.shape(), a.dtype(), nullptr, {}); + + values.set_data(allocator::malloc_or_wait(values.nbytes())); + + copy(a, vectors, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + if (compute_eigenvectors_) { + // Set the strides and flags so the eigenvectors + // are in the columns of the output + auto flags = vectors.flags(); + auto strides = vectors.strides(); + auto ndim = a.ndim(); + std::swap(strides[ndim-1], strides[ndim-2]); + + if (a.size() > 1) { + flags.row_contiguous = false; + if (ndim > 2) { + flags.col_contiguous = false; + } else { + flags.col_contiguous = true; + } + } + vectors.move_shared_buffer( + vectors, + strides, + flags, + vectors.data_size()); + } + + auto vec_ptr = vectors.data(); + auto eig_ptr = values.data(); + + char jobz = compute_eigenvectors_ ? 'V' : 'N'; + auto N = a.shape(-1); + + // Work query + int lwork; + int liwork; + { + float work; + int iwork; + ssyevd(jobz, uplo_[0], nullptr, N, nullptr, &work, -1, &iwork, -1); + lwork = static_cast(work); + liwork = iwork; + } + + auto work_buf = array::Data{allocator::malloc_or_wait(sizeof(float) * lwork)}; + auto iwork_buf = array::Data{allocator::malloc_or_wait(sizeof(int) * liwork)}; + for (size_t i = 0; i < a.size() / (N * N); ++i) { + ssyevd( + jobz, + uplo_[0], + vec_ptr, + N, + eig_ptr, + static_cast(work_buf.buffer.raw_ptr()), + lwork, + static_cast(iwork_buf.buffer.raw_ptr()), + liwork); + vec_ptr += N * N; + eig_ptr += N; + } +} + +} // namespace mlx::core diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp deleted file mode 100644 index 62eadf0dc..000000000 --- a/mlx/backend/common/eigvalsh.cpp +++ /dev/null @@ -1,190 +0,0 @@ -// Copyright © 2023-2024 Apple Inc. - -#include "mlx/allocator.h" -#include "mlx/array.h" -#include "mlx/backend/common/copy.h" -#include "mlx/linalg.h" -#include "mlx/primitives.h" - -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - -namespace mlx::core { - -namespace { - -// Delegate to the eigenvalue decomposition taking into account differences in -// LAPACK implementations (basically how to pass the 'jobz' and 'uplo' strings -// to fortran). -int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) { - int info; - int lwork = -1; - int liwork = -1; - float work_query; - int iwork_query; - - // Query for optimal work array sizes -#ifdef LAPACK_FORTRAN_STRLEN_END - ssyevd_( - /* jobz = */ &jobz, - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* w = */ w, - /* work = */ &work_query, - /* lwork = */ &lwork, - /* iwork = */ &iwork_query, - /* liwork = */ &liwork, - /* info = */ &info, - /* jobz_len = */ static_cast(1), - /* uplo_len = */ static_cast(1)); -#else - ssyevd_( - /* jobz = */ &jobz, - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* w = */ w, - /* work = */ &work_query, - /* lwork = */ &lwork, - /* iwork = */ &iwork_query, - /* liwork = */ &liwork, - /* info = */ &info); -#endif - - lwork = static_cast(work_query); - liwork = iwork_query; - - std::vector work(lwork); - std::vector iwork(liwork); - - // Compute eigenvalues (and optionally eigenvectors) -#ifdef LAPACK_FORTRAN_STRLEN_END - ssyevd_( - /* jobz = */ &jobz, - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* w = */ w, - /* work = */ work.data(), - /* lwork = */ &lwork, - /* iwork = */ iwork.data(), - /* liwork = */ &liwork, - /* info = */ &info, - /* jobz_len = */ static_cast(1), - /* uplo_len = */ static_cast(1)); -#else - ssyevd_( - /* jobz = */ &jobz, - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* w = */ w, - /* work = */ work.data(), - /* lwork = */ &lwork, - /* iwork = */ iwork.data(), - /* liwork = */ &liwork, - /* info = */ &info); -#endif - - return info; -} - -} // namespace - -void eigh_impl( - const array& a, - array& values, - array& vectors, - bool upper, - bool compute_eigenvectors) { - char jobz = compute_eigenvectors ? 'V' : 'N'; - char uplo = upper ? 'U' : 'L'; - - array buffer = copy(a); - - const int N = static_cast(a.shape(-1)); - const int num_matrices = static_cast(a.size() / (N * N)); - - std::vector values_shape = {num_matrices, N}; - values = array( - allocator::malloc(num_matrices * N * size_of(a.dtype())), - values_shape, - a.dtype()); - - float* matrix = buffer.data(); - float* w = values.data(); - - if (compute_eigenvectors) { - std::vector vectors_shape = a.shape(); - vectors = array( - allocator::malloc(a.size() * size_of(a.dtype())), - vectors_shape, - a.dtype()); - } - - float* vecs = compute_eigenvectors ? vectors.data() : nullptr; - - for (int i = 0; i < num_matrices; i++) { - int info = ssyevd_wrapper(jobz, uplo, matrix, w, N); - - if (info != 0) { - std::stringstream msg; - msg << "[eigh] Eigenvalue decomposition failed with error code " << info; - throw std::runtime_error(msg.str()); - } - - if (compute_eigenvectors) { - // Copy eigenvectors to the output array - std::copy(matrix, matrix + N * N, vecs); - vecs += N * N; - } - - matrix += N * N; - w += N; - } -} - -void Eigh::eval(const std::vector& inputs, std::vector& outputs) { - // Validate the number of inputs - if (inputs.size() != 1) { - throw std::invalid_argument( - "[Eigh::eval] Expected exactly one input array."); - } - - const array& input = inputs[0]; - - // Ensure the input array is evaluated before accessing its data - const_cast(input).eval(); - - // Validate the data type - Dtype input_dtype = input.dtype(); // Changed from 'dtype_t' to 'Dtype' - - // Validate the number of dimensions (expecting at least 2D) - if (input.ndim() < 2) { - throw std::invalid_argument( - "[Eigh::eval] Input array must be at least 2-dimensional."); - } - - array values{}; - array vectors{}; - eigh_impl(input, values, vectors, upper_, compute_eigenvectors_); - - // Ensure the output arrays are evaluated - values.eval(); - if (compute_eigenvectors_) { - vectors.eval(); - outputs = {values, vectors}; - } else { - outputs = {values}; - } -} - -} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/common/inverse.cpp index 57d885c73..e3365c06c 100644 --- a/mlx/backend/common/inverse.cpp +++ b/mlx/backend/common/inverse.cpp @@ -2,39 +2,18 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - -// Wrapper to account for differences in -// LAPACK implementations (basically how to pass the 'uplo' string to fortran). int strtri_wrapper(char uplo, char diag, float* matrix, int N) { int info; - -#ifdef LAPACK_FORTRAN_STRLEN_END - strtri_( - /* uplo = */ &uplo, - /* diag = */ &diag, - /* N = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info, - /* uplo_len = */ static_cast(1), - /* diag_len = */ static_cast(1)); -#else - strtri_( + MLX_LAPACK_FUNC(strtri)( /* uplo = */ &uplo, /* diag = */ &diag, /* N = */ &N, /* a = */ matrix, /* lda = */ &N, /* info = */ &info); -#endif - return info; } diff --git a/mlx/backend/common/lapack_helper.h b/mlx/backend/common/lapack.h similarity index 93% rename from mlx/backend/common/lapack_helper.h rename to mlx/backend/common/lapack.h index bf0f76437..16a699a17 100644 --- a/mlx/backend/common/lapack_helper.h +++ b/mlx/backend/common/lapack.h @@ -1,4 +1,4 @@ -// Copyright © 2024 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #pragma once diff --git a/mlx/backend/common/masked_mm.cpp b/mlx/backend/common/masked_mm.cpp index 44a471168..d0286f0fd 100644 --- a/mlx/backend/common/masked_mm.cpp +++ b/mlx/backend/common/masked_mm.cpp @@ -1,15 +1,10 @@ // Copyright © 2024 Apple Inc. -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - #include #include "mlx/array.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/utils.h" #include "mlx/primitives.h" diff --git a/mlx/backend/common/qrf.cpp b/mlx/backend/common/qrf.cpp index 4171398fd..9383f6c88 100644 --- a/mlx/backend/common/qrf.cpp +++ b/mlx/backend/common/qrf.cpp @@ -2,14 +2,9 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" -#ifdef ACCELERATE_NEW_LAPACK -#include -#else -#include -#endif - namespace mlx::core { template diff --git a/mlx/backend/common/svd.cpp b/mlx/backend/common/svd.cpp index 412f06297..1a6f1b1ad 100644 --- a/mlx/backend/common/svd.cpp +++ b/mlx/backend/common/svd.cpp @@ -2,7 +2,7 @@ #include "mlx/allocator.h" #include "mlx/backend/common/copy.h" -#include "mlx/backend/common/lapack_helper.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" namespace mlx::core { diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index b7a58736e..997279ed3 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -454,67 +454,47 @@ array cross( return concatenate(outputs, axis, s); } -array eigvalsh( - const array& a, - bool upper /* = false */, - StreamOrDevice s /* = {} */) { +void validate_eigh(const array& a, const std::string fname) { if (a.dtype() != float32) { std::ostringstream msg; - msg << "[linalg::eigvalsh] Arrays must be type float32. Received array " + msg << fname << " Arrays must have type float32. Received array " << "with type " << a.dtype() << "."; throw std::invalid_argument(msg.str()); } if (a.ndim() < 2) { std::ostringstream msg; - msg << "[linalg::eigvalsh] Arrays must have >= 2 dimensions. Received array " - "with " + msg << fname << " Arrays must have >= 2 dimensions. Received array with " << a.ndim() << " dimensions."; throw std::invalid_argument(msg.str()); } if (a.shape(-1) != a.shape(-2)) { - throw std::invalid_argument( - "[linalg::eigvalsh] Eigenvalues are only defined for square matrices."); + throw std::invalid_argument(fname + " Only defined for square matrices."); } +} +array eigvalsh( + const array& a, + std::string UPLO /* = "L" */, + StreamOrDevice s /* = {} */) { + validate_eigh(a, "[linalg::eigvalsh]"); std::vector out_shape(a.shape().begin(), a.shape().end() - 1); - out_shape.back() = a.shape(-1); - return array( - out_shape, + std::move(out_shape), a.dtype(), - std::make_shared(to_stream(s), upper, false), - {astype(a, a.dtype(), s)}); + std::make_shared(to_stream(s), UPLO, false), + {a}); } std::pair -eigh(const array& a, bool upper /* = false */, StreamOrDevice s /* = {} */) { - if (a.dtype() != float32) { - std::ostringstream msg; - msg << "[linalg::eigh] Arrays must be type float32. Received array " - << "with type " << a.dtype() << "."; - throw std::invalid_argument(msg.str()); - } - - if (a.ndim() < 2) { - std::ostringstream msg; - msg << "[linalg::eigh] Arrays must have >= 2 dimensions. Received array " - "with " - << a.ndim() << " dimensions."; - throw std::invalid_argument(msg.str()); - } - - if (a.shape(-1) != a.shape(-2)) { - throw std::invalid_argument( - "[linalg::eigh] Eigenvectors are only defined for square matrices."); - } - +eigh(const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { + validate_eigh(a, "[linalg::eigh]"); auto out = array::make_arrays( {std::vector(a.shape().begin(), a.shape().end() - 1), a.shape()}, {a.dtype(), a.dtype()}, - std::make_shared(to_stream(s), upper, true), - {astype(a, a.dtype(), s)}); + std::make_shared(to_stream(s), UPLO, true), + {a}); return std::make_pair(out[0], out[1]); } diff --git a/mlx/linalg.h b/mlx/linalg.h index 1b29b5dc5..4ea81bef0 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -83,9 +83,9 @@ array cross( int axis = -1, StreamOrDevice s = {}); -array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {}); +array eigvalsh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); std::pair -eigh(const array& a, bool upper = false, StreamOrDevice s = {}); +eigh(const array& a, std::string UPLO = "L", StreamOrDevice s = {}); } // namespace mlx::core::linalg diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 43148ad48..c9f839d4b 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -773,20 +773,19 @@ std::pair, std::vector> Eigh::vmap( assert(inputs.size() == 1); assert(axes.size() == 1); - auto ax = axes[0] >= 0 ? 0 : -1; - auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + bool needs_move = axes[0] >= (inputs[0].ndim() - 2); + auto a = needs_move ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + auto ax = needs_move ? 0 : axes[0]; std::vector outputs; if (compute_eigenvectors_) { - auto [values, vectors] = linalg::eigh(a, upper_, stream()); + auto [values, vectors] = linalg::eigh(a, uplo_, stream()); outputs = {values, vectors}; } else { - outputs = {linalg::eigvalsh(a, upper_, stream())}; + outputs = {linalg::eigvalsh(a, uplo_, stream())}; } - std::vector out_axes(outputs.size(), ax); - - return {outputs, out_axes}; + return {outputs, std::vector(outputs.size(), ax)}; } std::vector Concatenate::vjp( diff --git a/mlx/primitives.h b/mlx/primitives.h index 64d9ca94e..f2b5bab7c 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2198,9 +2198,9 @@ class Cholesky : public UnaryPrimitive { class Eigh : public Primitive { public: - explicit Eigh(Stream stream, bool upper, bool compute_eigenvectors) + explicit Eigh(Stream stream, std::string uplo, bool compute_eigenvectors) : Primitive(stream), - upper_(upper), + uplo_(std::move(uplo)), compute_eigenvectors_(compute_eigenvectors) {} void eval_cpu(const std::vector& inputs, std::vector& outputs) @@ -2224,7 +2224,7 @@ class Eigh : public Primitive { bool is_equivalent(const Primitive& other) const override { if (auto* p = dynamic_cast(&other)) { - return upper_ == p->upper_ && + return uplo_ == p->uplo_ && compute_eigenvectors_ == p->compute_eigenvectors_; } return false; @@ -2232,7 +2232,7 @@ class Eigh : public Primitive { private: void eval(const std::vector& inputs, std::vector& outputs); - bool upper_; + std::string uplo_; bool compute_eigenvectors_; }; diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 869444cd3..9fb40b687 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -409,22 +409,21 @@ void init_linalg(nb::module_& parent_module) { "eigvalsh", &eigvalsh, "a"_a, - "upper"_a = true, + "UPLO"_a = "L", nb::kw_only(), "stream"_a = nb::none(), - nb::sig( - "def eigvalsh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> array"), R"pbdoc( Compute the eigenvalues of a complex Hermitian or real symmetric matrix. - This function supports arrays with at least 2 dimensions. When the input - has more than two dimensions, the eigenvalues are computed for each matrix - in the last two dimensions of ``a``. + This function supports arrays with at least 2 dimensions. When the + input has more than two dimensions, the eigenvalues are computed for + each matrix in the last two dimensions. Args: - a (array): Input array. Must be a real symmetric or complex Hermitian matrix. - upper (bool, optional): Whether to use the upper or lower triangle of the matrix. - Default is True (upper triangle). + a (array): Input array. Must be a real symmetric or complex + Hermitian matrix. + UPLO (str, optional): Whether to use the upper (``"U"``) or + lower (``"L"``) triangle of the matrix. Default: ``"L"``. stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. @@ -432,56 +431,55 @@ void init_linalg(nb::module_& parent_module) { array: The eigenvalues in ascending order. Note: - The input matrix is assumed to be symmetric (or Hermitian). Only the - upper triangle (if upper=True) or lower triangle (if upper=False) is used. - No checks for symmetry are performed. + The input matrix is assumed to be symmetric (or Hermitian). Only + the selected triangle is used. No checks for symmetry are performed. Example: >>> A = mx.array([[1., -2.], [-2., 1.]]) - >>> eigenvalues = mx.linalg.eigvalsh(A) + >>> eigenvalues = mx.linalg.eigvalsh(A, stream=mx.cpu) >>> eigenvalues array([-1., 3.], dtype=float32) )pbdoc"); m.def( "eigh", - [](const array& a, bool upper, StreamOrDevice s) { - auto result = eigh(a, upper, s); + [](const array& a, const std::string UPLO, StreamOrDevice s) { + // TODO avoid cast? + auto result = eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, "a"_a, - "upper"_a = true, + "UPLO"_a = "L", nb::kw_only(), "stream"_a = nb::none(), - nb::sig( - "def eigh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), R"pbdoc( - Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix. + Compute the eigenvalues and eigenvectors of a complex Hermitian or + real symmetric matrix. This function supports arrays with at least 2 dimensions. When the input - has more than two dimensions, the eigenvalues and eigenvectors are computed for each matrix - in the last two dimensions of ``a``. + has more than two dimensions, the eigenvalues and eigenvectors are + computed for each matrix in the last two dimensions. Args: - a (array): Input array. Must be a real symmetric or complex Hermitian matrix. - upper (bool, optional): Whether to use the upper or lower triangle of the matrix. - Default is True (upper triangle). + a (array): Input array. Must be a real symmetric or complex + Hermitian matrix. + UPLO (str, optional): Whether to use the upper (``"U"``) or + lower (``"L"``) triangle of the matrix. Default: ``"L"``. stream (Stream, optional): Stream or device. Defaults to ``None`` in which case the default stream of the default device is used. Returns: - Tuple[array, array]: A tuple containing: - - The eigenvalues in ascending order. - - The normalized eigenvectors. The column v[:, i] is the - eigenvector corresponding to the i-th eigenvalue. + Tuple[array, array]: + A tuple containing the eigenvalues in ascending order and + the normalized eigenvectors. The column ``v[:, i]`` is the + eigenvector corresponding to the i-th eigenvalue. Note: - The input matrix is assumed to be symmetric (or Hermitian). Only the - upper triangle (if upper=True) or lower triangle (if upper=False) is used. - No checks for symmetry are performed. + The input matrix is assumed to be symmetric (or Hermitian). Only + the selected triangle is used. No checks for symmetry are performed. Example: >>> A = mx.array([[1., -2.], [-2., 1.]]) - >>> w, v = mx.linalg.eigh(A) + >>> w, v = mx.linalg.eigh(A, stream=mx.cpu) >>> w array([-1., 3.], dtype=float32) >>> v diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index edf671b88..ccc73cade 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -268,166 +268,54 @@ def test_cross_product(self): with self.assertRaises(ValueError): mx.linalg.cross(a, b) - def test_eigvalsh(self): - # Test a simple 2x2 symmetric matrix - A_mx = mx.array([[1.0, 2.0], [2.0, 4.0]], dtype=mx.float32) - A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32) + def test_eigh(self): + tols = {"atol": 1e-5, "rtol": 1e-5} - eigenvalues_mx = mx.linalg.eigvalsh(A_mx) - eigenvalues_np = np.linalg.eigvalsh(A_np) + def check_eigs_and_vecs(A_np, kwargs={}): + A = mx.array(A_np) + eig_vals, eig_vecs = mx.linalg.eigh(A, stream=mx.cpu, **kwargs) + eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs) + self.assertTrue(np.allclose(eig_vals, eig_vals_np, **tols)) + self.assertTrue(mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols)) - self.assertTrue( - mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) - ) + eig_vals_only = mx.linalg.eigvalsh(A, stream=mx.cpu, **kwargs) + self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols)) + + # Test a simple 2x2 symmetric matrix + A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32) + check_eigs_and_vecs(A_np) # Test a larger random symmetric matrix n = 5 - rng = np.random.default_rng(42) - B_np = rng.random((n, n)).astype(np.float32) - B_np = (B_np + B_np.T) / 2 # Make sure B is symmetric - B_mx = mx.array(B_np) + np.random.seed(1) + A_np = np.random.randn(n, n).astype(np.float32) + A_np = (A_np + A_np.T) / 2 + check_eigs_and_vecs(A_np) - eigenvalues_mx = mx.linalg.eigvalsh(B_mx) - eigenvalues_np = np.linalg.eigvalsh(B_np) - - self.assertTrue( - mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) - ) - - # Test that eigenvalues are in ascending order - self.assertTrue(mx.all(eigenvalues_mx[1:] >= eigenvalues_mx[:-1])) - - # Test with upper=False - eigenvalues_mx_lower = mx.linalg.eigvalsh(B_mx, upper=False) - eigenvalues_np_lower = np.linalg.eigvalsh(B_np, UPLO="L") - - self.assertTrue( - mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5) - ) + # Test with upper triangle + check_eigs_and_vecs(A_np, {"UPLO": "U"}) # Test with batched input - C_np = rng.random((3, n, n)).astype(np.float32) - C_np = ( - C_np + np.transpose(C_np, (0, 2, 1)) - ) / 2 # Make sure C is symmetric for each batch - C_mx = mx.array(C_np) - - eigenvalues_mx = mx.linalg.eigvalsh(C_mx) - eigenvalues_np = np.linalg.eigvalsh(C_np) - - self.assertTrue( - mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) - ) - - # Test that eigenvalues are in ascending order for each batch - self.assertTrue(mx.all(eigenvalues_mx[:, 1:] >= eigenvalues_mx[:, :-1])) + A_np = np.random.randn(3, n, n).astype(np.float32) + A_np = ( + A_np + np.transpose(A_np, (0, 2, 1)) + ) / 2 + check_eigs_and_vecs(A_np) # Test error cases with self.assertRaises(ValueError): - mx.linalg.eigvalsh(mx.array([1.0, 2.0])) # 1D array + mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array with self.assertRaises(ValueError): - mx.linalg.eigvalsh( + mx.linalg.eigh( mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ) # Non-square matrix - def test_eigh(self): - # Test a simple 2x2 symmetric matrix - A_mx = mx.array([[1.0, 2.0], [2.0, 4.0]], dtype=mx.float32) - A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32) - - eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(A_mx) - eigenvalues_np, eigenvectors_np = np.linalg.eigh(A_np) - - self.assertTrue( - mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) - ) - self.assertTrue( - mx.allclose( - mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5 - ) - ) - - # Test a larger random symmetric matrix - n = 5 - rng = np.random.default_rng(42) - B_np = rng.random((n, n)).astype(np.float32) - B_np = (B_np + B_np.T) / 2 # Make sure B is symmetric - B_mx = mx.array(B_np) - - eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(B_mx) - eigenvalues_np, eigenvectors_np = np.linalg.eigh(B_np) - - self.assertTrue( - mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) - ) - self.assertTrue( - mx.allclose( - mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5 - ) - ) - - # Test that eigenvalues are in ascending order - self.assertTrue(mx.all(eigenvalues_mx[1:] >= eigenvalues_mx[:-1])) - - # Test orthogonality of eigenvectors - identity = mx.eye(n) - self.assertTrue( - mx.allclose(eigenvectors_mx.T @ eigenvectors_mx, identity, atol=1e-5) - ) - - # Test with upper=False - eigenvalues_mx_lower, eigenvectors_mx_lower = mx.linalg.eigh(B_mx, upper=False) - eigenvalues_np_lower, eigenvectors_np_lower = np.linalg.eigh(B_np, UPLO="L") - - self.assertTrue( - mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5) - ) - self.assertTrue( - mx.allclose( - mx.abs(eigenvectors_mx_lower), - mx.abs(mx.array(eigenvectors_np_lower)), - atol=1e-5, - ) - ) - - # Test with batched input - C_np = rng.random((3, n, n)).astype(np.float32) - C_np = ( - C_np + np.transpose(C_np, (0, 2, 1)) - ) / 2 # Make sure C is symmetric for each batch - C_mx = mx.array(C_np) - - eigenvalues_mx, eigenvectors_mx = mx.linalg.eigh(C_mx) - eigenvalues_np, eigenvectors_np = np.linalg.eigh(C_np) - - self.assertTrue( - mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5) - ) - self.assertTrue( - mx.allclose( - mx.abs(eigenvectors_mx), mx.abs(mx.array(eigenvectors_np)), atol=1e-5 - ) - ) - - # Test that eigenvalues are in ascending order for each batch - self.assertTrue(mx.all(eigenvalues_mx[:, 1:] >= eigenvalues_mx[:, :-1])) - - # Test orthogonality of eigenvectors for each batch - identity = mx.eye(n) - for i in range(3): - self.assertTrue( - mx.allclose( - eigenvectors_mx[i].T @ eigenvectors_mx[i], identity, atol=1e-5 - ) - ) - - # Test error cases with self.assertRaises(ValueError): - mx.linalg.eigh(mx.array([1.0, 2.0])) # 1D array + mx.linalg.eigvalsh(mx.array([1.0, 2.0])) # 1D array with self.assertRaises(ValueError): - mx.linalg.eigh( + mx.linalg.eigvalsh( mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) ) # Non-square matrix diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index f149b7c9f..951185f12 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -436,117 +436,41 @@ TEST_CASE("test cross product") { CHECK(allclose(result, expected).item()); } -TEST_CASE("test matrix eigvalsh") { - // 0D and 1D throw - CHECK_THROWS(linalg::eigvalsh(array(0.0), /* upper = */ true, Device::cpu)); - CHECK_THROWS( - linalg::eigvalsh(array({0.0, 1.0}), /* upper = */ true, Device::cpu)); - - // Unsupported types throw - CHECK_THROWS( - linalg::eigvalsh(array({0, 1}, {1, 2}), /* upper = */ true, Device::cpu)); - - // Non-square throws - CHECK_THROWS(linalg::eigvalsh( - array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); - - // Test a simple 2x2 symmetric matrix - array A = array({{1, 2}, {2, 4}}); - array eigenvalues = linalg::eigvalsh(A, /* upper = */ true, Device::cpu); - - // Expected eigenvalues (calculated analytically) - array expected_eigenvalues = array({0.0, 5.0}); - - CHECK(allclose( - eigenvalues, - expected_eigenvalues, - /* rtol = */ 1e-5, - /* atol = */ 1e-5) - .item()); - - /// Test a larger symmetric matrix - const auto prng_key = random::key(42); - const auto B = random::normal({5, 5}, prng_key); - const auto B_sym = 0.5 * (B + transpose(B)); // Make sure B is symmetric - const auto B_eigenvalues = - linalg::eigvalsh(B_sym, /* upper = */ true, Device::cpu); - - // Check that eigenvalues are real and in ascending order - CHECK(B_eigenvalues.dtype() == float32); - CHECK(B_eigenvalues.shape() == std::vector{5}); - CHECK(all(isfinite(B_eigenvalues)).item()); - CHECK(all(slice(B_eigenvalues, {1}, {5}) >= slice(B_eigenvalues, {0}, {4})) - .item()); - - // Reconstruct the matrix using eigendecomposition and check if it's close to - // the original - const auto [eigenvalues_eigh, eigenvectors] = - linalg::eigh(B_sym, /* upper = */ true, Device::cpu); - const auto B_reconstructed = matmul( - matmul(eigenvectors, diag(eigenvalues_eigh)), transpose(eigenvectors)); - - CHECK(allclose(B_reconstructed, B_sym, /* rtol = */ 1e-5, /* atol = */ 1e-5) - .item()); - - // Check that eigvalsh and eigh produce the same eigenvalues - CHECK( - allclose( - B_eigenvalues, eigenvalues_eigh, /* rtol = */ 1e-5, /* atol = */ 1e-5) - .item()); -} TEST_CASE("test matrix eigh") { // 0D and 1D throw - CHECK_THROWS(linalg::eigh(array(0.0), /* upper = */ true, Device::cpu)); - CHECK_THROWS( - linalg::eigh(array({0.0, 1.0}), /* upper = */ true, Device::cpu)); + CHECK_THROWS(linalg::eigh(array(0.0))); + CHECK_THROWS(linalg::eigh(array({0.0, 1.0}))); + CHECK_THROWS(linalg::eigvalsh(array(0.0))); + CHECK_THROWS(linalg::eigvalsh(array({0.0, 1.0}))); // Unsupported types throw - CHECK_THROWS( - linalg::eigh(array({0, 1}, {1, 2}), /* upper = */ true, Device::cpu)); + CHECK_THROWS(linalg::eigh(array({0, 1}, {1, 2}))); // Non-square throws - CHECK_THROWS(linalg::eigh( - array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); + CHECK_THROWS(linalg::eigh(array({1, 2, 3, 4, 5, 6}, {2, 3}))); // Test a simple 2x2 symmetric matrix array A = array({1.0, 2.0, 2.0, 4.0}, {2, 2}, float32); - auto [eigenvalues, eigenvectors] = - linalg::eigh(A, /* upper = */ true, Device::cpu); - - // Expected eigenvalues and eigenvectors (calculated analytically) - array expected_eigenvalues = array({0.0, 5.0}); - array expected_eigenvectors = - array({-0.4472136f, 0.8944272f, 0.8944272f, 0.4472136f}, {2, 2}, float32); + auto [eigvals, eigvecs] = linalg::eigh(A, "L", Device::cpu); + // Expected eigenvalues + array expected_eigvals = array({0.0, 5.0}); CHECK(allclose( - eigenvalues, - expected_eigenvalues, - /* rtol = */ 1e-5, - /* atol = */ 1e-5) - .item()); - CHECK(allclose( - abs(eigenvectors), - abs(expected_eigenvectors), + eigvals, + expected_eigvals, /* rtol = */ 1e-5, /* atol = */ 1e-5) .item()); // Verify orthogonality of eigenvectors CHECK(allclose( - matmul(transpose(eigenvectors), eigenvectors), + matmul(eigvecs, transpose(eigvecs)), eye(2), /* rtol = */ 1e-5, /* atol = */ 1e-5) .item()); // Verify eigendecomposition - CHECK( - allclose( - matmul( - matmul(eigenvectors, diag(eigenvalues)), transpose(eigenvectors)), - A, - /* rtol = */ 1e-5, - /* atol = */ 1e-5) - .item()); + CHECK(allclose(matmul(A, eigvecs), eigvals * eigvecs).item()); } From aef4160e7e4302bf8e54af19a3e8e257e83659a1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 22 Oct 2024 11:41:31 -0700 Subject: [PATCH 18/19] format --- mlx/array.cpp | 4 ++-- mlx/backend/common/cholesky.cpp | 3 ++- mlx/backend/common/conv.cpp | 2 +- mlx/backend/common/eigh.cpp | 33 ++++++++++++++++++++++----------- mlx/backend/common/inverse.cpp | 3 ++- mlx/linalg.cpp | 6 ++++-- python/src/linalg.cpp | 2 +- python/tests/test_linalg.py | 8 ++++---- tests/linalg_tests.cpp | 1 - 9 files changed, 38 insertions(+), 24 deletions(-) diff --git a/mlx/array.cpp b/mlx/array.cpp index 7eb69092f..bb92989c3 100644 --- a/mlx/array.cpp +++ b/mlx/array.cpp @@ -180,8 +180,8 @@ void array::move_shared_buffer( auto char_offset = sizeof(char) * itemsize() * offset; auto data_ptr = other.array_desc_->data_ptr; other.array_desc_->data_ptr = nullptr; - array_desc_->data_ptr = static_cast( - static_cast(data_ptr) + char_offset); + array_desc_->data_ptr = + static_cast(static_cast(data_ptr) + char_offset); } void array::move_shared_buffer(array other) { diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp index d94822c9d..62807e6dd 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/common/cholesky.cpp @@ -33,7 +33,8 @@ void cholesky_impl(const array& a, array& factor, bool upper) { for (int i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. int info; - MLX_LAPACK_FUNC(spotrf)( + MLX_LAPACK_FUNC(spotrf) + ( /* uplo = */ &uplo, /* n = */ &N, /* a = */ matrix, diff --git a/mlx/backend/common/conv.cpp b/mlx/backend/common/conv.cpp index 57c90e250..67bdaeefb 100644 --- a/mlx/backend/common/conv.cpp +++ b/mlx/backend/common/conv.cpp @@ -3,8 +3,8 @@ #include #include -#include "mlx/backend/common/lapack.h" #include "mlx/backend/common/copy.h" +#include "mlx/backend/common/lapack.h" #include "mlx/primitives.h" #include "mlx/utils.h" diff --git a/mlx/backend/common/eigh.cpp b/mlx/backend/common/eigh.cpp index c9af5bc9f..8a4e499a3 100644 --- a/mlx/backend/common/eigh.cpp +++ b/mlx/backend/common/eigh.cpp @@ -11,9 +11,19 @@ namespace mlx::core { namespace { -void ssyevd(char jobz, char uplo, float* a, int N, float* w, float* work, int lwork, int* iwork, int liwork) { +void ssyevd( + char jobz, + char uplo, + float* a, + int N, + float* w, + float* work, + int lwork, + int* iwork, + int liwork) { int info; - MLX_LAPACK_FUNC(ssyevd)( + MLX_LAPACK_FUNC(ssyevd) + ( /* jobz = */ &jobz, /* uplo = */ &uplo, /* n = */ &N, @@ -28,7 +38,7 @@ void ssyevd(char jobz, char uplo, float* a, int N, float* w, float* work, int lw if (info != 0) { std::stringstream msg; msg << "[Eigh::eval_cpu] Eigenvalue decomposition failed with error code " - << info; + << info; throw std::runtime_error(msg.str()); } } @@ -39,11 +49,16 @@ void Eigh::eval(const std::vector& inputs, std::vector& outputs) { const auto& a = inputs[0]; auto& values = outputs[0]; - auto vectors = compute_eigenvectors_ ? outputs[1] : array(a.shape(), a.dtype(), nullptr, {}); + auto vectors = compute_eigenvectors_ + ? outputs[1] + : array(a.shape(), a.dtype(), nullptr, {}); values.set_data(allocator::malloc_or_wait(values.nbytes())); - copy(a, vectors, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy( + a, + vectors, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General); if (compute_eigenvectors_) { // Set the strides and flags so the eigenvectors @@ -51,7 +66,7 @@ void Eigh::eval(const std::vector& inputs, std::vector& outputs) { auto flags = vectors.flags(); auto strides = vectors.strides(); auto ndim = a.ndim(); - std::swap(strides[ndim-1], strides[ndim-2]); + std::swap(strides[ndim - 1], strides[ndim - 2]); if (a.size() > 1) { flags.row_contiguous = false; @@ -61,11 +76,7 @@ void Eigh::eval(const std::vector& inputs, std::vector& outputs) { flags.col_contiguous = true; } } - vectors.move_shared_buffer( - vectors, - strides, - flags, - vectors.data_size()); + vectors.move_shared_buffer(vectors, strides, flags, vectors.data_size()); } auto vec_ptr = vectors.data(); diff --git a/mlx/backend/common/inverse.cpp b/mlx/backend/common/inverse.cpp index e3365c06c..96dbfc001 100644 --- a/mlx/backend/common/inverse.cpp +++ b/mlx/backend/common/inverse.cpp @@ -7,7 +7,8 @@ int strtri_wrapper(char uplo, char diag, float* matrix, int N) { int info; - MLX_LAPACK_FUNC(strtri)( + MLX_LAPACK_FUNC(strtri) + ( /* uplo = */ &uplo, /* diag = */ &diag, /* N = */ &N, diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index 997279ed3..daf5573fc 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -487,8 +487,10 @@ array eigvalsh( {a}); } -std::pair -eigh(const array& a, std::string UPLO /* = "L" */, StreamOrDevice s /* = {} */) { +std::pair eigh( + const array& a, + std::string UPLO /* = "L" */, + StreamOrDevice s /* = {} */) { validate_eigh(a, "[linalg::eigh]"); auto out = array::make_arrays( {std::vector(a.shape().begin(), a.shape().end() - 1), a.shape()}, diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 9fb40b687..e2c3aea23 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -443,7 +443,7 @@ void init_linalg(nb::module_& parent_module) { m.def( "eigh", [](const array& a, const std::string UPLO, StreamOrDevice s) { - // TODO avoid cast? + // TODO avoid cast? auto result = eigh(a, UPLO, s); return nb::make_tuple(result.first, result.second); }, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index ccc73cade..695d7704f 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -276,7 +276,9 @@ def check_eigs_and_vecs(A_np, kwargs={}): eig_vals, eig_vecs = mx.linalg.eigh(A, stream=mx.cpu, **kwargs) eig_vals_np, _ = np.linalg.eigh(A_np, **kwargs) self.assertTrue(np.allclose(eig_vals, eig_vals_np, **tols)) - self.assertTrue(mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols)) + self.assertTrue( + mx.allclose(A @ eig_vecs, eig_vals[..., None, :] * eig_vecs, **tols) + ) eig_vals_only = mx.linalg.eigvalsh(A, stream=mx.cpu, **kwargs) self.assertTrue(mx.allclose(eig_vals, eig_vals_only, **tols)) @@ -297,9 +299,7 @@ def check_eigs_and_vecs(A_np, kwargs={}): # Test with batched input A_np = np.random.randn(3, n, n).astype(np.float32) - A_np = ( - A_np + np.transpose(A_np, (0, 2, 1)) - ) / 2 + A_np = (A_np + np.transpose(A_np, (0, 2, 1))) / 2 check_eigs_and_vecs(A_np) # Test error cases diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 951185f12..f0b34cc01 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -436,7 +436,6 @@ TEST_CASE("test cross product") { CHECK(allclose(result, expected).item()); } - TEST_CASE("test matrix eigh") { // 0D and 1D throw CHECK_THROWS(linalg::eigh(array(0.0))); From 758e033968dd31b2afe04da4828d525c680c46e4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Tue, 22 Oct 2024 20:54:50 +0200 Subject: [PATCH 19/19] add cblas.h --- mlx/backend/common/lapack.h | 1 + 1 file changed, 1 insertion(+) diff --git a/mlx/backend/common/lapack.h b/mlx/backend/common/lapack.h index 16a699a17..b3bb7ebf0 100644 --- a/mlx/backend/common/lapack.h +++ b/mlx/backend/common/lapack.h @@ -5,6 +5,7 @@ #ifdef ACCELERATE_NEW_LAPACK #include #else +#include #include #endif