|  | 
| 3 | 3 | from collections import defaultdict | 
| 4 | 4 | from collections.abc import Hashable, Iterable, Mapping, Sequence | 
| 5 | 5 | from collections.abc import Set as AbstractSet | 
| 6 |  | -from typing import TYPE_CHECKING, Any, NamedTuple, Union | 
|  | 6 | +from typing import TYPE_CHECKING, Any, NamedTuple, Union, cast, overload | 
| 7 | 7 | 
 | 
| 8 | 8 | import pandas as pd | 
| 9 | 9 | 
 | 
|  | 
| 34 | 34 |     from xarray.core.coordinates import Coordinates | 
| 35 | 35 |     from xarray.core.dataarray import DataArray | 
| 36 | 36 |     from xarray.core.dataset import Dataset | 
|  | 37 | +    from xarray.core.datatree import DataTree | 
| 37 | 38 |     from xarray.core.types import ( | 
| 38 | 39 |         CombineAttrsOptions, | 
| 39 | 40 |         CompatOptions, | 
| @@ -793,18 +794,101 @@ def merge_core( | 
| 793 | 794 |     return _MergeResult(variables, coord_names, dims, out_indexes, attrs) | 
| 794 | 795 | 
 | 
| 795 | 796 | 
 | 
|  | 797 | +def merge_trees( | 
|  | 798 | +    trees: Iterable[DataTree], | 
|  | 799 | +    compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT, | 
|  | 800 | +    join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT, | 
|  | 801 | +    fill_value: object = dtypes.NA, | 
|  | 802 | +    combine_attrs: CombineAttrsOptions = "override", | 
|  | 803 | +) -> DataTree: | 
|  | 804 | +    """Merge specialized to DataTree objects.""" | 
|  | 805 | +    from xarray.core.dataset import Dataset | 
|  | 806 | +    from xarray.core.datatree import DataTree | 
|  | 807 | +    from xarray.core.datatree_mapping import add_path_context_to_errors | 
|  | 808 | + | 
|  | 809 | +    if fill_value is not dtypes.NA: | 
|  | 810 | +        # fill_value support dicts, which probably should be mapped to sub-groups? | 
|  | 811 | +        raise NotImplementedError( | 
|  | 812 | +            "fill_value is not yet supported for DataTree objects in merge" | 
|  | 813 | +        ) | 
|  | 814 | + | 
|  | 815 | +    node_lists: defaultdict[str, list[DataTree]] = defaultdict(list) | 
|  | 816 | +    for tree in trees: | 
|  | 817 | +        for key, node in tree.subtree_with_keys: | 
|  | 818 | +            node_lists[key].append(node) | 
|  | 819 | + | 
|  | 820 | +    root_datasets = [node.dataset for node in node_lists.pop(".")] | 
|  | 821 | +    with add_path_context_to_errors("."): | 
|  | 822 | +        root_ds = merge( | 
|  | 823 | +            root_datasets, compat=compat, join=join, combine_attrs=combine_attrs | 
|  | 824 | +        ) | 
|  | 825 | +    result = DataTree(dataset=root_ds) | 
|  | 826 | + | 
|  | 827 | +    def level(kv): | 
|  | 828 | +        # all trees with the same path have the same level | 
|  | 829 | +        _, trees = kv | 
|  | 830 | +        return trees[0].level | 
|  | 831 | + | 
|  | 832 | +    for key, nodes in sorted(node_lists.items(), key=level): | 
|  | 833 | +        # Merge datasets, including inherited indexes to ensure alignment. | 
|  | 834 | +        datasets = [node.dataset for node in nodes] | 
|  | 835 | +        with add_path_context_to_errors(key): | 
|  | 836 | +            merge_result = merge_core( | 
|  | 837 | +                datasets, | 
|  | 838 | +                compat=compat, | 
|  | 839 | +                join=join, | 
|  | 840 | +                combine_attrs=combine_attrs, | 
|  | 841 | +            ) | 
|  | 842 | +        # Remove inherited coordinates/indexes/dimensions. | 
|  | 843 | +        for var_name in list(merge_result.coord_names): | 
|  | 844 | +            if not any(var_name in node._coord_variables for node in nodes): | 
|  | 845 | +                del merge_result.variables[var_name] | 
|  | 846 | +                merge_result.coord_names.remove(var_name) | 
|  | 847 | +        for index_name in list(merge_result.indexes): | 
|  | 848 | +            if not any(index_name in node._node_indexes for node in nodes): | 
|  | 849 | +                del merge_result.indexes[index_name] | 
|  | 850 | +        for dim in list(merge_result.dims): | 
|  | 851 | +            if not any(dim in node._node_dims for node in nodes): | 
|  | 852 | +                del merge_result.dims[dim] | 
|  | 853 | + | 
|  | 854 | +        merged_ds = Dataset._construct_direct(**merge_result._asdict()) | 
|  | 855 | +        result[key] = DataTree(dataset=merged_ds) | 
|  | 856 | + | 
|  | 857 | +    return result | 
|  | 858 | + | 
|  | 859 | + | 
|  | 860 | +@overload | 
|  | 861 | +def merge( | 
|  | 862 | +    objects: Iterable[DataTree], | 
|  | 863 | +    compat: CompatOptions | CombineKwargDefault = ..., | 
|  | 864 | +    join: JoinOptions | CombineKwargDefault = ..., | 
|  | 865 | +    fill_value: object = ..., | 
|  | 866 | +    combine_attrs: CombineAttrsOptions = ..., | 
|  | 867 | +) -> DataTree: ... | 
|  | 868 | + | 
|  | 869 | + | 
|  | 870 | +@overload | 
|  | 871 | +def merge( | 
|  | 872 | +    objects: Iterable[DataArray | Dataset | Coordinates | dict], | 
|  | 873 | +    compat: CompatOptions | CombineKwargDefault = ..., | 
|  | 874 | +    join: JoinOptions | CombineKwargDefault = ..., | 
|  | 875 | +    fill_value: object = ..., | 
|  | 876 | +    combine_attrs: CombineAttrsOptions = ..., | 
|  | 877 | +) -> Dataset: ... | 
|  | 878 | + | 
|  | 879 | + | 
| 796 | 880 | def merge( | 
| 797 |  | -    objects: Iterable[DataArray | CoercibleMapping], | 
|  | 881 | +    objects: Iterable[DataTree | DataArray | Dataset | Coordinates | dict], | 
| 798 | 882 |     compat: CompatOptions | CombineKwargDefault = _COMPAT_DEFAULT, | 
| 799 | 883 |     join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT, | 
| 800 | 884 |     fill_value: object = dtypes.NA, | 
| 801 | 885 |     combine_attrs: CombineAttrsOptions = "override", | 
| 802 |  | -) -> Dataset: | 
|  | 886 | +) -> DataTree | Dataset: | 
| 803 | 887 |     """Merge any number of xarray objects into a single Dataset as variables. | 
| 804 | 888 | 
 | 
| 805 | 889 |     Parameters | 
| 806 | 890 |     ---------- | 
| 807 |  | -    objects : iterable of Dataset or iterable of DataArray or iterable of dict-like | 
|  | 891 | +    objects : iterable of DataArray, Dataset, DataTree or dict | 
| 808 | 892 |         Merge together all variables from these objects. If any of them are | 
| 809 | 893 |         DataArray objects, they must have a name. | 
| 810 | 894 |     compat : {"identical", "equals", "broadcast_equals", "no_conflicts", \ | 
| @@ -859,8 +943,9 @@ def merge( | 
| 859 | 943 | 
 | 
| 860 | 944 |     Returns | 
| 861 | 945 |     ------- | 
| 862 |  | -    Dataset | 
| 863 |  | -        Dataset with combined variables from each object. | 
|  | 946 | +    Dataset or DataTree | 
|  | 947 | +        Objects with combined variables from the inputs. If any inputs are a | 
|  | 948 | +        DataTree, this will also be a DataTree. Otherwise it will be a Dataset. | 
| 864 | 949 | 
 | 
| 865 | 950 |     Examples | 
| 866 | 951 |     -------- | 
| @@ -1023,13 +1108,31 @@ def merge( | 
| 1023 | 1108 |     from xarray.core.coordinates import Coordinates | 
| 1024 | 1109 |     from xarray.core.dataarray import DataArray | 
| 1025 | 1110 |     from xarray.core.dataset import Dataset | 
|  | 1111 | +    from xarray.core.datatree import DataTree | 
|  | 1112 | + | 
|  | 1113 | +    objects = list(objects) | 
|  | 1114 | + | 
|  | 1115 | +    if any(isinstance(obj, DataTree) for obj in objects): | 
|  | 1116 | +        if not all(isinstance(obj, DataTree) for obj in objects): | 
|  | 1117 | +            raise TypeError( | 
|  | 1118 | +                "merge does not support mixed type arguments when one argument " | 
|  | 1119 | +                f"is a DataTree: {objects}" | 
|  | 1120 | +            ) | 
|  | 1121 | +        trees = cast(list[DataTree], objects) | 
|  | 1122 | +        return merge_trees( | 
|  | 1123 | +            trees, | 
|  | 1124 | +            compat=compat, | 
|  | 1125 | +            join=join, | 
|  | 1126 | +            combine_attrs=combine_attrs, | 
|  | 1127 | +            fill_value=fill_value, | 
|  | 1128 | +        ) | 
| 1026 | 1129 | 
 | 
| 1027 | 1130 |     dict_like_objects = [] | 
| 1028 | 1131 |     for obj in objects: | 
| 1029 | 1132 |         if not isinstance(obj, DataArray | Dataset | Coordinates | dict): | 
| 1030 | 1133 |             raise TypeError( | 
| 1031 |  | -                "objects must be an iterable containing only " | 
| 1032 |  | -                "Dataset(s), DataArray(s), and dictionaries." | 
|  | 1134 | +                "objects must be an iterable containing only DataTree(s), " | 
|  | 1135 | +                f"Dataset(s), DataArray(s), and dictionaries: {objects}" | 
| 1033 | 1136 |             ) | 
| 1034 | 1137 | 
 | 
| 1035 | 1138 |         if isinstance(obj, DataArray): | 
|  | 
0 commit comments