diff --git a/jax_tpu_embedding/sparsecore/lib/core/BUILD b/jax_tpu_embedding/sparsecore/lib/core/BUILD index 20b77982..f1584584 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/core/BUILD @@ -67,6 +67,7 @@ cc_library( "@eigen_archive//:eigen3", "@tsl//tsl/profiler/lib:traceme", "@xla//xla:util", + "@xla//xla/tsl/concurrency:async_value", ], ) @@ -433,6 +434,7 @@ cc_test( "@com_google_absl//absl/types:span", "@com_google_benchmark//:benchmark_main", "@eigen_archive//:eigen3", + "@xla//xla/tsl/concurrency:async_value", ], ) diff --git a/jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h b/jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h index d72611fc..98b7eeda 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h +++ b/jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h @@ -16,6 +16,7 @@ #include #include +#include #include "absl/base/attributes.h" // from @com_google_absl #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" @@ -52,6 +53,13 @@ class AbstractInputBatch { // Returns the total number of embedding IDs across all samples. virtual int64_t id_count() const = 0; + // Returns number of ids in rows [start_row, end_row). + // If not implemented by a subclass, returns std::nullopt. + virtual std::optional GetIdsCountInSlice(int start_row, + int end_row) const { + return std::nullopt; + } + // Returns true if the input batch has variable weights. virtual bool HasVariableWeights() const { return true; } diff --git a/jax_tpu_embedding/sparsecore/lib/core/coo_format.h b/jax_tpu_embedding/sparsecore/lib/core/coo_format.h index 938508b4..b27a5fc5 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/coo_format.h +++ b/jax_tpu_embedding/sparsecore/lib/core/coo_format.h @@ -64,6 +64,7 @@ struct CooFormat { static constexpr uint32_t kDataMask = (1 << kDataBits) - 1; // Bit offset for rotated_col_id in grouping key. static constexpr uint32_t kRotatedColIdOffset = kDataBits; + // Bit offset for bucket_id in grouping key. static constexpr uint32_t kBucketIdOffset = kRotatedColIdOffset + 32; diff --git a/jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc b/jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc index 30862213..7d4c0b72 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc @@ -35,6 +35,8 @@ #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" #include "jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h" #include "jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h" +#include "xla/tsl/concurrency/async_value.h" // from @xla +#include "xla/tsl/concurrency/async_value_ref.h" // from @xla namespace jax_sc_embedding { @@ -168,6 +170,8 @@ void BM_ExtractCooTensors(benchmark::State& state) { std::vector stacked_table_metadata; stacked_table_metadata.reserve(num_features); for (int i = 0; i < num_features; ++i) { + // Set to INT_MAX to avoid ID dropping and observe the actual statistics of + // the generated data. This doesn't affect performance of grouping itself. stacked_table_metadata.push_back(StackedTableMetadata( absl::StrCat("table_", i), /*feature_index=*/i, /*max_ids_per_partition=*/std::numeric_limits::max(), @@ -188,9 +192,13 @@ void BM_ExtractCooTensors(benchmark::State& state) { }; for (auto s : state) { - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - stacked_table_metadata, absl::MakeSpan(input_batches), - /*local_device_id=*/0, options); + std::vector> results_av = + internal::ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync( + stacked_table_metadata, absl::MakeSpan(input_batches), + /*local_device_id=*/0, options); + for (auto& av : results_av) { + tsl::BlockUntilReady(av.GetAsyncValue()); + } } } BENCHMARK(BM_ExtractCooTensors) @@ -233,6 +241,7 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) { .enable_minibatching = true, }; + // Extract COO tensors for all features on a single local device. ExtractedCooTensors extracted_coo_tensors = internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( stacked_table_metadata_list, absl::MakeSpan(input_batches), @@ -248,8 +257,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) { if (state.thread_index() == 0) { SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata_list[0], options, - stats_per_device, minibatching_required); + extracted_coo_tensors, + stacked_table_metadata_list[0], options, stats_per_device, + minibatching_required); LogStats(stats_per_device.max_ids_per_partition, "Max ids per partition across all global SCs"); LogStats(stats_per_device.max_unique_ids_per_partition, @@ -258,8 +268,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) { for (auto s : state) { SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata_list[0], options, - stats_per_device, minibatching_required); + extracted_coo_tensors, + stacked_table_metadata_list[0], options, stats_per_device, + minibatching_required); } } BENCHMARK(BM_SortAndGroup_Phase1) diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc index 960c7711..720fbde3 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc @@ -19,6 +19,7 @@ #include #include #include +#include #include #include #include @@ -41,6 +42,7 @@ #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" #include "jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h" #include "jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h" +#include "xla/tsl/concurrency/async_value.h" // from @xla #include "xla/tsl/concurrency/async_value_ref.h" // from @xla #include "tsl/platform/env.h" // from @tsl #include "tsl/platform/statusor.h" // from @tsl @@ -68,14 +70,34 @@ void ExtractCooTensorsForSingleFeatureSlice( const int local_device_id, const int feature_slice_id, const int feature_slices_per_device, const PreprocessSparseDenseMatmulInputOptions& options, + const int batch_size_for_device_total, ExtractedCooTensors& extracted_coo_tensors) { const int feature_index = metadata.feature_index; const std::unique_ptr& curr_batch = input_batches[feature_index]; const int num_samples = curr_batch->size(); + const int num_samples_per_split = + num_samples / (options.local_device_count * feature_slices_per_device); + const int start_index = + (local_device_id * feature_slices_per_device + feature_slice_id) * + num_samples_per_split; + int end_index = std::min(num_samples, start_index + num_samples_per_split); - const int batch_size_per_slice = xla::CeilOfRatio( - extracted_coo_tensors.batch_size_for_device, feature_slices_per_device); + // Reserve space in COO tensor vector if input batch provides slice + // count. + std::optional ids_in_slice = + curr_batch->GetIdsCountInSlice(start_index, end_index); + + if (ids_in_slice.has_value()) { + extracted_coo_tensors.coo_tensors.reserve(ids_in_slice.value()); + } else { + extracted_coo_tensors.coo_tensors.reserve(xla::CeilOfRatio( + curr_batch->id_count(), + options.local_device_count * feature_slices_per_device)); + } + + const int batch_size_per_slice = + xla::CeilOfRatio(batch_size_for_device_total, feature_slices_per_device); CHECK_GT(feature_slices_per_device, 0); CHECK_GT(options.global_device_count, 0); @@ -89,21 +111,13 @@ void ExtractCooTensorsForSingleFeatureSlice( const int col_offset = metadata.col_offset; const int col_shift = metadata.col_shift; - const int num_samples_per_split = - num_samples / (options.local_device_count * feature_slices_per_device); - const int start_index = - (local_device_id * feature_slices_per_device + feature_slice_id) * - num_samples_per_split; - int end_index = std::min(num_samples, start_index + num_samples_per_split); - - // In the case of feature stacking, we need to group all the COO tensors - // at this stage (i.e., before the sorting later on). VLOG(2) << absl::StrFormat( "Extracting COO Tensor from feature #%d from row %d to %d " "(local_device_id = %d, feature_slice_id = %d, row_offset = %d, " "batch_size_per_slice = %d)", feature_index, start_index, end_index, local_device_id, feature_slice_id, row_offset, batch_size_per_slice); + curr_batch->ExtractCooTensors( { .slice_start = start_index, @@ -187,25 +201,23 @@ struct TableState { }; template -void SortAndGroupCooTensorsForTableState( - TableState& state, int local_device, +PartitionedCooTensors SortAndGroupCooTensorsForTableState( + bool has_variable_weights, + const StackedTableMetadata& stacked_table_metadata, const PreprocessSparseDenseMatmulInputOptions& options, - internal::StatsPerDevice& stats, SplitType& split) { - if (state.has_variable_weights) { - state.partitioned_coo_tensors_per_device[local_device] = - SortAndGroupCooTensorsPerLocalDevice( - state.extracted_coo_tensors_per_device[local_device], - state.stacked_table_metadata[0], options, stats, split); + internal::StatsPerDevice& stats, SplitType& split, + const ExtractedCooTensors& extracted_coo_tensors) { + if (has_variable_weights) { + return SortAndGroupCooTensorsPerLocalDevice( + extracted_coo_tensors, stacked_table_metadata, options, stats, split); } else { - state.partitioned_coo_tensors_per_device[local_device] = - SortAndGroupCooTensorsPerLocalDevice( - state.extracted_coo_tensors_per_device[local_device], - state.stacked_table_metadata[0], options, stats, split); + return SortAndGroupCooTensorsPerLocalDevice( + extracted_coo_tensors, stacked_table_metadata, options, stats, split); } } -// Extracts, sorts, and groups COO tensors for a single stacked table across -// all local devices. This function populates +// Extracts, sorts, and groups COO tensors for a single stacked table across all +// local devices. This function populates // `state.extracted_coo_tensors_per_device` and // `state.partitioned_coo_tensors_per_device`. void ExtractSortAndGroupCooTensorsForTable( @@ -221,21 +233,49 @@ void ExtractSortAndGroupCooTensorsForTable( }); for (int local_device = 0; local_device < options.local_device_count; ++local_device) { - options.async_task_scheduler( - [&, local_device, &state = state, input_batches] { - state.extracted_coo_tensors_per_device[local_device] = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - state.stacked_table_metadata, input_batches, local_device, - options); - - internal::StatsPerDevice stats_per_device = - state.stats_per_host.GetStatsPerDevice(local_device); - SortAndGroupCooTensorsForTableState( - state, local_device, options, stats_per_device, - state.table_minibatching_required); - state.dropped_id_count_per_device[local_device] = - stats_per_device.dropped_id_count; - counter.DecrementCount(); + std::vector> + extracted_coo_tensors_av = + internal::ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync( + state.stacked_table_metadata, input_batches, local_device, + options); + + tsl::RunWhenReady( + absl::MakeConstSpan(extracted_coo_tensors_av), + [&, local_device, extracted_coo_tensors_av, &state = state]() mutable { + options.async_task_scheduler([&, local_device, + extracted_coo_tensors_av, + &state = state]() mutable { + int64_t total_ids_for_device = 0; + for (const auto& av : extracted_coo_tensors_av) { + total_ids_for_device += av.get().coo_tensors.size(); + } + + ExtractedCooTensors& merged_result = + state.extracted_coo_tensors_per_device[local_device]; + merged_result = ExtractedCooTensors( + options.num_sc_per_device, + extracted_coo_tensors_av[0].get().batch_size_for_device); + merged_result.coo_tensors.reserve(total_ids_for_device); + + for (auto& av : extracted_coo_tensors_av) { + merged_result.Append(std::move(av.get())); + } + state.batch_size_for_device = merged_result.batch_size_for_device; + internal::StatsPerDevice stats_per_device = + state.stats_per_host.GetStatsPerDevice(local_device); + + state.partitioned_coo_tensors_per_device[local_device] = + SortAndGroupCooTensorsForTableState( + state.has_variable_weights, state.stacked_table_metadata[0], + options, stats_per_device, + state.table_minibatching_required, + state.extracted_coo_tensors_per_device[local_device]); + + state.dropped_id_count_per_device[local_device] = + stats_per_device.dropped_id_count; + + counter.DecrementCount(); + }); }); } } @@ -278,9 +318,11 @@ void CreateMinibatchingBucketsForTable( options.num_sc_per_device); internal::StatsPerDevice dummy_stats = dummy_stats_host.GetStatsPerDevice(0); - SortAndGroupCooTensorsForTableState(state, local_device, options, - dummy_stats, - state.table_minibatching_split); + state.partitioned_coo_tensors_per_device[local_device] = + SortAndGroupCooTensorsForTableState( + state.has_variable_weights, state.stacked_table_metadata[0], + options, dummy_stats, state.table_minibatching_split, + state.extracted_coo_tensors_per_device[local_device]); state.dropped_id_count_per_device[local_device] = dummy_stats.dropped_id_count; counter.DecrementCount(); @@ -301,11 +343,20 @@ inline bool Serialize(bool value) { return value; } inline bool Deserialize(bool value) { return value; } // Extract the COO tensors for all features. -ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( +std::vector> +ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync( const absl::Span stacked_table_metadata, absl::Span> input_batches, const int local_device_id, const PreprocessSparseDenseMatmulInputOptions& options) { + tsl::profiler::TraceMe traceme([&] { + return tsl::profiler::TraceMeEncode( + absl::StrCat("ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync-", + stacked_table_metadata[0].name), + {{"batch_number", options.batch_number}}); + }); + // Calculate total batch size and total number of IDs for this device by + // summing up sizes and ID counts of all features in the stacked table. int batch_size_for_device = 0; int64_t total_ids_for_device = 0; for (const auto& feature_metadata : stacked_table_metadata) { @@ -319,6 +370,10 @@ ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( CheckDeviceBatchSize(batch_size_for_device, options.num_sc_per_device, stacked_table_metadata[0].name); + // Determine the number of slices per feature based on stacking strategy. + // kStackThenSplit: All features are stacked first, then split, so 1 slice. + // kSplitThenStack: Each feature is split, then stacked, so + // num_sc_per_device slices. int feature_slices_per_device; switch (options.feature_stacking_strategy) { case FeatureStackingStrategy::kStackThenSplit: @@ -348,15 +403,84 @@ ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( // SC1: F1_2, F2_2, ... Fn_2, // <- batch_size_per_slice // ... // <- batch_size_per_slice // SCk: F1_k, F2_k, ..., Fn_k // <- batch_size_per_slice - for (int feature_slice_id = 0; feature_slice_id < feature_slices_per_device; - ++feature_slice_id) { - for (const auto& feature_metadata : stacked_table_metadata) { - ExtractCooTensorsForSingleFeatureSlice( - feature_metadata, input_batches, local_device_id, feature_slice_id, - feature_slices_per_device, options, extracted_coo_tensors); + + // Holds async results for each feature and slice. The size is + // num_features * num_slices. + std::vector> feature_results_av; + feature_results_av.reserve(feature_slices_per_device * + stacked_table_metadata.size()); + for (int i = 0; + i < feature_slices_per_device * stacked_table_metadata.size(); ++i) { + feature_results_av.push_back( + tsl::MakeUnconstructedAsyncValueRef()); + } + + // Schedule all feature slices for parallel execution. + for (int feature_idx = 0; feature_idx < stacked_table_metadata.size(); + ++feature_idx) { + for (int feature_slice_id = 0; feature_slice_id < feature_slices_per_device; + ++feature_slice_id) { + options.async_task_scheduler([=, feature_idx = feature_idx, + feature_slice_id = feature_slice_id, + feature_results_av = feature_results_av, + &options]() mutable { + const StackedTableMetadata& feature_metadata = + stacked_table_metadata[feature_idx]; + // Calculate the index in `feature_results_av` for the current + // feature and slice. + int result_idx = + feature_slice_id * stacked_table_metadata.size() + feature_idx; + + ExtractedCooTensors result(options.num_sc_per_device, + batch_size_for_device); + + // Extract COO tensors for the current slice. + ExtractCooTensorsForSingleFeatureSlice( + feature_metadata, input_batches, local_device_id, feature_slice_id, + feature_slices_per_device, options, batch_size_for_device, result); + + // Emplace result into async value to signal completion. + feature_results_av[result_idx].emplace(std::move(result)); + }); } } - return extracted_coo_tensors; + + return feature_results_av; +} + +// Sync version of ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync. +// Merges the results from all features and slices into a single +// ExtractedCooTensors object. This is only for testing. +ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( + const absl::Span stacked_table_metadata, + absl::Span> input_batches, + const int local_device_id, + const PreprocessSparseDenseMatmulInputOptions& options) { + std::vector> results_av = + ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync( + stacked_table_metadata, input_batches, local_device_id, options); + tsl::AsyncValueRef merged_av = + tsl::MakeUnconstructedAsyncValueRef(); + tsl::RunWhenReady( + absl::MakeConstSpan(results_av), + [results_av, merged_av, &options]() mutable { + int64_t total_ids_for_device = 0; + for (const auto& av : results_av) { + total_ids_for_device += av.get().coo_tensors.size(); + } + + ExtractedCooTensors merged_result( + options.num_sc_per_device, + results_av[0].get().batch_size_for_device); + merged_result.coo_tensors.reserve(total_ids_for_device); + for (auto& av : results_av) { + merged_result.Append(std::move(av.get())); + } + merged_av.emplace(std::move(merged_result)); + }); + + tsl::BlockUntilReady(merged_av.GetAsyncValue()); + return merged_av.get(); } } // namespace internal @@ -520,6 +644,7 @@ void FillDeviceBuffersForTable( global_minibatching_split] { PartitionedCooTensors& grouped_coo_tensors = state.partitioned_coo_tensors_per_device[local_device]; + if (options.enable_minibatching && global_minibatching_required) { grouped_coo_tensors.Merge(global_minibatching_split); } @@ -680,6 +805,7 @@ PreprocessSparseDenseMatmulInput( }); absl::BlockingCounter counter(table_states.size() * options.local_device_count); + for (auto& state : table_states) { ExtractSortAndGroupCooTensorsForTable(state, input_batches, options, counter); diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h index f4a8e483..937e4b39 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h @@ -23,6 +23,7 @@ #include "absl/types/span.h" // from @com_google_absl #include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h" #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" +#include "xla/tsl/concurrency/async_value_ref.h" // from @xla namespace jax_sc_embedding { @@ -42,6 +43,21 @@ struct SparseDenseMatmulInputStats { }; namespace internal { + +// Asynchronously extracts COO tensors for all features on a given local +// device. Returns a vector of AsyncValueRefs, one for each feature slice, +// which will become ready as extraction completes. +std::vector> +ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync( + absl::Span stacked_table_metadata, + absl::Span> input_batches, + int local_device_id, + const PreprocessSparseDenseMatmulInputOptions& options); + +// Synchronously extracts and merges COO tensors for all features on a given +// local device. This function calls the asynchronous version and blocks until +// all results are ready, then merges them into a single ExtractedCooTensors +// object. ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( absl::Span stacked_table_metadata, absl::Span> input_batches, diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h index 0fcc5b6a..905155b8 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,7 @@ #include "absl/base/nullability.h" // from @com_google_absl #include "absl/container/flat_hash_map.h" // from @com_google_absl #include "absl/functional/any_invocable.h" // from @com_google_absl +#include "absl/log/check.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl #include "Eigen/Core" // from @eigen_archive @@ -34,6 +36,7 @@ #include "jax_tpu_embedding/sparsecore/lib/core/coo_format.h" #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h" #include "jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h" +#include "xla/tsl/concurrency/async_value_ref.h" // from @xla namespace jax_sc_embedding { @@ -74,6 +77,44 @@ struct OutputCsrArrays { StackedTableMap> lhs_gains; }; +struct ExtractedCooTensors { + std::vector coo_tensors; + // Number of samples these coo_tensors are extracted from. + int batch_size_for_device; + // Count coo tensors per SC for efficient allocation of vector for sorting and + // grouping them. Might be lower after deduplication. + std::vector coo_tensors_per_sc; + + ExtractedCooTensors() : ExtractedCooTensors(0, 0) {} + ExtractedCooTensors(int num_sc_per_device, int batch_size_for_device) + : batch_size_for_device(batch_size_for_device), + coo_tensors_per_sc(num_sc_per_device, 0) {} + + // Test only constructor. + ExtractedCooTensors(int num_sc_per_device, int batch_size_for_device, + absl::Span coos) + : coo_tensors(coos.begin(), coos.end()), + batch_size_for_device(batch_size_for_device), + coo_tensors_per_sc(num_sc_per_device, 0) { + DCHECK_GT(num_sc_per_device, 0); + DCHECK_EQ(batch_size_for_device % num_sc_per_device, 0); + const int batch_size_per_sc = batch_size_for_device / num_sc_per_device; + for (const auto& coo : coo_tensors) { + coo_tensors_per_sc[coo.row_id / batch_size_per_sc]++; + } + } + + // Appends other to this, leaving other in an unspecified state. + void Append(ExtractedCooTensors&& other) { + coo_tensors.insert(coo_tensors.end(), + std::make_move_iterator(other.coo_tensors.begin()), + std::make_move_iterator(other.coo_tensors.end())); + for (int i = 0; i < coo_tensors_per_sc.size(); ++i) { + coo_tensors_per_sc[i] += other.coo_tensors_per_sc[i]; + } + } +}; + namespace internal { struct CsrArraysPerDevice { @@ -238,20 +279,6 @@ enum class RowCombiner { RowCombiner GetRowCombiner(absl::string_view combiner); -struct ExtractedCooTensors { - std::vector coo_tensors; - // Number of samples these coo_tensors are extracted from. - int batch_size_for_device; - // Count coo tensors per SC for efficient allocation of vector for sorting and - // grouping them. Might be lower after deduplication. - std::vector coo_tensors_per_sc; - - ExtractedCooTensors() : ExtractedCooTensors(0, 0) {} - ExtractedCooTensors(int num_sc_per_device, int batch_size_for_device) - : batch_size_for_device(batch_size_for_device), - coo_tensors_per_sc(num_sc_per_device, 0) {} -}; - struct StackedTableMetadata { StackedTableMetadata() = delete; StackedTableMetadata( diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc index 2bb1b776..f598aef4 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc @@ -146,8 +146,7 @@ TEST(SortAndGroupTest, Base) { coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 8, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, @@ -157,6 +156,7 @@ TEST(SortAndGroupTest, Base) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -234,8 +234,7 @@ TEST(SortAndGroupTest, TwoScs) { coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(2, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(2, 8, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, @@ -245,6 +244,7 @@ TEST(SortAndGroupTest, TwoScs) { .global_device_count = 1, .num_sc_per_device = 2, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/2, @@ -303,8 +303,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations1) { coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 8, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/2, /*max_unique_ids_per_partition=*/1, /*row_offset=*/0, /*col_offset=*/0, @@ -314,6 +313,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations1) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -354,8 +354,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations2) { coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 16); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 16, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/4, /*max_unique_ids_per_partition=*/1, /*row_offset=*/0, /*col_offset=*/0, @@ -365,6 +364,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations2) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -410,8 +410,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations3) { coo_formats.push_back(CooFormat(row, 6, 1.0)); coo_formats.push_back(CooFormat(row, 7, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 16); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 16, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/8, /*max_unique_ids_per_partition=*/2, /*row_offset=*/0, /*col_offset=*/0, @@ -421,6 +420,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations3) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -467,8 +467,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations4) { coo_formats.push_back(CooFormat(row, 6, 1.0)); coo_formats.push_back(CooFormat(row, 7, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 128); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 128, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/64, /*max_unique_ids_per_partition=*/2, /*row_offset=*/0, /*col_offset=*/0, @@ -478,6 +477,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations4) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -519,8 +519,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations5) { coo_formats.push_back(CooFormat(row, 8, 1.0)); coo_formats.push_back(CooFormat(row, 16, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 128); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 128, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/128, /*max_unique_ids_per_partition=*/4, /*row_offset=*/0, /*col_offset=*/0, @@ -530,6 +529,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations5) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -571,8 +571,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations6) { for (int row = 0; row < 128; ++row) { coo_formats.push_back(CooFormat(row, row * 4, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 128); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 128, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, @@ -582,6 +581,7 @@ TEST(SortAndGroupTest, VerifyIdLimitations6) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -624,8 +624,7 @@ TEST(SortAndGroupTest, IdDropping) { } // Force dropping of IDs here with max_ids_per_partition == 2 // The later 2 samples for each sparsecore will be dropped. - ExtractedCooTensors extracted_coo_tensors(4, 16); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 16, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/2, /*max_unique_ids_per_partition=*/1, /*row_offset=*/0, /*col_offset=*/0, @@ -635,6 +634,7 @@ TEST(SortAndGroupTest, IdDropping) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = true, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; bool minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -715,8 +715,7 @@ TEST(InputPreprocessingUtilTest, FillBuffer) { coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 8, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, @@ -726,6 +725,7 @@ TEST(InputPreprocessingUtilTest, FillBuffer) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, @@ -846,8 +846,7 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingSingleMinibatch) { coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 8, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, @@ -858,6 +857,7 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingSingleMinibatch) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, .enable_minibatching = true, .minibatching_bucketing_hash_fn = hash_fn}; MinibatchingSplit minibatching_split = 0; @@ -976,8 +976,7 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingFourMinibatches) { coo_formats.push_back(CooFormat(row, col, 1.0)); } } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + ExtractedCooTensors extracted_coo_tensors(4, 8, coo_formats); StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, @@ -988,6 +987,7 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingFourMinibatches) { .global_device_count = 1, .num_sc_per_device = 4, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, .enable_minibatching = true, .minibatching_bucketing_hash_fn = hash_fn}; MinibatchingSplit minibatching_split = 0; @@ -1153,8 +1153,7 @@ TEST(InputPreprocessingUtilTest, coo_formats.emplace_back(/*row=*/1, /*col=*/1, /*gain=*/1.0); ExtractedCooTensors extracted(/*num_sc_per_device=*/1, - /*batch_size_for_device=*/4); - extracted.coo_tensors = coo_formats; + /*batch_size_for_device=*/4, coo_formats); StackedTableMetadata meta("stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, @@ -1167,6 +1166,7 @@ TEST(InputPreprocessingUtilTest, .global_device_count = 1, .num_sc_per_device = 1, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, }; bool minibatching_required = false; @@ -1217,8 +1217,7 @@ TEST(InputPreprocessingUtilTest, coo_formats.emplace_back(/*row=*/3, /*col=*/1, /*gain=*/1.0); ExtractedCooTensors extracted(/*num_sc_per_device=*/1, - /*batch_size_for_device=*/4); - extracted.coo_tensors = coo_formats; + /*batch_size_for_device=*/4, coo_formats); StackedTableMetadata meta("stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, @@ -1232,6 +1231,7 @@ TEST(InputPreprocessingUtilTest, .global_device_count = 1, .num_sc_per_device = 1, .allow_id_dropping = false, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit, .enable_minibatching = true, .minibatching_bucketing_hash_fn = hash_fn, }; diff --git a/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h b/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h index 9454ea0b..8396c293 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h +++ b/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h @@ -64,6 +64,14 @@ class NumpySparseInputBatch : public AbstractInputBatch { // Returns the total number of embedding IDs across all samples. int64_t id_count() const override { return id_count_; } + std::optional GetIdsCountInSlice(int start_row, + int end_row) const override { + if (feature_.ndim() == 2) { + return (end_row - start_row) * feature_.shape(1); + } + return std::nullopt; + } + bool HasVariableWeights() const override { return weights_.has_value(); } void ExtractCooTensors(const ExtractCooTensorsOptions& options, diff --git a/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h b/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h index 27136fac..e8a3990f 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h +++ b/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h @@ -15,6 +15,7 @@ #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_RAGGED_TENSOR_INPUT_BATCH_H_ #include #include +#include #include #include @@ -94,6 +95,11 @@ class ABSL_ATTRIBUTE_VIEW RaggedTensorInputBatch : public AbstractInputBatch { // Returns the total number of embedding IDs across all samples. int64_t id_count() const override { return row_offsets_[size()]; } + std::optional GetIdsCountInSlice( + int start_row, int end_row) const override { + return row_offsets_[end_row] - row_offsets_[start_row]; + } + bool HasVariableWeights() const override { return false; } void ExtractCooTensors(const ExtractCooTensorsOptions& options, @@ -139,6 +145,11 @@ class RaggedTensorInputBatchWithOwnedData : public AbstractInputBatch { // Returns the total number of embedding IDs across all samples. int64_t id_count() const override { return view_.id_count(); } + std::optional GetIdsCountInSlice( + int start_row, int end_row) const override { + return view_.GetIdsCountInSlice(start_row, end_row); + } + bool HasVariableWeights() const override { return view_.HasVariableWeights(); } diff --git a/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h b/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h index 322d8a67..86becbf0 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h +++ b/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h @@ -329,8 +329,9 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl( absl::Span coo_tensors = extracted_coo_tensors.coo_tensors; const int num_sc_per_device = options.num_sc_per_device; bool allow_id_dropping = options.allow_id_dropping; - const int batch_size_per_sc = xla::CeilOfRatio( - extracted_coo_tensors.batch_size_for_device, options.num_sc_per_device); + const int batch_size_for_device = extracted_coo_tensors.batch_size_for_device; + const int batch_size_per_sc = + xla::CeilOfRatio(batch_size_for_device, options.num_sc_per_device); const uint32_t global_sc_count = options.GetNumScs(); const int num_sc_bits = absl::bit_width(global_sc_count - 1); const int max_ids_per_partition = @@ -387,8 +388,10 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl( // The key here is [bucket_id(6 bits), global_sc_id(num_scs bits), // local_embedding_id(32-num_scs bits), index(26 bits)]. // Note that this assumes `num_scs` is a power of 2. + uint32_t data = + kHasVariableWeights ? coo_tensor_index : coo_tensor.row_id; keys.push_back(coo_tensor.GetGroupingKey( - num_sc_bits, coo_tensor_index, kCreateBuckets, + num_sc_bits, data, kCreateBuckets, options.minibatching_bucketing_hash_fn, kHasVariableWeights)); DCHECK(kHasVariableWeights || coo_tensors[coo_tensor_index].gain == 1.0f) << "kHasVariableWeights: " << kHasVariableWeights diff --git a/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.cc b/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.cc index 16a952df..f5f10a57 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.cc @@ -31,11 +31,12 @@ namespace jax_sc_embedding { -void PySparseCooInputBatch::ConstructRowPointers() { +void PySparseCooInputBatch::ConstructRowPointers() const { if (!row_pointers_.empty()) { return; } auto indices_array = indices_.unchecked<2>(); + auto values_array = values_.unchecked<1>(); // Precompute indexes for row starts. Add a sentinel node for last row. row_pointers_.reserve(batch_size_ + 1); int row_pointers_index = 0; @@ -44,7 +45,7 @@ void PySparseCooInputBatch::ConstructRowPointers() { int last_val = -1; // Only for DCHECK. for (int i = 0; i < indices_array.shape(0); ++i) { const int row_id = indices_array(i, 0), col_id = indices_array(i, 1), - val = values_.at(i); + val = values_array(i); DCHECK_GE(row_id, last_row_id) << "Decreasing row id values for row-major."; while (row_pointers_index <= row_id) { // Increment index until we reach the current row. Keep storing the row @@ -73,7 +74,7 @@ void PySparseCooInputBatch::ConstructRowPointers() { DCHECK_EQ(row_pointers_.size(), batch_size_ + 1); } -void PySparseCooInputBatch::ConstructRowPointersIfRequired() { +void PySparseCooInputBatch::ConstructRowPointersIfRequired() const { absl::call_once(row_pointer_construction_flag_, &PySparseCooInputBatch::ConstructRowPointers, this); } diff --git a/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h b/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h index d5f5b6af..d3ef15bf 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h +++ b/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h @@ -17,6 +17,7 @@ #include #include +#include #include #include #include @@ -62,6 +63,12 @@ class PySparseCooInputBatch : public AbstractInputBatch { // Returns the total number of embedding IDs across all samples. int64_t id_count() const override { return values_.size(); } + std::optional GetIdsCountInSlice(int start_row, + int end_row) const override { + ConstructRowPointersIfRequired(); + return row_pointers_[end_row] - row_pointers_[start_row]; + } + bool HasVariableWeights() const override { return false; } // Extracts COO tensors for each SparseCore. @@ -76,16 +83,16 @@ class PySparseCooInputBatch : public AbstractInputBatch { const int64_t batch_size_; const std::string table_name_; - std::vector row_pointers_; - absl::once_flag row_pointer_construction_flag_; + mutable std::vector row_pointers_; + mutable absl::once_flag row_pointer_construction_flag_; // Converts this to a CSR format. A refactor could return an object of type // SparseCsrInputBatch after Slicing, and ExtractCooTensors can call // the same function on a temporary object of SparseCsrInputBatch type. - void ConstructRowPointersIfRequired(); + void ConstructRowPointersIfRequired() const; // Internal function called by `ConstructRowPointersIfRequired`. - void ConstructRowPointers(); + void ConstructRowPointers() const; }; } // namespace jax_sc_embedding