Skip to content

Commit 2ee8f80

Browse files
committed
Use old code for preprocessing
1 parent ce15ed7 commit 2ee8f80

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def sparsecore_build(
441441
)
442442

443443
# Collect all stacked tables.
444-
table_specs = embedding_utils.get_table_specs(feature_specs)
444+
table_specs = embedding.get_table_specs(feature_specs)
445445
table_stacks = embedding_utils.get_table_stacks(table_specs)
446446

447447
# Create variables for all stacked tables and slot variables.
@@ -515,7 +515,7 @@ def _sparsecore_symbolic_preprocess(
515515
del inputs, weights, training
516516

517517
# Each stacked-table gets a ShardedCooMatrix.
518-
table_specs = embedding_utils.get_table_specs(
518+
table_specs = embedding.get_table_specs(
519519
self._config.feature_specs
520520
)
521521
table_stacks = embedding_utils.get_table_stacks(table_specs)
@@ -750,7 +750,7 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
750750

751751
config = self._config
752752
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
753-
table_specs = embedding_utils.get_table_specs(config.feature_specs)
753+
table_specs = embedding.get_table_specs(config.feature_specs)
754754

755755
# Extract only the table variables, not the gradient slot variables.
756756
table_variables = {

0 commit comments

Comments
 (0)