Skip to content

Add TPU acceleration for generalized eigenvalue decompositions in decoding module (CSP, SPoC, SSD) #13959

@toxiifired-design

Description

@toxiifired-design

Describe the new feature or enhancement

This feature would enable MNE-Python to leverage Tensor Processing Units (TPUs) for accelerating computationally intensive steps in the mne.decoding module, specifically the generalized eigenvalue decomposition (GED) and approximate joint diagonalization (AJD) operations used by Common Spatial Patterns (CSP), SPoC, and SSD. Currently, these algorithms rely on CPU-based scipy.linalg or fallback to custom _smart_ged and _smart_ajd functions. By adding an optional TPU backend powered by PyTorch/XLA (TorchTPU) or JAX, users can see significant speedups for large‑scale, batched problems (e.g., multi‑frequency CSP, high‑density EEG/MEG, or long recordings).

The feature "Could" :

  • Detect TPU availability at runtime (e.g., via torch_xla or jax.devices()).
  • Offload covariance matrix operations and generalized eigenvalue problems to the TPU.
  • Fall back to the original CPU implementation when TPUs are not available or data shapes are suboptimal.
  • Follow the configuration pattern already established for CUDA (e.g., MNE_USE_TPU environment variable).

Describe your proposed implementation

How I would implement it (with small code examples)
I will show three small pieces of code. This is to prove I understand the idea and it is not too hard.

  1. TPU version of _smart_ged
    This function replaces the old one when TPU is there. It moves data to TPU, solves the eigenvalue problem, then brings results back.

python
def _smart_ged_tpu(S, R, restr_mat=None, R_func=None):
# Check if TPU is there
if not _tpu_available():
return _smart_ged_cpu(S, R, restr_mat, R_func) # go back to old way

import torch
# Put matrices on TPU (device='xla' for PyTorch/XLA)
S_t = torch.tensor(S, dtype=torch.float64, device='xla')
R_t = torch.tensor(R, dtype=torch.float64, device='xla')

# Solve S*v = lambda*R*v  -- TPU does this fast
evals_t, evecs_t = torch.linalg.eig(S_t, R_t)

# Send back to CPU and return as numpy
return evals_t.real.cpu().numpy(), evecs_t.real.cpu().numpy()
  1. Small change inside _GEDTransformer.fit
    This is the main class that CSP, SPoC, SSD all use. I just add a small branch.

python
class _GEDTransformer(...):
def fit(self, X, y=None):
# ... old code to compute covariances S and R ...

    # THe New part: 
    if _tpu_available() and self.use_tpu:
        eigvals, eigvecs = _smart_ged_tpu(S, R)
    else:
        eigvals, eigvecs = _smart_ged(S, R)   # old CPU version
  1. Helper to check if TPU exists
    Put this in a new file like mne/tpu.py, same pattern as mne/cuda.py.

python
def _tpu_available():
"""Return True if we can use TPU."""
try:
import torch_xla.core.xla_model as xm
# Get first device, see if it is TPU
dev = xm.get_xla_supported_devices()[0]
return xm.xla_device_hw(dev) == 'TPU'
except (ImportError, IndexError):
return False

Why these help the MNE community
The changes are very small. Only a few functions in base.py need touch. All CSP, SPoC, SSD will get faster without extra work. If something goes wrong, it falls back to CPU. This is same pattern as CUDA already in MNE. I think this is a good first step. Later we can do more, like source localization. But first just decoding.

Describe possible alternatives

Small Alternative: Redirect GED to GPU (CuPy)

Changes _smart_ged to check for CuPy (CUDA GPU) first. If CuPy is there, it moves matrices to GPU, solves the eigenvalue problem using cupy.linalg.eig, then sends results back to CPU. If no GPU, it falls back to normal CPU.

Code snippet
python
def _smart_ged(S, R, restr_mat=None, R_func=None):
# Trying GPU first
try:
import cupy as cp
S_gpu = cp.asarray(S)
R_gpu = cp.asarray(R)
evals_gpu, evecs_gpu = cp.linalg.eig(S_gpu, R_gpu)
# Bring back to CPU
return evals_gpu.real.get(), evecs_gpu.real.get()
except (ImportError, MemoryError):
pass # No GPU or out of memory → use CPU

# Original CPU can go below
return _smart_ged_cpu(S, R, restr_mat, R_func)

Additional context

I believe the impact could be high for a relatively small amount of code. From looking at the issue history, classes like CSP, SPoC, and SSD all inherit from the _GEDTransformer parent class. Therefore, if we modify the generalized eigenvalue decomposition in _GEDTransformer to optionally use a TPU, all of these downstream classes could benefit automatically.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions