Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions openfl/experimental/workflow/interface/participants.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def get_name(self) -> str:
"""
return self._name

def get_state(self) -> Dict[str, Any]:
"""Returns the state of the participant.

Returns:
Dict[str, Any]: The state of the participant.
"""
return self.__dict__
Copy link

Copilot AI May 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Returning __dict__ exposes all internal attributes, including private ones. Consider returning only the necessary state fields to avoid leaking internal implementation details.

Suggested change
return self.__dict__
return {
"private_attributes": self.private_attributes,
"name": self._name
}

Copilot uses AI. Check for mistakes.

def initialize_private_attributes(self, private_attrs: Dict[Any, Any] = None) -> None:
"""Initialize private attributes of Participant (aggregator or collaborator)
by invoking the callable specified by user."""
Expand Down
25 changes: 20 additions & 5 deletions openfl/experimental/workflow/runtime/local_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ def __get_aggregator_object(self, aggregator: Type[Aggregator]) -> Any:
ResourcesNotAvailableError: If the requested resources exceed the
available resources.
"""

if aggregator.private_attributes and aggregator.private_attributes_callable:
self.logger.warning(
"Warning: Aggregator private attributes "
Expand All @@ -397,6 +396,10 @@ def __get_aggregator_object(self, aggregator: Type[Aggregator]) -> Any:
if self.backend == "single_process":
return aggregator

# Store aggregator reference to later sync the full internal state
# from the remote
self.__aggregator_reference = aggregator

total_available_cpus = os.cpu_count()
total_available_gpus = get_number_of_gpus()

Expand Down Expand Up @@ -474,6 +477,10 @@ def __get_collaborator_object(self, collaborators: List) -> Any:
if self.backend == "single_process":
return collaborators

# Store collaborators references to later sync the full internal state
# from the remote
self.__collaborators_reference = collaborators

total_available_cpus = os.cpu_count()
total_required_cpus = sum([collaborator.num_cpus for collaborator in collaborators])
if total_available_cpus < total_required_cpus:
Expand All @@ -482,9 +489,8 @@ def __get_collaborator_object(self, collaborators: List) -> Any:
({total_required_cpus} < {total_available_cpus})."
)

if self.backend == "ray":
collaborator_ray_refs = ray_group_assign(collaborators, num_actors=self.num_actors)
return collaborator_ray_refs
collaborator_ray_refs = ray_group_assign(collaborators, num_actors=self.num_actors)
return collaborator_ray_refs

@property
def aggregator(self) -> str:
Expand Down Expand Up @@ -536,6 +542,14 @@ def get_collab_name(collab):
get_collab_name(collaborator): collaborator for collaborator in collaborators
}

def _sync_participants_state(self) -> None:
"""Update local aggregator and collaborator references with remote states.
(Ray backend).
"""
self.__aggregator_reference.__dict__.update(ray.get(self._aggregator.get_state.remote()))
for idx, collab in enumerate(self.__collaborators.values()):
self.__collaborators_reference[idx].__dict__.update(ray.get(collab.get_state.remote()))

