diff --git a/keras_rs/src/layers/embedding/base_distributed_embedding.py b/keras_rs/src/layers/embedding/base_distributed_embedding.py index 8219cc0..4e7ecb8 100644 --- a/keras_rs/src/layers/embedding/base_distributed_embedding.py +++ b/keras_rs/src/layers/embedding/base_distributed_embedding.py @@ -32,7 +32,25 @@ class DistributedEmbedding(keras.layers.Layer): --- - ## Configuration + `DistributedEmbedding` is a layer optimized for TPU chips with SparseCore + and can dramatically improve the speed of embedding lookups and embedding + training. It works by combining multiple lookups into one invocation, and by + sharding the embedding tables across the available chips. Note that one will + only see performance benefits for embedding tables that are large enough to + to require sharding because they don't fit on a single chip. More details + are provided in the "Placement" section below. + + On other hardware, GPUs, CPUs and TPUs without SparseCore, + `DistributedEmbedding` provides the same API without any specific + acceleration. No particular distribution scheme is applied besides the one + set via `keras.distribution.set_distribution`. + + `DistributedEmbedding` embeds sequences of inputs and reduces them to a + single embedding by applying a configurable combiner function. + + ### Configuration + + #### Features and tables A `DistributedEmbedding` embedding layer is configured via a set of `keras_rs.layers.FeatureConfig` objects, which themselves refer to @@ -50,11 +68,13 @@ class DistributedEmbedding(keras.layers.Layer): name="table1", vocabulary_size=TABLE1_VOCABULARY_SIZE, embedding_dim=TABLE1_EMBEDDING_SIZE, + placement="auto", ) table2 = keras_rs.layers.TableConfig( name="table2", vocabulary_size=TABLE2_VOCABULARY_SIZE, embedding_dim=TABLE2_EMBEDDING_SIZE, + placement="auto", ) feature1 = keras_rs.layers.FeatureConfig( @@ -78,22 +98,141 @@ class DistributedEmbedding(keras.layers.Layer): embedding = keras_rs.layers.DistributedEmbedding(feature_configs) ``` - ## Optimizers + #### Optimizers Each embedding table within `DistributedEmbedding` uses its own optimizer for training, which is independent from the optimizer set on the model via `model.compile()`. Note that not all optimizers are supported. Currently, the following are - always supported (i.e. on all backends and accelerators): + supported on all backends and accelerators: - `keras.optimizers.Adagrad` - `keras.optimizers.SGD` - Additionally, not all parameters of the optimizers are supported (e.g. the + The following are additionally available when using the TensorFlow backend: + + - `keras.optimizers.Adam` + - `keras.optimizers.Ftrl` + + Also, not all parameters of the optimizers are supported (e.g. the `nesterov` option of `SGD`). An error is raised when an unsupported optimizer or an unsupported optimizer parameter is used. + #### Placement + + Each embedding table within `DistributedEmbedding` can be either placed on + the SparseCore chip or the default device placement for the accelerator + (e.g. HBM of the Tensor Cores on TPU). This is controlled by the `placement` + attribute of `keras_rs.layers.TableConfig`. + + - A placement of `"sparsecore"` indicates that the table should be placed on + the SparseCore chips. An error is raised if this option is selected and + there are no SparseCore chips. + - A placement of `"default_device"` indicates that the table should not be + placed on SparseCore, even if available. Instead the table is placed on + the device where the model normally goes, i.e. the HBM on TPUs and GPUs. + In this case, if applicable, the table is distributed using the scheme set + via `keras.distribution.set_distribution`. On GPUs, CPUs and TPUs without + SparseCore, this is the only placement available, and is the one selected + by `"auto"`. + - A placement of `"auto"` indicates to use `"sparsecore"` if available, and + `"default_device"` otherwise. This is the default when not specified. + + To optimize performance on TPU: + + - Tables that are so large that they need to be sharded should use the + `"sparsecore"` placement. + - Tables that are small enough should use `"default_device"` and should + typically be replicated across TPUs by using the + `keras.distribution.DataParallel` distribution option. + + ### Usage with TensorFlow on TPU with SpareCore + + #### Inputs + + In addition to `tf.Tensor`, `DistributedEmbedding` accepts `tf.RaggedTensor` + and `tf.SparseTensor` as inputs for the embedding lookups. Ragged tensors + must be ragged in the dimension with index 1. Note that if weights are + passed, each weight tensor must be of the same class as the inputs for that + particular feature and use the exact same ragged row lenghts for ragged + tensors, and the same indices for sparse tensors. All the output of + `DistributedEmbedding` are dense tensors. + + #### Setup + + To use `DistributedEmbedding` on TPUs with TensorFlow, one must use a + `tf.distribute.TPUStrategy`. The `DistributedEmbedding` layer must be + created under the `TPUStrategy`. + + ```python + resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local") + topology = tf.tpu.experimental.initialize_tpu_system(resolver) + device_assignment = tf.tpu.experimental.DeviceAssignment.build( + topology, num_replicas=resolver.get_tpu_system_metadata().num_cores + ) + strategy = tf.distribute.TPUStrategy( + resolver, experimental_device_assignment=device_assignment + ) + + with strategy.scope(): + embedding = keras_rs.layers.DistributedEmbedding(feature_configs) + ``` + + #### Usage in a Keras model + + To use Keras' `model.fit()`, one must compile the model under the + `TPUStrategy`. Then, `model.fit()`, `model.evaluate()` or `model.predict()` + can be called directly. The Keras model takes care of running the model + using the strategy and also automatically distributes the dataset. + + ```python + with strategy.scope(): + embedding = keras_rs.layers.DistributedEmbedding(feature_configs) + model = create_model(embedding) + model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam") + + model.fit(dataset, epochs=10) + ``` + + #### Direct invocation + + `DistributedEmbedding` must be invoked via a `strategy.run` call nested in a + `tf.function`. + + ```python + @tf.function + def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None): + def strategy_fn(st_fn_inputs, st_fn_weights): + return embedding(st_fn_inputs, st_fn_weights) + + return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights))) + + embedding_wrapper(my_inputs, my_weights) + ``` + + When using a dataset, the dataset must be distributed. The iterator can then + be passed to the `tf.function` that uses `strategy.run`. + + ```python + dataset = strategy.experimental_distribute_dataset(dataset) + + @tf.function + def run_loop(iterator): + def step(data): + (inputs, weights), labels = data + with tf.GradientTape() as tape: + result = embedding(inputs, weights) + loss = keras.losses.mean_squared_error(labels, result) + tape.gradient(loss, embedding.trainable_variables) + return result + + for _ in tf.range(4): + result = strategy.run(step, args=(next(iterator),)) + + run_loop(iter(dataset)) + ``` + Args: feature_configs: A nested structure of `keras_rs.layers.FeatureConfig`. table_stacking: The table stacking to use. `None` means no table @@ -282,7 +421,7 @@ def preprocess( to feeding data into a model. An example usage might look like: - ``` + ```python # Create the embedding layer. embedding_layer = DistributedEmbedding(feature_configs)