-
Notifications
You must be signed in to change notification settings - Fork 943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Eigenvalues and eigenvectors #1334
base: main
Are you sure you want to change the base?
Changes from 8 commits
6735a59
f1789b3
a89ee52
b5f900a
9dd754b
19e0148
3100188
e64ca5e
523cb3d
97b965b
c0b653b
2383181
dc614eb
5b53354
859dd23
dbb5c64
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,3 +15,5 @@ Linear Algebra | |
cholesky_inv | ||
qr | ||
svd | ||
eigvalsh | ||
eigh |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
// Copyright © 2023-2024 Apple Inc. | ||
|
||
#include "mlx/allocator.h" | ||
#include "mlx/backend/common/copy.h" | ||
#include "mlx/linalg.h" | ||
#include "mlx/primitives.h" | ||
|
||
#ifdef ACCELERATE_NEW_LAPACK | ||
#include <Accelerate/Accelerate.h> | ||
#else | ||
#include <lapack.h> | ||
#endif | ||
|
||
namespace mlx::core { | ||
|
||
namespace { | ||
|
||
// Delegate to the eigenvalue decomposition taking into account differences in | ||
// LAPACK implementations (basically how to pass the 'jobz' and 'uplo' strings | ||
// to fortran). | ||
int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) { | ||
int info; | ||
int lwork = -1; | ||
int liwork = -1; | ||
float work_query; | ||
int iwork_query; | ||
|
||
// Query for optimal work array sizes | ||
#ifdef LAPACK_FORTRAN_STRLEN_END | ||
ssyevd_( | ||
/* jobz = */ &jobz, | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* w = */ w, | ||
/* work = */ &work_query, | ||
/* lwork = */ &lwork, | ||
/* iwork = */ &iwork_query, | ||
/* liwork = */ &liwork, | ||
/* info = */ &info, | ||
/* jobz_len = */ static_cast<size_t>(1), | ||
/* uplo_len = */ static_cast<size_t>(1)); | ||
#else | ||
ssyevd_( | ||
/* jobz = */ &jobz, | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* w = */ w, | ||
/* work = */ &work_query, | ||
/* lwork = */ &lwork, | ||
/* iwork = */ &iwork_query, | ||
/* liwork = */ &liwork, | ||
/* info = */ &info); | ||
#endif | ||
|
||
lwork = static_cast<int>(work_query); | ||
liwork = iwork_query; | ||
|
||
std::vector<float> work(lwork); | ||
std::vector<int> iwork(liwork); | ||
|
||
// Compute eigenvalues (and optionally eigenvectors) | ||
#ifdef LAPACK_FORTRAN_STRLEN_END | ||
ssyevd_( | ||
/* jobz = */ &jobz, | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* w = */ w, | ||
/* work = */ work.data(), | ||
/* lwork = */ &lwork, | ||
/* iwork = */ iwork.data(), | ||
/* liwork = */ &liwork, | ||
/* info = */ &info, | ||
/* jobz_len = */ static_cast<size_t>(1), | ||
/* uplo_len = */ static_cast<size_t>(1)); | ||
#else | ||
ssyevd_( | ||
/* jobz = */ &jobz, | ||
/* uplo = */ &uplo, | ||
/* n = */ &N, | ||
/* a = */ matrix, | ||
/* lda = */ &N, | ||
/* w = */ w, | ||
/* work = */ work.data(), | ||
/* lwork = */ &lwork, | ||
/* iwork = */ iwork.data(), | ||
/* liwork = */ &liwork, | ||
/* info = */ &info); | ||
#endif | ||
|
||
return info; | ||
} | ||
|
||
} // namespace | ||
|
||
void eigvalsh_impl( | ||
const array& a, | ||
array& values, | ||
bool upper) { | ||
char jobz = 'N'; // Only compute eigenvalues | ||
char uplo = (upper) ? 'U' : 'L'; | ||
|
||
array buffer = copy(a); | ||
|
||
const int N = static_cast<int>(a.shape(-1)); | ||
const int num_matrices = static_cast<int>(a.size() / (N * N)); | ||
|
||
std::vector<int> values_shape = {num_matrices, N}; | ||
values = array(allocator::malloc(num_matrices * N * size_of(a.dtype())), values_shape, a.dtype()); | ||
|
||
float* matrix = buffer.data<float>(); | ||
float* w = values.data<float>(); | ||
|
||
for (int i = 0; i < num_matrices; i++) { | ||
int info = ssyevd_wrapper(jobz, uplo, matrix, w, N); | ||
|
||
if (info != 0) { | ||
std::stringstream msg; | ||
msg << "[eigvalsh] Eigenvalue decomposition failed with error code " << info; | ||
throw std::runtime_error(msg.str()); | ||
} | ||
|
||
matrix += N * N; | ||
w += N; | ||
} | ||
} | ||
|
||
void eigh_impl( | ||
const array& a, | ||
array& vectors, | ||
bool upper) { | ||
char jobz = 'V'; // Compute both eigenvalues and eigenvectors | ||
char uplo = (upper) ? 'U' : 'L'; | ||
|
||
array buffer = copy(a); | ||
|
||
const int N = static_cast<int>(a.shape(-1)); | ||
const int num_matrices = static_cast<int>(a.size() / (N * N)); | ||
|
||
std::vector<int> vectors_shape = a.shape(); | ||
vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype()); | ||
|
||
float* matrix = buffer.data<float>(); | ||
float* vecs = vectors.data<float>(); | ||
|
||
// Temporary buffer for eigenvalues (we don't return these) | ||
std::vector<float> w(N); | ||
|
||
for (int i = 0; i < num_matrices; i++) { | ||
int info = ssyevd_wrapper(jobz, uplo, matrix, w.data(), N); | ||
|
||
if (info != 0) { | ||
std::stringstream msg; | ||
msg << "[eigh] Eigenvalue decomposition failed with error code " << info; | ||
throw std::runtime_error(msg.str()); | ||
} | ||
|
||
// Copy eigenvectors to the output array | ||
std::copy(matrix, matrix + N * N, vecs); | ||
|
||
matrix += N * N; | ||
vecs += N * N; | ||
} | ||
} | ||
|
||
void Eigvalsh::eval( | ||
const std::vector<array>& inputs, array& output) { | ||
if (inputs[0].dtype() != float32) { | ||
throw std::runtime_error("[Eigvalsh::eval] only supports float32."); | ||
} | ||
eigvalsh_impl(inputs[0], output, upper_); | ||
} | ||
|
||
void Eigh::eval( | ||
const std::vector<array>& inputs, array& output) { | ||
if (inputs[0].dtype() != float32) { | ||
throw std::runtime_error("[Eigh::eval] only supports float32."); | ||
} | ||
eigh_impl(inputs[0], output, upper_); | ||
} | ||
|
||
} // namespace mlx::core |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -74,4 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {}); | |
|
||
array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {}); | ||
|
||
array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {}); | ||
|
||
array eigh(const array& a, bool upper = false, StreamOrDevice s = {}); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't that return both the eigenvalues and eigenvectors like numpy? |
||
|
||
} // namespace mlx::core::linalg |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2158,4 +2158,36 @@ class Cholesky : public UnaryPrimitive { | |
bool upper_; | ||
}; | ||
|
||
class Eigvalsh : public UnaryPrimitive { | ||
kashif marked this conversation as resolved.
Show resolved
Hide resolved
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry if @barronalex already made this suggestion: but I think it makes sense to merge the It looks like the work is done anyway.. and the underlying implementations are basically identical. |
||
public: | ||
explicit Eigvalsh(Stream stream, bool upper) | ||
: UnaryPrimitive(stream), upper_(upper) {} | ||
|
||
void eval_cpu(const std::vector<array>& inputs, array& out) override; | ||
void eval_gpu(const std::vector<array>& inputs, array& out) override; | ||
|
||
DEFINE_VMAP() | ||
DEFINE_PRINT(Eigvalsh) | ||
|
||
private: | ||
void eval(const std::vector<array>& inputs, array& output); | ||
bool upper_; | ||
}; | ||
|
||
class Eigh : public UnaryPrimitive { | ||
public: | ||
explicit Eigh(Stream stream, bool upper) | ||
: UnaryPrimitive(stream), upper_(upper) {} | ||
|
||
void eval_cpu(const std::vector<array>& inputs, array& out) override; | ||
void eval_gpu(const std::vector<array>& inputs, array& out) override; | ||
|
||
DEFINE_VMAP() | ||
DEFINE_PRINT(Eigh) | ||
|
||
private: | ||
void eval(const std::vector<array>& inputs, array& output); | ||
bool upper_; | ||
}; | ||
|
||
} // namespace mlx::core |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it needs to be in this PR, but it would be nice to add support for arbitrary matrices (not just hermitian/symmetric ones).
Hopefully that could just be a flag on the
Eig
primitive that then uses a differentlapack
incantation?