Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
settings:
director_host: localhost
director_port: 50050


params:
client_reconnect_interval: 5

Portland:
private_attributes: private_attributes.portland_attrs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
settings:
director_host: localhost
director_port: 50050


params:
client_reconnect_interval: 5

Seattle:
private_attributes: private_attributes.seattle_attrs
5 changes: 5 additions & 0 deletions openfl/experimental/workflow/component/envoy/envoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class Envoy:
_envoy_dir_client (EnvoyDirectorClient): The envoy director client.
install_requirements (bool): A flag indicating if the requirements
should be installed.
client_reconnect_interval (int): The interval for client reconnection attempts.
is_experiment_running (bool): A flag indicating if an experiment is
running.
executor (ThreadPoolExecutor): The executor for running tasks.
Expand All @@ -56,6 +57,7 @@ def __init__(
certificate: Optional[Union[Path, str]] = None,
tls: bool = True,
install_requirements: bool = True,
client_reconnect_interval: int = 5,
) -> None:
"""Initialize a envoy object.

Expand All @@ -74,12 +76,14 @@ def __init__(
connections. Defaults to True.
install_requirements (bool, optional): A flag indicating if the
requirements should be installed. Defaults to True.
client_reconnect_interval (int): The interval for client reconnection attempts.
"""
self.name = envoy_name
self.envoy_config = envoy_config
self.tls = tls
self._fill_certs(root_certificate, private_key, certificate)
self.install_requirements = install_requirements
self.client_reconnect_interval = client_reconnect_interval
self._envoy_dir_client = self._create_envoy_dir_client(director_host, director_port)
self.is_experiment_running = False
self.executor = ThreadPoolExecutor()
Expand Down Expand Up @@ -108,6 +112,7 @@ def _create_envoy_dir_client(
root_certificate=self.root_certificate,
private_key=self.private_key,
certificate=self.certificate,
client_reconnect_interval=self.client_reconnect_interval,
)

def _fill_certs(self, root_certificate, private_key, certificate) -> None:
Expand Down
5 changes: 4 additions & 1 deletion openfl/experimental/workflow/interface/cli/envoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,11 @@ def start_(
# Parse envoy parameters
envoy_params = config.get("params", {})
if envoy_params:
install_requirements = envoy_params["install_requirements"]
install_requirements = envoy_params.get("install_requirements", True)
client_reconnect_interval = envoy_params.get("client_reconnect_interval", 5)
else:
install_requirements = False
client_reconnect_interval = 5

if config.root_certificate:
config.root_certificate = Path(config.root_certificate).absolute()
Expand All @@ -153,6 +155,7 @@ def start_(
certificate=config.certificate,
tls=tls,
install_requirements=install_requirements,
client_reconnect_interval=client_reconnect_interval,
)

envoy.start()
5 changes: 5 additions & 0 deletions openfl/experimental/workflow/transport/grpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
RuntimeDirectorClient,
)
from openfl.experimental.workflow.transport.grpc.director_server import DirectorGRPCServer
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import (
ConstantBackoff,
RetryOnRpcErrorClientInterceptor,
channel_options,
)
62 changes: 6 additions & 56 deletions openfl/experimental/workflow/transport/grpc/aggregator_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,70 +4,19 @@

"""AggregatorGRPCClient module."""

import time
from logging import getLogger
from typing import Optional, Tuple

import grpc

from openfl.experimental.workflow.protocols import aggregator_pb2, aggregator_pb2_grpc
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import channel_options
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import (
ConstantBackoff,
RetryOnRpcErrorClientInterceptor,
channel_options,
)
from openfl.protocols.utils import datastream_to_proto, proto_to_datastream


class ConstantBackoff:
"""Constant Backoff policy."""

def __init__(self, reconnect_interval, logger, uri):
"""Initialize Constant Backoff."""
self.reconnect_interval = reconnect_interval
self.logger = logger
self.uri = uri

def sleep(self):
"""Sleep for specified interval."""
self.logger.info(f"Attempting to connect to aggregator at {self.uri}")
time.sleep(self.reconnect_interval)


class RetryOnRpcErrorClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor
):
"""Retry gRPC connection on failure."""

def __init__(
self,
sleeping_policy,
status_for_retry: Optional[Tuple[grpc.StatusCode]] = None,
):
"""Initialize function for gRPC retry."""
self.sleeping_policy = sleeping_policy
self.status_for_retry = status_for_retry

def _intercept_call(self, continuation, client_call_details, request_or_iterator):
"""Intercept the call to the gRPC server."""
while True:
response = continuation(client_call_details, request_or_iterator)

if isinstance(response, grpc.RpcError):
# If status code is not in retryable status codes
self.sleeping_policy.logger.info(f"Response code: {response.code()}")
if self.status_for_retry and response.code() not in self.status_for_retry:
return response

self.sleeping_policy.sleep()
else:
return response

