3232#include < cstdio>
3333#include < memory>
3434
35- #include < omp.h>
36-
3735namespace arrow {
3836namespace acero {
3937struct BenchmarkSettings {
@@ -56,6 +54,8 @@ struct BenchmarkSettings {
5654 int var_length_max = 20 ; // Maximum length of any var length types
5755
5856 Expression residual_filter = literal(true );
57+
58+ bool stats_probe_rows = true ;
5959};
6060
6161class JoinBenchmark {
@@ -128,6 +128,7 @@ class JoinBenchmark {
128128 for (ExecBatch& batch : r_batches_with_schema.batches )
129129 r_batches_.InsertBatch (std::move (batch));
130130
131+ stats_.num_build_rows = settings.num_build_batches * settings.batch_size ;
131132 stats_.num_probe_rows = settings.num_probe_batches * settings.batch_size ;
132133
133134 schema_mgr_ = std::make_unique<HashJoinSchema>();
@@ -141,14 +142,9 @@ class JoinBenchmark {
141142 join_ = *HashJoinImpl::MakeSwiss ();
142143 }
143144
144- omp_set_num_threads (settings.num_threads );
145- auto schedule_callback = [](std::function<Status (size_t )> func) -> Status {
146- #pragma omp task
147- { DCHECK_OK (func (omp_get_thread_num ())); }
148- return Status::OK ();
149- };
150-
151145 scheduler_ = TaskScheduler::Make ();
146+ thread_pool_ = arrow::internal::GetCpuThreadPool ();
147+ DCHECK_OK (thread_pool_->SetCapacity (settings.num_threads ));
152148 DCHECK_OK (ctx_.Init (nullptr ));
153149
154150 auto register_task_group_callback = [&](std::function<Status (size_t , int64_t )> task,
@@ -157,15 +153,15 @@ class JoinBenchmark {
157153 };
158154
159155 auto start_task_group_callback = [&](int task_group_id, int64_t num_tasks) {
160- return scheduler_->StartTaskGroup (omp_get_thread_num () , task_group_id, num_tasks);
156+ return scheduler_->StartTaskGroup (/* thread_id= */ 0 , task_group_id, num_tasks);
161157 };
162158
163159 DCHECK_OK (join_->Init (
164160 &ctx_, settings.join_type , settings.num_threads , &(schema_mgr_->proj_maps [0 ]),
165161 &(schema_mgr_->proj_maps [1 ]), std::move (key_cmp), settings.residual_filter ,
166162 std::move (register_task_group_callback), std::move (start_task_group_callback),
167163 [](int64_t , ExecBatch) { return Status::OK (); },
168- [](int64_t ) { return Status::OK (); }));
164+ [& ](int64_t ) { return Status::OK (); }));
169165
170166 task_group_probe_ = scheduler_->RegisterTaskGroup (
171167 [this ](size_t thread_index, int64_t task_id) -> Status {
@@ -178,25 +174,27 @@ class JoinBenchmark {
178174 scheduler_->RegisterEnd ();
179175
180176 DCHECK_OK (scheduler_->StartScheduling (
181- 0 /* thread index*/ , std::move (schedule_callback),
182- static_cast <int >(2 * settings.num_threads ) /* concurrent tasks*/ ,
183- settings.num_threads == 1 ));
177+ /* thread_id=*/ 0 ,
178+ [&](std::function<Status (size_t )> task) -> Status {
179+ return thread_pool_->Spawn ([&, task]() { DCHECK_OK (task (thread_indexer_ ())); });
180+ },
181+ thread_pool_->GetCapacity (), settings.num_threads == 1 ));
184182 }
185183
186184 void RunJoin () {
187- #pragma omp parallel
188- {
189- int tid = omp_get_thread_num ();
190- #pragma omp single
191- DCHECK_OK (
192- join_->BuildHashTable (tid, std::move (r_batches_), [this ](size_t thread_index) {
193- return scheduler_->StartTaskGroup (thread_index, task_group_probe_,
194- l_batches_.batch_count ());
195- }));
196- }
185+ DCHECK_OK (join_->BuildHashTable (
186+ /* thread_id=*/ 0 , std::move (r_batches_), [this ](size_t thread_index) {
187+ return scheduler_->StartTaskGroup (thread_index, task_group_probe_,
188+ l_batches_.batch_count ());
189+ }));
190+
191+ thread_pool_->WaitForIdle ();
197192 }
198193
199194 std::unique_ptr<TaskScheduler> scheduler_;
195+ ThreadIndexer thread_indexer_;
196+ arrow::internal::ThreadPool* thread_pool_;
197+
200198 AccumulationQueue l_batches_;
201199 AccumulationQueue r_batches_;
202200 std::unique_ptr<HashJoinSchema> schema_mgr_;
@@ -205,6 +203,7 @@ class JoinBenchmark {
205203 int task_group_probe_;
206204
207205 struct {
206+ uint64_t num_build_rows;
208207 uint64_t num_probe_rows;
209208 } stats_;
210209};
@@ -219,11 +218,13 @@ static void HashJoinBasicBenchmarkImpl(benchmark::State& st,
219218 st.ResumeTiming ();
220219 bm.RunJoin ();
221220 st.PauseTiming ();
222- total_rows += bm.stats_ .num_probe_rows ;
221+ total_rows += (settings.stats_probe_rows ? bm.stats_ .num_probe_rows
222+ : bm.stats_ .num_build_rows );
223223 }
224224 st.ResumeTiming ();
225225 }
226- st.counters [" rows/sec" ] = benchmark::Counter (total_rows, benchmark::Counter::kIsRate );
226+ st.counters [" rows/sec" ] =
227+ benchmark::Counter (static_cast <double >(total_rows), benchmark::Counter::kIsRate );
227228}
228229
229230template <typename ... Args>
@@ -302,6 +303,7 @@ static void BM_HashJoinBasic_BuildParallelism(benchmark::State& st) {
302303 settings.num_threads = static_cast <int >(st.range (0 ));
303304 settings.num_build_batches = static_cast <int >(st.range (1 ));
304305 settings.num_probe_batches = settings.num_threads ;
306+ settings.stats_probe_rows = false ;
305307
306308 HashJoinBasicBenchmarkImpl (st, settings);
307309}
0 commit comments