File tree 4 files changed +69
-1
lines changed 4 files changed +69
-1
lines changed Original file line number Diff line number Diff line change 38
38
run : |
39
39
python -m pip install --upgrade pip
40
40
python -m pip install pytest pytest-xdist
41
- python -m pip install -U jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets
41
+ python -m pip install -U chex jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets
42
42
- name : Run tests
43
43
run : |
44
44
pytest -n auto jax_ai_stack
@@ -147,3 +147,28 @@ jobs:
147
147
if : failure() && github.event.pull_request == null
148
148
with :
149
149
github-token : ${{ secrets.GITHUB_TOKEN }}
150
+
151
+ chex-nightly :
152
+ name : Test with chex nightly
153
+ runs-on : ubuntu-latest
154
+ timeout-minutes : 10
155
+ strategy :
156
+ fail-fast : false
157
+ steps :
158
+ - uses : actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
159
+ - uses : actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
160
+ with :
161
+ python-version : 3.12
162
+ - name : Install dependencies with chex nightly
163
+ run : |
164
+ python -m pip install --upgrade pip
165
+ python -m pip install .[dev,tfds,grain]
166
+ python -m pip install --upgrade 'git+https://github.com/google-deepmind/chex/'
167
+ - name : Run tests
168
+ run : |
169
+ pytest -n auto jax_ai_stack
170
+ - name : Notify failed build
171
+ uses : jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0
172
+ if : failure() && github.event.pull_request == null
173
+ with :
174
+ github-token : ${{ secrets.GITHUB_TOKEN }}
Original file line number Diff line number Diff line change @@ -43,6 +43,7 @@ together via the integration tests in this repository. Packages include:
43
43
- [ ml_dtypes] ( http://github.com/jax-ml/ml_dtypes ) : NumPy dtype extensions for machine learning.
44
44
- [ optax] ( https://github.com/google-deepmind/optax ) : gradient processing and optimization in JAX.
45
45
- [ orbax] ( https://github.com/google/orbax ) : checkpointing and persistence utilities for JAX.
46
+ - [ chex] ( https://github.com/google-deepmind/chex ) : utilities for writing reliable JAX code.
46
47
47
48
### Optional packages
48
49
Original file line number Diff line number Diff line change
1
+ # Copyright 2024 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ import chex
17
+ import jax
18
+ import jax .numpy as jnp
19
+
20
+
21
+ class ChexTest (unittest .TestCase ):
22
+
23
+ def test_chex_dataclass (self ):
24
+ @chex .dataclass
25
+ class Params :
26
+ x : chex .ArrayDevice
27
+ y : chex .ArrayDevice
28
+
29
+ params = Params (
30
+ x = jnp .arange (4 ),
31
+ y = jnp .ones (10 ),
32
+ )
33
+
34
+ updated = jax .tree .map (lambda x : 2.0 * x , params )
35
+
36
+ chex .assert_trees_all_close (updated .x , jnp .arange (0 , 8 , 2 ))
37
+ chex .assert_trees_all_close (updated .y , jnp .full (10 , fill_value = 2.0 ))
38
+
39
+
40
+ if __name__ == '__main__' :
41
+ unittest .main ()
Original file line number Diff line number Diff line change @@ -15,6 +15,7 @@ keywords = []
15
15
16
16
# pip dependencies of the project
17
17
dependencies = [
18
+ " chex==0.1.88" ,
18
19
" jax==0.4.38" ,
19
20
" flax==0.10.2" ,
20
21
" ml_dtypes==0.4.0" ,
You can’t perform that action at this time.
0 commit comments