Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eigenvalues and eigenvectors #1334

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions docs/src/python/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,5 @@ Linear Algebra
cholesky_inv
qr
svd
eigvalsh
eigh
2 changes: 2 additions & 0 deletions mlx/backend/accelerate/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT(Eigvalsh)
DEFAULT(Eigh)

void Abs::eval_cpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 1);
Expand Down
1 change: 1 addition & 0 deletions mlx/backend/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/conv.cpp
${CMAKE_CURRENT_SOURCE_DIR}/copy.cpp
${CMAKE_CURRENT_SOURCE_DIR}/erf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eigvalsh.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
${CMAKE_CURRENT_SOURCE_DIR}/masked_mm.cpp
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/common/default_primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ DEFAULT(Tanh)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT(Eigvalsh)
DEFAULT(Eigh)

namespace {

Expand Down
187 changes: 187 additions & 0 deletions mlx/backend/common/eigvalsh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
// Copyright © 2023-2024 Apple Inc.

#include "mlx/allocator.h"
#include "mlx/backend/common/copy.h"
#include "mlx/linalg.h"
#include "mlx/primitives.h"

#ifdef ACCELERATE_NEW_LAPACK
#include <Accelerate/Accelerate.h>
#else
#include <lapack.h>
#endif

namespace mlx::core {

namespace {

// Delegate to the eigenvalue decomposition taking into account differences in
// LAPACK implementations (basically how to pass the 'jobz' and 'uplo' strings
// to fortran).
int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it needs to be in this PR, but it would be nice to add support for arbitrary matrices (not just hermitian/symmetric ones).

Hopefully that could just be a flag on the Eig primitive that then uses a different lapack incantation?

int info;
int lwork = -1;
int liwork = -1;
float work_query;
int iwork_query;

// Query for optimal work array sizes
#ifdef LAPACK_FORTRAN_STRLEN_END
ssyevd_(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* w = */ w,
/* work = */ &work_query,
/* lwork = */ &lwork,
/* iwork = */ &iwork_query,
/* liwork = */ &liwork,
/* info = */ &info,
/* jobz_len = */ static_cast<size_t>(1),
/* uplo_len = */ static_cast<size_t>(1));
#else
ssyevd_(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* w = */ w,
/* work = */ &work_query,
/* lwork = */ &lwork,
/* iwork = */ &iwork_query,
/* liwork = */ &liwork,
/* info = */ &info);
#endif

lwork = static_cast<int>(work_query);
liwork = iwork_query;

std::vector<float> work(lwork);
std::vector<int> iwork(liwork);

// Compute eigenvalues (and optionally eigenvectors)
#ifdef LAPACK_FORTRAN_STRLEN_END
ssyevd_(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* w = */ w,
/* work = */ work.data(),
/* lwork = */ &lwork,
/* iwork = */ iwork.data(),
/* liwork = */ &liwork,
/* info = */ &info,
/* jobz_len = */ static_cast<size_t>(1),
/* uplo_len = */ static_cast<size_t>(1));
#else
ssyevd_(
/* jobz = */ &jobz,
/* uplo = */ &uplo,
/* n = */ &N,
/* a = */ matrix,
/* lda = */ &N,
/* w = */ w,
/* work = */ work.data(),
/* lwork = */ &lwork,
/* iwork = */ iwork.data(),
/* liwork = */ &liwork,
/* info = */ &info);
#endif

return info;
}

} // namespace

void eigvalsh_impl(
const array& a,
array& values,
bool upper) {
char jobz = 'N'; // Only compute eigenvalues
char uplo = (upper) ? 'U' : 'L';

array buffer = copy(a);

const int N = static_cast<int>(a.shape(-1));
const int num_matrices = static_cast<int>(a.size() / (N * N));

std::vector<int> values_shape = {num_matrices, N};
values = array(allocator::malloc(num_matrices * N * size_of(a.dtype())), values_shape, a.dtype());

float* matrix = buffer.data<float>();
float* w = values.data<float>();

for (int i = 0; i < num_matrices; i++) {
int info = ssyevd_wrapper(jobz, uplo, matrix, w, N);

if (info != 0) {
std::stringstream msg;
msg << "[eigvalsh] Eigenvalue decomposition failed with error code " << info;
throw std::runtime_error(msg.str());
}

matrix += N * N;
w += N;
}
}

void eigh_impl(
const array& a,
array& vectors,
bool upper) {
char jobz = 'V'; // Compute both eigenvalues and eigenvectors
char uplo = (upper) ? 'U' : 'L';

array buffer = copy(a);

const int N = static_cast<int>(a.shape(-1));
const int num_matrices = static_cast<int>(a.size() / (N * N));

std::vector<int> vectors_shape = a.shape();
vectors = array(allocator::malloc(a.size() * size_of(a.dtype())), vectors_shape, a.dtype());

float* matrix = buffer.data<float>();
float* vecs = vectors.data<float>();

// Temporary buffer for eigenvalues (we don't return these)
std::vector<float> w(N);

for (int i = 0; i < num_matrices; i++) {
int info = ssyevd_wrapper(jobz, uplo, matrix, w.data(), N);

if (info != 0) {
std::stringstream msg;
msg << "[eigh] Eigenvalue decomposition failed with error code " << info;
throw std::runtime_error(msg.str());
}

// Copy eigenvectors to the output array
std::copy(matrix, matrix + N * N, vecs);

matrix += N * N;
vecs += N * N;
}
}

void Eigvalsh::eval(
const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Eigvalsh::eval] only supports float32.");
}
eigvalsh_impl(inputs[0], output, upper_);
}

