-
Notifications
You must be signed in to change notification settings - Fork 943
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Cholesky decomposition #1026
Comments
I've been using the CPU Cholesky recently and unfortunately it's quite slow for large matrices. A hybrid CPU/GPU Cholesky with MLX ops is about 2-3x faster than the pure CPU version for def cholesky(A: mx.array, block_size: int = 512):
N = A.shape[-1]
L = mx.zeros_like(A)
A = mx.array(A)
# For numerical stability
amax = A.abs().max(axis=range(1, A.ndim), keepdims=True)
A /= amax
for k in range(0, N, block_size):
end = min(k + block_size, N)
L[..., k:end, k:end] = mx.linalg.cholesky(A[..., k:end, k:end], stream=mx.cpu)
if end < N:
L_inv = mx.linalg.tri_inv(mx.swapaxes(L[..., k:end, k:end], -1, -2), upper=True, stream=mx.cpu)
L[..., end:N, k:end] = A[..., end:N, k:end] @ L_inv
A[..., end:N, end:N] -= mx.matmul(L[..., end:N, k:end], mx.swapaxes(L[..., end:N, k:end], -1, -2))
L *= mx.sqrt(amax)
return L
It seems to be similarly numerical accurate, if anything slightly better than the CPU version when you check We don't really have a great pattern for ops that force some computation on the CPU and some on the GPU but maybe it's worth merging anyway? It could be quite hard to write a more performant GPU only kernel for a single matrix since the unblocked |
That's pretty awesome that it's faster. A rare example of mixing CPU / GPU speeding things up! I'm not sure what to do with it. On the one-hand, it's a lot faster which is nice. On the other hand, Implementing this at the op level will kind of break a couple patterns:
I think as a temporary speedup it's fine to add / we probably should. But it would be useful to know long term what a good plan is for Cholesky and friends (heavy ops which are hard to parallelize in just one or two kernels). Is it feasible that we eventually replace it with our own fast kernel(s)? The alternative is maybe we should rethink which of those patterns above are worth being consistent about and which are not and maybe come up with a consistent way of working around them. |
I think it's feasible to write a GPU only Cholesky that's at least close to as performant as the above so maybe we don't need to change the pattern. Given that we'll likely want a batched version for all of these harder to parallelize ops it is tempting to just to keep it consistent and maybe sacrifice a little bit of performance. |
Add for the CPU using Lapack.
For the GPU MPS has a Cholesky which could be a good option to start with (following how we used to do bind MPS matmul):
The text was updated successfully, but these errors were encountered: