diff --git a/CHANGELOG.md b/CHANGELOG.md index ee736f0..376221e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,17 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## v0.2.0 (Unreleased) +## v0.3.0 + +### New + +- Added `ConcurrencyOptions` class for fine-grained concurrency control with separate limits for activities and orchestrations. The thread pool worker count can also be configured. + +### Fixed + +- Fixed an issue where a worker could not recover after its connection was interrupted or severed + +## v0.2.1 ### New diff --git a/durabletask/client.py b/durabletask/client.py index fae968d..7a72e1a 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Optional, TypeVar, Union +from typing import Any, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import wrappers_pb2 @@ -16,6 +16,7 @@ import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared from durabletask import task +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') @@ -96,8 +97,25 @@ def __init__(self, *, metadata: Optional[list[tuple[str, str]]] = None, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False): - channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel) + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): + + # If the caller provided metadata, we need to create a new interceptor for it and + # add it to the list of interceptors. + if interceptors is not None: + interceptors = list(interceptors) + if metadata is not None: + interceptors.append(DefaultClientInterceptorImpl(metadata)) + elif metadata is not None: + interceptors = [DefaultClientInterceptorImpl(metadata)] + else: + interceptors = None + + channel = shared.get_grpc_channel( + host_address=host_address, + secure_channel=secure_channel, + interceptors=interceptors + ) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) @@ -116,7 +134,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, version=wrappers_pb2.StringValue(value=""), orchestrationIdReusePolicy=reuse_id_policy, - ) + ) self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") res: pb.CreateInstanceResponse = self._stub.StartInstance(req) diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index 738fca9..69db3c5 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -19,10 +19,10 @@ class _ClientCallDetails( class DefaultClientInterceptorImpl ( - grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): + grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an interceptor to add additional headers to all calls as needed.""" def __init__(self, metadata: list[tuple[str, str]]): @@ -30,17 +30,17 @@ def __init__(self, metadata: list[tuple[str, str]]): self._metadata = metadata def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC call details.""" if self._metadata is None: return client_call_details - + if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) else: metadata = [] - + metadata.extend(self._metadata) client_call_details = _ClientCallDetails( client_call_details.method, client_call_details.timeout, metadata, diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index c4f3aa4..c0fbe74 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -5,11 +5,16 @@ import json import logging from types import SimpleNamespace -from typing import Any, Optional +from typing import Any, Optional, Sequence, Union import grpc -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +ClientInterceptor = Union[ + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor +] # Field name used to indicate that an object was automatically serialized # and should be deserialized as a SimpleNamespace @@ -25,8 +30,8 @@ def get_default_host_address() -> str: def get_grpc_channel( host_address: Optional[str], - metadata: Optional[list[tuple[str, str]]], - secure_channel: bool = False) -> grpc.Channel: + secure_channel: bool = False, + interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel: if host_address is None: host_address = get_default_host_address() @@ -44,16 +49,18 @@ def get_grpc_channel( host_address = host_address[len(protocol):] break + # Create the base channel if secure_channel: channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials()) else: channel = grpc.insecure_channel(host_address) - if metadata is not None and len(metadata) > 0: - interceptors = [DefaultClientInterceptorImpl(metadata)] + # Apply interceptors ONLY if they exist + if interceptors: channel = grpc.intercept_channel(channel, *interceptors) return channel + def get_logger( name_suffix: str, log_handler: Optional[logging.Handler] = None, @@ -98,7 +105,7 @@ def default(self, obj): if dataclasses.is_dataclass(obj): # Dataclasses are not serializable by default, so we convert them to a dict and mark them for # automatic deserialization by the receiver - d = dataclasses.asdict(obj) # type: ignore + d = dataclasses.asdict(obj) # type: ignore d[AUTO_SERIALIZED] = True return d elif isinstance(obj, SimpleNamespace): diff --git a/durabletask/task.py b/durabletask/task.py index a40602b..d319bf2 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -333,7 +333,7 @@ class RetryableTask(CompletableTask[T]): """A task that can be retried according to a retry policy.""" def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, - start_time:datetime, is_sub_orch: bool) -> None: + start_time: datetime, is_sub_orch: bool) -> None: super().__init__() self._action = action self._retry_policy = retry_policy @@ -343,7 +343,7 @@ def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, def increment_attempt_count(self) -> None: self._attempt_count += 1 - + def compute_next_delay(self) -> Optional[timedelta]: if self._attempt_count >= self._retry_policy.max_number_of_attempts: return None @@ -351,7 +351,7 @@ def compute_next_delay(self) -> Optional[timedelta]: retry_expiration: datetime = datetime.max if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max: retry_expiration = self._start_time + self._retry_policy.retry_timeout - + if self._retry_policy.backoff_coefficient is None: backoff_coefficient = 1.0 else: diff --git a/durabletask/worker.py b/durabletask/worker.py index 75e2e37..0922567 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -1,29 +1,78 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import concurrent.futures +import asyncio +import inspect import logging +import os +import random +from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timedelta from threading import Event, Thread from types import GeneratorType from typing import Any, Generator, Optional, Sequence, TypeVar, Union import grpc -from google.protobuf import empty_pb2, wrappers_pb2 +from google.protobuf import empty_pb2 import durabletask.internal.helpers as ph -import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared from durabletask import task +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl + +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") + + +class ConcurrencyOptions: + """Configuration options for controlling concurrency of different work item types and the thread pool size. + + This class provides fine-grained control over concurrent processing limits for + activities, orchestrations and the thread pool size. + """ + + def __init__( + self, + maximum_concurrent_activity_work_items: Optional[int] = None, + maximum_concurrent_orchestration_work_items: Optional[int] = None, + maximum_thread_pool_workers: Optional[int] = None, + ): + """Initialize concurrency options. + + Args: + maximum_concurrent_activity_work_items: Maximum number of activity work items + that can be processed concurrently. Defaults to 100 * processor_count. + maximum_concurrent_orchestration_work_items: Maximum number of orchestration work items + that can be processed concurrently. Defaults to 100 * processor_count. + maximum_thread_pool_workers: Maximum number of thread pool workers to use. + """ + processor_count = os.cpu_count() or 1 + default_concurrency = 100 * processor_count + # see https://docs.python.org/3/library/concurrent.futures.html + default_max_workers = processor_count + 4 + + self.maximum_concurrent_activity_work_items = ( + maximum_concurrent_activity_work_items + if maximum_concurrent_activity_work_items is not None + else default_concurrency + ) -TInput = TypeVar('TInput') -TOutput = TypeVar('TOutput') + self.maximum_concurrent_orchestration_work_items = ( + maximum_concurrent_orchestration_work_items + if maximum_concurrent_orchestration_work_items is not None + else default_concurrency + ) + self.maximum_thread_pool_workers = ( + maximum_thread_pool_workers + if maximum_thread_pool_workers is not None + else default_max_workers + ) -class _Registry: +class _Registry: orchestrators: dict[str, task.Orchestrator] activities: dict[str, task.Activity] @@ -33,7 +82,7 @@ def __init__(self): def add_orchestrator(self, fn: task.Orchestrator) -> str: if fn is None: - raise ValueError('An orchestrator function argument is required.') + raise ValueError("An orchestrator function argument is required.") name = task.get_name(fn) self.add_named_orchestrator(name, fn) @@ -41,7 +90,7 @@ def add_orchestrator(self, fn: task.Orchestrator) -> str: def add_named_orchestrator(self, name: str, fn: task.Orchestrator) -> None: if not name: - raise ValueError('A non-empty orchestrator name is required.') + raise ValueError("A non-empty orchestrator name is required.") if name in self.orchestrators: raise ValueError(f"A '{name}' orchestrator already exists.") @@ -52,7 +101,7 @@ def get_orchestrator(self, name: str) -> Optional[task.Orchestrator]: def add_activity(self, fn: task.Activity) -> str: if fn is None: - raise ValueError('An activity function argument is required.') + raise ValueError("An activity function argument is required.") name = task.get_name(fn) self.add_named_activity(name, fn) @@ -60,7 +109,7 @@ def add_activity(self, fn: task.Activity) -> str: def add_named_activity(self, name: str, fn: task.Activity) -> None: if not name: - raise ValueError('A non-empty activity name is required.') + raise ValueError("A non-empty activity name is required.") if name in self.activities: raise ValueError(f"A '{name}' activity already exists.") @@ -72,31 +121,142 @@ def get_activity(self, name: str) -> Optional[task.Activity]: class OrchestratorNotRegisteredError(ValueError): """Raised when attempting to start an orchestration that is not registered""" + pass class ActivityNotRegisteredError(ValueError): """Raised when attempting to call an activity that is not registered""" + pass class TaskHubGrpcWorker: - _response_stream: Optional[grpc.Future] = None + """A gRPC-based worker for processing durable task orchestrations and activities. + + This worker connects to a Durable Task backend service via gRPC to receive and process + work items including orchestration functions and activity functions. It provides + concurrent execution capabilities with configurable limits and automatic retry handling. + + The worker manages the complete lifecycle: + - Registers orchestrator and activity functions + - Connects to the gRPC backend service + - Receives work items and executes them concurrently + - Handles failures, retries, and state management + - Provides logging and monitoring capabilities + + Args: + host_address (Optional[str], optional): The gRPC endpoint address of the backend service. + Defaults to the value from environment variables or localhost. + metadata (Optional[list[tuple[str, str]]], optional): gRPC metadata to include with + requests. Used for authentication and routing. Defaults to None. + log_handler (optional): Custom logging handler for worker logs. Defaults to None. + log_formatter (Optional[logging.Formatter], optional): Custom log formatter. + Defaults to None. + secure_channel (bool, optional): Whether to use a secure gRPC channel (TLS). + Defaults to False. + interceptors (Optional[Sequence[shared.ClientInterceptor]], optional): Custom gRPC + interceptors to apply to the channel. Defaults to None. + concurrency_options (Optional[ConcurrencyOptions], optional): Configuration for + controlling worker concurrency limits. If None, default settings are used. + + Attributes: + concurrency_options (ConcurrencyOptions): The current concurrency configuration. + + Example: + Basic worker setup: + + >>> from durabletask.worker import TaskHubGrpcWorker, ConcurrencyOptions + >>> + >>> # Create worker with custom concurrency settings + >>> concurrency = ConcurrencyOptions( + ... maximum_concurrent_activity_work_items=50, + ... maximum_concurrent_orchestration_work_items=20 + ... ) + >>> worker = TaskHubGrpcWorker( + ... host_address="localhost:4001", + ... concurrency_options=concurrency + ... ) + >>> + >>> # Register functions + >>> @worker.add_orchestrator + ... def my_orchestrator(context, input): + ... result = yield context.call_activity("my_activity", input="hello") + ... return result + >>> + >>> @worker.add_activity + ... def my_activity(context, input): + ... return f"Processed: {input}" + >>> + >>> # Start the worker + >>> worker.start() + >>> # ... worker runs in background thread + >>> worker.stop() + + Using as context manager: + + >>> with TaskHubGrpcWorker() as worker: + ... worker.add_orchestrator(my_orchestrator) + ... worker.add_activity(my_activity) + ... worker.start() + ... # Worker automatically stops when exiting context + + Raises: + RuntimeError: If attempting to add orchestrators/activities while the worker is running, + or if starting a worker that is already running. + OrchestratorNotRegisteredError: If an orchestration work item references an + unregistered orchestrator function. + ActivityNotRegisteredError: If an activity work item references an unregistered + activity function. + """ - def __init__(self, *, - host_address: Optional[str] = None, - metadata: Optional[list[tuple[str, str]]] = None, - log_handler=None, - log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False): + _response_stream: Optional[grpc.Future] = None + _interceptors: Optional[list[shared.ClientInterceptor]] = None + + def __init__( + self, + *, + host_address: Optional[str] = None, + metadata: Optional[list[tuple[str, str]]] = None, + log_handler=None, + log_formatter: Optional[logging.Formatter] = None, + secure_channel: bool = False, + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, + concurrency_options: Optional[ConcurrencyOptions] = None, + ): self._registry = _Registry() - self._host_address = host_address if host_address else shared.get_default_host_address() - self._metadata = metadata + self._host_address = ( + host_address if host_address else shared.get_default_host_address() + ) self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel + # Use provided concurrency options or create default ones + self._concurrency_options = ( + concurrency_options + if concurrency_options is not None + else ConcurrencyOptions() + ) + + # Determine the interceptors to use + if interceptors is not None: + self._interceptors = list(interceptors) + if metadata: + self._interceptors.append(DefaultClientInterceptorImpl(metadata)) + elif metadata: + self._interceptors = [DefaultClientInterceptorImpl(metadata)] + else: + self._interceptors = None + + self._async_worker_manager = _AsyncWorkerManager(self._concurrency_options) + + @property + def concurrency_options(self) -> ConcurrencyOptions: + """Get the current concurrency options for this worker.""" + return self._concurrency_options + def __enter__(self): return self @@ -106,70 +266,223 @@ def __exit__(self, type, value, traceback): def add_orchestrator(self, fn: task.Orchestrator) -> str: """Registers an orchestrator function with the worker.""" if self._is_running: - raise RuntimeError('Orchestrators cannot be added while the worker is running.') + raise RuntimeError( + "Orchestrators cannot be added while the worker is running." + ) return self._registry.add_orchestrator(fn) def add_activity(self, fn: task.Activity) -> str: """Registers an activity function with the worker.""" if self._is_running: - raise RuntimeError('Activities cannot be added while the worker is running.') + raise RuntimeError( + "Activities cannot be added while the worker is running." + ) return self._registry.add_activity(fn) def start(self): """Starts the worker on a background thread and begins listening for work items.""" - channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel) - stub = stubs.TaskHubSidecarServiceStub(channel) - if self._is_running: - raise RuntimeError('The worker is already running.') + raise RuntimeError("The worker is already running.") def run_loop(): - # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity - # functions. We'd need to know ahead of time whether a function is async or not. - # TODO: Max concurrency configuration settings - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: - while not self._shutdown.is_set(): - try: - # send a "Hello" message to the sidecar to ensure that it's listening - stub.Hello(empty_pb2.Empty()) - - # stream work items - self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest()) - self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...') - - # The stream blocks until either a work item is received or the stream is canceled - # by another thread (see the stop() method). - for work_item in self._response_stream: # type: ignore - request_type = work_item.WhichOneof('request') - self._logger.debug(f'Received "{request_type}" work item') - if work_item.HasField('orchestratorRequest'): - executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub) - elif work_item.HasField('activityRequest'): - executor.submit(self._execute_activity, work_item.activityRequest, stub) - else: - self._logger.warning(f'Unexpected work item type: {request_type}') - - except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore - self._logger.info(f'Disconnected from {self._host_address}') - elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore - self._logger.warning( - f'The sidecar at address {self._host_address} is unavailable - will continue retrying') - else: - self._logger.warning(f'Unexpected error: {rpc_error}') - except Exception as ex: - self._logger.warning(f'Unexpected error: {ex}') - - # CONSIDER: exponential backoff - self._shutdown.wait(5) - self._logger.info("No longer listening for work items") - return + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self._async_run_loop()) self._logger.info(f"Starting gRPC worker that connects to {self._host_address}") self._runLoop = Thread(target=run_loop) self._runLoop.start() self._is_running = True + async def _async_run_loop(self): + worker_task = asyncio.create_task(self._async_worker_manager.run()) + # Connection state management for retry fix + current_channel = None + current_stub = None + current_reader_thread = None + conn_retry_count = 0 + conn_max_retry_delay = 60 + + def create_fresh_connection(): + nonlocal current_channel, current_stub, conn_retry_count + if current_channel: + try: + current_channel.close() + except Exception: + pass + current_channel = None + current_stub = None + try: + current_channel = shared.get_grpc_channel( + self._host_address, self._secure_channel, self._interceptors + ) + current_stub = stubs.TaskHubSidecarServiceStub(current_channel) + current_stub.Hello(empty_pb2.Empty()) + conn_retry_count = 0 + self._logger.info(f"Created fresh connection to {self._host_address}") + except Exception as e: + self._logger.warning(f"Failed to create connection: {e}") + current_channel = None + current_stub = None + raise + + def invalidate_connection(): + nonlocal current_channel, current_stub, current_reader_thread + # Cancel the response stream first to signal the reader thread to stop + if self._response_stream is not None: + try: + self._response_stream.cancel() + except Exception: + pass + self._response_stream = None + + # Wait for the reader thread to finish + if current_reader_thread is not None: + try: + current_reader_thread.join(timeout=2) + if current_reader_thread.is_alive(): + self._logger.warning("Stream reader thread did not shut down gracefully") + except Exception: + pass + current_reader_thread = None + + # Close the channel + if current_channel: + try: + current_channel.close() + except Exception: + pass + current_channel = None + current_stub = None + + def should_invalidate_connection(rpc_error): + error_code = rpc_error.code() # type: ignore + connection_level_errors = { + grpc.StatusCode.UNAVAILABLE, + grpc.StatusCode.DEADLINE_EXCEEDED, + grpc.StatusCode.CANCELLED, + grpc.StatusCode.UNAUTHENTICATED, + grpc.StatusCode.ABORTED, + } + return error_code in connection_level_errors + + while not self._shutdown.is_set(): + if current_stub is None: + try: + create_fresh_connection() + except Exception: + conn_retry_count += 1 + delay = min( + conn_max_retry_delay, + (2 ** min(conn_retry_count, 6)) + random.uniform(0, 1), + ) + self._logger.warning( + f"Connection failed, retrying in {delay:.2f} seconds (attempt {conn_retry_count})" + ) + if self._shutdown.wait(delay): + break + continue + try: + assert current_stub is not None + stub = current_stub + get_work_items_request = pb.GetWorkItemsRequest( + maxConcurrentOrchestrationWorkItems=self._concurrency_options.maximum_concurrent_orchestration_work_items, + maxConcurrentActivityWorkItems=self._concurrency_options.maximum_concurrent_activity_work_items, + ) + self._response_stream = stub.GetWorkItems(get_work_items_request) + self._logger.info( + f"Successfully connected to {self._host_address}. Waiting for work items..." + ) + + # Use a thread to read from the blocking gRPC stream and forward to asyncio + import queue + + work_item_queue = queue.Queue() + + def stream_reader(): + try: + for work_item in self._response_stream: + work_item_queue.put(work_item) + except Exception as e: + work_item_queue.put(e) + + import threading + + current_reader_thread = threading.Thread(target=stream_reader, daemon=True) + current_reader_thread.start() + loop = asyncio.get_running_loop() + while not self._shutdown.is_set(): + try: + work_item = await loop.run_in_executor( + None, work_item_queue.get + ) + if isinstance(work_item, Exception): + raise work_item + request_type = work_item.WhichOneof("request") + self._logger.debug(f'Received "{request_type}" work item') + if work_item.HasField("orchestratorRequest"): + self._async_worker_manager.submit_orchestration( + self._execute_orchestrator, + work_item.orchestratorRequest, + stub, + work_item.completionToken, + ) + elif work_item.HasField("activityRequest"): + self._async_worker_manager.submit_activity( + self._execute_activity, + work_item.activityRequest, + stub, + work_item.completionToken, + ) + elif work_item.HasField("healthPing"): + pass + else: + self._logger.warning( + f"Unexpected work item type: {request_type}" + ) + except Exception as e: + self._logger.warning(f"Error in work item stream: {e}") + raise e + current_reader_thread.join(timeout=1) + self._logger.info("Work item stream ended normally") + except grpc.RpcError as rpc_error: + should_invalidate = should_invalidate_connection(rpc_error) + if should_invalidate: + invalidate_connection() + error_code = rpc_error.code() # type: ignore + error_details = str(rpc_error) + + if error_code == grpc.StatusCode.CANCELLED: + self._logger.info(f"Disconnected from {self._host_address}") + break + elif error_code == grpc.StatusCode.UNAVAILABLE: + # Check if this is a connection timeout scenario + if "Timeout occurred" in error_details or "Failed to connect to remote host" in error_details: + self._logger.warning( + f"Connection timeout to {self._host_address}: {error_details} - will retry with fresh connection" + ) + else: + self._logger.warning( + f"The sidecar at address {self._host_address} is unavailable: {error_details} - will continue retrying" + ) + elif should_invalidate: + self._logger.warning( + f"Connection-level gRPC error ({error_code}): {rpc_error} - resetting connection" + ) + else: + self._logger.warning( + f"Application-level gRPC error ({error_code}): {rpc_error}" + ) + self._shutdown.wait(1) + except Exception as ex: + invalidate_connection() + self._logger.warning(f"Unexpected error: {ex}") + self._shutdown.wait(1) + invalidate_connection() + self._logger.info("No longer listening for work items") + self._async_worker_manager.shutdown() + await worker_task + def stop(self): """Stops the worker and waits for any pending work items to complete.""" if not self._is_running: @@ -181,48 +494,80 @@ def stop(self): self._response_stream.cancel() if self._runLoop is not None: self._runLoop.join(timeout=30) + self._async_worker_manager.shutdown() self._logger.info("Worker shutdown completed") self._is_running = False - def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub): + def _execute_orchestrator( + self, + req: pb.OrchestratorRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): try: executor = _OrchestrationExecutor(self._registry, self._logger) result = executor.execute(req.instanceId, req.pastEvents, req.newEvents) res = pb.OrchestratorResponse( instanceId=req.instanceId, actions=result.actions, - customStatus=pbh.get_string_value(result.encoded_custom_status)) + customStatus=ph.get_string_value(result.encoded_custom_status), + completionToken=completionToken, + ) except Exception as ex: - self._logger.exception(f"An error occurred while trying to execute instance '{req.instanceId}': {ex}") - failure_details = pbh.new_failure_details(ex) - actions = [pbh.new_complete_orchestration_action(-1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details)] - res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions) + self._logger.exception( + f"An error occurred while trying to execute instance '{req.instanceId}': {ex}" + ) + failure_details = ph.new_failure_details(ex) + actions = [ + ph.new_complete_orchestration_action( + -1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details + ) + ] + res = pb.OrchestratorResponse( + instanceId=req.instanceId, + actions=actions, + completionToken=completionToken, + ) try: stub.CompleteOrchestratorTask(res) except Exception as ex: - self._logger.exception(f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}") - - def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub): + self._logger.exception( + f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}" + ) + + def _execute_activity( + self, + req: pb.ActivityRequest, + stub: stubs.TaskHubSidecarServiceStub, + completionToken, + ): instance_id = req.orchestrationInstance.instanceId try: executor = _ActivityExecutor(self._registry, self._logger) - result = executor.execute(instance_id, req.name, req.taskId, req.input.value) + result = executor.execute( + instance_id, req.name, req.taskId, req.input.value + ) res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, - result=pbh.get_string_value(result)) + result=ph.get_string_value(result), + completionToken=completionToken, + ) except Exception as ex: res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, - failureDetails=pbh.new_failure_details(ex)) + failureDetails=ph.new_failure_details(ex), + completionToken=completionToken, + ) try: stub.CompleteActivityTask(res) except Exception as ex: self._logger.exception( - f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}") + f"Failed to deliver activity response for '{req.name}#{req.taskId}' of orchestration ID '{instance_id}' to sidecar: {ex}" + ) class _RuntimeOrchestrationContext(task.OrchestrationContext): @@ -256,7 +601,9 @@ def run(self, generator: Generator[task.Task, Any, Any]): def resume(self): if self._generator is None: # This is never expected unless maybe there's an issue with the history - raise TypeError("The orchestrator generator is not initialized! Was the orchestration history corrupted?") + raise TypeError( + "The orchestrator generator is not initialized! Was the orchestration history corrupted?" + ) # We can resume the generator only if the previously yielded task # has reached a completed state. The only time this won't be the @@ -277,7 +624,12 @@ def resume(self): raise TypeError("The orchestrator generator yielded a non-Task object") self._previous_task = next_task - def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_encoded: bool = False): + def set_complete( + self, + result: Any, + status: pb.OrchestrationStatus, + is_result_encoded: bool = False, + ): if self._is_complete: return @@ -290,7 +642,8 @@ def set_complete(self, result: Any, status: pb.OrchestrationStatus, is_result_en if result is not None: result_json = result if is_result_encoded else shared.to_json(result) action = ph.new_complete_orchestration_action( - self.next_sequence_number(), status, result_json) + self.next_sequence_number(), status, result_json + ) self._pending_actions[action.id] = action def set_failed(self, ex: Exception): @@ -302,7 +655,10 @@ def set_failed(self, ex: Exception): self._completion_status = pb.ORCHESTRATION_STATUS_FAILED action = ph.new_complete_orchestration_action( - self.next_sequence_number(), pb.ORCHESTRATION_STATUS_FAILED, None, ph.new_failure_details(ex) + self.next_sequence_number(), + pb.ORCHESTRATION_STATUS_FAILED, + None, + ph.new_failure_details(ex), ) self._pending_actions[action.id] = action @@ -326,14 +682,21 @@ def get_actions(self) -> list[pb.OrchestratorAction]: # replayed when the new instance starts. for event_name, values in self._received_events.items(): for event_value in values: - encoded_value = shared.to_json(event_value) if event_value else None - carryover_events.append(ph.new_event_raised_event(event_name, encoded_value)) + encoded_value = ( + shared.to_json(event_value) if event_value else None + ) + carryover_events.append( + ph.new_event_raised_event(event_name, encoded_value) + ) action = ph.new_complete_orchestration_action( self.next_sequence_number(), pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW, - result=shared.to_json(self._new_input) if self._new_input is not None else None, + result=shared.to_json(self._new_input) + if self._new_input is not None + else None, failure_details=None, - carryover_events=carryover_events) + carryover_events=carryover_events, + ) return [action] else: return list(self._pending_actions.values()) @@ -350,60 +713,84 @@ def instance_id(self) -> str: def current_utc_datetime(self) -> datetime: return self._current_utc_datetime - @property - def is_replaying(self) -> bool: - return self._is_replaying - @current_utc_datetime.setter def current_utc_datetime(self, value: datetime): self._current_utc_datetime = value + @property + def is_replaying(self) -> bool: + return self._is_replaying + def set_custom_status(self, custom_status: Any) -> None: - self._encoded_custom_status = shared.to_json(custom_status) if custom_status is not None else None + self._encoded_custom_status = ( + shared.to_json(custom_status) if custom_status is not None else None + ) def create_timer(self, fire_at: Union[datetime, timedelta]) -> task.Task: return self.create_timer_internal(fire_at) - def create_timer_internal(self, fire_at: Union[datetime, timedelta], - retryable_task: Optional[task.RetryableTask] = None) -> task.Task: + def create_timer_internal( + self, + fire_at: Union[datetime, timedelta], + retryable_task: Optional[task.RetryableTask] = None, + ) -> task.Task: id = self.next_sequence_number() if isinstance(fire_at, timedelta): fire_at = self.current_utc_datetime + fire_at action = ph.new_create_timer_action(id, fire_at) self._pending_actions[id] = action - timer_task = task.TimerTask() + timer_task: task.TimerTask = task.TimerTask() if retryable_task is not None: timer_task.set_retryable_parent(retryable_task) self._pending_tasks[id] = timer_task return timer_task - def call_activity(self, activity: Union[task.Activity[TInput, TOutput], str], *, - input: Optional[TInput] = None, - retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]: + def call_activity( + self, + activity: Union[task.Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None, + ) -> task.Task[TOutput]: id = self.next_sequence_number() - self.call_activity_function_helper(id, activity, input=input, retry_policy=retry_policy, - is_sub_orch=False) + self.call_activity_function_helper( + id, activity, input=input, retry_policy=retry_policy, is_sub_orch=False + ) return self._pending_tasks.get(id, task.CompletableTask()) - def call_sub_orchestrator(self, orchestrator: task.Orchestrator[TInput, TOutput], *, - input: Optional[TInput] = None, - instance_id: Optional[str] = None, - retry_policy: Optional[task.RetryPolicy] = None) -> task.Task[TOutput]: + def call_sub_orchestrator( + self, + orchestrator: task.Orchestrator[TInput, TOutput], + *, + input: Optional[TInput] = None, + instance_id: Optional[str] = None, + retry_policy: Optional[task.RetryPolicy] = None, + ) -> task.Task[TOutput]: id = self.next_sequence_number() orchestrator_name = task.get_name(orchestrator) - self.call_activity_function_helper(id, orchestrator_name, input=input, retry_policy=retry_policy, - is_sub_orch=True, instance_id=instance_id) + self.call_activity_function_helper( + id, + orchestrator_name, + input=input, + retry_policy=retry_policy, + is_sub_orch=True, + instance_id=instance_id, + ) return self._pending_tasks.get(id, task.CompletableTask()) - def call_activity_function_helper(self, id: Optional[int], - activity_function: Union[task.Activity[TInput, TOutput], str], *, - input: Optional[TInput] = None, - retry_policy: Optional[task.RetryPolicy] = None, - is_sub_orch: bool = False, - instance_id: Optional[str] = None, - fn_task: Optional[task.CompletableTask[TOutput]] = None): + def call_activity_function_helper( + self, + id: Optional[int], + activity_function: Union[task.Activity[TInput, TOutput], str], + *, + input: Optional[TInput] = None, + retry_policy: Optional[task.RetryPolicy] = None, + is_sub_orch: bool = False, + instance_id: Optional[str] = None, + fn_task: Optional[task.CompletableTask[TOutput]] = None, + ): if id is None: id = self.next_sequence_number() @@ -414,7 +801,11 @@ def call_activity_function_helper(self, id: Optional[int], # We just need to take string representation of it. encoded_input = str(input) if not is_sub_orch: - name = activity_function if isinstance(activity_function, str) else task.get_name(activity_function) + name = ( + activity_function + if isinstance(activity_function, str) + else task.get_name(activity_function) + ) action = ph.new_schedule_task_action(id, name, encoded_input) else: if instance_id is None: @@ -422,16 +813,21 @@ def call_activity_function_helper(self, id: Optional[int], instance_id = f"{self.instance_id}:{id:04x}" if not isinstance(activity_function, str): raise ValueError("Orchestrator function name must be a string") - action = ph.new_create_sub_orchestration_action(id, activity_function, instance_id, encoded_input) + action = ph.new_create_sub_orchestration_action( + id, activity_function, instance_id, encoded_input + ) self._pending_actions[id] = action if fn_task is None: if retry_policy is None: fn_task = task.CompletableTask[TOutput]() else: - fn_task = task.RetryableTask[TOutput](retry_policy=retry_policy, action=action, - start_time=self.current_utc_datetime, - is_sub_orch=is_sub_orch) + fn_task = task.RetryableTask[TOutput]( + retry_policy=retry_policy, + action=action, + start_time=self.current_utc_datetime, + is_sub_orch=is_sub_orch, + ) self._pending_tasks[id] = fn_task def wait_for_external_event(self, name: str) -> task.Task: @@ -440,7 +836,7 @@ def wait_for_external_event(self, name: str) -> task.Task: # event with the given name so that we can resume the generator when it # arrives. If there are multiple events with the same name, we return # them in the order they were received. - external_event_task = task.CompletableTask() + external_event_task: task.CompletableTask = task.CompletableTask() event_name = name.casefold() event_list = self._received_events.get(event_name, None) if event_list: @@ -467,7 +863,10 @@ class ExecutionResults: actions: list[pb.OrchestratorAction] encoded_custom_status: Optional[str] - def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str]): + + def __init__( + self, actions: list[pb.OrchestratorAction], encoded_custom_status: Optional[str] + ): self.actions = actions self.encoded_custom_status = encoded_custom_status @@ -480,14 +879,23 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._is_suspended = False self._suspended_events: list[pb.HistoryEvent] = [] - def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_events: Sequence[pb.HistoryEvent]) -> ExecutionResults: + def execute( + self, + instance_id: str, + old_events: Sequence[pb.HistoryEvent], + new_events: Sequence[pb.HistoryEvent], + ) -> ExecutionResults: if not new_events: - raise task.OrchestrationStateError("The new history event list must have at least one event in it.") + raise task.OrchestrationStateError( + "The new history event list must have at least one event in it." + ) ctx = _RuntimeOrchestrationContext(instance_id) try: # Rebuild local state by replaying old history into the orchestrator function - self._logger.debug(f"{instance_id}: Rebuilding local state with {len(old_events)} history event...") + self._logger.debug( + f"{instance_id}: Rebuilding local state with {len(old_events)} history event..." + ) ctx._is_replaying = True for old_event in old_events: self.process_event(ctx, old_event) @@ -495,7 +903,9 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e # Get new actions by executing newly received events into the orchestrator function if self._logger.level <= logging.DEBUG: summary = _get_new_event_summary(new_events) - self._logger.debug(f"{instance_id}: Processing {len(new_events)} new event(s): {summary}") + self._logger.debug( + f"{instance_id}: Processing {len(new_events)} new event(s): {summary}" + ) ctx._is_replaying = False for new_event in new_events: self.process_event(ctx, new_event) @@ -507,17 +917,32 @@ def execute(self, instance_id: str, old_events: Sequence[pb.HistoryEvent], new_e if not ctx._is_complete: task_count = len(ctx._pending_tasks) event_count = len(ctx._pending_events) - self._logger.info(f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding.") - elif ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW: - completion_status_str = pbh.get_orchestration_status_str(ctx._completion_status) - self._logger.info(f"{instance_id}: Orchestration completed with status: {completion_status_str}") + self._logger.info( + f"{instance_id}: Orchestrator yielded with {task_count} task(s) and {event_count} event(s) outstanding." + ) + elif ( + ctx._completion_status and ctx._completion_status is not pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW + ): + completion_status_str = ph.get_orchestration_status_str( + ctx._completion_status + ) + self._logger.info( + f"{instance_id}: Orchestration completed with status: {completion_status_str}" + ) actions = ctx.get_actions() if self._logger.level <= logging.DEBUG: - self._logger.debug(f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}") - return ExecutionResults(actions=actions, encoded_custom_status=ctx._encoded_custom_status) - def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent) -> None: + self._logger.debug( + f"{instance_id}: Returning {len(actions)} action(s): {_get_action_summary(actions)}" + ) + return ExecutionResults( + actions=actions, encoded_custom_status=ctx._encoded_custom_status + ) + + def process_event( + self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEvent + ) -> None: if self._is_suspended and _is_suspendable(event): # We are suspended, so we need to buffer this event until we are resumed self._suspended_events.append(event) @@ -532,14 +957,19 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven fn = self._registry.get_orchestrator(event.executionStarted.name) if fn is None: raise OrchestratorNotRegisteredError( - f"A '{event.executionStarted.name}' orchestrator was not registered.") + f"A '{event.executionStarted.name}' orchestrator was not registered." + ) # deserialize the input, if any input = None - if event.executionStarted.input is not None and event.executionStarted.input.value != "": + if ( + event.executionStarted.input is not None and event.executionStarted.input.value != "" + ): input = shared.from_json(event.executionStarted.input.value) - result = fn(ctx, input) # this does not execute the generator, only creates it + result = fn( + ctx, input + ) # this does not execute the generator, only creates it if isinstance(result, GeneratorType): # Start the orchestrator's generator function ctx.run(result) @@ -552,10 +982,14 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven timer_id = event.eventId action = ctx._pending_actions.pop(timer_id, None) if not action: - raise _get_non_determinism_error(timer_id, task.get_name(ctx.create_timer)) + raise _get_non_determinism_error( + timer_id, task.get_name(ctx.create_timer) + ) elif not action.HasField("createTimer"): expected_method_name = task.get_name(ctx.create_timer) - raise _get_wrong_action_type_error(timer_id, expected_method_name, action) + raise _get_wrong_action_type_error( + timer_id, expected_method_name, action + ) elif event.HasField("timerFired"): timer_id = event.timerFired.timerId timer_task = ctx._pending_tasks.pop(timer_id, None) @@ -563,7 +997,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx._is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}.") + f"{ctx.instance_id}: Ignoring unexpected timerFired event with ID = {timer_id}." + ) return timer_task.complete(None) if timer_task._retryable_parent is not None: @@ -575,12 +1010,15 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven else: cur_task = activity_action.createSubOrchestration instance_id = cur_task.instanceId - ctx.call_activity_function_helper(id=activity_action.id, activity_function=cur_task.name, - input=cur_task.input.value, - retry_policy=timer_task._retryable_parent._retry_policy, - is_sub_orch=timer_task._retryable_parent._is_sub_orch, - instance_id=instance_id, - fn_task=timer_task._retryable_parent) + ctx.call_activity_function_helper( + id=activity_action.id, + activity_function=cur_task.name, + input=cur_task.input.value, + retry_policy=timer_task._retryable_parent._retry_policy, + is_sub_orch=timer_task._retryable_parent._is_sub_orch, + instance_id=instance_id, + fn_task=timer_task._retryable_parent, + ) else: ctx.resume() elif event.HasField("taskScheduled"): @@ -590,16 +1028,21 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven action = ctx._pending_actions.pop(task_id, None) activity_task = ctx._pending_tasks.get(task_id, None) if not action: - raise _get_non_determinism_error(task_id, task.get_name(ctx.call_activity)) + raise _get_non_determinism_error( + task_id, task.get_name(ctx.call_activity) + ) elif not action.HasField("scheduleTask"): expected_method_name = task.get_name(ctx.call_activity) - raise _get_wrong_action_type_error(task_id, expected_method_name, action) + raise _get_wrong_action_type_error( + task_id, expected_method_name, action + ) elif action.scheduleTask.name != event.taskScheduled.name: raise _get_wrong_action_name_error( task_id, method_name=task.get_name(ctx.call_activity), expected_task_name=event.taskScheduled.name, - actual_task_name=action.scheduleTask.name) + actual_task_name=action.scheduleTask.name, + ) elif event.HasField("taskCompleted"): # This history event contains the result of a completed activity task. task_id = event.taskCompleted.taskScheduledId @@ -608,7 +1051,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx.is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}.") + f"{ctx.instance_id}: Ignoring unexpected taskCompleted event with ID = {task_id}." + ) return result = None if not ph.is_empty(event.taskCompleted.result): @@ -622,7 +1066,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx.is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}.") + f"{ctx.instance_id}: Ignoring unexpected taskFailed event with ID = {task_id}." + ) return if isinstance(activity_task, task.RetryableTask): @@ -631,7 +1076,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if next_delay is None: activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", - event.taskFailed.failureDetails) + event.taskFailed.failureDetails, + ) ctx.resume() else: activity_task.increment_attempt_count() @@ -639,7 +1085,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven elif isinstance(activity_task, task.CompletableTask): activity_task.fail( f"{ctx.instance_id}: Activity task #{task_id} failed: {event.taskFailed.failureDetails.errorMessage}", - event.taskFailed.failureDetails) + event.taskFailed.failureDetails, + ) ctx.resume() else: raise TypeError("Unexpected task type") @@ -649,16 +1096,23 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven task_id = event.eventId action = ctx._pending_actions.pop(task_id, None) if not action: - raise _get_non_determinism_error(task_id, task.get_name(ctx.call_sub_orchestrator)) + raise _get_non_determinism_error( + task_id, task.get_name(ctx.call_sub_orchestrator) + ) elif not action.HasField("createSubOrchestration"): expected_method_name = task.get_name(ctx.call_sub_orchestrator) - raise _get_wrong_action_type_error(task_id, expected_method_name, action) - elif action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name: + raise _get_wrong_action_type_error( + task_id, expected_method_name, action + ) + elif ( + action.createSubOrchestration.name != event.subOrchestrationInstanceCreated.name + ): raise _get_wrong_action_name_error( task_id, method_name=task.get_name(ctx.call_sub_orchestrator), expected_task_name=event.subOrchestrationInstanceCreated.name, - actual_task_name=action.createSubOrchestration.name) + actual_task_name=action.createSubOrchestration.name, + ) elif event.HasField("subOrchestrationInstanceCompleted"): task_id = event.subOrchestrationInstanceCompleted.taskScheduledId sub_orch_task = ctx._pending_tasks.pop(task_id, None) @@ -666,11 +1120,14 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx.is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}.") + f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceCompleted event with ID = {task_id}." + ) return result = None if not ph.is_empty(event.subOrchestrationInstanceCompleted.result): - result = shared.from_json(event.subOrchestrationInstanceCompleted.result.value) + result = shared.from_json( + event.subOrchestrationInstanceCompleted.result.value + ) sub_orch_task.complete(result) ctx.resume() elif event.HasField("subOrchestrationInstanceFailed"): @@ -681,7 +1138,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven # TODO: Should this be an error? When would it ever happen? if not ctx.is_replaying: self._logger.warning( - f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}.") + f"{ctx.instance_id}: Ignoring unexpected subOrchestrationInstanceFailed event with ID = {task_id}." + ) return if isinstance(sub_orch_task, task.RetryableTask): if sub_orch_task._retry_policy is not None: @@ -689,7 +1147,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven if next_delay is None: sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", - failedEvent.failureDetails) + failedEvent.failureDetails, + ) ctx.resume() else: sub_orch_task.increment_attempt_count() @@ -697,7 +1156,8 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven elif isinstance(sub_orch_task, task.CompletableTask): sub_orch_task.fail( f"Sub-orchestration task #{task_id} failed: {failedEvent.failureDetails.errorMessage}", - failedEvent.failureDetails) + failedEvent.failureDetails, + ) ctx.resume() else: raise TypeError("Unexpected sub-orchestration task type") @@ -726,7 +1186,9 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven decoded_result = shared.from_json(event.eventRaised.input.value) event_list.append(decoded_result) if not ctx.is_replaying: - self._logger.info(f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it.") + self._logger.info( + f"{ctx.instance_id}: Event '{event_name}' has been buffered as there are no tasks waiting for it." + ) elif event.HasField("executionSuspended"): if not self._is_suspended and not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Execution suspended.") @@ -741,11 +1203,21 @@ def process_event(self, ctx: _RuntimeOrchestrationContext, event: pb.HistoryEven elif event.HasField("executionTerminated"): if not ctx.is_replaying: self._logger.info(f"{ctx.instance_id}: Execution terminating.") - encoded_output = event.executionTerminated.input.value if not ph.is_empty(event.executionTerminated.input) else None - ctx.set_complete(encoded_output, pb.ORCHESTRATION_STATUS_TERMINATED, is_result_encoded=True) + encoded_output = ( + event.executionTerminated.input.value + if not ph.is_empty(event.executionTerminated.input) + else None + ) + ctx.set_complete( + encoded_output, + pb.ORCHESTRATION_STATUS_TERMINATED, + is_result_encoded=True, + ) else: eventType = event.WhichOneof("eventType") - raise task.OrchestrationStateError(f"Don't know how to handle event of type '{eventType}'") + raise task.OrchestrationStateError( + f"Don't know how to handle event of type '{eventType}'" + ) except StopIteration as generatorStopped: # The orchestrator generator function completed ctx.set_complete(generatorStopped.value, pb.ORCHESTRATION_STATUS_COMPLETED) @@ -756,12 +1228,22 @@ def __init__(self, registry: _Registry, logger: logging.Logger): self._registry = registry self._logger = logger - def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: Optional[str]) -> Optional[str]: + def execute( + self, + orchestration_id: str, + name: str, + task_id: int, + encoded_input: Optional[str], + ) -> Optional[str]: """Executes an activity function and returns the serialized result, if any.""" - self._logger.debug(f"{orchestration_id}/{task_id}: Executing activity '{name}'...") + self._logger.debug( + f"{orchestration_id}/{task_id}: Executing activity '{name}'..." + ) fn = self._registry.get_activity(name) if not fn: - raise ActivityNotRegisteredError(f"Activity function named '{name}' was not registered!") + raise ActivityNotRegisteredError( + f"Activity function named '{name}' was not registered!" + ) activity_input = shared.from_json(encoded_input) if encoded_input else None ctx = task.ActivityContext(orchestration_id, task_id) @@ -769,49 +1251,54 @@ def execute(self, orchestration_id: str, name: str, task_id: int, encoded_input: # Execute the activity function activity_output = fn(ctx, activity_input) - encoded_output = shared.to_json(activity_output) if activity_output is not None else None + encoded_output = ( + shared.to_json(activity_output) if activity_output is not None else None + ) chars = len(encoded_output) if encoded_output else 0 self._logger.debug( - f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output.") + f"{orchestration_id}/{task_id}: Activity '{name}' completed successfully with {chars} char(s) of encoded output." + ) return encoded_output -def _get_non_determinism_error(task_id: int, action_name: str) -> task.NonDeterminismError: +def _get_non_determinism_error( + task_id: int, action_name: str +) -> task.NonDeterminismError: return task.NonDeterminismError( f"A previous execution called {action_name} with ID={task_id}, but the current " f"execution doesn't have this action with this ID. This problem occurs when either " f"the orchestration has non-deterministic logic or if the code was changed after an " - f"instance of this orchestration already started running.") + f"instance of this orchestration already started running." + ) def _get_wrong_action_type_error( - task_id: int, - expected_method_name: str, - action: pb.OrchestratorAction) -> task.NonDeterminismError: + task_id: int, expected_method_name: str, action: pb.OrchestratorAction +) -> task.NonDeterminismError: unexpected_method_name = _get_method_name_for_action(action) return task.NonDeterminismError( f"Failed to restore orchestration state due to a history mismatch: A previous execution called " f"{expected_method_name} with ID={task_id}, but the current execution is instead trying to call " f"{unexpected_method_name} as part of rebuilding it's history. This kind of mismatch can happen if an " f"orchestration has non-deterministic logic or if the code was changed after an instance of this " - f"orchestration already started running.") + f"orchestration already started running." + ) def _get_wrong_action_name_error( - task_id: int, - method_name: str, - expected_task_name: str, - actual_task_name: str) -> task.NonDeterminismError: + task_id: int, method_name: str, expected_task_name: str, actual_task_name: str +) -> task.NonDeterminismError: return task.NonDeterminismError( f"Failed to restore orchestration state due to a history mismatch: A previous execution called " f"{method_name} with name='{expected_task_name}' and sequence number {task_id}, but the current " f"execution is instead trying to call {actual_task_name} as part of rebuilding it's history. " f"This kind of mismatch can happen if an orchestration has non-deterministic logic or if the code " - f"was changed after an instance of this orchestration already started running.") + f"was changed after an instance of this orchestration already started running." + ) def _get_method_name_for_action(action: pb.OrchestratorAction) -> str: - action_type = action.WhichOneof('orchestratorActionType') + action_type = action.WhichOneof("orchestratorActionType") if action_type == "scheduleTask": return task.get_name(task.OrchestrationContext.call_activity) elif action_type == "createTimer": @@ -833,7 +1320,7 @@ def _get_new_event_summary(new_events: Sequence[pb.HistoryEvent]) -> str: else: counts: dict[str, int] = {} for event in new_events: - event_type = event.WhichOneof('eventType') + event_type = event.WhichOneof("eventType") counts[event_type] = counts.get(event_type, 0) + 1 return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]" @@ -847,11 +1334,210 @@ def _get_action_summary(new_actions: Sequence[pb.OrchestratorAction]) -> str: else: counts: dict[str, int] = {} for action in new_actions: - action_type = action.WhichOneof('orchestratorActionType') + action_type = action.WhichOneof("orchestratorActionType") counts[action_type] = counts.get(action_type, 0) + 1 return f"[{', '.join(f'{name}={count}' for name, count in counts.items())}]" def _is_suspendable(event: pb.HistoryEvent) -> bool: """Returns true if the event is one that can be suspended and resumed.""" - return event.WhichOneof("eventType") not in ["executionResumed", "executionTerminated"] + return event.WhichOneof("eventType") not in [ + "executionResumed", + "executionTerminated", + ] + + +class _AsyncWorkerManager: + def __init__(self, concurrency_options: ConcurrencyOptions): + self.concurrency_options = concurrency_options + self.activity_semaphore = None + self.orchestration_semaphore = None + # Don't create queues here - defer until we have an event loop + self.activity_queue: Optional[asyncio.Queue] = None + self.orchestration_queue: Optional[asyncio.Queue] = None + self._queue_event_loop: Optional[asyncio.AbstractEventLoop] = None + # Store work items when no event loop is available + self._pending_activity_work: list = [] + self._pending_orchestration_work: list = [] + self.thread_pool = ThreadPoolExecutor( + max_workers=concurrency_options.maximum_thread_pool_workers, + thread_name_prefix="DurableTask", + ) + self._shutdown = False + + def _ensure_queues_for_current_loop(self): + """Ensure queues are bound to the current event loop.""" + try: + current_loop = asyncio.get_running_loop() + except RuntimeError: + # No event loop running, can't create queues + return + + # Check if queues are already properly set up for current loop + if self._queue_event_loop is current_loop: + if self.activity_queue is not None and self.orchestration_queue is not None: + # Queues are already bound to the current loop and exist + return + + # Need to recreate queues for the current event loop + # First, preserve any existing work items + existing_activity_items = [] + existing_orchestration_items = [] + + if self.activity_queue is not None: + try: + while not self.activity_queue.empty(): + existing_activity_items.append(self.activity_queue.get_nowait()) + except Exception: + pass + + if self.orchestration_queue is not None: + try: + while not self.orchestration_queue.empty(): + existing_orchestration_items.append( + self.orchestration_queue.get_nowait() + ) + except Exception: + pass + + # Create fresh queues for the current event loop + self.activity_queue = asyncio.Queue() + self.orchestration_queue = asyncio.Queue() + self._queue_event_loop = current_loop + + # Restore the work items to the new queues + for item in existing_activity_items: + self.activity_queue.put_nowait(item) + for item in existing_orchestration_items: + self.orchestration_queue.put_nowait(item) + + # Move pending work items to the queues + for item in self._pending_activity_work: + self.activity_queue.put_nowait(item) + for item in self._pending_orchestration_work: + self.orchestration_queue.put_nowait(item) + + # Clear the pending work lists + self._pending_activity_work.clear() + self._pending_orchestration_work.clear() + + async def run(self): + # Reset shutdown flag in case this manager is being reused + self._shutdown = False + + # Ensure queues are properly bound to the current event loop + self._ensure_queues_for_current_loop() + + # Create semaphores in the current event loop + self.activity_semaphore = asyncio.Semaphore( + self.concurrency_options.maximum_concurrent_activity_work_items + ) + self.orchestration_semaphore = asyncio.Semaphore( + self.concurrency_options.maximum_concurrent_orchestration_work_items + ) + + # Start background consumers for each work type + if self.activity_queue is not None and self.orchestration_queue is not None: + await asyncio.gather( + self._consume_queue(self.activity_queue, self.activity_semaphore), + self._consume_queue( + self.orchestration_queue, self.orchestration_semaphore + ), + ) + + async def _consume_queue(self, queue: asyncio.Queue, semaphore: asyncio.Semaphore): + # List to track running tasks + running_tasks: set[asyncio.Task] = set() + + while True: + # Clean up completed tasks + done_tasks = {task for task in running_tasks if task.done()} + running_tasks -= done_tasks + + # Exit if shutdown is set and the queue is empty and no tasks are running + if self._shutdown and queue.empty() and not running_tasks: + break + + try: + work = await asyncio.wait_for(queue.get(), timeout=1.0) + except asyncio.TimeoutError: + continue + + func, args, kwargs = work + # Create a concurrent task for processing + task = asyncio.create_task( + self._process_work_item(semaphore, queue, func, args, kwargs) + ) + running_tasks.add(task) + + async def _process_work_item( + self, semaphore: asyncio.Semaphore, queue: asyncio.Queue, func, args, kwargs + ): + async with semaphore: + try: + await self._run_func(func, *args, **kwargs) + finally: + queue.task_done() + + async def _run_func(self, func, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return await func(*args, **kwargs) + else: + loop = asyncio.get_running_loop() + # Avoid submitting to executor after shutdown + if ( + getattr(self, "_shutdown", False) and getattr(self, "thread_pool", None) and getattr( + self.thread_pool, "_shutdown", False) + ): + return None + return await loop.run_in_executor( + self.thread_pool, lambda: func(*args, **kwargs) + ) + + def submit_activity(self, func, *args, **kwargs): + work_item = (func, args, kwargs) + self._ensure_queues_for_current_loop() + if self.activity_queue is not None: + self.activity_queue.put_nowait(work_item) + else: + # No event loop running, store in pending list + self._pending_activity_work.append(work_item) + + def submit_orchestration(self, func, *args, **kwargs): + work_item = (func, args, kwargs) + self._ensure_queues_for_current_loop() + if self.orchestration_queue is not None: + self.orchestration_queue.put_nowait(work_item) + else: + # No event loop running, store in pending list + self._pending_orchestration_work.append(work_item) + + def shutdown(self): + self._shutdown = True + self.thread_pool.shutdown(wait=True) + + def reset_for_new_run(self): + """Reset the manager state for a new run.""" + self._shutdown = False + # Clear any existing queues - they'll be recreated when needed + if self.activity_queue is not None: + # Clear existing queue by creating a new one + # This ensures no items from previous runs remain + try: + while not self.activity_queue.empty(): + self.activity_queue.get_nowait() + except Exception: + pass + if self.orchestration_queue is not None: + try: + while not self.orchestration_queue.empty(): + self.orchestration_queue.get_nowait() + except Exception: + pass + # Clear pending work lists + self._pending_activity_work.clear() + self._pending_orchestration_work.clear() + + +# Export public API +__all__ = ["ConcurrencyOptions", "TaskHubGrpcWorker"] diff --git a/examples/README.md b/examples/README.md index ec9088f..a6cd847 100644 --- a/examples/README.md +++ b/examples/README.md @@ -8,7 +8,11 @@ All the examples assume that you have a Durable Task-compatible sidecar running 1. Install the latest version of the [Dapr CLI](https://docs.dapr.io/getting-started/install-dapr-cli/), which contains and exposes an embedded version of the Durable Task engine. The setup process (which requires Docker) will configure the workflow engine to store state in a local Redis container. -1. Clone and run the [Durable Task Sidecar](https://github.com/microsoft/durabletask-go) project locally (requires Go 1.18 or higher). Orchestration state will be stored in a local sqlite database. +2. Run the [Durable Task Sidecar](https://github.com/dapr/durabletask-go) project locally (requires Go 1.18 or higher). Orchestration state will be stored in a local sqlite database. + ```sh + go install github.com/dapr/durabletask-go@main + durabletask-go --port 4001 + ``` ## Running the examples @@ -24,4 +28,4 @@ In some cases, the sample may require command-line parameters or user inputs. In - [Activity sequence](./activity_sequence.py): Orchestration that schedules three activity calls in a sequence. - [Fan-out/fan-in](./fanout_fanin.py): Orchestration that schedules a dynamic number of activity calls in parallel, waits for all of them to complete, and then performs an aggregation on the results. -- [Human interaction](./human_interaction.py): Orchestration that waits for a human to approve an order before continuing. \ No newline at end of file +- [Human interaction](./human_interaction.py): Orchestration that waits for a human to approve an order before continuing. diff --git a/pyproject.toml b/pyproject.toml index 9c05d86..8c4d1e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,8 @@ license = {file = "LICENSE"} readme = "README.md" dependencies = [ "grpcio", + "protobuf", + "asyncio" ] [project.urls] @@ -42,6 +44,7 @@ local_scheme = "no-local-version" [tool.pytest.ini_options] minversion = "6.0" testpaths = ["tests"] +pythonpath = ["."] markers = [ "e2e: mark a test as an end-to-end test that requires a running sidecar" ] diff --git a/requirements.txt b/requirements.txt index a31419b..07426eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newe protobuf pytest pytest-cov +asyncio diff --git a/tests/test_activity_executor.py b/tests/durabletask/test_activity_executor.py similarity index 100% rename from tests/test_activity_executor.py rename to tests/durabletask/test_activity_executor.py diff --git a/tests/durabletask/test_client.py b/tests/durabletask/test_client.py new file mode 100644 index 0000000..e5a8e9b --- /dev/null +++ b/tests/durabletask/test_client.py @@ -0,0 +1,88 @@ +from unittest.mock import ANY, patch + +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +from durabletask.internal.shared import (get_default_host_address, + get_grpc_channel) + +HOST_ADDRESS = 'localhost:50051' +METADATA = [('key1', 'value1'), ('key2', 'value2')] +INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] + + +def test_get_grpc_channel_insecure(): + with patch('grpc.insecure_channel') as mock_channel: + get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) + mock_channel.assert_called_once_with(HOST_ADDRESS) + + +def test_get_grpc_channel_secure(): + with patch('grpc.secure_channel') as mock_channel, patch( + 'grpc.ssl_channel_credentials') as mock_credentials: + get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS) + mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) + +def test_get_grpc_channel_default_host_address(): + with patch('grpc.insecure_channel') as mock_channel: + get_grpc_channel(None, False, interceptors=INTERCEPTORS) + mock_channel.assert_called_once_with(get_default_host_address()) + + +def test_get_grpc_channel_with_metadata(): + with patch('grpc.insecure_channel') as mock_channel, patch( + 'grpc.intercept_channel') as mock_intercept_channel: + get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) + mock_channel.assert_called_once_with(HOST_ADDRESS) + mock_intercept_channel.assert_called_once() + + # Capture and check the arguments passed to intercept_channel() + args, kwargs = mock_intercept_channel.call_args + assert args[0] == mock_channel.return_value + assert isinstance(args[1], DefaultClientInterceptorImpl) + assert args[1]._metadata == METADATA + + +def test_grpc_channel_with_host_name_protocol_stripping(): + with patch('grpc.insecure_channel') as mock_insecure_channel, patch( + 'grpc.secure_channel') as mock_secure_channel: + + host_name = "myserver.com:1234" + + prefix = "grpc://" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_insecure_channel.assert_called_with(host_name) + + prefix = "http://" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_insecure_channel.assert_called_with(host_name) + + prefix = "HTTP://" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_insecure_channel.assert_called_with(host_name) + + prefix = "GRPC://" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_insecure_channel.assert_called_with(host_name) + + prefix = "" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_insecure_channel.assert_called_with(host_name) + + prefix = "grpcs://" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_secure_channel.assert_called_with(host_name, ANY) + + prefix = "https://" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_secure_channel.assert_called_with(host_name, ANY) + + prefix = "HTTPS://" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_secure_channel.assert_called_with(host_name, ANY) + + prefix = "GRPCS://" + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) + mock_secure_channel.assert_called_with(host_name, ANY) + + prefix = "" + get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) + mock_secure_channel.assert_called_with(host_name, ANY) diff --git a/tests/durabletask/test_concurrency_options.py b/tests/durabletask/test_concurrency_options.py new file mode 100644 index 0000000..b49b7ec --- /dev/null +++ b/tests/durabletask/test_concurrency_options.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import os + +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker + + +def test_default_concurrency_options(): + """Test that default concurrency options work correctly.""" + options = ConcurrencyOptions() + processor_count = os.cpu_count() or 1 + expected_default = 100 * processor_count + expected_workers = processor_count + 4 + + assert options.maximum_concurrent_activity_work_items == expected_default + assert options.maximum_concurrent_orchestration_work_items == expected_default + assert options.maximum_thread_pool_workers == expected_workers + + +def test_custom_concurrency_options(): + """Test that custom concurrency options work correctly.""" + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=50, + maximum_concurrent_orchestration_work_items=25, + maximum_thread_pool_workers=30, + ) + + assert options.maximum_concurrent_activity_work_items == 50 + assert options.maximum_concurrent_orchestration_work_items == 25 + assert options.maximum_thread_pool_workers == 30 + + +def test_partial_custom_options(): + """Test that partially specified options use defaults for unspecified values.""" + processor_count = os.cpu_count() or 1 + expected_default = 100 * processor_count + expected_workers = processor_count + 4 + + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=30 + ) + + assert options.maximum_concurrent_activity_work_items == 30 + assert options.maximum_concurrent_orchestration_work_items == expected_default + assert options.maximum_thread_pool_workers == expected_workers + + +def test_worker_with_concurrency_options(): + """Test that TaskHubGrpcWorker accepts concurrency options.""" + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=10, + maximum_concurrent_orchestration_work_items=20, + maximum_thread_pool_workers=15, + ) + + worker = TaskHubGrpcWorker(concurrency_options=options) + + assert worker.concurrency_options == options + + +def test_worker_default_options(): + """Test that TaskHubGrpcWorker uses default options when no parameters are provided.""" + worker = TaskHubGrpcWorker() + + processor_count = os.cpu_count() or 1 + expected_default = 100 * processor_count + expected_workers = processor_count + 4 + + assert ( + worker.concurrency_options.maximum_concurrent_activity_work_items == expected_default + ) + assert ( + worker.concurrency_options.maximum_concurrent_orchestration_work_items == expected_default + ) + assert worker.concurrency_options.maximum_thread_pool_workers == expected_workers + + +def test_concurrency_options_property_access(): + """Test that the concurrency_options property works correctly.""" + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=15, + maximum_concurrent_orchestration_work_items=25, + maximum_thread_pool_workers=30, + ) + + worker = TaskHubGrpcWorker(concurrency_options=options) + retrieved_options = worker.concurrency_options + + # Should be the same object + assert retrieved_options is options + + # Should have correct values + assert retrieved_options.maximum_concurrent_activity_work_items == 15 + assert retrieved_options.maximum_concurrent_orchestration_work_items == 25 + assert retrieved_options.maximum_thread_pool_workers == 30 diff --git a/tests/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py similarity index 99% rename from tests/test_orchestration_e2e.py rename to tests/durabletask/test_orchestration_e2e.py index bcb3d3c..2343184 100644 --- a/tests/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -278,6 +278,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): assert state.runtime_status == client.OrchestrationStatus.TERMINATED assert state.serialized_output == json.dumps("some reason for termination") + def test_terminate_recursive(): thread_lock = threading.Lock() activity_counter = 0 @@ -330,7 +331,6 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): assert activity_counter == 5, "Activity should have executed without recursive termination" - def test_continue_as_new(): all_results = [] @@ -338,7 +338,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): result = yield ctx.wait_for_external_event("my_event") if not ctx.is_replaying: # NOTE: Real orchestrations should never interact with nonlocal variables like this. - nonlocal all_results + nonlocal all_results # noqa: F824 all_results.append(result) if len(all_results) <= 4: @@ -462,6 +462,7 @@ def throw_activity(ctx: task.ActivityContext, _): assert state.failure_details.stack_trace is not None assert throw_activity_counter == 4 + def test_custom_status(): def empty_orchestrator(ctx: task.OrchestrationContext, _): diff --git a/tests/test_orchestration_executor.py b/tests/durabletask/test_orchestration_executor.py similarity index 100% rename from tests/test_orchestration_executor.py rename to tests/durabletask/test_orchestration_executor.py diff --git a/tests/durabletask/test_orchestration_wait.py b/tests/durabletask/test_orchestration_wait.py new file mode 100644 index 0000000..03f7e30 --- /dev/null +++ b/tests/durabletask/test_orchestration_wait.py @@ -0,0 +1,63 @@ +from unittest.mock import patch, ANY, Mock + +from durabletask.client import TaskHubGrpcClient +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +from durabletask.internal.shared import (get_default_host_address, + get_grpc_channel) +import pytest + +@pytest.mark.parametrize("timeout", [None, 0, 5]) +def test_wait_for_orchestration_start_timeout(timeout): + instance_id = "test-instance" + + from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \ + OrchestrationState, ORCHESTRATION_STATUS_RUNNING + + response = GetInstanceResponse() + state = OrchestrationState() + state.instanceId = instance_id + state.orchestrationStatus = ORCHESTRATION_STATUS_RUNNING + response.orchestrationState.CopyFrom(state) + + c = TaskHubGrpcClient() + c._stub = Mock() + c._stub.WaitForInstanceStart.return_value = response + + grpc_timeout = None if timeout is None else timeout + c.wait_for_orchestration_start(instance_id, timeout=grpc_timeout) + + # Verify WaitForInstanceStart was called with timeout=None + c._stub.WaitForInstanceStart.assert_called_once() + _, kwargs = c._stub.WaitForInstanceStart.call_args + if timeout is None or timeout == 0: + assert kwargs.get('timeout') is None + else: + assert kwargs.get('timeout') == timeout + +@pytest.mark.parametrize("timeout", [None, 0, 5]) +def test_wait_for_orchestration_completion_timeout(timeout): + instance_id = "test-instance" + + from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \ + OrchestrationState, ORCHESTRATION_STATUS_COMPLETED + + response = GetInstanceResponse() + state = OrchestrationState() + state.instanceId = instance_id + state.orchestrationStatus = ORCHESTRATION_STATUS_COMPLETED + response.orchestrationState.CopyFrom(state) + + c = TaskHubGrpcClient() + c._stub = Mock() + c._stub.WaitForInstanceCompletion.return_value = response + + grpc_timeout = None if timeout is None else timeout + c.wait_for_orchestration_completion(instance_id, timeout=grpc_timeout) + + # Verify WaitForInstanceStart was called with timeout=None + c._stub.WaitForInstanceCompletion.assert_called_once() + _, kwargs = c._stub.WaitForInstanceCompletion.call_args + if timeout is None or timeout == 0: + assert kwargs.get('timeout') is None + else: + assert kwargs.get('timeout') == timeout diff --git a/tests/durabletask/test_worker_concurrency_loop.py b/tests/durabletask/test_worker_concurrency_loop.py new file mode 100644 index 0000000..de6753b --- /dev/null +++ b/tests/durabletask/test_worker_concurrency_loop.py @@ -0,0 +1,140 @@ +import asyncio +import threading +import time + +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker + + +class DummyStub: + def __init__(self): + self.completed = [] + + def CompleteOrchestratorTask(self, res): + self.completed.append(('orchestrator', res)) + + def CompleteActivityTask(self, res): + self.completed.append(('activity', res)) + + +class DummyRequest: + def __init__(self, kind, instance_id): + self.kind = kind + self.instanceId = instance_id + self.orchestrationInstance = type('O', (), {'instanceId': instance_id}) + self.name = 'dummy' + self.taskId = 1 + self.input = type('I', (), {'value': ''}) + self.pastEvents = [] + self.newEvents = [] + + def HasField(self, field): + return (field == 'orchestratorRequest' and self.kind == 'orchestrator') or \ + (field == 'activityRequest' and self.kind == 'activity') + + def WhichOneof(self, _): + return f'{self.kind}Request' + + +class DummyCompletionToken: + pass + + +def test_worker_concurrency_loop_sync(): + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=2, + maximum_concurrent_orchestration_work_items=1, + maximum_thread_pool_workers=2, + ) + worker = TaskHubGrpcWorker(concurrency_options=options) + stub = DummyStub() + + def dummy_orchestrator(req, stub, completionToken): + time.sleep(0.1) + stub.CompleteOrchestratorTask('ok') + + def dummy_activity(req, stub, completionToken): + time.sleep(0.1) + stub.CompleteActivityTask('ok') + + # Patch the worker's _execute_orchestrator and _execute_activity + worker._execute_orchestrator = dummy_orchestrator + worker._execute_activity = dummy_activity + + orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] + activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] + + async def run_test(): + # Start the worker manager's run loop in the background + worker_task = asyncio.create_task(worker._async_worker_manager.run()) + for req in orchestrator_requests: + worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + for req in activity_requests: + worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + await asyncio.sleep(1.0) + orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') + activity_count = sum(1 for t, _ in stub.completed if t == 'activity') + assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" + assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" + worker._async_worker_manager._shutdown = True + await worker_task + asyncio.run(run_test()) + + +# Dummy orchestrator and activity for sync context +def dummy_orchestrator(ctx, input): + # Simulate some work + time.sleep(0.1) + return "orchestrator-done" + + +def dummy_activity(ctx, input): + # Simulate some work + time.sleep(0.1) + return "activity-done" + + +def test_worker_concurrency_sync(): + # Use small concurrency to make test observable + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=2, + maximum_concurrent_orchestration_work_items=2, + maximum_thread_pool_workers=2, + ) + worker = TaskHubGrpcWorker(concurrency_options=options) + worker.add_orchestrator(dummy_orchestrator) + worker.add_activity(dummy_activity) + + # Simulate submitting work items to the queues directly (bypassing gRPC) + # We'll use the internal _async_worker_manager for this test + manager = worker._async_worker_manager + results = [] + lock = threading.Lock() + + def make_work(kind, idx): + def fn(*args, **kwargs): + time.sleep(0.1) + with lock: + results.append((kind, idx)) + return f"{kind}-{idx}-done" + return fn + + # Submit more work than concurrency allows + for i in range(5): + manager.submit_orchestration(make_work("orch", i)) + manager.submit_activity(make_work("act", i)) + + # Run the manager loop in a thread (sync context) + def run_manager(): + asyncio.run(manager.run()) + + t = threading.Thread(target=run_manager) + t.start() + time.sleep(1.5) # Let work process + manager.shutdown() + # Unblock the consumers by putting dummy items in the queues + manager.activity_queue.put_nowait((lambda: None, (), {})) + manager.orchestration_queue.put_nowait((lambda: None, (), {})) + t.join(timeout=2) + + # Check that all work items completed + assert len(results) == 10 diff --git a/tests/durabletask/test_worker_concurrency_loop_async.py b/tests/durabletask/test_worker_concurrency_loop_async.py new file mode 100644 index 0000000..c7ba238 --- /dev/null +++ b/tests/durabletask/test_worker_concurrency_loop_async.py @@ -0,0 +1,80 @@ +import asyncio + +from durabletask.worker import ConcurrencyOptions, TaskHubGrpcWorker + + +class DummyStub: + def __init__(self): + self.completed = [] + + def CompleteOrchestratorTask(self, res): + self.completed.append(('orchestrator', res)) + + def CompleteActivityTask(self, res): + self.completed.append(('activity', res)) + + +class DummyRequest: + def __init__(self, kind, instance_id): + self.kind = kind + self.instanceId = instance_id + self.orchestrationInstance = type('O', (), {'instanceId': instance_id}) + self.name = 'dummy' + self.taskId = 1 + self.input = type('I', (), {'value': ''}) + self.pastEvents = [] + self.newEvents = [] + + def HasField(self, field): + return (field == 'orchestratorRequest' and self.kind == 'orchestrator') or \ + (field == 'activityRequest' and self.kind == 'activity') + + def WhichOneof(self, _): + return f'{self.kind}Request' + + +class DummyCompletionToken: + pass + + +def test_worker_concurrency_loop_async(): + options = ConcurrencyOptions( + maximum_concurrent_activity_work_items=2, + maximum_concurrent_orchestration_work_items=1, + maximum_thread_pool_workers=2, + ) + grpc_worker = TaskHubGrpcWorker(concurrency_options=options) + stub = DummyStub() + + async def dummy_orchestrator(req, stub, completionToken): + await asyncio.sleep(0.1) + stub.CompleteOrchestratorTask('ok') + + async def dummy_activity(req, stub, completionToken): + await asyncio.sleep(0.1) + stub.CompleteActivityTask('ok') + + # Patch the worker's _execute_orchestrator and _execute_activity + grpc_worker._execute_orchestrator = dummy_orchestrator + grpc_worker._execute_activity = dummy_activity + + orchestrator_requests = [DummyRequest('orchestrator', f'orch{i}') for i in range(3)] + activity_requests = [DummyRequest('activity', f'act{i}') for i in range(4)] + + async def run_test(): + # Clear stub state before each run + stub.completed.clear() + worker_task = asyncio.create_task(grpc_worker._async_worker_manager.run()) + for req in orchestrator_requests: + grpc_worker._async_worker_manager.submit_orchestration(dummy_orchestrator, req, stub, DummyCompletionToken()) + for req in activity_requests: + grpc_worker._async_worker_manager.submit_activity(dummy_activity, req, stub, DummyCompletionToken()) + await asyncio.sleep(1.0) + orchestrator_count = sum(1 for t, _ in stub.completed if t == 'orchestrator') + activity_count = sum(1 for t, _ in stub.completed if t == 'activity') + assert orchestrator_count == 3, f"Expected 3 orchestrator completions, got {orchestrator_count}" + assert activity_count == 4, f"Expected 4 activity completions, got {activity_count}" + grpc_worker._async_worker_manager._shutdown = True + await worker_task + asyncio.run(run_test()) + asyncio.run(run_test()) diff --git a/tests/test_client.py b/tests/test_client.py deleted file mode 100644 index 5990db0..0000000 --- a/tests/test_client.py +++ /dev/null @@ -1,147 +0,0 @@ -from unittest.mock import patch, ANY, Mock - -from durabletask.client import TaskHubGrpcClient -from durabletask.internal.shared import (DefaultClientInterceptorImpl, - get_default_host_address, - get_grpc_channel) -import pytest - -HOST_ADDRESS = 'localhost:50051' -METADATA = [('key1', 'value1'), ('key2', 'value2')] - - -def test_get_grpc_channel_insecure(): - with patch('grpc.insecure_channel') as mock_channel: - get_grpc_channel(HOST_ADDRESS, METADATA, False) - mock_channel.assert_called_once_with(HOST_ADDRESS) - - -def test_get_grpc_channel_secure(): - with patch('grpc.secure_channel') as mock_channel, patch( - 'grpc.ssl_channel_credentials') as mock_credentials: - get_grpc_channel(HOST_ADDRESS, METADATA, True) - mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) - - -def test_get_grpc_channel_default_host_address(): - with patch('grpc.insecure_channel') as mock_channel: - get_grpc_channel(None, METADATA, False) - mock_channel.assert_called_once_with(get_default_host_address()) - - -def test_get_grpc_channel_with_metadata(): - with patch('grpc.insecure_channel') as mock_channel, patch( - 'grpc.intercept_channel') as mock_intercept_channel: - get_grpc_channel(HOST_ADDRESS, METADATA, False) - mock_channel.assert_called_once_with(HOST_ADDRESS) - mock_intercept_channel.assert_called_once() - - # Capture and check the arguments passed to intercept_channel() - args, kwargs = mock_intercept_channel.call_args - assert args[0] == mock_channel.return_value - assert isinstance(args[1], DefaultClientInterceptorImpl) - assert args[1]._metadata == METADATA - - -def test_grpc_channel_with_host_name_protocol_stripping(): - with patch('grpc.insecure_channel') as mock_insecure_channel, patch( - 'grpc.secure_channel') as mock_secure_channel: - - host_name = "myserver.com:1234" - - prefix = "grpc://" - get_grpc_channel(prefix + host_name, METADATA) - mock_insecure_channel.assert_called_with(host_name) - - prefix = "http://" - get_grpc_channel(prefix + host_name, METADATA) - mock_insecure_channel.assert_called_with(host_name) - - prefix = "HTTP://" - get_grpc_channel(prefix + host_name, METADATA) - mock_insecure_channel.assert_called_with(host_name) - - prefix = "GRPC://" - get_grpc_channel(prefix + host_name, METADATA) - mock_insecure_channel.assert_called_with(host_name) - - prefix = "" - get_grpc_channel(prefix + host_name, METADATA) - mock_insecure_channel.assert_called_with(host_name) - - prefix = "grpcs://" - get_grpc_channel(prefix + host_name, METADATA) - mock_secure_channel.assert_called_with(host_name, ANY) - - prefix = "https://" - get_grpc_channel(prefix + host_name, METADATA) - mock_secure_channel.assert_called_with(host_name, ANY) - - prefix = "HTTPS://" - get_grpc_channel(prefix + host_name, METADATA) - mock_secure_channel.assert_called_with(host_name, ANY) - - prefix = "GRPCS://" - get_grpc_channel(prefix + host_name, METADATA) - mock_secure_channel.assert_called_with(host_name, ANY) - - prefix = "" - get_grpc_channel(prefix + host_name, METADATA, True) - mock_secure_channel.assert_called_with(host_name, ANY) - - -@pytest.mark.parametrize("timeout", [None, 0, 5]) -def test_wait_for_orchestration_start_timeout(timeout): - instance_id = "test-instance" - - from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \ - OrchestrationState, ORCHESTRATION_STATUS_RUNNING - - response = GetInstanceResponse() - state = OrchestrationState() - state.instanceId = instance_id - state.orchestrationStatus = ORCHESTRATION_STATUS_RUNNING - response.orchestrationState.CopyFrom(state) - - c = TaskHubGrpcClient() - c._stub = Mock() - c._stub.WaitForInstanceStart.return_value = response - - grpc_timeout = None if timeout is None else timeout - c.wait_for_orchestration_start(instance_id, timeout=grpc_timeout) - - # Verify WaitForInstanceStart was called with timeout=None - c._stub.WaitForInstanceStart.assert_called_once() - _, kwargs = c._stub.WaitForInstanceStart.call_args - if timeout is None or timeout == 0: - assert kwargs.get('timeout') is None - else: - assert kwargs.get('timeout') == timeout - -@pytest.mark.parametrize("timeout", [None, 0, 5]) -def test_wait_for_orchestration_completion_timeout(timeout): - instance_id = "test-instance" - - from durabletask.internal.orchestrator_service_pb2 import GetInstanceResponse, \ - OrchestrationState, ORCHESTRATION_STATUS_COMPLETED - - response = GetInstanceResponse() - state = OrchestrationState() - state.instanceId = instance_id - state.orchestrationStatus = ORCHESTRATION_STATUS_COMPLETED - response.orchestrationState.CopyFrom(state) - - c = TaskHubGrpcClient() - c._stub = Mock() - c._stub.WaitForInstanceCompletion.return_value = response - - grpc_timeout = None if timeout is None else timeout - c.wait_for_orchestration_completion(instance_id, timeout=grpc_timeout) - - # Verify WaitForInstanceStart was called with timeout=None - c._stub.WaitForInstanceCompletion.assert_called_once() - _, kwargs = c._stub.WaitForInstanceCompletion.call_args - if timeout is None or timeout == 0: - assert kwargs.get('timeout') is None - else: - assert kwargs.get('timeout') == timeout \ No newline at end of file