3535#include " absl/synchronization/blocking_counter.h" // from @com_google_absl
3636#include " absl/synchronization/mutex.h" // from @com_google_absl
3737#include " absl/types/span.h" // from @com_google_absl
38+ #include " Eigen/Core" // from @eigen_archive
3839#include " jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h"
3940#include " jax_tpu_embedding/sparsecore/lib/core/coo_format.h"
4041#include " jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
@@ -167,8 +168,9 @@ struct TableState {
167168 bool has_variable_weights,
168169 const PreprocessSparseDenseMatmulInputOptions& options,
169170 int num_scs, int coo_buffer_size_per_device,
170- MatrixXi& row_pointers, MatrixXi& embedding_ids,
171- MatrixXi& sample_ids, MatrixXf& gains)
171+ Eigen::Ref<MatrixXi> row_pointers,
172+ Eigen::Ref<MatrixXi> embedding_ids,
173+ Eigen::Ref<MatrixXi> sample_ids, Eigen::Ref<MatrixXf> gains)
172174 : stacked_table_name(name),
173175 stacked_table_metadata (metadata),
174176 has_variable_weights(has_variable_weights),
@@ -537,12 +539,14 @@ void FillDeviceBuffersForTable(
537539 }
538540}
539541
540- std::tuple<MatrixXi&, MatrixXi&, MatrixXi&, MatrixXf&> GetOutputCsrBuffers (
541- const std::string& stacked_table_name,
542- const PreprocessSparseDenseMatmulInputOptions& options,
543- int row_pointers_size_per_device, int coo_buffer_size_per_device,
544- OutputCsrArrays* output_csr_arrays,
545- PreprocessSparseDenseMatmulOutput& out) {
542+ std::tuple<Eigen::Ref<MatrixXi>, Eigen::Ref<MatrixXi>, Eigen::Ref<MatrixXi>,
543+ Eigen::Ref<MatrixXf>>
544+ GetOutputCsrBuffers (const std::string& stacked_table_name,
545+ const PreprocessSparseDenseMatmulInputOptions& options,
546+ int row_pointers_size_per_device,
547+ int coo_buffer_size_per_device,
548+ OutputCsrArrays* output_csr_arrays,
549+ PreprocessSparseDenseMatmulOutput& out) {
546550 if (output_csr_arrays != nullptr ) {
547551 DCHECK (output_csr_arrays->lhs_row_pointers .contains (stacked_table_name))
548552 << " Missing lhs_row_pointers for table: " << stacked_table_name;
@@ -552,22 +556,23 @@ std::tuple<MatrixXi&, MatrixXi&, MatrixXi&, MatrixXf&> GetOutputCsrBuffers(
552556 << " Missing lhs_sample_ids for table: " << stacked_table_name;
553557 DCHECK (output_csr_arrays->lhs_gains .contains (stacked_table_name))
554558 << " Missing lhs_gains for table: " << stacked_table_name;
555- MatrixXi& row_pointers =
556- output_csr_arrays->lhs_row_pointers [ stacked_table_name] ;
559+ Eigen::Map< MatrixXi> & row_pointers =
560+ output_csr_arrays->lhs_row_pointers . find ( stacked_table_name)-> second ;
557561 DCHECK_EQ (row_pointers.rows (), options.local_device_count );
558562 DCHECK_EQ (row_pointers.cols (), row_pointers_size_per_device);
559563
560- MatrixXi& embedding_ids =
561- output_csr_arrays->lhs_embedding_ids [ stacked_table_name] ;
564+ Eigen::Map< MatrixXi> & embedding_ids =
565+ output_csr_arrays->lhs_embedding_ids . find ( stacked_table_name)-> second ;
562566 DCHECK_EQ (embedding_ids.rows (), options.local_device_count );
563567 DCHECK_EQ (embedding_ids.cols (), coo_buffer_size_per_device);
564568
565- MatrixXi& sample_ids =
566- output_csr_arrays->lhs_sample_ids [ stacked_table_name] ;
569+ Eigen::Map< MatrixXi> & sample_ids =
570+ output_csr_arrays->lhs_sample_ids . find ( stacked_table_name)-> second ;
567571 DCHECK_EQ (sample_ids.rows (), options.local_device_count );
568572 DCHECK_EQ (sample_ids.cols (), coo_buffer_size_per_device);
569573
570- MatrixXf& gains = output_csr_arrays->lhs_gains [stacked_table_name];
574+ Eigen::Map<MatrixXf>& gains =
575+ output_csr_arrays->lhs_gains .find (stacked_table_name)->second ;
571576 DCHECK_EQ (gains.rows (), options.local_device_count );
572577 DCHECK_EQ (gains.cols (), coo_buffer_size_per_device);
573578
0 commit comments