Skip to content

Commit 14054b4

Browse files
upgrade to jax 0.4.34 (apple#817)
* upgrade jax to 0.4.34 * add workaround for change to jax cluster autodetection
1 parent 4b559c5 commit 14054b4

File tree

6 files changed

+31
-16
lines changed

6 files changed

+31
-16
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# Change Log
22

3+
## 0.1.4
4+
5+
* Changes
6+
* Upgrade Jax from 0.4.33 to 0.4.34.
7+
38
## 0.1.3
49

510
* Changes

axlearn/common/learner_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,7 @@ def test_learner_masking(test_self):
12191219
pre-existing `CompositeLearner` implementation.
12201220
12211221
"""
1222-
updates = axlearn.common.update_transformation_test.mock_updates()
1222+
updates = axlearn.common.update_transformation_test.mock_updates(state_param_none=False)
12231223

12241224
param_keys = updates.opt_params.keys()
12251225
state_keys = updates.inplace_updates.keys()

axlearn/common/optimizers.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -544,11 +544,13 @@ def update_fn(updates: NestedTensor, state: AddDecayedWeightsState, params: Nest
544544
lr_scale = lr**learning_rate_exponent
545545

546546
param_scales = _weight_decay_scales(params, per_param_scale=per_param_scale)
547+
f = lambda g, p, s: g + weight_decay * lr_scale * p.value * s
547548
updates = jax.tree.map(
548-
lambda g, p, s: g + weight_decay * lr_scale * p.value * s,
549+
lambda x, y, z: None if x is None else f(x, y, z),
549550
updates,
550551
params,
551552
param_scales,
553+
is_leaf=lambda x: x is None,
552554
)
553555
if learning_rate_exponent is None:
554556
updated_state = state
@@ -1882,9 +1884,10 @@ def _smoothed_updates(
18821884
# First compute raw updates.
18831885
raw_updates, pps_tree = _split_update_results(
18841886
jax.tree.map(
1885-
lambda g, s: _raw_updates(grad=g, pps=s),
1887+
lambda g, s: None if g is None else _raw_updates(grad=g, pps=s),
18861888
grads,
18871889
state.pps,
1890+
is_leaf=lambda x: x is None,
18881891
)
18891892
)
18901893
# Clip raw updates if necessary.
@@ -1966,7 +1969,12 @@ def _update2(u: Tensor, param: OptParam):
19661969
context.add_summary("weight_decay_rate", weight_decay * schedule_scale)
19671970
return -schedule_scale * updates_with_wd
19681971

1969-
updates2 = jax.tree.map(lambda u, p: _update2(u, param=p), updates, params)
1972+
updates2 = jax.tree.map(
1973+
lambda u, p: None if u is None else _update2(u, param=p),
1974+
updates,
1975+
params,
1976+
is_leaf=lambda x: x is None,
1977+
)
19701978
return updates2, optax.safe_int32_increment(step)
19711979

19721980
# Stage 1.

axlearn/common/update_transformation_test.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,9 +166,11 @@ def mock_params() -> Nested[Tensor]:
166166
)
167167

168168

169-
def mock_updates() -> axlearn.common.update_transformation.Updates:
169+
def mock_updates(state_param_none: bool = True) -> axlearn.common.update_transformation.Updates:
170170
"""Create an updates object with various semi-reasonable values."""
171171
model_params = mock_params()
172+
if state_param_none:
173+
model_params["state"] = None
172174
opt_params = jax.tree.map(
173175
lambda p: OptParam(
174176
value=p,
@@ -197,6 +199,7 @@ def test_param_values(self):
197199
updates = mock_updates()
198200
actual = updates.param_values()
199201
expected = mock_params()
202+
expected["state"] = None
200203
chex.assert_trees_all_equal_structs(actual, expected)
201204
self.assertNestedAllClose(actual, expected)
202205

@@ -218,12 +221,7 @@ def test_param_specs(self):
218221
weight_decay_scale=0.1,
219222
)
220223
),
221-
state=ParameterSpec(
222-
shape=(2,),
223-
dtype=jnp.int32,
224-
factorization=FactorizationSpec([None]),
225-
weight_decay_scale=0.1,
226-
),
224+
state=None,
227225
more_state=ParameterSpec(
228226
shape=(3,),
229227
dtype=jnp.int32,

axlearn/common/utils_spmd.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ def setup(
8787
num_processes=num_processes,
8888
process_id=process_id,
8989
)
90+
if jax_backend == "gpu":
91+
# jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying
92+
# local_device_ids arg allows us to maintain expected behavior
93+
init_kwargs["local_device_ids"] = list(range(8))
9094

9195
jax.distributed.initialize(**init_kwargs)
9296
_jax_distributed_initialized = True

pyproject.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"
44

55
[project]
66
name = "axlearn"
7-
version = "0.1.3"
7+
version = "0.1.4"
88
description = "AXLearn"
99
readme = "README.md"
1010
requires-python = ">=3.10"
@@ -23,8 +23,8 @@ core = [
2323
"absl-py==2.1.0",
2424
"chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25.
2525
"importlab==0.7", # breaks pytype on 0.8
26-
"jax==0.4.33",
27-
"jaxlib==0.4.33",
26+
"jax==0.4.34",
27+
"jaxlib==0.4.34",
2828
"nltk==3.7", # for text preprocessing
2929
"optax==0.1.7", # optimizers (0.1.0 has known bugs).
3030
"portpicker",
@@ -101,7 +101,7 @@ gcp = [
101101
# Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install.
102102
tpu = [
103103
"axlearn[gcp]",
104-
"jax[tpu]==0.4.33", # must be >=0.4.19 for compat with v5p.
104+
"jax[tpu]==0.4.34", # must be >=0.4.19 for compat with v5p.
105105
]
106106
# Vertex AI tensorboard. TODO(markblee): Merge with `gcp`.
107107
vertexai_tensorboard = [
@@ -125,7 +125,7 @@ dataflow = [
125125
# GPU custom kernel dependency.
126126
gpu = [
127127
"triton==2.1.0",
128-
"jax[cuda12_pip]==0.4.33",
128+
"jax[cuda12]==0.4.34",
129129
]
130130
# Open API inference.
131131
open_api = [

0 commit comments

Comments
 (0)