Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 80 additions & 119 deletions openfl/experimental/workflow/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,38 @@ def __delete_private_attrs_from_clone(self, clone: Any, replace_str: str = None)
def _log_big_warning(self) -> None:
"""Warn user about single collaborator cert mode."""
logger.warning(
f"\n{the_dragon}\nYOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS"
f" NOT PROPER PKI AND "
f"SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN"
f" WARNED!!!"
"YOU ARE RUNNING IN SINGLE COLLABORATOR CERT MODE! THIS IS"
" NOT PROPER PKI AND "
"SHOULD ONLY BE USED IN DEVELOPMENT SETTINGS!!!! YE HAVE BEEN"
" WARNED!!!"
)

def _initialize_flow(self) -> None:
"""Initialize flow by resetting and creating clones."""
FLSpec._reset_clones()
FLSpec._create_clones(self.flow, self.flow.runtime.collaborators)

def _enqueue_next_step_for_collaborators(self, next_step) -> None:
"""Enqueue the next step and associated clone for each selected collaborator.

Args:
next_step (str): Next step to be executed by collaborators
"""
for collaborator, task_queue in self.__collaborator_tasks_queue.items():
if collaborator in self.selected_collaborators:
task_queue.put((next_step, self.clones_dict[collaborator]))
else:
logger.info(
f"Skipping task dispatch for collaborator '{collaborator}' "
f"as it is not part of selected_collaborators."
)

def _restore_instance_snapshot(self) -> None:
"""Restore the FLSpec state at the aggregator from a saved instance snapshot."""
if hasattr(self, "instance_snapshot"):
self.flow.restore_instance_snapshot(self.flow, list(self.instance_snapshot))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pls check if we need the check. FLSpec.restore_instance_snapshot already checks whether there is a backup

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can your loop be optimized for readability and efficiency by avoiding repeated dictionary lookups? And instead of k and V can we give them more generic name for readability ?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right about renaming k and v for better readability — I’ll update them to more meaningful names like collaborator and task_queue.

Regarding the optimization: the loop is already efficient as it avoids unnecessary operations by checking membership in selected_collaborators before putting the task into the queue. The dictionary lookup (self.__collaborator_tasks_queue.items()) is performed only once at the beginning, and each key is accessed only once per iteration. So, performance-wise, the loop is already optimal.

delattr(self, "instance_snapshot")

def _update_final_flow(self) -> None:
"""Update the final flow state with current flow artifacts."""
artifacts_iter, _ = generate_artifacts(ctx=self.flow)
Expand All @@ -203,6 +229,33 @@ def _get_sleep_time() -> int:
"""
return 10

async def _track_collaborator_status(self) -> None:
"""Wait for selected collaborators to connect, request tasks, and submit results."""
while not self.collaborator_task_results.is_set():
len_sel_collabs = len(self.selected_collaborators)
len_connected_collabs = len(self.connected_collaborators)
if len_connected_collabs < len_sel_collabs:
# Waiting for collaborators to connect.
logger.info(
"Waiting for "
+ f"{len_sel_collabs - len_connected_collabs}/{len_sel_collabs}"
+ " collaborators to connect..."
)
elif self.tasks_sent_to_collaborators != len_sel_collabs:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this condition required ? Let us discuss

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

logger.info(
"Waiting for "
+ f"{len_sel_collabs - self.tasks_sent_to_collaborators}/{len_sel_collabs}"
+ " to make requests for tasks..."
)
else:
# Waiting for selected collaborators to send the results.
logger.info(
"Waiting for "
+ f"{len_sel_collabs - self.collaborators_counter}/{len_sel_collabs}"
+ " collaborators to send results..."
)
await asyncio.sleep(Aggregator._get_sleep_time())

async def run_flow(self) -> FLSpec:
"""
Start the execution and run flow until completion.
Expand All @@ -211,61 +264,36 @@ async def run_flow(self) -> FLSpec:
Returns:
flow (FLSpec): Updated instance.
"""
# Start function will be the first step if any flow
self._initialize_flow()
# Start function will be the first step of any flow
f_name = "start"
# Creating a clones from the flow object
FLSpec._reset_clones()
FLSpec._create_clones(self.flow, self.flow.runtime.collaborators)

logger.info(f"Starting round {self.current_round}...")

while True:
# Perform Aggregator steps if any
next_step = self.do_task(f_name)

if self.time_to_quit:
logger.info("Experiment Completed.")
break

# Prepare queue for collaborator task, with clones
for k, v in self.__collaborator_tasks_queue.items():
if k in self.selected_collaborators:
v.put((next_step, self.clones_dict[k]))
else:
logger.info(f"Tasks will not be sent to {k}")

while not self.collaborator_task_results.is_set():
len_sel_collabs = len(self.selected_collaborators)
len_connected_collabs = len(self.connected_collaborators)
if len_connected_collabs < len_sel_collabs:
# Waiting for collaborators to connect.
logger.info(
"Waiting for "
+ f"{len_sel_collabs - len_connected_collabs}/{len_sel_collabs}"
+ " collaborators to connect..."
)
elif self.tasks_sent_to_collaborators != len_sel_collabs:
logger.info(
"Waiting for "
+ f"{len_sel_collabs - self.tasks_sent_to_collaborators}/{len_sel_collabs}"
+ " to make requests for tasks..."
)
else:
# Waiting for selected collaborators to send the results.
logger.info(
"Waiting for "
+ f"{len_sel_collabs - self.collaborators_counter}/{len_sel_collabs}"
+ " collaborators to send results..."
)
await asyncio.sleep(Aggregator._get_sleep_time())

self._enqueue_next_step_for_collaborators(next_step)
await self._track_collaborator_status()
self.collaborator_task_results.clear()
f_name = self.next_step
if hasattr(self, "instance_snapshot"):
self.flow.restore_instance_snapshot(self.flow, list(self.instance_snapshot))
delattr(self, "instance_snapshot")
self._restore_instance_snapshot()

self._update_final_flow()
return self.final_flow_state

def extract_flow(self) -> FLSpec:
"""Extract the flow object from the aggregator.

