Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions model/common/src/icon4py/model/common/utils/data_allocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ def array_ns(try_cupy: bool) -> ModuleType:
return np


def array_ns_from_array(array: NDArray) -> ModuleType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def array_ns_from_array(array: NDArray) -> ModuleType:
def get_array_ns_from_array(array: NDArray) -> ModuleType:

to keep in line with (Magda's?) get/compute/apply

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I must say I don't love adding get etc. everywhere, but I don't object. If you prefer adding it let's do it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In gt4py we recently introduced array_api_compat as a dependency which provides array_api_compat.array_namespace(array)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
Returns the array namespace for a given array.
"""
if hasattr(array, "__array_namespace__"):
return array.__array_namespace__()
elif isinstance(array, np.ndarray):
return np
elif isinstance(array, xp.ndarray):
return xp
else:
raise RuntimeError(
f"Unsupported array type '{type(array)}'. Cannot detect the array namespace."
)


def import_array_ns(allocator: gtx_typing.Allocator | None) -> ModuleType:
"""Import cupy or numpy depending on a chosen GT4Py backend DevicType."""
return array_ns(device_utils.is_cupy_device(allocator))
Expand Down