@@ -32,7 +32,25 @@ class DistributedEmbedding(keras.layers.Layer):
32
32
33
33
---
34
34
35
- ## Configuration
35
+ `DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
36
+ and can dramatically improve the speed of embedding lookups and embedding
37
+ training. It works by combining multiple lookups into one invocation, and by
38
+ sharding the embedding tables across the available chips. Note that one will
39
+ only see performance benefits for embedding tables that are large enough to
40
+ to require sharding because they don't fit on a single chip. More details
41
+ are provided in the "Placement" section below.
42
+
43
+ On other hardware, GPUs, CPUs and TPUs without SparseCore,
44
+ `DistributedEmbedding` provides the same API without any specific
45
+ acceleration. No particular distribution scheme is applied besides the one
46
+ set via `keras.distribution.set_distribution`.
47
+
48
+ `DistributedEmbedding` embeds sequences of inputs and reduces them to a
49
+ single embedding by applying a configurable combiner function.
50
+
51
+ ### Configuration
52
+
53
+ #### Features and tables
36
54
37
55
A `DistributedEmbedding` embedding layer is configured via a set of
38
56
`keras_rs.layers.FeatureConfig` objects, which themselves refer to
@@ -50,11 +68,13 @@ class DistributedEmbedding(keras.layers.Layer):
50
68
name="table1",
51
69
vocabulary_size=TABLE1_VOCABULARY_SIZE,
52
70
embedding_dim=TABLE1_EMBEDDING_SIZE,
71
+ placement="auto",
53
72
)
54
73
table2 = keras_rs.layers.TableConfig(
55
74
name="table2",
56
75
vocabulary_size=TABLE2_VOCABULARY_SIZE,
57
76
embedding_dim=TABLE2_EMBEDDING_SIZE,
77
+ placement="auto",
58
78
)
59
79
60
80
feature1 = keras_rs.layers.FeatureConfig(
@@ -78,22 +98,141 @@ class DistributedEmbedding(keras.layers.Layer):
78
98
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
79
99
```
80
100
81
- ## Optimizers
101
+ #### Optimizers
82
102
83
103
Each embedding table within `DistributedEmbedding` uses its own optimizer
84
104
for training, which is independent from the optimizer set on the model via
85
105
`model.compile()`.
86
106
87
107
Note that not all optimizers are supported. Currently, the following are
88
- always supported (i.e. on all backends and accelerators) :
108
+ supported on all backends and accelerators:
89
109
90
110
- `keras.optimizers.Adagrad`
91
111
- `keras.optimizers.SGD`
92
112
93
- Additionally, not all parameters of the optimizers are supported (e.g. the
113
+ The following are additionally available when using the TensorFlow backend:
114
+
115
+ - `keras.optimizers.Adam`
116
+ - `keras.optimizers.Ftrl`
117
+
118
+ Also, not all parameters of the optimizers are supported (e.g. the
94
119
`nesterov` option of `SGD`). An error is raised when an unsupported
95
120
optimizer or an unsupported optimizer parameter is used.
96
121
122
+ #### Placement
123
+
124
+ Each embedding table within `DistributedEmbedding` can be either placed on
125
+ the SparseCore chip or the default device placement for the accelerator
126
+ (e.g. HBM of the Tensor Cores on TPU). This is controlled by the `placement`
127
+ attribute of `keras_rs.layers.TableConfig`.
128
+
129
+ - A placement of `"sparsecore"` indicates that the table should be placed on
130
+ the SparseCore chips. An error is raised if this option is selected and
131
+ there are no SparseCore chips.
132
+ - A placement of `"default_device"` indicates that the table should not be
133
+ placed on SparseCore, even if available. Instead the table is placed on
134
+ the device where the model normally goes, i.e. the HBM on TPUs and GPUs.
135
+ In this case, if applicable, the table is distributed using the scheme set
136
+ via `keras.distribution.set_distribution`. On GPUs, CPUs and TPUs without
137
+ SparseCore, this is the only placement available, and is the one selected
138
+ by `"auto"`.
139
+ - A placement of `"auto"` indicates to use `"sparsecore"` if available, and
140
+ `"default_device"` otherwise. This is the default when not specified.
141
+
142
+ To optimize performance on TPU:
143
+
144
+ - Tables that are so large that they need to be sharded should use the
145
+ `"sparsecore"` placement.
146
+ - Tables that are small enough should use `"default_device"` and should
147
+ typically be replicated across TPUs by using the
148
+ `keras.distribution.DataParallel` distribution option.
149
+
150
+ ### Usage with TensorFlow on TPU with SpareCore
151
+
152
+ #### Inputs
153
+
154
+ In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor`
155
+ and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors
156
+ must be ragged in the dimension with index 1. Note that if weights are
157
+ passed, each weight tensor must be of the same class as the inputs for that
158
+ particular feature and use the exact same ragged row lenghts for ragged
159
+ tensors, and the same indices for sparse tensors. All the output of
160
+ `DistributedEmbedding` are dense tensors.
161
+
162
+ #### Setup
163
+
164
+ To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
165
+ `tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
166
+ created under the `TPUStrategy`.
167
+
168
+ ```python
169
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
170
+ topology = tf.tpu.experimental.initialize_tpu_system(resolver)
171
+ device_assignment = tf.tpu.experimental.DeviceAssignment.build(
172
+ topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
173
+ )
174
+ strategy = tf.distribute.TPUStrategy(
175
+ resolver, experimental_device_assignment=device_assignment
176
+ )
177
+
178
+ with strategy.scope():
179
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
180
+ ```
181
+
182
+ #### Usage in a Keras model
183
+
184
+ To use Keras' `model.fit()`, one must compile the model under the
185
+ `TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
186
+ can be called directly. The Keras model takes care of running the model
187
+ using the strategy and also automatically distributes the dataset.
188
+
189
+ ```python
190
+ with strategy.scope():
191
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
192
+ model = create_model(embedding)
193
+ model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
194
+
195
+ model.fit(dataset, epochs=10)
196
+ ```
197
+
198
+ #### Direct invocation
199
+
200
+ `DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
201
+ `tf.function`.
202
+
203
+ ```python
204
+ @tf.function
205
+ def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
206
+ def strategy_fn(st_fn_inputs, st_fn_weights):
207
+ return embedding(st_fn_inputs, st_fn_weights)
208
+
209
+ return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
210
+
211
+ embedding_wrapper(my_inputs, my_weights)
212
+ ```
213
+
214
+ When using a dataset, the dataset must be distributed. The iterator can then
215
+ be passed to the `tf.function` that uses `strategy.run`.
216
+
217
+ ```python
218
+ dataset = strategy.experimental_distribute_dataset(dataset)
219
+
220
+ @tf.function
221
+ def run_loop(iterator):
222
+ def step(data):
223
+ (inputs, weights), labels = data
224
+ with tf.GradientTape() as tape:
225
+ result = embedding(inputs, weights)
226
+ loss = keras.losses.mean_squared_error(labels, result)
227
+ tape.gradient(loss, embedding.trainable_variables)
228
+ return result
229
+
230
+ for _ in tf.range(4):
231
+ result = strategy.run(step, args=(next(iterator),))
232
+
233
+ run_loop(iter(dataset))
234
+ ```
235
+
97
236
Args:
98
237
feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
99
238
table_stacking: The table stacking to use. `None` means no table
@@ -282,7 +421,7 @@ def preprocess(
282
421
to feeding data into a model.
283
422
284
423
An example usage might look like:
285
- ```
424
+ ```python
286
425
# Create the embedding layer.
287
426
embedding_layer = DistributedEmbedding(feature_configs)
288
427
0 commit comments