Skip to content

Added TF specific documentation to DistributedEmbedding. #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 144 additions & 5 deletions keras_rs/src/layers/embedding/base_distributed_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,25 @@ class DistributedEmbedding(keras.layers.Layer):

---

## Configuration
`DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
and can dramatically improve the speed of embedding lookups and embedding
training. It works by combining multiple lookups into one invocation, and by
sharding the embedding tables across the available chips. Note that one will
only see performance benefits for embedding tables that are large enough to
to require sharding because they don't fit on a single chip. More details
are provided in the "Placement" section below.

On other hardware, GPUs, CPUs and TPUs without SparseCore,
`DistributedEmbedding` provides the same API without any specific
acceleration. No particular distribution scheme is applied besides the one
set via `keras.distribution.set_distribution`.

`DistributedEmbedding` embeds sequences of inputs and reduces them to a
single embedding by applying a configurable combiner function.

### Configuration

#### Features and tables

A `DistributedEmbedding` embedding layer is configured via a set of
`keras_rs.layers.FeatureConfig` objects, which themselves refer to
Expand All @@ -50,11 +68,13 @@ class DistributedEmbedding(keras.layers.Layer):
name="table1",
vocabulary_size=TABLE1_VOCABULARY_SIZE,
embedding_dim=TABLE1_EMBEDDING_SIZE,
placement="auto",
)
table2 = keras_rs.layers.TableConfig(
name="table2",
vocabulary_size=TABLE2_VOCABULARY_SIZE,
embedding_dim=TABLE2_EMBEDDING_SIZE,
placement="auto",
)

feature1 = keras_rs.layers.FeatureConfig(
Expand All @@ -78,22 +98,141 @@ class DistributedEmbedding(keras.layers.Layer):
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
```

## Optimizers
#### Optimizers

Each embedding table within `DistributedEmbedding` uses its own optimizer
for training, which is independent from the optimizer set on the model via
`model.compile()`.

Note that not all optimizers are supported. Currently, the following are
always supported (i.e. on all backends and accelerators):
supported on all backends and accelerators:

- `keras.optimizers.Adagrad`
- `keras.optimizers.SGD`

Additionally, not all parameters of the optimizers are supported (e.g. the
The following are additionally available when using the TensorFlow backend:

- `keras.optimizers.Adam`
- `keras.optimizers.Ftrl`

Also, not all parameters of the optimizers are supported (e.g. the
`nesterov` option of `SGD`). An error is raised when an unsupported
optimizer or an unsupported optimizer parameter is used.

#### Placement

Each embedding table within `DistributedEmbedding` can be either placed on
the SparseCore chip or the default device placement for the accelerator
(e.g. HBM of the Tensor Cores on TPU). This is controlled by the `placement`
attribute of `keras_rs.layers.TableConfig`.

- A placement of `"sparsecore"` indicates that the table should be placed on
the SparseCore chips. An error is raised if this option is selected and
there are no SparseCore chips.
- A placement of `"default_device"` indicates that the table should not be
placed on SparseCore, even if available. Instead the table is placed on
the device where the model normally goes, i.e. the HBM on TPUs and GPUs.
In this case, if applicable, the table is distributed using the scheme set
via `keras.distribution.set_distribution`. On GPUs, CPUs and TPUs without
SparseCore, this is the only placement available, and is the one selected
by `"auto"`.
- A placement of `"auto"` indicates to use `"sparsecore"` if available, and
`"default_device"` otherwise. This is the default when not specified.

To optimize performance on TPU:

- Tables that are so large that they need to be sharded should use the
`"sparsecore"` placement.
- Tables that are small enough should use `"default_device"` and should
typically be replicated across TPUs by using the
`keras.distribution.DataParallel` distribution option.

### Usage with TensorFlow on TPU with SpareCore

#### Inputs

In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
must be ragged in the dimension with index 1. Note that if weights are
passed, each weight tensor must be of the same class as the inputs for that
particular feature and use the exact same ragged row lenghts for ragged
tensors, and the same indices for sparse tensors. All the output of
`DistributedEmbedding` are dense tensors.

#### Setup

To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
`tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
created under the `TPUStrategy`.

```python
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)

with strategy.scope():
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
```

#### Usage in a Keras model

To use Keras' `model.fit()`, one must compile the model under the
`TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
can be called directly. The Keras model takes care of running the model
using the strategy and also automatically distributes the dataset.

```python
with strategy.scope():
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
model = create_model(embedding)
model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")

model.fit(dataset, epochs=10)
```

#### Direct invocation

`DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
`tf.function`.

```python
@tf.function
def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
def strategy_fn(st_fn_inputs, st_fn_weights):
return embedding(st_fn_inputs, st_fn_weights)

return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))

embedding_wrapper(my_inputs, my_weights)
```

When using a dataset, the dataset must be distributed. The iterator can then
be passed to the `tf.function` that uses `strategy.run`.

```python
dataset = strategy.experimental_distribute_dataset(dataset)

@tf.function
def run_loop(iterator):
def step(data):
(inputs, weights), labels = data
with tf.GradientTape() as tape:
result = embedding(inputs, weights)
loss = keras.losses.mean_squared_error(labels, result)
tape.gradient(loss, embedding.trainable_variables)
return result

for _ in tf.range(4):
result = strategy.run(step, args=(next(iterator),))

run_loop(iter(dataset))
```

Args:
feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
table_stacking: The table stacking to use. `None` means no table
Expand Down Expand Up @@ -282,7 +421,7 @@ def preprocess(
to feeding data into a model.

An example usage might look like:
```
```python
# Create the embedding layer.
embedding_layer = DistributedEmbedding(feature_configs)

Expand Down