def get_collaborator_kwargs(self, collaborator_name: str):
"""Returns kwargs of collaborator.

Expand Down Expand Up @@ -659,7 +673,8 @@ def execute_task(self, flspec_obj: Type[FLSpec], f: Callable, **kwargs):
f, parent_func, instance_snapshot, kwargs = flspec_obj.execute_task_args
else:
flspec_obj = self.execute_agg_task(flspec_obj, f)

if self.backend == "ray":
self._sync_participants_state()
artifacts_iter, _ = generate_artifacts(ctx=flspec_obj)
return artifacts_iter()

Expand Down
21 changes: 21 additions & 0 deletions tests/end_to_end/test_suites/wf_local_func_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from tests.end_to_end.workflow.private_attr_wo_callable import TestFlowPrivateAttributesWoCallable
from tests.end_to_end.workflow.private_attributes_flow import TestFlowPrivateAttributes
from tests.end_to_end.workflow.private_attr_both import TestFlowPrivateAttributesBoth
from tests.end_to_end.workflow.dynamic_private_attr_sync import TestFlowDynamicPrivateAttributeSync

from tests.end_to_end.utils import wf_helper as wf_helper

Expand Down Expand Up @@ -220,3 +221,23 @@ def test_private_attr_both(request, fx_local_federated_workflow_prvt_attr):
log.info(f"Starting round {i}...")
flflow.run()
log.info("Successfully ended test_private_attr_both")

@pytest.mark.parametrize("fx_local_federated_workflow", [("init_mock_pvt_attr", None, "init_mock_pvt_attr")], indirect=True)
def test_dynamic_private_attr_sync(request, fx_local_federated_workflow):
"""
Set private attribute through callable function and direct assignment
"""
log.info("Starting test_dynamic_private_attr_sync")
flflow = TestFlowDynamicPrivateAttributeSync(checkpoint=True)
flflow.runtime = fx_local_federated_workflow.runtime
flflow.run()
test_attribute_sets = wf_helper.get_test_attribute_sets()
wf_helper.check_modified_private_attributes(
fx_local_federated_workflow.aggregator,
test_attribute_sets[fx_local_federated_workflow.aggregator.name]
)
for collab in fx_local_federated_workflow.collaborators:
wf_helper.check_modified_private_attributes(
collab, test_attribute_sets[collab.name]
)
log.info("Successfully ended test_dynamic_private_attr_sync")
3 changes: 2 additions & 1 deletion tests/end_to_end/utils/wf_common_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def fx_local_federated_workflow(request):
init_collaborator_private_attr_index,
init_collaborator_private_attr_name,
init_collaborate_pvt_attr_np,
init_agg_pvt_attr_np
init_agg_pvt_attr_np,
init_mock_pvt_attr
)
collab_callback_func = request.param[0] if hasattr(request, 'param') and request.param else None
collab_value = request.param[1] if hasattr(request, 'param') and request.param else None
Expand Down
139 changes: 139 additions & 0 deletions tests/end_to_end/utils/wf_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@

from metaflow import Flow
import logging
import torch
from torch.utils.data import DataLoader
import torchvision
import datetime
import numpy as np
from typing import Dict, Any

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,6 +61,128 @@ def validate_flow(flow_obj, expected_flow_steps):
return steps_present_in_cli, missing_steps_in_cli, extra_steps_in_cli


def get_test_attribute_sets() -> Dict[str, Dict[str, Any]]:
"""
Generates a test dictionary of private attributes for multiple entities, including various
data types.

