-
Notifications
You must be signed in to change notification settings - Fork 9
Move preprocess to base distributed embedding class. #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -183,6 +190,128 @@ def _create_table_and_slot_variables( | |||
return output | |||
|
|||
|
|||
def create_feature_samples( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are also added to test_utils.py
below. This looks like a mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The two are actually different. This one generates random samples. The one in embedding_utils.py
combines the inputs and weights into a format for use with jax_tpu_embedding
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function was only used in the embedding_lookup_test
, which I previously didn't include.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to generate_feature_samples
to avoid confusion.
model.fit(preprocessed_training_dataset, epochs=10) | ||
``` | ||
|
||
For non-JAX backends, preprocessing will bundle together the inputs and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explicitly mention that preprocessing is optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added "optional" here again to re-iterate.
091472c
to
d4ff069
Compare
@@ -275,8 +404,8 @@ def _compute_expected_lookup_grad( | |||
embedding_dim = activation_gradients.shape[1] | |||
sample_lengths = jnp.array([len(sample) for sample in samples.tokens]) | |||
rows = jnp.repeat(jnp.arange(batch_size), sample_lengths) | |||
cols = jnp.concatenate(jnp.unstack(samples.tokens)) | |||
vals = jnp.concatenate(jnp.unstack(samples.weights)).reshape(-1, 1) | |||
cols = jnp.concatenate(np.unstack(samples.tokens)) # type: ignore[attr-defined] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really don't understand the mypy issue with this. The mypy job passes locally with the exact same version of numpy
and mypy
installed. The np.unstack
attribute does exist in numpy 2.2.6, which is what is installed by the CI. I need np.unstack
because the arrays are ragged, so jnp.unstack
throws an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added unused-ignore
to the pyproject.toml
to prevent local failure.
Also updated some minor issues encountered in testing on sparsecores.
Also updated some minor issues encountered in testing on sparsecores.