diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 994fc70339c..27cf1a5d4f9 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,9 @@ New Features Breaking changes ~~~~~~~~~~~~~~~~ +- Skip empty nodes in :py:func:`map_over_datasets`. Also affects binary operations. + This is a breaking change in xarray, but restores the behavior of the xarray-datatree package (:issue:`9693`, :pull:`10042`). + By `Mathias Hauser `_. - Warn instead of raise if phony_dims are detected when using h5netcdf-backend and ``phony_dims=None`` (:issue:`10049`, :pull:`10058`) By `Kai Mühlbauer `_. diff --git a/xarray/core/datatree_mapping.py b/xarray/core/datatree_mapping.py index 6262c7f19cd..a1ae531338b 100644 --- a/xarray/core/datatree_mapping.py +++ b/xarray/core/datatree_mapping.py @@ -47,15 +47,15 @@ def map_over_datasets( kwargs: Mapping[str, Any] | None = None, ) -> DataTree | tuple[DataTree, ...]: """ - Applies a function to every dataset in one or more DataTree objects with - the same structure (ie.., that are isomorphic), returning new trees which + Applies a function to every non-empty dataset in one or more DataTree objects + with the same structure (i.e., that are isomorphic), returning new trees which store the results. - The function will be applied to any dataset stored in any of the nodes in - the trees. The returned trees will have the same structure as the supplied - trees. + The function will be applied to every node containing data (i.e., which has + ``data_vars`` and/ or ``coordinates``) in the trees. The returned tree(s) will + have the same structure as the supplied trees. - ``func`` needs to return a Dataset, tuple of Dataset objects or None in order + ``func`` needs to return a Dataset, tuple of Dataset objects or None to be able to rebuild the subtrees after mapping, as each result will be assigned to its respective node of a new tree via `DataTree.from_dict`. Any returned value that is one of these types will be stacked into a separate @@ -63,7 +63,7 @@ def map_over_datasets( ``map_over_datasets`` is essentially syntactic sugar for the combination of ``group_subtrees`` and ``DataTree.from_dict``. For example, in the case of - a two argument function that return one result, it is equivalent to:: + a two argument function that returns one result, it is equivalent to:: results = {} for path, (left, right) in group_subtrees(left_tree, right_tree): @@ -107,21 +107,35 @@ def map_over_datasets( # We don't know which arguments are DataTrees so we zip all arguments together as iterables # Store tuples of results in a dict because we don't yet know how many trees we need to rebuild to return out_data_objects: dict[str, Dataset | None | tuple[Dataset | None, ...]] = {} + func_called: dict[str, bool] = {} # empty nodes don't call `func` tree_args = [arg for arg in args if isinstance(arg, DataTree)] name = result_name(tree_args) for path, node_tree_args in group_subtrees(*tree_args): - node_dataset_args = [arg.dataset for arg in node_tree_args] - for i, arg in enumerate(args): - if not isinstance(arg, DataTree): - node_dataset_args.insert(i, arg) + if node_tree_args[0].has_data: + node_dataset_args = [arg.dataset for arg in node_tree_args] + for i, arg in enumerate(args): + if not isinstance(arg, DataTree): + node_dataset_args.insert(i, arg) + + func_with_error_context = _handle_errors_with_path_context(path)(func) + results = func_with_error_context(*node_dataset_args, **kwargs) + func_called[path] = True + + elif node_tree_args[0].has_attrs: + # propagate attrs + results = node_tree_args[0].dataset + func_called[path] = False + + else: + # use Dataset instead of None to ensure it has copy method + results = Dataset() + func_called[path] = False - func_with_error_context = _handle_errors_with_path_context(path)(func) - results = func_with_error_context(*node_dataset_args, **kwargs) out_data_objects[path] = results - num_return_values = _check_all_return_values(out_data_objects) + num_return_values = _check_all_return_values(out_data_objects, func_called) if num_return_values is None: # one return value @@ -134,6 +148,11 @@ def map_over_datasets( {} for _ in range(num_return_values) ] for path, outputs in out_data_tuples.items(): + # duplicate outputs when func was not called (empty nodes) + if not func_called[path]: + out = cast(Dataset, outputs) + outputs = tuple(out.copy() for _ in range(num_return_values)) + for output_dict, output in zip(output_dicts, outputs, strict=False): output_dict[path] = output @@ -185,17 +204,32 @@ def _check_single_set_return_values(path_to_node: str, obj: Any) -> int | None: return len(obj) -def _check_all_return_values(returned_objects) -> int | None: +def _check_all_return_values(returned_objects, func_called) -> int | None: """Walk through all values returned by mapping func over subtrees, raising on any invalid or inconsistent types.""" result_data_objects = list(returned_objects.items()) - first_path, result = result_data_objects[0] - return_values = _check_single_set_return_values(first_path, result) + func_called_before = False + + # initialize to None if all nodes are empty + return_values = None - for path_to_node, obj in result_data_objects[1:]: + for path_to_node, obj in result_data_objects: cur_return_values = _check_single_set_return_values(path_to_node, obj) + cur_func_called = func_called[path_to_node] + + # the first node where func was actually called - needed to find the number of + # return values + if cur_func_called and not func_called_before: + return_values = cur_return_values + func_called_before = True + first_path = path_to_node + + # no need to check if the function was not called + if not cur_func_called: + continue + if return_values != cur_return_values: if return_values is None: raise TypeError( diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 55b809307e4..4afd0855a0b 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -2113,6 +2113,22 @@ def test_binary_op_on_dataset(self) -> None: result = dt * other_ds assert_equal(result, expected) + def test_binary_op_on_dataset_skip_empty_nodes(self) -> None: + # https://github.com/pydata/xarray/issues/10013 + + a = xr.Dataset(data_vars={"x": ("time", [10])}, coords={"time": [0]}) + b = xr.Dataset(data_vars={"x": ("time", [11, 22])}, coords={"time": [0, 1]}) + + dt = DataTree.from_dict({"a": a, "b": b}) + + expected = DataTree.from_dict({"a": a - b, "b": b - b}) + + # if the empty root node is not skipped its coordinates become inconsistent + # with the ones of node a + result = dt - b + + assert_equal(result, expected) + def test_binary_op_on_datatree(self) -> None: ds1 = xr.Dataset({"a": [5], "b": [3]}) ds2 = xr.Dataset({"x": [0.1, 0.2], "y": [10, 20]}) diff --git a/xarray/tests/test_datatree_mapping.py b/xarray/tests/test_datatree_mapping.py index 6cb4455b739..c7b47d7a974 100644 --- a/xarray/tests/test_datatree_mapping.py +++ b/xarray/tests/test_datatree_mapping.py @@ -64,6 +64,13 @@ def multiply_by_kwarg(ds, **kwargs): ) assert_equal(result_tree, expected) + def test_single_tree_skip_empty_nodes(self, create_test_datatree): + dt = create_test_datatree() + expected = create_test_datatree(lambda ds: ds.rename(a="c")) + # this would fail on empty nodes + result_tree = map_over_datasets(lambda ds: ds.rename(a="c"), dt) + assert_equal(result_tree, expected) + def test_multiple_tree_args(self, create_test_datatree): dt1 = create_test_datatree() dt2 = create_test_datatree() @@ -79,6 +86,14 @@ def test_return_multiple_trees(self, create_test_datatree): expected_max = create_test_datatree(modify=lambda ds: ds.max()) assert_equal(dt_max, expected_max) + def test_return_multiple_trees_empty_first_node(self): + # check result tree is constructed correctly even if first nodes are empty + ds = xr.Dataset(data_vars={"a": ("x", [1, 2, 3])}) + dt = xr.DataTree.from_dict({"set1": None, "set2": ds}) + res_min, res_max = xr.map_over_datasets(lambda ds: (ds.min(), ds.max()), dt) + assert_equal(res_min, dt.min()) + assert_equal(res_max, dt.max()) + def test_return_wrong_type(self, simple_datatree): dt1 = simple_datatree @@ -183,7 +198,6 @@ def empty_func(ds): def test_error_contains_path_of_offending_node(self, create_test_datatree): dt = create_test_datatree() dt["set1"]["bad_var"] = 0 - print(dt) def fail_on_specific_node(ds): if "bad_var" in ds: