Skip to content

Commit cb1cbed

Browse files
committed
✅: jax
Signed-off-by: nstarman <[email protected]>
1 parent d4078d7 commit cb1cbed

File tree

4 files changed

+412
-5
lines changed

4 files changed

+412
-5
lines changed

.github/workflows/ci.yml

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,28 +121,85 @@ jobs:
121121
# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
122122
- name: mypy
123123
run: |
124-
uv sync --no-editable --group=mypy
124+
uv sync --no-editable --group=test_numpy --group=mypy
125125
uv pip install numpy==${{ matrix.numpy-version }}
126126
uv run --no-sync --active \
127127
mypy --tb --no-incremental --cache-dir=/dev/null \
128128
${{ steps.collect-files.outputs.files }}
129129
130130
- name: basedmypy
131131
run: |
132-
uv sync --no-editable --group=basedmypy
132+
uv sync --no-editable --group=test_numpy --group=basedmypy
133133
uv pip install numpy==${{ matrix.numpy-version }}
134134
uv run --no-sync --active \
135135
mypy --tb --no-incremental --cache-dir=/dev/null \
136136
${{ steps.collect-files.outputs.files }}
137137
138138
- name: pyright
139139
run: |
140-
uv sync --no-editable --group=pyright
140+
uv sync --no-editable --group=test_numpy --group=pyright
141141
uv pip install numpy==${{ matrix.numpy-version }}
142142
uv run --no-sync --active \
143143
pyright ${{ steps.collect-files.outputs.files }}
144144
145145
# TODO: (based)pyright
146146

147+
test_integration_jax:
148+
name: integration tests (jax)
149+
runs-on: ubuntu-latest
150+
strategy:
151+
fail-fast: false
152+
matrix:
153+
jax-version: ["0.7.0"]
154+
155+
steps:
156+
- uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
157+
158+
- uses: astral-sh/setup-uv@d9e0f98d3fc6adb07d1e3d37f3043649ddad06a1 # v6.5.0
159+
with:
160+
python-version: "3.11"
161+
activate-environment: true
162+
163+
- name: get major.minor jax version
164+
id: jax-version
165+
run: |
166+
version="${{ matrix.jax-version }}"
167+
major=$(echo "$version" | cut -d. -f1)
168+
minor=$(echo "$version" | cut -d. -f2)
169+
170+
echo "major=$major" >> $GITHUB_OUTPUT
171+
echo "minor=$minor" >> $GITHUB_OUTPUT
172+
173+
- name: collect test files
174+
id: collect-files
175+
run: |
176+
major="${{ steps.jax-version.outputs.major }}"
177+
minor="${{ steps.jax-version.outputs.minor }}"
178+
179+
prefix="tests/integration"
180+
files=""
181+
182+
while IFS= read -r -d '' path; do
183+
fname=$(basename "$path")
184+
fminor=$(echo "$fname" | sed -E "s/test_jax${major}p([0-9]+)\.pyi/\1/")
185+
if [ "$fminor" -le "$minor" ]; then
186+
files="$files $path"
187+
fi
188+
done < <(find "$prefix" -name "test_jax${major}p*.pyi" -print0)
189+
190+
files="${files# }"
191+
echo "files=$files" >> "$GITHUB_OUTPUT"
192+
193+
# NOTE: `uv run --with=...` will be ignored by mypy (and `--isolated` does not help)
194+
- name: mypy
195+
run: |
196+
uv sync --no-editable --group=test_jax --group=mypy
197+
uv pip install jax==${{ matrix.jax-version }}
198+
uv run --no-sync --active \
199+
mypy --tb --no-incremental --cache-dir=/dev/null \
200+
${{ steps.collect-files.outputs.files }}
201+
202+
# TODO: basedmypy/(based)pyright
203+
147204
# TODO: integration tests for array-api-strict
148205
# TODO: integration tests for 3rd party libs such as cupy, pytorch, tensorflow, dask, etc.

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ pyright = [
6868
basedmypy = [
6969
"basedmypy>=2.10.0",
7070
]
71+
test-jax = [
72+
"jax>=0.6.2",
73+
]
7174

7275
[tool.hatch]
7376
version.source = "vcs"

tests/integration/test_jax0p7.pyi

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# mypy: disable-error-code="no-redef, explicit-any"
2+
3+
from types import ModuleType
4+
from typing import Any, TypeAlias
5+
6+
import jax.numpy as jnp
7+
from jax import Array
8+
9+
import array_api_typing as xpt
10+
11+
# DType aliases
12+
F32: TypeAlias = jnp.float32
13+
I32: TypeAlias = jnp.int32
14+
15+
# Define JAX Arrays against which we can test the protocols (JAX's Array type is
16+
# not parametrized by dtype at type level.)
17+
arr: Array
18+
arr_i32: Array
19+
arr_f32: Array
20+
arr_b: Array
21+
22+
# =========================================================
23+
# `xpt.HasArrayNamespace`
24+
25+
# Check assignment
26+
_001: xpt.HasArrayNamespace[ModuleType] = arr
27+
_002: xpt.HasArrayNamespace[ModuleType] = arr_i32
28+
_003: xpt.HasArrayNamespace[ModuleType] = arr_f32
29+
_004: xpt.HasArrayNamespace[ModuleType] = arr_b
30+
31+
# Check `__array_namespace__` method
32+
a_ns: xpt.HasArrayNamespace[ModuleType] = arr
33+
ns: ModuleType = a_ns.__array_namespace__()
34+
35+
# =========================================================
36+
# `xpt.HasDType`
37+
38+
# Check DTypeT_co assignment
39+
_005: xpt.HasDType[Any] = arr
40+
_006: xpt.HasDType[jnp.dtype[I32]] = arr_i32
41+
_007: xpt.HasDType[jnp.dtype[F32]] = arr_f32
42+
_008: xpt.HasDType[jnp.dtype[jnp.bool_]] = arr_b
43+
44+
# =========================================================
45+
# `xpt.Array`
46+
47+
# Check NamespaceT_co assignment
48+
x_ns: xpt.Array[Any, ModuleType] = arr
49+
50+
# Check DTypeT_co assignment
51+
_009: xpt.Array[Any] = arr
52+
_010: xpt.Array[jnp.dtype[I32]] = arr_i32
53+
_011: xpt.Array[jnp.dtype[F32]] = arr_f32
54+
_012: xpt.Array[jnp.dtype[jnp.bool_]] = arr_b

0 commit comments

Comments
 (0)