diff --git a/openfl/experimental/workflow/component/aggregator/aggregator.py b/openfl/experimental/workflow/component/aggregator/aggregator.py index 8ff1b0779b..053819295d 100644 --- a/openfl/experimental/workflow/component/aggregator/aggregator.py +++ b/openfl/experimental/workflow/component/aggregator/aggregator.py @@ -182,12 +182,38 @@ def __delete_private_attrs_from_clone(self, clone: Any, replace_str: str = None) def _log_big_warning(self) -> None: """Warn user about single collaborator cert mode.""" logger.warning( - f"\n{the_dragon}\nYOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS" - f" NOT PROPER PKI AND " - f"SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN" - f" WARNED!!!" + "YOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS" + " NOT PROPER PKI AND " + "SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN" + " WARNED!!!" ) + def _initialize_flow(self) -> None: + """Initialize flow by resetting and creating clones.""" + FLSpec._reset_clones() + FLSpec._create_clones(self.flow, self.flow.runtime.collaborators) + + def _enqueue_next_step_for_collaborators(self, next_step) -> None: + """Enqueue the next step and associated clone for each selected collaborator. + + Args: + next_step (str): Next step to be executed by collaborators + """ + for collaborator, task_queue in self.__collaborator_tasks_queue.items(): + if collaborator in self.selected_collaborators: + task_queue.put((next_step, self.clones_dict[collaborator])) + else: + logger.info( + f"Skipping task dispatch for collaborator '{collaborator}' " + f"as it is not part of selected_collaborators." + ) + + def _restore_instance_snapshot(self) -> None: + """Restore the FLSpec state at the aggregator from a saved instance snapshot.""" + if hasattr(self, "instance_snapshot"): + self.flow.restore_instance_snapshot(self.flow, list(self.instance_snapshot)) + delattr(self, "instance_snapshot") + def _update_final_flow(self) -> None: """Update the final flow state with current flow artifacts.""" artifacts_iter, _ = generate_artifacts(ctx=self.flow) @@ -203,6 +229,33 @@ def _get_sleep_time() -> int: """ return 10 + async def _track_collaborator_status(self) -> None: + """Wait for selected collaborators to connect, request tasks, and submit results.""" + while not self.collaborator_task_results.is_set(): + len_sel_collabs = len(self.selected_collaborators) + len_connected_collabs = len(self.connected_collaborators) + if len_connected_collabs < len_sel_collabs: + # Waiting for collaborators to connect. + logger.info( + "Waiting for " + + f"{len_sel_collabs - len_connected_collabs}/{len_sel_collabs}" + + " collaborators to connect..." + ) + elif self.tasks_sent_to_collaborators != len_sel_collabs: + logger.info( + "Waiting for " + + f"{len_sel_collabs - self.tasks_sent_to_collaborators}/{len_sel_collabs}" + + " to make requests for tasks..." + ) + else: + # Waiting for selected collaborators to send the results. + logger.info( + "Waiting for " + + f"{len_sel_collabs - self.collaborators_counter}/{len_sel_collabs}" + + " collaborators to send results..." + ) + await asyncio.sleep(Aggregator._get_sleep_time()) + async def run_flow(self) -> FLSpec: """ Start the execution and run flow until completion. @@ -211,61 +264,36 @@ async def run_flow(self) -> FLSpec: Returns: flow (FLSpec): Updated instance. """ - # Start function will be the first step if any flow + self._initialize_flow() + # Start function will be the first step of any flow f_name = "start" - # Creating a clones from the flow object - FLSpec._reset_clones() - FLSpec._create_clones(self.flow, self.flow.runtime.collaborators) - logger.info(f"Starting round {self.current_round}...") + while True: + # Perform Aggregator steps if any next_step = self.do_task(f_name) - if self.time_to_quit: logger.info("Experiment Completed.") break - # Prepare queue for collaborator task, with clones - for k, v in self.__collaborator_tasks_queue.items(): - if k in self.selected_collaborators: - v.put((next_step, self.clones_dict[k])) - else: - logger.info(f"Tasks will not be sent to {k}") - - while not self.collaborator_task_results.is_set(): - len_sel_collabs = len(self.selected_collaborators) - len_connected_collabs = len(self.connected_collaborators) - if len_connected_collabs < len_sel_collabs: - # Waiting for collaborators to connect. - logger.info( - "Waiting for " - + f"{len_sel_collabs - len_connected_collabs}/{len_sel_collabs}" - + " collaborators to connect..." - ) - elif self.tasks_sent_to_collaborators != len_sel_collabs: - logger.info( - "Waiting for " - + f"{len_sel_collabs - self.tasks_sent_to_collaborators}/{len_sel_collabs}" - + " to make requests for tasks..." - ) - else: - # Waiting for selected collaborators to send the results. - logger.info( - "Waiting for " - + f"{len_sel_collabs - self.collaborators_counter}/{len_sel_collabs}" - + " collaborators to send results..." - ) - await asyncio.sleep(Aggregator._get_sleep_time()) - + self._enqueue_next_step_for_collaborators(next_step) + await self._track_collaborator_status() self.collaborator_task_results.clear() f_name = self.next_step - if hasattr(self, "instance_snapshot"): - self.flow.restore_instance_snapshot(self.flow, list(self.instance_snapshot)) - delattr(self, "instance_snapshot") + self._restore_instance_snapshot() self._update_final_flow() return self.final_flow_state + def extract_flow(self) -> FLSpec: + """Extract the flow object from the aggregator. + + Returns: + FLSpec: The flow object. + """ + self.__delete_private_attrs_from_clone(self.flow) + return self.flow + def call_checkpoint( self, name: str, ctx: Any, f: Callable, stream_buffer: bytes = None ) -> None: @@ -527,78 +555,11 @@ def valid_collaborator_cn_and_id( ) def all_quit_jobs_sent(self) -> bool: - """Assert all quit jobs are sent to collaborators.""" - return set(self.quit_job_sent_to) == set(self.authorized_cols) - + """ + Check whether a quit job has been sent to all authorized collaborators. -the_dragon = """ - - ,@@.@@+@@##@,@@@@.`@@#@+ *@@@@ #@##@ `@@#@# @@@@@ @@ @@@@` #@@@ :@@ `@#`@@@#.@ - @@ #@ ,@ +. @@.@* #@ :` @+*@ .@`+. @@ *@::@`@@ @@# @@ #`;@`.@@ @@@`@`#@* +:@` - @@@@@ ,@@@ @@@@ +@@+ @@@@ .@@@ @@ .@+:@@@: .;+@` @@ ,;,#@` @@ @@@@@ ,@@@* @ - @@ #@ ,@`*. @@.@@ #@ ,; `@+,@#.@.*` @@ ,@::@`@@` @@@@# @@`:@;*@+ @@ @`:@@`@ *@@ ` - .@@`@@,+@+;@.@@ @@`@@;*@ ;@@#@:*@+;@ `@@;@@ #@**@+;@ `@@:`@@@@ @@@@.`@+ .@ +@+@*,@ - `` `` ` `` . ` ` ` ` ` .` ` `` `` `` ` . ` - - - - .** - ;` `****: - @**`******* - *** +***********; - ,@***;` .*:,;************ - ;***********@@*********** - ;************************, - `************************* - ************************* - ,************************ - **#********************* - *@****` :**********; - +**; .********. - ;*; `*******#: `,: - ****@@@++:: ,,;***. - *@@@**;#;: +: **++*, - @***#@@@: +*; ,**** - @*@+**** ***` ****, - ,@#******. , **** **;,**. - * ******** :, ;*:*+ ** :,** - # ********:: *,.*:**` * ,*; - . *********: .+,*:;*: : `:** - ; :********: ***::** ` ` ** - + :****::*** , *;;::**` :* - `` .****::;**::: *;::::*; ;* - * *****::***:. **::::** ;: - # *****;:**** ;*::;*** ,*` - ; ************` ,**:****; ::* - : *************;:;*;*++: *. - : *****************;* `* - `. `*****************; : *. - .` .*+************+****;: :* - `. :;+***********+******;` : .,* - ; ::*+*******************. `:: .`:. - + :::**********************;;:` * - + ,::;*************;:::*******. * - # `:::+*************:::;******** :, * - @ :::***************;:;*********;:, * - @ ::::******:*********************: ,:* - @ .:::******:;*********************, :* - # :::******::******###@*******;;**** *, - # .::;*****::*****#****@*****;:::***; `` ** - * ::;***********+*****+#******::*****,,,,** - : :;***********#******#****************** - .` `;***********#******+****+************ - `, ***#**@**+***+*****+**************;` - ; *++**#******#+****+` `.,.. - + `@***#*******#****# - + +***@********+**+: - * .+**+;**;;;**;#**# - ,` ****@ +*+: - # +**+ :+** - @ ;**+, ,***+ - # #@+**** *#****+ - `; @+***+@ `#**+#++ - # #*#@##, .++:.,# - `* @# +. - @@@ - # `@ - , """ + Returns: + bool: True if quit jobs have been sent to all authorized collaborators, + False otherwise. + """ + return set(self.quit_job_sent_to) == set(self.authorized_cols) diff --git a/openfl/experimental/workflow/component/director/director.py b/openfl/experimental/workflow/component/director/director.py index eecf0855dd..775e330683 100644 --- a/openfl/experimental/workflow/component/director/director.py +++ b/openfl/experimental/workflow/component/director/director.py @@ -9,13 +9,12 @@ import time from collections import defaultdict from pathlib import Path -from typing import Any, AsyncGenerator, Dict, Iterable, Optional, Tuple, Union - -import dill +from typing import Any, AsyncGenerator, Dict, Iterable, Optional, Union from openfl.experimental.workflow.component.director.experiment import ( Experiment, ExperimentsRegistry, + Status, ) from openfl.experimental.workflow.transport.grpc.exceptions import EnvoyNotFoundError @@ -37,7 +36,7 @@ class Director: _flow_status (Queue): Stores the flow status experiments_registry (ExperimentsRegistry): An object of ExperimentsRegistry to store the experiments. - col_exp (dict): A dictionary to store the experiments for + collaborator_experiments (dict): A dictionary to store the experiments for collaborators. col_exp_queues (defaultdict): A defaultdict to store the experiment queues for collaborators. @@ -84,7 +83,7 @@ def __init__( self._flow_status = asyncio.Queue() self.experiments_registry = ExperimentsRegistry() - self.col_exp = {} + self.collaborator_experiments = {} self.col_exp_queues = defaultdict(asyncio.Queue) self._envoy_registry = {} self.envoy_health_check_period = envoy_health_check_period @@ -96,6 +95,7 @@ async def start_experiment_execution_loop(self) -> None: loop = asyncio.get_event_loop() while True: try: + logger.info("Waiting for an experiment to run...") async with self.experiments_registry.get_next_experiment() as experiment: await self._wait_for_authorized_envoys() run_aggregator_future = loop.create_task( @@ -115,6 +115,11 @@ async def start_experiment_execution_loop(self) -> None: # Wait for the experiment to complete and save the result flow_status = await run_aggregator_future await self._flow_status.put(flow_status) + # Mark all envoys' experiment states as None, + # indicating no active experiment + self.collaborator_experiments = dict.fromkeys( + self.collaborator_experiments, None + ) except Exception as e: logger.error(f"Error while executing experiment: {e}") raise @@ -131,16 +136,15 @@ async def _wait_for_authorized_envoys(self) -> None: ) await asyncio.sleep(10) - async def get_flow_state(self) -> Tuple[bool, bytes]: + async def get_flow_state(self) -> Dict[str, Union[bool, Optional[Any], Optional[str]]]: """Wait until the experiment flow status indicates completion - and return the status along with a serialized FLSpec object. + and return the flow status. Returns: - status (bool): The flow status. - flspec_obj (bytes): A serialized FLSpec object (in bytes) using dill. + dict: A dictionary containing the flow status. """ - status, flspec_obj = await self._flow_status.get() - return status, dill.dumps(flspec_obj) + status = await self._flow_status.get() + return status async def wait_experiment(self, envoy_name: str) -> str: """Waits for an experiment to be ready for a given envoy. @@ -151,17 +155,17 @@ async def wait_experiment(self, envoy_name: str) -> str: Returns: str: The name of the experiment on the queue. """ - experiment_name = self.col_exp.get(envoy_name) + experiment_name = self.collaborator_experiments.get(envoy_name) # If any envoy gets disconnected if experiment_name and experiment_name in self.experiments_registry: experiment = self.experiments_registry[experiment_name] if experiment.aggregator.current_round < experiment.aggregator.rounds_to_train: return experiment_name - self.col_exp[envoy_name] = None + self.collaborator_experiments[envoy_name] = None queue = self.col_exp_queues[envoy_name] experiment_name = await queue.get() - self.col_exp[envoy_name] = experiment_name + self.collaborator_experiments[envoy_name] = experiment_name return experiment_name @@ -220,7 +224,11 @@ async def stream_experiment_stdout( f'No experiment name "{experiment_name}" in experiments list, or caller "{caller}"' f" does not have access to this experiment" ) - while not self.experiments_registry[experiment_name].aggregator: + experiment = self.experiments_registry[experiment_name] + while not experiment.aggregator: + if experiment.experiment_status.status == Status.FAILED: + # Exit early if the experiment failed to start + return await asyncio.sleep(5) aggregator = self.experiments_registry[experiment_name].aggregator while True: @@ -277,7 +285,7 @@ def get_envoys(self) -> Dict[str, Any]: envoy["is_online"] = time.time() < envoy.get("last_updated", 0) + envoy.get( "valid_duration", 0 ) - envoy["experiment_name"] = self.col_exp.get(envoy["name"], "None") + envoy["experiment_name"] = self.collaborator_experiments.get(envoy["name"], "None") return self._envoy_registry diff --git a/openfl/experimental/workflow/component/director/experiment.py b/openfl/experimental/workflow/component/director/experiment.py index a0c79cdab0..b0fb889550 100644 --- a/openfl/experimental/workflow/component/director/experiment.py +++ b/openfl/experimental/workflow/component/director/experiment.py @@ -6,10 +6,12 @@ import asyncio import logging +import traceback from contextlib import asynccontextmanager +from dataclasses import dataclass from enum import Enum, auto from pathlib import Path -from typing import Any, Iterable, List, Optional, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Union from openfl.experimental.workflow.federated import Plan from openfl.experimental.workflow.transport import AggregatorGRPCServer @@ -28,6 +30,51 @@ class Status(Enum): REJECTED = auto() +@dataclass +class ExperimentStatus: + """ + A class to track the status and exceptions of an experiment run. + + Attributes: + status (Status): The current running status of the experiment. + updated_flow (FLSpec): The updated flow object associated with the experiment. + exception (str, optional): Any exception that occurred during the experiment. + + """ + + status: Status = Status.PENDING + updated_flow: Optional[Any] = None + exception: Optional[str] = None + + def update_experiment_status( + self, + status: Status, + updated_flow: Optional[Any] = None, + exception: Optional[str] = None, + ) -> None: + """ + A method to update the experiment status and associated details. + """ + self.status = status + if updated_flow: + self.updated_flow = updated_flow + if exception: + self.exception = exception + + def get_status(self) -> Dict[str, Union[bool, Optional[Any], Optional[str]]]: + """ + Get the status of the experiment. + + Returns: + dict: The status of the experiment. + """ + return { + "status": self.status == Status.FINISHED, + "updated_flow": self.updated_flow, + "exception": self.exception, + } + + class Experiment: """Experiment class. @@ -41,6 +88,8 @@ class Experiment: plan_path (Union[Path, str]): The path to the plan. users (Iterable[str]): The list of users. status (str): The status of the experiment. + aggregator_grpc_server (AggregatorGRPCServer): The gRPC server + for the aggregator. aggregator (Aggregator): The aggregator instance. updated_flow (FLSpec): Updated flow instance. """ @@ -76,10 +125,75 @@ def __init__( # experiment workspace provided by the director self.plan_path = Path(plan_path) self.users = set() if users is None else set(users) - self.status = Status.PENDING + self.experiment_status = ExperimentStatus() + self._aggregator_grpc_server = None self.aggregator = None self.updated_flow = None + def _initialize_aggregator_server( + self, + tls: bool, + root_certificate: Optional[Union[Path, str]], + private_key: Optional[Union[Path, str]], + certificate: Optional[Union[Path, str]], + director_config: Path, + ) -> bool: + """Initialize the aggregator server. + + Args: + tls (bool, optional): A flag indicating if TLS should be used for + connections. Defaults to True. + root_certificate (Optional[Union[Path, str]]): The path to the + root certificate for TLS. Defaults to None. + private_key (Optional[Union[Path, str]]): The path to the private + key for TLS. Defaults to None. + certificate (Optional[Union[Path, str]]): The path to the + certificate for TLS. Defaults to None. + director_config (Path): Path to director's config file. + Defaults to None. + + Returns: + bool: True if the server was created successfully, False otherwise. + """ + try: + self._aggregator_grpc_server = self._create_aggregator_grpc_server( + tls=tls, + root_certificate=root_certificate, + private_key=private_key, + certificate=certificate, + director_config=director_config, + ) + self.aggregator = self._aggregator_grpc_server.aggregator + return True + except Exception: + exception_trace = traceback.format_exc() + logger.error(f"Failed to create aggregator server: {exception_trace}") + self.experiment_status.update_experiment_status( + Status.FAILED, + exception=exception_trace, + ) + return False + + async def _run_experiment_flow(self) -> None: + """Run the experiment flow and the aggregator gRPC server.""" + _, self.updated_flow = await asyncio.gather( + self._run_aggregator_grpc_server(self._aggregator_grpc_server), + self.aggregator.run_flow(), + ) + + def _handle_experiment_failure(self) -> None: + """Handle experiment failure and update status.""" + exception_trace = traceback.format_exc() + self.experiment_status.update_experiment_status( + Status.FAILED, + updated_flow=self.aggregator.extract_flow(), + exception=exception_trace, + ) + # Mark quit jobs as sent to all collaborators to allow the + # aggregator gRPC server to shut down + self.aggregator.quit_job_sent_to = self.collaborators + logger.error(f"Experiment {self.name} failed with error: {exception_trace}") + async def start( self, *, @@ -89,7 +203,7 @@ async def start( certificate: Optional[Union[Path, str]] = None, director_config: Path = None, install_requirements: bool = True, - ) -> Tuple[bool, Any]: + ) -> Dict[str, Union[bool, Optional[Any], Optional[str]]]: """Run experiment. Args: @@ -106,41 +220,32 @@ async def start( requirements should be installed. Defaults to True. Returns: - List[Union[bool, Any]]: - - status: status of the experiment. - - updated_flow: The updated flow object. + dict: A dictionary containing: + - status (Status): Final status of the experiment. + - updated_flow (FLSpec): The updated flow object. + - exception (str or None): Formatted traceback if any exception occurred. """ - self.status = Status.IN_PROGRESS + self.experiment_status.update_experiment_status(Status.IN_PROGRESS) + logger.info(f"New experiment {self.name} for collaborators {self.collaborators}") try: - logger.info(f"New experiment {self.name} for collaborators {self.collaborators}") - with ExperimentWorkspace( experiment_name=self.name, data_file_path=self.archive_path, install_requirements=install_requirements, ): - aggregator_grpc_server = self._create_aggregator_grpc_server( - tls=tls, - root_certificate=root_certificate, - private_key=private_key, - certificate=certificate, - director_config=director_config, - ) - self.aggregator = aggregator_grpc_server.aggregator - _, self.updated_flow = await asyncio.gather( - self._run_aggregator_grpc_server( - aggregator_grpc_server, - ), - self.aggregator.run_flow(), - ) - self.status = Status.FINISHED - logger.info("Experiment %s was finished successfully.", self.name) - except Exception as e: - self.status = Status.FAILED - logger.error("Experiment %s failed with error: %s.", self.name, e) - raise - - return self.status == Status.FINISHED, self.updated_flow + if self._initialize_aggregator_server( + tls, root_certificate, private_key, certificate, director_config + ): + await self._run_experiment_flow() + self.experiment_status.update_experiment_status( + Status.FINISHED, + updated_flow=self.updated_flow, + ) + logger.info("Experiment %s was finished successfully.", self.name) + except Exception: + self._handle_experiment_failure() + + return self.experiment_status.get_status() def _create_aggregator_grpc_server( self, diff --git a/openfl/experimental/workflow/component/envoy/envoy.py b/openfl/experimental/workflow/component/envoy/envoy.py index 1d27d48c79..d1a1ea7e39 100644 --- a/openfl/experimental/workflow/component/envoy/envoy.py +++ b/openfl/experimental/workflow/component/envoy/envoy.py @@ -138,12 +138,11 @@ def _run(self) -> None: # Wait for experiment from Director server experiment_name = self._envoy_dir_client.wait_experiment() data_stream = self._envoy_dir_client.get_experiment_data(experiment_name) + data_file_path = self._save_data_stream_to_file(data_stream) except Exception as exc: logger.exception("Failed to get experiment: %s", exc) time.sleep(self.DEFAULT_RETRY_TIMEOUT_IN_SECONDS) continue - data_file_path = self._save_data_stream_to_file(data_stream) - try: with ExperimentWorkspace( experiment_name=f"{self.name}_{experiment_name}", @@ -154,6 +153,7 @@ def _run(self) -> None: self._run_collaborator() except Exception as exc: logger.exception("Collaborator failed with error: %s:", exc) + continue finally: self.is_experiment_running = False diff --git a/openfl/experimental/workflow/interface/fl_spec.py b/openfl/experimental/workflow/interface/fl_spec.py index ac25c5a692..6dfaea1055 100644 --- a/openfl/experimental/workflow/interface/fl_spec.py +++ b/openfl/experimental/workflow/interface/fl_spec.py @@ -184,6 +184,7 @@ def _setup_initial_state(self) -> None: def _run_federated(self) -> None: """Executes the flow using FederatedRuntime.""" try: + exp_name = None # Prepare workspace and submit it for the FederatedRuntime archive_path, exp_name = self.runtime.prepare_workspace_archive() self.runtime.submit_experiment(archive_path, exp_name) @@ -193,11 +194,15 @@ def _run_federated(self) -> None: # Retrieve the flspec object to update the experiment state flspec_obj = self._get_flow_state() # Update state of self - self._update_from_flspec_obj(flspec_obj) + if flspec_obj: + self._update_from_flspec_obj(flspec_obj) except Exception as e: - raise Exception( - f"FederatedRuntime: Experiment {exp_name} failed to run due to error: {e}" + error_msg = ( + "FederatedRuntime: Failed to prepare workspace archive" + if exp_name is None + else f"FederatedRuntime: Experiment {exp_name} failed" ) + raise Exception(f"{error_msg} due to error: {e}") from e def _update_from_flspec_obj(self, flspec_obj: FLSpec) -> None: """Update self with attributes from the updated flspec instance. @@ -219,13 +224,15 @@ def _get_flow_state(self) -> Union[FLSpec, None]: flspec_obj (Union[FLSpec, None]): An updated FLSpec instance if the experiment runs successfully. None if the experiment could not run. """ - status, flspec_obj = self.runtime.get_flow_state() + status, flspec_obj, exception = self.runtime.get_flow_state() if status: - print("Experiment ran successfully") - return flspec_obj + print("\033[92mExperiment ran successfully\033[0m") else: - print("Experiment could not run") - return None + print( + "\033[91m Experiment could not run due to error:\033[0m", + f"\033[91m{exception}\033[0m", + ) + return flspec_obj def _capture_instance_snapshot(self, kwargs) -> List: """Takes backup of self before exclude or include filtering. diff --git a/openfl/experimental/workflow/protocols/director.proto b/openfl/experimental/workflow/protocols/director.proto index d8e8f9c692..dcb3361f4f 100644 --- a/openfl/experimental/workflow/protocols/director.proto +++ b/openfl/experimental/workflow/protocols/director.proto @@ -89,6 +89,7 @@ message GetFlowStateRequest {} message GetFlowStateResponse { bool completed = 1; bytes flspec_obj = 2; + string exception = 3; } message SendRuntimeRequest {} diff --git a/openfl/experimental/workflow/runtime/federated_runtime.py b/openfl/experimental/workflow/runtime/federated_runtime.py index 4bda382765..100d149487 100644 --- a/openfl/experimental/workflow/runtime/federated_runtime.py +++ b/openfl/experimental/workflow/runtime/federated_runtime.py @@ -179,22 +179,23 @@ def submit_experiment(self, archive_path, exp_name) -> None: finally: self.remove_workspace_archive(archive_path) - def get_flow_state(self) -> Tuple[bool, Any]: + def get_flow_state(self) -> Tuple[bool, Any, str]: """ Retrieve the updated flow status and deserialized flow object. Returns: status (bool): The flow status. flow_object: The deserialized flow object. + exception (str): Exception message if any. """ - status, flspec_obj = self._runtime_dir_client.get_flow_state() + status, flspec_obj, exception = self._runtime_dir_client.get_flow_state() # Append generated workspace path to sys.path # to allow unpickling of flspec_obj sys.path.append(str(self.generated_workspace_path)) flow_object = dill.loads(flspec_obj) - return status, flow_object + return status, flow_object, exception def get_envoys(self) -> List[str]: """ diff --git a/openfl/experimental/workflow/transport/grpc/director_client.py b/openfl/experimental/workflow/transport/grpc/director_client.py index 4ebe4a636d..5e6441ea8d 100644 --- a/openfl/experimental/workflow/transport/grpc/director_client.py +++ b/openfl/experimental/workflow/transport/grpc/director_client.py @@ -316,7 +316,7 @@ def get_flow_state(self) -> Tuple: """ response_stream = self.stub.GetFlowState(director_pb2.GetFlowStateRequest()) response = datastream_to_proto(director_pb2.GetFlowStateResponse(), response_stream) - return response.completed, response.flspec_obj + return response.completed, response.flspec_obj, response.exception def stream_experiment_stdout(self, experiment_name) -> Iterator[Dict[str, Any]]: """Stream experiment stdout RPC. diff --git a/openfl/experimental/workflow/transport/grpc/director_server.py b/openfl/experimental/workflow/transport/grpc/director_server.py index b5de3fe4b6..74aed763fc 100644 --- a/openfl/experimental/workflow/transport/grpc/director_server.py +++ b/openfl/experimental/workflow/transport/grpc/director_server.py @@ -9,6 +9,7 @@ from pathlib import Path from typing import AsyncIterator, Optional, Union +import dill import grpc from grpc import aio, ssl_server_credentials @@ -250,14 +251,20 @@ async def GetExperimentData( Yields: director_pb2.ExperimentData: The experiment data. """ - data_file_path = self.director.get_experiment_data(request.experiment_name) - max_buffer_size = 2 * 1024 * 1024 - with open(data_file_path, "rb") as df: - while True: - data = df.read(max_buffer_size) - if len(data) == 0: - break - yield director_pb2.ExperimentData(size=len(data), exp_data=data) + try: + data_file_path = self.director.get_experiment_data(request.experiment_name) + max_buffer_size = 2 * 1024 * 1024 + with open(data_file_path, "rb") as df: + while True: + data = df.read(max_buffer_size) + if not data: + break + yield director_pb2.ExperimentData(size=len(data), exp_data=data) + except Exception as e: + await context.abort( + grpc.StatusCode.INTERNAL, + f"Failed to stream experiment data: {type(e).__name__}: {e}", + ) async def WaitExperiment(self, request, context) -> director_pb2.WaitExperimentResponse: """Handles a request to wait for an experiment to be ready. @@ -325,10 +332,11 @@ async def GetFlowState(self, request, context) -> director_pb2.GetFlowStateRespo Returns: director_pb2.GetFlowStateResponse: The response to the request. """ - status, flspec_obj = await self.director.get_flow_state() + status = await self.director.get_flow_state() response = director_pb2.GetFlowStateResponse( - completed=status, - flspec_obj=flspec_obj, + completed=status["status"], + flspec_obj=dill.dumps(status["updated_flow"]), + exception=status["exception"], ) for chunk in proto_to_datastream(response): await context.write(chunk)