Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ cc_library(
"@eigen_archive//:eigen3",
"@tsl//tsl/profiler/lib:traceme",
"@xla//xla:util",
"@xla//xla/tsl/concurrency:async_value",
],
)

Expand Down Expand Up @@ -433,6 +434,7 @@ cc_test(
"@com_google_absl//absl/types:span",
"@com_google_benchmark//:benchmark_main",
"@eigen_archive//:eigen3",
"@xla//xla/tsl/concurrency:async_value",
],
)

Expand Down
8 changes: 8 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <sys/types.h>

#include <cstdint>
#include <optional>

#include "absl/base/attributes.h" // from @com_google_absl
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
Expand Down Expand Up @@ -52,6 +53,13 @@ class AbstractInputBatch {
// Returns the total number of embedding IDs across all samples.
virtual int64_t id_count() const = 0;

// Returns number of ids in rows [start_row, end_row).
// If not implemented by a subclass, returns std::nullopt.
virtual std::optional<int64_t> GetIdsCountInSlice(int start_row,
int end_row) const {
return std::nullopt;
}

// Returns true if the input batch has variable weights.
virtual bool HasVariableWeights() const { return true; }

Expand Down
1 change: 1 addition & 0 deletions jax_tpu_embedding/sparsecore/lib/core/coo_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ struct CooFormat {
static constexpr uint32_t kDataMask = (1 << kDataBits) - 1;
// Bit offset for rotated_col_id in grouping key.
static constexpr uint32_t kRotatedColIdOffset = kDataBits;

// Bit offset for bucket_id in grouping key.
static constexpr uint32_t kBucketIdOffset = kRotatedColIdOffset + 32;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
#include "jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h"
#include "jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h"
#include "xla/tsl/concurrency/async_value.h" // from @xla
#include "xla/tsl/concurrency/async_value_ref.h" // from @xla

namespace jax_sc_embedding {

Expand Down Expand Up @@ -168,6 +170,8 @@ void BM_ExtractCooTensors(benchmark::State& state) {
std::vector<StackedTableMetadata> stacked_table_metadata;
stacked_table_metadata.reserve(num_features);
for (int i = 0; i < num_features; ++i) {
// Set to INT_MAX to avoid ID dropping and observe the actual statistics of
// the generated data. This doesn't affect performance of grouping itself.
stacked_table_metadata.push_back(StackedTableMetadata(
absl::StrCat("table_", i), /*feature_index=*/i,
/*max_ids_per_partition=*/std::numeric_limits<int>::max(),
Expand All @@ -188,9 +192,13 @@ void BM_ExtractCooTensors(benchmark::State& state) {
};

for (auto s : state) {
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
stacked_table_metadata, absl::MakeSpan(input_batches),
/*local_device_id=*/0, options);
std::vector<tsl::AsyncValueRef<ExtractedCooTensors>> results_av =
internal::ExtractCooTensorsForAllFeaturesPerLocalDeviceAsync(
stacked_table_metadata, absl::MakeSpan(input_batches),
/*local_device_id=*/0, options);
for (auto& av : results_av) {
tsl::BlockUntilReady(av.GetAsyncValue());
}
}
}
BENCHMARK(BM_ExtractCooTensors)
Expand Down Expand Up @@ -233,6 +241,7 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
.enable_minibatching = true,
};

// Extract COO tensors for all features on a single local device.
ExtractedCooTensors extracted_coo_tensors =
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
stacked_table_metadata_list, absl::MakeSpan(input_batches),
Expand All @@ -248,8 +257,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {

if (state.thread_index() == 0) {
SortAndGroupCooTensorsPerLocalDevice</*kHasVariableWeights=*/false>(
extracted_coo_tensors, stacked_table_metadata_list[0], options,
stats_per_device, minibatching_required);
extracted_coo_tensors,
stacked_table_metadata_list[0], options, stats_per_device,
minibatching_required);
LogStats(stats_per_device.max_ids_per_partition,
"Max ids per partition across all global SCs");
LogStats(stats_per_device.max_unique_ids_per_partition,
Expand All @@ -258,8 +268,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {

for (auto s : state) {
SortAndGroupCooTensorsPerLocalDevice</*kHasVariableWeights=*/false>(
extracted_coo_tensors, stacked_table_metadata_list[0], options,
stats_per_device, minibatching_required);
extracted_coo_tensors,
stacked_table_metadata_list[0], options, stats_per_device,
minibatching_required);
}
}
BENCHMARK(BM_SortAndGroup_Phase1)
Expand Down
Loading
Loading