Skip to content

Commit 9ced4ed

Browse files
committed
Use old code for preprocessing
1 parent 5782407 commit 9ced4ed

File tree

1 file changed

+44
-57
lines changed

1 file changed

+44
-57
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 44 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _add_table_variable(
265265
table_specs: Sequence[embedding_spec.TableSpec],
266266
num_shards: int,
267267
add_slot_variables: bool,
268-
) -> embedding.EmbeddingVariables:
268+
) -> tuple[keras.Variable, tuple[keras.Variable, ...] | None]:
269269
stacked_table_spec = typing.cast(
270270
embedding_spec.StackedTableSpec, table_specs[0].stacked_table_spec
271271
)
@@ -334,7 +334,7 @@ def _add_table_variable(
334334
slot_initializers, slot_variables
335335
)
336336

337-
return embedding.EmbeddingVariables(table_variable, slot_variables)
337+
return table_variable, slot_variables
338338

339339
@keras_utils.no_automatic_dependency_tracking
340340
def _sparsecore_init(
@@ -441,8 +441,8 @@ def sparsecore_build(
441441
)
442442

443443
# Collect all stacked tables.
444-
table_specs = embedding.get_table_specs(feature_specs)
445-
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
444+
table_specs = embedding_utils.get_table_specs(feature_specs)
445+
table_stacks = embedding_utils.get_table_stacks(table_specs)
446446

447447
# Create variables for all stacked tables and slot variables.
448448
with sparsecore_distribution.scope():
@@ -515,8 +515,10 @@ def _sparsecore_symbolic_preprocess(
515515
del inputs, weights, training
516516

517517
# Each stacked-table gets a ShardedCooMatrix.
518-
table_specs = embedding.get_table_specs(self._config.feature_specs)
519-
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
518+
table_specs = embedding_utils.get_table_specs(
519+
self._config.feature_specs
520+
)
521+
table_stacks = embedding_utils.get_table_stacks(table_specs)
520522
stacked_table_specs = {
521523
stack_name: stack[0].stacked_table_spec
522524
for stack_name, stack in table_stacks.items()
@@ -580,18 +582,12 @@ def _sparsecore_preprocess(
580582
)
581583

582584
layout = self._sparsecore_layout
583-
print(f"-->{layout=}")
584585
mesh = layout.device_mesh.backend_mesh
585-
print(f"-->{mesh=}")
586586
global_device_count = mesh.devices.size
587-
print(f"-->{global_device_count=}")
588587
local_device_count = mesh.local_mesh.devices.size
589-
print(f"{local_device_count=}")
590588
num_sc_per_device = jte_utils.num_sparsecores_per_device(
591589
mesh.devices.item(0)
592590
)
593-
print(f"-->{num_sc_per_device=}")
594-
print(f"-->{jax.process_count()=}")
595591

596592
preprocessed, stats = embedding_utils.stack_and_shard_samples(
597593
self._config.feature_specs,
@@ -600,51 +596,44 @@ def _sparsecore_preprocess(
600596
global_device_count,
601597
num_sc_per_device,
602598
)
603-
print(f"-->{stats=}")
604599

605600
if training:
606601
# Synchronize input statistics across all devices and update the
607602
# underlying stacked tables specs in the feature specs.
603+
prev_stats = embedding_utils.get_stacked_table_stats(
604+
self._config.feature_specs
605+
)
606+
607+
# Take the maximum with existing stats.
608+
stats = keras.tree.map_structure(max, prev_stats, stats)
609+
610+
# Flatten the stats so we can more efficiently transfer them
611+
# between hosts. We use jax.tree because we will later need to
612+
# unflatten.
613+
flat_stats, stats_treedef = jax.tree.flatten(stats)
608614

609-
# Aggregate stats across all processes/devices via pmax.
615+
# In the case of multiple local CPU devices per host, we need to
616+
# replicate the stats to placate JAX collectives.
610617
num_local_cpu_devices = jax.local_device_count("cpu")
611-
print(f"-->{num_local_cpu_devices=}")
612-
613-
def pmax_aggregate(x: Any) -> Any:
614-
if not hasattr(x, "ndim"):
615-
x = np.array(x)
616-
jax.debug.print("--> x.shape={}", x.shape)
617-
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
618-
jax.debug.print("--> tiled_x.shape={}", tiled_x.shape)
619-
return jax.pmap(
620-
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
621-
axis_name="all_cpus",
622-
backend="cpu",
623-
)(tiled_x)[0]
624-
625-
full_stats = jax.tree.map(pmax_aggregate, stats)
626-
627-
# Check if stats changed enough to warrant action.
628-
stacked_table_specs = embedding.get_stacked_table_specs(
629-
self._config.feature_specs
618+
tiled_stats = np.tile(
619+
np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1)
630620
)
631-
changed = any(
632-
np.max(full_stats.max_ids_per_partition[stack_name])
633-
> spec.max_ids_per_partition
634-
or np.max(full_stats.max_unique_ids_per_partition[stack_name])
635-
> spec.max_unique_ids_per_partition
636-
or (
637-
np.max(full_stats.required_buffer_size_per_sc[stack_name])
638-
* num_sc_per_device
639-
)
640-
> (spec.suggested_coo_buffer_size_per_device or 0)
641-
for stack_name, spec in stacked_table_specs.items()
621+
622+
# Aggregate variables across all processes/devices.
623+
max_across_cpus = jax.pmap(
624+
lambda x: jax.lax.pmax( # type: ignore[no-untyped-call]
625+
x, "all_cpus"
626+
),
627+
axis_name="all_cpus",
628+
backend="cpu",
642629
)
630+
flat_stats = max_across_cpus(tiled_stats)[0].tolist()
631+
stats = jax.tree.unflatten(stats_treedef, flat_stats)
643632

644633
# Update configuration and repeat preprocessing if stats changed.
645-
if changed:
646-
embedding.update_preprocessing_parameters(
647-
self._config.feature_specs, full_stats, num_sc_per_device
634+
if stats != prev_stats:
635+
embedding_utils.update_stacked_table_stats(
636+
self._config.feature_specs, stats
648637
)
649638

650639
# Re-execute preprocessing with consistent input statistics.
@@ -729,8 +718,8 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
729718

730719
config = self._config
731720
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
732-
table_specs = embedding.get_table_specs(config.feature_specs)
733-
sharded_tables = jte_table_stacking.stack_and_shard_tables(
721+
table_specs = embedding_utils.get_table_specs(config.feature_specs)
722+
sharded_tables = embedding_utils.stack_and_shard_tables(
734723
table_specs,
735724
tables,
736725
num_table_shards,
@@ -749,8 +738,8 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
749738
# Assign stacked table variables to the device values.
750739
keras.tree.map_structure_up_to(
751740
device_tables,
752-
lambda embedding_variables,
753-
table_value: embedding_variables.table.assign(table_value),
741+
lambda table_and_slot_variables,
742+
table_value: table_and_slot_variables[0].assign(table_value),
754743
self._table_and_slot_variables,
755744
device_tables,
756745
)
@@ -761,19 +750,17 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
761750

762751
config = self._config
763752
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
764-
table_specs = embedding.get_table_specs(config.feature_specs)
753+
table_specs = embedding_utils.get_table_specs(config.feature_specs)
765754

766755
# Extract only the table variables, not the gradient slot variables.
767756
table_variables = {
768-
name: jax.device_get(embedding_variables.table.value)
769-
for name, embedding_variables in (
770-
self._table_and_slot_variables.items()
771-
)
757+
name: jax.device_get(table_and_slots[0].value)
758+
for name, table_and_slots in self._table_and_slot_variables.items()
772759
}
773760

774761
return typing.cast(
775762
dict[str, ArrayLike],
776-
jte_table_stacking.unshard_and_unstack_tables(
763+
embedding_utils.unshard_and_unstack_tables(
777764
table_specs, table_variables, num_table_shards
778765
),
779766
)

0 commit comments

Comments
 (0)