Skip to content

Commit 3736aaf

Browse files
authored
Add documentation for using DistributedEmbedding with JAX. (#111)
1 parent 9f54191 commit 3736aaf

File tree

2 files changed

+144
-44
lines changed

2 files changed

+144
-44
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 138 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,141 @@ def step(data):
291291
result = strategy.run(step, args=(next(iterator),))
292292
293293
run_loop(iter(dataset))
294+
295+
### Usage with JAX on TPU with SpareCore
296+
297+
#### Setup
298+
299+
To use `DistributedEmbedding` on TPUs with JAX, one must create and set a
300+
Keras `Distribution`.
301+
```
302+
distribution = keras.distribution.DataParallel(devices=jax.device("tpu))
303+
keras.distribution.set_distribution(distribution)
304+
```
305+
306+
#### Inputs
307+
308+
For JAX, inputs can either be dense tensors, or ragged (nested) NumPy
309+
arrays. To enable `jit_compile = True`, one must explicitly call
310+
`layer.preprocess(...)` on the inputs, and then feed the preprocessed
311+
output to the model. See the next section on preprocessing for details.
312+
313+
Ragged input arrays must be ragged in the dimension with index 1. Note that
314+
if weights are passed, each weight tensor must be of the same class as the
315+
inputs for that particular feature and use the exact same ragged row lengths
316+
for ragged tensors. All the output of `DistributedEmbedding` are dense
317+
tensors.
318+
319+
#### Preprocessing
320+
321+
In JAX, SparseCore usage requires specially formatted data that depends
322+
on properties of the available hardware. This data reformatting
323+
currently does not support jit-compilation, so must be applied _prior_
324+
to passing data into a model.
325+
326+
Preprocessing works on dense or ragged NumPy arrays, or on tensors that are
327+
convertible to dense or ragged NumPy arrays like `tf.RaggedTensor`.
328+
329+
One simple way to add preprocessing is to append the function to an input
330+
pipeline by using a python generator.
331+
```python
332+
# Create the embedding layer.
333+
embedding_layer = DistributedEmbedding(feature_configs)
334+
335+
# Add preprocessing to a data input pipeline.
336+
def train_dataset_generator():
337+
for (inputs, weights), labels in iter(train_dataset):
338+
yield embedding_layer.preprocess(
339+
inputs, weights, training=True
340+
), labels
341+
342+
preprocessed_train_dataset = train_dataset_generator()
343+
```
344+
This explicit preprocessing stage combines the input and optional weights,
345+
so the new data can be passed directly into the `inputs` argument of the
346+
layer or model.
347+
348+
#### Usage in a Keras model
349+
350+
Once the global distribution is set and the input preprocessing pipeline
351+
is defined, model training can proceed as normal. For example:
352+
```python
353+
# Construct, compile, and fit the model using the preprocessed data.
354+
model = keras.Sequential(
355+
[
356+
embedding_layer,
357+
keras.layers.Dense(2),
358+
keras.layers.Dense(3),
359+
keras.layers.Dense(4),
360+
]
361+
)
362+
model.compile(optimizer="adam", loss="mse", jit_compile=True)
363+
model.fit(preprocessed_train_dataset, epochs=10)
364+
```
365+
366+
#### Direct invocation
367+
368+
The `DistributedEmbedding` layer can also be invoked directly. Explicit
369+
preprocessing is required when used with JIT compilation.
370+
```python
371+
# Call the layer directly.
372+
activations = embedding_layer(my_inputs, my_weights)
373+
374+
# Call the layer with JIT compilation and explicitly preprocessed inputs.
375+
embedding_layer_jit = jax.jit(embedding_layer)
376+
preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
377+
activations = embedding_layer_jit(preprocessed_inputs)
378+
```
379+
380+
Similarly, for custom training loops, preprocessing must be applied prior
381+
to passing the data to the JIT-compiled training step.
382+
```python
383+
# Create an optimizer and loss function.
384+
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
385+
386+
def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
387+
y_pred, non_trainable_variables = model.stateless_call(
388+
trainable_variables, non_trainable_variables, x, training=True
389+
)
390+
loss = keras.losses.mean_squared_error(y, y_pred)
391+
return loss, non_trainable_variables
392+
393+
grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)
394+
395+
# Create a JIT-compiled training step.
396+
@jax.jit
397+
def train_step(state, x, y):
398+
(
399+
trainable_variables,
400+
non_trainable_variables,
401+
optimizer_variables,
402+
) = state
403+
(loss, non_trainable_variables), grads = grad_fn(
404+
trainable_variables, non_trainable_variables, x, y
405+
)
406+
trainable_variables, optimizer_variables = optimizer.stateless_apply(
407+
optimizer_variables, grads, trainable_variables
408+
)
409+
return loss, (
410+
trainable_variables,
411+
non_trainable_variables,
412+
optimizer_variables,
413+
)
414+
415+
# Build optimizer variables.
416+
optimizer.build(model.trainable_variables)
417+
418+
# Assemble the training state.
419+
trainable_variables = model.trainable_variables
420+
non_trainable_variables = model.non_trainable_variables
421+
optimizer_variables = optimizer.variables
422+
state = trainable_variables, non_trainable_variables, optimizer_variables
423+
424+
# Training loop.
425+
for (inputs, weights), labels in train_dataset:
426+
# Explicitly preprocess the data.
427+
preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
428+
loss, state = train_step(state, preprocessed_inputs, labels)
294429
```
295430
296431
Args:
@@ -471,41 +606,9 @@ def preprocess(
471606
) -> types.Nested[types.Tensor]:
472607
"""Preprocesses and reformats the data for consumption by the model.
473608
474-
Calling `preprocess` explicitly is only required for the JAX backend
475-
to enable `jit_compile = True`. For all other cases and backends,
476-
explicit use of `preprocess` is optional.
477-
478-
In JAX, sparsecore usage requires specially formatted data that depends
479-
on properties of the available hardware. This data reformatting
480-
currently does not support jit-compilation, so must be applied _prior_
481-
to feeding data into a model.
482-
483-
An example usage might look like:
484-
```python
485-
# Create the embedding layer.
486-
embedding_layer = DistributedEmbedding(feature_configs)
487-
488-
# Add preprocessing to a data input pipeline.
489-
def training_dataset_generator():
490-
for (inputs, weights), labels in iter(training_dataset):
491-
yield embedding_layer.preprocess(
492-
inputs, weights, training=True
493-
), labels
494-
495-
preprocessed_training_dataset = training_dataset_generate()
496-
497-
# Construct, compile, and fit the model using the preprocessed data.
498-
model = keras.Sequential(
499-
[
500-
embedding_layer,
501-
keras.layers.Dense(2),
502-
keras.layers.Dense(3),
503-
keras.layers.Dense(4),
504-
]
505-
)
506-
model.compile(optimizer="adam", loss="mse", jit_compile=True)
507-
model.fit(preprocessed_training_dataset, epochs=10)
508-
```
609+
For the JAX backend, converts the input data to a hardward-dependent
610+
format required for use with SparseCores. Calling `preprocess`
611+
explicitly is only necessary to enable `jit_compile = True`.
509612
510613
For non-JAX backends, preprocessing will bundle together the inputs and
511614
weights, and separate the inputs by device placement. This step is

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -326,15 +326,12 @@ def test_model_fit(self, input_type, use_weights):
326326
)
327327

328328
# Call preprocess on dataset inputs/weights.
329-
def preprocess(inputs_and_labels):
329+
def preprocess(inputs, labels):
330330
# Extract inputs, weights and labels.
331331
weights = None
332-
inputs, labels = inputs_and_labels
333-
labels = keras.tree.map_structure(lambda x: x.numpy(), labels)
334332
if use_weights:
335333
inputs, weights = inputs
336-
preprocessed = layer.preprocess(inputs, weights, training=True)
337-
return preprocessed, labels
334+
return layer.preprocess(inputs, weights, training=True), labels
338335

339336
# Create a dataset generator that applies the preprocess function.
340337
# We need to create an intermediary tf_dataset to avoid
@@ -343,16 +340,16 @@ def preprocess(inputs_and_labels):
343340
tf_train_dataset = train_dataset.repeat(16)
344341

345342
def train_dataset_generator():
346-
for inputs_and_labels in iter(tf_train_dataset):
347-
yield preprocess(inputs_and_labels)
343+
for inputs, labels in iter(tf_train_dataset):
344+
yield preprocess(inputs, labels)
348345

349346
train_dataset = train_dataset_generator()
350347

351348
tf_test_dataset = test_dataset.repeat(16)
352349

353350
def test_dataset_generator():
354-
for inputs in iter(tf_test_dataset):
355-
yield preprocess(inputs)
351+
for inputs, labels in iter(tf_test_dataset):
352+
yield preprocess(inputs, labels)
356353

357354
test_dataset = test_dataset_generator()
358355
model = keras.Sequential([layer])

0 commit comments

Comments
 (0)