|
1 | | -# jax_tpu_embedding |
| 1 | +# JAX TPU Embedding |
2 | 2 |
|
3 | 3 | [](https://github.com/jax-ml/jax-tpu-embedding/actions/workflows/build_and_test.yml) |
4 | 4 |
|
5 | | -Usage instructions coming soon! |
| 5 | +JAX TPU Embedding is a library for JAX that provides an efficient, scalable API for large embedding tables on Google Cloud TPUs, leveraging the specialized SparseCore hardware. It is designed to integrate seamlessly into JAX and Flax workflows. |
6 | 6 |
|
7 | | -*This is not an officially supported Google product.* |
| 7 | +--- |
| 8 | + |
| 9 | +## Installation |
| 10 | + |
| 11 | +You can install the library directly from its GitHub repository: |
| 12 | + |
| 13 | +```bash |
| 14 | +git clone https://github.com/jax-ml/jax-tpu-embedding.git |
| 15 | +cd jax-tpu-embedding |
| 16 | +chmod +x .tools/local_build_wheel.sh |
| 17 | +.tools/local_build_wheel.sh |
| 18 | +pip install ./dist/*.whl |
| 19 | +``` |
| 20 | + |
| 21 | +### Development |
| 22 | + |
| 23 | +To build and test the library from a local clone, you will need to install Bazel. You can find the required version in the `.bazelversion` file. |
| 24 | + |
| 25 | +```bash |
| 26 | +# Clone the repository |
| 27 | +git clone [https://github.com/jax-ml/jax-tpu-embedding.git](https://github.com/jax-ml/jax-tpu-embedding.git) |
| 28 | +cd jax-tpu-embedding |
| 29 | + |
| 30 | +# Run all tests |
| 31 | +bazel test //... |
| 32 | +``` |
| 33 | + |
| 34 | +--- |
| 35 | + |
| 36 | +## Quick Start |
| 37 | + |
| 38 | +Here's a quick example of how to use the high-level Flax API to define a model with an embedding layer. |
| 39 | + |
| 40 | +### 1. Define Embedding Table and Feature Specifications |
| 41 | + |
| 42 | +First, define the structure of your embedding table (`TableSpec`) and how your features map to it (`FeatureSpec`). These specifications tell the library how to allocate memory on the SparseCores and configure the lookup hardware. |
| 43 | + |
| 44 | +```python |
| 45 | +import jax |
| 46 | +import jax.numpy as jnp |
| 47 | +from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec |
| 48 | + |
| 49 | +# Example constants |
| 50 | +BATCH_SIZE = 32 |
| 51 | +SEQ_LEN = 16 |
| 52 | +VOCAB_SIZE = 1024 |
| 53 | +EMBEDDING_DIM = 128 |
| 54 | + |
| 55 | +# Define the embedding table properties |
| 56 | +table_spec = embedding_spec.TableSpec( |
| 57 | + vocabulary_size=VOCAB_SIZE, |
| 58 | + embedding_dim=EMBEDDING_DIM, |
| 59 | + name='word_embedding_table', |
| 60 | + optimizer=embedding_spec.SGDOptimizerSpec(learning_rate=0.05) |
| 61 | +) |
| 62 | + |
| 63 | +# Define a feature that uses this table |
| 64 | +feature_spec = embedding_spec.FeatureSpec( |
| 65 | + table_spec=table_spec, |
| 66 | + input_shape=(BATCH_SIZE, SEQ_LEN), # Shape of the input IDs |
| 67 | + output_shape=(BATCH_SIZE, SEQ_LEN, EMBEDDING_DIM), # Desired output activation shape |
| 68 | + name='word_ids' |
| 69 | +) |
| 70 | + |
| 71 | +# Feature specs are passed as a PyTree (e.g., a dictionary) |
| 72 | +feature_specs = {'word_ids': feature_spec} |
| 73 | +``` |
| 74 | + |
| 75 | +### 2. Create a Flax Model |
| 76 | + |
| 77 | +Use the `SparseCoreEmbed` layer within a standard Flax `nn.Module`. This layer handles all the communication with the SparseCores. |
| 78 | + |
| 79 | +```python |
| 80 | +from flax import linen as nn |
| 81 | +from jax_tpu_embedding.sparsecore.lib.flax import embed |
| 82 | + |
| 83 | +class ShakespeareModel(nn.Module): |
| 84 | + feature_specs: embed.Nested[embedding_spec.FeatureSpec] |
| 85 | + |
| 86 | + @nn.compact |
| 87 | + def __call__(self, embedding_inputs): |
| 88 | + # This layer performs the embedding lookup on SparseCores. |
| 89 | + # `embedding_inputs` is a dictionary of integer ID tensors. |
| 90 | + embedding_activations = embed.SparseCoreEmbed( |
| 91 | + feature_specs=self.feature_specs |
| 92 | + )(embedding_inputs) |
| 93 | + |
| 94 | + # The result is a dictionary of activations, matching the feature spec keys. |
| 95 | + x = embedding_activations['word_ids'] |
| 96 | + |
| 97 | + # Flatten sequence and embedding dimensions for the dense layer |
| 98 | + x = x.reshape((x.shape[0], -1)) |
| 99 | + |
| 100 | + # Add a dense layer (runs on TensorCores) |
| 101 | + x = nn.Dense(VOCAB_SIZE)(x) |
| 102 | + return x |
| 103 | +``` |
| 104 | + |
| 105 | +### 3. Initialize and Run the Training Step |
| 106 | + |
| 107 | +The `SparseCoreEmbed` layer separates its parameters (the embedding tables, which live on SparseCore HBM) from the rest of the model's parameters (which live on TensorCore HBM). You initialize them separately and pass them to the training step. |
| 108 | + |
| 109 | +```python |
| 110 | +# Create a mesh for device layout |
| 111 | +devices = jax.devices() |
| 112 | +mesh = jax.sharding.Mesh(devices, axis_names=('data',)) |
| 113 | + |
| 114 | +# 1. Initialize the embedding tables (SparseCore parameters) |
| 115 | +# This needs to be done under the mesh context |
| 116 | +with mesh: |
| 117 | + embedding_params = embed.SparseCoreEmbed.create_embedding_variables( |
| 118 | + feature_specs, jax.random.PRNGKey(0) |
| 119 | + ) |
| 120 | + |
| 121 | +# 2. Initialize the dense model parts (TensorCore parameters) |
| 122 | +model = ShakespeareModel(feature_specs=feature_specs) |
| 123 | +dummy_inputs = {'word_ids': jnp.zeros((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)} |
| 124 | +dense_params = model.init( |
| 125 | + jax.random.PRNGKey(1), dummy_inputs |
| 126 | +)['params'] |
| 127 | + |
| 128 | +# 3. Define and JIT-compile the training step |
| 129 | +@jax.jit |
| 130 | +def train_step(dense_params, embedding_params, features, labels): |
| 131 | + def loss_fn(params): |
| 132 | + # The 'embedding' collection is automatically handled by SparseCoreEmbed |
| 133 | + logits = model.apply({'params': params, 'embedding': embedding_params}, features) |
| 134 | + # A real implementation would use a proper loss function like cross-entropy |
| 135 | + return jnp.mean(jnp.square(logits - labels)) # Dummy loss for demonstration |
| 136 | + |
| 137 | + grad_fn = jax.grad(loss_fn) |
| 138 | + grads = grad_fn(dense_params) |
| 139 | + |
| 140 | + # Gradients for `embedding_params` are computed and applied on the |
| 141 | + # SparseCores automatically by the `SparseCoreEmbed` layer, using the |
| 142 | + # optimizer defined in the TableSpec. |
| 143 | + |
| 144 | + # Update dense parameters using your preferred optimizer |
| 145 | + new_dense_params = jax.tree.map(lambda p, g: p - 0.01 * g, dense_params, grads) |
| 146 | + |
| 147 | + return new_dense_params |
| 148 | + |
| 149 | +# --- Example usage with dummy data --- |
| 150 | +features = {'word_ids': jax.random.randint( |
| 151 | + jax.random.PRNGKey(2), (BATCH_SIZE, SEQ_LEN), 0, VOCAB_SIZE)} |
| 152 | +labels = jax.random.normal(jax.random.PRNGKey(3), (BATCH_SIZE, VOCAB_SIZE)) |
| 153 | + |
| 154 | +# Run one training step |
| 155 | +new_dense_params = train_step(dense_params, embedding_params, features, labels) |
| 156 | +``` |
| 157 | + |
| 158 | +--- |
| 159 | + |
| 160 | +## Key Concepts |
| 161 | + |
| 162 | +- **`TableSpec`**: Defines the properties of a single embedding table, including its shape (`vocabulary_size`, `embedding_dim`), initializer, and the optimizer (e.g., `SGDOptimizerSpec`) to be run on the SparseCores. |
| 163 | +- **`FeatureSpec`**: Describes a logical feature that maps to a `TableSpec`. It specifies the `input_shape` of the integer ID tensor and the desired `output_shape` of the resulting activation tensors. A single table can be used by multiple features. |
| 164 | +- **`SparseCoreEmbed`**: A Flax Linen layer that acts as the entry point for embedding lookups. It takes a dictionary of feature names to ID tensors and returns a dictionary of feature names to activation tensors. It manages the embedding table parameters separately from the dense model parameters. |
| 165 | + |
| 166 | +--- |
| 167 | + |
| 168 | +## Running the Examples |
| 169 | + |
| 170 | +The repository includes a complete Shakespeare next-word-prediction model. You can run it using Bazel. |
| 171 | + |
| 172 | +To run the tests for the example: |
| 173 | +```bash |
| 174 | +# Make sure you are at the root of the repository |
| 175 | +bazel test //jax_tpu_embedding/sparsecore/examples/shakespeare/... |
| 176 | +``` |
| 177 | +This will build and run both the `pmap` and `jit` + `shard_map` versions of the model. |
| 178 | + |
| 179 | +--- |
| 180 | + |
| 181 | +## Contributing |
| 182 | + |
| 183 | +Contributions are welcome! Please read the [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines on how to contribute to this project. |
| 184 | + |
| 185 | +--- |
| 186 | + |
| 187 | +## License |
| 188 | + |
| 189 | +This project is licensed under the terms of the Apache 2.0 license. See the [LICENSE](LICENSE) file for more details. |
0 commit comments