Skip to content

Commit

Permalink
Fix the push method when the limit parameter is bigger than the chunk… (
Browse files Browse the repository at this point in the history
#9940)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
3 people authored Jan 29, 2025
1 parent 5fdceff commit d7ac79a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 55 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ Deprecations

Bug fixes
~~~~~~~~~

- Fix :py:meth:`DataArray.ffill`, :py:meth:`DataArray.bfill`, :py:meth:`Dataset.ffill` and :py:meth:`Dataset.bfill` when the limit is bigger than the chunksize (:issue:`9939`).
By `Joseph Nowak <https://github.com/josephnowak>`_.
- Fix issues related to Pandas v3 ("us" vs. "ns" for python datetime, copy on write) and handling of 0d-numpy arrays in datetime/timedelta decoding (:pull:`9953`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Remove dask-expr from CI runs, add "pyarrow" dask dependency to windows CI runs, fix related tests (:issue:`9962`, :pull:`9971`).
Expand Down
50 changes: 13 additions & 37 deletions xarray/core/dask_array_ops.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import math
from functools import partial

from xarray.core import dtypes, nputils

Expand Down Expand Up @@ -92,31 +91,6 @@ def _dtype_push(a, axis, dtype=None):
return _push(a, axis=axis)


def _reset_cumsum(a, axis, dtype=None):
import numpy as np

cumsum = np.cumsum(a, axis=axis)
reset_points = np.maximum.accumulate(np.where(a == 0, cumsum, 0), axis=axis)
return cumsum - reset_points


def _last_reset_cumsum(a, axis, keepdims=None):
import numpy as np

# Take the last cumulative sum taking into account the reset
# This is useful for blelloch method
return np.take(_reset_cumsum(a, axis=axis), axis=axis, indices=[-1])


def _combine_reset_cumsum(a, b, axis):
import numpy as np

# It is going to sum the previous result until the first
# non nan value
bitmask = np.cumprod(b != 0, axis=axis)
return np.where(bitmask, b + a, b)


def push(array, n, axis, method="blelloch"):
"""
Dask-aware bottleneck.push
Expand Down Expand Up @@ -145,16 +119,18 @@ def push(array, n, axis, method="blelloch"):
)

if n is not None and 0 < n < array.shape[axis] - 1:
valid_positions = da.reductions.cumreduction(
func=_reset_cumsum,
binop=partial(_combine_reset_cumsum, axis=axis),
ident=0,
x=da.isnan(array, dtype=int),
axis=axis,
dtype=int,
method=method,
preop=_last_reset_cumsum,
)
pushed_array = da.where(valid_positions <= n, pushed_array, np.nan)
# The idea is to calculate a cumulative sum of a bitmask
# created from the isnan method, but every time a False is found the sum
# must be restarted, and the final result indicates the amount of contiguous
# nan values found in the original array on every position
nan_bitmask = da.isnan(array, dtype=int)
cumsum_nan = nan_bitmask.cumsum(axis=axis, method=method)
valid_positions = da.where(nan_bitmask == 0, cumsum_nan, np.nan)
valid_positions = push(valid_positions, None, axis, method=method)
# All the NaNs at the beginning are converted to 0
valid_positions = da.nan_to_num(valid_positions)
valid_positions = cumsum_nan - valid_positions
valid_positions = valid_positions <= n
pushed_array = da.where(valid_positions, pushed_array, np.nan)

return pushed_array
45 changes: 27 additions & 18 deletions xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1025,31 +1025,40 @@ def test_least_squares(use_dask, skipna):
@requires_dask
@requires_bottleneck
@pytest.mark.parametrize("method", ["sequential", "blelloch"])
def test_push_dask(method):
@pytest.mark.parametrize(
"arr",
[
[np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6],
[
np.nan,
np.nan,
np.nan,
2,
np.nan,
np.nan,
np.nan,
9,
np.nan,
np.nan,
np.nan,
np.nan,
],
],
)
def test_push_dask(method, arr):
import bottleneck
import dask.array
import dask.array as da

array = np.array([np.nan, 1, 2, 3, np.nan, np.nan, np.nan, np.nan, 4, 5, np.nan, 6])
arr = np.array(arr)
chunks = list(range(1, 11)) + [(1, 2, 3, 2, 2, 1, 1)]

for n in [None, 1, 2, 3, 4, 5, 11]:
expected = bottleneck.push(array, axis=0, n=n)
for c in range(1, 11):
expected = bottleneck.push(arr, axis=0, n=n)
for c in chunks:
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=c), axis=0, n=n, method=method
)
actual = push(da.from_array(arr, chunks=c), axis=0, n=n, method=method)
np.testing.assert_equal(actual, expected)

# some chunks of size-1 with NaN
with raise_if_dask_computes():
actual = push(
dask.array.from_array(array, chunks=(1, 2, 3, 2, 2, 1, 1)),
axis=0,
n=n,
method=method,
)
np.testing.assert_equal(actual, expected)


def test_extension_array_equality(categorical1, int1):
int_duck_array = PandasExtensionArray(int1)
Expand Down

0 comments on commit d7ac79a

Please sign in to comment.