Skip to content

Commit 3f23b28

Browse files
[JAX SC] Extract features slice per SC in parallel.
PiperOrigin-RevId: 839887356
1 parent 595e68f commit 3f23b28

13 files changed

+530
-147
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ cc_library(
6565
"@com_google_absl//absl/strings:string_view",
6666
"@com_google_absl//absl/types:span",
6767
"@eigen_archive//:eigen3",
68+
"@tsl//tsl/concurrency:async_value",
6869
"@tsl//tsl/profiler/lib:traceme",
6970
"@xla//xla:util",
7071
],
@@ -153,6 +154,7 @@ cc_library(
153154
"@com_google_absl//absl/synchronization",
154155
"@com_google_absl//absl/types:span",
155156
"@eigen_archive//:eigen3",
157+
"@tsl//tsl/concurrency:async_value",
156158
"@tsl//tsl/platform:statusor",
157159
"@tsl//tsl/profiler/lib:traceme",
158160
"@xla//xla:util",
@@ -434,6 +436,7 @@ cc_test(
434436
"@com_google_absl//absl/types:span",
435437
"@com_google_benchmark//:benchmark_main",
436438
"@eigen_archive//:eigen3",
439+
"@tsl//tsl/concurrency:async_value",
437440
],
438441
)
439442

jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <sys/types.h>
1717

1818
#include <cstdint>
19+
#include <optional>
1920

