Skip to content

map_over_datasets: skip empty nodes #10042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ New Features

Breaking changes
~~~~~~~~~~~~~~~~
- Skip empty nodes in :py:func:`map_over_datasets`. Also affects binary operations.
This is a breaking change in xarray, but restores the behavior of the xarray-datatree package (:issue:`9693`, :pull:`10042`).
By `Mathias Hauser <https://github.com/mathause>`_.
- Warn instead of raise if phony_dims are detected when using h5netcdf-backend and ``phony_dims=None`` (:issue:`10049`, :pull:`10058`)
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.

Expand Down
70 changes: 52 additions & 18 deletions xarray/core/datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,23 @@ def map_over_datasets(
kwargs: Mapping[str, Any] | None = None,
) -> DataTree | tuple[DataTree, ...]:
"""
Applies a function to every dataset in one or more DataTree objects with
the same structure (ie.., that are isomorphic), returning new trees which
Applies a function to every non-empty dataset in one or more DataTree objects
with the same structure (i.e., that are isomorphic), returning new trees which
store the results.

The function will be applied to any dataset stored in any of the nodes in
the trees. The returned trees will have the same structure as the supplied
trees.
The function will be applied to every node containing data (i.e., which has
``data_vars`` and/ or ``coordinates``) in the trees. The returned tree(s) will
have the same structure as the supplied trees.

``func`` needs to return a Dataset, tuple of Dataset objects or None in order
``func`` needs to return a Dataset, tuple of Dataset objects or None
to be able to rebuild the subtrees after mapping, as each result will be
assigned to its respective node of a new tree via `DataTree.from_dict`. Any
returned value that is one of these types will be stacked into a separate
tree before returning all of them.

``map_over_datasets`` is essentially syntactic sugar for the combination of
``group_subtrees`` and ``DataTree.from_dict``. For example, in the case of
a two argument function that return one result, it is equivalent to::
a two argument function that returns one result, it is equivalent to::

results = {}
for path, (left, right) in group_subtrees(left_tree, right_tree):
Expand Down Expand Up @@ -107,21 +107,35 @@ def map_over_datasets(
# We don't know which arguments are DataTrees so we zip all arguments together as iterables
# Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return
out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {}
func_called: dict[str, bool] = {} # empty nodes don't call `func`

tree_args = [arg for arg in args if isinstance(arg, DataTree)]
name = result_name(tree_args)

for path, node_tree_args in group_subtrees(*tree_args):
node_dataset_args = [arg.dataset for arg in node_tree_args]
for i, arg in enumerate(args):
if not isinstance(arg, DataTree):
node_dataset_args.insert(i, arg)
if node_tree_args[0].has_data:
node_dataset_args = [arg.dataset for arg in node_tree_args]
for i, arg in enumerate(args):
if not isinstance(arg, DataTree):
node_dataset_args.insert(i, arg)

func_with_error_context = _handle_errors_with_path_context(path)(func)
results = func_with_error_context(*node_dataset_args, **kwargs)
func_called[path] = True

elif node_tree_args[0].has_attrs:
# propagate attrs
results = node_tree_args[0].dataset
func_called[path] = False

else:
# use Dataset instead of None to ensure it has copy method
results = Dataset()
func_called[path] = False

func_with_error_context = _handle_errors_with_path_context(path)(func)
results = func_with_error_context(*node_dataset_args, **kwargs)
out_data_objects[path] = results

num_return_values = _check_all_return_values(out_data_objects)
num_return_values = _check_all_return_values(out_data_objects, func_called)

if num_return_values is None:
# one return value
Expand All @@ -134,6 +148,11 @@ def map_over_datasets(
{} for _ in range(num_return_values)
]
for path, outputs in out_data_tuples.items():
# duplicate outputs when func was not called (empty nodes)
if not func_called[path]:
out = cast(Dataset, outputs)
outputs = tuple(out.copy() for _ in range(num_return_values))

for output_dict, output in zip(output_dicts, outputs, strict=False):
output_dict[path] = output

Expand Down Expand Up @@ -185,17 +204,32 @@ def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None:
return len(obj)


def _check_all_return_values(returned_objects) -> int | None:
def _check_all_return_values(returned_objects, func_called) -> int | None:
"""Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types."""

result_data_objects = list(returned_objects.items())

first_path, result = result_data_objects[0]
return_values = _check_single_set_return_values(first_path, result)
func_called_before = False

# initialize to None if all nodes are empty
return_values = None

for path_to_node, obj in result_data_objects[1:]:
for path_to_node, obj in result_data_objects:
cur_return_values = _check_single_set_return_values(path_to_node, obj)

cur_func_called = func_called[path_to_node]

# the first node where func was actually called - needed to find the number of
# return values
if cur_func_called and not func_called_before:
return_values = cur_return_values
func_called_before = True
first_path = path_to_node

# no need to check if the function was not called
if not cur_func_called:
continue

if return_values != cur_return_values:
if return_values is None:
raise TypeError(
Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -2113,6 +2113,22 @@ def test_binary_op_on_dataset(self) -> None:
result = dt * other_ds
assert_equal(result, expected)

def test_binary_op_on_dataset_skip_empty_nodes(self) -> None:
# https://github.com/pydata/xarray/issues/10013

a = xr.Dataset(data_vars={"x": ("time", [10])}, coords={"time": [0]})
b = xr.Dataset(data_vars={"x": ("time", [11, 22])}, coords={"time": [0, 1]})

dt = DataTree.from_dict({"a": a, "b": b})

expected = DataTree.from_dict({"a": a - b, "b": b - b})

# if the empty root node is not skipped its coordinates become inconsistent
# with the ones of node a
result = dt - b

assert_equal(result, expected)

def test_binary_op_on_datatree(self) -> None:
ds1 = xr.Dataset({"a": [5], "b": [3]})
ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]})
Expand Down
16 changes: 15 additions & 1 deletion xarray/tests/test_datatree_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ def multiply_by_kwarg(ds, **kwargs):
)
assert_equal(result_tree, expected)

def test_single_tree_skip_empty_nodes(self, create_test_datatree):
dt = create_test_datatree()
expected = create_test_datatree(lambda ds: ds.rename(a="c"))
# this would fail on empty nodes
result_tree = map_over_datasets(lambda ds: ds.rename(a="c"), dt)
assert_equal(result_tree, expected)

def test_multiple_tree_args(self, create_test_datatree):
dt1 = create_test_datatree()
dt2 = create_test_datatree()
Expand All @@ -79,6 +86,14 @@ def test_return_multiple_trees(self, create_test_datatree):
expected_max = create_test_datatree(modify=lambda ds: ds.max())
assert_equal(dt_max, expected_max)

def test_return_multiple_trees_empty_first_node(self):
# check result tree is constructed correctly even if first nodes are empty
ds = xr.Dataset(data_vars={"a": ("x", [1, 2, 3])})
dt = xr.DataTree.from_dict({"set1": None, "set2": ds})
res_min, res_max = xr.map_over_datasets(lambda ds: (ds.min(), ds.max()), dt)
assert_equal(res_min, dt.min())
assert_equal(res_max, dt.max())

def test_return_wrong_type(self, simple_datatree):
dt1 = simple_datatree

Expand Down Expand Up @@ -183,7 +198,6 @@ def empty_func(ds):
def test_error_contains_path_of_offending_node(self, create_test_datatree):
dt = create_test_datatree()
dt["set1"]["bad_var"] = 0
print(dt)

def fail_on_specific_node(ds):
if "bad_var" in ds:
Expand Down
Loading