Skip to content

Commit a879f36

Browse files
Update README.md
1 parent 24df1b7 commit a879f36

File tree

1 file changed

+185
-3
lines changed

1 file changed

+185
-3
lines changed

README.md

Lines changed: 185 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,189 @@
1-
# jax_tpu_embedding
1+
# JAX TPU Embedding
22

33
[![Unittests](https://github.com/jax-ml/jax-tpu-embedding/actions/workflows/build_and_test.yml/badge.svg)](https://github.com/jax-ml/jax-tpu-embedding/actions/workflows/build_and_test.yml)
44

5-
Usage instructions coming soon!
5+
JAX TPU Embedding is a library for JAX that provides an efficient, scalable API for large embedding tables on Google Cloud TPUs, leveraging the specialized SparseCore hardware. It is designed to integrate seamlessly into JAX and Flax workflows.
66

7-
*This is not an officially supported Google product.*
7+
---
8+
9+
## Installation
10+
11+
You can install the library directly from its GitHub repository:
12+
13+
```bash
14+
git clone https://github.com/jax-ml/jax-tpu-embedding.git
15+
cd jax-tpu-embedding
16+
chmod +x .tools/local_build_wheel.sh
17+
.tools/local_build_wheel.sh
18+
pip install ./dist/*.whl
19+
```
20+
21+
### Development
22+
23+
To build and test the library from a local clone, you will need to install Bazel. You can find the required version in the `.bazelversion` file.
24+
25+
```bash
26+
# Clone the repository
27+
git clone [https://github.com/jax-ml/jax-tpu-embedding.git](https://github.com/jax-ml/jax-tpu-embedding.git)
28+
cd jax-tpu-embedding
29+
30+
# Run all tests
31+
bazel test //...
32+
```
33+
34+
---
35+
36+
## Quick Start
37+
38+
Here's a quick example of how to use the high-level Flax API to define a model with an embedding layer.
39+
40+
### 1. Define Embedding Table and Feature Specifications
41+
42+
First, define the structure of your embedding table (`TableSpec`) and how your features map to it (`FeatureSpec`). These specifications tell the library how to allocate memory on the SparseCores and configure the lookup hardware.
43+
44+
```python
45+
import jax
46+
import jax.numpy as jnp
47+
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
48+
49+
# Example constants
50+
BATCH_SIZE = 32
51+
SEQ_LEN = 16
52+
VOCAB_SIZE = 1024
53+
EMBEDDING_DIM = 128
54+
55+
# Define the embedding table properties
56+
table_spec = embedding_spec.TableSpec(
57+
vocabulary_size=VOCAB_SIZE,
58+
embedding_dim=EMBEDDING_DIM,
59+
name='word_embedding_table',
60+
optimizer=embedding_spec.SGDOptimizerSpec(learning_rate=0.05)
61+
)
62+
63+
# Define a feature that uses this table
64+
feature_spec = embedding_spec.FeatureSpec(
65+
table_spec=table_spec,
66+
input_shape=(BATCH_SIZE, SEQ_LEN), # Shape of the input IDs
67+
output_shape=(BATCH_SIZE, SEQ_LEN, EMBEDDING_DIM), # Desired output activation shape
68+
name='word_ids'
69+
)
70+
71+
# Feature specs are passed as a PyTree (e.g., a dictionary)
72+
feature_specs = {'word_ids': feature_spec}
73+
```
74+
75+
### 2. Create a Flax Model
76+
77+
Use the `SparseCoreEmbed` layer within a standard Flax `nn.Module`. This layer handles all the communication with the SparseCores.
78+
79+
```python
80+
from flax import linen as nn
81+
from jax_tpu_embedding.sparsecore.lib.flax import embed
82+
83+
class ShakespeareModel(nn.Module):
84+
feature_specs: embed.Nested[embedding_spec.FeatureSpec]
85+
86+
@nn.compact
87+
def __call__(self, embedding_inputs):
88+
# This layer performs the embedding lookup on SparseCores.
89+
# `embedding_inputs` is a dictionary of integer ID tensors.
90+
embedding_activations = embed.SparseCoreEmbed(
91+
feature_specs=self.feature_specs
92+
)(embedding_inputs)
93+
94+
# The result is a dictionary of activations, matching the feature spec keys.
95+
x = embedding_activations['word_ids']
96+
97+
# Flatten sequence and embedding dimensions for the dense layer
98+
x = x.reshape((x.shape[0], -1))
99+
100+
# Add a dense layer (runs on TensorCores)
101+
x = nn.Dense(VOCAB_SIZE)(x)
102+
return x
103+
```
104+
105+
### 3. Initialize and Run the Training Step
106+
107+
The `SparseCoreEmbed` layer separates its parameters (the embedding tables, which live on SparseCore HBM) from the rest of the model's parameters (which live on TensorCore HBM). You initialize them separately and pass them to the training step.
108+
109+
```python
110+
# Create a mesh for device layout
111+
devices = jax.devices()
112+
mesh = jax.sharding.Mesh(devices, axis_names=('data',))
113+
114+
# 1. Initialize the embedding tables (SparseCore parameters)
115+
# This needs to be done under the mesh context
116+
with mesh:
117+
embedding_params = embed.SparseCoreEmbed.create_embedding_variables(
118+
feature_specs, jax.random.PRNGKey(0)
119+
)
120+
121+
# 2. Initialize the dense model parts (TensorCore parameters)
122+
model = ShakespeareModel(feature_specs=feature_specs)
123+
dummy_inputs = {'word_ids': jnp.zeros((BATCH_SIZE, SEQ_LEN), dtype=jnp.int32)}
124+
dense_params = model.init(
125+
jax.random.PRNGKey(1), dummy_inputs
126+
)['params']
127+
128+
# 3. Define and JIT-compile the training step
129+
@jax.jit
130+
def train_step(dense_params, embedding_params, features, labels):
131+
def loss_fn(params):
132+
# The 'embedding' collection is automatically handled by SparseCoreEmbed
133+
logits = model.apply({'params': params, 'embedding': embedding_params}, features)
134+
# A real implementation would use a proper loss function like cross-entropy
135+
return jnp.mean(jnp.square(logits - labels)) # Dummy loss for demonstration
136+
137+
grad_fn = jax.grad(loss_fn)
138+
grads = grad_fn(dense_params)
139+
140+
# Gradients for `embedding_params` are computed and applied on the
141+
# SparseCores automatically by the `SparseCoreEmbed` layer, using the
142+
# optimizer defined in the TableSpec.
143+
144+
# Update dense parameters using your preferred optimizer
145+
new_dense_params = jax.tree.map(lambda p, g: p - 0.01 * g, dense_params, grads)
146+
147+
return new_dense_params
148+
149+
# --- Example usage with dummy data ---
150+
features = {'word_ids': jax.random.randint(
151+
jax.random.PRNGKey(2), (BATCH_SIZE, SEQ_LEN), 0, VOCAB_SIZE)}
152+
labels = jax.random.normal(jax.random.PRNGKey(3), (BATCH_SIZE, VOCAB_SIZE))
153+
154+
# Run one training step
155+
new_dense_params = train_step(dense_params, embedding_params, features, labels)
156+
```
157+
158+
---
159+
160+
## Key Concepts
161+
162+
- **`TableSpec`**: Defines the properties of a single embedding table, including its shape (`vocabulary_size`, `embedding_dim`), initializer, and the optimizer (e.g., `SGDOptimizerSpec`) to be run on the SparseCores.
163+
- **`FeatureSpec`**: Describes a logical feature that maps to a `TableSpec`. It specifies the `input_shape` of the integer ID tensor and the desired `output_shape` of the resulting activation tensors. A single table can be used by multiple features.
164+
- **`SparseCoreEmbed`**: A Flax Linen layer that acts as the entry point for embedding lookups. It takes a dictionary of feature names to ID tensors and returns a dictionary of feature names to activation tensors. It manages the embedding table parameters separately from the dense model parameters.
165+
166+
---
167+
168+
## Running the Examples
169+
170+
The repository includes a complete Shakespeare next-word-prediction model. You can run it using Bazel.
171+
172+
To run the tests for the example:
173+
```bash
174+
# Make sure you are at the root of the repository
175+
bazel test //jax_tpu_embedding/sparsecore/examples/shakespeare/...
176+
```
177+
This will build and run both the `pmap` and `jit` + `shard_map` versions of the model.
178+
179+
---
180+
181+
## Contributing
182+
183+
Contributions are welcome! Please read the [CONTRIBUTING.md](CONTRIBUTING.md) file for guidelines on how to contribute to this project.
184+
185+
---
186+
187+
## License
188+
189+
This project is licensed under the terms of the Apache 2.0 license. See the [LICENSE](LICENSE) file for more details.

0 commit comments

Comments
 (0)