Skip to content

Commit 5e4fa61

Browse files
authored
Updates to "Customizing fit() with JAX" guide. (#2188)
The starting point for these changes was to use `stateless_compute_loss` instead of `compute_loss` for consistency, since we state that everything in `train_step` should be stateless. - This in turn required plumbing `metrics_variable` and `sample_weight`. - This required adding `keras.utils.unpack_x_y_sample_weight`. These changes make the guide slightly longer, but the code is more generic and closer to the Keras default implementation. Also fixed issue where `test_step` was not returning the metrics correctly.
1 parent f890f49 commit 5e4fa61

File tree

3 files changed

+198
-110
lines changed

3 files changed

+198
-110
lines changed

guides/custom_train_step_in_jax.py

Lines changed: 56 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,10 @@
6060
- We implement a fully-stateless `compute_loss_and_updates()` method
6161
to compute the loss as well as the updated values for the non-trainable
6262
variables of the model. Internally, it calls `stateless_call()` and
63-
the built-in `compute_loss()`.
63+
the built-in `stateless_compute_loss()`.
6464
- We implement a fully-stateless `train_step()` method to compute current
6565
metric values (including the loss) as well as updated values for the
6666
trainable variables, the optimizer variables, and the metric variables.
67-
68-
Note that you can also take into account the `sample_weight` argument by:
69-
70-
- Unpacking the data as `x, y, sample_weight = data`
71-
- Passing `sample_weight` to `compute_loss()`
72-
- Passing `sample_weight` alongside `y` and `y_pred`
73-
to metrics in `stateless_update_state()`
7467
"""
7568

7669

@@ -79,8 +72,10 @@ def compute_loss_and_updates(
7972
self,
8073
trainable_variables,
8174
non_trainable_variables,
75+
metrics_variables,
8276
x,
8377
y,
78+
sample_weight,
8479
training=False,
8580
):
8681
y_pred, non_trainable_variables = self.stateless_call(
@@ -89,8 +84,21 @@ def compute_loss_and_updates(
8984
x,
9085
training=training,
9186
)
92-
loss = self.compute_loss(x, y, y_pred)
93-
return loss, (y_pred, non_trainable_variables)
87+
loss, (
88+
trainable_variables,
89+
non_trainable_variables,
90+
metrics_variables,
91+
) = self.stateless_compute_loss(
92+
trainable_variables,
93+
non_trainable_variables,
94+
metrics_variables,
95+
x=x,
96+
y=y,
97+
y_pred=y_pred,
98+
sample_weight=sample_weight,
99+
training=training,
100+
)
101+
return loss, (y_pred, non_trainable_variables, metrics_variables)
94102

95103
def train_step(self, state, data):
96104
(
@@ -99,25 +107,24 @@ def train_step(self, state, data):
99107
optimizer_variables,
100108
metrics_variables,
101109
) = state
102-
x, y = data
110+
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
103111

104112
# Get the gradient function.
105113
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
106114

107115
# Compute the gradients.
108-
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
116+
(loss, (y_pred, non_trainable_variables, metrics_variables)), grads = grad_fn(
109117
trainable_variables,
110118
non_trainable_variables,
119+
metrics_variables,
111120
x,
112121
y,
122+
sample_weight,
113123
training=True,
114124
)
115125

116126
# Update trainable variables and optimizer variables.
117-
(
118-
trainable_variables,
119-
optimizer_variables,
120-
) = self.optimizer.stateless_apply(
127+
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
121128
optimizer_variables, grads, trainable_variables
122129
)
123130

@@ -129,10 +136,12 @@ def train_step(self, state, data):
129136
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
130137
]
131138
if metric.name == "loss":
132-
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
139+
this_metric_vars = metric.stateless_update_state(
140+
this_metric_vars, loss, sample_weight=sample_weight
141+
)
133142
else:
134143
this_metric_vars = metric.stateless_update_state(
135-
this_metric_vars, y, y_pred
144+
this_metric_vars, y, y_pred, sample_weight=sample_weight
136145
)
137146
logs[metric.name] = metric.stateless_result(this_metric_vars)
138147
new_metrics_vars += this_metric_vars
@@ -186,6 +195,7 @@ def compute_loss_and_updates(
186195
non_trainable_variables,
187196
x,
188197
y,
198+
sample_weight,
189199
training=False,
190200
):
191201
y_pred, non_trainable_variables = self.stateless_call(
@@ -194,7 +204,7 @@ def compute_loss_and_updates(
194204
x,
195205
training=training,
196206
)
197-
loss = self.loss_fn(y, y_pred)
207+
loss = self.loss_fn(y, y_pred, sample_weight=sample_weight)
198208
return loss, (y_pred, non_trainable_variables)
199209

200210
def train_step(self, state, data):
@@ -204,7 +214,7 @@ def train_step(self, state, data):
204214
optimizer_variables,
205215
metrics_variables,
206216
) = state
207-
x, y = data
217+
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
208218

209219
# Get the gradient function.
210220
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
@@ -215,14 +225,12 @@ def train_step(self, state, data):
215225
non_trainable_variables,
216226
x,
217227
y,
228+
sample_weight,
218229
training=True,
219230
)
220231

221232
# Update trainable variables and optimizer variables.
222-
(
223-
trainable_variables,
224-
optimizer_variables,
225-
) = self.optimizer.stateless_apply(
233+
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
226234
optimizer_variables, grads, trainable_variables
227235
)
228236

@@ -231,10 +239,10 @@ def train_step(self, state, data):
231239
mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]
232240

233241
loss_tracker_vars = self.loss_tracker.stateless_update_state(
234-
loss_tracker_vars, loss
242+
loss_tracker_vars, loss, sample_weight=sample_weight
235243
)
236244
mae_metric_vars = self.mae_metric.stateless_update_state(
237-
mae_metric_vars, y, y_pred
245+
mae_metric_vars, y, y_pred, sample_weight=sample_weight
238246
)
239247

240248
logs = {}
@@ -287,7 +295,7 @@ def metrics(self):
287295
class CustomModel(keras.Model):
288296
def test_step(self, state, data):
289297
# Unpack the data.
290-
x, y = data
298+
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
291299
(
292300
trainable_variables,
293301
non_trainable_variables,
@@ -301,21 +309,37 @@ def test_step(self, state, data):
301309
x,
302310
training=False,
303311
)
304-
loss = self.compute_loss(x, y, y_pred)
312+
loss, (
313+
trainable_variables,
314+
non_trainable_variables,
315+
metrics_variables,
316+
) = self.stateless_compute_loss(
317+
trainable_variables,
318+
non_trainable_variables,
319+
metrics_variables,
320+
x=x,
321+
y=y,
322+
y_pred=y_pred,
323+
sample_weight=sample_weight,
324+
training=False,
325+
)
305326

306327
# Update metrics.
307328
new_metrics_vars = []
329+
logs = {}
308330
for metric in self.metrics:
309331
this_metric_vars = metrics_variables[
310332
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
311333
]
312334
if metric.name == "loss":
313-
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
335+
this_metric_vars = metric.stateless_update_state(
336+
this_metric_vars, loss, sample_weight=sample_weight
337+
)
314338
else:
315339
this_metric_vars = metric.stateless_update_state(
316-
this_metric_vars, y, y_pred
340+
this_metric_vars, y, y_pred, sample_weight=sample_weight
317341
)
318-
logs = metric.stateless_result(this_metric_vars)
342+
logs[metric.name] = metric.stateless_result(this_metric_vars)
319343
new_metrics_vars += this_metric_vars
320344

321345
# Return metric logs and updated state variables.
@@ -336,7 +360,7 @@ def test_step(self, state, data):
336360
# Evaluate with our custom test_step
337361
x = np.random.random((1000, 32))
338362
y = np.random.random((1000, 1))
339-
model.evaluate(x, y)
363+
model.evaluate(x, y, return_dict=True)
340364

341365

342366
"""

0 commit comments

Comments
 (0)