@@ -265,7 +265,7 @@ def _add_table_variable(
265265 table_specs : Sequence [embedding_spec .TableSpec ],
266266 num_shards : int ,
267267 add_slot_variables : bool ,
268- ) -> embedding . EmbeddingVariables :
268+ ) -> tuple [ keras . Variable , tuple [ keras . Variable , ...] | None ] :
269269 stacked_table_spec = typing .cast (
270270 embedding_spec .StackedTableSpec , table_specs [0 ].stacked_table_spec
271271 )
@@ -334,7 +334,7 @@ def _add_table_variable(
334334 slot_initializers , slot_variables
335335 )
336336
337- return embedding . EmbeddingVariables ( table_variable , slot_variables )
337+ return table_variable , slot_variables
338338
339339 @keras_utils .no_automatic_dependency_tracking
340340 def _sparsecore_init (
@@ -441,8 +441,8 @@ def sparsecore_build(
441441 )
442442
443443 # Collect all stacked tables.
444- table_specs = embedding .get_table_specs (feature_specs )
445- table_stacks = jte_table_stacking .get_table_stacks (table_specs )
444+ table_specs = embedding_utils .get_table_specs (feature_specs )
445+ table_stacks = embedding_utils .get_table_stacks (table_specs )
446446
447447 # Create variables for all stacked tables and slot variables.
448448 with sparsecore_distribution .scope ():
@@ -515,8 +515,10 @@ def _sparsecore_symbolic_preprocess(
515515 del inputs , weights , training
516516
517517 # Each stacked-table gets a ShardedCooMatrix.
518- table_specs = embedding .get_table_specs (self ._config .feature_specs )
519- table_stacks = jte_table_stacking .get_table_stacks (table_specs )
518+ table_specs = embedding_utils .get_table_specs (
519+ self ._config .feature_specs
520+ )
521+ table_stacks = embedding_utils .get_table_stacks (table_specs )
520522 stacked_table_specs = {
521523 stack_name : stack [0 ].stacked_table_spec
522524 for stack_name , stack in table_stacks .items ()
@@ -580,18 +582,12 @@ def _sparsecore_preprocess(
580582 )
581583
582584 layout = self ._sparsecore_layout
583- print (f"-->{ layout = } " )
584585 mesh = layout .device_mesh .backend_mesh
585- print (f"-->{ mesh = } " )
586586 global_device_count = mesh .devices .size
587- print (f"-->{ global_device_count = } " )
588587 local_device_count = mesh .local_mesh .devices .size
589- print (f"{ local_device_count = } " )
590588 num_sc_per_device = jte_utils .num_sparsecores_per_device (
591589 mesh .devices .item (0 )
592590 )
593- print (f"-->{ num_sc_per_device = } " )
594- print (f"-->{ jax .process_count ()= } " )
595591
596592 preprocessed , stats = embedding_utils .stack_and_shard_samples (
597593 self ._config .feature_specs ,
@@ -600,51 +596,44 @@ def _sparsecore_preprocess(
600596 global_device_count ,
601597 num_sc_per_device ,
602598 )
603- print (f"-->{ stats = } " )
604599
605600 if training :
606601 # Synchronize input statistics across all devices and update the
607602 # underlying stacked tables specs in the feature specs.
603+ prev_stats = embedding_utils .get_stacked_table_stats (
604+ self ._config .feature_specs
605+ )
606+
607+ # Take the maximum with existing stats.
608+ stats = keras .tree .map_structure (max , prev_stats , stats )
609+
610+ # Flatten the stats so we can more efficiently transfer them
611+ # between hosts. We use jax.tree because we will later need to
612+ # unflatten.
613+ flat_stats , stats_treedef = jax .tree .flatten (stats )
608614
609- # Aggregate stats across all processes/devices via pmax.
615+ # In the case of multiple local CPU devices per host, we need to
616+ # replicate the stats to placate JAX collectives.
610617 num_local_cpu_devices = jax .local_device_count ("cpu" )
611- print (f"-->{ num_local_cpu_devices = } " )
612-
613- def pmax_aggregate (x : Any ) -> Any :
614- if not hasattr (x , "ndim" ):
615- x = np .array (x )
616- jax .debug .print ("--> x.shape={}" , x .shape )
617- tiled_x = np .tile (x , (num_local_cpu_devices , * ([1 ] * x .ndim )))
618- jax .debug .print ("--> tiled_x.shape={}" , tiled_x .shape )
619- return jax .pmap (
620- lambda y : jax .lax .pmax (y , "all_cpus" ), # type: ignore[no-untyped-call]
621- axis_name = "all_cpus" ,
622- backend = "cpu" ,
623- )(tiled_x )[0 ]
624-
625- full_stats = jax .tree .map (pmax_aggregate , stats )
626-
627- # Check if stats changed enough to warrant action.
628- stacked_table_specs = embedding .get_stacked_table_specs (
629- self ._config .feature_specs
618+ tiled_stats = np .tile (
619+ np .array (flat_stats , dtype = np .int32 ), (num_local_cpu_devices , 1 )
630620 )
631- changed = any (
632- np .max (full_stats .max_ids_per_partition [stack_name ])
633- > spec .max_ids_per_partition
634- or np .max (full_stats .max_unique_ids_per_partition [stack_name ])
635- > spec .max_unique_ids_per_partition
636- or (
637- np .max (full_stats .required_buffer_size_per_sc [stack_name ])
638- * num_sc_per_device
639- )
640- > (spec .suggested_coo_buffer_size_per_device or 0 )
641- for stack_name , spec in stacked_table_specs .items ()
621+
622+ # Aggregate variables across all processes/devices.
623+ max_across_cpus = jax .pmap (
624+ lambda x : jax .lax .pmax ( # type: ignore[no-untyped-call]
625+ x , "all_cpus"
626+ ),
627+ axis_name = "all_cpus" ,
628+ backend = "cpu" ,
642629 )
630+ flat_stats = max_across_cpus (tiled_stats )[0 ].tolist ()
631+ stats = jax .tree .unflatten (stats_treedef , flat_stats )
643632
644633 # Update configuration and repeat preprocessing if stats changed.
645- if changed :
646- embedding . update_preprocessing_parameters (
647- self ._config .feature_specs , full_stats , num_sc_per_device
634+ if stats != prev_stats :
635+ embedding_utils . update_stacked_table_stats (
636+ self ._config .feature_specs , stats
648637 )
649638
650639 # Re-execute preprocessing with consistent input statistics.
@@ -729,8 +718,8 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
729718
730719 config = self ._config
731720 num_table_shards = config .mesh .devices .size * config .num_sc_per_device
732- table_specs = embedding .get_table_specs (config .feature_specs )
733- sharded_tables = jte_table_stacking .stack_and_shard_tables (
721+ table_specs = embedding_utils .get_table_specs (config .feature_specs )
722+ sharded_tables = embedding_utils .stack_and_shard_tables (
734723 table_specs ,
735724 tables ,
736725 num_table_shards ,
@@ -749,8 +738,8 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
749738 # Assign stacked table variables to the device values.
750739 keras .tree .map_structure_up_to (
751740 device_tables ,
752- lambda embedding_variables ,
753- table_value : embedding_variables . table .assign (table_value ),
741+ lambda table_and_slot_variables ,
742+ table_value : table_and_slot_variables [ 0 ] .assign (table_value ),
754743 self ._table_and_slot_variables ,
755744 device_tables ,
756745 )
@@ -761,19 +750,17 @@ def _sparsecore_get_embedding_tables(self) -> dict[str, ArrayLike]:
761750
762751 config = self ._config
763752 num_table_shards = config .mesh .devices .size * config .num_sc_per_device
764- table_specs = embedding .get_table_specs (config .feature_specs )
753+ table_specs = embedding_utils .get_table_specs (config .feature_specs )
765754
766755 # Extract only the table variables, not the gradient slot variables.
767756 table_variables = {
768- name : jax .device_get (embedding_variables .table .value )
769- for name , embedding_variables in (
770- self ._table_and_slot_variables .items ()
771- )
757+ name : jax .device_get (table_and_slots [0 ].value )
758+ for name , table_and_slots in self ._table_and_slot_variables .items ()
772759 }
773760
774761 return typing .cast (
775762 dict [str , ArrayLike ],
776- jte_table_stacking .unshard_and_unstack_tables (
763+ embedding_utils .unshard_and_unstack_tables (
777764 table_specs , table_variables , num_table_shards
778765 ),
779766 )
0 commit comments