Skip to content

ENH: Implement nan_to_num function #398

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
expand_dims
isclose
kron
nan_to_num
nunique
one_hot
pad
Expand Down
3 changes: 2 additions & 1 deletion src/array_api_extra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Extra array functions built on top of the array API standard."""

from ._delegation import isclose, one_hot, pad
from ._delegation import isclose, nan_to_num, one_hot, pad
from ._lib._at import at
from ._lib._funcs import (
apply_where,
Expand Down Expand Up @@ -33,6 +33,7 @@
"isclose",
"kron",
"lazy_apply",
"nan_to_num",
"nunique",
"one_hot",
"pad",
Expand Down
81 changes: 80 additions & 1 deletion src/array_api_extra/_delegation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ._lib._utils._helpers import asarrays
from ._lib._utils._typing import Array, DType

__all__ = ["isclose", "one_hot", "pad"]
__all__ = ["isclose", "nan_to_num", "one_hot", "pad"]


def isclose(
Expand Down Expand Up @@ -113,6 +113,85 @@ def isclose(
return _funcs.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan, xp=xp)


def nan_to_num(
x: Array | float | complex,
/,
*,
fill_value: int | float = 0.0,
xp: ModuleType | None = None,
) -> Array:
"""
Replace NaN with zero and infinity with large finite numbers (default behaviour).

If `x` is inexact, NaN is replaced by zero or by the user defined value in
`fill_value` keyword, infinity is replaced by the largest finite floating
point values representable by ``x.dtype`` and -infinity is replaced by the
most negative finite floating point values representable by ``x.dtype``.

For complex dtypes, the above is applied to each of the real and
imaginary components of `x` separately.

Parameters
----------
x : array | float | complex
Input data.
fill_value : int | float, optional
Value to be used to fill NaN values. If no value is passed
then NaN values will be replaced with 0.0.
xp : array_namespace, optional
The standard-compatible namespace for `x`. Default: infer.

Returns
-------
array
`x`, with the non-finite values replaced.

See Also
--------
array_api.isnan : Shows which elements are Not a Number (NaN).

Examples
--------
>>> import array_api_extra as xpx
>>> import array_api_strict as xp
>>> xpx.nan_to_num(xp.inf)
1.7976931348623157e+308
>>> xpx.nan_to_num(-xp.inf)
-1.7976931348623157e+308
>>> xpx.nan_to_num(xp.nan)
0.0
>>> x = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
>>> xpx.nan_to_num(x)
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
-1.28000000e+002, 1.28000000e+002])
>>> y = xp.asarray([complex(xp.inf, xp.nan), xp.nan, complex(xp.nan, xp.inf)])
array([ 1.79769313e+308, -1.79769313e+308, 0.00000000e+000, # may vary
-1.28000000e+002, 1.28000000e+002])
>>> xpx.nan_to_num(y)
array([ 1.79769313e+308 +0.00000000e+000j, # may vary
0.00000000e+000 +0.00000000e+000j,
0.00000000e+000 +1.79769313e+308j])
"""
if isinstance(fill_value, complex):
msg = "Complex fill values are not supported."
raise TypeError(msg)

xp = array_namespace(x) if xp is None else xp

# for scalars we want to output an array
y = xp.asarray(x)

if (
is_cupy_namespace(xp)
or is_jax_namespace(xp)
or is_numpy_namespace(xp)
or is_torch_namespace(xp)
):
return xp.nan_to_num(y, nan=fill_value)

return _funcs.nan_to_num(y, fill_value=fill_value, xp=xp)


def one_hot(
x: Array,
/,
Expand Down
41 changes: 41 additions & 0 deletions src/array_api_extra/_lib/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,47 @@ def kron(
return xp.reshape(result, res_shape)


def nan_to_num( # numpydoc ignore=PR01,RT01
x: Array,
/,
fill_value: int | float = 0.0,
*,
xp: ModuleType,
) -> Array:
"""See docstring in `array_api_extra._delegation.py`."""

def perform_replacements( # numpydoc ignore=PR01,RT01
x: Array,
fill_value: int | float,
xp: ModuleType,
) -> Array:
"""Internal function to perform the replacements."""
x = xp.where(xp.isnan(x), fill_value, x)

# convert infinities to finite values
finfo = xp.finfo(x.dtype)
idx_posinf = xp.isinf(x) & ~xp.signbit(x)
idx_neginf = xp.isinf(x) & xp.signbit(x)
x = xp.where(idx_posinf, finfo.max, x)
return xp.where(idx_neginf, finfo.min, x)

if xp.isdtype(x.dtype, "complex floating"):
return perform_replacements(
xp.real(x),
fill_value,
xp,
) + 1j * perform_replacements(
xp.imag(x),
fill_value,
xp,
)

if xp.isdtype(x.dtype, "numeric"):
return perform_replacements(x, fill_value, xp)

return x


def nunique(x: Array, /, *, xp: ModuleType | None = None) -> Array:
"""
Count the number of unique elements in an array.
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,11 @@ def device(
if library == Backend.TORCH_GPU:
return xp.device("cpu")
return get_device(xp.empty(0))


@pytest.fixture
def infinity(library: Backend) -> float:
"""Retrieve the positive infinity value for the given backend."""
if library in (Backend.TORCH, Backend.TORCH_GPU):
return 3.4028235e38
return 1.7976931348623157e308
136 changes: 136 additions & 0 deletions tests/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
expand_dims,
isclose,
kron,
nan_to_num,
nunique,
one_hot,
pad,
Expand All @@ -40,6 +41,7 @@
lazy_xp_function(create_diagonal)
lazy_xp_function(expand_dims)
lazy_xp_function(kron)
lazy_xp_function(nan_to_num)
lazy_xp_function(nunique)
lazy_xp_function(one_hot)
lazy_xp_function(pad)
Expand Down Expand Up @@ -941,6 +943,140 @@ def test_xp(self, xp: ModuleType):
xp_assert_equal(kron(a, b, xp=xp), k)


class TestNanToNum:
def test_bool(self, xp: ModuleType) -> None:
a = xp.asarray([True])
xp_assert_equal(nan_to_num(a, xp=xp), a)

def test_scalar_pos_inf(self, xp: ModuleType, infinity: float) -> None:
a = xp.inf
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(infinity))

