Skip to content

Commit 6e8aeaf

Browse files
authored
Add chex to dependencies (#151)
1 parent d27cae2 commit 6e8aeaf

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

.github/workflows/nightly.yaml

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ jobs:
3838
run: |
3939
python -m pip install --upgrade pip
4040
python -m pip install pytest pytest-xdist
41-
python -m pip install -U jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets
41+
python -m pip install -U chex jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets
4242
- name: Run tests
4343
run: |
4444
pytest -n auto jax_ai_stack
@@ -147,3 +147,28 @@ jobs:
147147
if: failure() && github.event.pull_request == null
148148
with:
149149
github-token: ${{ secrets.GITHUB_TOKEN }}
150+
151+
chex-nightly:
152+
name: Test with chex nightly
153+
runs-on: ubuntu-latest
154+
timeout-minutes: 10
155+
strategy:
156+
fail-fast: false
157+
steps:
158+
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
159+
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
160+
with:
161+
python-version: 3.12
162+
- name: Install dependencies with chex nightly
163+
run: |
164+
python -m pip install --upgrade pip
165+
python -m pip install .[dev,tfds,grain]
166+
python -m pip install --upgrade 'git+https://github.com/google-deepmind/chex/'
167+
- name: Run tests
168+
run: |
169+
pytest -n auto jax_ai_stack
170+
- name: Notify failed build
171+
uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0
172+
if: failure() && github.event.pull_request == null
173+
with:
174+
github-token: ${{ secrets.GITHUB_TOKEN }}

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ together via the integration tests in this repository. Packages include:
4343
- [ml_dtypes](http://github.com/jax-ml/ml_dtypes): NumPy dtype extensions for machine learning.
4444
- [optax](https://github.com/google-deepmind/optax): gradient processing and optimization in JAX.
4545
- [orbax](https://github.com/google/orbax): checkpointing and persistence utilities for JAX.
46+
- [chex](https://github.com/google-deepmind/chex): utilities for writing reliable JAX code.
4647

4748
### Optional packages
4849

jax_ai_stack/tests/test_chex.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2024 The JAX Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import chex
17+
import jax
18+
import jax.numpy as jnp
19+
20+
21+
class ChexTest(unittest.TestCase):
22+
23+
def test_chex_dataclass(self):
24+
@chex.dataclass
25+
class Params:
26+
x: chex.ArrayDevice
27+
y: chex.ArrayDevice
28+
29+
params = Params(
30+
x=jnp.arange(4),
31+
y=jnp.ones(10),
32+
)
33+
34+
updated = jax.tree.map(lambda x: 2.0 * x, params)
35+
36+
chex.assert_trees_all_close(updated.x, jnp.arange(0, 8, 2))
37+
chex.assert_trees_all_close(updated.y, jnp.full(10, fill_value=2.0))
38+
39+
40+
if __name__ == '__main__':
41+
unittest.main()

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ keywords = []
1515

1616
# pip dependencies of the project
1717
dependencies = [
18+
"chex==0.1.88",
1819
"jax==0.4.38",
1920
"flax==0.10.2",
2021
"ml_dtypes==0.4.0",

0 commit comments

Comments
 (0)