Skip to content

Commit 59e5123

Browse files
committed
Move preprocess to base distributed embedding class.
Also updated some minor issues encountered in testing on sparsecores.
1 parent bdb77cb commit 59e5123

10 files changed

+905
-221
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 166 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import collections
2+
import typing
23
from typing import Any, Optional, Sequence, Union
34

45
import keras
6+
from keras.src import backend
57

68
from keras_rs.src import types
79
from keras_rs.src.layers.embedding import distributed_embedding_config
@@ -262,31 +264,75 @@ def populate_placement_to_path_to_input_shape(
262264

263265
super().build(input_shapes)
264266

265-
def call(
267+
def preprocess(
266268
self,
267269
inputs: types.Nested[types.Tensor],
268270
weights: Optional[types.Nested[types.Tensor]] = None,
269271
training: bool = False,
270272
) -> types.Nested[types.Tensor]:
271-
"""Lookup features in embedding tables and apply reduction.
273+
"""Preprocesses and reformats the data for consumption by the model.
274+
275+
Calling `preprocess` explicitly is only required to enable `sparsecore`
276+
placement with the JAX backend and `jit_compile = True`. For all other
277+
cases and backends, explicit use of `preprocess` is optional.
278+
279+
In JAX, sparsecore usage requires specially formatted data that depends
280+
on properties of the available hardware. This data reformatting
281+
currently does not support jit-compilation, so must be applied _prior_
282+
to feeding data into a model.
283+
284+
An example usage might look like:
285+
```
286+
# Create the embedding layer.
287+
embedding_layer = DistributedEmbedding(feature_configs)
288+
289+
# Add preprocessing to a data input pipeline.
290+
def training_dataset_generator():
291+
for (inputs, weights), labels in iter(training_dataset):
292+
yield embedding_layer.preprocess(
293+
inputs, weights, training=True
294+
), labels
295+
296+
preprocessed_training_dataset = training_dataset_generate()
297+
298+
# Construct, compile, and fit the model using the preprocessed data.
299+
model = keras.Sequential(
300+
[
301+
embedding_layer,
302+
keras.layers.Dense(2),
303+
keras.layers.Dense(3),
304+
keras.layers.Dense(4),
305+
]
306+
)
307+
model.compile(optimizer="adam", loss="mse", jit_compile=True)
308+
model.fit(preprocessed_training_dataset, epochs=10)
309+
```
310+
311+
For non-JAX backends, preprocessing will bundle together the inputs and
312+
weights, and separate the inputs by device placement. This step is
313+
entirely optional.
272314
273315
Args:
274-
inputs: A nested structure of 2D tensors to embed and reduce. The
275-
structure must be the same as the `feature_configs` passed
276-
during construction.
277-
weights: An optional nested structure of 2D tensors of weights to
278-
apply before reduction. When present, the structure must be the
279-
same as `inputs` and the shapes must match.
316+
inputs: Ragged or dense set of sample IDs.
317+
weights: Optional ragged or dense set of sample weights.
318+
training: If true, will update internal parameters, such as
319+
required buffer sizes for the preprocessed data.
280320
281321
Returns:
282-
A nested structure of dense 2D tensors, which are the reduced
283-
embeddings from the passed features. The structure is the same as
284-
`inputs`.
322+
Set of preprocessed inputs that can be fed directly into the
323+
`inputs` argument of the layer.
285324
"""
286-
287325
# Verify input structure.
288326
keras.tree.assert_same_structure(self._feature_configs, inputs)
289327

328+
if not self.built:
329+
input_shapes = keras.tree.map_structure_up_to(
330+
self._feature_configs,
331+
lambda array: backend.standardize_shape(array.shape),
332+
inputs,
333+
)
334+
self.build(input_shapes)
335+
290336
# Go from deeply nested structure of inputs to flat inputs.
291337
flat_inputs = keras.tree.flatten(inputs)
292338

@@ -308,22 +354,98 @@ def call(
308354
k: None for k in placement_to_path_to_inputs
309355
}
310356

357+
placement_to_path_to_preprocessed: dict[
358+
str, dict[str, types.Nested[types.Tensor]]
359+
] = {}
360+
361+
# Preprocess for features placed on "sparsecore".
362+
if "sparsecore" in placement_to_path_to_inputs:
363+
placement_to_path_to_preprocessed["sparsecore"] = (
364+
self._sparsecore_preprocess(
365+
placement_to_path_to_inputs["sparsecore"],
366+
placement_to_path_to_weights["sparsecore"],
367+
training,
368+
)
369+
)
370+
371+
# Preprocess for features placed on "default_device".
372+
if "default_device" in placement_to_path_to_inputs:
373+
placement_to_path_to_preprocessed["default_device"] = (
374+
self._default_device_preprocess(
375+
placement_to_path_to_inputs["default_device"],
376+
placement_to_path_to_weights["default_device"],
377+
training,
378+
)
379+
)
380+
381+
# Mark inputs as preprocessed using an extra level of nesting.
382+
# This is necessary to detect whether inputs are already preprocessed
383+
# in `call`.
384+
output = {
385+
"preprocessed_inputs_per_placement": (
386+
placement_to_path_to_preprocessed
387+
)
388+
}
389+
return output
390+
391+
def call(
392+
self,
393+
inputs: types.Nested[types.Tensor],
394+
weights: Optional[types.Nested[types.Tensor]] = None,
395+
training: bool = False,
396+
) -> types.Nested[types.Tensor]:
397+
"""Lookup features in embedding tables and apply reduction.
398+
399+
Args:
400+
inputs: A nested structure of 2D tensors to embed and reduce. The
401+
structure must be the same as the `feature_configs` passed
402+
during construction. Alternatively, may consist of already
403+
preprocessed inputs (see `preprocess`).
404+
weights: An optional nested structure of 2D tensors of weights to
405+
apply before reduction. When present, the structure must be the
406+
same as `inputs` and the shapes must match.
407+
training: Whether we are training or evaluating the model.
408+
409+
Returns:
410+
A nested structure of dense 2D tensors, which are the reduced
411+
embeddings from the passed features. The structure is the same as
412+
`inputs`.
413+
"""
414+
preprocessed_inputs = inputs
415+
# Preprocess if not already done.
416+
if (
417+
not isinstance(inputs, dict)
418+
or "preprocessed_inputs_per_placement" not in inputs
419+
):
420+
preprocessed_inputs = self.preprocess(inputs, weights, training)
421+
422+
preprocessed_inputs = typing.cast(
423+
dict[str, dict[str, dict[str, types.Tensor]]], preprocessed_inputs
424+
)
425+
# Placement -> path -> preprocessed inputs.
426+
preprocessed_inputs = preprocessed_inputs[
427+
"preprocessed_inputs_per_placement"
428+
]
429+
311430
placement_to_path_to_outputs = {}
312431

313432
# Call for features placed on "sparsecore".
314-
if "sparsecore" in placement_to_path_to_inputs:
433+
if "sparsecore" in preprocessed_inputs:
434+
paths_to_inputs_and_weights = preprocessed_inputs["sparsecore"]
315435
placement_to_path_to_outputs["sparsecore"] = self._sparsecore_call(
316-
placement_to_path_to_inputs["sparsecore"],
317-
placement_to_path_to_weights["sparsecore"],
436+
paths_to_inputs_and_weights["inputs"],
437+
paths_to_inputs_and_weights.get("weights", None),
318438
training,
319439
)
320440

321441
# Call for features placed on "default_device".
322-
if "default_device" in placement_to_path_to_inputs:
442+
if "default_device" in preprocessed_inputs:
443+
paths_to_inputs_and_weights = preprocessed_inputs["default_device"]
323444
placement_to_path_to_outputs["default_device"] = (
324445
self._default_device_call(
325-
placement_to_path_to_inputs["default_device"],
326-
placement_to_path_to_weights["default_device"],
446+
paths_to_inputs_and_weights["inputs"],
447+
paths_to_inputs_and_weights.get("weights", None),
448+
training,
327449
)
328450
)
329451

@@ -389,6 +511,19 @@ def _default_device_build(
389511
if not embedding_layer.built:
390512
embedding_layer.build(input_shape)
391513

514+
def _default_device_preprocess(
515+
self,
516+
inputs: dict[str, types.Tensor],
517+
weights: Optional[dict[str, types.Tensor]],
518+
training: bool = False,
519+
) -> dict[str, types.Tensor]:
520+
del training
521+
output: dict[str, types.Tensor] = {"inputs": inputs}
522+
if weights is not None:
523+
output["weights"] = weights
524+
525+
return output
526+
392527
def _default_device_call(
393528
self,
394529
inputs: dict[str, types.Tensor],
@@ -434,6 +569,19 @@ def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
434569
del input_shapes
435570
raise self._unsupported_placement_error("sparsecore")
436571

572+
def _sparsecore_preprocess(
573+
self,
574+
inputs: dict[str, types.Tensor],
575+
weights: Optional[dict[str, types.Tensor]],
576+
training: bool = False,
577+
) -> dict[str, types.Tensor]:
578+
del training
579+
output: dict[str, types.Tensor] = {"inputs": inputs}
580+
if weights is not None:
581+
output["weights"] = weights
582+
583+
return output
584+
437585
def _sparsecore_call(
438586
self,
439587
inputs: dict[str, types.Tensor],

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
import tensorflow as tf
1313
from absl import flags
14+
from absl.testing import absltest
1415
from absl.testing import parameterized
1516

1617
from keras_rs.src import testing
@@ -355,8 +356,8 @@ def preprocess(inputs_and_labels):
355356
tf_train_dataset = train_dataset.repeat(16)
356357

357358
def train_dataset_generator():
358-
for inputs in iter(tf_train_dataset):
359-
yield preprocess(inputs)
359+
for inputs_and_labels in iter(tf_train_dataset):
360+
yield preprocess(inputs_and_labels)
360361

361362
train_dataset = train_dataset_generator()
362363

@@ -728,3 +729,7 @@ def test_save_load_model(self):
728729
keras.tree.flatten(output_before), keras.tree.flatten(output_after)
729730
):
730731
self.assertAllClose(before, after)
732+
733+
734+
if __name__ == "__main__":
735+
absltest.main()

keras_rs/src/layers/embedding/jax/config_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def key(self) -> Union[jax.Array, None]:
3636
return None
3737

3838
def __call__(
39-
self, key: Any, shape: types.Shape, dtype: types.DType = jnp.float_
39+
self, key: Any, shape: Any, dtype: Any = jnp.float_
4040
) -> jax.Array:
4141
# Force use of provided key. The JAX backend for random initializers
4242
# forwards the `seed` attribute to the underlying JAX random functions.

0 commit comments

Comments
 (0)