diff --git a/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/Portland/Portland_config.yaml b/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/Portland/Portland_config.yaml index a37e5010bd..5cd843112d 100755 --- a/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/Portland/Portland_config.yaml +++ b/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/Portland/Portland_config.yaml @@ -1,6 +1,9 @@ settings: director_host: localhost director_port: 50050 - + +params: + client_reconnect_interval: 5 + Portland: private_attributes: private_attributes.portland_attrs \ No newline at end of file diff --git a/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/Seattle/Seattle_config.yaml b/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/Seattle/Seattle_config.yaml index 07a819115c..817ba0b733 100755 --- a/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/Seattle/Seattle_config.yaml +++ b/openfl-tutorials/experimental/workflow/FederatedRuntime/101_MNIST/Seattle/Seattle_config.yaml @@ -1,6 +1,9 @@ settings: director_host: localhost director_port: 50050 - + +params: + client_reconnect_interval: 5 + Seattle: private_attributes: private_attributes.seattle_attrs \ No newline at end of file diff --git a/openfl/experimental/workflow/component/envoy/envoy.py b/openfl/experimental/workflow/component/envoy/envoy.py index 1d27d48c79..a28207e80b 100644 --- a/openfl/experimental/workflow/component/envoy/envoy.py +++ b/openfl/experimental/workflow/component/envoy/envoy.py @@ -35,6 +35,7 @@ class Envoy: _envoy_dir_client (EnvoyDirectorClient): The envoy director client. install_requirements (bool): A flag indicating if the requirements should be installed. + client_reconnect_interval (int): The interval for client reconnection attempts. is_experiment_running (bool): A flag indicating if an experiment is running. executor (ThreadPoolExecutor): The executor for running tasks. @@ -56,6 +57,7 @@ def __init__( certificate: Optional[Union[Path, str]] = None, tls: bool = True, install_requirements: bool = True, + client_reconnect_interval: int = 5, ) -> None: """Initialize a envoy object. @@ -74,12 +76,14 @@ def __init__( connections. Defaults to True. install_requirements (bool, optional): A flag indicating if the requirements should be installed. Defaults to True. + client_reconnect_interval (int): The interval for client reconnection attempts. """ self.name = envoy_name self.envoy_config = envoy_config self.tls = tls self._fill_certs(root_certificate, private_key, certificate) self.install_requirements = install_requirements + self.client_reconnect_interval = client_reconnect_interval self._envoy_dir_client = self._create_envoy_dir_client(director_host, director_port) self.is_experiment_running = False self.executor = ThreadPoolExecutor() @@ -108,6 +112,7 @@ def _create_envoy_dir_client( root_certificate=self.root_certificate, private_key=self.private_key, certificate=self.certificate, + client_reconnect_interval=self.client_reconnect_interval, ) def _fill_certs(self, root_certificate, private_key, certificate) -> None: diff --git a/openfl/experimental/workflow/interface/cli/envoy.py b/openfl/experimental/workflow/interface/cli/envoy.py index 1c16ec9990..fc198a7d52 100644 --- a/openfl/experimental/workflow/interface/cli/envoy.py +++ b/openfl/experimental/workflow/interface/cli/envoy.py @@ -132,9 +132,11 @@ def start_( # Parse envoy parameters envoy_params = config.get("params", {}) if envoy_params: - install_requirements = envoy_params["install_requirements"] + install_requirements = envoy_params.get("install_requirements", True) + client_reconnect_interval = envoy_params.get("client_reconnect_interval", 5) else: install_requirements = False + client_reconnect_interval = 5 if config.root_certificate: config.root_certificate = Path(config.root_certificate).absolute() @@ -153,6 +155,7 @@ def start_( certificate=config.certificate, tls=tls, install_requirements=install_requirements, + client_reconnect_interval=client_reconnect_interval, ) envoy.start() diff --git a/openfl/experimental/workflow/transport/grpc/__init__.py b/openfl/experimental/workflow/transport/grpc/__init__.py index 464df70c54..fdeddf3338 100644 --- a/openfl/experimental/workflow/transport/grpc/__init__.py +++ b/openfl/experimental/workflow/transport/grpc/__init__.py @@ -11,3 +11,8 @@ RuntimeDirectorClient, ) from openfl.experimental.workflow.transport.grpc.director_server import DirectorGRPCServer +from openfl.experimental.workflow.transport.grpc.grpc_channel_options import ( + ConstantBackoff, + RetryOnRpcErrorClientInterceptor, + channel_options, +) diff --git a/openfl/experimental/workflow/transport/grpc/aggregator_client.py b/openfl/experimental/workflow/transport/grpc/aggregator_client.py index 00adb2f4b6..1cf609ea44 100644 --- a/openfl/experimental/workflow/transport/grpc/aggregator_client.py +++ b/openfl/experimental/workflow/transport/grpc/aggregator_client.py @@ -4,70 +4,19 @@ """AggregatorGRPCClient module.""" -import time from logging import getLogger -from typing import Optional, Tuple import grpc from openfl.experimental.workflow.protocols import aggregator_pb2, aggregator_pb2_grpc -from openfl.experimental.workflow.transport.grpc.grpc_channel_options import channel_options +from openfl.experimental.workflow.transport.grpc.grpc_channel_options import ( + ConstantBackoff, + RetryOnRpcErrorClientInterceptor, + channel_options, +) from openfl.protocols.utils import datastream_to_proto, proto_to_datastream -class ConstantBackoff: - """Constant Backoff policy.""" - - def __init__(self, reconnect_interval, logger, uri): - """Initialize Constant Backoff.""" - self.reconnect_interval = reconnect_interval - self.logger = logger - self.uri = uri - - def sleep(self): - """Sleep for specified interval.""" - self.logger.info(f"Attempting to connect to aggregator at {self.uri}") - time.sleep(self.reconnect_interval) - - -class RetryOnRpcErrorClientInterceptor( - grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor -): - """Retry gRPC connection on failure.""" - - def __init__( - self, - sleeping_policy, - status_for_retry: Optional[Tuple[grpc.StatusCode]] = None, - ): - """Initialize function for gRPC retry.""" - self.sleeping_policy = sleeping_policy - self.status_for_retry = status_for_retry - - def _intercept_call(self, continuation, client_call_details, request_or_iterator): - """Intercept the call to the gRPC server.""" - while True: - response = continuation(client_call_details, request_or_iterator) - - if isinstance(response, grpc.RpcError): - # If status code is not in retryable status codes - self.sleeping_policy.logger.info(f"Response code: {response.code()}") - if self.status_for_retry and response.code() not in self.status_for_retry: - return response - - self.sleeping_policy.sleep() - else: - return response - - def intercept_unary_unary(self, continuation, client_call_details, request): - """Wrap intercept call for unary->unary RPC.""" - return self._intercept_call(continuation, client_call_details, request) - - def intercept_stream_unary(self, continuation, client_call_details, request_iterator): - """Wrap intercept call for stream->unary RPC.""" - return self._intercept_call(continuation, client_call_details, request_iterator) - - def _atomic_connection(func): def wrapper(self, *args, **kwargs): self.reconnect() @@ -148,6 +97,7 @@ def __init__( logger=self.logger, reconnect_interval=int(kwargs.get("client_reconnect_interval", 1)), uri=self.uri, + participant="Aggregator", ), status_for_retry=(grpc.StatusCode.UNAVAILABLE,), ), diff --git a/openfl/experimental/workflow/transport/grpc/director_client.py b/openfl/experimental/workflow/transport/grpc/director_client.py index 4ebe4a636d..010d36ff5b 100644 --- a/openfl/experimental/workflow/transport/grpc/director_client.py +++ b/openfl/experimental/workflow/transport/grpc/director_client.py @@ -12,10 +12,13 @@ from openfl.experimental.workflow.protocols import director_pb2, director_pb2_grpc from openfl.experimental.workflow.transport.grpc.exceptions import EnvoyNotFoundError +from openfl.experimental.workflow.transport.grpc.grpc_channel_options import ( + ConstantBackoff, + RetryOnRpcErrorClientInterceptor, + channel_options, +) from openfl.protocols.utils import datastream_to_proto -from .grpc_channel_options import channel_options - logger = logging.getLogger(__name__) @@ -42,6 +45,7 @@ def __init__( root_certificate: Optional[Union[Path, str]] = None, private_key: Optional[Union[Path, str]] = None, certificate: Optional[Union[Path, str]] = None, + client_reconnect_interval: int = 5, ) -> None: """ Initialize director client object. @@ -57,6 +61,7 @@ def __init__( connection. certificate (Optional[Union[Path, str]]): The path to the certificate for the TLS connection. + client_reconnect_interval (int): The interval for client reconnection attempts. """ director_addr = f"{director_host}:{director_port}" self.envoy_name = envoy_name @@ -82,7 +87,21 @@ def __init__( certificate_chain=certificate_b, ) channel = grpc.secure_channel(director_addr, credentials, options=channel_options) - self.stub = director_pb2_grpc.DirectorStub(channel) + + # Adding an interceptor for RPC Errors + self.interceptors = ( + RetryOnRpcErrorClientInterceptor( + sleeping_policy=ConstantBackoff( + logger=logger, + reconnect_interval=client_reconnect_interval, + uri=director_addr, + participant="Director", + ), + status_for_retry=(grpc.StatusCode.UNAVAILABLE,), + ), + ) + intercept_channel = grpc.intercept_channel(channel, *self.interceptors) + self.stub = director_pb2_grpc.DirectorStub(intercept_channel) def connect_envoy(self, envoy_name: str) -> bool: """Attempt to establish a connection with the director. diff --git a/openfl/experimental/workflow/transport/grpc/grpc_channel_options.py b/openfl/experimental/workflow/transport/grpc/grpc_channel_options.py index 6e143f224f..30453a03f8 100644 --- a/openfl/experimental/workflow/transport/grpc/grpc_channel_options.py +++ b/openfl/experimental/workflow/transport/grpc/grpc_channel_options.py @@ -1,6 +1,10 @@ # Copyright 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 +import time +from typing import Optional, Tuple + +import grpc max_metadata_size = 32 * 2**20 max_message_length = 2**30 @@ -10,3 +14,57 @@ ("grpc.max_send_message_length", max_message_length), ("grpc.max_receive_message_length", max_message_length), ] + + +class ConstantBackoff: + """Constant Backoff policy.""" + + def __init__(self, reconnect_interval, logger, uri, participant): + """Initialize Constant Backoff.""" + self.reconnect_interval = reconnect_interval + self.logger = logger + self.uri = uri + self.participant = participant + + def sleep(self): + """Sleep for specified interval.""" + self.logger.info(f"Attempting to connect to {self.participant} at {self.uri}") + time.sleep(self.reconnect_interval) + + +class RetryOnRpcErrorClientInterceptor( + grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor +): + """Retry gRPC connection on failure.""" + + def __init__( + self, + sleeping_policy, + status_for_retry: Optional[Tuple[grpc.StatusCode]] = None, + ): + """Initialize function for gRPC retry.""" + self.sleeping_policy = sleeping_policy + self.status_for_retry = status_for_retry + + def _intercept_call(self, continuation, client_call_details, request_or_iterator): + """Intercept the call to the gRPC server.""" + while True: + response = continuation(client_call_details, request_or_iterator) + + if isinstance(response, grpc.RpcError): + # If status code is not in retryable status codes + self.sleeping_policy.logger.info(f"Response code: {response.code()}") + if self.status_for_retry and response.code() not in self.status_for_retry: + return response + + self.sleeping_policy.sleep() + else: + return response + + def intercept_unary_unary(self, continuation, client_call_details, request): + """Wrap intercept call for unary->unary RPC.""" + return self._intercept_call(continuation, client_call_details, request) + + def intercept_stream_unary(self, continuation, client_call_details, request_iterator): + """Wrap intercept call for stream->unary RPC.""" + return self._intercept_call(continuation, client_call_details, request_iterator)