Skip to content

Move preprocess to base distributed embedding class. #93

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 21, 2025

Conversation

cantonios
Copy link
Collaborator

Also updated some minor issues encountered in testing on sparsecores.

@cantonios cantonios requested a review from hertschuh May 20, 2025 23:43
@@ -183,6 +190,128 @@ def _create_table_and_slot_variables(
return output


def create_feature_samples(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are also added to test_utils.py below. This looks like a mistake.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two are actually different. This one generates random samples. The one in embedding_utils.py combines the inputs and weights into a format for use with jax_tpu_embedding.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function was only used in the embedding_lookup_test, which I previously didn't include.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to generate_feature_samples to avoid confusion.

model.fit(preprocessed_training_dataset, epochs=10)
```

For non-JAX backends, preprocessing will bundle together the inputs and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explicitly mention that preprocessing is optional.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added "optional" here again to re-iterate.

@cantonios cantonios force-pushed the preprocess branch 2 times, most recently from 091472c to d4ff069 Compare May 21, 2025 00:55
@@ -275,8 +404,8 @@ def _compute_expected_lookup_grad(
embedding_dim = activation_gradients.shape[1]
sample_lengths = jnp.array([len(sample) for sample in samples.tokens])
rows = jnp.repeat(jnp.arange(batch_size), sample_lengths)
cols = jnp.concatenate(jnp.unstack(samples.tokens))
vals = jnp.concatenate(jnp.unstack(samples.weights)).reshape(-1, 1)
cols = jnp.concatenate(np.unstack(samples.tokens)) # type: ignore[attr-defined]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really don't understand the mypy issue with this. The mypy job passes locally with the exact same version of numpy and mypy installed. The np.unstack attribute does exist in numpy 2.2.6, which is what is installed by the CI. I need np.unstack because the arrays are ragged, so jnp.unstack throws an error.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added unused-ignore to the pyproject.toml to prevent local failure.

Also updated some minor issues encountered in testing on sparsecores.
@hertschuh hertschuh merged commit 91422d5 into keras-team:main May 21, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants