Skip to content

Commit 595e68f

Browse files
[JAX SC] Add the optionality in input pre-processing to pass in a custom scheduler callback.
Useful in cases where the client wishes to have a tighter control over scheduling the pre-processing work. PiperOrigin-RevId: 840582756
1 parent 08d02f8 commit 595e68f

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ cc_library(
5454
deps = [
5555
":all_reduce_interface",
5656
":coo_format",
57+
":input_preprocessing_threads",
5758
":partitioned_coo_tensors",
5859
"@com_google_absl//absl/base:core_headers",
5960
"@com_google_absl//absl/base:nullability",
6061
"@com_google_absl//absl/container:flat_hash_map",
62+
"@com_google_absl//absl/functional:any_invocable",
6163
"@com_google_absl//absl/log",
6264
"@com_google_absl//absl/log:check",
6365
"@com_google_absl//absl/strings:string_view",
@@ -140,6 +142,8 @@ cc_library(
140142
"@com_google_absl//absl/algorithm:container",
141143
"@com_google_absl//absl/base:core_headers",
142144
"@com_google_absl//absl/container:flat_hash_map",
145+
"@com_google_absl//absl/functional:any_invocable",
146+
"@com_google_absl//absl/functional:bind_front",
143147
"@com_google_absl//absl/log",
144148
"@com_google_absl//absl/log:check",
145149
"@com_google_absl//absl/status:statusor",

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ void ExtractSortAndGroupCooTensorsForTable(
220220
});
221221
for (int local_device = 0; local_device < options.local_device_count;
222222
++local_device) {
223-
PreprocessingThreadPool()->Schedule(
223+
options.async_task_scheduler(
224224
[&, local_device, &state = state, input_batches] {
225225
state.extracted_coo_tensors_per_device[local_device] =
226226
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
@@ -267,7 +267,7 @@ void CreateMinibatchingBucketsForTable(
267267
state.stats_per_host.dropped_id_count = 0;
268268
for (int local_device = 0; local_device < options.local_device_count;
269269
++local_device) {
270-
PreprocessingThreadPool()->Schedule([&, local_device, &state = state] {
270+
options.async_task_scheduler([&, local_device, &state = state] {
271271
// Note: We create a dummy stats object here because we don't want to
272272
// overwrite the stats from the first pass, which are authoritative.
273273
// The only stat we care about from this second pass is the number of

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,19 @@
2020
#include <limits>
2121
#include <optional>
2222
#include <string>
23+
#include <utility>
2324
#include <vector>
2425

2526
#include "absl/base/attributes.h" // from @com_google_absl
2627
#include "absl/base/nullability.h" // from @com_google_absl
2728
#include "absl/container/flat_hash_map.h" // from @com_google_absl
29+
#include "absl/functional/any_invocable.h" // from @com_google_absl
2830
#include "absl/strings/string_view.h" // from @com_google_absl
2931
#include "absl/types/span.h" // from @com_google_absl
3032
#include "Eigen/Core" // from @eigen_archive
3133
#include "jax_tpu_embedding/sparsecore/lib/core/all_reduce_interface.h"
3234
#include "jax_tpu_embedding/sparsecore/lib/core/coo_format.h"
35+
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
3336
#include "jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h"
3437

3538
namespace jax_sc_embedding {
@@ -215,6 +218,12 @@ struct PreprocessSparseDenseMatmulInputOptions {
215218
// Hash function used for creating minibatching buckets.
216219
CooFormat::HashFn minibatching_bucketing_hash_fn = HighwayHash;
217220

221+
// Callback to schedule async work.
222+
absl::AnyInvocable<void(std::function<void()>) const> async_task_scheduler =
223+
[](std::function<void()> callback) {
224+
PreprocessingThreadPool()->Schedule(std::move(callback));
225+
};
226+
218227
// Returns the total number of SparseCores across all devices and hosts.
219228
uint32_t GetNumScs() const { return num_sc_per_device * global_device_count; }
220229
};

0 commit comments

Comments
 (0)