@@ -91,11 +91,11 @@ class TaskSchedulerImpl : public TaskScheduler {
9191 AbortContinuationImpl abort_cont_impl_;
9292
9393 std::vector<TaskGroup> task_groups_;
94- bool aborted_;
9594 bool register_finished_;
9695 std::mutex mutex_; // Mutex protecting task_groups_ (state_ and num_tasks_present_
97- // fields), aborted_ flag and register_finished_ flag
96+ // fields) and register_finished_ flag
9897
98+ AtomicWithPadding<bool > aborted_;
9999 AtomicWithPadding<int > num_tasks_to_schedule_;
100100 // If a task group adds tasks it's possible for a thread inside
101101 // ScheduleMore to miss this fact. This serves as a flag to
@@ -105,10 +105,8 @@ class TaskSchedulerImpl : public TaskScheduler {
105105};
106106
107107TaskSchedulerImpl::TaskSchedulerImpl ()
108- : use_sync_execution_(false ),
109- num_concurrent_tasks_(0 ),
110- aborted_(false ),
111- register_finished_(false ) {
108+ : use_sync_execution_(false ), num_concurrent_tasks_(0 ), register_finished_(false ) {
109+ aborted_.value .store (false );
112110 num_tasks_to_schedule_.value .store (0 );
113111 tasks_added_recently_.value .store (false );
114112}
@@ -131,13 +129,11 @@ Status TaskSchedulerImpl::StartTaskGroup(size_t thread_id, int group_id,
131129 ARROW_DCHECK (group_id >= 0 && group_id < static_cast <int >(task_groups_.size ()));
132130 TaskGroup& task_group = task_groups_[group_id];
133131
134- bool aborted = false ;
132+ bool aborted = aborted_. value . load () ;
135133 bool all_tasks_finished = false ;
136134 {
137135 std::lock_guard<std::mutex> lock (mutex_);
138136
139- aborted = aborted_;
140-
141137 if (task_group.state_ == TaskGroupState::NOT_READY) {
142138 task_group.num_tasks_present_ = total_num_tasks;
143139 if (total_num_tasks == 0 ) {
@@ -212,7 +208,7 @@ std::vector<std::pair<int, int64_t>> TaskSchedulerImpl::PickTasks(int num_tasks,
212208
213209Status TaskSchedulerImpl::ExecuteTask (size_t thread_id, int group_id, int64_t task_id,
214210 bool * task_group_finished) {
215- if (!aborted_) {
211+ if (!aborted_. value . load () ) {
216212 RETURN_NOT_OK (task_groups_[group_id].task_impl_ (thread_id, task_id));
217213 }
218214 *task_group_finished = PostExecuteTask (thread_id, group_id);
@@ -228,11 +224,10 @@ bool TaskSchedulerImpl::PostExecuteTask(size_t thread_id, int group_id) {
228224
229225Status TaskSchedulerImpl::OnTaskGroupFinished (size_t thread_id, int group_id,
230226 bool * all_task_groups_finished) {
231- bool aborted = false ;
227+ bool aborted = aborted_. value . load () ;
232228 {
233229 std::lock_guard<std::mutex> lock (mutex_);
234230
235- aborted = aborted_;
236231 TaskGroup& task_group = task_groups_[group_id];
237232 task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
238233 *all_task_groups_finished = true ;
@@ -260,7 +255,7 @@ Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int num_tasks_to_execute
260255
261256 int last_id = 0 ;
262257 for (;;) {
263- if (aborted_) {
258+ if (aborted_. value . load () ) {
264259 return Status::Cancelled (" Scheduler cancelled" );
265260 }
266261
@@ -278,8 +273,8 @@ Status TaskSchedulerImpl::ExecuteMore(size_t thread_id, int num_tasks_to_execute
278273 bool task_group_finished = false ;
279274 Status status = ExecuteTask (thread_id, group_id, task_id, &task_group_finished);
280275 if (!status.ok ()) {
281- // Mark the remaining picked tasks as finished
282- for (size_t j = i + 1 ; j < tasks.size (); ++j) {
276+ // Mark the current and remaining picked tasks as finished
277+ for (size_t j = i; j < tasks.size (); ++j) {
283278 if (PostExecuteTask (thread_id, tasks[j].first )) {
284279 bool all_task_groups_finished = false ;
285280 RETURN_NOT_OK (
@@ -328,7 +323,7 @@ Status TaskSchedulerImpl::StartScheduling(size_t thread_id, ScheduleImpl schedul
328323}
329324
330325Status TaskSchedulerImpl::ScheduleMore (size_t thread_id, int num_tasks_finished) {
331- if (aborted_) {
326+ if (aborted_. value . load () ) {
332327 return Status::Cancelled (" Scheduler cancelled" );
333328 }
334329
@@ -369,17 +364,25 @@ Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished)
369364 int group_id = tasks[i].first ;
370365 int64_t task_id = tasks[i].second ;
371366 RETURN_NOT_OK (schedule_impl_ ([this , group_id, task_id](size_t thread_id) -> Status {
372- RETURN_NOT_OK (ScheduleMore (thread_id, 1 ));
373-
374367 bool task_group_finished = false ;
375- RETURN_NOT_OK (ExecuteTask (thread_id, group_id, task_id, &task_group_finished));
368+ // PostExecuteTask must be called later if any error ocurres during task execution
369+ // (including ScheduleMore), so we preserve the status.
370+ auto status = [&]() {
371+ RETURN_NOT_OK (ScheduleMore (thread_id, 1 ));
372+ return ExecuteTask (thread_id, group_id, task_id, &task_group_finished);
373+ }();
374+
375+ if (!status.ok ()) {
376+ task_group_finished = PostExecuteTask (thread_id, group_id);
377+ }
376378
377379 if (task_group_finished) {
378380 bool all_task_groups_finished = false ;
379- return OnTaskGroupFinished (thread_id, group_id, &all_task_groups_finished);
381+ RETURN_NOT_OK (
382+ OnTaskGroupFinished (thread_id, group_id, &all_task_groups_finished));
380383 }
381384
382- return Status::OK () ;
385+ return status ;
383386 }));
384387 }
385388
@@ -388,31 +391,43 @@ Status TaskSchedulerImpl::ScheduleMore(size_t thread_id, int num_tasks_finished)
388391
389392void TaskSchedulerImpl::Abort (AbortContinuationImpl impl) {
390393 bool all_finished = true ;
394+ DCHECK_EQ (aborted_.value .load (), false );
395+ aborted_.value .store (true );
391396 {
392397 std::lock_guard<std::mutex> lock (mutex_);
393- aborted_ = true ;
394398 abort_cont_impl_ = std::move (impl);
395399 if (register_finished_) {
396400 for (size_t i = 0 ; i < task_groups_.size (); ++i) {
397401 TaskGroup& task_group = task_groups_[i];
398- if (task_group.state_ == TaskGroupState::NOT_READY) {
399- task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
400- } else if (task_group.state_ == TaskGroupState::READY) {
401- int64_t expected = task_group.num_tasks_started_ .value .load ();
402- for (;;) {
403- if (task_group.num_tasks_started_ .value .compare_exchange_strong (
404- expected, task_group.num_tasks_present_ )) {
405- break ;
402+ switch (task_group.state_ ) {
403+ case TaskGroupState::NOT_READY: {
404+ task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
405+ break ;
406+ }
407+ case TaskGroupState::READY: {
408+ int64_t expected = task_group.num_tasks_started_ .value .load ();
409+ for (;;) {
410+ if (task_group.num_tasks_started_ .value .compare_exchange_strong (
411+ expected, task_group.num_tasks_present_ )) {
412+ break ;
413+ }
406414 }
415+ int64_t before_add = task_group.num_tasks_finished_ .value .fetch_add (
416+ task_group.num_tasks_present_ - expected);
417+ if (before_add >= expected) {
418+ task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
419+ } else {
420+ all_finished = false ;
421+ task_group.state_ = TaskGroupState::ALL_TASKS_STARTED;
422+ }
423+ break ;
407424 }
408- int64_t before_add = task_group.num_tasks_finished_ .value .fetch_add (
409- task_group.num_tasks_present_ - expected);
410- if (before_add >= expected) {
411- task_group.state_ = TaskGroupState::ALL_TASKS_FINISHED;
412- } else {
425+ case TaskGroupState::ALL_TASKS_STARTED: {
413426 all_finished = false ;
414- task_group. state_ = TaskGroupState::ALL_TASKS_STARTED ;
427+ break ;
415428 }
429+ default :
430+ break ;
416431 }
417432 }
418433 }
0 commit comments