1414#include < cmath>
1515#include < cstdint>
1616#include < cstdio>
17+ #include < functional>
1718#include < limits>
1819#include < memory>
1920#include < optional>
3536#include " jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
3637#include " jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h"
3738#include " jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h"
39+ #include " tsl/concurrency/async_value.h" // from @tsl
40+ #include " tsl/concurrency/async_value_ref.h" // from @tsl
3841
3942namespace jax_sc_embedding {
4043
@@ -168,6 +171,8 @@ void BM_ExtractCooTensors(benchmark::State& state) {
168171 std::vector<StackedTableMetadata> stacked_table_metadata;
169172 stacked_table_metadata.reserve (num_features);
170173 for (int i = 0 ; i < num_features; ++i) {
174+ // Set to INT_MAX to avoid ID dropping and observe the actual statistics of
175+ // the generated data. This doesn't affect performance of grouping itself.
171176 stacked_table_metadata.push_back (StackedTableMetadata (
172177 absl::StrCat (" table_" , i), /* feature_index=*/ i,
173178 /* max_ids_per_partition=*/ std::numeric_limits<int >::max (),
@@ -188,9 +193,13 @@ void BM_ExtractCooTensors(benchmark::State& state) {
188193 };
189194
190195 for (auto s : state) {
191- internal::ExtractCooTensorsForAllFeaturesPerLocalDevice (
192- stacked_table_metadata, absl::MakeSpan (input_batches),
193- /* local_device_id=*/ 0 , options);
196+ std::vector<tsl::AsyncValueRef<ExtractedCooTensors>> results_av =
197+ internal::ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync (
198+ stacked_table_metadata, absl::MakeSpan (input_batches),
199+ /* local_device_id=*/ 0 , options);
200+ for (auto & av : results_av) {
201+ tsl::BlockUntilReady (av.GetAsyncValue ());
202+ }
194203 }
195204}
196205BENCHMARK (BM_ExtractCooTensors)
@@ -233,10 +242,20 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
233242 .enable_minibatching = true ,
234243 };
235244
236- ExtractedCooTensors extracted_coo_tensors =
237- internal::ExtractCooTensorsForAllFeaturesPerLocalDevice (
245+ // Extract COO tensors for all features on a single local device.
246+ std::vector<tsl::AsyncValueRef<ExtractedCooTensors>> feature_results_av =
247+ internal::ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync (
238248 stacked_table_metadata_list, absl::MakeSpan (input_batches),
239249 /* local_device_id=*/ 0 , options);
250+ for (auto & av : feature_results_av) {
251+ tsl::BlockUntilReady (av.GetAsyncValue ());
252+ }
253+ std::vector<ExtractedCooTensors> feature_results;
254+ std::vector<std::reference_wrapper<const ExtractedCooTensors>>
255+ extracted_coo_tensors_list;
256+ internal::GetExtractedCooTensorsFromAsyncValues (
257+ absl::MakeSpan (feature_results_av), feature_results,
258+ extracted_coo_tensors_list);
240259
241260 bool minibatching_required = false ;
242261 StatsPerHost stats_per_host (
@@ -248,8 +267,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
248267
249268 if (state.thread_index () == 0 ) {
250269 SortAndGroupCooTensorsPerLocalDevice</* kHasVariableWeights=*/ false >(
251- extracted_coo_tensors, stacked_table_metadata_list[0 ], options,
252- stats_per_device, minibatching_required);
270+ absl::MakeSpan (extracted_coo_tensors_list),
271+ stacked_table_metadata_list[0 ], options, stats_per_device,
272+ minibatching_required);
253273 LogStats (stats_per_device.max_ids_per_partition ,
254274 " Max ids per partition across all global SCs" );
255275 LogStats (stats_per_device.max_unique_ids_per_partition ,
@@ -258,8 +278,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
258278
259279 for (auto s : state) {
260280 SortAndGroupCooTensorsPerLocalDevice</* kHasVariableWeights=*/ false >(
261- extracted_coo_tensors, stacked_table_metadata_list[0 ], options,
262- stats_per_device, minibatching_required);
281+ absl::MakeSpan (extracted_coo_tensors_list),
282+ stacked_table_metadata_list[0 ], options, stats_per_device,
283+ minibatching_required);
263284 }
264285}
265286BENCHMARK (BM_SortAndGroup_Phase1)
0 commit comments