Skip to content

Commit 1acd558

Browse files
vtavanaVahid Tavanashadvlad-perevezentsev
authored
implement dpnp.piecewise (#2550)
In this PR, `dpnp.piecewise` is implemented. Co-authored-by: Vahid Tavanashad <[email protected]> Co-authored-by: Vladislav Perevezentsev <[email protected]>
1 parent f244f40 commit 1acd558

File tree

9 files changed

+1249
-4
lines changed

9 files changed

+1249
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
* Added implementation of `dpnp.ndarray.data` and `dpnp.ndarray.data.ptr` attributes [#2521](https://github.com/IntelPython/dpnp/pull/2521)
1717
* Added `dpnp.ndarray.__contains__` method [#2534](https://github.com/IntelPython/dpnp/pull/2534)
1818
* Added implementation of `dpnp.linalg.lu_factor` (SciPy-compatible) [#2557](https://github.com/IntelPython/dpnp/pull/2557), [#2565](https://github.com/IntelPython/dpnp/pull/2565)
19+
* Added implementation of `dpnp.piecewise` [#2550](https://github.com/IntelPython/dpnp/pull/2550)
1920

2021
### Changed
2122

dpnp/dpnp_iface_functional.py

Lines changed: 143 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
3737
"""
3838

39+
# pylint: disable=protected-access
3940

4041
from dpctl.tensor._numpy_helper import (
4142
normalize_axis_index,
@@ -44,7 +45,10 @@
4445

4546
import dpnp
4647

47-
__all__ = ["apply_along_axis", "apply_over_axes"]
48+
# pylint: disable=no-name-in-module
49+
from dpnp.dpnp_utils import get_usm_allocations
50+
51+
__all__ = ["apply_along_axis", "apply_over_axes", "piecewise"]
4852

4953

5054
def apply_along_axis(func1d, axis, arr, *args, **kwargs):
@@ -266,3 +270,141 @@ def apply_over_axes(func, a, axes):
266270
)
267271
a = res
268272
return res
273+
274+
275+
def piecewise(x, condlist, funclist):
276+
"""
277+
Evaluate a piecewise-defined function.
278+
279+
Given a set of conditions and corresponding functions, evaluate each
280+
function on the input data wherever its condition is true.
281+
282+
For full documentation refer to :obj:`numpy.piecewise`.
283+
284+
Parameters
285+
----------
286+
x : {dpnp.ndarray, usm_ndarray}
287+
The input domain.
288+
condlist : {sequence of array-like boolean, bool scalars}
289+
Each boolean array/scalar corresponds to a function in `funclist`.
290+
Wherever `condlist[i]` is ``True``, `funclist[i](x)` is used as the
291+
output value.
292+
293+
Each boolean array in `condlist` selects a piece of `x`, and should
294+
therefore be of the same shape as `x`.
295+
296+
The length of `condlist` must correspond to that of `funclist`.
297+
If one extra function is given, i.e. if
298+
``len(funclist) == len(condlist) + 1``, then that extra function
299+
is the default value, used wherever all conditions are ``False``.
300+
funclist : {array-like of scalars}
301+
A constant value is returned wherever corresponding condition of `x`
302+
is ``True``.
303+
304+
Returns
305+
-------
306+
out : dpnp.ndarray
307+
The output is the same shape and type as `x` and is found by
308+
calling the functions in `funclist` on the appropriate portions of `x`,
309+
as defined by the boolean arrays in `condlist`. Portions not covered
310+
by any condition have a default value of ``0``.
311+
312+
Limitations
313+
-----------
314+
Parameters `args` and `kw` are not supported and `funclist` cannot include a
315+
callable functions.
316+
317+
See Also
318+
--------
319+
:obj:`dpnp.choose` : Construct an array from an index array and a set of
320+
arrays to choose from.
321+
:obj:`dpnp.select` : Return an array drawn from elements in `choicelist`,
322+
depending on conditions.
323+
:obj:`dpnp.where` : Return elements from one of two arrays depending
324+
on condition.
325+
326+
Examples
327+
--------
328+
>>> import dpnp as np
329+
330+
Define the signum function, which is -1 for ``x < 0`` and +1 for ``x >= 0``.
331+
332+
>>> x = np.linspace(-2.5, 2.5, 6)
333+
>>> np.piecewise(x, [x < 0, x >= 0], [-1, 1])
334+
array([-1., -1., -1., 1., 1., 1.])
335+
336+
"""
337+
dpnp.check_supported_arrays_type(x)
338+
x_dtype = x.dtype
339+
if dpnp.is_supported_array_type(condlist) and condlist.ndim in [0, 1]:
340+
condlist = [condlist]
341+
elif dpnp.isscalar(condlist) or (
342+
dpnp.isscalar(condlist[0]) and x.ndim != 0
343+
):
344+
# convert scalar to a list of one array
345+
# convert list of scalars to a list of one array
346+
condlist = [
347+
dpnp.full(
348+
x.shape, condlist, usm_type=x.usm_type, sycl_queue=x.sycl_queue
349+
)
350+
]
351+
elif not dpnp.is_supported_array_type(condlist[0]):
352+
# convert list of lists to list of arrays
353+
# convert list of scalars to a list of 0d arrays (for 0d input)
354+
tmp = []
355+
for _, cond in enumerate(condlist):
356+
tmp.append(
357+
dpnp.array(cond, usm_type=x.usm_type, sycl_queue=x.sycl_queue)
358+
)
359+
condlist = tmp
360+
361+
dpnp.check_supported_arrays_type(*condlist)
362+
if dpnp.is_supported_array_type(funclist):
363+
usm_type, exec_q = get_usm_allocations([x, *condlist, funclist])
364+
else:
365+
usm_type, exec_q = get_usm_allocations([x, *condlist])
366+
367+
result = dpnp.empty_like(x, usm_type=usm_type, sycl_queue=exec_q)
368+
369+
condlen = len(condlist)
370+
try:
371+
if isinstance(funclist, str):
372+
raise TypeError("funclist must be a non-string sequence")
373+
funclen = len(funclist)
374+
except TypeError as e:
375+
raise TypeError("funclist must be a sequence of scalars") from e
376+
377+
if condlen == funclen:
378+
# default value is zero
379+
default_value = x_dtype.type(0)
380+
elif condlen + 1 == funclen:
381+
# default value is the last element of funclist
382+
default_value = funclist[-1]
383+
if callable(default_value):
384+
raise NotImplementedError(
385+
"Callable functions are not supported currently"
386+
)
387+
if isinstance(default_value, dpnp.ndarray):
388+
default_value = default_value.astype(x_dtype, copy=False)
389+
else:
390+
default_value = x_dtype.type(default_value)
391+
funclist = funclist[:-1]
392+
else:
393+
raise ValueError(
394+
f"with {condlen} condition(s), either {condlen} or {condlen + 1} "
395+
"functions are expected"
396+
)
397+
398+
for condition, func in zip(condlist, funclist):
399+
if callable(func):
400+
raise NotImplementedError(
401+
"Callable functions are not supported currently"
402+
)
403+
if isinstance(func, dpnp.ndarray):
404+
func = func.astype(x_dtype, copy=False)
405+
else:
406+
func = x_dtype.type(func)
407+
dpnp.where(condition, func, default_value, out=result)
408+
default_value = result
409+
410+
return result

0 commit comments

Comments
 (0)