diff --git a/src/centimators/narwhals_utils.py b/src/centimators/narwhals_utils.py index b506710..1848900 100644 --- a/src/centimators/narwhals_utils.py +++ b/src/centimators/narwhals_utils.py @@ -13,7 +13,7 @@ def _ensure_numpy(data, allow_series: bool = False): """Convert data to numpy array, handling both numpy arrays and dataframes. Args: - data: Input data (numpy array, dataframe, or series) + data: Input data (numpy array, dataframe, series, or PyTorch tensor) allow_series: Whether to allow series inputs Returns: @@ -24,6 +24,14 @@ def _ensure_numpy(data, allow_series: bool = False): try: return nw.from_native(data, allow_series=allow_series).to_numpy() except Exception: + # Handle PyTorch tensors (including CUDA tensors) + try: + import torch + if isinstance(data, torch.Tensor): + # Move to CPU if on GPU, then convert to numpy + return data.detach().cpu().numpy() + except ImportError: + pass return numpy.asarray(data)