|
36 | 36 |
|
37 | 37 | """
|
38 | 38 |
|
| 39 | +# pylint: disable=protected-access |
39 | 40 |
|
40 | 41 | from dpctl.tensor._numpy_helper import (
|
41 | 42 | normalize_axis_index,
|
|
44 | 45 |
|
45 | 46 | import dpnp
|
46 | 47 |
|
47 |
| -__all__ = ["apply_along_axis", "apply_over_axes"] |
| 48 | +# pylint: disable=no-name-in-module |
| 49 | +from dpnp.dpnp_utils import get_usm_allocations |
| 50 | + |
| 51 | +__all__ = ["apply_along_axis", "apply_over_axes", "piecewise"] |
48 | 52 |
|
49 | 53 |
|
50 | 54 | def apply_along_axis(func1d, axis, arr, *args, **kwargs):
|
@@ -266,3 +270,141 @@ def apply_over_axes(func, a, axes):
|
266 | 270 | )
|
267 | 271 | a = res
|
268 | 272 | return res
|
| 273 | + |
| 274 | + |
| 275 | +def piecewise(x, condlist, funclist): |
| 276 | + """ |
| 277 | + Evaluate a piecewise-defined function. |
| 278 | +
|
| 279 | + Given a set of conditions and corresponding functions, evaluate each |
| 280 | + function on the input data wherever its condition is true. |
| 281 | +
|
| 282 | + For full documentation refer to :obj:`numpy.piecewise`. |
| 283 | +
|
| 284 | + Parameters |
| 285 | + ---------- |
| 286 | + x : {dpnp.ndarray, usm_ndarray} |
| 287 | + The input domain. |
| 288 | + condlist : {sequence of array-like boolean, bool scalars} |
| 289 | + Each boolean array/scalar corresponds to a function in `funclist`. |
| 290 | + Wherever `condlist[i]` is ``True``, `funclist[i](x)` is used as the |
| 291 | + output value. |
| 292 | +
|
| 293 | + Each boolean array in `condlist` selects a piece of `x`, and should |
| 294 | + therefore be of the same shape as `x`. |
| 295 | +
|
| 296 | + The length of `condlist` must correspond to that of `funclist`. |
| 297 | + If one extra function is given, i.e. if |
| 298 | + ``len(funclist) == len(condlist) + 1``, then that extra function |
| 299 | + is the default value, used wherever all conditions are ``False``. |
| 300 | + funclist : {array-like of scalars} |
| 301 | + A constant value is returned wherever corresponding condition of `x` |
| 302 | + is ``True``. |
| 303 | +
|
| 304 | + Returns |
| 305 | + ------- |
| 306 | + out : dpnp.ndarray |
| 307 | + The output is the same shape and type as `x` and is found by |
| 308 | + calling the functions in `funclist` on the appropriate portions of `x`, |
| 309 | + as defined by the boolean arrays in `condlist`. Portions not covered |
| 310 | + by any condition have a default value of ``0``. |
| 311 | +
|
| 312 | + Limitations |
| 313 | + ----------- |
| 314 | + Parameters `args` and `kw` are not supported and `funclist` cannot include a |
| 315 | + callable functions. |
| 316 | +
|
| 317 | + See Also |
| 318 | + -------- |
| 319 | + :obj:`dpnp.choose` : Construct an array from an index array and a set of |
| 320 | + arrays to choose from. |
| 321 | + :obj:`dpnp.select` : Return an array drawn from elements in `choicelist`, |
| 322 | + depending on conditions. |
| 323 | + :obj:`dpnp.where` : Return elements from one of two arrays depending |
| 324 | + on condition. |
| 325 | +
|
| 326 | + Examples |
| 327 | + -------- |
| 328 | + >>> import dpnp as np |
| 329 | +
|
| 330 | + Define the signum function, which is -1 for ``x < 0`` and +1 for ``x >= 0``. |
| 331 | +
|
| 332 | + >>> x = np.linspace(-2.5, 2.5, 6) |
| 333 | + >>> np.piecewise(x, [x < 0, x >= 0], [-1, 1]) |
| 334 | + array([-1., -1., -1., 1., 1., 1.]) |
| 335 | +
|
| 336 | + """ |
| 337 | + dpnp.check_supported_arrays_type(x) |
| 338 | + x_dtype = x.dtype |
| 339 | + if dpnp.is_supported_array_type(condlist) and condlist.ndim in [0, 1]: |
| 340 | + condlist = [condlist] |
| 341 | + elif dpnp.isscalar(condlist) or ( |
| 342 | + dpnp.isscalar(condlist[0]) and x.ndim != 0 |
| 343 | + ): |
| 344 | + # convert scalar to a list of one array |
| 345 | + # convert list of scalars to a list of one array |
| 346 | + condlist = [ |
| 347 | + dpnp.full( |
| 348 | + x.shape, condlist, usm_type=x.usm_type, sycl_queue=x.sycl_queue |
| 349 | + ) |
| 350 | + ] |
| 351 | + elif not dpnp.is_supported_array_type(condlist[0]): |
| 352 | + # convert list of lists to list of arrays |
| 353 | + # convert list of scalars to a list of 0d arrays (for 0d input) |
| 354 | + tmp = [] |
| 355 | + for _, cond in enumerate(condlist): |
| 356 | + tmp.append( |
| 357 | + dpnp.array(cond, usm_type=x.usm_type, sycl_queue=x.sycl_queue) |
| 358 | + ) |
| 359 | + condlist = tmp |
| 360 | + |
| 361 | + dpnp.check_supported_arrays_type(*condlist) |
| 362 | + if dpnp.is_supported_array_type(funclist): |
| 363 | + usm_type, exec_q = get_usm_allocations([x, *condlist, funclist]) |
| 364 | + else: |
| 365 | + usm_type, exec_q = get_usm_allocations([x, *condlist]) |
| 366 | + |
| 367 | + result = dpnp.empty_like(x, usm_type=usm_type, sycl_queue=exec_q) |
| 368 | + |
| 369 | + condlen = len(condlist) |
| 370 | + try: |
| 371 | + if isinstance(funclist, str): |
| 372 | + raise TypeError("funclist must be a non-string sequence") |
| 373 | + funclen = len(funclist) |
| 374 | + except TypeError as e: |
| 375 | + raise TypeError("funclist must be a sequence of scalars") from e |
| 376 | + |
| 377 | + if condlen == funclen: |
| 378 | + # default value is zero |
| 379 | + default_value = x_dtype.type(0) |
| 380 | + elif condlen + 1 == funclen: |
| 381 | + # default value is the last element of funclist |
| 382 | + default_value = funclist[-1] |
| 383 | + if callable(default_value): |
| 384 | + raise NotImplementedError( |
| 385 | + "Callable functions are not supported currently" |
| 386 | + ) |
| 387 | + if isinstance(default_value, dpnp.ndarray): |
| 388 | + default_value = default_value.astype(x_dtype, copy=False) |
| 389 | + else: |
| 390 | + default_value = x_dtype.type(default_value) |
| 391 | + funclist = funclist[:-1] |
| 392 | + else: |
| 393 | + raise ValueError( |
| 394 | + f"with {condlen} condition(s), either {condlen} or {condlen + 1} " |
| 395 | + "functions are expected" |
| 396 | + ) |
| 397 | + |
| 398 | + for condition, func in zip(condlist, funclist): |
| 399 | + if callable(func): |
| 400 | + raise NotImplementedError( |
| 401 | + "Callable functions are not supported currently" |
| 402 | + ) |
| 403 | + if isinstance(func, dpnp.ndarray): |
| 404 | + func = func.astype(x_dtype, copy=False) |
| 405 | + else: |
| 406 | + func = x_dtype.type(func) |
| 407 | + dpnp.where(condition, func, default_value, out=result) |
| 408 | + default_value = result |
| 409 | + |
| 410 | + return result |
0 commit comments