Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions model/common/src/icon4py/model/common/decomposition/halo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,19 @@ class IconLikeHaloConstructor(HaloConstructor):
def __init__(
self,
run_properties: defs.ProcessProperties,
connectivities: dict[gtx.FieldOffset | str, data_alloc.NDArray],
neighbor_tables: dict[gtx.FieldOffset | str, data_alloc.NDArray],
allocator: gtx_typing.Allocator | None = None,
):
"""

Args:
run_properties: contains information on the communicator and local compute node.
connectivities: connectivity arrays needed to construct the halos
neighbor_tables: neighbor table arrays needed to construct the halos
allocator: GT4Py buffer allocator
"""
self._xp = data_alloc.import_array_ns(allocator)
self._props = run_properties
self._connectivities = {self._value(k): v for k, v in connectivities.items()}
self._neighbor_tables = {self._value(k): v for k, v in neighbor_tables.items()}
self._assert_all_neighbor_tables()

@staticmethod
Expand All @@ -82,7 +82,7 @@ def _value(k: gtx.FieldOffset | str) -> str:

def _validate_mapping(self, cell_to_rank_mapping: data_alloc.NDArray) -> None:
# validate the distribution mapping:
num_cells = self._connectivity(dims.C2E2C).shape[0]
num_cells = self._neighbor_table(dims.C2E2C).shape[0]
expected_shape = (num_cells,)
if cell_to_rank_mapping.shape != expected_shape:
raise exceptions.ValidationError(
Expand All @@ -109,12 +109,12 @@ def _assert_all_neighbor_tables(self) -> None:
]
for d in relevant_dimension:
assert (
d.value in self._connectivities
d.value in self._neighbor_tables
), f"Table for {d} is missing from the neighbor table array."

def _connectivity(self, offset: gtx.FieldOffset | str) -> data_alloc.NDArray:
def _neighbor_table(self, offset: gtx.FieldOffset | str) -> data_alloc.NDArray:
try:
return self._connectivities[self._value(offset)]
return self._neighbor_tables[self._value(offset)]
except KeyError as err:
raise exceptions.MissingConnectivityError(
f"Connectivity for offset {offset} is not available"
Expand All @@ -137,7 +137,7 @@ def _find_neighbors(
) -> data_alloc.NDArray:
"""Get a flattened list of all (unique) neighbors to a given global index list"""
assert source_indices.ndim == 1
neighbors = self._xp.unique(self._connectivity(offset)[source_indices, :].flatten())
neighbors = self._xp.unique(self._neighbor_table(offset)[source_indices, :].flatten())
# Connectivities may have invalid neighbors, filter them out to avoid
# indexing with negative indices later.
return neighbors[neighbors >= 0]
Expand Down Expand Up @@ -378,7 +378,7 @@ def __call__(self, cell_to_rank: data_alloc.NDArray) -> defs.DecompositionInfo:
vertex_owner_mask,
all_vertices,
vertex_on_cutting_line,
self._connectivity(dims.V2C),
self._neighbor_table(dims.V2C),
)

# Once mask has been updated, some owned cells may now belong to the
Expand Down Expand Up @@ -430,7 +430,7 @@ def __call__(self, cell_to_rank: data_alloc.NDArray) -> defs.DecompositionInfo:
edge_owner_mask,
all_edges,
edge_on_cutting_line,
self._connectivity(dims.E2C),
self._neighbor_table(dims.E2C),
)

