1
1
import collections
2
+ import typing
2
3
from typing import Any , Optional , Sequence , Union
3
4
4
5
import keras
6
+ from keras .src import backend
5
7
6
8
from keras_rs .src import types
7
9
from keras_rs .src .layers .embedding import distributed_embedding_config
@@ -262,31 +264,75 @@ def populate_placement_to_path_to_input_shape(
262
264
263
265
super ().build (input_shapes )
264
266
265
- def call (
267
+ def preprocess (
266
268
self ,
267
269
inputs : types .Nested [types .Tensor ],
268
270
weights : Optional [types .Nested [types .Tensor ]] = None ,
269
271
training : bool = False ,
270
272
) -> types .Nested [types .Tensor ]:
271
- """Lookup features in embedding tables and apply reduction.
273
+ """Preprocesses and reformats the data for consumption by the model.
274
+
275
+ Calling `preprocess` explicitly is only required to enable `sparsecore`
276
+ placement with the JAX backend and `jit_compile = True`. For all other
277
+ cases and backends, explicit use of `preprocess` is optional.
278
+
279
+ In JAX, sparsecore usage requires specially formatted data that depends
280
+ on properties of the available hardware. This data reformatting
281
+ currently does not support jit-compilation, so must be applied _prior_
282
+ to feeding data into a model.
283
+
284
+ An example usage might look like:
285
+ ```
286
+ # Create the embedding layer.
287
+ embedding_layer = DistributedEmbedding(feature_configs)
288
+
289
+ # Add preprocessing to a data input pipeline.
290
+ def training_dataset_generator():
291
+ for (inputs, weights), labels in iter(training_dataset):
292
+ yield embedding_layer.preprocess(
293
+ inputs, weights, training=True
294
+ ), labels
295
+
296
+ preprocessed_training_dataset = training_dataset_generate()
297
+
298
+ # Construct, compile, and fit the model using the preprocessed data.
299
+ model = keras.Sequential(
300
+ [
301
+ embedding_layer,
302
+ keras.layers.Dense(2),
303
+ keras.layers.Dense(3),
304
+ keras.layers.Dense(4),
305
+ ]
306
+ )
307
+ model.compile(optimizer="adam", loss="mse", jit_compile=True)
308
+ model.fit(preprocessed_training_dataset, epochs=10)
309
+ ```
310
+
311
+ For non-JAX backends, preprocessing will bundle together the inputs and
312
+ weights, and separate the inputs by device placement. This step is
313
+ entirely optional.
272
314
273
315
Args:
274
- inputs: A nested structure of 2D tensors to embed and reduce. The
275
- structure must be the same as the `feature_configs` passed
276
- during construction.
277
- weights: An optional nested structure of 2D tensors of weights to
278
- apply before reduction. When present, the structure must be the
279
- same as `inputs` and the shapes must match.
316
+ inputs: Ragged or dense set of sample IDs.
317
+ weights: Optional ragged or dense set of sample weights.
318
+ training: If true, will update internal parameters, such as
319
+ required buffer sizes for the preprocessed data.
280
320
281
321
Returns:
282
- A nested structure of dense 2D tensors, which are the reduced
283
- embeddings from the passed features. The structure is the same as
284
- `inputs`.
322
+ Set of preprocessed inputs that can be fed directly into the
323
+ `inputs` argument of the layer.
285
324
"""
286
-
287
325
# Verify input structure.
288
326
keras .tree .assert_same_structure (self ._feature_configs , inputs )
289
327
328
+ if not self .built :
329
+ input_shapes = keras .tree .map_structure_up_to (
330
+ self ._feature_configs ,
331
+ lambda array : backend .standardize_shape (array .shape ),
332
+ inputs ,
333
+ )
334
+ self .build (input_shapes )
335
+
290
336
# Go from deeply nested structure of inputs to flat inputs.
291
337
flat_inputs = keras .tree .flatten (inputs )
292
338
@@ -308,22 +354,98 @@ def call(
308
354
k : None for k in placement_to_path_to_inputs
309
355
}
310
356
357
+ placement_to_path_to_preprocessed : dict [
358
+ str , dict [str , types .Nested [types .Tensor ]]
359
+ ] = {}
360
+
361
+ # Preprocess for features placed on "sparsecore".
362
+ if "sparsecore" in placement_to_path_to_inputs :
363
+ placement_to_path_to_preprocessed ["sparsecore" ] = (
364
+ self ._sparsecore_preprocess (
365
+ placement_to_path_to_inputs ["sparsecore" ],
366
+ placement_to_path_to_weights ["sparsecore" ],
367
+ training ,
368
+ )
369
+ )
370
+
371
+ # Preprocess for features placed on "default_device".
372
+ if "default_device" in placement_to_path_to_inputs :
373
+ placement_to_path_to_preprocessed ["default_device" ] = (
374
+ self ._default_device_preprocess (
375
+ placement_to_path_to_inputs ["default_device" ],
376
+ placement_to_path_to_weights ["default_device" ],
377
+ training ,
378
+ )
379
+ )
380
+
381
+ # Mark inputs as preprocessed using an extra level of nesting.
382
+ # This is necessary to detect whether inputs are already preprocessed
383
+ # in `call`.
384
+ output = {
385
+ "preprocessed_inputs_per_placement" : (
386
+ placement_to_path_to_preprocessed
387
+ )
388
+ }
389
+ return output
390
+
391
+ def call (
392
+ self ,
393
+ inputs : types .Nested [types .Tensor ],
394
+ weights : Optional [types .Nested [types .Tensor ]] = None ,
395
+ training : bool = False ,
396
+ ) -> types .Nested [types .Tensor ]:
397
+ """Lookup features in embedding tables and apply reduction.
398
+
399
+ Args:
400
+ inputs: A nested structure of 2D tensors to embed and reduce. The
401
+ structure must be the same as the `feature_configs` passed
402
+ during construction. Alternatively, may consist of already
403
+ preprocessed inputs (see `preprocess`).
404
+ weights: An optional nested structure of 2D tensors of weights to
405
+ apply before reduction. When present, the structure must be the
406
+ same as `inputs` and the shapes must match.
407
+ training: Whether we are training or evaluating the model.
408
+
409
+ Returns:
410
+ A nested structure of dense 2D tensors, which are the reduced
411
+ embeddings from the passed features. The structure is the same as
412
+ `inputs`.
413
+ """
414
+ preprocessed_inputs = inputs
415
+ # Preprocess if not already done.
416
+ if (
417
+ not isinstance (inputs , dict )
418
+ or "preprocessed_inputs_per_placement" not in inputs
419
+ ):
420
+ preprocessed_inputs = self .preprocess (inputs , weights , training )
421
+
422
+ preprocessed_inputs = typing .cast (
423
+ dict [str , dict [str , dict [str , types .Tensor ]]], preprocessed_inputs
424
+ )
425
+ # Placement -> path -> preprocessed inputs.
426
+ preprocessed_inputs = preprocessed_inputs [
427
+ "preprocessed_inputs_per_placement"
428
+ ]
429
+
311
430
placement_to_path_to_outputs = {}
312
431
313
432
# Call for features placed on "sparsecore".
314
- if "sparsecore" in placement_to_path_to_inputs :
433
+ if "sparsecore" in preprocessed_inputs :
434
+ paths_to_inputs_and_weights = preprocessed_inputs ["sparsecore" ]
315
435
placement_to_path_to_outputs ["sparsecore" ] = self ._sparsecore_call (
316
- placement_to_path_to_inputs [ "sparsecore " ],
317
- placement_to_path_to_weights [ "sparsecore" ] ,
436
+ paths_to_inputs_and_weights [ "inputs " ],
437
+ paths_to_inputs_and_weights . get ( "weights" , None ) ,
318
438
training ,
319
439
)
320
440
321
441
# Call for features placed on "default_device".
322
- if "default_device" in placement_to_path_to_inputs :
442
+ if "default_device" in preprocessed_inputs :
443
+ paths_to_inputs_and_weights = preprocessed_inputs ["default_device" ]
323
444
placement_to_path_to_outputs ["default_device" ] = (
324
445
self ._default_device_call (
325
- placement_to_path_to_inputs ["default_device" ],
326
- placement_to_path_to_weights ["default_device" ],
446
+ paths_to_inputs_and_weights ["inputs" ],
447
+ paths_to_inputs_and_weights .get ("weights" , None ),
448
+ training ,
327
449
)
328
450
)
329
451
@@ -389,6 +511,19 @@ def _default_device_build(
389
511
if not embedding_layer .built :
390
512
embedding_layer .build (input_shape )
391
513
514
+ def _default_device_preprocess (
515
+ self ,
516
+ inputs : dict [str , types .Tensor ],
517
+ weights : Optional [dict [str , types .Tensor ]],
518
+ training : bool = False ,
519
+ ) -> dict [str , types .Tensor ]:
520
+ del training
521
+ output : dict [str , types .Tensor ] = {"inputs" : inputs }
522
+ if weights is not None :
523
+ output ["weights" ] = weights
524
+
525
+ return output
526
+
392
527
def _default_device_call (
393
528
self ,
394
529
inputs : dict [str , types .Tensor ],
@@ -434,6 +569,19 @@ def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
434
569
del input_shapes
435
570
raise self ._unsupported_placement_error ("sparsecore" )
436
571
572
+ def _sparsecore_preprocess (
573
+ self ,
574
+ inputs : dict [str , types .Tensor ],
575
+ weights : Optional [dict [str , types .Tensor ]],
576
+ training : bool = False ,
577
+ ) -> dict [str , types .Tensor ]:
578
+ del training
579
+ output : dict [str , types .Tensor ] = {"inputs" : inputs }
580
+ if weights is not None :
581
+ output ["weights" ] = weights
582
+
583
+ return output
584
+
437
585
def _sparsecore_call (
438
586
self ,
439
587
inputs : dict [str , types .Tensor ],
0 commit comments