Skip to content

Commit c9d89e2

Browse files
Improve concat performance (#7824)
* 1. var_idx very slow * 2. slow any * Add test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * 3. Slow array_type called multiple times * 4. Can use fastpath for variable.concat? * 5. slow init of pd.unique * typos * Update concat.py * Update merge.py * 6. Avoid recalculating in loops * 7. No need to transpose 1d arrays. * 8. speed up dask_dataframe * Update dataset.py * Update dataset.py * Update dataset.py * Add dask combine test with many variables * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update combine.py * Update combine.py * Update combine.py * list not needed * dim is usually string, might be faster to check for that * first_var.dims doesn't change and can be defined 1 time * mask bad points rather than append good points * reduce duplicated code * don't think id() is required here. * get dtype directly instead of through result_dtype * seems better to delete rather than append, * use internal fastpath if it's a dataset, values should be fine then * Change isinstance order. * use fastpath if already xarray objtect * Update variable.py * Update dtypes.py * typing fixes * more typing fixes * test undoing as_compatible_data * undo concat_dim_length deletion * Update xarray/core/concat.py * Remove .copy and sum * Update concat.py * Use OrderedSet * Apply suggestions from code review * Update whats-new.rst * Update xarray/core/concat.py * no need to check arrays if cupy isnt even installed * Update whats-new.rst * Add concat comment * minimize diff * revert sketchy --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 960f15c commit c9d89e2

File tree

12 files changed

+131
-51
lines changed

12 files changed

+131
-51
lines changed

asv_bench/benchmarks/combine.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,49 @@
22

33
import xarray as xr
44

5+
from . import requires_dask
56

6-
class Combine:
7+
8+
class Combine1d:
9+
"""Benchmark concatenating and merging large datasets"""
10+
11+
def setup(self) -> None:
12+
"""Create 2 datasets with two different variables"""
13+
14+
t_size = 8000
15+
t = np.arange(t_size)
16+
data = np.random.randn(t_size)
17+
18+
self.dsA0 = xr.Dataset({"A": xr.DataArray(data, coords={"T": t}, dims=("T"))})
19+
self.dsA1 = xr.Dataset(
20+
{"A": xr.DataArray(data, coords={"T": t + t_size}, dims=("T"))}
21+
)
22+
23+
def time_combine_by_coords(self) -> None:
24+
"""Also has to load and arrange t coordinate"""
25+
datasets = [self.dsA0, self.dsA1]
26+
27+
xr.combine_by_coords(datasets)
28+
29+
30+
class Combine1dDask(Combine1d):
31+
"""Benchmark concatenating and merging large datasets"""
32+
33+
def setup(self) -> None:
34+
"""Create 2 datasets with two different variables"""
35+
requires_dask()
36+
37+
t_size = 8000
38+
t = np.arange(t_size)
39+
var = xr.Variable(dims=("T",), data=np.random.randn(t_size)).chunk()
40+
41+
data_vars = {f"long_name_{v}": ("T", var) for v in range(500)}
42+
43+
self.dsA0 = xr.Dataset(data_vars, coords={"T": t})
44+
self.dsA1 = xr.Dataset(data_vars, coords={"T": t + t_size})
45+
46+
47+
class Combine3d:
748
"""Benchmark concatenating and merging large datasets"""
849

950
def setup(self):

doc/whats-new.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ Deprecations
3737

3838
Performance
3939
~~~~~~~~~~~
40-
40+
- Improve concatenation performance (:issue:`7833`, :pull:`7824`).
41+
By `Jimmy Westling <https://github.com/illviljan>`_.
4142

4243
Bug fixes
4344
~~~~~~~~~

xarray/core/combine.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -970,18 +970,18 @@ def combine_by_coords(
970970

971971
# Perform the multidimensional combine on each group of data variables
972972
# before merging back together
973-
concatenated_grouped_by_data_vars = []
974-
for vars, datasets_with_same_vars in grouped_by_vars:
975-
concatenated = _combine_single_variable_hypercube(
976-
list(datasets_with_same_vars),
973+
concatenated_grouped_by_data_vars = tuple(
974+
_combine_single_variable_hypercube(
975+
tuple(datasets_with_same_vars),
977976
fill_value=fill_value,
978977
data_vars=data_vars,
979978
coords=coords,
980979
compat=compat,
981980
join=join,
982981
combine_attrs=combine_attrs,
983982
)
984-
concatenated_grouped_by_data_vars.append(concatenated)
983+
for vars, datasets_with_same_vars in grouped_by_vars
984+
)
985985

986986
return merge(
987987
concatenated_grouped_by_data_vars,

xarray/core/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def get_axis_num(self, dim: Hashable | Iterable[Hashable]) -> int | tuple[int, .
211211
int or tuple of int
212212
Axis number or numbers corresponding to the given dimensions.
213213
"""
214-
if isinstance(dim, Iterable) and not isinstance(dim, str):
214+
if not isinstance(dim, str) and isinstance(dim, Iterable):
215215
return tuple(self._get_axis_num(d) for d in dim)
216216
else:
217217
return self._get_axis_num(dim)

xarray/core/concat.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Hashable, Iterable
44
from typing import TYPE_CHECKING, Any, Union, cast, overload
55

6+
import numpy as np
67
import pandas as pd
78

89
from xarray.core import dtypes, utils
@@ -517,7 +518,7 @@ def _dataset_concat(
517518
if variables_to_merge:
518519
grouped = {
519520
k: v
520-
for k, v in collect_variables_and_indexes(list(datasets)).items()
521+
for k, v in collect_variables_and_indexes(datasets).items()
521522
if k in variables_to_merge
522523
}
523524
merged_vars, merged_indexes = merge_collected(
@@ -543,7 +544,7 @@ def ensure_common_dims(vars, concat_dim_lengths):
543544
# ensure each variable with the given name shares the same
544545
# dimensions and the same shape for all of them except along the
545546
# concat dimension
546-
common_dims = tuple(pd.unique([d for v in vars for d in v.dims]))
547+
common_dims = tuple(utils.OrderedSet(d for v in vars for d in v.dims))
547548
if dim not in common_dims:
548549
common_dims = (dim,) + common_dims
549550
for var, dim_len in zip(vars, concat_dim_lengths):
@@ -568,38 +569,45 @@ def get_indexes(name):
568569
yield PandasIndex(data, dim, coord_dtype=var.dtype)
569570

570571
# create concatenation index, needed for later reindexing
571-
concat_index = list(range(sum(concat_dim_lengths)))
572+
file_start_indexes = np.append(0, np.cumsum(concat_dim_lengths))
573+
concat_index = np.arange(file_start_indexes[-1])
574+
concat_index_size = concat_index.size
575+
variable_index_mask = np.ones(concat_index_size, dtype=bool)
572576

573577
# stack up each variable and/or index to fill-out the dataset (in order)
574578
# n.b. this loop preserves variable order, needed for groupby.
579+
ndatasets = len(datasets)
575580
for name in vars_order:
576581
if name in concat_over and name not in result_indexes:
577582
variables = []
578-
variable_index = []
583+
# Initialize the mask to all True then set False if any name is missing in
584+
# the datasets:
585+
variable_index_mask.fill(True)
579586
var_concat_dim_length = []
580587
for i, ds in enumerate(datasets):
581588
if name in ds.variables:
582589
variables.append(ds[name].variable)
583-
# add to variable index, needed for reindexing
584-
var_idx = [
585-
sum(concat_dim_lengths[:i]) + k
586-
for k in range(concat_dim_lengths[i])
587-
]
588-
variable_index.extend(var_idx)
589-
var_concat_dim_length.append(len(var_idx))
590+
var_concat_dim_length.append(concat_dim_lengths[i])
590591
else:
591592
# raise if coordinate not in all datasets
592593
if name in coord_names:
593594
raise ValueError(
594595
f"coordinate {name!r} not present in all datasets."
595596
)
597+
598+
# Mask out the indexes without the name:
599+
start = file_start_indexes[i]
600+
end = file_start_indexes[i + 1]
601+
variable_index_mask[slice(start, end)] = False
602+
603+
variable_index = concat_index[variable_index_mask]
596604
vars = ensure_common_dims(variables, var_concat_dim_length)
597605

598606
# Try to concatenate the indexes, concatenate the variables when no index
599607
# is found on all datasets.
600608
indexes: list[Index] = list(get_indexes(name))
601609
if indexes:
602-
if len(indexes) < len(datasets):
610+
if len(indexes) < ndatasets:
603611
raise ValueError(
604612
f"{name!r} must have either an index or no index in all datasets, "
605613
f"found {len(indexes)}/{len(datasets)} datasets with an index."
@@ -623,7 +631,7 @@ def get_indexes(name):
623631
vars, dim, positions, combine_attrs=combine_attrs
624632
)
625633
# reindex if variable is not present in all datasets
626-
if len(variable_index) < len(concat_index):
634+
if len(variable_index) < concat_index_size:
627635
combined_var = reindex_variables(
628636
variables={name: combined_var},
629637
dim_pos_indexers={

xarray/core/dataset.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,7 @@ def __init__(
647647
)
648648

649649
if isinstance(coords, Dataset):
650-
coords = coords.variables
650+
coords = coords._variables
651651

652652
variables, coord_names, dims, indexes, _ = merge_data_and_coords(
653653
data_vars, coords, compat="broadcast_equals"
@@ -1399,8 +1399,8 @@ def _construct_dataarray(self, name: Hashable) -> DataArray:
13991399
coords: dict[Hashable, Variable] = {}
14001400
# preserve ordering
14011401
for k in self._variables:
1402-
if k in self._coord_names and set(self.variables[k].dims) <= needed_dims:
1403-
coords[k] = self.variables[k]
1402+
if k in self._coord_names and set(self._variables[k].dims) <= needed_dims:
1403+
coords[k] = self._variables[k]
14041404

14051405
indexes = filter_indexes_from_coords(self._indexes, set(coords))
14061406

xarray/core/dtypes.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,11 @@ def __eq__(self, other):
3737
# instead of following NumPy's own type-promotion rules. These type promotion
3838
# rules match pandas instead. For reference, see the NumPy type hierarchy:
3939
# https://numpy.org/doc/stable/reference/arrays.scalars.html
40-
PROMOTE_TO_OBJECT = [
41-
{np.number, np.character}, # numpy promotes to character
42-
{np.bool_, np.character}, # numpy promotes to character
43-
{np.bytes_, np.unicode_}, # numpy promotes to unicode
44-
]
40+
PROMOTE_TO_OBJECT: tuple[tuple[type[np.generic], type[np.generic]], ...] = (
41+
(np.number, np.character), # numpy promotes to character
42+
(np.bool_, np.character), # numpy promotes to character
43+
(np.bytes_, np.unicode_), # numpy promotes to unicode
44+
)
4545

4646

4747
def maybe_promote(dtype):
@@ -156,7 +156,9 @@ def is_datetime_like(dtype):
156156
return np.issubdtype(dtype, np.datetime64) or np.issubdtype(dtype, np.timedelta64)
157157

158158

159-
def result_type(*arrays_and_dtypes):
159+
def result_type(
160+
*arrays_and_dtypes: np.typing.ArrayLike | np.typing.DTypeLike,
161+
) -> np.dtype:
160162
"""Like np.result_type, but with type promotion rules matching pandas.
161163
162164
Examples of changed behavior:

xarray/core/duck_array_ops.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,10 @@ def asarray(data, xp=np):
194194

195195
def as_shared_dtype(scalars_or_arrays, xp=np):
196196
"""Cast a arrays to a shared dtype using xarray's type promotion rules."""
197-
if any(isinstance(x, array_type("cupy")) for x in scalars_or_arrays):
197+
array_type_cupy = array_type("cupy")
198+
if array_type_cupy and any(
199+
isinstance(x, array_type_cupy) for x in scalars_or_arrays
200+
):
198201
import cupy as cp
199202

200203
arrays = [asarray(x, xp=cp) for x in scalars_or_arrays]

xarray/core/indexes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1495,7 +1495,7 @@ def filter_indexes_from_coords(
14951495
of coordinate names.
14961496
14971497
"""
1498-
filtered_indexes: dict[Any, Index] = dict(**indexes)
1498+
filtered_indexes: dict[Any, Index] = dict(indexes)
14991499

15001500
index_coord_names: dict[Hashable, set[Hashable]] = defaultdict(set)
15011501
for name, idx in indexes.items():

xarray/core/merge.py

+19-8
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,11 @@ def _assert_prioritized_valid(
195195

196196

197197
def merge_collected(
198-
grouped: dict[Hashable, list[MergeElement]],
198+
grouped: dict[Any, list[MergeElement]],
199199
prioritized: Mapping[Any, MergeElement] | None = None,
200200
compat: CompatOptions = "minimal",
201201
combine_attrs: CombineAttrsOptions = "override",
202-
equals: dict[Hashable, bool] | None = None,
202+
equals: dict[Any, bool] | None = None,
203203
) -> tuple[dict[Hashable, Variable], dict[Hashable, Index]]:
204204
"""Merge dicts of variables, while resolving conflicts appropriately.
205205
@@ -306,7 +306,7 @@ def merge_collected(
306306

307307

308308
def collect_variables_and_indexes(
309-
list_of_mappings: list[DatasetLike],
309+
list_of_mappings: Iterable[DatasetLike],
310310
indexes: Mapping[Any, Any] | None = None,
311311
) -> dict[Hashable, list[MergeElement]]:
312312
"""Collect variables and indexes from list of mappings of xarray objects.
@@ -556,7 +556,12 @@ def merge_coords(
556556
return variables, out_indexes
557557

558558

559-
def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="outer"):
559+
def merge_data_and_coords(
560+
data_vars: Mapping[Any, Any],
561+
coords: Mapping[Any, Any],
562+
compat: CompatOptions = "broadcast_equals",
563+
join: JoinOptions = "outer",
564+
) -> _MergeResult:
560565
"""Used in Dataset.__init__."""
561566
indexes, coords = _create_indexes_from_coords(coords, data_vars)
562567
objects = [data_vars, coords]
@@ -570,7 +575,9 @@ def merge_data_and_coords(data_vars, coords, compat="broadcast_equals", join="ou
570575
)
571576

572577

573-
def _create_indexes_from_coords(coords, data_vars=None):
578+
def _create_indexes_from_coords(
579+
coords: Mapping[Any, Any], data_vars: Mapping[Any, Any] | None = None
580+
) -> tuple[dict, dict]:
574581
"""Maybe create default indexes from a mapping of coordinates.
575582
576583
Return those indexes and updated coordinates.
@@ -605,7 +612,11 @@ def _create_indexes_from_coords(coords, data_vars=None):
605612
return indexes, updated_coords
606613

607614

608-
def assert_valid_explicit_coords(variables, dims, explicit_coords):
615+
def assert_valid_explicit_coords(
616+
variables: Mapping[Any, Any],
617+
dims: Mapping[Any, int],
618+
explicit_coords: Iterable[Hashable],
619+
) -> None:
609620
"""Validate explicit coordinate names/dims.
610621
611622
Raise a MergeError if an explicit coord shares a name with a dimension
@@ -688,7 +699,7 @@ def merge_core(
688699
join: JoinOptions = "outer",
689700
combine_attrs: CombineAttrsOptions = "override",
690701
priority_arg: int | None = None,
691-
explicit_coords: Sequence | None = None,
702+
explicit_coords: Iterable[Hashable] | None = None,
692703
indexes: Mapping[Any, Any] | None = None,
693704
fill_value: object = dtypes.NA,
694705
) -> _MergeResult:
@@ -1035,7 +1046,7 @@ def dataset_merge_method(
10351046
# method due for backwards compatibility
10361047
# TODO: consider deprecating it?
10371048

1038-
if isinstance(overwrite_vars, Iterable) and not isinstance(overwrite_vars, str):
1049+
if not isinstance(overwrite_vars, str) and isinstance(overwrite_vars, Iterable):
10391050
overwrite_vars = set(overwrite_vars)
10401051
else:
10411052
overwrite_vars = {overwrite_vars}

xarray/core/pycompat.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,26 @@ def __init__(self, mod: ModType) -> None:
6161
self.available = duck_array_module is not None
6262

6363

64+
_cached_duck_array_modules: dict[ModType, DuckArrayModule] = {}
65+
66+
67+
def _get_cached_duck_array_module(mod: ModType) -> DuckArrayModule:
68+
if mod not in _cached_duck_array_modules:
69+
duckmod = DuckArrayModule(mod)
70+
_cached_duck_array_modules[mod] = duckmod
71+
return duckmod
72+
else:
73+
return _cached_duck_array_modules[mod]
74+
75+
6476
def array_type(mod: ModType) -> DuckArrayTypes:
6577
"""Quick wrapper to get the array class of the module."""
66-
return DuckArrayModule(mod).type
78+
return _get_cached_duck_array_module(mod).type
6779

6880

6981
def mod_version(mod: ModType) -> Version:
7082
"""Quick wrapper to get the version of the module."""
71-
return DuckArrayModule(mod).version
83+
return _get_cached_duck_array_module(mod).version
7284

7385

7486
def is_dask_collection(x):

0 commit comments

Comments
 (0)