@@ -30,6 +30,19 @@ class DistributedEmbedding(keras.layers.Layer):
30
30
31
31
---
32
32
33
+ `DistributedEmbedding` is a layer optimized for TPU chips with SparseCore
34
+ and can dramatically improve the speed of embedding lookups and embedding
35
+ training, in particular for large embedding tables. It works by combining
36
+ multiple lookups in one invocations, and by sharding the embedding tables
37
+ across the available chips.
38
+
39
+ On other hardware (GPUs, CPUs) and TPUs without SparseCore,
40
+ `DistributedEmbedding` provides the same API without any specific
41
+ acceleration.
42
+
43
+ `DistributedEmbedding` embeds sequences of inputs and reduces them to a
44
+ single embedding by applying a configurable combiner function.
45
+
33
46
## Configuration
34
47
35
48
A `DistributedEmbedding` embedding layer is configured via a set of
@@ -83,15 +96,96 @@ class DistributedEmbedding(keras.layers.Layer):
83
96
`model.compile()`.
84
97
85
98
Note that not all optimizers are supported. Currently, the following are
86
- always supported (i.e. on all backends and accelerators) :
99
+ supported on all backends and accelerators:
87
100
88
101
- `keras.optimizers.Adagrad`
89
102
- `keras.optimizers.SGD`
90
103
91
- Additionally, not all parameters of the optimizers are supported (e.g. the
104
+ The following are additionally available when using the TensorFlow backend:
105
+
106
+ - `keras.optimizers.Adam`
107
+ - `keras.optimizers.Ftrl`
108
+
109
+ Also, not all parameters of the optimizers are supported (e.g. the
92
110
`nesterov` option of `SGD`). An error is raised when an unsupported
93
111
optimizer or an unsupported optimizer parameter is used.
94
112
113
+ ## Using with TensorFlow on TPU with SpareCore
114
+
115
+ ### Setup
116
+
117
+ To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a
118
+ `tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be
119
+ created under the `TPUStrategy`.
120
+
121
+ ```python
122
+ resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
123
+ topology = tf.tpu.experimental.initialize_tpu_system(resolver)
124
+ device_assignment = tf.tpu.experimental.DeviceAssignment.build(
125
+ topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
126
+ )
127
+ strategy = tf.distribute.TPUStrategy(
128
+ resolver, experimental_device_assignment=device_assignment
129
+ )
130
+
131
+ with strategy.scope():
132
+ embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
133
+ ```
134
+
135
+ ### Using in a Keras model
136
+
137
+ The use Keras' `model.fit()`, one must compile the model under the
138
+ `TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()`
139
+ can be called directly. The Keras model takes care of running the model
140
+ using the strategy and also automatically distributes the dataset.
141
+
142
+ ```python
143
+ model = create_model(embedding)
144
+
145
+ with strategy.scope():
146
+ model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
147
+
148
+ model.fit(dataset, epochs=10)
149
+ ```
150
+
151
+ ### Manual invocation
152
+
153
+ `DistributedEmbedding` must be invoked via a `strategy.run` call nested in a
154
+ `tf.function`.
155
+
156
+ ```python
157
+ @tf.function
158
+ def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
159
+ def strategy_fn(st_fn_inputs, st_fn_weights):
160
+ return embedding(st_fn_inputs, st_fn_weights)
161
+
162
+ return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
163
+
164
+ embedding_wrapper(my_inputs, my_weights)
165
+ ```
166
+
167
+ When using a dataset, the dataset must be distributed. The iterator can then
168
+ be passed to the `tf.function` that uses `strategy.run`.
169
+
170
+ ```python
171
+ dataset = strategy.experimental_distribute_dataset(dataset)
172
+
173
+ @tf.function
174
+ def run_loop(iterator):
175
+ def step(data):
176
+ (inputs, weights), labels = data
177
+ with tf.GradientTape() as tape:
178
+ result = embedding(inputs, weights)
179
+ loss = keras.losses.mean_squared_error(labels, result)
180
+ tape.gradient(loss, embedding.trainable_variables)
181
+ return result
182
+
183
+ for _ in tf.range(4):
184
+ result = strategy.run(step, args=(next(iterator),))
185
+
186
+ run_loop(iter(dataset))
187
+ ```
188
+
95
189
Args:
96
190
feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`.
97
191
table_stacking: The table stacking to use. `None` means no table
0 commit comments