Skip to content
154 changes: 126 additions & 28 deletions durabletask/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def __init__(
else:
self._interceptors = None

self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options)
self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options, self._logger)

@property
def concurrency_options(self) -> ConcurrencyOptions:
Expand Down Expand Up @@ -533,27 +533,31 @@ def stream_reader():
if work_item.HasField("orchestratorRequest"):
self._async_worker_manager.submit_orchestration(
self._execute_orchestrator,
self._cancel_orchestrator,
work_item.orchestratorRequest,
stub,
work_item.completionToken,
)
elif work_item.HasField("activityRequest"):
self._async_worker_manager.submit_activity(
self._execute_activity,
self._cancel_activity,
work_item.activityRequest,
stub,
work_item.completionToken,
)
elif work_item.HasField("entityRequest"):
self._async_worker_manager.submit_entity_batch(
self._execute_entity_batch,
self._cancel_entity_batch,
work_item.entityRequest,
stub,
work_item.completionToken,
)
elif work_item.HasField("entityRequestV2"):
self._async_worker_manager.submit_entity_batch(
self._execute_entity_batch,
self._cancel_entity_batch,
work_item.entityRequestV2,
stub,
work_item.completionToken
Expand Down Expand Up @@ -670,6 +674,19 @@ def _execute_orchestrator(
f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}"
)

def _cancel_orchestrator(
self,
req: pb.OrchestratorRequest,
stub: stubs.TaskHubSidecarServiceStub,
completionToken,
):
stub.AbandonTaskOrchestratorWorkItem(
pb.AbandonOrchestrationTaskRequest(
completionToken=completionToken
)
)
self._logger.info(f"Cancelled orchestration task for invocation ID: {req.instanceId}")

def _execute_activity(
self,
req: pb.ActivityRequest,
Expand Down Expand Up @@ -703,6 +720,19 @@ def _execute_activity(
f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}"
)

def _cancel_activity(
self,
req: pb.ActivityRequest,
stub: stubs.TaskHubSidecarServiceStub,
completionToken,
):
stub.AbandonTaskActivityWorkItem(
pb.AbandonActivityTaskRequest(
completionToken=completionToken
)
)
self._logger.info(f"Cancelled activity task for task ID: {req.taskId} on orchestration ID: {req.orchestrationInstance.instanceId}")

def _execute_entity_batch(
self,
req: Union[pb.EntityBatchRequest, pb.EntityRequest],
Expand Down Expand Up @@ -771,6 +801,19 @@ def _execute_entity_batch(

return batch_result

def _cancel_entity_batch(
self,
req: Union[pb.EntityBatchRequest, pb.EntityRequest],
stub: stubs.TaskHubSidecarServiceStub,
completionToken,
):
stub.AbandonTaskEntityWorkItem(
pb.AbandonEntityTaskRequest(
completionToken=completionToken
)
)
self._logger.info(f"Cancelled entity batch task for instance ID: {req.instanceId}")


class _RuntimeOrchestrationContext(task.OrchestrationContext):
_generator: Optional[Generator[task.Task, Any, Any]]
Expand Down Expand Up @@ -1933,8 +1976,10 @@ def _is_suspendable(event: pb.HistoryEvent) -> bool:


class _AsyncWorkerManager:
def __init__(self, concurrency_options: ConcurrencyOptions):
def __init__(self, concurrency_options: ConcurrencyOptions, logger: logging.Logger):
self.concurrency_options = concurrency_options
self._logger = logger

self.activity_semaphore = None
self.orchestration_semaphore = None
self.entity_semaphore = None
Expand Down Expand Up @@ -2044,17 +2089,51 @@ async def run(self):
)

# Start background consumers for each work type
if self.activity_queue is not None and self.orchestration_queue is not None \
and self.entity_batch_queue is not None:
await asyncio.gather(
self._consume_queue(self.activity_queue, self.activity_semaphore),
self._consume_queue(
self.orchestration_queue, self.orchestration_semaphore
),
self._consume_queue(
self.entity_batch_queue, self.entity_semaphore
try:
if self.activity_queue is not None and self.orchestration_queue is not None \
and self.entity_batch_queue is not None:
await asyncio.gather(
self._consume_queue(self.activity_queue, self.activity_semaphore),
self._consume_queue(
self.orchestration_queue, self.orchestration_semaphore
),
self._consume_queue(
self.entity_batch_queue, self.entity_semaphore
)
)
)
except Exception as queue_exception:
self._logger.error(f"Shutting down worker - Uncaught error in worker manager: {queue_exception}")
while self.activity_queue is not None and not self.activity_queue.empty():
try:
func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
await self._run_func(cancellation_func, *args, **kwargs)
self._logger.error(f"Activity work item args: {args}, kwargs: {kwargs}")
except asyncio.QueueEmpty:
# Queue was empty, no cancellation needed
pass
except Exception as cancellation_exception:
self._logger.error(f"Uncaught error while cancelling activity work item: {cancellation_exception}")
while self.orchestration_queue is not None and not self.orchestration_queue.empty():
try:
func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
await self._run_func(cancellation_func, *args, **kwargs)
self._logger.error(f"Orchestration work item args: {args}, kwargs: {kwargs}")
except asyncio.QueueEmpty:
# Queue was empty, no cancellation needed
pass
except Exception as cancellation_exception:
self._logger.error(f"Uncaught error while cancelling orchestration work item: {cancellation_exception}")
while self.entity_batch_queue is not None and not self.entity_batch_queue.empty():
try:
func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
await self._run_func(cancellation_func, *args, **kwargs)
self._logger.error(f"Entity batch work item args: {args}, kwargs: {kwargs}")
except asyncio.QueueEmpty:
# Queue was empty, no cancellation needed
pass
except Exception as cancellation_exception:
self._logger.error(f"Uncaught error while cancelling entity batch work item: {cancellation_exception}")
self.shutdown()

async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore):
# List to track running tasks
Expand All @@ -2074,19 +2153,22 @@ async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphor
except asyncio.TimeoutError:
continue

func, args, kwargs = work
func, cancellation_func, args, kwargs = work
# Create a concurrent task for processing
task = asyncio.create_task(
self._process_work_item(semaphore, queue, func, args, kwargs)
self._process_work_item(semaphore, queue, func, cancellation_func, args, kwargs)
)
running_tasks.add(task)

async def _process_work_item(
self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs
self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, cancellation_func, args, kwargs
):
async with semaphore:
try:
await self._run_func(func, *args, **kwargs)
except Exception as work_exception:
self._logger.error(f"Uncaught error while processing work item, item will be abandoned: {work_exception}")
await self._run_func(cancellation_func, *args, **kwargs)
finally:
queue.task_done()

Expand All @@ -2105,26 +2187,32 @@ async def _run_func(self, func, *args, **kwargs):
self.thread_pool, lambda: func(*args, **kwargs)
)

