Skip to content

✅: jax #55

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 85 additions & 15 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -98,38 +98,108 @@ jobs:
echo "major=$major" >> $GITHUB_OUTPUT
echo "minor=$minor" >> $GITHUB_OUTPUT

- name: install deps
run: |
uv sync --no-editable --group=mypy
uv pip install numpy==${{ matrix.numpy-version }}

# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
- name: mypy
- name: collect test files
id: collect-files
run: |
major="${{ steps.numpy-version.outputs.major }}"
minor="${{ steps.numpy-version.outputs.minor }}"

# Directory containing versioned test files
prefix="tests/integration"
files=""

# Find all test files matching the current major version
for path in $(find "$prefix" -name "test_numpy${major}p*.pyi"); do
# Extract file name
while IFS= read -r -d '' path; do
fname=$(basename "$path")
# Parse the minor version from the filename
fminor=$(echo "$fname" | sed -E "s/test_numpy${major}p([0-9]+)\.pyi/\1/")
# Include files where minor version ≤ NumPy's minor
if [ "$fminor" -le "$minor" ]; then
files="$files $path"
fi
done
done < <(find "$prefix" -name "test_numpy${major}p*.pyi" -print0)

files="${files# }"
echo "files=$files" >> "$GITHUB_OUTPUT"

# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
- name: mypy
run: |
uv sync --no-editable --group=test_numpy --group=mypy
uv pip install numpy==${{ matrix.numpy-version }}
uv run --no-sync --active \
mypy --tb --no-incremental --cache-dir=/dev/null \
${{ steps.collect-files.outputs.files }}

- name: basedmypy
run: |
uv sync --no-editable --group=test_numpy --group=basedmypy
uv pip install numpy==${{ matrix.numpy-version }}
uv run --no-sync --active \
mypy --tb --no-incremental --cache-dir=/dev/null \
$files
${{ steps.collect-files.outputs.files }}

- name: pyright
run: |
uv sync --no-editable --group=test_numpy --group=pyright
uv pip install numpy==${{ matrix.numpy-version }}
uv run --no-sync --active \
pyright ${{ steps.collect-files.outputs.files }}

# TODO: (based)pyright

test_integration_jax:
name: integration tests (jax)
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
jax-version: ["0.6.2", "0.7.0"]

steps:
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0

- uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0
with:
python-version: "3.11"
activate-environment: true

- name: get major.minor jax version
id: jax-version
run: |
version="${{ matrix.jax-version }}"
major=$(echo "$version" | cut -d. -f1)
minor=$(echo "$version" | cut -d. -f2)

echo "major=$major" >> $GITHUB_OUTPUT
echo "minor=$minor" >> $GITHUB_OUTPUT

- name: collect test files
id: collect-files
run: |
major="${{ steps.jax-version.outputs.major }}"
minor="${{ steps.jax-version.outputs.minor }}"

prefix="tests/integration"
files=""

while IFS= read -r -d '' path; do
fname=$(basename "$path")
fminor=$(echo "$fname" | sed -E "s/test_jax${major}p([0-9]+)\.pyi/\1/")
if [ "$fminor" -le "$minor" ]; then
files="$files $path"
fi
done < <(find "$prefix" -name "test_jax${major}p*.pyi" -print0)

files="${files# }"
echo "files=$files" >> "$GITHUB_OUTPUT"

# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
- name: mypy
run: |
uv sync --no-editable --group=test_jax --group=mypy
uv pip install jax==${{ matrix.jax-version }}
uv run --no-sync --active \
mypy --tb --no-incremental --cache-dir=/dev/null \
${{ steps.collect-files.outputs.files }}

# TODO: basedmypy/(based)pyright

