Skip to content

Commit 31d881b

Browse files
authored
Added TF specific documentation to DistributedEmbedding. (#94)
1 parent 03fe057 commit 31d881b

File tree

1 file changed

+144
-5
lines changed

1 file changed

+144
-5
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 144 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,25 @@ class DistributedEmbedding(keras.layers.Layer):
3232
3333
---
3434
35-
## Configuration
35+
`DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
36+
and can dramatically improve the speed of embedding lookups and embedding
37+
training. It works by combining multiple lookups into one invocation, and by
38+
sharding the embedding tables across the available chips. Note that one will
39+
only see performance benefits for embedding tables that are large enough to
40+
to require sharding because they don't fit on a single chip. More details
41+
are provided in the "Placement" section below.
42+
43+
On other hardware, GPUs, CPUs and TPUs without SparseCore,
44+
`DistributedEmbedding` provides the same API without any specific
45+
acceleration. No particular distribution scheme is applied besides the one
46+
set via `keras.distribution.set_distribution`.
47+
48+
`DistributedEmbedding` embeds sequences of inputs and reduces them to a
49+
single embedding by applying a configurable combiner function.
50+
51+
### Configuration
52+
53+
#### Features and tables
3654
3755
A `DistributedEmbedding` embedding layer is configured via a set of
3856
`keras_rs.layers.FeatureConfig` objects, which themselves refer to
@@ -50,11 +68,13 @@ class DistributedEmbedding(keras.layers.Layer):
5068
name="table1",
5169
vocabulary_size=TABLE1_VOCABULARY_SIZE,
5270
embedding_dim=TABLE1_EMBEDDING_SIZE,
71+
placement="auto",
5372
)
5473
table2 = keras_rs.layers.TableConfig(
5574
name="table2",
5675
vocabulary_size=TABLE2_VOCABULARY_SIZE,
5776
embedding_dim=TABLE2_EMBEDDING_SIZE,
77+
placement="auto",
5878
)
5979
6080
feature1 = keras_rs.layers.FeatureConfig(
@@ -78,22 +98,141 @@ class DistributedEmbedding(keras.layers.Layer):
7898
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
7999
```
80100
81-
## Optimizers
101+
#### Optimizers
82102
83103
Each embedding table within `DistributedEmbedding` uses its own optimizer
84104
for training, which is independent from the optimizer set on the model via
85105
`model.compile()`.
86106
87107
Note that not all optimizers are supported. Currently, the following are
88-
always supported (i.e. on all backends and accelerators):
108+
supported on all backends and accelerators:
89109
90110
- `keras.optimizers.Adagrad`
91111
- `keras.optimizers.SGD`
92112
93-
Additionally, not all parameters of the optimizers are supported (e.g. the
113+
The following are additionally available when using the TensorFlow backend:
114+
115+
- `keras.optimizers.Adam`
116+
- `keras.optimizers.Ftrl`
117+
118+
Also, not all parameters of the optimizers are supported (e.g. the
94119
`nesterov` option of `SGD`). An error is raised when an unsupported
95120
optimizer or an unsupported optimizer parameter is used.
96121
122+
#### Placement
123+
124+
Each embedding table within `DistributedEmbedding` can be either placed on
125+
the SparseCore chip or the default device placement for the accelerator
126+
(e.g. HBM of the Tensor Cores on TPU). This is controlled by the `placement`
127+
attribute of `keras_rs.layers.TableConfig`.
128+
129+
- A placement of `"sparsecore"` indicates that the table should be placed on
130+
the SparseCore chips. An error is raised if this option is selected and
131+
there are no SparseCore chips.
132+
- A placement of `"default_device"` indicates that the table should not be
133+
placed on SparseCore, even if available. Instead the table is placed on
134+
the device where the model normally goes, i.e. the HBM on TPUs and GPUs.
135+
In this case, if applicable, the table is distributed using the scheme set
136+
via `keras.distribution.set_distribution`. On GPUs, CPUs and TPUs without
137+
SparseCore, this is the only placement available, and is the one selected
138+
by `"auto"`.
139+
- A placement of `"auto"` indicates to use `"sparsecore"` if available, and
140+
`"default_device"` otherwise. This is the default when not specified.
141+
142+
To optimize performance on TPU:
143+
144+
- Tables that are so large that they need to be sharded should use the
145+
`"sparsecore"` placement.
146+
- Tables that are small enough should use `"default_device"` and should
147+
typically be replicated across TPUs by using the
148+
`keras.distribution.DataParallel` distribution option.
149+
150+
### Usage with TensorFlow on TPU with SpareCore
151+
152+
#### Inputs
153+
154+
In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
155+
and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
156+
must be ragged in the dimension with index 1. Note that if weights are
157+
passed, each weight tensor must be of the same class as the inputs for that
158+
particular feature and use the exact same ragged row lenghts for ragged
159+
tensors, and the same indices for sparse tensors. All the output of
160+
`DistributedEmbedding` are dense tensors.
161+
162+
#### Setup
163+
164+
To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
165+
`tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
166+
created under the `TPUStrategy`.
167+
168+
```python
169+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
170+
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
171+
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
172+
topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
173+
)
174+
strategy = tf.distribute.TPUStrategy(
175+
resolver, experimental_device_assignment=device_assignment
176+
)
177+
178+
with strategy.scope():
179+
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
180+
```
181+
182+
#### Usage in a Keras model
183+
184+
To use Keras' `model.fit()`, one must compile the model under the
185+
`TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
186+
can be called directly. The Keras model takes care of running the model
187+
using the strategy and also automatically distributes the dataset.
188+
189+
```python
190+
with strategy.scope():
191+
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
192+
model = create_model(embedding)
193+
model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
194+
195+
model.fit(dataset, epochs=10)
196+
```
197+
198+
#### Direct invocation
199+
200+
`DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
201+
`tf.function`.
202+
203+
```python
204+
@tf.function
205+
def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
206+
def strategy_fn(st_fn_inputs, st_fn_weights):
207+
return embedding(st_fn_inputs, st_fn_weights)
208+
209+
return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
210+
211+
embedding_wrapper(my_inputs, my_weights)
212+
```
213+
214+
When using a dataset, the dataset must be distributed. The iterator can then
215+
be passed to the `tf.function` that uses `strategy.run`.
216+
217+
```python
218+
dataset = strategy.experimental_distribute_dataset(dataset)
219+
220+
@tf.function
221+
def run_loop(iterator):
222+
def step(data):
223+
(inputs, weights), labels = data
224+
with tf.GradientTape() as tape:
225+
result = embedding(inputs, weights)
226+
loss = keras.losses.mean_squared_error(labels, result)
227+
tape.gradient(loss, embedding.trainable_variables)
228+
return result
229+
230+
for _ in tf.range(4):
231+
result = strategy.run(step, args=(next(iterator),))
232+
233+
run_loop(iter(dataset))
234+
```
235+
97236
Args:
98237
feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
99238
table_stacking: The table stacking to use. `None` means no table
@@ -282,7 +421,7 @@ def preprocess(
282421
to feeding data into a model.
283422
284423
An example usage might look like:
285-
```
424+
```python
286425
# Create the embedding layer.
287426
embedding_layer = DistributedEmbedding(feature_configs)
288427

0 commit comments

Comments
 (0)