diff --git a/docs/src/python/linalg.rst b/docs/src/python/linalg.rst index 227711c22..f6c51ed0b 100644 --- a/docs/src/python/linalg.rst +++ b/docs/src/python/linalg.rst @@ -16,3 +16,5 @@ Linear Algebra cross qr svd + eigvalsh + eigh diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index eee93f2ab..352ca9f2f 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) +DEFAULT_MULTI(EighPrimitive) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index 925f4731c..a1bf8c5ce 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -31,6 +31,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/common.cpp ${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp ${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/eigvalsh.cpp ${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp ${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index f8932c5f8..2418b18a6 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -114,6 +114,7 @@ DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) DEFAULT(Cholesky) +DEFAULT_MULTI(EighPrimitive) namespace { diff --git a/mlx/backend/common/eigvalsh.cpp b/mlx/backend/common/eigvalsh.cpp new file mode 100644 index 000000000..c21f2bcb4 --- /dev/null +++ b/mlx/backend/common/eigvalsh.cpp @@ -0,0 +1,184 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/array.h" +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +namespace mlx::core { + +namespace { + +// Delegate to the eigenvalue decomposition taking into account differences in +// LAPACK implementations (basically how to pass the 'jobz' and 'uplo' strings +// to fortran). +int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) { + int info; + int lwork = -1; + int liwork = -1; + float work_query; + int iwork_query; + + // Query for optimal work array sizes +#ifdef LAPACK_FORTRAN_STRLEN_END + ssyevd_( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* w = */ w, + /* work = */ &work_query, + /* lwork = */ &lwork, + /* iwork = */ &iwork_query, + /* liwork = */ &liwork, + /* info = */ &info, + /* jobz_len = */ static_cast(1), + /* uplo_len = */ static_cast(1)); +#else + ssyevd_( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* w = */ w, + /* work = */ &work_query, + /* lwork = */ &lwork, + /* iwork = */ &iwork_query, + /* liwork = */ &liwork, + /* info = */ &info); +#endif + + lwork = static_cast(work_query); + liwork = iwork_query; + + std::vector work(lwork); + std::vector iwork(liwork); + + // Compute eigenvalues (and optionally eigenvectors) +#ifdef LAPACK_FORTRAN_STRLEN_END + ssyevd_( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* w = */ w, + /* work = */ work.data(), + /* lwork = */ &lwork, + /* iwork = */ iwork.data(), + /* liwork = */ &liwork, + /* info = */ &info, + /* jobz_len = */ static_cast(1), + /* uplo_len = */ static_cast(1)); +#else + ssyevd_( + /* jobz = */ &jobz, + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* w = */ w, + /* work = */ work.data(), + /* lwork = */ &lwork, + /* iwork = */ iwork.data(), + /* liwork = */ &liwork, + /* info = */ &info); +#endif + + return info; +} + +} // namespace + +void eigh_impl( + const array& a, + array& values, + array& vectors, + bool upper, + bool compute_eigenvectors) { + char jobz = compute_eigenvectors ? 'V' : 'N'; + char uplo = upper ? 'U' : 'L'; + + array buffer = copy(a); + + const int N = static_cast(a.shape(-1)); + const int num_matrices = static_cast(a.size() / (N * N)); + + std::vector values_shape = {num_matrices, N}; + values = array(allocator::malloc(num_matrices * N * size_of(a.dtype())), values_shape, a.dtype()); + + float* matrix = buffer.data(); + float* w = values.data(); + + if (compute_eigenvectors) { + std::vector vectors_shape = a.shape(); + vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype()); + } + + float* vecs = compute_eigenvectors ? vectors.data() : nullptr; + + for (int i = 0; i < num_matrices; i++) { + int info = ssyevd_wrapper(jobz, uplo, matrix, w, N); + + if (info != 0) { + std::stringstream msg; + msg << "[eigh] Eigenvalue decomposition failed with error code " << info; + throw std::runtime_error(msg.str()); + } + + if (compute_eigenvectors) { + // Copy eigenvectors to the output array + std::copy(matrix, matrix + N * N, vecs); + vecs += N * N; + } + + matrix += N * N; + w += N; + } +} + +void EighPrimitive::eval( + const std::vector& inputs, + std::vector& outputs) { + // Validate the number of inputs + if (inputs.size() != 1) { + throw std::invalid_argument("[EighPrimitive::eval] Expected exactly one input array."); + } + + const array& input = inputs[0]; + + // Ensure the input array is evaluated before accessing its data + const_cast(input).eval(); + + // Validate the data type + Dtype input_dtype = input.dtype(); // Changed from 'dtype_t' to 'Dtype' + + // Validate the number of dimensions (expecting at least 2D) + if (input.ndim() < 2) { + throw std::invalid_argument("[EighPrimitive::eval] Input array must be at least 2-dimensional."); + } + + array values{}; + array vectors{}; + eigh_impl(input, values, vectors, upper_, compute_eigenvectors_); + + // Ensure the output arrays are evaluated + values.eval(); + if (compute_eigenvectors_) { + vectors.eval(); + outputs = {values, vectors}; + } else { + outputs = {values}; + } +} + +} // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d9607efce..ebe878a60 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -395,6 +395,10 @@ void Cholesky::eval_gpu(const std::vector& inputs, array& out) { "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } +void EighPrimitive::eval_gpu(const std::vector& inputs, std::vector& outputs) { + throw std::runtime_error("[Eigvalsh::eval_gpu] Metal EighPrimitive NYI."); +} + void View::eval_gpu(const std::vector& inputs, array& out) { auto& in = inputs[0]; auto ibytes = size_of(in.dtype()); diff --git a/mlx/backend/no_cpu/primitives.cpp b/mlx/backend/no_cpu/primitives.cpp index ff60e4d22..490dd89d3 100644 --- a/mlx/backend/no_cpu/primitives.cpp +++ b/mlx/backend/no_cpu/primitives.cpp @@ -48,6 +48,7 @@ NO_CPU(Divide) NO_CPU_MULTI(DivMod) NO_CPU(NumberOfElements) NO_CPU(Remainder) +NO_CPU_MULTI(EighPrimitive) NO_CPU(Equal) NO_CPU(Erf) NO_CPU(ErfInv) diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 544a2c6f2..0271fba17 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -110,6 +110,7 @@ NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) NO_GPU(Cholesky) +NO_GPU_MULTI(EighPrimitive) NO_GPU(View) namespace fast { diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index a64f98aa8..124998e7e 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -382,6 +382,71 @@ array cholesky_inv( } } +array eigvalsh( + const array& a, + bool upper /* = false */, + StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::eigvalsh] Arrays must be type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::eigvalsh] Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument( + "[linalg::eigvalsh] Eigenvalues are only defined for square matrices."); + } + + std::vector out_shape(a.shape().begin(), a.shape().end() - 1); + out_shape.back() = a.shape(-1); + + return array( + out_shape, + a.dtype(), + std::make_shared(to_stream(s), upper, false), + {astype(a, a.dtype(), s)}); +} + +std::pair eigh( + const array& a, + bool upper /* = false */, + StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::eigh] Arrays must be type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::eigh] Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument( + "[linalg::eigh] Eigenvectors are only defined for square matrices."); + } + + auto out = array::make_arrays( + {std::vector(a.shape().begin(), a.shape().end() - 1), a.shape()}, + {a.dtype(), a.dtype()}, + std::make_shared(to_stream(s), upper, true), + {astype(a, a.dtype(), s)}); + return std::make_pair(out[0], out[1]); + array cross( const array& a, const array& b, diff --git a/mlx/linalg.h b/mlx/linalg.h index acfcc1a41..f119b5f65 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -74,6 +74,10 @@ array pinv(const array& a, StreamOrDevice s = {}); array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); +array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {}); + +std::pair eigh(const array& a, bool upper = false, StreamOrDevice s = {}); + /** * Compute the cross product of two arrays along the given axis. */ diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index a1549fa6f..279bc3663 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -737,6 +737,28 @@ std::pair, std::vector> Cholesky::vmap( return {{linalg::cholesky(a, upper_, stream())}, {ax}}; } +std::pair, std::vector> EighPrimitive::vmap( + const std::vector& inputs, + const std::vector& axes) { + assert(inputs.size() == 1); + assert(axes.size() == 1); + + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + + std::vector outputs; + if (compute_eigenvectors_) { + auto [values, vectors] = linalg::eigh(a, upper_, stream()); + outputs = {values, vectors}; + } else { + outputs = {linalg::eigvalsh(a, upper_, stream())}; + } + + std::vector out_axes(outputs.size(), ax); + + return {outputs, out_axes}; +} + std::vector Concatenate::vjp( const std::vector& primals, const std::vector& cotangents, diff --git a/mlx/primitives.h b/mlx/primitives.h index 5e5bda7c0..82b45a333 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2166,4 +2166,39 @@ class Cholesky : public UnaryPrimitive { bool upper_; }; +class EighPrimitive : public Primitive { + public: + explicit EighPrimitive(Stream stream, bool upper, bool compute_eigenvectors) + : Primitive(stream), upper_(upper), compute_eigenvectors_(compute_eigenvectors) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) override; + void eval_gpu(const std::vector& inputs, std::vector& outputs) override; + + DEFINE_VMAP() + DEFINE_PRINT(EighPrimitive) + + std::vector> output_shapes( + const std::vector& inputs) override { + auto shape = inputs[0].shape(); + shape.pop_back(); // Remove last dimension for eigenvalues + if (compute_eigenvectors_) { + return {shape, inputs[0].shape()}; // Eigenvalues and eigenvectors + } else { + return {shape}; // Only eigenvalues + } + } + + bool is_equivalent(const Primitive& other) const override { + if (auto* p = dynamic_cast(&other)) { + return upper_ == p->upper_ && compute_eigenvectors_ == p->compute_eigenvectors_; + } + return false; + } + + private: + void eval(const std::vector& inputs, std::vector& outputs); + bool upper_; + bool compute_eigenvectors_; +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index 65dd8d0e4..482716888 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -353,6 +353,89 @@ void init_linalg(nb::module_& parent_module) { Returns: array: :math:`\mathbf{A^{-1}}` where :math:`\mathbf{A} = \mathbf{L}\mathbf{L}^T`. )pbdoc"); +m.def( + "eigvalsh", + &eigvalsh, + "a"_a, + "upper"_a = true, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def eigvalsh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the eigenvalues of a complex Hermitian or real symmetric matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvalues are computed for each matrix + in the last two dimensions of ``a``. + + Args: + a (array): Input array. Must be a real symmetric or complex Hermitian matrix. + upper (bool, optional): Whether to use the upper or lower triangle of the matrix. + Default is True (upper triangle). + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: The eigenvalues in ascending order. + + Note: + The input matrix is assumed to be symmetric (or Hermitian). Only the + upper triangle (if upper=True) or lower triangle (if upper=False) is used. + No checks for symmetry are performed. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> eigenvalues = mx.linalg.eigvalsh(A) + >>> eigenvalues + array([-1., 3.], dtype=float32) + )pbdoc"); +m.def( + "eigh", + [](const array& a, bool upper, StreamOrDevice s) { + auto result = eigh(a, upper, s); + return nb::make_tuple(result.first, result.second); + }, + "a"_a, + "upper"_a = true, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def eigh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> Tuple[array, array]"), + R"pbdoc( + Compute the eigenvalues and eigenvectors of a complex Hermitian or real symmetric matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the eigenvalues and eigenvectors are computed for each matrix + in the last two dimensions of ``a``. + + Args: + a (array): Input array. Must be a real symmetric or complex Hermitian matrix. + upper (bool, optional): Whether to use the upper or lower triangle of the matrix. + Default is True (upper triangle). + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + Tuple[array, array]: A tuple containing: + - The eigenvalues in ascending order. + - The normalized eigenvectors. The column v[:, i] is the + eigenvector corresponding to the i-th eigenvalue. + + Note: + The input matrix is assumed to be symmetric (or Hermitian). Only the + upper triangle (if upper=True) or lower triangle (if upper=False) is used. + No checks for symmetry are performed. + + Example: + >>> A = mx.array([[1., -2.], [-2., 1.]]) + >>> w, v = mx.linalg.eigh(A) + >>> w + array([-1., 3.], dtype=float32) + >>> v + array([[ 0.707107, -0.707107], + [ 0.707107, 0.707107]], dtype=float32) + )pbdoc"); m.def( "pinv", &pinv, diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index 6051beef7..d4da3814e 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -220,6 +220,54 @@ def test_cholesky_inv(self): for M, M_inv in zip(AB, AB_inv): self.assertTrue(mx.allclose(M @ M_inv, mx.eye(N), atol=1e-4)) + def test_eigvalsh(self): + # Test a simple 2x2 symmetric matrix + A_mx = mx.array([[1.0, 2.0], [2.0, 4.0]], dtype=mx.float32) + A_np = np.array([[1.0, 2.0], [2.0, 4.0]], dtype=np.float32) + + eigenvalues_mx = mx.linalg.eigvalsh(A_mx) + eigenvalues_np = np.linalg.eigvalsh(A_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + + # Test a larger random symmetric matrix + n = 5 + rng = np.random.default_rng(42) + B_np = rng.random((n, n)).astype(np.float32) + B_np = (B_np + B_np.T) / 2 # Make sure B is symmetric + B_mx = mx.array(B_np) + + eigenvalues_mx = mx.linalg.eigvalsh(B_mx) + eigenvalues_np = np.linalg.eigvalsh(B_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + + # Test that eigenvalues are in ascending order + self.assertTrue(mx.all(eigenvalues_mx[1:] >= eigenvalues_mx[:-1])) + + # Test with upper=False + eigenvalues_mx_lower = mx.linalg.eigvalsh(B_mx, upper=False) + eigenvalues_np_lower = np.linalg.eigvalsh(B_np, UPLO='L') + + self.assertTrue(mx.allclose(eigenvalues_mx_lower, mx.array(eigenvalues_np_lower), atol=1e-5)) + + # Test with batched input + C_np = rng.random((3, n, n)).astype(np.float32) + C_np = (C_np + np.transpose(C_np, (0, 2, 1))) / 2 # Make sure C is symmetric for each batch + C_mx = mx.array(C_np) + + eigenvalues_mx = mx.linalg.eigvalsh(C_mx) + eigenvalues_np = np.linalg.eigvalsh(C_np) + + self.assertTrue(mx.allclose(eigenvalues_mx, mx.array(eigenvalues_np), atol=1e-5)) + + # Test error cases + with self.assertRaises(ValueError): + mx.linalg.eigvalsh(mx.array([1.0, 2.0])) # 1D array + + with self.assertRaises(ValueError): + mx.linalg.eigvalsh(mx.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])) # Non-square matrix + def test_cross_product(self): a = mx.array([1.0, 2.0, 3.0]) b = mx.array([4.0, 5.0, 6.0]) diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index e9e196583..ecdac9c51 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -349,6 +349,57 @@ TEST_CASE("test matrix cholesky") { .item()); } +TEST_CASE("test matrix eigvalsh") { + // 0D and 1D throw + CHECK_THROWS(linalg::eigvalsh(array(0.0), /* upper = */ true, Device::cpu)); + CHECK_THROWS( + linalg::eigvalsh(array({0.0, 1.0}), /* upper = */ true, Device::cpu)); + + // Unsupported types throw + CHECK_THROWS( + linalg::eigvalsh(array({0, 1}, {1, 2}), /* upper = */ true, Device::cpu)); + + // Non-square throws + CHECK_THROWS(linalg::eigvalsh( + array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ true, Device::cpu)); + + // Test a simple 2x2 symmetric matrix + array A = array({{1.0, 2.0}, {2.0, 4.0}}); + array eigenvalues = linalg::eigvalsh(A, /* upper = */ true, Device::cpu); + + // Expected eigenvalues (calculated analytically) + array expected_eigenvalues = array({0.0, 5.0}); + + CHECK(allclose( + eigenvalues, + expected_eigenvalues, + /* rtol = */ 1e-5, + /* atol = */ 1e-5) + .item()); + + // Test a larger symmetric matrix + const auto prng_key = random::key(42); + const auto B = random::normal({5, 5}, prng_key); + const auto B_sym = 0.5 * (B + transpose(B)); // Make sure B is symmetric + const auto B_eigenvalues = + linalg::eigvalsh(B_sym, /* upper = */ true, Device::cpu); + + // Check that eigenvalues are real and in ascending order + CHECK(B_eigenvalues.dtype() == float32); + CHECK(B_eigenvalues.shape() == std::vector{5}); + CHECK(all(isfinite(B_eigenvalues)).item()); + CHECK(all(B_eigenvalues [1:] >= B_eigenvalues[:-1]).item()); + + // Reconstruct the matrix using eigendecomposition and check if it's close to + // the original + const auto D = diag(B_eigenvalues); + const auto V = linalg::eigh(B_sym, /* upper = */ true, Device::cpu) + .second; // Assuming eigh is implemented + const auto B_reconstructed = matmul(matmul(V, D), transpose(V)); + + CHECK(allclose(B_reconstructed, B_sym, /* rtol = */ 1e-5, /* atol = */ 1e-5) + .item()); +} TEST_CASE("test matrix pseudo-inverse") { // 0D and 1D throw CHECK_THROWS(linalg::pinv(array(0.0), Device::cpu));