Skip to content

Commit f80cb75

Browse files
committed
Added TF specific documentation to DistributedEmbedding.
1 parent bdb77cb commit f80cb75

File tree

1 file changed

+96
-2
lines changed

1 file changed

+96
-2
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,19 @@ class DistributedEmbedding(keras.layers.Layer):
3030
3131
---
3232
33+
`DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
34+
and can dramatically improve the speed of embedding lookups and embedding
35+
training, in particular for large embedding tables. It works by combining
36+
multiple lookups in one invocations, and by sharding the embedding tables
37+
across the available chips.
38+
39+
On other hardware (GPUs, CPUs) and TPUs without SparseCore,
40+
`DistributedEmbedding` provides the same API without any specific
41+
acceleration.
42+
43+
`DistributedEmbedding` embeds sequences of inputs and reduces them to a
44+
single embedding by applying a configurable combiner function.
45+
3346
## Configuration
3447
3548
A `DistributedEmbedding` embedding layer is configured via a set of
@@ -83,15 +96,96 @@ class DistributedEmbedding(keras.layers.Layer):
8396
`model.compile()`.
8497
8598
Note that not all optimizers are supported. Currently, the following are
86-
always supported (i.e. on all backends and accelerators):
99+
supported on all backends and accelerators:
87100
88101
- `keras.optimizers.Adagrad`
89102
- `keras.optimizers.SGD`
90103
91-
Additionally, not all parameters of the optimizers are supported (e.g. the
104+
The following are additionally available when using the TensorFlow backend:
105+
106+
- `keras.optimizers.Adam`
107+
- `keras.optimizers.Ftrl`
108+
109+
Also, not all parameters of the optimizers are supported (e.g. the
92110
`nesterov` option of `SGD`). An error is raised when an unsupported
93111
optimizer or an unsupported optimizer parameter is used.
94112
113+
## Using with TensorFlow on TPU with SpareCore
114+
115+
### Setup
116+
117+
To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
118+
`tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
119+
created under the `TPUStrategy`.
120+
121+
```python
122+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
123+
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
124+
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
125+
topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
126+
)
127+
strategy = tf.distribute.TPUStrategy(
128+
resolver, experimental_device_assignment=device_assignment
129+
)
130+
131+
with strategy.scope():
132+
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
133+
```
134+
135+
### Using in a Keras model
136+
137+
The use Keras' `model.fit()`, one must compile the model under the
138+
`TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
139+
can be called directly. The Keras model takes care of running the model
140+
using the strategy and also automatically distributes the dataset.
141+
142+
```python
143+
model = create_model(embedding)
144+
145+
with strategy.scope():
146+
model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
147+
148+
model.fit(dataset, epochs=10)
149+
```
150+
151+
### Manual invocation
152+
153+
`DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
154+
`tf.function`.
155+
156+
```python
157+
@tf.function
158+
def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
159+
def strategy_fn(st_fn_inputs, st_fn_weights):
160+
return embedding(st_fn_inputs, st_fn_weights)
161+
162+
return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
163+
164+
embedding_wrapper(my_inputs, my_weights)
165+
```
166+
167+
When using a dataset, the dataset must be distributed. The iterator can then
168+
be passed to the `tf.function` that uses `strategy.run`.
169+
170+
```python
171+
dataset = strategy.experimental_distribute_dataset(dataset)
172+
173+
@tf.function
174+
def run_loop(iterator):
175+
def step(data):
176+
(inputs, weights), labels = data
177+
with tf.GradientTape() as tape:
178+
result = embedding(inputs, weights)
179+
loss = keras.losses.mean_squared_error(labels, result)
180+
tape.gradient(loss, embedding.trainable_variables)
181+
return result
182+
183+
for _ in tf.range(4):
184+
result = strategy.run(step, args=(next(iterator),))
185+
186+
run_loop(iter(dataset))
187+
```
188+
95189
Args:
96190
feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
97191
table_stacking: The table stacking to use. `None` means no table

0 commit comments

Comments
 (0)