@@ -361,17 +361,24 @@ def _initialize_curvefit_params(params, p0, bounds, func_args):
361
361
"""Set initial guess and bounds for curvefit.
362
362
Priority: 1) passed args 2) func signature 3) scipy defaults
363
363
"""
364
+ from xarray .core .computation import where
364
365
365
366
def _initialize_feasible (lb , ub ):
366
367
# Mimics functionality of scipy.optimize.minpack._initialize_feasible
367
368
lb_finite = np .isfinite (lb )
368
369
ub_finite = np .isfinite (ub )
369
- p0 = np .nansum (
370
- [
371
- 0.5 * (lb + ub ) * int (lb_finite & ub_finite ),
372
- (lb + 1 ) * int (lb_finite & ~ ub_finite ),
373
- (ub - 1 ) * int (~ lb_finite & ub_finite ),
374
- ]
370
+ p0 = where (
371
+ lb_finite ,
372
+ where (
373
+ ub_finite ,
374
+ 0.5 * (lb + ub ), # both bounds finite
375
+ lb + 1 , # lower bound finite, upper infinite
376
+ ),
377
+ where (
378
+ ub_finite ,
379
+ ub - 1 , # lower bound infinite, upper finite
380
+ 0 , # both bounds infinite
381
+ ),
375
382
)
376
383
return p0
377
384
@@ -381,9 +388,13 @@ def _initialize_feasible(lb, ub):
381
388
if p in func_args and func_args [p ].default is not func_args [p ].empty :
382
389
param_defaults [p ] = func_args [p ].default
383
390
if p in bounds :
384
- bounds_defaults [p ] = tuple (bounds [p ])
385
- if param_defaults [p ] < bounds [p ][0 ] or param_defaults [p ] > bounds [p ][1 ]:
386
- param_defaults [p ] = _initialize_feasible (bounds [p ][0 ], bounds [p ][1 ])
391
+ lb , ub = bounds [p ]
392
+ bounds_defaults [p ] = (lb , ub )
393
+ param_defaults [p ] = where (
394
+ (param_defaults [p ] < lb ) | (param_defaults [p ] > ub ),
395
+ _initialize_feasible (lb , ub ),
396
+ param_defaults [p ],
397
+ )
387
398
if p in p0 :
388
399
param_defaults [p ] = p0 [p ]
389
400
return param_defaults , bounds_defaults
@@ -8617,8 +8628,8 @@ def curvefit(
8617
8628
func : Callable [..., Any ],
8618
8629
reduce_dims : Dims = None ,
8619
8630
skipna : bool = True ,
8620
- p0 : dict [str , Any ] | None = None ,
8621
- bounds : dict [str , Any ] | None = None ,
8631
+ p0 : dict [str , float | DataArray ] | None = None ,
8632
+ bounds : dict [str , tuple [ float | DataArray , float | DataArray ] ] | None = None ,
8622
8633
param_names : Sequence [str ] | None = None ,
8623
8634
kwargs : dict [str , Any ] | None = None ,
8624
8635
) -> T_Dataset :
@@ -8649,12 +8660,16 @@ def curvefit(
8649
8660
Whether to skip missing values when fitting. Default is True.
8650
8661
p0 : dict-like, optional
8651
8662
Optional dictionary of parameter names to initial guesses passed to the
8652
- `curve_fit` `p0` arg. If none or only some parameters are passed, the rest will
8653
- be assigned initial values following the default scipy behavior.
8663
+ `curve_fit` `p0` arg. If the values are DataArrays, they will be appropriately
8664
+ broadcast to the coordinates of the array. If none or only some parameters are
8665
+ passed, the rest will be assigned initial values following the default scipy
8666
+ behavior.
8654
8667
bounds : dict-like, optional
8655
- Optional dictionary of parameter names to bounding values passed to the
8656
- `curve_fit` `bounds` arg. If none or only some parameters are passed, the rest
8657
- will be unbounded following the default scipy behavior.
8668
+ Optional dictionary of parameter names to tuples of bounding values passed to the
8669
+ `curve_fit` `bounds` arg. If any of the bounds are DataArrays, they will be
8670
+ appropriately broadcast to the coordinates of the array. If none or only some
8671
+ parameters are passed, the rest will be unbounded following the default scipy
8672
+ behavior.
8658
8673
param_names : sequence of hashable, optional
8659
8674
Sequence of names for the fittable parameters of `func`. If not supplied,
8660
8675
this will be automatically determined by arguments of `func`. `param_names`
@@ -8721,29 +8736,53 @@ def curvefit(
8721
8736
"in fitting on scalar data."
8722
8737
)
8723
8738
8739
+ # Check that initial guess and bounds only contain coordinates that are in preserved_dims
8740
+ for param , guess in p0 .items ():
8741
+ if isinstance (guess , DataArray ):
8742
+ unexpected = set (guess .dims ) - set (preserved_dims )
8743
+ if unexpected :
8744
+ raise ValueError (
8745
+ f"Initial guess for '{ param } ' has unexpected dimensions "
8746
+ f"{ tuple (unexpected )} . It should only have dimensions that are in data "
8747
+ f"dimensions { preserved_dims } ."
8748
+ )
8749
+ for param , (lb , ub ) in bounds .items ():
8750
+ for label , bound in zip (("Lower" , "Upper" ), (lb , ub )):
8751
+ if isinstance (bound , DataArray ):
8752
+ unexpected = set (bound .dims ) - set (preserved_dims )
8753
+ if unexpected :
8754
+ raise ValueError (
8755
+ f"{ label } bound for '{ param } ' has unexpected dimensions "
8756
+ f"{ tuple (unexpected )} . It should only have dimensions that are in data "
8757
+ f"dimensions { preserved_dims } ."
8758
+ )
8759
+
8724
8760
# Broadcast all coords with each other
8725
8761
coords_ = broadcast (* coords_ )
8726
8762
coords_ = [
8727
8763
coord .broadcast_like (self , exclude = preserved_dims ) for coord in coords_
8728
8764
]
8765
+ n_coords = len (coords_ )
8729
8766
8730
8767
params , func_args = _get_func_args (func , param_names )
8731
8768
param_defaults , bounds_defaults = _initialize_curvefit_params (
8732
8769
params , p0 , bounds , func_args
8733
8770
)
8734
8771
n_params = len (params )
8735
- kwargs .setdefault ("p0" , [param_defaults [p ] for p in params ])
8736
- kwargs .setdefault (
8737
- "bounds" ,
8738
- [
8739
- [bounds_defaults [p ][0 ] for p in params ],
8740
- [bounds_defaults [p ][1 ] for p in params ],
8741
- ],
8742
- )
8743
8772
8744
- def _wrapper (Y , * coords_ , ** kwargs ):
8773
+ def _wrapper (Y , * args , ** kwargs ):
8745
8774
# Wrap curve_fit with raveled coordinates and pointwise NaN handling
8746
- x = np .vstack ([c .ravel () for c in coords_ ])
8775
+ # *args contains:
8776
+ # - the coordinates
8777
+ # - initial guess
8778
+ # - lower bounds
8779
+ # - upper bounds
8780
+ coords__ = args [:n_coords ]
8781
+ p0_ = args [n_coords + 0 * n_params : n_coords + 1 * n_params ]
8782
+ lb = args [n_coords + 1 * n_params : n_coords + 2 * n_params ]
8783
+ ub = args [n_coords + 2 * n_params :]
8784
+
8785
+ x = np .vstack ([c .ravel () for c in coords__ ])
8747
8786
y = Y .ravel ()
8748
8787
if skipna :
8749
8788
mask = np .all ([np .any (~ np .isnan (x ), axis = 0 ), ~ np .isnan (y )], axis = 0 )
@@ -8754,7 +8793,7 @@ def _wrapper(Y, *coords_, **kwargs):
8754
8793
pcov = np .full ([n_params , n_params ], np .nan )
8755
8794
return popt , pcov
8756
8795
x = np .squeeze (x )
8757
- popt , pcov = curve_fit (func , x , y , ** kwargs )
8796
+ popt , pcov = curve_fit (func , x , y , p0 = p0_ , bounds = ( lb , ub ), ** kwargs )
8758
8797
return popt , pcov
8759
8798
8760
8799
result = type (self )()
@@ -8764,13 +8803,21 @@ def _wrapper(Y, *coords_, **kwargs):
8764
8803
else :
8765
8804
name = f"{ str (name )} _"
8766
8805
8806
+ input_core_dims = [reduce_dims_ for _ in range (n_coords + 1 )]
8807
+ input_core_dims .extend (
8808
+ [[] for _ in range (3 * n_params )]
8809
+ ) # core_dims for p0 and bounds
8810
+
8767
8811
popt , pcov = apply_ufunc (
8768
8812
_wrapper ,
8769
8813
da ,
8770
8814
* coords_ ,
8815
+ * param_defaults .values (),
8816
+ * [b [0 ] for b in bounds_defaults .values ()],
8817
+ * [b [1 ] for b in bounds_defaults .values ()],
8771
8818
vectorize = True ,
8772
8819
dask = "parallelized" ,
8773
- input_core_dims = [ reduce_dims_ for d in range ( len ( coords_ ) + 1 )] ,
8820
+ input_core_dims = input_core_dims ,
8774
8821
output_core_dims = [["param" ], ["cov_i" , "cov_j" ]],
8775
8822
dask_gufunc_kwargs = {
8776
8823
"output_sizes" : {
0 commit comments