Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Cholesky on CPU #1026 #1119

Merged
merged 4 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ DEFAULT(StopGradient)
DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)

void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/qrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/svd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/inverse.cpp
${CMAKE_CURRENT_SOURCE_DIR}/cholesky.cpp
${CMAKE_CURRENT_BINARY_DIR}/compiled_preamble.cpp
)

Expand Down
109 changes: 109 additions & 0 deletions mlx/backend/common/cholesky.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright © 2023-2024 Apple Inc.

#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"

#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif

namespace mlx::core {

namespace {

// Delegate to the Cholesky factorization taking into account differences in
// LAPACK implementations (basically how to pass the 'uplo' string to fortran).
int spotrf_wrapper(char uplo, float* matrix, int N) {
int info;

#ifdef LAPACK_FORTRAN_STRLEN_END
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info,
/* uplo_len = */ static_cast<size_t>(1));
#else
spotrf_(
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* info = */ &info);
#endif

return info;
}

} // namespace

void cholesky_impl(const array& a, array& factor, bool upper) {
// Lapack uses the column-major convention. We take advantage of the fact that
// the matrix should be symmetric:
// (A)ᵀ = A
// and that a column-major lower triangular matrix is a row-major upper
// triangular matrix, so uplo is the opposite of what we would expect from
// upper

char uplo = (upper) ? 'L' : 'U';

// The decomposition is computed in place, so just copy the input to the
// output.
copy(
a,
factor,
a.flags().row_contiguous ? CopyType::Vector : CopyType::General);

const int N = a.shape(-1);
const size_t num_matrices = a.size() / (N * N);

float* matrix = factor.data<float>();

for (int i = 0; i < num_matrices; i++) {
// Compute Cholesky factorization.
int info = spotrf_wrapper(uplo, matrix, N);

// TODO: We do nothing when the matrix is not positive semi-definite
// because throwing an error would result in a crash. If we figure out how
// to catch errors from the implementation we should throw.
if (info < 0) {
std::stringstream msg;
msg << "[cholesky] Cholesky decomposition failed with error code "
<< info;
throw std::runtime_error(msg.str());
}

// Zero out the upper/lower triangle while advancing the pointer to the
// next matrix at the same time.
for (int row = 0; row < N; row++) {
if (upper) {
std::fill(matrix, matrix + row, 0);
} else {
std::fill(matrix + row + 1, matrix + N, 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it in the first review but here you were zeroing column-wise which is really inefficient especially for large matrices. I also use std::fill which may be faster if the compiler couldn't figure out that the previous loop was vectorizable.

}
matrix += N;
}
}
}

void Cholesky::eval(const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Cholesky::eval] only supports float32.");
}
cholesky_impl(inputs[0], output, upper_);
}

std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
}

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ DEFAULT(Tan)
DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)

namespace {

Expand Down
5 changes: 5 additions & 0 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1012,4 +1012,9 @@ void Inverse::eval_gpu(const std::vector<array>& inputs, array& output) {
throw std::runtime_error("[Inverse::eval_gpu] Metal inversion NYI.");
}

void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
}

} // namespace mlx::core
1 change: 1 addition & 0 deletions mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ NO_GPU(Tan)
NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Inverse)
NO_GPU(Cholesky)

namespace fast {
NO_GPU_MULTI(LayerNorm)
Expand Down
31 changes: 31 additions & 0 deletions mlx/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,4 +261,35 @@ array inv(const array& a, StreamOrDevice s /* = {} */) {
a.shape(), a.dtype(), std::make_shared<Inverse>(to_stream(s)), {a});
}

array cholesky(
const array& a,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::cholesky] Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(
"[linalg::cholesky] Cholesky decomposition is only defined for square "
"matrices.");
}
return array(
a.shape(),
a.dtype(),
std::make_shared<Cholesky>(to_stream(s), upper),
{a});
}

} // namespace mlx::core::linalg
2 changes: 2 additions & 0 deletions mlx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,6 @@ std::vector<array> svd(const array& a, StreamOrDevice s = {});

array inv(const array& a, StreamOrDevice s = {});

array cholesky(const array& a, bool upper = false, StreamOrDevice s = {});

} // namespace mlx::core::linalg
16 changes: 16 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -2065,4 +2065,20 @@ class Inverse : public UnaryPrimitive {
void eval(const std::vector<array>& inputs, array& output);
};

