Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eigenvalues and eigenvectors #1334

Draft
wants to merge 16 commits into
base: main
Choose a base branch
from
Draft

Conversation

kashif
Copy link

@kashif kashif commented Aug 18, 2024

Proposed changes

Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@kashif kashif marked this pull request as draft August 18, 2024 16:44
@kashif
Copy link
Author

kashif commented Aug 19, 2024

thanks @awni my question is:

  • the output with the lapack can in theory return 2 things... is that an issue with the unitary primitive? the output depending on the params can be values, vectors, or I believe both...
  • how do i initialize an empty array in cpp with a given shape?

Copy link
Collaborator

@barronalex barronalex left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nicely done and a great addition -- thanks!
Left a few comments/suggestions.

mlx/linalg.h Outdated Show resolved Hide resolved
mlx/primitives.h Outdated Show resolved Hide resolved
nb::sig(
"def eigvalsh(a: array, upper: bool = True, *, stream: Union[None, Stream, Device] = None) -> array"),
R"pbdoc(
Compute the eigenvalues of a complex Hermitian or real symmetric matrix.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't do a brilliant job of this in the other linalg functions, but we should probably check if complex inputs work and raise a not implemented error if not.

mlx/primitives.h Outdated
@@ -2158,4 +2158,23 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};

class Eigvalsh : public UnaryPrimitive {
public:
explicit Eigvalsh(Stream stream, bool upper, bool compute_vectors)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given it can compute eigenvectors, this can probably be called Eigh rather than Eigvalsh?

// Delegate to the eigenvalue decomposition taking into account differences in
// LAPACK implementations (basically how to pass the 'jobz' and 'uplo' strings
// to fortran).
int ssyevd_wrapper(char jobz, char uplo, float* matrix, float* w, int N) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it needs to be in this PR, but it would be nice to add support for arbitrary matrices (not just hermitian/symmetric ones).

Hopefully that could just be a flag on the Eig primitive that then uses a different lapack incantation?

mlx/linalg.cpp Outdated Show resolved Hide resolved
mlx/linalg.h Outdated
@@ -74,4 +74,8 @@ array pinv(const array& a, StreamOrDevice s = {});

array cholesky_inv(const array& a, bool upper = false, StreamOrDevice s = {});

array eigvalsh(const array& a, bool upper = false, StreamOrDevice s = {});

array eigh(const array& a, bool upper = false, StreamOrDevice s = {});
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't that return both the eigenvalues and eigenvectors like numpy?

mlx/primitives.h Outdated
@@ -2158,4 +2158,36 @@ class Cholesky : public UnaryPrimitive {
bool upper_;
};

class Eigvalsh : public UnaryPrimitive {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if @barronalex already made this suggestion: but I think it makes sense to merge the Eighvalsh and Eigh into a single primitive. And have the ops use the same primitive but just return only the eigenvalues in the case of eighvalsh.

It looks like the work is done anyway.. and the underlying implementations are basically identical.

@kashif
Copy link
Author

kashif commented Sep 15, 2024

@awni is the current implementation more like what you mean?

@@ -81,6 +81,7 @@ DEFAULT_MULTI(SVD)
DEFAULT(Transpose)
DEFAULT(Inverse)
DEFAULT(Cholesky)
DEFAULT_MULTI(EighPrimitive)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please rename this to Eigh for consistency with other primitive names.

@awni
Copy link
Member

awni commented Sep 19, 2024

Yes I think it's better this way. I'm debating if we should bother with the compute_eigenvectors flag as opposed to just having the primitive always populate both outputs and then just extracting the output you need in the function in ops.cpp.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants