Skip to content
28 changes: 12 additions & 16 deletions model/common/src/icon4py/model/common/decomposition/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,34 +554,30 @@ class SingleNodeRun(RunType):


class Reductions(Protocol):
def min(
self, buffer: data_alloc.NDArray, array_ns: ModuleType = np
) -> state_utils.ScalarType: ...
def min(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType: ...

def max(
self, buffer: data_alloc.NDArray, array_ns: ModuleType = np
) -> state_utils.ScalarType: ...
def max(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType: ...

def sum(
self, buffer: data_alloc.NDArray, array_ns: ModuleType = np
) -> state_utils.ScalarType: ...
def sum(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType: ...

def mean(
self, buffer: data_alloc.NDArray, array_ns: ModuleType = np
) -> state_utils.ScalarType: ...
def mean(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType: ...


class SingleNodeReductions(Reductions):
def min(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_utils.ScalarType:
def min(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
return array_ns.min(buffer).item()

def max(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_utils.ScalarType:
def max(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
return array_ns.max(buffer).item()

def sum(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_utils.ScalarType:
def sum(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
return array_ns.sum(buffer).item()

def mean(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_utils.ScalarType:
def mean(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
return array_ns.sum(buffer).item() / buffer.size


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,6 @@ def get_halo_constructor(
def global_to_local(
global_indices: data_alloc.NDArray,
indices_to_translate: data_alloc.NDArray,
array_ns: ModuleType = np,
) -> data_alloc.NDArray:
"""Translate an array of global indices into rank-local ones.

Expand All @@ -510,6 +509,7 @@ def global_to_local(
indices_to_translate: the array to map to local indices

"""
array_ns = data_alloc.array_namespace(global_indices)
sorter = array_ns.argsort(global_indices)

mask = array_ns.isin(indices_to_translate, global_indices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,8 @@ def _reduce(
buffer: data_alloc.NDArray,
local_reduction: Callable[[data_alloc.NDArray], data_alloc.ScalarT],
global_reduction: mpi4py.MPI.Op,
array_ns: ModuleType = np,
) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
local_red_val = local_reduction(buffer)
recv_buffer = array_ns.empty(1, dtype=buffer.dtype)
if hasattr(
Expand All @@ -465,42 +465,43 @@ def _reduce(
def _calc_buffer_size(
self,
buffer: data_alloc.NDArray,
array_ns: ModuleType = np,
) -> state_utils.ScalarType:
return self._reduce(array_ns.asarray([buffer.size]), array_ns.sum, mpi4py.MPI.SUM, array_ns)
array_ns = data_alloc.array_namespace(buffer)
return self._reduce(array_ns.asarray([buffer.size]), array_ns.sum, mpi4py.MPI.SUM)

def min(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_utils.ScalarType:
if self._calc_buffer_size(buffer, array_ns) == 0:
def min(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
if self._calc_buffer_size(buffer) == 0:
raise ValueError("global_min requires a non-empty buffer")
return self._reduce(
buffer if buffer.size != 0 else self._min_identity(buffer.dtype, array_ns),
array_ns.min,
mpi4py.MPI.MIN,
array_ns,
)

def max(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_utils.ScalarType:
if self._calc_buffer_size(buffer, array_ns) == 0:
def max(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
if self._calc_buffer_size(buffer) == 0:
raise ValueError("global_max requires a non-empty buffer")
return self._reduce(
buffer if buffer.size != 0 else self._max_identity(buffer.dtype, array_ns),
array_ns.max,
mpi4py.MPI.MAX,
array_ns,
)

def sum(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_utils.ScalarType:
if self._calc_buffer_size(buffer, array_ns) == 0:
def sum(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
if self._calc_buffer_size(buffer) == 0:
raise ValueError("global_sum requires a non-empty buffer")
return self._reduce(
buffer if buffer.size != 0 else self._sum_identity(buffer.dtype, array_ns),
array_ns.sum,
mpi4py.MPI.SUM,
array_ns,
)

def mean(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_utils.ScalarType:
global_buffer_size = self._calc_buffer_size(buffer, array_ns)
def mean(self, buffer: data_alloc.NDArray) -> state_utils.ScalarType:
array_ns = data_alloc.array_namespace(buffer)
global_buffer_size = self._calc_buffer_size(buffer)
if global_buffer_size == 0:
raise ValueError("global_mean requires a non-empty buffer")

Expand All @@ -509,7 +510,6 @@ def mean(self, buffer: data_alloc.NDArray, array_ns: ModuleType = np) -> state_u
(buffer if buffer.size != 0 else self._sum_identity(buffer.dtype, array_ns)),
array_ns.sum,
mpi4py.MPI.SUM,
array_ns,
)
/ global_buffer_size
)
Expand Down
12 changes: 6 additions & 6 deletions model/common/src/icon4py/model/common/grid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import functools
import logging
from collections.abc import Callable, Sequence
from types import ModuleType

import gt4py.next as gtx
import gt4py.next.typing as gtx_typing
Expand Down Expand Up @@ -178,7 +177,7 @@ def construct_connectivity(
if replace_skip_values:
_log.debug(f"Replacing skip values in connectivity for {dim} with max valid neighbor.")
skip_value = None
table = _replace_skip_values(dim, table, array_ns=data_alloc.import_array_ns(allocator))
table = _replace_skip_values(dim, table)

return gtx.as_connectivity(
[from_dim, dim],
Expand All @@ -191,7 +190,7 @@ def construct_connectivity(


def _replace_skip_values(
domain: Sequence[gtx.Dimension], neighbor_table: data_alloc.NDArray, array_ns: ModuleType
domain: Sequence[gtx.Dimension], neighbor_table: data_alloc.NDArray
) -> data_alloc.NDArray:
"""
Manipulate a Connectivity's neighbor table to remove invalid indices.
Expand Down Expand Up @@ -221,11 +220,11 @@ def _replace_skip_values(
Args:
domain: the domain of the Connectivity
connectivity: NDArray object to be manipulated
array_ns: numpy or cupy module to use for array operations
Returns:
NDArray without skip values
"""
if _has_skip_values_in_table(neighbor_table, array_ns):
array_ns = data_alloc.array_namespace(neighbor_table)
if _has_skip_values_in_table(neighbor_table):
_log.info(f"Found invalid indices in {domain}. Replacing...")
max_valid_neighbor = neighbor_table.max(axis=1, keepdims=True)
if not array_ns.all(max_valid_neighbor >= 0):
Expand All @@ -241,5 +240,6 @@ def _replace_skip_values(
return neighbor_table


def _has_skip_values_in_table(data: data_alloc.NDArray, array_ns: ModuleType) -> bool:
def _has_skip_values_in_table(data: data_alloc.NDArray) -> bool:
array_ns = data_alloc.array_namespace(data)
return array_ns.amin(data).item() == GridFile.INVALID_INDEX
20 changes: 4 additions & 16 deletions model/common/src/icon4py/model/common/grid/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,7 @@ def _register_computed_fields(self) -> None:
self.register_provider(edge_areas)

mean_edge_length_np = factory.NumpyDataProvider(
func=functools.partial(
self._global_reductions.mean,
array_ns=self._xp,
),
func=self._global_reductions.mean,
domain=(),
deps={
"buffer": attrs.EDGE_LENGTH,
Expand All @@ -328,10 +325,7 @@ def _register_computed_fields(self) -> None:
self.register_provider(mean_edge_length_np)

mean_dual_edge_length_np = factory.NumpyDataProvider(
func=functools.partial(
self._global_reductions.mean,
array_ns=self._xp,
),
func=self._global_reductions.mean,
domain=(),
deps={
"buffer": attrs.DUAL_EDGE_LENGTH,
Expand All @@ -341,10 +335,7 @@ def _register_computed_fields(self) -> None:
self.register_provider(mean_dual_edge_length_np)

mean_cell_area_np = factory.NumpyDataProvider(
func=functools.partial(
self._global_reductions.mean,
array_ns=self._xp,
),
func=self._global_reductions.mean,
domain=(),
deps={
"buffer": attrs.CELL_AREA,
Expand All @@ -354,10 +345,7 @@ def _register_computed_fields(self) -> None:
self.register_provider(mean_cell_area_np)

mean_dual_cell_area_np = factory.NumpyDataProvider(
func=functools.partial(
self._global_reductions.mean,
array_ns=self._xp,
),
func=self._global_reductions.mean,
domain=(),
deps={
"buffer": attrs.DUAL_AREA,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

from types import ModuleType

import gt4py.next.typing as gtx_typing
import numpy as np
from gt4py import next as gtx
from gt4py.next import sin, where

Expand Down Expand Up @@ -824,8 +822,8 @@ def compute_primal_cart_normal(
primal_cart_normal_x: data_alloc.NDArray,
primal_cart_normal_y: data_alloc.NDArray,
primal_cart_normal_z: data_alloc.NDArray,
array_ns: ModuleType = np,
) -> data_alloc.NDArray:
array_ns = data_alloc.array_namespace(primal_cart_normal_x)
primal_cart_normal = array_ns.transpose(
array_ns.stack(
(
Expand Down
Loading