# TODO: integration tests for array-api-strict
# TODO: integration tests for 3rd party libs such as cupy, pytorch, tensorflow, dask, etc.
9 changes: 9 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,15 @@ test_runtime = [
test_numpy = [
"numpy>=1.25",
]
pyright = [
"pyright>=1.1.403",
]
basedmypy = [
"basedmypy>=2.10.0",
]
test-jax = [
"jax>=0.6.2",
]

[tool.hatch]
version.source = "vcs"
Expand Down
54 changes: 54 additions & 0 deletions tests/integration/test_jax0p6.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# mypy: disable-error-code="no-redef, explicit-any"

from types import ModuleType
from typing import Any, TypeAlias

import jax.numpy as jnp
from jax import Array

import array_api_typing as xpt

# DType aliases
F32: TypeAlias = jnp.float32
I32: TypeAlias = jnp.int32

# Define JAX Arrays against which we can test the protocols (JAX's Array type is
# not parametrized by dtype at type level.)
arr: Array
arr_i32: Array
arr_f32: Array
arr_b: Array

# =========================================================
# `xpt.HasArrayNamespace`

# Check assignment
_001: xpt.HasArrayNamespace[ModuleType] = arr
_002: xpt.HasArrayNamespace[ModuleType] = arr_i32
_003: xpt.HasArrayNamespace[ModuleType] = arr_f32
_004: xpt.HasArrayNamespace[ModuleType] = arr_b

# Check `__array_namespace__` method
a_ns: xpt.HasArrayNamespace[ModuleType] = arr
ns: ModuleType = a_ns.__array_namespace__()

# =========================================================
# `xpt.HasDType`

# Check DTypeT_co assignment
_005: xpt.HasDType[Any] = arr
_006: xpt.HasDType[jnp.dtype[I32]] = arr_i32
_007: xpt.HasDType[jnp.dtype[F32]] = arr_f32
_008: xpt.HasDType[jnp.dtype[jnp.bool_]] = arr_b

# =========================================================
# `xpt.Array`

# Check NamespaceT_co assignment
x_ns: xpt.Array[Any, ModuleType] = arr

# Check DTypeT_co assignment
_009: xpt.Array[Any] = arr
_010: xpt.Array[jnp.dtype[I32]] = arr_i32
_011: xpt.Array[jnp.dtype[F32]] = arr_f32
_012: xpt.Array[jnp.dtype[jnp.bool_]] = arr_b
8 changes: 4 additions & 4 deletions tests/integration/test_numpy1p0.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# mypy: disable-error-code="no-redef"
# mypy: disable-error-code="no-redef, explicit-any"

from types import ModuleType
from typing import Any
Expand All @@ -22,8 +22,8 @@ _: xpt.HasArrayNamespace[ModuleType] = nparr_i32
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32

# Check `__array_namespace__` method
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
ns: ModuleType = a_ns.__array_namespace__()
has_ns: xpt.HasArrayNamespace[ModuleType] = nparr
ns: ModuleType = has_ns.__array_namespace__()

# Incorrect values are caught when using `__array_namespace__` and
# backpropagated to the type of `a_ns`
Expand All @@ -43,7 +43,7 @@ _: xpt.HasDType[dtype[Any]] = nparr_f32
# `xpt.Array`

# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, ModuleType] = nparr
_: xpt.Array[Any, ModuleType] = nparr

# Check DTypeT_co assignment
# Note that `np.array_api` uses dtype objects, not dtype classes, so we can't
Expand Down
32 changes: 14 additions & 18 deletions tests/integration/test_numpy2p0.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# mypy: disable-error-code="no-redef"
# mypy: disable-error-code="no-redef, explicit-any"

from types import ModuleType
from typing import Any, TypeAlias
Expand All @@ -22,36 +22,32 @@ nparr_b: npt.NDArray[np.bool_]
# `xpt.HasArrayNamespace`

# Check assignment
_: xpt.HasArrayNamespace[ModuleType] = nparr
_: xpt.HasArrayNamespace[ModuleType] = nparr_i32
_: xpt.HasArrayNamespace[ModuleType] = nparr_f32
_: xpt.HasArrayNamespace[ModuleType] = nparr_b
_001: xpt.HasArrayNamespace[ModuleType] = nparr
_002: xpt.HasArrayNamespace[ModuleType] = nparr_i32
_003: xpt.HasArrayNamespace[ModuleType] = nparr_f32
_004: xpt.HasArrayNamespace[ModuleType] = nparr_b

# Check `__array_namespace__` method
a_ns: xpt.HasArrayNamespace[ModuleType] = nparr
ns: ModuleType = a_ns.__array_namespace__()

# Incorrect values are caught when using `__array_namespace__` and
# backpropagated to the type of `a_ns`
_: xpt.HasArrayNamespace[dict[str, int]] = nparr # not caught

# =========================================================
# `xpt.HasDType`

# Check DTypeT_co assignment
_: xpt.HasDType[Any] = nparr
_: xpt.HasDType[np.dtype[I32]] = nparr_i32
_: xpt.HasDType[np.dtype[F32]] = nparr_f32
_: xpt.HasDType[np.dtype[np.bool_]] = nparr_b
_005: xpt.HasDType[Any] = nparr
_006: xpt.HasDType[np.dtype[I32]] = nparr_i32
_007: xpt.HasDType[np.dtype[F32]] = nparr_f32
_008: xpt.HasDType[np.dtype[np.bool_]] = nparr_b

# =========================================================
# `xpt.Array`

# Check NamespaceT_co assignment
a_ns: xpt.Array[Any, ModuleType] = nparr
x_ns: xpt.Array[Any, ModuleType] = nparr

# Check DTypeT_co assignment
_: xpt.Array[Any] = nparr
_: xpt.Array[np.dtype[I32]] = nparr_i32
_: xpt.Array[np.dtype[F32]] = nparr_f32
_: xpt.Array[np.dtype[np.bool_]] = nparr_b
_009: xpt.Array[Any] = nparr
_010: xpt.Array[np.dtype[I32]] = nparr_i32
_011: xpt.Array[np.dtype[F32]] = nparr_f32
_012: xpt.Array[np.dtype[np.bool_]] = nparr_b
4 changes: 3 additions & 1 deletion tests/integration/test_numpy2p2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ from test_numpy2p0 import nparr

import array_api_typing as xpt

_: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]
# Incorrect values are caught when using `__array_namespace__` and
# backpropagated to the type of `a_ns`
a_ns: xpt.HasArrayNamespace[dict[str, int]] = nparr # type: ignore[assignment]
Loading
Loading