1
1
import collections
2
+ import typing
2
3
from typing import Any , Optional , Sequence , Union
3
4
4
5
import keras
6
+ from keras .src import backend
5
7
6
8
from keras_rs .src import types
7
9
from keras_rs .src .layers .embedding import distributed_embedding_config
@@ -262,31 +264,74 @@ def populate_placement_to_path_to_input_shape(
262
264
263
265
super ().build (input_shapes )
264
266
265
- def call (
267
+ def preprocess (
266
268
self ,
267
269
inputs : types .Nested [types .Tensor ],
268
270
weights : Optional [types .Nested [types .Tensor ]] = None ,
269
271
training : bool = False ,
270
272
) -> types .Nested [types .Tensor ]:
271
- """Lookup features in embedding tables and apply reduction.
273
+ """Preprocesses and reformats the data for consumption by the model.
274
+
275
+ Calling `preprocess` explicitly is only required to enable `sparsecore`
276
+ placement with the JAX backend and `jit_compile = True`. For all other
277
+ cases and backends, explicit use of `preprocess` is optional.
278
+
279
+ In JAX, sparsecore usage requires specially formatted data that depends
280
+ on properties of the available hardware. This data reformatting
281
+ currently does not support jit-compilation, so must be applied _prior_
282
+ to feeding data into a model.
283
+
284
+ An example usage might look like:
285
+ ```
286
+ # Create the embedding layer.
287
+ embedding_layer = DistributedEmbedding(feature_configs)
288
+
289
+ # Add preprocessing to a data input pipeline.
290
+ def training_dataset_generator():
291
+ for (inputs, weights), labels in iter(training_dataset):
292
+ yield embedding_layer.preprocess(
293
+ inputs, weights, training=True
294
+ ), labels
295
+
296
+ preprocessed_training_dataset = training_dataset_generate()
297
+
298
+ # Construct, compile, and fit the model using the preprocessed data.
299
+ model = keras.Sequential(
300
+ [
301
+ embedding_layer,
302
+ keras.layers.Dense(2),
303
+ keras.layers.Dense(3),
304
+ keras.layers.Dense(4),
305
+ ]
306
+ )
307
+ model.compile(optimizer="adam", loss="mse", jit_compile=True)
308
+ model.fit(preprocessed_training_dataset, epochs=10)
309
+ ```
310
+
311
+ For non-JAX backends, preprocessing will bundle together the inputs and
312
+ optional weights, and separate the inputs by device placement.
272
313
273
314
Args:
274
- inputs: A nested structure of 2D tensors to embed and reduce. The
275
- structure must be the same as the `feature_configs` passed
276
- during construction.
277
- weights: An optional nested structure of 2D tensors of weights to
278
- apply before reduction. When present, the structure must be the
279
- same as `inputs` and the shapes must match.
315
+ inputs: Ragged or dense set of sample IDs.
316
+ weights: Optional ragged or dense set of sample weights.
317
+ training: If true, will update internal parameters, such as
318
+ required buffer sizes for the preprocessed data.
280
319
281
320
Returns:
282
- A nested structure of dense 2D tensors, which are the reduced
283
- embeddings from the passed features. The structure is the same as
284
- `inputs`.
321
+ Set of preprocessed inputs that can be fed directly into the
322
+ `inputs` argument of the layer.
285
323
"""
286
-
287
324
# Verify input structure.
288
325
keras .tree .assert_same_structure (self ._feature_configs , inputs )
289
326
327
+ if not self .built :
328
+ input_shapes = keras .tree .map_structure_up_to (
329
+ self ._feature_configs ,
330
+ lambda array : backend .standardize_shape (array .shape ),
331
+ inputs ,
332
+ )
333
+ self .build (input_shapes )
334
+
290
335
# Go from deeply nested structure of inputs to flat inputs.
291
336
flat_inputs = keras .tree .flatten (inputs )
292
337
@@ -308,22 +353,98 @@ def call(
308
353
k : None for k in placement_to_path_to_inputs
309
354
}
310
355
356
+ placement_to_path_to_preprocessed : dict [
357
+ str , dict [str , types .Nested [types .Tensor ]]
358
+ ] = {}
359
+
360
+ # Preprocess for features placed on "sparsecore".
361
+ if "sparsecore" in placement_to_path_to_inputs :
362
+ placement_to_path_to_preprocessed ["sparsecore" ] = (
363
+ self ._sparsecore_preprocess (
364
+ placement_to_path_to_inputs ["sparsecore" ],
365
+ placement_to_path_to_weights ["sparsecore" ],
366
+ training ,
367
+ )
368
+ )
369
+
370
+ # Preprocess for features placed on "default_device".
371
+ if "default_device" in placement_to_path_to_inputs :
372
+ placement_to_path_to_preprocessed ["default_device" ] = (
373
+ self ._default_device_preprocess (
374
+ placement_to_path_to_inputs ["default_device" ],
375
+ placement_to_path_to_weights ["default_device" ],
376
+ training ,
377
+ )
378
+ )
379
+
380
+ # Mark inputs as preprocessed using an extra level of nesting.
381
+ # This is necessary to detect whether inputs are already preprocessed
382
+ # in `call`.
383
+ output = {
384
+ "preprocessed_inputs_per_placement" : (
385
+ placement_to_path_to_preprocessed
386
+ )
387
+ }
388
+ return output
389
+
390
+ def call (
391
+ self ,
392
+ inputs : types .Nested [types .Tensor ],
393
+ weights : Optional [types .Nested [types .Tensor ]] = None ,
394
+ training : bool = False ,
395
+ ) -> types .Nested [types .Tensor ]:
396
+ """Lookup features in embedding tables and apply reduction.
397
+
398
+ Args:
399
+ inputs: A nested structure of 2D tensors to embed and reduce. The
400
+ structure must be the same as the `feature_configs` passed
401
+ during construction. Alternatively, may consist of already
402
+ preprocessed inputs (see `preprocess`).
403
+ weights: An optional nested structure of 2D tensors of weights to
404
+ apply before reduction. When present, the structure must be the
405
+ same as `inputs` and the shapes must match.
406
+ training: Whether we are training or evaluating the model.
407
+
408
+ Returns:
409
+ A nested structure of dense 2D tensors, which are the reduced
410
+ embeddings from the passed features. The structure is the same as
411
+ `inputs`.
412
+ """
413
+ preprocessed_inputs = inputs
414
+ # Preprocess if not already done.
415
+ if (
416
+ not isinstance (inputs , dict )
417
+ or "preprocessed_inputs_per_placement" not in inputs
418
+ ):
419
+ preprocessed_inputs = self .preprocess (inputs , weights , training )
420
+
421
+ preprocessed_inputs = typing .cast (
422
+ dict [str , dict [str , dict [str , types .Tensor ]]], preprocessed_inputs
423
+ )
424
+ # Placement -> path -> preprocessed inputs.
425
+ preprocessed_inputs = preprocessed_inputs [
426
+ "preprocessed_inputs_per_placement"
427
+ ]
428
+
311
429
placement_to_path_to_outputs = {}
312
430
313
431
# Call for features placed on "sparsecore".
314
- if "sparsecore" in placement_to_path_to_inputs :
432
+ if "sparsecore" in preprocessed_inputs :
433
+ paths_to_inputs_and_weights = preprocessed_inputs ["sparsecore" ]
315
434
placement_to_path_to_outputs ["sparsecore" ] = self ._sparsecore_call (
316
- placement_to_path_to_inputs [ "sparsecore " ],
317
- placement_to_path_to_weights [ "sparsecore" ] ,
435
+ paths_to_inputs_and_weights [ "inputs " ],
436
+ paths_to_inputs_and_weights . get ( "weights" , None ) ,
318
437
training ,
319
438
)
320
439
321
440
# Call for features placed on "default_device".
322
- if "default_device" in placement_to_path_to_inputs :
441
+ if "default_device" in preprocessed_inputs :
442
+ paths_to_inputs_and_weights = preprocessed_inputs ["default_device" ]
323
443
placement_to_path_to_outputs ["default_device" ] = (
324
444
self ._default_device_call (
325
- placement_to_path_to_inputs ["default_device" ],
326
- placement_to_path_to_weights ["default_device" ],
445
+ paths_to_inputs_and_weights ["inputs" ],
446
+ paths_to_inputs_and_weights .get ("weights" , None ),
447
+ training ,
327
448
)
328
449
)
329
450
@@ -389,6 +510,19 @@ def _default_device_build(
389
510
if not embedding_layer .built :
390
511
embedding_layer .build (input_shape )
391
512
513
+ def _default_device_preprocess (
514
+ self ,
515
+ inputs : dict [str , types .Tensor ],
516
+ weights : Optional [dict [str , types .Tensor ]],
517
+ training : bool = False ,
518
+ ) -> dict [str , types .Tensor ]:
519
+ del training
520
+ output : dict [str , types .Tensor ] = {"inputs" : inputs }
521
+ if weights is not None :
522
+ output ["weights" ] = weights
523
+
524
+ return output
525
+
392
526
def _default_device_call (
393
527
self ,
394
528
inputs : dict [str , types .Tensor ],
@@ -434,6 +568,19 @@ def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
434
568
del input_shapes
435
569
raise self ._unsupported_placement_error ("sparsecore" )
436
570
571
+ def _sparsecore_preprocess (
572
+ self ,
573
+ inputs : dict [str , types .Tensor ],
574
+ weights : Optional [dict [str , types .Tensor ]],
575
+ training : bool = False ,
576
+ ) -> dict [str , types .Tensor ]:
577
+ del training
578
+ output : dict [str , types .Tensor ] = {"inputs" : inputs }
579
+ if weights is not None :
580
+ output ["weights" ] = weights
581
+
582
+ return output
583
+
437
584
def _sparsecore_call (
438
585
self ,
439
586
inputs : dict [str , types .Tensor ],
0 commit comments