@@ -291,6 +291,141 @@ def step(data):
291
291
result = strategy.run(step, args=(next(iterator),))
292
292
293
293
run_loop(iter(dataset))
294
+
295
+ ### Usage with JAX on TPU with SpareCore
296
+
297
+ #### Setup
298
+
299
+ To use `DistributedEmbedding` on TPUs with JAX, one must create and set a
300
+ Keras `Distribution`.
301
+ ```
302
+ distribution = keras.distribution.DataParallel(devices=jax.device("tpu))
303
+ keras.distribution.set_distribution(distribution)
304
+ ```
305
+
306
+ #### Inputs
307
+
308
+ For JAX, inputs can either be dense tensors, or ragged (nested) NumPy
309
+ arrays. To enable `jit_compile = True`, one must explicitly call
310
+ `layer.preprocess(...)` on the inputs, and then feed the preprocessed
311
+ output to the model. See the next section on preprocessing for details.
312
+
313
+ Ragged input arrays must be ragged in the dimension with index 1. Note that
314
+ if weights are passed, each weight tensor must be of the same class as the
315
+ inputs for that particular feature and use the exact same ragged row lengths
316
+ for ragged tensors. All the output of `DistributedEmbedding` are dense
317
+ tensors.
318
+
319
+ #### Preprocessing
320
+
321
+ In JAX, SparseCore usage requires specially formatted data that depends
322
+ on properties of the available hardware. This data reformatting
323
+ currently does not support jit-compilation, so must be applied _prior_
324
+ to passing data into a model.
325
+
326
+ Preprocessing works on dense or ragged NumPy arrays, or on tensors that are
327
+ convertible to dense or ragged NumPy arrays like `tf.RaggedTensor`.
328
+
329
+ One simple way to add preprocessing is to append the function to an input
330
+ pipeline by using a python generator.
331
+ ```python
332
+ # Create the embedding layer.
333
+ embedding_layer = DistributedEmbedding(feature_configs)
334
+
335
+ # Add preprocessing to a data input pipeline.
336
+ def train_dataset_generator():
337
+ for (inputs, weights), labels in iter(train_dataset):
338
+ yield embedding_layer.preprocess(
339
+ inputs, weights, training=True
340
+ ), labels
341
+
342
+ preprocessed_train_dataset = train_dataset_generator()
343
+ ```
344
+ This explicit preprocessing stage combines the input and optional weights,
345
+ so the new data can be passed directly into the `inputs` argument of the
346
+ layer or model.
347
+
348
+ #### Usage in a Keras model
349
+
350
+ Once the global distribution is set and the input preprocessing pipeline
351
+ is defined, model training can proceed as normal. For example:
352
+ ```python
353
+ # Construct, compile, and fit the model using the preprocessed data.
354
+ model = keras.Sequential(
355
+ [
356
+ embedding_layer,
357
+ keras.layers.Dense(2),
358
+ keras.layers.Dense(3),
359
+ keras.layers.Dense(4),
360
+ ]
361
+ )
362
+ model.compile(optimizer="adam", loss="mse", jit_compile=True)
363
+ model.fit(preprocessed_train_dataset, epochs=10)
364
+ ```
365
+
366
+ #### Direct invocation
367
+
368
+ The `DistributedEmbedding` layer can also be invoked directly. Explicit
369
+ preprocessing is required when used with JIT compilation.
370
+ ```python
371
+ # Call the layer directly.
372
+ activations = embedding_layer(my_inputs, my_weights)
373
+
374
+ # Call the layer with JIT compilation and explicitly preprocessed inputs.
375
+ embedding_layer_jit = jax.jit(embedding_layer)
376
+ preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
377
+ activations = embedding_layer_jit(preprocessed_inputs)
378
+ ```
379
+
380
+ Similarly, for custom training loops, preprocessing must be applied prior
381
+ to passing the data to the JIT-compiled training step.
382
+ ```python
383
+ # Create an optimizer and loss function.
384
+ optimizer = keras.optimizers.Adam(learning_rate=1e-3)
385
+
386
+ def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
387
+ y_pred, non_trainable_variables = model.stateless_call(
388
+ trainable_variables, non_trainable_variables, x, training=True
389
+ )
390
+ loss = keras.losses.mean_squared_error(y, y_pred)
391
+ return loss, non_trainable_variables
392
+
393
+ grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)
394
+
395
+ # Create a JIT-compiled training step.
396
+ @jax.jit
397
+ def train_step(state, x, y):
398
+ (
399
+ trainable_variables,
400
+ non_trainable_variables,
401
+ optimizer_variables,
402
+ ) = state
403
+ (loss, non_trainable_variables), grads = grad_fn(
404
+ trainable_variables, non_trainable_variables, x, y
405
+ )
406
+ trainable_variables, optimizer_variables = optimizer.stateless_apply(
407
+ optimizer_variables, grads, trainable_variables
408
+ )
409
+ return loss, (
410
+ trainable_variables,
411
+ non_trainable_variables,
412
+ optimizer_variables,
413
+ )
414
+
415
+ # Build optimizer variables.
416
+ optimizer.build(model.trainable_variables)
417
+
418
+ # Assemble the training state.
419
+ trainable_variables = model.trainable_variables
420
+ non_trainable_variables = model.non_trainable_variables
421
+ optimizer_variables = optimizer.variables
422
+ state = trainable_variables, non_trainable_variables, optimizer_variables
423
+
424
+ # Training loop.
425
+ for (inputs, weights), labels in train_dataset:
426
+ # Explicitly preprocess the data.
427
+ preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
428
+ loss, state = train_step(state, preprocessed_inputs, labels)
294
429
```
295
430
296
431
Args:
@@ -471,41 +606,9 @@ def preprocess(
471
606
) -> types .Nested [types .Tensor ]:
472
607
"""Preprocesses and reformats the data for consumption by the model.
473
608
474
- Calling `preprocess` explicitly is only required for the JAX backend
475
- to enable `jit_compile = True`. For all other cases and backends,
476
- explicit use of `preprocess` is optional.
477
-
478
- In JAX, sparsecore usage requires specially formatted data that depends
479
- on properties of the available hardware. This data reformatting
480
- currently does not support jit-compilation, so must be applied _prior_
481
- to feeding data into a model.
482
-
483
- An example usage might look like:
484
- ```python
485
- # Create the embedding layer.
486
- embedding_layer = DistributedEmbedding(feature_configs)
487
-
488
- # Add preprocessing to a data input pipeline.
489
- def training_dataset_generator():
490
- for (inputs, weights), labels in iter(training_dataset):
491
- yield embedding_layer.preprocess(
492
- inputs, weights, training=True
493
- ), labels
494
-
495
- preprocessed_training_dataset = training_dataset_generate()
496
-
497
- # Construct, compile, and fit the model using the preprocessed data.
498
- model = keras.Sequential(
499
- [
500
- embedding_layer,
501
- keras.layers.Dense(2),
502
- keras.layers.Dense(3),
503
- keras.layers.Dense(4),
504
- ]
505
- )
506
- model.compile(optimizer="adam", loss="mse", jit_compile=True)
507
- model.fit(preprocessed_training_dataset, epochs=10)
508
- ```
609
+ For the JAX backend, converts the input data to a hardward-dependent
610
+ format required for use with SparseCores. Calling `preprocess`
611
+ explicitly is only necessary to enable `jit_compile = True`.
509
612
510
613
For non-JAX backends, preprocessing will bundle together the inputs and
511
614
weights, and separate the inputs by device placement. This step is
0 commit comments