Skip to content

Commit e9ec383

Browse files
Joob1nzanmato1984pitrou
authored
GH-45266: [C++][Acero] Fix the running tasks count of Scheduler when get error tasks in multi-threads (#45268)
### Rationale for this change When the TaskGroup should be canceled, it will move the number which not-start to finished to avoid do them(in `TaskSchedulerImpl::Abort`). But this is one operation that happens in multi-threads. At the same time, maybe some task start to running and happen some error. Then they will return the bad status. But the tasks are running for Scheduler, they will just return bad status and not change the running_task count. Because the code uses `RETURN_NOT_OK`. ### What changes are included in this PR? For any task, what status weather it returns, it will change the running_count before return. ### Are these changes tested? No. It is too hard to build ut. ### Are there any user-facing changes? No. But I am very shocked at hasn't this happened to anyone? * GitHub Issue: #45266 Lead-authored-by: zhouyunpei <zhouyunpei@yanhuangdata.com> Co-authored-by: Rossi Sun <zanmato1984@gmail.com> Co-authored-by: Antoine Pitrou <antoine@python.org> Signed-off-by: Rossi Sun <zanmato1984@gmail.com>
1 parent 13940cd commit e9ec383

File tree

2 files changed

+143
-36
lines changed

2 files changed

+143
-36
lines changed

cpp/src/arrow/acero/task_util.cc

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -91,11 +91,11 @@ class TaskSchedulerImpl : public TaskScheduler {
9191
AbortContinuationImpl abort_cont_impl_;
9292

9393
std::vector<TaskGroup> task_groups_;
94-
bool aborted_;
9594
bool register_finished_;
9695
std::mutex mutex_; // Mutex protecting task_groups_ (state_ and num_tasks_present_
97-
// fields), aborted_ flag and register_finished_ flag
96+
// fields) and register_finished_ flag
9897

98+
AtomicWithPadding<bool> aborted_;
9999
AtomicWithPadding<int> num_tasks_to_schedule_;
100100
// If a task group adds tasks it's possible for a thread inside
101101
// ScheduleMore to miss this fact. This serves as a flag to
@@ -105,10 +105,8 @@ class TaskSchedulerImpl : public TaskScheduler {
105105
};
106106

107107
TaskSchedulerImpl::TaskSchedulerImpl()
108-
: use_sync_execution_(false),
109-
num_concurrent_tasks_(0),
110-
aborted_(false),
111-
register_finished_(false) {
108+
: use_sync_execution_(false), num_concurrent_tasks_(0), register_finished_(false) {
109+
aborted_.value.store(false);
112110
num_tasks_to_schedule_.value.store(0);
113111
tasks_added_recently_.value.store(false);
114112
}
@@ -131,13 +129,11 @@ Status TaskSchedulerImpl::StartTaskGroup(size_t thread_id, int group_id,
131129
ARROW_DCHECK(group_id >= 0 && group_id < static_cast<int>(task_groups_.size()));
132130
TaskGroup& task_group = task_groups_[group_id];
133131

134-
bool aborted = false;
132+
bool aborted = aborted_.value.load();
135133
bool all_tasks_finished = false;
136134
{
137135
std::lock_guard<std::mutex> lock(mutex_);
138136

139-
aborted = aborted_;
140-
141137
if (task_group.state_ == TaskGroupState::NOT_READY) {
142138
task_group.num_tasks_present_ = total_num_tasks;
143139
if (total_num_tasks == 0) {
@@ -212,7 +208,7 @@ std::vector<std::pair<int, int64_t>> TaskSchedulerImpl::PickTasks(int num_tasks,
212208

213209
Status TaskSchedulerImpl::ExecuteTask(size_t thread_id, int group_id, int64_t task_id,
214210
bool* task_group_finished) {
215-
if (!aborted_) {
211+
if (!aborted_.value.load()) {
216212
RETURN_NOT_OK(task_groups_[group_id].task_impl_(thread_id, task_id));
217213
}
218214
*task_group_finished = PostExecuteTask(thread_id, group_id);
@@ -228,11 +224,10 @@ bool TaskSchedulerImpl::PostExecuteTask(size_t thread_id, int group_id) {
228224

229225
Status TaskSchedulerImpl::OnTaskGroupFinished(size_t thread_id, int group_id,
230226
bool* all_task_groups_finished) {
231-
bool aborted = false;
227+
bool aborted = aborted_.value.load();
232228
{
233229
std::lock_guard<std::mutex> lock(mutex_);
234230

235-
aborted = aborted_;
236231
TaskGroup& task_group = task_groups_[group_id];
237232
task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
238233
*all_task_groups_finished = true;
@@ -260,7 +255,7 @@ Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int num_tasks_to_execute
260255

261256
int last_id = 0;
262257
for (;;) {
263-
if (aborted_) {
258+
if (aborted_.value.load()) {
264259
return Status::Cancelled("Scheduler cancelled");
265260
}
266261

@@ -278,8 +273,8 @@ Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int num_tasks_to_execute
278273
bool task_group_finished = false;
279274
Status status = ExecuteTask(thread_id, group_id, task_id, &task_group_finished);
280275
if (!status.ok()) {
281-
// Mark the remaining picked tasks as finished
282-
for (size_t j = i + 1; j < tasks.size(); ++j) {
276+
// Mark the current and remaining picked tasks as finished
277+
for (size_t j = i; j < tasks.size(); ++j) {
283278
if (PostExecuteTask(thread_id, tasks[j].first)) {
284279
bool all_task_groups_finished = false;
285280
RETURN_NOT_OK(
@@ -328,7 +323,7 @@ Status TaskSchedulerImpl::StartScheduling(size_t thread_id, ScheduleImpl schedul
328323
}
329324

330325
Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished) {
331-
if (aborted_) {
326+
if (aborted_.value.load()) {
332327
return Status::Cancelled("Scheduler cancelled");
333328
}
334329

@@ -369,17 +364,25 @@ Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished)
369364
int group_id = tasks[i].first;
370365
int64_t task_id = tasks[i].second;
371366
RETURN_NOT_OK(schedule_impl_([this, group_id, task_id](size_t thread_id) -> Status {
372-
RETURN_NOT_OK(ScheduleMore(thread_id, 1));
373-
374367
bool task_group_finished = false;
375-
RETURN_NOT_OK(ExecuteTask(thread_id, group_id, task_id, &task_group_finished));
368+
// PostExecuteTask must be called later if any error ocurres during task execution
369+
// (including ScheduleMore), so we preserve the status.
370+
auto status = [&]() {
371+
RETURN_NOT_OK(ScheduleMore(thread_id, 1));
372+
return ExecuteTask(thread_id, group_id, task_id, &task_group_finished);
373+
}();
374+
375+
if (!status.ok()) {
376+
task_group_finished = PostExecuteTask(thread_id, group_id);
377+
}
376378

377379
if (task_group_finished) {
378380
bool all_task_groups_finished = false;
379-
return OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished);
381+
RETURN_NOT_OK(
382+
OnTaskGroupFinished(thread_id, group_id, &all_task_groups_finished));
380383
}
381384

382-
return Status::OK();
385+
return status;
383386
}));
384387
}
385388

@@ -388,31 +391,43 @@ Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished)
388391

389392
void TaskSchedulerImpl::Abort(AbortContinuationImpl impl) {
390393
bool all_finished = true;
394+
DCHECK_EQ(aborted_.value.load(), false);
395+
aborted_.value.store(true);
391396
{
392397
std::lock_guard<std::mutex> lock(mutex_);
393-
aborted_ = true;
394398
abort_cont_impl_ = std::move(impl);
395399
if (register_finished_) {
396400
for (size_t i = 0; i < task_groups_.size(); ++i) {
397401
TaskGroup& task_group = task_groups_[i];
398-
if (task_group.state_ == TaskGroupState::NOT_READY) {
399-
task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
400-
} else if (task_group.state_ == TaskGroupState::READY) {
401-
int64_t expected = task_group.num_tasks_started_.value.load();
402-
for (;;) {
403-
if (task_group.num_tasks_started_.value.compare_exchange_strong(
404-
expected, task_group.num_tasks_present_)) {
405-
break;
402+
switch (task_group.state_) {
403+
case TaskGroupState::NOT_READY: {
404+
task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
405+
break;
406+
}
407+
case TaskGroupState::READY: {
408+
int64_t expected = task_group.num_tasks_started_.value.load();
409+
for (;;) {
410+
if (task_group.num_tasks_started_.value.compare_exchange_strong(
411+
expected, task_group.num_tasks_present_)) {
412+
break;
413+
}
406414
}
415+
int64_t before_add = task_group.num_tasks_finished_.value.fetch_add(
416+
task_group.num_tasks_present_ - expected);
417+
if (before_add >= expected) {
418+
task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
419+
} else {
420+
all_finished = false;
421+
task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
422+
}
423+
break;
407424
}
408-
int64_t before_add = task_group.num_tasks_finished_.value.fetch_add(
409-
task_group.num_tasks_present_ - expected);
410-
if (before_add >= expected) {
411-
task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
412-
} else {
425+
case TaskGroupState::ALL_TASKS_STARTED: {
413426
all_finished = false;
414-
task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
427+
break;
415428
}
429+
default:
430+
break;
416431
}
417432
}
418433
}

cpp/src/arrow/acero/task_util_test.cc

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,5 +231,97 @@ TEST(TaskScheduler, StressTwo) {
231231
}
232232
}
233233

234+
TEST(TaskScheduler, AbortContOnTaskErrorSerial) {
235+
constexpr int kNumTasks = 16;
236+
237+
auto scheduler = TaskScheduler::Make();
238+
auto task = [&](std::size_t, int64_t task_id) {
239+
if (task_id == kNumTasks / 2) {
240+
return Status::Invalid("Task failed");
241+
}
242+
return Status::OK();
243+
};
244+
245+
int task_group =
246+
scheduler->RegisterTaskGroup(task, [](std::size_t) { return Status::OK(); });
247+
scheduler->RegisterEnd();
248+
249+
ASSERT_OK(scheduler->StartScheduling(
250+
/*thread_id=*/0,
251+
/*schedule_impl=*/
252+
[](TaskScheduler::TaskGroupContinuationImpl) { return Status::OK(); },
253+
/*num_concurrent_tasks=*/1, /*use_sync_execution=*/true));
254+
ASSERT_RAISES_WITH_MESSAGE(
255+
Invalid, "Invalid: Task failed",
256+
scheduler->StartTaskGroup(/*thread_id=*/0, task_group, kNumTasks));
257+
258+
int num_abort_cont_calls = 0;
259+
auto abort_cont = [&]() { ++num_abort_cont_calls; };
260+
261+
scheduler->Abort(abort_cont);
262+
263+
ASSERT_EQ(num_abort_cont_calls, 1);
264+
}
265+
266+
TEST(TaskScheduler, AbortContOnTaskErrorParallel) {
267+
#ifndef ARROW_ENABLE_THREADING
268+
GTEST_SKIP() << "Test requires threading support";
269+
#endif
270+
constexpr int kNumThreads = 16;
271+
272+
ThreadIndexer thread_indexer;
273+
int num_threads = std::min(static_cast<int>(thread_indexer.Capacity()), kNumThreads);
274+
ASSERT_OK_AND_ASSIGN(std::shared_ptr<ThreadPool> thread_pool,
275+
MakePrimedThreadPool(num_threads));
276+
TaskScheduler::ScheduleImpl schedule =
277+
[&](TaskScheduler::TaskGroupContinuationImpl task) {
278+
return thread_pool->Spawn([&, task] {
279+
std::size_t thread_id = thread_indexer();
280+
auto status = task(thread_id);
281+
ASSERT_TRUE(status.ok() || status.IsInvalid() || status.IsCancelled())
282+
<< status;
283+
});
284+
};
285+
286+
for (int num_tasks :
287+
{2, num_threads - 1, num_threads, num_threads + 1, 2 * num_threads}) {
288+
ARROW_SCOPED_TRACE("num_tasks = ", num_tasks);
289+
for (int num_concurrent_tasks :
290+
{1, num_tasks - 1, num_tasks, num_tasks + 1, 2 * num_tasks}) {
291+
ARROW_SCOPED_TRACE("num_concurrent_tasks = ", num_concurrent_tasks);
292+
for (int aborting_task_id = 0; aborting_task_id < num_tasks; ++aborting_task_id) {
293+
ARROW_SCOPED_TRACE("aborting_task_id = ", aborting_task_id);
294+
auto scheduler = TaskScheduler::Make();
295+
296+
int num_abort_cont_calls = 0;
297+
auto abort_cont = [&]() { ++num_abort_cont_calls; };
298+
299+
auto task = [&](std::size_t, int64_t task_id) {
300+
if (task_id == aborting_task_id) {
301+
scheduler->Abort(abort_cont);
302+
}
303+
if (task_id % 2 == 0) {
304+
return Status::Invalid("Task failed");
305+
}
306+
return Status::OK();
307+
};
308+
309+
int task_group =
310+
scheduler->RegisterTaskGroup(task, [](std::size_t) { return Status::OK(); });
311+
scheduler->RegisterEnd();
312+
313+
ASSERT_OK(scheduler->StartScheduling(/*thread_id=*/0, schedule,
314+
num_concurrent_tasks,
315+
/*use_sync_execution=*/false));
316+
ASSERT_OK(scheduler->StartTaskGroup(/*thread_id=*/0, task_group, num_tasks));
317+
318+
thread_pool->WaitForIdle();
319+
320+
ASSERT_EQ(num_abort_cont_calls, 1);
321+
}
322+
}
323+
}
324+
}
325+
234326
} // namespace acero
235327
} // namespace arrow

0 commit comments

Comments
 (0)