Returns:
Dict[str, Dict[str, Any]]: A dictionary where each key is an entity name
(e.g., 'Aggregator', 'Paris') and the value is another dictionary of mock
private attributes using a variety of data types.
"""
torch.backends.cudnn.enabled = False
torch.manual_seed(1)

transform = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,), (0.3081,))]
)

train_loader = DataLoader(
torchvision.datasets.MNIST("files/", train=True, download=True, transform=transform),
batch_size=128,
shuffle=True,
)

test_loader = DataLoader(
torchvision.datasets.MNIST("files/", train=False, download=True, transform=transform),
batch_size=128,
shuffle=True,
)

return {
"agg": {
"private_attribute_1": train_loader,
"private_attribute_2": test_loader,
"private_attribute_3": 3.14,
"private_attribute_4": np.array([1, 2, 3]),
},
"collaborator0": {
"private_attribute_1": True,
"private_attribute_2": [1, 2, 3],
"private_attribute_3": {"a": 1},
"private_attribute_4": None,
},
"collaborator1": {
"private_attribute_1": (4, 5),
"private_attribute_2": b"bytes",
"private_attribute_3": complex(1, 2),
"private_attribute_4": np.int64(10),
},
"collaborator2": {
"private_attribute_1": {1, 2, 3},
"private_attribute_2": frozenset([4, 5]),
"private_attribute_3": range(5),
"private_attribute_4": np.float32(5.5),
},
"collaborator3": {
"private_attribute_1": bytearray(b"abc"),
"private_attribute_2": memoryview(b"xyz"),
"private_attribute_3": slice(1, 5, 2),
"private_attribute_4": np.bool_(False),
},
"collaborator4": {
"private_attribute_1": NotImplemented,
"private_attribute_2": Ellipsis,
"private_attribute_3": memoryview(bytearray(b"test")),
"private_attribute_4": np.complex64(3 + 4j),
},
"collaborator5": {
"private_attribute_1": set(),
"private_attribute_2": type,
"private_attribute_3": super,
"private_attribute_4": datetime.datetime(2023, 1, 1),
},
}


def dataloader_equal(dl1, dl2):
"""Check if two DataLoader objects are equal.
Args:
dl1 (torch.utils.data.DataLoader): First DataLoader object.
dl2 (torch.utils.data.DataLoader): Second DataLoader object.
"""
return (
isinstance(dl1, torch.utils.data.DataLoader)
and isinstance(dl2, torch.utils.data.DataLoader)
and dl1.batch_size == dl2.batch_size
and type(dl1.dataset) is type(dl2.dataset)
and isinstance(dl1.sampler, type(dl2.sampler))
)

def check_modified_private_attributes(participant, expected_attributes) -> None:
"""Check if the participant's private_attributes match the expected values.
Args:
participant (Participant): The participant (aggregator or collaborator) to check.
expected_attributes (dict): The expected private attributes.
"""
actual_attributes = participant.private_attributes
mismatches = []

for key, expected_value in expected_attributes.items():
actual_value = actual_attributes.get(key, "<Missing>")
if isinstance(expected_value, np.ndarray) and isinstance(actual_value, np.ndarray):
equal = np.array_equal(actual_value, expected_value)
elif isinstance(expected_value, torch.utils.data.DataLoader):
equal = dataloader_equal(actual_value, expected_value)
else:
equal = actual_value == expected_value

if not equal:
mismatches.append((key, actual_value, expected_value))

if mismatches:
print(f"{participant.name} attribute mismatches detected:")
for key, actual, expected in mismatches:
print(f"- {key}: actual={actual} | expected={expected}")
raise AssertionError(f"{len(mismatches)} mismatches found in {participant.name}")
else:
print(
f"{participant.name} attributes "
f"match expected values!"
)


def init_collaborator_private_attr_index(param):
"""
Initialize a collaborator's private attribute index.
Expand Down Expand Up @@ -112,3 +239,15 @@ def init_agg_pvt_attr_np():
of a NumPy array of shape (10, 28, 28) filled with random values.
"""
return {"test_loader": np.random.rand(10, 28, 28)}


def init_mock_pvt_attr(**kwargs):
"""
Initialize a dictionary with private attributes for testing.

Returns:
dict: A dictionary containing four keys "private_attribute_1", "private_attribute_2",
"private_attribute_3", and "private_attribute_4", each with a value of a NumPy
array of shape (10, 28, 28) filled with random values.
"""
return {f"private_attribute_{i}": np.random.rand(10, 28, 28) for i in range(1, 5)}
53 changes: 53 additions & 0 deletions tests/end_to_end/workflow/dynamic_private_attr_sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (C) 2020-2025 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import logging
from openfl.experimental.workflow.interface import FLSpec
from openfl.experimental.workflow.placement import aggregator, collaborator
from tests.end_to_end.utils.wf_helper import get_test_attribute_sets

log = logging.getLogger(__name__)


class bcolors:
HEADER = "\033[95m"
OKBLUE = "\033[94m"
OKCYAN = "\033[96m"
OKGREEN = "\033[92m"
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
BOLD = "\033[1m"
UNDERLINE = "\033[4m"


class TestFlowDynamicPrivateAttributeSync(FLSpec):
"""Test case to validate the dynamic assignment and synchronization of private attributes
across aggregator and collaborators during Flow execution.
"""

@aggregator
def start(self):
log.info(f"{bcolors.OKBLUE}Testing FederatedFlow - Starting Test {bcolors.ENDC}")
self.collaborators = self.runtime.collaborators
self.next(self.aggregator_step)

@aggregator
def aggregator_step(self):
self.modify_private_attributes("agg")
self.next(self.collaborator_step_b, foreach="collaborators")

@collaborator
def collaborator_step_b(self):
self.modify_private_attributes(self.input)
self.next(self.end)

@aggregator
def end(self, _):
log.info(f"{bcolors.OKBLUE}Test round completed.{bcolors.ENDC}")

def modify_private_attributes(self, participant) -> None:
"""Modify private attributes for the aggregator and collaborators."""
test_attribute_sets = get_test_attribute_sets()
for key, value in test_attribute_sets[participant].items():
setattr(self, key, value)
Loading