void Eigh::eval(
const std::vector<array>& inputs, array& output) {
if (inputs[0].dtype() != float32) {
throw std::runtime_error("[Eigh::eval] only supports float32.");
}
eigh_impl(inputs[0], output, upper_);
}

} // namespace mlx::core
8 changes: 8 additions & 0 deletions mlx/backend/metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,14 @@ void Cholesky::eval_gpu(const std::vector<array>& inputs, array& out) {
"[Cholesky::eval_gpu] Metal Cholesky decomposition NYI.");
}

void Eigvalsh::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Eigvalsh::eval_gpu] Metal Eigvalsh NYI.");
}

void Eigh::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error("[Eigh::eval_gpu] Metal Eigh NYI.");
}

void View::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& in = inputs[0];
auto ibytes = size_of(in.dtype());
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/no_cpu/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ NO_CPU(Divide)
NO_CPU_MULTI(DivMod)
NO_CPU(NumberOfElements)
NO_CPU(Remainder)
NO_CPU(Eigvalsh)
NO_CPU(Eigh)
NO_CPU(Equal)
NO_CPU(Erf)
NO_CPU(ErfInv)
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/no_metal/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ NO_GPU(Tanh)
NO_GPU(Transpose)
NO_GPU(Inverse)
NO_GPU(Cholesky)
NO_GPU(Eigvalsh)
NO_GPU(Eigh)
NO_GPU(View)

namespace fast {
Expand Down
65 changes: 65 additions & 0 deletions mlx/linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -382,4 +382,69 @@ array cholesky_inv(
}
}

array eigvalsh(
const array& a,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::eigvalsh] Arrays must be type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::eigvalsh] Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(
"[linalg::eigvalsh] Eigenvalues are only defined for square matrices.");
}

std::vector<int> out_shape(a.shape().begin(), a.shape().end() - 1);
out_shape.back() = a.shape(-1);

return array(
out_shape,
a.dtype(),
std::make_shared<Eigvalsh>(to_stream(s), upper),
{astype(a, a.dtype(), s)});
}

array eigh(
const array& a,
bool upper /* = false */,
StreamOrDevice s /* = {} */) {
if (a.dtype() != float32) {
std::ostringstream msg;
msg << "[linalg::eigh] Arrays must be type float32. Received array "
<< "with type " << a.dtype() << ".";
throw std::invalid_argument(msg.str());
}

if (a.ndim() < 2) {
std::ostringstream msg;
msg << "[linalg::eigh] Arrays must have >= 2 dimensions. Received array "
"with "
<< a.ndim() << " dimensions.";
throw std::invalid_argument(msg.str());
}

if (a.shape(-1) != a.shape(-2)) {
throw std::invalid_argument(
"[linalg::eigh] Eigenvectors are only defined for square matrices.");
}

return array(
a.shape(),
a.dtype(),
std::make_shared<Eigh>(to_stream(s), upper),
{astype(a, a.dtype(), s)});
}

} // namespace mlx::core::linalg
4 changes: 4 additions & 0 deletions mlx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {});

array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});

array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {});

array eigh(const array& a, bool upper = false, StreamOrDevice s = {});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that return both the eigenvalues and eigenvectors like numpy?


} // namespace mlx::core::linalg
16 changes: 16 additions & 0 deletions mlx/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,22 @@ std::pair<std::vector<array>, std::vector<int>> Cholesky::vmap(
return {{linalg::cholesky(a, upper_, stream())}, {ax}};
}

std::pair<std::vector<array>, std::vector<int>> Eigvalsh::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::eigvalsh(a, upper_, stream())}, {ax}};
}

std::pair<std::vector<array>, std::vector<int>> Eigh::vmap(
const std::vector<array>& inputs,
const std::vector<int>& axes) {
auto ax = axes[0] >= 0 ? 0 : -1;
auto a = axes[0] > 0 ? moveaxis(inputs[0], axes[0], 0, stream()) : inputs[0];
return {{linalg::eigh(a, upper_, stream())}, {ax}};
}

std::vector<array> Concatenate::vjp(
const std::vector<array>& primals,
const std::vector<array>& cotangents,
Expand Down
32 changes: 32 additions & 0 deletions mlx/primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -2158,4 +2158,36 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};

class Eigvalsh : public UnaryPrimitive {
kashif marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if @barronalex already made this suggestion: but I think it makes sense to merge the Eighvalsh and Eigh into a single primitive. And have the ops use the same primitive but just return only the eigenvalues in the case of eighvalsh.

It looks like the work is done anyway.. and the underlying implementations are basically identical.

public:
explicit Eigvalsh(Stream stream, bool upper)
: UnaryPrimitive(stream), upper_(upper) {}

void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;

DEFINE_VMAP()
DEFINE_PRINT(Eigvalsh)

private:
void eval(const std::vector<array>& inputs, array& output);
bool upper_;
};

class Eigh : public UnaryPrimitive {
public:
explicit Eigh(Stream stream, bool upper)
: UnaryPrimitive(stream), upper_(upper) {}

void eval_cpu(const std::vector<array>& inputs, array& out) override;
void eval_gpu(const std::vector<array>& inputs, array& out) override;

DEFINE_VMAP()
DEFINE_PRINT(Eigh)

private:
void eval(const std::vector<array>& inputs, array& output);
bool upper_;
};

} // namespace mlx::core
Loading