Skip to content

Commit 091472c

Browse files
committed
Move preprocess to base distributed embedding class.
Also updated some minor issues encountered in testing on sparsecores.
1 parent bdb77cb commit 091472c

9 files changed

+904
-220
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 165 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import collections
2+
import typing
23
from typing import Any, Optional, Sequence, Union
34

45
import keras
6+
from keras.src import backend
57

68
from keras_rs.src import types
79
from keras_rs.src.layers.embedding import distributed_embedding_config
@@ -262,31 +264,74 @@ def populate_placement_to_path_to_input_shape(
262264

263265
super().build(input_shapes)
264266

265-
def call(
267+
def preprocess(
266268
self,
267269
inputs: types.Nested[types.Tensor],
268270
weights: Optional[types.Nested[types.Tensor]] = None,
269271
training: bool = False,
270272
) -> types.Nested[types.Tensor]:
271-
"""Lookup features in embedding tables and apply reduction.
273+
"""Preprocesses and reformats the data for consumption by the model.
274+
275+
Calling `preprocess` explicitly is only required to enable `sparsecore`
276+
placement with the JAX backend and `jit_compile = True`. For all other
277+
cases and backends, explicit use of `preprocess` is optional.
278+
279+
In JAX, sparsecore usage requires specially formatted data that depends
280+
on properties of the available hardware. This data reformatting
281+
currently does not support jit-compilation, so must be applied _prior_
282+
to feeding data into a model.
283+
284+
An example usage might look like:
285+
```
286+
# Create the embedding layer.
287+
embedding_layer = DistributedEmbedding(feature_configs)
288+
289+
# Add preprocessing to a data input pipeline.
290+
def training_dataset_generator():
291+
for (inputs, weights), labels in iter(training_dataset):
292+
yield embedding_layer.preprocess(
293+
inputs, weights, training=True
294+
), labels
295+
296+
preprocessed_training_dataset = training_dataset_generate()
297+
298+
# Construct, compile, and fit the model using the preprocessed data.
299+
model = keras.Sequential(
300+
[
301+
embedding_layer,
302+
keras.layers.Dense(2),
303+
keras.layers.Dense(3),
304+
keras.layers.Dense(4),
305+
]
306+
)
307+
model.compile(optimizer="adam", loss="mse", jit_compile=True)
308+
model.fit(preprocessed_training_dataset, epochs=10)
309+
```
310+
311+
For non-JAX backends, preprocessing will bundle together the inputs and
312+
optional weights, and separate the inputs by device placement.
272313
273314
Args:
274-
inputs: A nested structure of 2D tensors to embed and reduce. The
275-
structure must be the same as the `feature_configs` passed
276-
during construction.
277-
weights: An optional nested structure of 2D tensors of weights to
278-
apply before reduction. When present, the structure must be the
279-
same as `inputs` and the shapes must match.
315+
inputs: Ragged or dense set of sample IDs.
316+
weights: Optional ragged or dense set of sample weights.
317+
training: If true, will update internal parameters, such as
318+
required buffer sizes for the preprocessed data.
280319
281320
Returns:
282-
A nested structure of dense 2D tensors, which are the reduced
283-
embeddings from the passed features. The structure is the same as
284-
`inputs`.
321+
Set of preprocessed inputs that can be fed directly into the
322+
`inputs` argument of the layer.
285323
"""
286-
287324
# Verify input structure.
288325
keras.tree.assert_same_structure(self._feature_configs, inputs)
289326

327+
if not self.built:
328+
input_shapes = keras.tree.map_structure_up_to(
329+
self._feature_configs,
330+
lambda array: backend.standardize_shape(array.shape),
331+
inputs,
332+
)
333+
self.build(input_shapes)
334+
290335
# Go from deeply nested structure of inputs to flat inputs.
291336
flat_inputs = keras.tree.flatten(inputs)
292337

@@ -308,22 +353,98 @@ def call(
308353
k: None for k in placement_to_path_to_inputs
309354
}
310355

356+
placement_to_path_to_preprocessed: dict[
357+
str, dict[str, types.Nested[types.Tensor]]
358+
] = {}
359+
360+
# Preprocess for features placed on "sparsecore".
361+
if "sparsecore" in placement_to_path_to_inputs:
362+
placement_to_path_to_preprocessed["sparsecore"] = (
363+
self._sparsecore_preprocess(
364+
placement_to_path_to_inputs["sparsecore"],
365+
placement_to_path_to_weights["sparsecore"],
366+
training,
367+
)
368+
)
369+
370+
# Preprocess for features placed on "default_device".
371+
if "default_device" in placement_to_path_to_inputs:
372+
placement_to_path_to_preprocessed["default_device"] = (
373+
self._default_device_preprocess(
374+
placement_to_path_to_inputs["default_device"],
375+
placement_to_path_to_weights["default_device"],
376+
training,
377+
)
378+
)
379+
380+
# Mark inputs as preprocessed using an extra level of nesting.
381+
# This is necessary to detect whether inputs are already preprocessed
382+
# in `call`.
383+
output = {
384+
"preprocessed_inputs_per_placement": (
385+
placement_to_path_to_preprocessed
386+
)
387+
}
388+
return output
389+
390+
def call(
391+
self,
392+
inputs: types.Nested[types.Tensor],
393+
weights: Optional[types.Nested[types.Tensor]] = None,
394+
training: bool = False,
395+
) -> types.Nested[types.Tensor]:
396+
"""Lookup features in embedding tables and apply reduction.
397+
398+
Args:
399+
inputs: A nested structure of 2D tensors to embed and reduce. The
400+
structure must be the same as the `feature_configs` passed
401+
during construction. Alternatively, may consist of already
402+
preprocessed inputs (see `preprocess`).
403+
weights: An optional nested structure of 2D tensors of weights to
404+
apply before reduction. When present, the structure must be the
405+
same as `inputs` and the shapes must match.
406+
training: Whether we are training or evaluating the model.
407+
408+
Returns:
409+
A nested structure of dense 2D tensors, which are the reduced
410+
embeddings from the passed features. The structure is the same as
411+
`inputs`.
412+
"""
413+
preprocessed_inputs = inputs
414+
# Preprocess if not already done.
415+
if (
416+
not isinstance(inputs, dict)
417+
or "preprocessed_inputs_per_placement" not in inputs
418+
):
419+
preprocessed_inputs = self.preprocess(inputs, weights, training)
420+
421+
preprocessed_inputs = typing.cast(
422+
dict[str, dict[str, dict[str, types.Tensor]]], preprocessed_inputs
423+
)
424+
# Placement -> path -> preprocessed inputs.
425+
preprocessed_inputs = preprocessed_inputs[
426+
"preprocessed_inputs_per_placement"
427+
]
428+
311429
placement_to_path_to_outputs = {}
312430

313431
# Call for features placed on "sparsecore".
314-
if "sparsecore" in placement_to_path_to_inputs:
432+
if "sparsecore" in preprocessed_inputs:
433+
paths_to_inputs_and_weights = preprocessed_inputs["sparsecore"]
315434
placement_to_path_to_outputs["sparsecore"] = self._sparsecore_call(
316-
placement_to_path_to_inputs["sparsecore"],
317-
placement_to_path_to_weights["sparsecore"],
435+
paths_to_inputs_and_weights["inputs"],
436+
paths_to_inputs_and_weights.get("weights", None),
318437
training,
319438
)
320439

321440
# Call for features placed on "default_device".
322-
if "default_device" in placement_to_path_to_inputs:
441+
if "default_device" in preprocessed_inputs:
442+
paths_to_inputs_and_weights = preprocessed_inputs["default_device"]
323443
placement_to_path_to_outputs["default_device"] = (
324444
self._default_device_call(
325-
placement_to_path_to_inputs["default_device"],
326-
placement_to_path_to_weights["default_device"],
445+
paths_to_inputs_and_weights["inputs"],
446+
paths_to_inputs_and_weights.get("weights", None),
447+
training,
327448
)
328449
)
329450

@@ -389,6 +510,19 @@ def _default_device_build(
389510
if not embedding_layer.built:
390511
embedding_layer.build(input_shape)
391512

513+
def _default_device_preprocess(
514+
self,
515+
inputs: dict[str, types.Tensor],
516+
weights: Optional[dict[str, types.Tensor]],
517+
training: bool = False,
518+
) -> dict[str, types.Tensor]:
519+
del training
520+
output: dict[str, types.Tensor] = {"inputs": inputs}
521+
if weights is not None:
522+
output["weights"] = weights
523+
524+
return output
525+
392526
def _default_device_call(
393527
self,
394528
inputs: dict[str, types.Tensor],
@@ -434,6 +568,19 @@ def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:
434568
del input_shapes
435569
raise self._unsupported_placement_error("sparsecore")
436570

571+
def _sparsecore_preprocess(
572+
self,
573+
inputs: dict[str, types.Tensor],
574+
weights: Optional[dict[str, types.Tensor]],
575+
training: bool = False,
576+
) -> dict[str, types.Tensor]:
577+
del training
578+
output: dict[str, types.Tensor] = {"inputs": inputs}
579+
if weights is not None:
580+
output["weights"] = weights
581+
582+
return output
583+
437584
def _sparsecore_call(
438585
self,
439586
inputs: dict[str, types.Tensor],

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import numpy as np
1212
import tensorflow as tf
1313
from absl import flags
14+
from absl.testing import absltest
1415
from absl.testing import parameterized
1516

1617
from keras_rs.src import testing
@@ -355,8 +356,8 @@ def preprocess(inputs_and_labels):
355356
tf_train_dataset = train_dataset.repeat(16)
356357

357358
def train_dataset_generator():
358-
for inputs in iter(tf_train_dataset):
359-
yield preprocess(inputs)
359+
for inputs_and_labels in iter(tf_train_dataset):
360+
yield preprocess(inputs_and_labels)
360361

361362
train_dataset = train_dataset_generator()
362363

@@ -728,3 +729,7 @@ def test_save_load_model(self):
728729
keras.tree.flatten(output_before), keras.tree.flatten(output_after)
729730
):
730731
self.assertAllClose(before, after)
732+
733+
734+
if __name__ == "__main__":
735+
absltest.main()

keras_rs/src/layers/embedding/jax/config_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def key(self) -> Union[jax.Array, None]:
3636
return None
3737

3838
def __call__(
39-
self, key: Any, shape: types.Shape, dtype: types.DType = jnp.float_
39+
self, key: Any, shape: Any, dtype: Any = jnp.float_
4040
) -> jax.Array:
4141
# Force use of provided key. The JAX backend for random initializers
4242
# forwards the `seed` attribute to the underlying JAX random functions.

0 commit comments

Comments
 (0)