def submit_activity(self, func, *args, **kwargs):
work_item = (func, args, kwargs)
def submit_activity(self, func, cancellation_func, *args, **kwargs):
if self._shutdown:
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
work_item = (func, cancellation_func, args, kwargs)
self._ensure_queues_for_current_loop()
if self.activity_queue is not None:
self.activity_queue.put_nowait(work_item)
else:
# No event loop running, store in pending list
self._pending_activity_work.append(work_item)

def submit_orchestration(self, func, *args, **kwargs):
work_item = (func, args, kwargs)
def submit_orchestration(self, func, cancellation_func, *args, **kwargs):
if self._shutdown:
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
work_item = (func, cancellation_func, args, kwargs)
self._ensure_queues_for_current_loop()
if self.orchestration_queue is not None:
self.orchestration_queue.put_nowait(work_item)
else:
# No event loop running, store in pending list
self._pending_orchestration_work.append(work_item)

def submit_entity_batch(self, func, *args, **kwargs):
work_item = (func, args, kwargs)
def submit_entity_batch(self, func, cancellation_func, *args, **kwargs):
if self._shutdown:
raise RuntimeError("Cannot submit new work items after shutdown has been initiated.")
work_item = (func, cancellation_func, args, kwargs)
self._ensure_queues_for_current_loop()
if self.entity_batch_queue is not None:
self.entity_batch_queue.put_nowait(work_item)
Expand All @@ -2136,7 +2224,7 @@ def shutdown(self):
self._shutdown = True
self.thread_pool.shutdown(wait=True)

def reset_for_new_run(self):
async def reset_for_new_run(self):
"""Reset the manager state for a new run."""
self._shutdown = False
# Clear any existing queues - they'll be recreated when needed
Expand All @@ -2145,18 +2233,28 @@ def reset_for_new_run(self):
# This ensures no items from previous runs remain
try:
while not self.activity_queue.empty():
self.activity_queue.get_nowait()
except Exception:
pass
func, cancellation_func, args, kwargs = self.activity_queue.get_nowait()
await self._run_func(cancellation_func, *args, **kwargs)
except Exception as reset_exception:
self._logger.warning(f"Error while clearing activity queue during reset: {reset_exception}")
if self.orchestration_queue is not None:
try:
while not self.orchestration_queue.empty():
self.orchestration_queue.get_nowait()
except Exception:
pass
func, cancellation_func, args, kwargs = self.orchestration_queue.get_nowait()
await self._run_func(cancellation_func, *args, **kwargs)
except Exception as reset_exception:
self._logger.warning(f"Error while clearing orchestration queue during reset: {reset_exception}")
if self.entity_batch_queue is not None:
try:
while not self.entity_batch_queue.empty():
func, cancellation_func, args, kwargs = self.entity_batch_queue.get_nowait()
await self._run_func(cancellation_func, *args, **kwargs)
except Exception as reset_exception:
self._logger.warning(f"Error while clearing entity queue during reset: {reset_exception}")
# Clear pending work lists
self._pending_activity_work.clear()
self._pending_orchestration_work.clear()
self._pending_entity_batch_work.clear()


