Skip to content

Commit 9909f90

Browse files
mgunyhoIllviljan
andauthored
Implement multidimensional initial guess and bounds for curvefit (#7821)
* Add test for multidimensional initial guess to curvefit * Pass initial guess with *args * Update curvefit docstrings * Add examples to curvefit * Add test for error on invalid p0 coords * Raise exception on invalid coordinates in initial guess * Add link from polyfit to curvefit * Update doc so it matches CI * Formatting * Add basic test for multidimensional bounds * Add tests for curvefit_helpers with array-valued bounds * First attempt at fixing initialize_curvefit_params, issues warnings * Alternative implementation of bounds initialization using xr.where(), avoids warnings * Pass also bounds as *args to _wrapper * Raise exception on unexpected dimensions in bounds * Update docstring of bounds * Update bounds docstring in dataarray also * Update type hints for curvefit p0 and bounds * Change list to tuple to pass mypy * Update whats-new * Use tuples in error message Co-authored-by: Illviljan <[email protected]> * Add type hints to test Co-authored-by: Illviljan <[email protected]> --------- Co-authored-by: Illviljan <[email protected]>
1 parent e3db616 commit 9909f90

File tree

4 files changed

+298
-37
lines changed

4 files changed

+298
-37
lines changed

doc/whats-new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ v2023.05.1 (unreleased)
2323
New Features
2424
~~~~~~~~~~~~
2525

26+
- Added support for multidimensional initial guess and bounds in :py:meth:`DataArray.curvefit` (:issue:`7768`, :pull:`7821`).
27+
By `András Gunyhó <https://github.com/mgunyho>`_.
2628

2729
Breaking changes
2830
~~~~~~~~~~~~~~~~

xarray/core/dataarray.py

Lines changed: 91 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5504,6 +5504,7 @@ def polyfit(
55045504
numpy.polyfit
55055505
numpy.polyval
55065506
xarray.polyval
5507+
DataArray.curvefit
55075508
"""
55085509
return self._to_temp_dataset().polyfit(
55095510
dim, deg, skipna=skipna, rcond=rcond, w=w, full=full, cov=cov
@@ -6158,8 +6159,8 @@ def curvefit(
61586159
func: Callable[..., Any],
61596160
reduce_dims: Dims = None,
61606161
skipna: bool = True,
6161-
p0: dict[str, Any] | None = None,
6162-
bounds: dict[str, Any] | None = None,
6162+
p0: dict[str, float | DataArray] | None = None,
6163+
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
61636164
param_names: Sequence[str] | None = None,
61646165
kwargs: dict[str, Any] | None = None,
61656166
) -> Dataset:
@@ -6190,12 +6191,16 @@ def curvefit(
61906191
Whether to skip missing values when fitting. Default is True.
61916192
p0 : dict-like or None, optional
61926193
Optional dictionary of parameter names to initial guesses passed to the
6193-
`curve_fit` `p0` arg. If none or only some parameters are passed, the rest will
6194-
be assigned initial values following the default scipy behavior.
6195-
bounds : dict-like or None, optional
6196-
Optional dictionary of parameter names to bounding values passed to the
6197-
`curve_fit` `bounds` arg. If none or only some parameters are passed, the rest
6198-
will be unbounded following the default scipy behavior.
6194+
`curve_fit` `p0` arg. If the values are DataArrays, they will be appropriately
6195+
broadcast to the coordinates of the array. If none or only some parameters are
6196+
passed, the rest will be assigned initial values following the default scipy
6197+
behavior.
6198+
bounds : dict-like, optional
6199+
Optional dictionary of parameter names to tuples of bounding values passed to the
6200+
`curve_fit` `bounds` arg. If any of the bounds are DataArrays, they will be
6201+
appropriately broadcast to the coordinates of the array. If none or only some
6202+
parameters are passed, the rest will be unbounded following the default scipy
6203+
behavior.
61996204
param_names : sequence of Hashable or None, optional
62006205
Sequence of names for the fittable parameters of `func`. If not supplied,
62016206
this will be automatically determined by arguments of `func`. `param_names`
@@ -6214,6 +6219,84 @@ def curvefit(
62146219
[var]_curvefit_covariance
62156220
The covariance matrix of the coefficient estimates.
62166221
6222+
Examples
6223+
--------
6224+
Generate some exponentially decaying data, where the decay constant and amplitude are
6225+
different for different values of the coordinate ``x``:
6226+
6227+
>>> rng = np.random.default_rng(seed=0)
6228+
>>> def exp_decay(t, time_constant, amplitude):
6229+
... return np.exp(-t / time_constant) * amplitude
6230+
...
6231+
>>> t = np.linspace(0, 10, 11)
6232+
>>> da = xr.DataArray(
6233+
... np.stack(
6234+
... [
6235+
... exp_decay(t, 1, 0.1),
6236+
... exp_decay(t, 2, 0.2),
6237+
... exp_decay(t, 3, 0.3),
6238+
... ]
6239+
... )
6240+
... + rng.normal(size=(3, t.size)) * 0.01,
6241+
... coords={"x": [0, 1, 2], "time": t},
6242+
... )
6243+
>>> da
6244+
<xarray.DataArray (x: 3, time: 11)>
6245+
array([[ 0.1012573 , 0.0354669 , 0.01993775, 0.00602771, -0.00352513,
6246+
0.00428975, 0.01328788, 0.009562 , -0.00700381, -0.01264187,
6247+
-0.0062282 ],
6248+
[ 0.20041326, 0.09805582, 0.07138797, 0.03216692, 0.01974438,
6249+
0.01097441, 0.00679441, 0.01015578, 0.01408826, 0.00093645,
6250+
0.01501222],
6251+
[ 0.29334805, 0.21847449, 0.16305984, 0.11130396, 0.07164415,
6252+
0.04744543, 0.03602333, 0.03129354, 0.01074885, 0.01284436,
6253+
0.00910995]])
6254+
Coordinates:
6255+
* x (x) int64 0 1 2
6256+
* time (time) float64 0.0 1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 10.0
6257+
6258+
Fit the exponential decay function to the data along the ``time`` dimension:
6259+
6260+
>>> fit_result = da.curvefit("time", exp_decay)
6261+
>>> fit_result["curvefit_coefficients"].sel(param="time_constant")
6262+
<xarray.DataArray 'curvefit_coefficients' (x: 3)>
6263+
array([1.05692036, 1.73549639, 2.94215771])
6264+
Coordinates:
6265+
* x (x) int64 0 1 2
6266+
param <U13 'time_constant'
6267+
>>> fit_result["curvefit_coefficients"].sel(param="amplitude")
6268+
<xarray.DataArray 'curvefit_coefficients' (x: 3)>
6269+
array([0.1005489 , 0.19631423, 0.30003579])
6270+
Coordinates:
6271+
* x (x) int64 0 1 2
6272+
param <U13 'amplitude'
6273+
6274+
An initial guess can also be given with the ``p0`` arg (although it does not make much
6275+
of a difference in this simple example). To have a different guess for different
6276+
coordinate points, the guess can be a DataArray. Here we use the same initial guess
6277+
for the amplitude but different guesses for the time constant:
6278+
6279+
>>> fit_result = da.curvefit(
6280+
... "time",
6281+
... exp_decay,
6282+
... p0={
6283+
... "amplitude": 0.2,
6284+
... "time_constant": xr.DataArray([1, 2, 3], coords=[da.x]),
6285+
... },
6286+
... )
6287+
>>> fit_result["curvefit_coefficients"].sel(param="time_constant")
6288+
<xarray.DataArray 'curvefit_coefficients' (x: 3)>
6289+
array([1.0569213 , 1.73550052, 2.94215733])
6290+
Coordinates:
6291+
* x (x) int64 0 1 2
6292+
param <U13 'time_constant'
6293+
>>> fit_result["curvefit_coefficients"].sel(param="amplitude")
6294+
<xarray.DataArray 'curvefit_coefficients' (x: 3)>
6295+
array([0.10054889, 0.1963141 , 0.3000358 ])
6296+
Coordinates:
6297+
* x (x) int64 0 1 2
6298+
param <U13 'amplitude'
6299+
62176300
See Also
62186301
--------
62196302
DataArray.polyfit

xarray/core/dataset.py

Lines changed: 75 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -361,17 +361,24 @@ def _initialize_curvefit_params(params, p0, bounds, func_args):
361361
"""Set initial guess and bounds for curvefit.
362362
Priority: 1) passed args 2) func signature 3) scipy defaults
363363
"""
364+
from xarray.core.computation import where
364365

365366
def _initialize_feasible(lb, ub):
366367
# Mimics functionality of scipy.optimize.minpack._initialize_feasible
367368
lb_finite = np.isfinite(lb)
368369
ub_finite = np.isfinite(ub)
369-
p0 = np.nansum(
370-
[
371-
0.5 * (lb + ub) * int(lb_finite & ub_finite),
372-
(lb + 1) * int(lb_finite & ~ub_finite),
373-
(ub - 1) * int(~lb_finite & ub_finite),
374-
]
370+
p0 = where(
371+
lb_finite,
372+
where(
373+
ub_finite,
374+
0.5 * (lb + ub), # both bounds finite
375+
lb + 1, # lower bound finite, upper infinite
376+
),
377+
where(
378+
ub_finite,
379+
ub - 1, # lower bound infinite, upper finite
380+
0, # both bounds infinite
381+
),
375382
)
376383
return p0
377384

@@ -381,9 +388,13 @@ def _initialize_feasible(lb, ub):
381388
if p in func_args and func_args[p].default is not func_args[p].empty:
382389
param_defaults[p] = func_args[p].default
383390
if p in bounds:
384-
bounds_defaults[p] = tuple(bounds[p])
385-
if param_defaults[p] < bounds[p][0] or param_defaults[p] > bounds[p][1]:
386-
param_defaults[p] = _initialize_feasible(bounds[p][0], bounds[p][1])
391+
lb, ub = bounds[p]
392+
bounds_defaults[p] = (lb, ub)
393+
param_defaults[p] = where(
394+
(param_defaults[p] < lb) | (param_defaults[p] > ub),
395+
_initialize_feasible(lb, ub),
396+
param_defaults[p],
397+
)
387398
if p in p0:
388399
param_defaults[p] = p0[p]
389400
return param_defaults, bounds_defaults
@@ -8617,8 +8628,8 @@ def curvefit(
86178628
func: Callable[..., Any],
86188629
reduce_dims: Dims = None,
86198630
skipna: bool = True,
8620-
p0: dict[str, Any] | None = None,
8621-
bounds: dict[str, Any] | None = None,
8631+
p0: dict[str, float | DataArray] | None = None,
8632+
bounds: dict[str, tuple[float | DataArray, float | DataArray]] | None = None,
86228633
param_names: Sequence[str] | None = None,
86238634
kwargs: dict[str, Any] | None = None,
86248635
) -> T_Dataset:
@@ -8649,12 +8660,16 @@ def curvefit(
86498660
Whether to skip missing values when fitting. Default is True.
86508661
p0 : dict-like, optional
86518662
Optional dictionary of parameter names to initial guesses passed to the
8652-
`curve_fit` `p0` arg. If none or only some parameters are passed, the rest will
8653-
be assigned initial values following the default scipy behavior.
8663+
`curve_fit` `p0` arg. If the values are DataArrays, they will be appropriately
8664+
broadcast to the coordinates of the array. If none or only some parameters are
8665+
passed, the rest will be assigned initial values following the default scipy
8666+
behavior.
86548667
bounds : dict-like, optional
8655-
Optional dictionary of parameter names to bounding values passed to the
8656-
`curve_fit` `bounds` arg. If none or only some parameters are passed, the rest
8657-
will be unbounded following the default scipy behavior.
8668+
Optional dictionary of parameter names to tuples of bounding values passed to the
8669+
`curve_fit` `bounds` arg. If any of the bounds are DataArrays, they will be
8670+
appropriately broadcast to the coordinates of the array. If none or only some
8671+
parameters are passed, the rest will be unbounded following the default scipy
8672+
behavior.
86588673
param_names : sequence of hashable, optional
86598674
Sequence of names for the fittable parameters of `func`. If not supplied,
86608675
this will be automatically determined by arguments of `func`. `param_names`
@@ -8721,29 +8736,53 @@ def curvefit(
87218736
"in fitting on scalar data."
87228737
)
87238738

8739+
# Check that initial guess and bounds only contain coordinates that are in preserved_dims
8740+
for param, guess in p0.items():
8741+
if isinstance(guess, DataArray):
8742+
unexpected = set(guess.dims) - set(preserved_dims)
8743+
if unexpected:
8744+
raise ValueError(
8745+
f"Initial guess for '{param}' has unexpected dimensions "
8746+
f"{tuple(unexpected)}. It should only have dimensions that are in data "
8747+
f"dimensions {preserved_dims}."
8748+
)
8749+
for param, (lb, ub) in bounds.items():
8750+
for label, bound in zip(("Lower", "Upper"), (lb, ub)):
8751+
if isinstance(bound, DataArray):
8752+
unexpected = set(bound.dims) - set(preserved_dims)
8753+
if unexpected:
8754+
raise ValueError(
8755+
f"{label} bound for '{param}' has unexpected dimensions "
8756+
f"{tuple(unexpected)}. It should only have dimensions that are in data "
8757+
f"dimensions {preserved_dims}."
8758+
)
8759+
87248760
# Broadcast all coords with each other
87258761
coords_ = broadcast(*coords_)
87268762
coords_ = [
87278763
coord.broadcast_like(self, exclude=preserved_dims) for coord in coords_
87288764
]
8765+
n_coords = len(coords_)
87298766

87308767
params, func_args = _get_func_args(func, param_names)
87318768
param_defaults, bounds_defaults = _initialize_curvefit_params(
87328769
params, p0, bounds, func_args
87338770
)
87348771
n_params = len(params)
8735-
kwargs.setdefault("p0", [param_defaults[p] for p in params])
8736-
kwargs.setdefault(
8737-
"bounds",
8738-
[
8739-
[bounds_defaults[p][0] for p in params],
8740-
[bounds_defaults[p][1] for p in params],
8741-
],
8742-
)
87438772

8744-
def _wrapper(Y, *coords_, **kwargs):
8773+
def _wrapper(Y, *args, **kwargs):
87458774
# Wrap curve_fit with raveled coordinates and pointwise NaN handling
8746-
x = np.vstack([c.ravel() for c in coords_])
8775+
# *args contains:
8776+
# - the coordinates
8777+
# - initial guess
8778+
# - lower bounds
8779+
# - upper bounds
8780+
coords__ = args[:n_coords]
8781+
p0_ = args[n_coords + 0 * n_params : n_coords + 1 * n_params]
8782+
lb = args[n_coords + 1 * n_params : n_coords + 2 * n_params]
8783+
ub = args[n_coords + 2 * n_params :]
8784+
8785+
x = np.vstack([c.ravel() for c in coords__])
87478786
y = Y.ravel()
87488787
if skipna:
87498788
mask = np.all([np.any(~np.isnan(x), axis=0), ~np.isnan(y)], axis=0)
@@ -8754,7 +8793,7 @@ def _wrapper(Y, *coords_, **kwargs):
87548793
pcov = np.full([n_params, n_params], np.nan)
87558794
return popt, pcov
87568795
x = np.squeeze(x)
8757-
popt, pcov = curve_fit(func, x, y, **kwargs)
8796+
popt, pcov = curve_fit(func, x, y, p0=p0_, bounds=(lb, ub), **kwargs)
87588797
return popt, pcov
87598798

87608799
result = type(self)()
@@ -8764,13 +8803,21 @@ def _wrapper(Y, *coords_, **kwargs):
87648803
else:
87658804
name = f"{str(name)}_"
87668805

8806+
input_core_dims = [reduce_dims_ for _ in range(n_coords + 1)]
8807+
input_core_dims.extend(
8808+
[[] for _ in range(3 * n_params)]
8809+
) # core_dims for p0 and bounds
8810+
87678811
popt, pcov = apply_ufunc(
87688812
_wrapper,
87698813
da,
87708814
*coords_,
8815+
*param_defaults.values(),
8816+
*[b[0] for b in bounds_defaults.values()],
8817+
*[b[1] for b in bounds_defaults.values()],
87718818
vectorize=True,
87728819
dask="parallelized",
8773-
input_core_dims=[reduce_dims_ for d in range(len(coords_) + 1)],
8820+
input_core_dims=input_core_dims,
87748821
output_core_dims=[["param"], ["cov_i", "cov_j"]],
87758822
dask_gufunc_kwargs={
87768823
"output_sizes": {

0 commit comments

Comments
 (0)