Skip to content

Commit 08d02f8

Browse files
[JAX SC] Modify OutputCsrArrays and CsrArraysPerHost to use Eigen::Map.
PiperOrigin-RevId: 840489755
1 parent 53c4d81 commit 08d02f8

File tree

4 files changed

+61
-38
lines changed

4 files changed

+61
-38
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ cc_library(
148148
"@com_google_absl//absl/strings:string_view",
149149
"@com_google_absl//absl/synchronization",
150150
"@com_google_absl//absl/types:span",
151+
"@eigen_archive//:eigen3",
151152
"@tsl//tsl/platform:statusor",
152153
"@tsl//tsl/profiler/lib:traceme",
153154
"@xla//xla:util",

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
3636
#include "absl/synchronization/mutex.h" // from @com_google_absl
3737
#include "absl/types/span.h" // from @com_google_absl
38+
#include "Eigen/Core" // from @eigen_archive
3839
#include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h"
3940
#include "jax_tpu_embedding/sparsecore/lib/core/coo_format.h"
4041
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
@@ -167,8 +168,9 @@ struct TableState {
167168
bool has_variable_weights,
168169
const PreprocessSparseDenseMatmulInputOptions& options,
169170
int num_scs, int coo_buffer_size_per_device,
170-
MatrixXi& row_pointers, MatrixXi& embedding_ids,
171-
MatrixXi& sample_ids, MatrixXf& gains)
171+
Eigen::Ref<MatrixXi> row_pointers,
172+
Eigen::Ref<MatrixXi> embedding_ids,
173+
Eigen::Ref<MatrixXi> sample_ids, Eigen::Ref<MatrixXf> gains)
172174
: stacked_table_name(name),
173175
stacked_table_metadata(metadata),
174176
has_variable_weights(has_variable_weights),
@@ -537,12 +539,14 @@ void FillDeviceBuffersForTable(
537539
}
538540
}
539541