# Once mask has been updated, some owned cells may now belong to the
Expand Down Expand Up @@ -467,7 +467,7 @@ def __call__(self, cell_to_rank: data_alloc.NDArray) -> defs.DecompositionInfo:
def get_halo_constructor(
run_properties: defs.ProcessProperties,
full_grid_size: base.HorizontalGridSize,
connectivities: dict[gtx.FieldOffset | str, data_alloc.NDArray],
neighbor_tables: dict[gtx.FieldOffset | str, data_alloc.NDArray],
allocator: gtx_typing.Allocator | None,
) -> HaloConstructor:
"""
Expand All @@ -480,7 +480,7 @@ def get_halo_constructor(
processor_props:
full_grid_size
allocator:
connectivities:
neighbor_tables:

Returns: a HaloConstructor suitable for the run_properties

Expand All @@ -493,7 +493,7 @@ def get_halo_constructor(

return IconLikeHaloConstructor(
run_properties=run_properties,
connectivities=connectivities,
neighbor_tables=neighbor_tables,
allocator=allocator,
)

Expand Down
2 changes: 1 addition & 1 deletion model/common/src/icon4py/model/common/grid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def _replace_skip_values(

Args:
domain: the domain of the Connectivity
connectivity: NDArray object to be manipulated
neighbor_table: NDArray object to be manipulated
array_ns: numpy or cupy module to use for array operations
Returns:
NDArray without skip values
Expand Down
10 changes: 5 additions & 5 deletions model/common/src/icon4py/model/common/grid/grid_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,17 +434,17 @@ def _construct_decomposed_grid(
halo_constructor = halo.get_halo_constructor(
run_properties=run_properties,
full_grid_size=global_size,
connectivities=global_neighbor_tables,
neighbor_tables=global_neighbor_tables,
allocator=allocator,
)

self._decomposition_info = halo_constructor(cells_to_rank_mapping)
distributed_size = self._decomposition_info.get_horizontal_size()

neighbor_tables = self._get_local_connectivities(global_neighbor_tables, array_ns=xp)
neighbor_tables = self._get_local_neighbor_tables(global_neighbor_tables, array_ns=xp)

# COMPUTE remaining derived connectivities
neighbor_tables.update(_get_derived_connectivities(neighbor_tables, array_ns=xp))
neighbor_tables.update(_get_derived_neighbor_tables(neighbor_tables, array_ns=xp))

refinement_fields = self._read_grid_refinement_fields(allocator)

Expand Down Expand Up @@ -475,7 +475,7 @@ def _construct_decomposed_grid(
refinement_control=refinement_fields,
)

def _get_local_connectivities(
def _get_local_neighbor_tables(
self,
neighbor_tables_global: dict[gtx.FieldOffset, data_alloc.NDArray],
array_ns,
Expand Down Expand Up @@ -545,7 +545,7 @@ def _get_index_field(
)


def _get_derived_connectivities(
def _get_derived_neighbor_tables(
neighbor_tables: dict[gtx.FieldOffset, data_alloc.NDArray], array_ns: ModuleType = np
) -> dict[gtx.FieldOffset, data_alloc.NDArray]:
e2v_table = neighbor_tables[dims.E2V]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def test_element_ownership_is_unique(
parallel_helpers.check_comm_size(processor_props, sizes=[4])

halo_generator = halo.IconLikeHaloConstructor(
connectivities=simple_neighbor_tables,
neighbor_tables=simple_neighbor_tables,
run_properties=processor_props,
allocator=backend,
)
Expand Down
14 changes: 7 additions & 7 deletions model/common/tests/common/decomposition/unit_tests/test_halo.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_halo_constructor_owned_cells(rank, simple_neighbor_tables, backend_like
processor_props = utils.DummyProps(rank=rank)
allocator = model_backends.get_allocator(backend_like)
halo_generator = halo.IconLikeHaloConstructor(
connectivities=simple_neighbor_tables,
neighbor_tables=simple_neighbor_tables,
run_properties=processor_props,
allocator=allocator,
)
Expand All @@ -53,7 +53,7 @@ def test_halo_constructor_decomposition_info_global_indices(rank, simple_neighbo
)

halo_generator = halo.IconLikeHaloConstructor(
connectivities=simple_neighbor_tables,
neighbor_tables=simple_neighbor_tables,
run_properties=processor_props,
)

Expand All @@ -80,7 +80,7 @@ def test_halo_constructor_decomposition_info_global_indices(rank, simple_neighbo
def test_halo_constructor_decomposition_info_halo_levels(rank, dim, simple_neighbor_tables):
processor_props = utils.DummyProps(rank=rank)
halo_generator = halo.IconLikeHaloConstructor(
connectivities=simple_neighbor_tables,
neighbor_tables=simple_neighbor_tables,
run_properties=processor_props,
)
decomp_info = halo_generator(utils.SIMPLE_DISTRIBUTION)
Expand Down Expand Up @@ -164,7 +164,7 @@ def test_halo_constructor_validate_rank_mapping_wrong_shape(simple_neighbor_tabl
num_cells = simple_neighbor_tables["C2E2C"].shape[0]
with pytest.raises(exceptions.ValidationError) as e:
halo_generator = halo.IconLikeHaloConstructor(
connectivities=simple_neighbor_tables,
neighbor_tables=simple_neighbor_tables,
run_properties=processor_props,
)
halo_generator(np.zeros((num_cells, 3), dtype=int))
Expand All @@ -178,7 +178,7 @@ def test_halo_constructor_validate_number_of_node_mismatch(rank, simple_neighbor
distribution = (processor_props.comm_size + 1) * np.ones((num_cells,), dtype=int)
with pytest.raises(expected_exception=exceptions.ValidationError) as e:
halo_generator = halo.IconLikeHaloConstructor(
connectivities=simple_neighbor_tables,
neighbor_tables=simple_neighbor_tables,
run_properties=processor_props,
)
halo_generator(distribution)
Expand All @@ -190,7 +190,7 @@ def test_owned_halo_mask_contiguous(rank):
simple_neighbor_tables = get_neighbor_tables_for_simple_grid()
props = dummy_four_ranks(rank)
halo_generator = halo.IconLikeHaloConstructor(
connectivities=simple_neighbor_tables,
neighbor_tables=simple_neighbor_tables,
run_properties=props,
)
decomp_info = halo_generator(utils.SIMPLE_DISTRIBUTION)
Expand Down Expand Up @@ -252,7 +252,7 @@ def test_horizontal_size(rank):
simple_neighbor_tables = get_neighbor_tables_for_simple_grid()
props = dummy_four_ranks(rank)
halo_generator = halo.IconLikeHaloConstructor(
connectivities=simple_neighbor_tables,
neighbor_tables=simple_neighbor_tables,
run_properties=props,
)
decomp_info = halo_generator(utils.SIMPLE_DISTRIBUTION)
Expand Down