-
Notifications
You must be signed in to change notification settings - Fork 990
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implemented Cholesky on CPU #1026 #1119
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/allocator.h" | ||
#include "mlx/backend/common/copy.h" | ||
#include "mlx/linalg.h" | ||
#include "mlx/primitives.h" | ||
|
||
#ifdef ACCELERATE_NEW_LAPACK | ||
#include <Accelerate/Accelerate.h> | ||
#else | ||
#include <lapack.h> | ||
#endif | ||
|
||
namespace mlx::core { | ||
|
||
namespace { | ||
|
||
// Delegate to the Cholesky factorization taking into account differences in | ||
// LAPACK implementations (basically how to pass the 'uplo' string to fortran). | ||
int spotrf_wrapper(char uplo, float* matrix, int N) { | ||
int info; | ||
|
||
#ifdef LAPACK_FORTRAN_STRLEN_END | ||
spotrf_( | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* info = */ &info, | ||
/* uplo_len = */ static_cast<size_t>(1)); | ||
#else | ||
spotrf_( | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* info = */ &info); | ||
#endif | ||
|
||
return info; | ||
} | ||
|
||
} // namespace | ||
|
||
void cholesky_impl(const array& a, array& factor, bool upper) { | ||
// Lapack uses the column-major convention. We take advantage of the fact that | ||
// the matrix should be symmetric: | ||
// (A)ᵀ = A | ||
// and that a column-major lower triangular matrix is a row-major upper | ||
// triangular matrix, so uplo is the opposite of what we would expect from | ||
// upper | ||
|
||
char uplo = (upper) ? 'L' : 'U'; | ||
|
||
// The decomposition is computed in place, so just copy the input to the | ||
// output. | ||
copy( | ||
a, | ||
factor, | ||
a.flags().row_contiguous ? CopyType::Vector : CopyType::General); | ||
|
||
const int N = a.shape(-1); | ||
const size_t num_matrices = a.size() / (N * N); | ||
|
||
float* matrix = factor.data<float>(); | ||
|
||
for (int i = 0; i < num_matrices; i++) { | ||
// Compute Cholesky factorization. | ||
int info = spotrf_wrapper(uplo, matrix, N); | ||
|
||
// TODO: We do nothing when the matrix is not positive semi-definite | ||
// because throwing an error would result in a crash. If we figure out how | ||
// to catch errors from the implementation we should throw. | ||
if (info < 0) { | ||
std::stringstream msg; | ||
msg << "[cholesky] Cholesky decomposition failed with error code " | ||
<< info; | ||
throw std::runtime_error(msg.str()); | ||
} | ||
|
||
// Zero out the upper/lower triangle while advancing the pointer to the | ||
// next matrix at the same time. | ||
for (int row = 0; row < N; row++) { | ||
if (upper) { | ||
std::fill(matrix, matrix + row, 0); | ||
} else { | ||
std::fill(matrix + row + 1, matrix + N, 0); | ||
} | ||
matrix += N; | ||
} | ||
} | ||
} | ||
|
||
void Cholesky::eval(const std::vector<array>& inputs, array& output) { | ||
if (inputs[0].dtype() != float32) { | ||
throw std::runtime_error("[Cholesky::eval] only supports float32."); | ||
} | ||
cholesky_impl(inputs[0], output, upper_); | ||
} | ||
|
||
std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap( | ||
const std::vector<array>& inputs, | ||
const std::vector<int>& axes) { | ||
auto ax = axes[0] >= 0 ? 0 : -1; | ||
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0]; | ||
return {{linalg::cholesky(a, upper_, stream())}, {ax}}; | ||
} | ||
|
||
} // namespace mlx::core |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -112,6 +112,7 @@ DEFAULT(Tan) | |
DEFAULT(Tanh) | ||
DEFAULT(Transpose) | ||
DEFAULT(Inverse) | ||
DEFAULT(Cholesky) | ||
|
||
namespace { | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -260,4 +260,33 @@ void init_linalg(nb::module_& parent_module) { | |
Returns: | ||
array: ``ainv`` such that ``dot(a, ainv) = dot(ainv, a) = eye(a.shape[0])`` | ||
)pbdoc"); | ||
m.def( | ||
"cholesky", | ||
&cholesky, | ||
"a"_a, | ||
"upper"_a = false, | ||
nb::kw_only(), | ||
"stream"_a = nb::none(), | ||
nb::sig( | ||
"def cholesky(a: array, upper: bool = False, *, stream: Union[None, Stream, Device] = None) -> array"), | ||
R"pbdoc( | ||
Compute the Cholesky decomposition of a real symmetric positive semi-definite matrix. | ||
|
||
This function supports arrays with at least 2 dimensions. When the input | ||
has more than two dimensions, the Cholesky decomposition is computed for each matrix | ||
in the last two dimensions of ``a``. | ||
|
||
If the input matrix is not symmetric positive semi-definite, behaviour is undefined. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It currently crashes. Maybe undefined is indeed better. We could just not throw in that case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it does crash if the matrix is not positive semidefinite. The behavior I wanted was
Maybe I should specify it in the doc, or do you still prefer not raising at all? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I changed it to do nothing and added a comment to add some error checking if we can throw from the implementation in the future. Otherwise, potentially crashing the program in an unrecoverable way based on the values of the input matrix seems an extremely bad idea to me. |
||
|
||
Args: | ||
a (array): Input array. | ||
upper (bool, optional): If ``True``, return the upper triangular Cholesky factor. | ||
If ``False``, return the lower triangular Cholesky factor. Default: ``False``. | ||
stream (Stream, optional): Stream or device. Defaults to ``None`` | ||
in which case the default stream of the default device is used. | ||
|
||
Returns: | ||
array: if ``upper = False``, it returns a lower trinagular ``L``matrix such that ``dot(L, L.T) = a``. | ||
If ``upper = True``, it returns an upper triangular ``U`` matrix such that ``dot(U.T, U) = a``. | ||
)pbdoc"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I missed it in the first review but here you were zeroing column-wise which is really inefficient especially for large matrices. I also use
std::fill
which may be faster if the compiler couldn't figure out that the previous loop was vectorizable.