def intercept_unary_unary(self, continuation, client_call_details, request):
"""Wrap intercept call for unary->unary RPC."""
return self._intercept_call(continuation, client_call_details, request)

def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
"""Wrap intercept call for stream->unary RPC."""
return self._intercept_call(continuation, client_call_details, request_iterator)


def _atomic_connection(func):
def wrapper(self, *args, **kwargs):
self.reconnect()
Expand Down Expand Up @@ -148,6 +97,7 @@ def __init__(
logger=self.logger,
reconnect_interval=int(kwargs.get("client_reconnect_interval", 1)),
uri=self.uri,
participant="Aggregator",
),
status_for_retry=(grpc.StatusCode.UNAVAILABLE,),
),
Expand Down
25 changes: 22 additions & 3 deletions openfl/experimental/workflow/transport/grpc/director_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,13 @@

from openfl.experimental.workflow.protocols import director_pb2, director_pb2_grpc
from openfl.experimental.workflow.transport.grpc.exceptions import EnvoyNotFoundError
from openfl.experimental.workflow.transport.grpc.grpc_channel_options import (
ConstantBackoff,
RetryOnRpcErrorClientInterceptor,
channel_options,
)
from openfl.protocols.utils import datastream_to_proto

from .grpc_channel_options import channel_options

logger = logging.getLogger(__name__)


Expand All @@ -42,6 +45,7 @@ def __init__(
root_certificate: Optional[Union[Path, str]] = None,
private_key: Optional[Union[Path, str]] = None,
certificate: Optional[Union[Path, str]] = None,
client_reconnect_interval: int = 5,
) -> None:
"""
Initialize director client object.
Expand All @@ -57,6 +61,7 @@ def __init__(
connection.
certificate (Optional[Union[Path, str]]): The path to the certificate for the TLS
connection.
client_reconnect_interval (int): The interval for client reconnection attempts.
"""
director_addr = f"{director_host}:{director_port}"
self.envoy_name = envoy_name
Expand All @@ -82,7 +87,21 @@ def __init__(
certificate_chain=certificate_b,
)
channel = grpc.secure_channel(director_addr, credentials, options=channel_options)
self.stub = director_pb2_grpc.DirectorStub(channel)

# Adding an interceptor for RPC Errors
self.interceptors = (
RetryOnRpcErrorClientInterceptor(
sleeping_policy=ConstantBackoff(
logger=logger,
reconnect_interval=client_reconnect_interval,
uri=director_addr,
participant="Director",
),
status_for_retry=(grpc.StatusCode.UNAVAILABLE,),
),
)
intercept_channel = grpc.intercept_channel(channel, *self.interceptors)
self.stub = director_pb2_grpc.DirectorStub(intercept_channel)

def connect_envoy(self, envoy_name: str) -> bool:
"""Attempt to establish a connection with the director.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import time
from typing import Optional, Tuple

import grpc

max_metadata_size = 32 * 2**20
max_message_length = 2**30
Expand All @@ -10,3 +14,57 @@
("grpc.max_send_message_length", max_message_length),
("grpc.max_receive_message_length", max_message_length),
]


class ConstantBackoff:
"""Constant Backoff policy."""

def __init__(self, reconnect_interval, logger, uri, participant):
"""Initialize Constant Backoff."""
self.reconnect_interval = reconnect_interval
self.logger = logger
self.uri = uri
self.participant = participant

def sleep(self):
"""Sleep for specified interval."""
self.logger.info(f"Attempting to connect to {self.participant} at {self.uri}")
time.sleep(self.reconnect_interval)


class RetryOnRpcErrorClientInterceptor(
grpc.UnaryUnaryClientInterceptor, grpc.StreamUnaryClientInterceptor
):
"""Retry gRPC connection on failure."""

def __init__(
self,
sleeping_policy,
status_for_retry: Optional[Tuple[grpc.StatusCode]] = None,
):
"""Initialize function for gRPC retry."""
self.sleeping_policy = sleeping_policy
self.status_for_retry = status_for_retry

def _intercept_call(self, continuation, client_call_details, request_or_iterator):
"""Intercept the call to the gRPC server."""
while True:
response = continuation(client_call_details, request_or_iterator)

if isinstance(response, grpc.RpcError):
# If status code is not in retryable status codes
self.sleeping_policy.logger.info(f"Response code: {response.code()}")
if self.status_for_retry and response.code() not in self.status_for_retry:
return response

self.sleeping_policy.sleep()
else:
return response

def intercept_unary_unary(self, continuation, client_call_details, request):
"""Wrap intercept call for unary->unary RPC."""
return self._intercept_call(continuation, client_call_details, request)

def intercept_stream_unary(self, continuation, client_call_details, request_iterator):
"""Wrap intercept call for stream->unary RPC."""
return self._intercept_call(continuation, client_call_details, request_iterator)
Loading