-
Notifications
You must be signed in to change notification settings - Fork 37
TYP: Type annotations, part 4 #313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
362c48a
ad375dc
49f9ba7
4371506
c724a52
14f70af
0a571bc
0172300
8711041
014e20f
7c5408c
924fc3d
5d98aa8
8eb647f
a06d51f
247ee6d
85fce08
983296f
d81b3aa
2954efd
c244872
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -12,56 +12,51 @@ | |||||
import math | ||||||
import sys | ||||||
import warnings | ||||||
from collections.abc import Collection | ||||||
from types import NoneType | ||||||
from typing import ( | ||||||
TYPE_CHECKING, | ||||||
Any, | ||||||
Final, | ||||||
Literal, | ||||||
SupportsIndex, | ||||||
TypeAlias, | ||||||
TypeGuard, | ||||||
TypeVar, | ||||||
cast, | ||||||
overload, | ||||||
) | ||||||
|
||||||
from ._typing import Array, Device, HasShape, Namespace, SupportsArrayNamespace | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
|
||||||
import cupy as cp | ||||||
import dask.array as da | ||||||
import jax | ||||||
import ndonnx as ndx | ||||||
import numpy as np | ||||||
import numpy.typing as npt | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
import torch | ||||||
|
||||||
# TODO: import from typing (requires Python >=3.13) | ||||||
from typing_extensions import TypeIs, TypeVar | ||||||
|
||||||
_SizeT = TypeVar("_SizeT", bound = int | None) | ||||||
from typing_extensions import TypeIs | ||||||
|
||||||
_ZeroGradientArray: TypeAlias = npt.NDArray[np.void] | ||||||
_CupyArray: TypeAlias = Any # cupy has no py.typed | ||||||
|
||||||
_ArrayApiObj: TypeAlias = ( | ||||||
npt.NDArray[Any] | ||||||
| cp.ndarray | ||||||
| da.Array | ||||||
| jax.Array | ||||||
| ndx.Array | ||||||
| sparse.SparseArray | ||||||
| torch.Tensor | ||||||
| SupportsArrayNamespace[Any] | ||||||
| _CupyArray | ||||||
| SupportsArrayNamespace | ||||||
) | ||||||
|
||||||
_API_VERSIONS_OLD: Final = frozenset({"2021.12", "2022.12", "2023.12"}) | ||||||
_API_VERSIONS: Final = _API_VERSIONS_OLD | frozenset({"2024.12"}) | ||||||
|
||||||
|
||||||
def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: | ||||||
def _is_jax_zero_gradient_array(x: object) -> TypeIs[_ZeroGradientArray]: | ||||||
crusaderky marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
"""Return True if `x` is a zero-gradient array. | ||||||
|
||||||
These arrays are a design quirk of Jax that may one day be removed. | ||||||
|
@@ -80,7 +75,7 @@ def _is_jax_zero_gradient_array(x: object) -> TypeGuard[_ZeroGradientArray]: | |||||
) | ||||||
|
||||||
|
||||||
def is_numpy_array(x: object) -> TypeGuard[npt.NDArray[Any]]: | ||||||
def is_numpy_array(x: object) -> TypeIs[npt.NDArray[Any]]: | ||||||
""" | ||||||
Return True if `x` is a NumPy array. | ||||||
|
||||||
|
@@ -137,7 +132,7 @@ def is_cupy_array(x: object) -> bool: | |||||
if "cupy" not in sys.modules: | ||||||
return False | ||||||
|
||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
import cupy as cp | ||||||
|
||||||
# TODO: Should we reject ndarray subclasses? | ||||||
return isinstance(x, cp.ndarray) # pyright: ignore[reportUnknownMemberType] | ||||||
|
@@ -280,13 +275,13 @@ def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]: | |||||
if "sparse" not in sys.modules: | ||||||
return False | ||||||
|
||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
|
||||||
# TODO: Account for other backends. | ||||||
return isinstance(x, sparse.SparseArray) | ||||||
|
||||||
|
||||||
def is_array_api_obj(x: object) -> TypeIs[_ArrayApiObj]: # pyright: ignore[reportUnknownParameterType] | ||||||
def is_array_api_obj(x: object) -> TypeGuard[_ArrayApiObj]: | ||||||
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
""" | ||||||
Return True if `x` is an array API compatible array object. | ||||||
|
||||||
|
@@ -587,7 +582,7 @@ def your_function(x, y): | |||||
|
||||||
namespaces.add(cupy_namespace) | ||||||
else: | ||||||
import cupy as cp # pyright: ignore[reportMissingTypeStubs] | ||||||
import cupy as cp | ||||||
|
||||||
namespaces.add(cp) | ||||||
elif is_torch_array(x): | ||||||
|
@@ -624,14 +619,14 @@ def your_function(x, y): | |||||
if hasattr(jax.numpy, "__array_api_version__"): | ||||||
jnp = jax.numpy | ||||||
else: | ||||||
import jax.experimental.array_api as jnp # pyright: ignore[reportMissingImports] | ||||||
import jax.experimental.array_api as jnp # type: ignore[no-redef] | ||||||
namespaces.add(jnp) | ||||||
elif is_pydata_sparse_array(x): | ||||||
if use_compat is True: | ||||||
_check_api_version(api_version) | ||||||
raise ValueError("`sparse` does not have an array-api-compat wrapper") | ||||||
else: | ||||||
import sparse # pyright: ignore[reportMissingTypeStubs] | ||||||
import sparse | ||||||
jorenham marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
# `sparse` is already an array namespace. We do not have a wrapper | ||||||
# submodule for it. | ||||||
namespaces.add(sparse) | ||||||
|
@@ -640,9 +635,9 @@ def your_function(x, y): | |||||
raise ValueError( | ||||||
"The given array does not have an array-api-compat wrapper" | ||||||
) | ||||||
x = cast("SupportsArrayNamespace[Any]", x) | ||||||
x = cast(SupportsArrayNamespace, x) | ||||||
jorenham marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||||||
namespaces.add(x.__array_namespace__(api_version=api_version)) | ||||||
elif isinstance(x, (bool, int, float, complex, type(None))): | ||||||
elif isinstance(x, int | float | complex | NoneType): | ||||||
|
elif isinstance(x, int | float | complex | NoneType): | |
elif x is None or isinstance(x, int | float | complex): |
(I'll spare you the pseudo-philosophical rant this time)
jorenham marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
casual bugfix 🤔 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's actually logically identical to before. But it was convoluted and rightfully the type checker was complaining.
Uh oh!
There was an error while loading. Please reload this page.