540-
std::tuple<MatrixXi&, MatrixXi&, MatrixXi&, MatrixXf&> GetOutputCsrBuffers(
541-
const std::string& stacked_table_name,
542-
const PreprocessSparseDenseMatmulInputOptions& options,
543-
int row_pointers_size_per_device, int coo_buffer_size_per_device,
544-
OutputCsrArrays* output_csr_arrays,
545-
PreprocessSparseDenseMatmulOutput& out) {
542+
std::tuple<Eigen::Ref<MatrixXi>, Eigen::Ref<MatrixXi>, Eigen::Ref<MatrixXi>,
543+
Eigen::Ref<MatrixXf>>
544+
GetOutputCsrBuffers(const std::string& stacked_table_name,
545+
const PreprocessSparseDenseMatmulInputOptions& options,
546+
int row_pointers_size_per_device,
547+
int coo_buffer_size_per_device,
548+
OutputCsrArrays* output_csr_arrays,
549+
PreprocessSparseDenseMatmulOutput& out) {
546550
if (output_csr_arrays != nullptr) {
547551
DCHECK(output_csr_arrays->lhs_row_pointers.contains(stacked_table_name))
548552
<< "Missing lhs_row_pointers for table: " << stacked_table_name;
@@ -552,22 +556,23 @@ std::tuple<MatrixXi&, MatrixXi&, MatrixXi&, MatrixXf&> GetOutputCsrBuffers(
552556
<< "Missing lhs_sample_ids for table: " << stacked_table_name;
553557
DCHECK(output_csr_arrays->lhs_gains.contains(stacked_table_name))
554558
<< "Missing lhs_gains for table: " << stacked_table_name;
555-
MatrixXi& row_pointers =
556-
output_csr_arrays->lhs_row_pointers[stacked_table_name];
559+
Eigen::Map<MatrixXi>& row_pointers =
560+
output_csr_arrays->lhs_row_pointers.find(stacked_table_name)->second;
557561
DCHECK_EQ(row_pointers.rows(), options.local_device_count);
558562
DCHECK_EQ(row_pointers.cols(), row_pointers_size_per_device);
559563

560-
MatrixXi& embedding_ids =
561-
output_csr_arrays->lhs_embedding_ids[stacked_table_name];
564+
Eigen::Map<MatrixXi>& embedding_ids =
565+
output_csr_arrays->lhs_embedding_ids.find(stacked_table_name)->second;
562566
DCHECK_EQ(embedding_ids.rows(), options.local_device_count);
563567
DCHECK_EQ(embedding_ids.cols(), coo_buffer_size_per_device);
564568

565-
MatrixXi& sample_ids =
566-
output_csr_arrays->lhs_sample_ids[stacked_table_name];
569+
Eigen::Map<MatrixXi>& sample_ids =
570+
output_csr_arrays->lhs_sample_ids.find(stacked_table_name)->second;
567571
DCHECK_EQ(sample_ids.rows(), options.local_device_count);
568572
DCHECK_EQ(sample_ids.cols(), coo_buffer_size_per_device);
569573

570-
MatrixXf& gains = output_csr_arrays->lhs_gains[stacked_table_name];
574+
Eigen::Map<MatrixXf>& gains =
575+
output_csr_arrays->lhs_gains.find(stacked_table_name)->second;
571576
DCHECK_EQ(gains.rows(), options.local_device_count);
572577
DCHECK_EQ(gains.cols(), coo_buffer_size_per_device);
573578

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -485,24 +485,44 @@ TEST_F(TableStackingTest, PreprocessInputWritesToProvidedOutputBuffers) {
485485
stacked_tables[stacked_table_metadata_multi_[0].name].push_back(
486486
stacked_table_metadata_multi_[0]);
487487

488+
StackedTableMap<MatrixXi> row_pointers_store;
489+
StackedTableMap<MatrixXi> embedding_ids_store;
490+
StackedTableMap<MatrixXi> sample_ids_store;
491+
StackedTableMap<MatrixXf> gains_store;
492+
488493
for (const auto& [table_name, metadata_list] : stacked_tables) {
489494
int coo_buffer_size = ComputeCooBufferSizePerDevice(
490495
num_scs, num_sc_per_device, metadata_list, options.batch_number,
491496
options.enable_minibatching);
492497

493-
output_csr_arrays.lhs_row_pointers[table_name].resize(
494-
local_device_count, row_pointers_size_per_device);
495-
output_csr_arrays.lhs_embedding_ids[table_name].resize(local_device_count,
496-
coo_buffer_size);
497-
output_csr_arrays.lhs_sample_ids[table_name].resize(local_device_count,
498-
coo_buffer_size);
499-
output_csr_arrays.lhs_gains[table_name].resize(local_device_count,
500-
coo_buffer_size);
501-
502-
output_csr_arrays.lhs_row_pointers[table_name].setConstant(-1);
503-
output_csr_arrays.lhs_embedding_ids[table_name].setConstant(-1);
504-
output_csr_arrays.lhs_sample_ids[table_name].setConstant(-1);
505-
output_csr_arrays.lhs_gains[table_name].setConstant(-1.0f);
498+
row_pointers_store[table_name] =
499+
MatrixXi(local_device_count, row_pointers_size_per_device);
500+
MatrixXi& row_pointers = row_pointers_store[table_name];
501+
output_csr_arrays.lhs_row_pointers.insert(
502+
{table_name, Eigen::Map<MatrixXi>(row_pointers.data(),
503+
row_pointers.rows(),
504+
row_pointers.cols())});
505+
506+
embedding_ids_store[table_name] =
507+
MatrixXi(local_device_count, coo_buffer_size);
508+
MatrixXi& embedding_ids = embedding_ids_store[table_name];
509+
output_csr_arrays.lhs_embedding_ids.insert(
510+
{table_name, Eigen::Map<MatrixXi>(embedding_ids.data(),
511+
embedding_ids.rows(),
512+
embedding_ids.cols())});
513+
514+
sample_ids_store[table_name] =
515+
MatrixXi(local_device_count, coo_buffer_size);
516+
MatrixXi& sample_ids = sample_ids_store[table_name];
517+
output_csr_arrays.lhs_sample_ids.insert(
518+
{table_name, Eigen::Map<MatrixXi>(sample_ids.data(), sample_ids.rows(),
519+
sample_ids.cols())});
520+
521+
gains_store[table_name] = MatrixXf(local_device_count, coo_buffer_size);
522+
MatrixXf& gains = gains_store[table_name];
523+
output_csr_arrays.lhs_gains.insert(
524+
{table_name,
525+
Eigen::Map<MatrixXf>(gains.data(), gains.rows(), gains.cols())});
506526
}
507527

508528
TF_ASSERT_OK_AND_ASSIGN(
@@ -512,10 +532,6 @@ TEST_F(TableStackingTest, PreprocessInputWritesToProvidedOutputBuffers) {
512532
&output_csr_arrays));
513533

514534
for (const auto& [table_name, _] : stacked_tables) {
515-
const MatrixXi& row_ptrs = output_csr_arrays.lhs_row_pointers[table_name];
516-
// Check that data was written (first element shouldn't be -1).
517-
EXPECT_NE(row_ptrs(0, 0), -1);
518-
519535
// Verify that the returned output structure has empty matrices for this
520536
// table because we provided the buffers.
521537
EXPECT_EQ(output.lhs_row_pointers[table_name].size(), 0);

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,10 @@ using StackedTableMap = absl::flat_hash_map<std::string, T>;
6565
// Container for output CSR arrays for multiple stacked tables.
6666
// Allows pre-allocated buffers to be passed in, avoiding data copies.
6767
struct OutputCsrArrays {
68-
StackedTableMap<MatrixXi> lhs_row_pointers;
69-
StackedTableMap<MatrixXi> lhs_embedding_ids;
70-
StackedTableMap<MatrixXi> lhs_sample_ids;
71-
StackedTableMap<MatrixXf> lhs_gains;
68+
StackedTableMap<Eigen::Map<MatrixXi>> lhs_row_pointers;
69+
StackedTableMap<Eigen::Map<MatrixXi>> lhs_embedding_ids;
70+
StackedTableMap<Eigen::Map<MatrixXi>> lhs_sample_ids;
71+
StackedTableMap<Eigen::Map<MatrixXf>> lhs_gains;
7272
};
7373

7474
namespace internal {
@@ -95,8 +95,9 @@ struct CsrArraysPerHost {
9595
Eigen::Map<MatrixXi> sample_ids;
9696
Eigen::Map<MatrixXf> gains;
9797

98-
CsrArraysPerHost(MatrixXi& row_pointers, MatrixXi& embedding_ids,
99-
MatrixXi& sample_ids, MatrixXf& gains)
98+
CsrArraysPerHost(Eigen::Ref<MatrixXi> row_pointers,
99+
Eigen::Ref<MatrixXi> embedding_ids,
100+
Eigen::Ref<MatrixXi> sample_ids, Eigen::Ref<MatrixXf> gains)
100101
: row_pointers(row_pointers.data(), row_pointers.rows(),
101102
row_pointers.cols()),
102103
embedding_ids(embedding_ids.data(), embedding_ids.rows(),

0 commit comments

Comments
 (0)