6060- We implement a fully-stateless `compute_loss_and_updates()` method
6161to compute the loss as well as the updated values for the non-trainable
6262variables of the model. Internally, it calls `stateless_call()` and
63- the built-in `compute_loss ()`.
63+ the built-in `stateless_compute_loss ()`.
6464- We implement a fully-stateless `train_step()` method to compute current
6565metric values (including the loss) as well as updated values for the
6666trainable variables, the optimizer variables, and the metric variables.
67-
68- Note that you can also take into account the `sample_weight` argument by:
69-
70- - Unpacking the data as `x, y, sample_weight = data`
71- - Passing `sample_weight` to `compute_loss()`
72- - Passing `sample_weight` alongside `y` and `y_pred`
73- to metrics in `stateless_update_state()`
7467"""
7568
7669
@@ -79,8 +72,10 @@ def compute_loss_and_updates(
7972 self ,
8073 trainable_variables ,
8174 non_trainable_variables ,
75+ metrics_variables ,
8276 x ,
8377 y ,
78+ sample_weight ,
8479 training = False ,
8580 ):
8681 y_pred , non_trainable_variables = self .stateless_call (
@@ -89,8 +84,21 @@ def compute_loss_and_updates(
8984 x ,
9085 training = training ,
9186 )
92- loss = self .compute_loss (x , y , y_pred )
93- return loss , (y_pred , non_trainable_variables )
87+ loss , (
88+ trainable_variables ,
89+ non_trainable_variables ,
90+ metrics_variables ,
91+ ) = self .stateless_compute_loss (
92+ trainable_variables ,
93+ non_trainable_variables ,
94+ metrics_variables ,
95+ x = x ,
96+ y = y ,
97+ y_pred = y_pred ,
98+ sample_weight = sample_weight ,
99+ training = training ,
100+ )
101+ return loss , (y_pred , non_trainable_variables , metrics_variables )
94102
95103 def train_step (self , state , data ):
96104 (
@@ -99,25 +107,24 @@ def train_step(self, state, data):
99107 optimizer_variables ,
100108 metrics_variables ,
101109 ) = state
102- x , y = data
110+ x , y , sample_weight = keras . utils . unpack_x_y_sample_weight ( data )
103111
104112 # Get the gradient function.
105113 grad_fn = jax .value_and_grad (self .compute_loss_and_updates , has_aux = True )
106114
107115 # Compute the gradients.
108- (loss , (y_pred , non_trainable_variables )), grads = grad_fn (
116+ (loss , (y_pred , non_trainable_variables , metrics_variables )), grads = grad_fn (
109117 trainable_variables ,
110118 non_trainable_variables ,
119+ metrics_variables ,
111120 x ,
112121 y ,
122+ sample_weight ,
113123 training = True ,
114124 )
115125
116126 # Update trainable variables and optimizer variables.
117- (
118- trainable_variables ,
119- optimizer_variables ,
120- ) = self .optimizer .stateless_apply (
127+ trainable_variables , optimizer_variables = self .optimizer .stateless_apply (
121128 optimizer_variables , grads , trainable_variables
122129 )
123130
@@ -129,10 +136,12 @@ def train_step(self, state, data):
129136 len (new_metrics_vars ) : len (new_metrics_vars ) + len (metric .variables )
130137 ]
131138 if metric .name == "loss" :
132- this_metric_vars = metric .stateless_update_state (this_metric_vars , loss )
139+ this_metric_vars = metric .stateless_update_state (
140+ this_metric_vars , loss , sample_weight = sample_weight
141+ )
133142 else :
134143 this_metric_vars = metric .stateless_update_state (
135- this_metric_vars , y , y_pred
144+ this_metric_vars , y , y_pred , sample_weight = sample_weight
136145 )
137146 logs [metric .name ] = metric .stateless_result (this_metric_vars )
138147 new_metrics_vars += this_metric_vars
@@ -186,6 +195,7 @@ def compute_loss_and_updates(
186195 non_trainable_variables ,
187196 x ,
188197 y ,
198+ sample_weight ,
189199 training = False ,
190200 ):
191201 y_pred , non_trainable_variables = self .stateless_call (
@@ -194,7 +204,7 @@ def compute_loss_and_updates(
194204 x ,
195205 training = training ,
196206 )
197- loss = self .loss_fn (y , y_pred )
207+ loss = self .loss_fn (y , y_pred , sample_weight = sample_weight )
198208 return loss , (y_pred , non_trainable_variables )
199209
200210 def train_step (self , state , data ):
@@ -204,7 +214,7 @@ def train_step(self, state, data):
204214 optimizer_variables ,
205215 metrics_variables ,
206216 ) = state
207- x , y = data
217+ x , y , sample_weight = keras . utils . unpack_x_y_sample_weight ( data )
208218
209219 # Get the gradient function.
210220 grad_fn = jax .value_and_grad (self .compute_loss_and_updates , has_aux = True )
@@ -215,14 +225,12 @@ def train_step(self, state, data):
215225 non_trainable_variables ,
216226 x ,
217227 y ,
228+ sample_weight ,
218229 training = True ,
219230 )
220231
221232 # Update trainable variables and optimizer variables.
222- (
223- trainable_variables ,
224- optimizer_variables ,
225- ) = self .optimizer .stateless_apply (
233+ trainable_variables , optimizer_variables = self .optimizer .stateless_apply (
226234 optimizer_variables , grads , trainable_variables
227235 )
228236
@@ -231,10 +239,10 @@ def train_step(self, state, data):
231239 mae_metric_vars = metrics_variables [len (self .loss_tracker .variables ) :]
232240
233241 loss_tracker_vars = self .loss_tracker .stateless_update_state (
234- loss_tracker_vars , loss
242+ loss_tracker_vars , loss , sample_weight = sample_weight
235243 )
236244 mae_metric_vars = self .mae_metric .stateless_update_state (
237- mae_metric_vars , y , y_pred
245+ mae_metric_vars , y , y_pred , sample_weight = sample_weight
238246 )
239247
240248 logs = {}
@@ -287,7 +295,7 @@ def metrics(self):
287295class CustomModel (keras .Model ):
288296 def test_step (self , state , data ):
289297 # Unpack the data.
290- x , y = data
298+ x , y , sample_weight = keras . utils . unpack_x_y_sample_weight ( data )
291299 (
292300 trainable_variables ,
293301 non_trainable_variables ,
@@ -301,21 +309,37 @@ def test_step(self, state, data):
301309 x ,
302310 training = False ,
303311 )
304- loss = self .compute_loss (x , y , y_pred )
312+ loss , (
313+ trainable_variables ,
314+ non_trainable_variables ,
315+ metrics_variables ,
316+ ) = self .stateless_compute_loss (
317+ trainable_variables ,
318+ non_trainable_variables ,
319+ metrics_variables ,
320+ x = x ,
321+ y = y ,
322+ y_pred = y_pred ,
323+ sample_weight = sample_weight ,
324+ training = False ,
325+ )
305326
306327 # Update metrics.
307328 new_metrics_vars = []
329+ logs = {}
308330 for metric in self .metrics :
309331 this_metric_vars = metrics_variables [
310332 len (new_metrics_vars ) : len (new_metrics_vars ) + len (metric .variables )
311333 ]
312334 if metric .name == "loss" :
313- this_metric_vars = metric .stateless_update_state (this_metric_vars , loss )
335+ this_metric_vars = metric .stateless_update_state (
336+ this_metric_vars , loss , sample_weight = sample_weight
337+ )
314338 else :
315339 this_metric_vars = metric .stateless_update_state (
316- this_metric_vars , y , y_pred
340+ this_metric_vars , y , y_pred , sample_weight = sample_weight
317341 )
318- logs = metric .stateless_result (this_metric_vars )
342+ logs [ metric . name ] = metric .stateless_result (this_metric_vars )
319343 new_metrics_vars += this_metric_vars
320344
321345 # Return metric logs and updated state variables.
@@ -336,7 +360,7 @@ def test_step(self, state, data):
336360# Evaluate with our custom test_step
337361x = np .random .random ((1000 , 32 ))
338362y = np .random .random ((1000 , 1 ))
339- model .evaluate (x , y )
363+ model .evaluate (x , y , return_dict = True )
340364
341365
342366"""
0 commit comments