From fa4b6cb98bdf041936509426b8e952594aee2093 Mon Sep 17 00:00:00 2001 From: Luca Arnaboldi Date: Tue, 14 May 2024 15:29:43 +0200 Subject: [PATCH 1/4] Implemented Cholesky for symmetric positive semi-definite --- mlx/backend/accelerate/primitives.cpp | 1 + mlx/backend/common/CMakeLists.txt | 1 + mlx/backend/common/cholesky.cpp | 86 +++++++++++++++++++++++ mlx/backend/common/default_primitives.cpp | 1 + mlx/backend/metal/primitives.cpp | 4 ++ mlx/backend/no_metal/primitives.cpp | 1 + mlx/linalg.cpp | 31 ++++++++ mlx/linalg.h | 2 + mlx/primitives.h | 16 +++++ python/src/linalg.cpp | 29 ++++++++ python/tests/test_linalg.py | 17 +++++ tests/linalg_tests.cpp | 26 +++++++ 12 files changed, 215 insertions(+) create mode 100644 mlx/backend/common/cholesky.cpp diff --git a/mlx/backend/accelerate/primitives.cpp b/mlx/backend/accelerate/primitives.cpp index 9bf1868c2..b40075d36 100644 --- a/mlx/backend/accelerate/primitives.cpp +++ b/mlx/backend/accelerate/primitives.cpp @@ -79,6 +79,7 @@ DEFAULT(StopGradient) DEFAULT_MULTI(SVD) DEFAULT(Transpose) DEFAULT(Inverse) +DEFAULT(Cholesky) void Abs::eval_cpu(const std::vector& inputs, array& out) { assert(inputs.size() == 1); diff --git a/mlx/backend/common/CMakeLists.txt b/mlx/backend/common/CMakeLists.txt index ea0babf18..4fd573bb9 100644 --- a/mlx/backend/common/CMakeLists.txt +++ b/mlx/backend/common/CMakeLists.txt @@ -55,6 +55,7 @@ target_sources( ${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp ${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp ${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp ${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp ) diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp new file mode 100644 index 000000000..2100266d1 --- /dev/null +++ b/mlx/backend/common/cholesky.cpp @@ -0,0 +1,86 @@ +// Copyright © 2023-2024 Apple Inc. + +#include "mlx/allocator.h" +#include "mlx/backend/common/copy.h" +#include "mlx/linalg.h" +#include "mlx/primitives.h" + +#ifdef ACCELERATE_NEW_LAPACK +#include +#else +#include +#endif + +namespace mlx::core { + +void cholesky_impl(const array& a, array& T, bool upper) { + // Lapack uses the column-major convention. We take advantage of the fact that + // the matrix should be symmetric: + // (A)ᵀ = A + // and that a column-major lower triangular matrix is a row-major upper + // triangular matrix, so uplo is the opposite of what we would expect from + // upper + + char uplo; + if (upper) { + uplo = 'L'; + } else { + uplo = 'U'; + } + + // The decomposition is computed in place, so just copy the input to the + // output. + copy(a, T, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + + const int N = a.shape(-1); + const size_t num_matrices = a.size() / (N * N); + + int info; + + for (int i = 0; i < num_matrices; i++) { + // Compute Cholesky factorization. + spotrf_( + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ T.data() + N * N * i, + /* lda = */ &N, + /* info = */ &info); + + if (info != 0) { + std::stringstream ss; + if (info < 0) + ss << "cholesky_impl: failed with error code " << info; + else { + ss << "cholesky_impl: matrix is not positive definite."; + } + throw std::runtime_error(ss.str()); + } + + // Zero out the upper/lower triangle. + for (int j = 0; j < N; j++) { + for (int k = 0; k < j; k++) { + if (upper) + T.data()[N * N * i + j * N + k] = 0.; + else + T.data()[N * N * i + k * N + j] = 0.; + } + } + } +} + +void Cholesky::eval(const std::vector& inputs, array& output) { + if (inputs[0].dtype() != float32) { + throw std::runtime_error("[Cholesky::eval] only supports float32."); + } + cholesky_impl(inputs[0], output, upper_); +} + +std::pair, std::vector> Cholesky::vmap( + const std::vector& inputs, + const std::vector& axes) { + auto ax = axes[0] >= 0 ? 0 : -1; + auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; + return {{linalg::cholesky(a, upper_, stream())}, {ax}}; +} + +} // namespace mlx::core diff --git a/mlx/backend/common/default_primitives.cpp b/mlx/backend/common/default_primitives.cpp index ec5289d6a..a50d30ad2 100644 --- a/mlx/backend/common/default_primitives.cpp +++ b/mlx/backend/common/default_primitives.cpp @@ -112,6 +112,7 @@ DEFAULT(Tan) DEFAULT(Tanh) DEFAULT(Transpose) DEFAULT(Inverse) +DEFAULT(Cholesky) namespace { diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index d989b2197..0b5e6ef3b 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -1012,4 +1012,8 @@ void Inverse::eval_gpu(const std::vector& inputs, array& output) { throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI."); } +void Cholesky::eval_gpu(const std::vector& inputs, array& out) { + throw std::runtime_error("[Cholesky::eval_gpu] Metal inversion NYI."); +} + } // namespace mlx::core diff --git a/mlx/backend/no_metal/primitives.cpp b/mlx/backend/no_metal/primitives.cpp index 63114d386..fa322b03a 100644 --- a/mlx/backend/no_metal/primitives.cpp +++ b/mlx/backend/no_metal/primitives.cpp @@ -106,6 +106,7 @@ NO_GPU(Tan) NO_GPU(Tanh) NO_GPU(Transpose) NO_GPU(Inverse) +NO_GPU(Cholesky) namespace fast { NO_GPU_MULTI(LayerNorm) diff --git a/mlx/linalg.cpp b/mlx/linalg.cpp index d772c0e14..845d1981f 100644 --- a/mlx/linalg.cpp +++ b/mlx/linalg.cpp @@ -261,4 +261,35 @@ array inv(const array& a, StreamOrDevice s /* = {} */) { a.shape(), a.dtype(), std::make_shared(to_stream(s)), {a}); } +array cholesky( + const array& a, + bool upper /* = false */, + StreamOrDevice s /* = {} */) { + if (a.dtype() != float32) { + std::ostringstream msg; + msg << "[linalg::cholesky] Arrays must type float32. Received array " + << "with type " << a.dtype() << "."; + throw std::invalid_argument(msg.str()); + } + + if (a.ndim() < 2) { + std::ostringstream msg; + msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array " + "with " + << a.ndim() << " dimensions."; + throw std::invalid_argument(msg.str()); + } + + if (a.shape(-1) != a.shape(-2)) { + throw std::invalid_argument( + "[linalg::cholesky] Cholesky decomposition is only defined for square " + "matrices."); + } + return array( + a.shape(), + a.dtype(), + std::make_shared(to_stream(s), upper), + {a}); +} + } // namespace mlx::core::linalg diff --git a/mlx/linalg.h b/mlx/linalg.h index aa46a7959..16a2bf25b 100644 --- a/mlx/linalg.h +++ b/mlx/linalg.h @@ -66,4 +66,6 @@ std::vector svd(const array& a, StreamOrDevice s = {}); array inv(const array& a, StreamOrDevice s = {}); +array cholesky(const array& a, bool upper = false, StreamOrDevice s = {}); + } // namespace mlx::core::linalg diff --git a/mlx/primitives.h b/mlx/primitives.h index 868b5e7f5..3b6c80205 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -2065,4 +2065,20 @@ class Inverse : public UnaryPrimitive { void eval(const std::vector& inputs, array& output); }; +class Cholesky : public UnaryPrimitive { + public: + explicit Cholesky(Stream stream, bool upper) + : UnaryPrimitive(stream), upper_(upper) {}; + + void eval_cpu(const std::vector& inputs, array& out) override; + void eval_gpu(const std::vector& inputs, array& out) override; + + DEFINE_VMAP() + DEFINE_PRINT(Cholesky) + + private: + void eval(const std::vector& inputs, array& output); + bool upper_; +}; + } // namespace mlx::core diff --git a/python/src/linalg.cpp b/python/src/linalg.cpp index a6a86e414..eed8fe53f 100644 --- a/python/src/linalg.cpp +++ b/python/src/linalg.cpp @@ -260,4 +260,33 @@ void init_linalg(nb::module_& parent_module) { Returns: array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`` )pbdoc"); + m.def( + "cholesky", + &cholesky, + "a"_a, + "upper"_a = false, + nb::kw_only(), + "stream"_a = nb::none(), + nb::sig( + "def cholesky(a: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), + R"pbdoc( + Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix. + + This function supports arrays with at least 2 dimensions. When the input + has more than two dimensions, the Cholesky decomposition is computed for each matrix + in the last two dimensions of ``a``. + + If the input matrix is not symmetric positive semi-definite, behaviour is undefined. + + Args: + a (array): Input array. + upper (bool, optional): If ``True``, return the upper triangular Cholesky factor. + If ``False``, return the lower triangular Cholesky factor. Default: ``False``. + stream (Stream, optional): Stream or device. Defaults to ``None`` + in which case the default stream of the default device is used. + + Returns: + array: if ``upper = False``, it returns a lower trinagular ``L``matrix such that ``dot(L, L.T) = a``. + If ``upper = True``, it returns an upper triangular ``U`` matrix such that ``dot(U.T, U) = a``. + )pbdoc"); } diff --git a/python/tests/test_linalg.py b/python/tests/test_linalg.py index a8dec0322..944df89b8 100644 --- a/python/tests/test_linalg.py +++ b/python/tests/test_linalg.py @@ -150,6 +150,23 @@ def test_inverse(self): mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5) ) + def test_cholesky(self): + sqrtA = mx.array( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32 + ) + A = sqrtA.T @ sqrtA / 81 + L = mx.linalg.cholesky(A, stream=mx.cpu) + U = mx.linalg.cholesky(A, upper=True, stream=mx.cpu) + self.assertTrue(mx.allclose(L @ L.T, A, rtol=1e-5, atol=1e-7)) + self.assertTrue(mx.allclose(U.T @ U, A, rtol=1e-5, atol=1e-7)) + + # Multiple matrices + B = A + 1 / 9 + AB = mx.stack([A, B]) + Ls = mx.linalg.cholesky(AB, stream=mx.cpu) + for M, L in zip(AB, Ls): + self.assertTrue(mx.allclose(L @ L.T, M, rtol=1e-5, atol=1e-7)) + if __name__ == "__main__": unittest.main() diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 45ccb6134..03a40f7ae 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -322,3 +322,29 @@ TEST_CASE("test matrix inversion") { CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6) .item()); } + +TEST_CASE("test matrix cholensky") { + // 0D and 1D throw + CHECK_THROWS(linalg::cholesky(array(0.0), /* upper = */ false, Device::cpu)); + CHECK_THROWS( + linalg::cholesky(array({0.0, 1.0}), /* upper = */ false, Device::cpu)); + + // Unsupported types throw + CHECK_THROWS(linalg::cholesky( + array({0, 1}, {1, 2}), /* upper = */ false, Device::cpu)); + + // Non-square throws. + CHECK_THROWS(linalg::cholesky( + array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ false, Device::cpu)); + + const auto prng_key = random::key(220398); + const auto sqrtA = random::normal({5, 5}, prng_key); + const auto A = matmul(sqrtA, transpose(sqrtA)); + const auto L = linalg::cholesky(A, /* upper = */ false, Device::cpu); + const auto U = linalg::cholesky(A, /* upper = */ true, Device::cpu); + + CHECK(allclose(matmul(L, transpose(L)), A, /* rtol = */ 0, /* atol = */ 1e-6) + .item()); + CHECK(allclose(matmul(transpose(U), U), A, /* rtol = */ 0, /* atol = */ 1e-6) + .item()); +} \ No newline at end of file From 989a245c26cfdc2ab55d4996295e32fc46da4e2d Mon Sep 17 00:00:00 2001 From: Luca Arnaboldi Date: Thu, 16 May 2024 11:14:46 +0200 Subject: [PATCH 2/4] Typos from the reviews --- mlx/backend/common/cholesky.cpp | 20 +++++++++----------- mlx/backend/metal/primitives.cpp | 3 ++- tests/linalg_tests.cpp | 2 +- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp index 2100266d1..43fd2f4a3 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/common/cholesky.cpp @@ -13,7 +13,7 @@ namespace mlx::core { -void cholesky_impl(const array& a, array& T, bool upper) { +void cholesky_impl(const array& a, array& factor, bool upper) { // Lapack uses the column-major convention. We take advantage of the fact that // the matrix should be symmetric: // (A)ᵀ = A @@ -21,16 +21,14 @@ void cholesky_impl(const array& a, array& T, bool upper) { // triangular matrix, so uplo is the opposite of what we would expect from // upper - char uplo; - if (upper) { - uplo = 'L'; - } else { - uplo = 'U'; - } + char uplo = (upper) ? 'L' : 'U'; // The decomposition is computed in place, so just copy the input to the // output. - copy(a, T, a.flags().row_contiguous ? CopyType::Vector : CopyType::General); + copy( + a, + factor, + a.flags().row_contiguous ? CopyType::Vector : CopyType::General); const int N = a.shape(-1); const size_t num_matrices = a.size() / (N * N); @@ -42,7 +40,7 @@ void cholesky_impl(const array& a, array& T, bool upper) { spotrf_( /* uplo = */ &uplo, /* n = */ &N, - /* a = */ T.data() + N * N * i, + /* a = */ factor.data() + N * N * i, /* lda = */ &N, /* info = */ &info); @@ -60,9 +58,9 @@ void cholesky_impl(const array& a, array& T, bool upper) { for (int j = 0; j < N; j++) { for (int k = 0; k < j; k++) { if (upper) - T.data()[N * N * i + j * N + k] = 0.; + factor.data()[N * N * i + j * N + k] = 0.; else - T.data()[N * N * i + k * N + j] = 0.; + factor.data()[N * N * i + k * N + j] = 0.; } } } diff --git a/mlx/backend/metal/primitives.cpp b/mlx/backend/metal/primitives.cpp index 0b5e6ef3b..a2c3df651 100644 --- a/mlx/backend/metal/primitives.cpp +++ b/mlx/backend/metal/primitives.cpp @@ -1013,7 +1013,8 @@ void Inverse::eval_gpu(const std::vector& inputs, array& output) { } void Cholesky::eval_gpu(const std::vector& inputs, array& out) { - throw std::runtime_error("[Cholesky::eval_gpu] Metal inversion NYI."); + throw std::runtime_error( + "[Cholesky::eval_gpu] Metal Cholesky decomposition NYI."); } } // namespace mlx::core diff --git a/tests/linalg_tests.cpp b/tests/linalg_tests.cpp index 03a40f7ae..2af868965 100644 --- a/tests/linalg_tests.cpp +++ b/tests/linalg_tests.cpp @@ -323,7 +323,7 @@ TEST_CASE("test matrix inversion") { .item()); } -TEST_CASE("test matrix cholensky") { +TEST_CASE("test matrix cholesky") { // 0D and 1D throw CHECK_THROWS(linalg::cholesky(array(0.0), /* upper = */ false, Device::cpu)); CHECK_THROWS( From f3afcb61ddb0108d4aab6d5a4e46f3d30c69398a Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 17 May 2024 10:44:53 -0700 Subject: [PATCH 3/4] Change the zeroing order and remove throw --- mlx/backend/common/cholesky.cpp | 34 +++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp index 43fd2f4a3..090494ad2 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/common/cholesky.cpp @@ -34,34 +34,36 @@ void cholesky_impl(const array& a, array& factor, bool upper) { const size_t num_matrices = a.size() / (N * N); int info; + float* matrix = factor.data(); for (int i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. spotrf_( /* uplo = */ &uplo, /* n = */ &N, - /* a = */ factor.data() + N * N * i, + /* a = */ matrix, /* lda = */ &N, /* info = */ &info); - if (info != 0) { - std::stringstream ss; - if (info < 0) - ss << "cholesky_impl: failed with error code " << info; - else { - ss << "cholesky_impl: matrix is not positive definite."; - } - throw std::runtime_error(ss.str()); + // TODO: We do nothing when the matrix is not positive semi-definite + // because throwing an error would result in a crash. If we figure out how + // to catch errors from the implementation we should throw. + if (info < 0) { + std::stringstream msg; + msg << "[cholesky] Cholesky decomposition failed with error code " + << info; + throw std::runtime_error(msg.str()); } - // Zero out the upper/lower triangle. - for (int j = 0; j < N; j++) { - for (int k = 0; k < j; k++) { - if (upper) - factor.data()[N * N * i + j * N + k] = 0.; - else - factor.data()[N * N * i + k * N + j] = 0.; + // Zero out the upper/lower triangle while advancing the pointer to the + // next matrix at the same time. + for (int row = 0; row < N; row++) { + if (upper) { + std::fill(matrix, matrix + row, 0); + } else { + std::fill(matrix + row + 1, matrix + N, 0); } + matrix += N; } } } From 9cbdaba4eea301fee68a9ea7df5d1097055d1fcc Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Fri, 17 May 2024 12:11:57 -0700 Subject: [PATCH 4/4] Account for differences in passing strings to Fortran --- mlx/backend/common/cholesky.cpp | 37 ++++++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/mlx/backend/common/cholesky.cpp b/mlx/backend/common/cholesky.cpp index 090494ad2..2af5d8ddf 100644 --- a/mlx/backend/common/cholesky.cpp +++ b/mlx/backend/common/cholesky.cpp @@ -13,6 +13,35 @@ namespace mlx::core { +namespace { + +// Delegate to the Cholesky factorization taking into account differences in +// LAPACK implementations (basically how to pass the 'uplo' string to fortran). +int spotrf_wrapper(char uplo, float* matrix, int N) { + int info; + +#ifdef LAPACK_FORTRAN_STRLEN_END + spotrf_( + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info, + /* uplo_len = */ static_cast(1)); +#else + spotrf_( + /* uplo = */ &uplo, + /* n = */ &N, + /* a = */ matrix, + /* lda = */ &N, + /* info = */ &info); +#endif + + return info; +} + +} // namespace + void cholesky_impl(const array& a, array& factor, bool upper) { // Lapack uses the column-major convention. We take advantage of the fact that // the matrix should be symmetric: @@ -33,17 +62,11 @@ void cholesky_impl(const array& a, array& factor, bool upper) { const int N = a.shape(-1); const size_t num_matrices = a.size() / (N * N); - int info; float* matrix = factor.data(); for (int i = 0; i < num_matrices; i++) { // Compute Cholesky factorization. - spotrf_( - /* uplo = */ &uplo, - /* n = */ &N, - /* a = */ matrix, - /* lda = */ &N, - /* info = */ &info); + int info = spotrf_wrapper(uplo, matrix, N); // TODO: We do nothing when the matrix is not positive semi-definite // because throwing an error would result in a crash. If we figure out how