# Export public API
Expand Down
21 changes: 17 additions & 4 deletions tests/durabletask/test_worker_concurrency_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,21 @@ def dummy_orchestrator(req, stub, completionToken):
time.sleep(0.1)
stub.CompleteOrchestratorTask('ok')

def cancel_dummy_orchestrator(req, stub, completionToken):
pass

def dummy_activity(req, stub, completionToken):
time.sleep(0.1)
stub.CompleteActivityTask('ok')

def cancel_dummy_activity(req, stub, completionToken):
pass

# Patch the worker's _execute_orchestrator and _execute_activity
worker._execute_orchestrator = dummy_orchestrator
worker._cancel_orchestrator = cancel_dummy_orchestrator
worker._execute_activity = dummy_activity
worker._cancel_activity = cancel_dummy_activity

orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)]
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
Expand All @@ -67,9 +75,9 @@ async def run_test():
# Start the worker manager's run loop in the background
worker_task = asyncio.create_task(worker._async_worker_manager.run())
for req in orchestrator_requests:
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())
worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
for req in activity_requests:
worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken())
worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken())
await asyncio.sleep(1.0)
orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator')
activity_count = sum(1 for t, _ in stub.completed if t == 'activity')
Expand Down Expand Up @@ -120,8 +128,8 @@ def fn(*args, **kwargs):

# Submit more work than concurrency allows
for i in range(5):
manager.submit_orchestration(make_work("orch", i))
manager.submit_activity(make_work("act", i))
manager.submit_orchestration(make_work("orch", i), lambda *a, **k: None)
manager.submit_activity(make_work("act", i), lambda *a, **k: None)

# Run the manager loop in a thread (sync context)
def run_manager():
Expand All @@ -131,6 +139,11 @@ def run_manager():
t.start()
time.sleep(1.5) # Let work process
manager.shutdown()

# Ensure the queues have been started
if (manager.activity_queue is None or manager.orchestration_queue is None):
raise RuntimeError("Worker manager queues not initialized")

# Unblock the consumers by putting dummy items in the queues
manager.activity_queue.put_nowait((lambda: None, (), {}))
manager.orchestration_queue.put_nowait((lambda: None, (), {}))
Expand Down
21 changes: 17 additions & 4 deletions tests/durabletask/test_worker_concurrency_loop_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,21 @@ async def dummy_orchestrator(req, stub, completionToken):
await asyncio.sleep(0.1)
stub.CompleteOrchestratorTask('ok')

async def cancel_dummy_orchestrator(req, stub, completionToken):
pass

async def dummy_activity(req, stub, completionToken):
await asyncio.sleep(0.1)
stub.CompleteActivityTask('ok')

async def cancel_dummy_activity(req, stub, completionToken):
pass

# Patch the worker's _execute_orchestrator and _execute_activity
grpc_worker._execute_orchestrator = dummy_orchestrator
grpc_worker._execute_activity = dummy_activity
grpc_worker._execute_orchestrator = dummy_orchestrator.__get__(grpc_worker, TaskHubGrpcWorker)
grpc_worker._cancel_orchestrator = cancel_dummy_orchestrator.__get__(grpc_worker, TaskHubGrpcWorker)
grpc_worker._execute_activity = dummy_activity.__get__(grpc_worker, TaskHubGrpcWorker)
grpc_worker._cancel_activity = cancel_dummy_activity.__get__(grpc_worker, TaskHubGrpcWorker)

orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)]
activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)]
Expand All @@ -65,10 +73,15 @@ async def run_test():
# Clear stub state before each run
stub.completed.clear()
worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run())
# Need to yield to that thread in order to let it start up on the second run
startup_attempts = 0
while grpc_worker._async_worker_manager._shutdown and startup_attempts < 10:
await asyncio.sleep(0.1)
startup_attempts += 1
for req in orchestrator_requests:
grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken())
grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, cancel_dummy_orchestrator, req, stub, DummyCompletionToken())
for req in activity_requests:
grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken())
grpc_worker._async_worker_manager.submit_activity(dummy_activity, cancel_dummy_activity, req, stub, DummyCompletionToken())
await asyncio.sleep(1.0)
orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator')
activity_count = sum(1 for t, _ in stub.completed if t == 'activity')
Expand Down
Loading