Returns:
FLSpec: The flow object.
"""
self.__delete_private_attrs_from_clone(self.flow)
return self.flow

def call_checkpoint(
self, name: str, ctx: Any, f: Callable, stream_buffer: bytes = None
) -> None:
Expand Down Expand Up @@ -527,78 +555,11 @@ def valid_collaborator_cn_and_id(
)

def all_quit_jobs_sent(self) -> bool:
"""Assert all quit jobs are sent to collaborators."""
return set(self.quit_job_sent_to) == set(self.authorized_cols)

"""
Check whether a quit job has been sent to all authorized collaborators.

the_dragon = """

,@@.@@+@@##@,@@@@.`@@#@+ *@@@@ #@##@ `@@#@# @@@@@ @@ @@@@` #@@@ :@@ `@#`@@@#.@
@@ #@ ,@ +. @@.@* #@ :` @+*@ .@`+. @@ *@::@`@@ @@# @@ #`;@`.@@ @@@`@`#@* +:@`
@@@@@ ,@@@ @@@@ +@@+ @@@@ .@@@ @@ .@+:@@@: .;+@` @@ ,;,#@` @@ @@@@@ ,@@@* @
@@ #@ ,@`*. @@.@@ #@ ,; `@+,@#.@.*` @@ ,@::@`@@` @@@@# @@`:@;*@+ @@ @`:@@`@ *@@ `
.@@`@@,+@+;@.@@ @@`@@;*@ ;@@#@:*@+;@ `@@;@@ #@**@+;@ `@@:`@@@@ @@@@.`@+ .@ +@+@*,@
`` `` ` `` . ` ` ` ` ` .` ` `` `` `` ` . `



.**
;` `****:
@**`*******
*** +***********;
,@***;` .*:,;************
;***********@@***********
;************************,
`*************************
*************************
,************************
**#*********************
*@****` :**********;
+**; .********.
;*; `*******#: `,:
****@@@++:: ,,;***.
*@@@**;#;: +: **++*,
@***#@@@: +*; ,****
@*@+**** ***` ****,
,@#******. , **** **;,**.
* ******** :, ;*:*+ ** :,**
# ********:: *,.*:**` * ,*;
. *********: .+,*:;*: : `:**
; :********: ***::** ` ` **
+ :****::*** , *;;::**` :*
`` .****::;**::: *;::::*; ;*
* *****::***:. **::::** ;:
# *****;:**** ;*::;*** ,*`
; ************` ,**:****; ::*
: *************;:;*;*++: *.
: *****************;* `*
`. `*****************; : *.
.` .*+************+****;: :*
`. :;+***********+******;` : .,*
; ::*+*******************. `:: .`:.
+ :::**********************;;:` *
+ ,::;*************;:::*******. *
# `:::+*************:::;******** :, *
@ :::***************;:;*********;:, *
@ ::::******:*********************: ,:*
@ .:::******:;*********************, :*
# :::******::******###@*******;;**** *,
# .::;*****::*****#****@*****;:::***; `` **
* ::;***********+*****+#******::*****,,,,**
: :;***********#******#******************
.` `;***********#******+****+************
`, ***#**@**+***+*****+**************;`
; *++**#******#+****+` `.,..
+ `@***#*******#****#
+ +***@********+**+:
* .+**+;**;;;**;#**#
,` ****@ +*+:
# +**+ :+**
@ ;**+, ,***+
# #@+**** *#****+
`; @+***+@ `#**+#++
# #*#@##, .++:.,#
`* @# +.
@@@
# `@
, """
Returns:
bool: True if quit jobs have been sent to all authorized collaborators,
False otherwise.
"""
return set(self.quit_job_sent_to) == set(self.authorized_cols)
40 changes: 24 additions & 16 deletions openfl/experimental/workflow/component/director/director.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
import time
from collections import defaultdict
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, Iterable, Optional, Tuple, Union

import dill
from typing import Any, AsyncGenerator, Dict, Iterable, Optional, Union

from openfl.experimental.workflow.component.director.experiment import (
Experiment,
ExperimentsRegistry,
Status,
)
from openfl.experimental.workflow.transport.grpc.exceptions import EnvoyNotFoundError

Expand All @@ -37,7 +36,7 @@ class Director:
_flow_status (Queue): Stores the flow status
experiments_registry (ExperimentsRegistry): An object of
ExperimentsRegistry to store the experiments.
col_exp (dict): A dictionary to store the experiments for
collaborator_experiments (dict): A dictionary to store the experiments for
collaborators.
col_exp_queues (defaultdict): A defaultdict to store the experiment
queues for collaborators.
Expand Down Expand Up @@ -84,7 +83,7 @@ def __init__(
self._flow_status = asyncio.Queue()

self.experiments_registry = ExperimentsRegistry()
self.col_exp = {}
self.collaborator_experiments = {}
self.col_exp_queues = defaultdict(asyncio.Queue)
self._envoy_registry = {}
self.envoy_health_check_period = envoy_health_check_period
Expand All @@ -96,6 +95,7 @@ async def start_experiment_execution_loop(self) -> None:
loop = asyncio.get_event_loop()
while True:
try:
logger.info("Waiting for an experiment to run...")
async with self.experiments_registry.get_next_experiment() as experiment:
await self._wait_for_authorized_envoys()
run_aggregator_future = loop.create_task(
Expand All @@ -115,6 +115,11 @@ async def start_experiment_execution_loop(self) -> None:
# Wait for the experiment to complete and save the result
flow_status = await run_aggregator_future
await self._flow_status.put(flow_status)
# Mark all envoys' experiment states as None,
# indicating no active experiment
self.collaborator_experiments = dict.fromkeys(
self.collaborator_experiments, None
)
except Exception as e:
logger.error(f"Error while executing experiment: {e}")
raise
Expand All @@ -131,16 +136,15 @@ async def _wait_for_authorized_envoys(self) -> None:
)
await asyncio.sleep(10)

async def get_flow_state(self) -> Tuple[bool, bytes]:
async def get_flow_state(self) -> Dict[str, Union[bool, Optional[Any], Optional[str]]]:
"""Wait until the experiment flow status indicates completion
and return the status along with a serialized FLSpec object.
and return the flow status.

