From 21802468ec8a26129d0b963a22682aba9220cfef Mon Sep 17 00:00:00 2001 From: Mikael Simberg Date: Mon, 2 Mar 2026 13:21:15 +0100 Subject: [PATCH] Add helper for getting array namespace from an array --- .../model/common/utils/data_allocation.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/model/common/src/icon4py/model/common/utils/data_allocation.py b/model/common/src/icon4py/model/common/utils/data_allocation.py index 011d1f72b9..76905b0de7 100644 --- a/model/common/src/icon4py/model/common/utils/data_allocation.py +++ b/model/common/src/icon4py/model/common/utils/data_allocation.py @@ -77,6 +77,22 @@ def array_ns(try_cupy: bool) -> ModuleType: return np +def array_ns_from_array(array: NDArray) -> ModuleType: + """ + Returns the array namespace for a given array. + """ + if hasattr(array, "__array_namespace__"): + return array.__array_namespace__() + elif isinstance(array, np.ndarray): + return np + elif isinstance(array, xp.ndarray): + return xp + else: + raise RuntimeError( + f"Unsupported array type '{type(array)}'. Cannot detect the array namespace." + ) + + def import_array_ns(allocator: gtx_typing.Allocator | None) -> ModuleType: """Import cupy or numpy depending on a chosen GT4Py backend DevicType.""" return array_ns(device_utils.is_cupy_device(allocator))