2021
#include "absl/base/attributes.h" // from @com_google_absl
2122
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
@@ -52,6 +53,13 @@ class AbstractInputBatch {
5253
// Returns the total number of embedding IDs across all samples.
5354
virtual int64_t id_count() const = 0;
5455

56+
// Returns number of ids in rows [start_row, end_row).
57+
// If not implemented by a subclass, returns std::nullopt.
58+
virtual std::optional<int64_t> GetIdsCountInSlice(int start_row,
59+
int end_row) const {
60+
return std::nullopt;
61+
}
62+
5563
// Returns true if the input batch has variable weights.
5664
virtual bool HasVariableWeights() const { return true; }
5765

jax_tpu_embedding/sparsecore/lib/core/coo_format.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ struct CooFormat {
6464
static constexpr uint32_t kDataMask = (1 << kDataBits) - 1;
6565
// Bit offset for rotated_col_id in grouping key.
6666
static constexpr uint32_t kRotatedColIdOffset = kDataBits;
67+
68+
// For hierarchical indexing within data bits when using multiple feature
69+
// slices.
70+
static constexpr uint32_t kSliceIndexBits = 6;
71+
static constexpr uint32_t kItemIndexBits = kDataBits - kSliceIndexBits;
72+
static constexpr uint32_t kItemIndexMask = (1UL << kItemIndexBits) - 1;
73+
6774
// Bit offset for bucket_id in grouping key.
6875
static constexpr uint32_t kBucketIdOffset = kRotatedColIdOffset + 32;
6976

@@ -171,6 +178,21 @@ struct CooFormat {
171178
static uint32_t GetBucketIdFromKey(uint64_t key) {
172179
return key >> kBucketIdOffset;
173180
}
181+
182+
static uint32_t EncodeHierarchicalIndex(uint32_t slice_idx,
183+
uint32_t item_idx) {
184+
DCHECK_LT(slice_idx, 1 << kSliceIndexBits);
185+
DCHECK_LT(item_idx, 1 << kItemIndexBits);
186+
return (slice_idx << kItemIndexBits) | item_idx;
187+
}
188+
189+
static uint32_t GetSliceIndexFromData(uint32_t data) {
190+
return data >> kItemIndexBits;
191+
}
192+
193+
static uint32_t GetItemIndexFromData(uint32_t data) {
194+
return data & kItemIndexMask;
195+
}
174196
};
175197

176198
} // namespace jax_sc_embedding

jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <cmath>
1515
#include <cstdint>
1616
#include <cstdio>
17+
#include <functional>
1718
#include <limits>
1819
#include <memory>
1920
#include <optional>
@@ -35,6 +36,8 @@
3536
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
3637
#include "jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h"
3738
#include "jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h"
39+
#include "tsl/concurrency/async_value.h" // from @tsl
40+
#include "tsl/concurrency/async_value_ref.h" // from @tsl
3841

3942
namespace jax_sc_embedding {
4043

@@ -168,6 +171,8 @@ void BM_ExtractCooTensors(benchmark::State& state) {
168171
std::vector<StackedTableMetadata> stacked_table_metadata;
169172
stacked_table_metadata.reserve(num_features);
170173
for (int i = 0; i < num_features; ++i) {
174+
// Set to INT_MAX to avoid ID dropping and observe the actual statistics of
175+
// the generated data. This doesn't affect performance of grouping itself.
171176
stacked_table_metadata.push_back(StackedTableMetadata(
172177
absl::StrCat("table_", i), /*feature_index=*/i,
173178
/*max_ids_per_partition=*/std::numeric_limits<int>::max(),
@@ -188,9 +193,13 @@ void BM_ExtractCooTensors(benchmark::State& state) {
188193
};
189194

190195
for (auto s : state) {
191-
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
192-
stacked_table_metadata, absl::MakeSpan(input_batches),
193-
/*local_device_id=*/0, options);
196+
std::vector<tsl::AsyncValueRef<ExtractedCooTensors>> results_av =
197+
internal::ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync(
198+
stacked_table_metadata, absl::MakeSpan(input_batches),
199+
/*local_device_id=*/0, options);
200+
for (auto& av : results_av) {
201+
tsl::BlockUntilReady(av.GetAsyncValue());
202+
}
194203
}
195204
}
196205
BENCHMARK(BM_ExtractCooTensors)
@@ -233,10 +242,20 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
233242
.enable_minibatching = true,
234243
};
235244

236-
ExtractedCooTensors extracted_coo_tensors =
237-
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
245+
// Extract COO tensors for all features on a single local device.
246+
std::vector<tsl::AsyncValueRef<ExtractedCooTensors>> feature_results_av =
247+
internal::ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync(
238248
stacked_table_metadata_list, absl::MakeSpan(input_batches),
239249
/*local_device_id=*/0, options);
250+
for (auto& av : feature_results_av) {
251+
tsl::BlockUntilReady(av.GetAsyncValue());
252+
}
253+
std::vector<ExtractedCooTensors> feature_results;
254+
std::vector<std::reference_wrapper<const ExtractedCooTensors>>
255+
extracted_coo_tensors_list;
256+
internal::GetExtractedCooTensorsFromAsyncValues(
257+
absl::MakeSpan(feature_results_av), feature_results,
258+
extracted_coo_tensors_list);
240259

241260
bool minibatching_required = false;
242261
StatsPerHost stats_per_host(
@@ -248,8 +267,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
248267

249268
if (state.thread_index() == 0) {
250269
SortAndGroupCooTensorsPerLocalDevice</*kHasVariableWeights=*/false>(
251-
extracted_coo_tensors, stacked_table_metadata_list[0], options,
252-
stats_per_device, minibatching_required);
270+
absl::MakeSpan(extracted_coo_tensors_list),
271+
stacked_table_metadata_list[0], options, stats_per_device,
272+
minibatching_required);
253273
LogStats(stats_per_device.max_ids_per_partition,
254274
"Max ids per partition across all global SCs");
255275
LogStats(stats_per_device.max_unique_ids_per_partition,
@@ -258,8 +278,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
258278

259279
for (auto s : state) {
260280
SortAndGroupCooTensorsPerLocalDevice</*kHasVariableWeights=*/false>(
261-
extracted_coo_tensors, stacked_table_metadata_list[0], options,
262-
stats_per_device, minibatching_required);
281+
absl::MakeSpan(extracted_coo_tensors_list),
282+
stacked_table_metadata_list[0], options, stats_per_device,
283+
minibatching_required);
263284
}
264285
}
265286
BENCHMARK(BM_SortAndGroup_Phase1)

0 commit comments

Comments
 (0)