def test_scalar_neg_inf(self, xp: ModuleType, infinity: float) -> None:
a = -xp.inf
xp_assert_equal(nan_to_num(a, xp=xp), -xp.asarray(infinity))

def test_scalar_nan(self, xp: ModuleType) -> None:
a = xp.nan
xp_assert_equal(nan_to_num(a, xp=xp), xp.asarray(0.0))

def test_real(self, xp: ModuleType, infinity: float) -> None:
a = xp.asarray([xp.inf, -xp.inf, xp.nan, -128, 128])
xp_assert_equal(
nan_to_num(a, xp=xp),
xp.asarray(
[
infinity,
-infinity,
0.0,
-128,
128,
]
),
)

def test_complex(self, xp: ModuleType, infinity: float) -> None:
a = xp.asarray(
[
complex(xp.inf, xp.nan),
xp.nan,
complex(xp.nan, xp.inf),
]
)
xp_assert_equal(
nan_to_num(a),
xp.asarray([complex(infinity, 0), complex(0, 0), complex(0, infinity)]),
)

def test_empty_array(self, xp: ModuleType) -> None:
a = xp.asarray([], dtype=xp.float32) # forced dtype due to torch
xp_assert_equal(nan_to_num(a, xp=xp), a)
assert xp.isdtype(nan_to_num(a, xp=xp).dtype, xp.float32)

@pytest.mark.parametrize(
("in_vals", "fill_value", "out_vals"),
[
([1, 2, np.nan, 4], 3, [1.0, 2.0, 3.0, 4.0]),
([1, 2, np.nan, 4], 3.0, [1.0, 2.0, 3.0, 4.0]),
(
[
complex(1, 1),
complex(2, 2),
complex(np.nan, 0),
complex(4, 4),
],
3,
[
complex(1.0, 1.0),
complex(2.0, 2.0),
complex(3.0, 0.0),
complex(4.0, 4.0),
],
),
(
[
complex(1, 1),
complex(2, 2),
complex(0, np.nan),
complex(4, 4),
],
3.0,
[
complex(1.0, 1.0),
complex(2.0, 2.0),
complex(0.0, 3.0),
complex(4.0, 4.0),
],
),
(
[
complex(1, 1),
complex(2, 2),
complex(np.nan, np.nan),
complex(4, 4),
],
3.0,
[
complex(1.0, 1.0),
complex(2.0, 2.0),
complex(3.0, 3.0),
complex(4.0, 4.0),
],
),
],
)
def test_fill_value_success(
self,
xp: ModuleType,
in_vals: Array,
fill_value: int | float,
out_vals: Array,
) -> None:
a = xp.asarray(in_vals)
xp_assert_equal(
nan_to_num(a, fill_value=fill_value, xp=xp),
xp.asarray(out_vals),
)

def test_fill_value_failure(self, xp: ModuleType) -> None:
a = xp.asarray(
[
complex(1, 1),
complex(xp.nan, xp.nan),
complex(3, 3),
]
)
with pytest.raises(
TypeError,
match="Complex fill values are not supported",
):
_ = nan_to_num(
a,
fill_value=complex(2, 2), # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
xp=xp,
)


class TestNUnique:
def test_simple(self, xp: ModuleType):
a = xp.asarray([[1, 1], [0, 2], [2, 2]])
Expand Down