class Cholesky : public UnaryPrimitive {
public:
explicit Cholesky(Stream stream, bool upper)
: UnaryPrimitive(stream), upper_(upper) {};

void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;

DEFINE_VMAP()
DEFINE_PRINT(Cholesky)

private:
void eval(const std::vector<array>& inputs, array& output);
bool upper_;
};

} // namespace mlx::core
29 changes: 29 additions & 0 deletions python/src/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,33 @@ void init_linalg(nb::module_& parent_module) {
Returns:
array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])``
)pbdoc");
m.def(
"cholesky",
&cholesky,
"a"_a,
"upper"_a = false,
nb::kw_only(),
"stream"_a = nb::none(),
nb::sig(
"def cholesky(a: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix.

This function supports arrays with at least 2 dimensions. When the input
has more than two dimensions, the Cholesky decomposition is computed for each matrix
in the last two dimensions of ``a``.

If the input matrix is not symmetric positive semi-definite, behaviour is undefined.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It currently crashes. Maybe undefined is indeed better. We could just not throw in that case.

Copy link
Contributor Author

@arn4 arn4 May 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it does crash if the matrix is not positive semidefinite. The behavior I wanted was

  • raise when the matrix is not positive semidefinite, since it is directly detected by Lapack
  • undefined if the matrix is not symmetric, since adding a check will come an unnecessary overhead.

Maybe I should specify it in the doc, or do you still prefer not raising at all?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed it to do nothing and added a comment to add some error checking if we can throw from the implementation in the future. Otherwise, potentially crashing the program in an unrecoverable way based on the values of the input matrix seems an extremely bad idea to me.


Args:
a (array): Input array.
upper (bool, optional): If ``True``, return the upper triangular Cholesky factor.
If ``False``, return the lower triangular Cholesky factor. Default: ``False``.
stream (Stream, optional): Stream or device. Defaults to ``None``
in which case the default stream of the default device is used.

Returns:
array: if ``upper = False``, it returns a lower trinagular ``L``matrix such that ``dot(L, L.T) = a``.
If ``upper = True``, it returns an upper triangular ``U`` matrix such that ``dot(U.T, U) = a``.
)pbdoc");
}
17 changes: 17 additions & 0 deletions python/tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,23 @@ def test_inverse(self):
mx.allclose(M @ M_inv, mx.eye(M.shape[0]), rtol=0, atol=1e-5)
)

def test_cholesky(self):
sqrtA = mx.array(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=mx.float32
)
A = sqrtA.T @ sqrtA / 81
L = mx.linalg.cholesky(A, stream=mx.cpu)
U = mx.linalg.cholesky(A, upper=True, stream=mx.cpu)
self.assertTrue(mx.allclose(L @ L.T, A, rtol=1e-5, atol=1e-7))
self.assertTrue(mx.allclose(U.T @ U, A, rtol=1e-5, atol=1e-7))

# Multiple matrices
B = A + 1 / 9
AB = mx.stack([A, B])
Ls = mx.linalg.cholesky(AB, stream=mx.cpu)
for M, L in zip(AB, Ls):
self.assertTrue(mx.allclose(L @ L.T, M, rtol=1e-5, atol=1e-7))


if __name__ == "__main__":
unittest.main()
26 changes: 26 additions & 0 deletions tests/linalg_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,29 @@ TEST_CASE("test matrix inversion") {
CHECK(allclose(matmul(A_inv, A), identity, /* rtol = */ 0, /* atol = */ 1e-6)
.item<bool>());
}

TEST_CASE("test matrix cholesky") {
// 0D and 1D throw
CHECK_THROWS(linalg::cholesky(array(0.0), /* upper = */ false, Device::cpu));
CHECK_THROWS(
linalg::cholesky(array({0.0, 1.0}), /* upper = */ false, Device::cpu));

// Unsupported types throw
CHECK_THROWS(linalg::cholesky(
array({0, 1}, {1, 2}), /* upper = */ false, Device::cpu));

// Non-square throws.
CHECK_THROWS(linalg::cholesky(
array({1, 2, 3, 4, 5, 6}, {2, 3}), /* upper = */ false, Device::cpu));

const auto prng_key = random::key(220398);
const auto sqrtA = random::normal({5, 5}, prng_key);
const auto A = matmul(sqrtA, transpose(sqrtA));
const auto L = linalg::cholesky(A, /* upper = */ false, Device::cpu);
const auto U = linalg::cholesky(A, /* upper = */ true, Device::cpu);

CHECK(allclose(matmul(L, transpose(L)), A, /* rtol = */ 0, /* atol = */ 1e-6)
.item<bool>());
CHECK(allclose(matmul(transpose(U), U), A, /* rtol = */ 0, /* atol = */ 1e-6)
.item<bool>());
}