Returns:
status (bool): The flow status.
flspec_obj (bytes): A serialized FLSpec object (in bytes) using dill.
dict: A dictionary containing the flow status.
"""
status, flspec_obj = await self._flow_status.get()
return status, dill.dumps(flspec_obj)
status = await self._flow_status.get()
return status

async def wait_experiment(self, envoy_name: str) -> str:
"""Waits for an experiment to be ready for a given envoy.
Expand All @@ -151,17 +155,17 @@ async def wait_experiment(self, envoy_name: str) -> str:
Returns:
str: The name of the experiment on the queue.
"""
experiment_name = self.col_exp.get(envoy_name)
experiment_name = self.collaborator_experiments.get(envoy_name)
# If any envoy gets disconnected
if experiment_name and experiment_name in self.experiments_registry:
experiment = self.experiments_registry[experiment_name]
if experiment.aggregator.current_round < experiment.aggregator.rounds_to_train:
return experiment_name

self.col_exp[envoy_name] = None
self.collaborator_experiments[envoy_name] = None
queue = self.col_exp_queues[envoy_name]
experiment_name = await queue.get()
self.col_exp[envoy_name] = experiment_name
self.collaborator_experiments[envoy_name] = experiment_name

return experiment_name

Expand Down Expand Up @@ -220,7 +224,11 @@ async def stream_experiment_stdout(
f'No experiment name "{experiment_name}" in experiments list, or caller "{caller}"'
f" does not have access to this experiment"
)
while not self.experiments_registry[experiment_name].aggregator:
experiment = self.experiments_registry[experiment_name]
while not experiment.aggregator:
if experiment.experiment_status.status == Status.FAILED:
# Exit early if the experiment failed to start
return
Comment on lines +227 to +231
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code is not very intuitive. I think it is checking for the condition where the aggregator fails to start - let us say due to issues in imported experiment.py file - right ? Can we evaluate some other way to do this

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it

await asyncio.sleep(5)
aggregator = self.experiments_registry[experiment_name].aggregator
while True:
Expand Down Expand Up @@ -277,7 +285,7 @@ def get_envoys(self) -> Dict[str, Any]:
envoy["is_online"] = time.time() < envoy.get("last_updated", 0) + envoy.get(
"valid_duration", 0
)
envoy["experiment_name"] = self.col_exp.get(envoy["name"], "None")
envoy["experiment_name"] = self.collaborator_experiments.get(envoy["name"], "None")

return self._envoy_registry

Expand Down
Loading
Loading