diff --git a/openfl/experimental/workflow/interface/cli/cli_helper.py b/openfl/experimental/workflow/interface/cli/cli_helper.py index 70577e3368..b1392f63f9 100644 --- a/openfl/experimental/workflow/interface/cli/cli_helper.py +++ b/openfl/experimental/workflow/interface/cli/cli_helper.py @@ -1,7 +1,6 @@ # Copyright 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Module with auxiliary CLI helper functions.""" import os @@ -14,128 +13,213 @@ from click import echo, style from yaml import FullLoader, load -FX = argv[0] +# Constants for better readability +FX_BINARY = argv[0] +TREE_SYMBOLS = {"space": " ", "branch": "│ ", "tee": "├── ", "last": "└── "} +DEFAULT_LENGTH_LIMIT = 1000 +WORKSPACE_CONFIG_FILE = ".workspace" +# Path constants SITEPACKS = Path(__file__).parent.parent.parent.parent.parent.parent WORKSPACE = SITEPACKS / "openfl-workspace" / "experimental" / "workflow" / "FederatedRuntime" TUTORIALS = SITEPACKS / "openfl-tutorials" OPENFL_USERDIR = Path.home() / ".openfl" -def pretty(o): - """Pretty-print the dictionary given.""" - m = max(map(len, o.keys())) +def pretty(dictionary): + """Pretty-print the dictionary with aligned formatting. + + Args: + dictionary (dict): Dictionary to print with color formatting. + """ + if not dictionary: + return - for k, v in o.items(): - echo(style(f"{k:<{m}} : ", fg="blue") + style(f"{v}", fg="cyan")) + max_key_length = max(map(len, dictionary.keys())) + + for key, value in dictionary.items(): + echo(style(f"{key:<{max_key_length}} : ", fg="blue") + style(f"{value}", fg="cyan")) def print_tree( dir_path: Path, level: int = -1, limit_to_directories: bool = False, - length_limit: int = 1000, + length_limit: int = DEFAULT_LENGTH_LIMIT, ): - """Given a directory Path object print a visual tree structure.""" - space = " " - branch = "│ " - tee = "├── " - last = "└── " - + """Print a visual tree structure for the given directory. + + Args: + dir_path (Path): Directory path to visualize. + level (int): Maximum depth to traverse (-1 for unlimited). + limit_to_directories (bool): If True, only show directories. + length_limit (int): Maximum number of items to display. + """ echo("\nNew experimental workspace directory structure:") - dir_path = Path(dir_path) # accept string coerceable to Path - files = 0 - directories = 0 + dir_path = Path(dir_path) # Accept string coerceable to Path + files_count = 0 + directories_count = 0 + + def _generate_tree_lines(directory: Path, prefix: str = "", level: int = -1): + """Generate tree structure lines recursively. + + Args: + directory (Path): Current directory to process. + prefix (str): Current prefix for tree formatting. + level (int): Remaining depth to traverse. + + Yields: + str: Formatted tree lines. + """ + nonlocal files_count, directories_count - def inner(dir_path: Path, prefix: str = "", level=-1): - nonlocal files, directories - if not level: - return # 0, stop iterating + if level == 0: + return # Stop traversing at depth limit + + # Get directory contents if limit_to_directories: - contents = [d for d in dir_path.iterdir() if d.is_dir()] + contents = [item for item in directory.iterdir() if item.is_dir()] else: - contents = list(dir_path.iterdir()) - pointers = [tee] * (len(contents) - 1) + [last] + contents = list(directory.iterdir()) + + # Create pointers for tree structure + pointers = [TREE_SYMBOLS["tee"]] * (len(contents) - 1) + [TREE_SYMBOLS["last"]] + for pointer, path in zip(pointers, contents): if path.is_dir(): yield prefix + pointer + path.name - directories += 1 - extension = branch if pointer == tee else space - yield from inner(path, prefix=prefix + extension, level=level - 1) + directories_count += 1 + # Determine next prefix + extension = ( + TREE_SYMBOLS["branch"] + if pointer == TREE_SYMBOLS["tee"] + else TREE_SYMBOLS["space"] + ) + yield from _generate_tree_lines(path, prefix=prefix + extension, level=level - 1) elif not limit_to_directories: yield prefix + pointer + path.name - files += 1 + files_count += 1 + # Print root directory and tree structure echo(dir_path.name) - iterator = inner(dir_path, level=level) - for line in islice(iterator, length_limit): + tree_iterator = _generate_tree_lines(dir_path, level=level) + + for line in islice(tree_iterator, length_limit): echo(line) - if next(iterator, None): + + # Check if we hit the length limit + if next(tree_iterator, None): echo(f"... length_limit, {length_limit}, reached, counted:") - echo(f"\n{directories} directories" + (f", {files} files" if files else "")) + # Print summary + files_info = f", {files_count} files" if files_count else "" + echo(f"\n{directories_count} directories{files_info}") -def get_workspace_parameter(name): - """Get a parameter from the workspace config file (.workspace).""" - # Update the .workspace file to show the current workspace plan - workspace_file = ".workspace" - with open(workspace_file, "r", encoding="utf-8") as f: - doc = load(f, Loader=FullLoader) +def get_workspace_parameter(parameter_name: str) -> str: + """Get a parameter from the workspace config file (.workspace). - if not doc: # YAML is not correctly formatted - doc = {} # Create empty dictionary + Args: + parameter_name (str): Name of the parameter to retrieve. - if name not in doc.keys() or not doc[name]: # List doesn't exist + Returns: + str: Parameter value or empty string if not found. + """ + try: + with open(WORKSPACE_CONFIG_FILE, "r", encoding="utf-8") as config_file: + workspace_config = load(config_file, Loader=FullLoader) + except FileNotFoundError: return "" - else: - return doc[name] + + # Handle case where YAML is not correctly formatted + if not workspace_config: + workspace_config = {} + + # Return parameter value or empty string if not found or empty + return workspace_config.get(parameter_name, "") or "" -def check_varenv(env: str = "", args: dict = None): - """Update "args" (dictionary) with if env has a defined - value in the host.""" +def check_varenv(env_var: str = "", args: dict = None) -> dict: + """Update args dictionary with environment variable value if defined. + + Args: + env_var (str): Environment variable name to check. + args (dict): Dictionary to update with environment variable value. + + Returns: + dict: Updated args dictionary. + """ if args is None: args = {} - env_val = environ.get(env) - if env and (env_val is not None): - args[env] = env_val + + if env_var: + env_value = environ.get(env_var) + if env_value is not None: + args[env_var] = env_value return args -def get_fx_path(curr_path=""): - """Return the absolute path to fx binary.""" +def get_fx_path(current_path: str = "") -> str: + """Return the absolute path to fx binary. + + Args: + current_path (str): Current path to process. - match = re.search("lib", curr_path) - idx = match.end() - path_prefix = curr_path[0:idx] + Returns: + str: Path to the fx binary. + + Raises: + AttributeError: If 'lib' pattern is not found in current_path. + """ + lib_match = re.search("lib", current_path) + if not lib_match: + raise AttributeError("'lib' pattern not found in current_path") + + lib_end_index = lib_match.end() + path_prefix = current_path[:lib_end_index] bin_path = re.sub("lib", "bin", path_prefix) fx_path = os.path.join(bin_path, "fx") return fx_path -def remove_line_from_file(pkg, filename): - """Remove line that contains `pkg` from the `filename` file.""" - with open(filename, "r+", encoding="utf-8") as f: - d = f.readlines() - f.seek(0) - for i in d: - if pkg not in i: - f.write(i) - f.truncate() - - -def replace_line_in_file(line, line_num_to_replace, filename): - """Replace line at `line_num_to_replace` with `line`.""" - with open(filename, "r+", encoding="utf-8") as f: - d = f.readlines() - f.seek(0) - for idx, i in enumerate(d): - if idx == line_num_to_replace: - f.write(line) +def remove_line_from_file(package_name: str, filename: str) -> None: + """Remove lines containing the specified package from the file. + + Args: + package_name (str): Package name to search for and remove. + filename (str): Path to the file to modify. + """ + with open(filename, "r+", encoding="utf-8") as file: + lines = file.readlines() + file.seek(0) + + # Write back only lines that don't contain the package name + for line in lines: + if package_name not in line: + file.write(line) + + file.truncate() + + +def replace_line_in_file(new_line: str, line_number: int, filename: str) -> None: + """Replace line at specified line number with new content. + + Args: + new_line (str): New line content to write. + line_number (int): Zero-based line number to replace. + filename (str): Path to the file to modify. + """ + with open(filename, "r+", encoding="utf-8") as file: + lines = file.readlines() + file.seek(0) + + for idx, line in enumerate(lines): + if idx == line_number: + file.write(new_line) else: - f.write(i) - f.truncate() + file.write(line) + + file.truncate() diff --git a/openfl/experimental/workflow/notebooktools/code_analyzer.py b/openfl/experimental/workflow/notebooktools/code_analyzer.py index 9fed1e7426..472571c6f2 100644 --- a/openfl/experimental/workflow/notebooktools/code_analyzer.py +++ b/openfl/experimental/workflow/notebooktools/code_analyzer.py @@ -12,14 +12,23 @@ import nbformat from nbdev.export import nb_export +# Constants for better readability +DEFAULT_EXP_PATTERN = r"#\s*\|\s*default_exp\s+(\w+)" +RUNTIME_CLASS_NAME = "FederatedRuntime" +RUN_METHOD_PATTERN = ".run()" +MAGIC_COMMANDS = ("!", "%") +EXCLUDE_PARAMS = ("self", "args", "kwargs") + class CodeAnalyzer: - """Analyzes and process Jupyter Notebooks. - Provides code extraction and transformation functionality + """Analyzes and processes Jupyter Notebooks. + + Provides code extraction and transformation functionality for converting + Jupyter notebooks to Python scripts with federated learning specific modifications. Attributes: - script_name (str): Name of the generated python script. - script_path (Path): Absolute path to the python script generated. + script_name (str): Name of the generated Python script. + script_path (Path): Absolute path to the Python script generated. requirements (List[str]): List of pip libraries found in the script. exported_script_module (ModuleType): The imported module object of the generated script. available_modules_in_exported_script (list): List of available attributes in the @@ -27,221 +36,266 @@ class CodeAnalyzer: """ def __init__(self, notebook_path: Path, output_path: Path) -> None: - """Initialize CodeAnalyzer and process the script from notebook + """Initialize CodeAnalyzer and process the script from notebook. Args: notebook_path (Path): Path to Jupyter notebook to be converted. output_path (Path): The directory where the converted Python script will be saved. """ print("Converting jupyter notebook to python script...") + # Extract the export filename from the notebook - self.script_name = self.__get_exp_name(notebook_path) + self.script_name = self._get_experiment_name(notebook_path) + # Convert the notebook to a Python script and set the script path + script_output_dir = output_path.joinpath("src") + script_filename = f"{self.script_name}.py" self.script_path = Path( - self.__convert_to_python( - notebook_path, - output_path.joinpath("src"), - f"{self.script_name}.py", - ) + self._convert_notebook_to_python(notebook_path, script_output_dir, script_filename) ).resolve() + self.requirements = self._get_requirements() - self.__modify_experiment_script() + self._modify_experiment_script() + + def _get_experiment_name(self, notebook_path: Path) -> str: + """Extract experiment name from Jupyter notebook. - def __get_exp_name(self, notebook_path: Path) -> str: - """Extract experiment name from Jupyter notebook Looks for '#| default_exp ' pattern in code cells and extracts the experiment name. The name must be a valid Python identifier. Args: - notebook_path (str): Path to Jupyter notebook. + notebook_path (Path): Path to Jupyter notebook. + + Returns: + str: The experiment name extracted from the notebook. + + Raises: + ValueError: If no default_exp marker is found in the notebook. """ - with notebook_path.open("r") as f: - notebook_content = nbformat.read(f, as_version=nbformat.NO_CONVERT) + with notebook_path.open("r") as notebook_file: + notebook_content = nbformat.read(notebook_file, as_version=nbformat.NO_CONVERT) for cell in notebook_content.cells: if cell.cell_type == "code": code = cell.source - match = re.search(r"#\s*\|\s*default_exp\s+(\w+)", code) + match = re.search(DEFAULT_EXP_PATTERN, code) if match: - print(f"Retrieved {match.group(1)} from default_exp") - return match.group(1) + experiment_name = match.group(1) + print(f"Retrieved {experiment_name} from default_exp") + return experiment_name + raise ValueError( - "The notebook does not contain a '#| default_exp ' marker. " "Please add the marker to the first cell of the notebook" ) - def __convert_to_python(self, notebook_path: Path, output_path: Path, export_filename) -> Path: - """Converts a Jupyter notebook to a Python script. + def _convert_notebook_to_python( + self, notebook_path: Path, output_path: Path, export_filename: str + ) -> Path: + """Convert a Jupyter notebook to a Python script. + Args: - notebook_path (Path): The path to the Jupyter notebook file - to be converted. - output_path (Path): The directory where the exported Python - script should be saved. - export_filename: The name of the exported Python script file. + notebook_path (Path): The path to the Jupyter notebook file to be converted. + output_path (Path): The directory where the exported Python script should be saved. + export_filename (str): The name of the exported Python script file. Returns: Path: The path to the exported Python script file. """ nb_export(notebook_path, output_path) - return Path(output_path).joinpath(export_filename).resolve() - def __modify_experiment_script(self) -> None: - """Modifies the given python script by commenting out following code: - - occurences of flflow.run() - - instance of FederatedRuntime + def _modify_experiment_script(self) -> None: + """Modify the generated Python script by commenting out specific code. + + Comments out the following: + - Occurrences of .run() method calls + - Instances of FederatedRuntime class """ - runtime_class = "FederatedRuntime" - instantiation_info = self.__extract_class_instantiation_info(runtime_class) - instance_name = instantiation_info.get("instance_name", []) + instantiation_info = self._extract_class_instantiation_info(RUNTIME_CLASS_NAME) + instance_names = instantiation_info.get("instance_name", []) + # Read the script content, excluding Jupyter magic commands with open(self.script_path, "r") as file: - code = "".join(line for line in file if not line.lstrip().startswith(("!", "%"))) + script_content = "".join( + line for line in file if not line.lstrip().startswith(MAGIC_COMMANDS) + ) - code = self.__comment_flow_execution(code) - code = self.__comment_class_instance(code, instance_name) + # Apply modifications + script_content = self._comment_flow_execution(script_content) + script_content = self._comment_class_instance(script_content, instance_names) + # Write the modified content back with open(self.script_path, "w") as file: - file.write(code) + file.write(script_content) + + def _comment_class_instance(self, script_code: str, instance_names: List[str]) -> str: + """Comment out specified class instances in the provided script. - def __comment_class_instance(self, script_code: str, instance_name: List[str]) -> str: - """ - Comments out specified class instance in the provided script Args: - script_code (str): Script content to be analyzed - instance_name (List[str]): The name of the instance + script_code (str): Script content to be analyzed. + instance_names (List[str]): The names of the instances to comment out. Returns: - str: The updated script with the specified instance lines commented out + str: The updated script with the specified instance lines commented out. """ tree = ast.parse(script_code) lines = script_code.splitlines() lines_to_comment = set() + for node in ast.walk(tree): if isinstance(node, (ast.Assign, ast.Expr)): + # Check if any subnode references our target instance names if any( - isinstance(subnode, ast.Name) and subnode.id in instance_name + isinstance(subnode, ast.Name) and subnode.id in instance_names for subnode in ast.walk(node) ): - for i in range(node.lineno - 1, node.end_lineno): - lines_to_comment.add(i) + # Add all lines from this node to the comment set + for line_idx in range(node.lineno - 1, node.end_lineno): + lines_to_comment.add(line_idx) + + # Comment out the identified lines modified_lines = [ f"# {line}" if idx in lines_to_comment else line for idx, line in enumerate(lines) ] - updated_script = "\n".join(modified_lines) - return updated_script + return "\n".join(modified_lines) + + def _comment_flow_execution(self, script_code: str) -> str: + """Comment out lines containing '.run()' in the specified Python script. - def __comment_flow_execution(self, script_code: str) -> str: - """ - Comment out lines containing '.run()' in the specified Python script Args: - script_code(str): Script content to be analyzed + script_code (str): Script content to be analyzed. Returns: - str: The modified script with run_statement commented out + str: The modified script with run statements commented out. """ - run_statement = ".run()" lines = script_code.splitlines() + for idx, line in enumerate(lines): stripped_line = line.strip() - if not stripped_line.startswith("#") and run_statement in stripped_line: + # Comment out lines that contain .run() and are not already commented + if not stripped_line.startswith("#") and RUN_METHOD_PATTERN in stripped_line: lines[idx] = f"# {line}" - updated_script = "\n".join(lines) - return updated_script + return "\n".join(lines) - def __import_generated_script(self) -> None: - """ - Imports the generated python script using the importlib module + def _import_generated_script(self) -> None: + """Import the generated Python script using the importlib module. + + Raises: + ImportError: If the script cannot be imported. """ try: sys.path.append(str(self.script_path.parent)) self.exported_script_module = import_module(self.script_name) self.available_modules_in_exported_script = dir(self.exported_script_module) except ImportError as e: - raise ImportError(f"Failed to import script {self.script_name}: {e}") + raise ImportError(f"Failed to import script {self.script_name}: {e}") from e - def __get_class_arguments(self, class_name) -> list: - """Given the class name returns expected class arguments. + def _get_class_arguments(self, class_name: str) -> List[str]: + """Get expected class arguments for the given class name. Args: class_name (str): The name of the class. Returns: - list: A list of expected class arguments. + List[str]: A list of expected class arguments. + + Raises: + NameError: If the class is not found in the exported script. """ if not hasattr(self, "exported_script_module"): - self.__import_generated_script() + self._import_generated_script() # Find class from imported python script module - for idx, attr in enumerate(self.available_modules_in_exported_script): - if attr == class_name: - cls = getattr( - self.exported_script_module, - self.available_modules_in_exported_script[idx], - ) - if "cls" not in locals(): + target_class = None + for attr_name in self.available_modules_in_exported_script: + if attr_name == class_name: + target_class = getattr(self.exported_script_module, attr_name) + break + + if target_class is None: raise NameError(f"{class_name} not found.") - if inspect.isclass(cls): - if "__init__" in cls.__dict__: - init_signature = inspect.signature(cls.__init__) - # Extract the parameter names (excluding 'self', 'args', and - # 'kwargs') + if inspect.isclass(target_class): + if "__init__" in target_class.__dict__: + init_signature = inspect.signature(target_class.__init__) + # Extract parameter names (excluding 'self', 'args', and 'kwargs') arg_names = [ - param - for param in init_signature.parameters - if param not in ("self", "args", "kwargs") + param_name + for param_name in init_signature.parameters + if param_name not in EXCLUDE_PARAMS ] return arg_names return [] - print(f"{cls} is not a class") - def __get_class_name(self, parent_class) -> Optional[str]: + print(f"{target_class} is not a class") + return [] + + def _get_class_name(self, parent_class) -> Optional[str]: """Find and return the name of a class derived from the provided parent class. + Args: - parent_class: FLSpec instance. + parent_class: FLSpec instance or parent class to search for. Returns: Optional[str]: The name of the derived class. + + Raises: + ValueError: If no flow class is found that inherits from the parent class. """ if not hasattr(self, "exported_script_module"): - self.__import_generated_script() + self._import_generated_script() + + # Go through all attributes in imported python script + for attr_name in self.available_modules_in_exported_script: + attribute = getattr(self.exported_script_module, attr_name) + if ( + inspect.isclass(attribute) + and attribute != parent_class + and issubclass(attribute, parent_class) + ): + return attr_name - # Going though all attributes in imported python script - for attr in self.available_modules_in_exported_script: - t = getattr(self.exported_script_module, attr) - if inspect.isclass(t) and t != parent_class and issubclass(t, parent_class): - return attr raise ValueError("No flow class found that inherits from FLSpec") - def __extract_class_instantiation_info(self, class_name: str) -> Dict[str, Any]: - """ - Extracts the instance name and its initialization arguments (both positional and keyword) - for the given class + def _extract_class_instantiation_info(self, class_name: str) -> Dict[str, Any]: + """Extract the instance name and initialization arguments for the given class. + Args: - class_name (str): The name of the class + class_name (str): The name of the class to search for. Returns: - Dict[str, Any]: A dictionary containing 'args', 'kwargs', and 'instance_name' + Dict[str, Any]: A dictionary containing 'args', 'kwargs', and 'instance_name'. """ - instantiation_args = {"args": {}, "kwargs": {}, "instance_name": []} + instantiation_info = {"args": {}, "kwargs": {}, "instance_name": []} + # Read script content, excluding Jupyter magic commands with open(self.script_path, "r") as file: - code = "".join(line for line in file if not line.lstrip().startswith(("!", "%"))) - tree = ast.parse(code) - for node in ast.walk(tree): - if isinstance(node, ast.Assign) and isinstance(node.value, ast.Call): - if isinstance(node.value.func, ast.Name) and node.value.func.id == class_name: - for target in node.targets: - if isinstance(target, ast.Name): - instantiation_args["instance_name"].append(target.id) - # We found an instantiation of the class - instantiation_args["args"] = self._extract_positional_args(node.value.args) - instantiation_args["kwargs"] = self._extract_keyword_args(node.value.keywords) + script_content = "".join( + line for line in file if not line.lstrip().startswith(MAGIC_COMMANDS) + ) - return instantiation_args + tree = ast.parse(script_content) + + for node in ast.walk(tree): + if ( + isinstance(node, ast.Assign) + and isinstance(node.value, ast.Call) + and isinstance(node.value.func, ast.Name) + and node.value.func.id == class_name + ): + # Extract instance names + for target in node.targets: + if isinstance(target, ast.Name): + instantiation_info["instance_name"].append(target.id) + + # Extract arguments + instantiation_info["args"] = self._extract_positional_args(node.value.args) + instantiation_info["kwargs"] = self._extract_keyword_args(node.value.keywords) + + return instantiation_info def _extract_positional_args(self, args) -> Dict[str, Any]: """Extract positional arguments from the AST nodes. @@ -295,40 +349,48 @@ def _clean_value(self, value: str) -> str: return value def _get_requirements(self) -> List[str]: - """Extract pip libraries from the script + """Extract pip libraries from the script. Returns: - requirements (List[str]): List of pip libraries found in the script. + List[str]: List of pip libraries found in the script. """ - data = None - with self.script_path.open("r") as f: - requirements = [] - data = f.readlines() - for _, line in enumerate(data): - line = line.strip() - if "pip install" in line: - # Avoid commented lines, libraries from *.txt file, or openfl.git - # installation - if not line.startswith("#") and "-r" not in line and "openfl.git" not in line: - requirements.append(f"{line.split(' ')[-1].strip()}\n") - - return requirements + requirements = [] + + with self.script_path.open("r") as file: + script_lines = file.readlines() + + for line in script_lines: + stripped_line = line.strip() + + # Look for pip install commands + if "pip install" in stripped_line: + # Skip commented lines, requirements files, and OpenFL git installations + is_commented = stripped_line.startswith("#") + is_requirements_file = "-r" in stripped_line + is_openfl_git = "openfl.git" in stripped_line + + if not (is_commented or is_requirements_file or is_openfl_git): + # Extract the package name (last part after 'pip install') + package_name = stripped_line.split(" ")[-1].strip() + requirements.append(f"{package_name}\n") + + return requirements def get_flow_class_details(self, parent_class) -> Dict[str, Any]: - """ - Retrieves details of a flow class that inherits from the given parent clas + """Retrieve details of a flow class that inherits from the given parent class. + Args: parent_class: The parent class (FLSpec instance). Returns: Dict[str, Any]: A dictionary containing: - flow_class_name (str): The name of the flow class. - expected_args (List[str]): The expected arguments for the flow class. - init_args (Dict[str, Any]): The initialization arguments for the flow class. + - flow_class_name (str): The name of the flow class. + - expected_args (List[str]): The expected arguments for the flow class. + - init_args (Dict[str, Any]): The initialization arguments for the flow class. """ - flow_class_name = self.__get_class_name(parent_class) - expected_arguments = self.__get_class_arguments(flow_class_name) - init_args = self.__extract_class_instantiation_info(flow_class_name) + flow_class_name = self._get_class_name(parent_class) + expected_arguments = self._get_class_arguments(flow_class_name) + init_args = self._extract_class_instantiation_info(flow_class_name) return { "flow_class_name": flow_class_name, @@ -338,11 +400,12 @@ def get_flow_class_details(self, parent_class) -> Dict[str, Any]: def fetch_flow_configuration(self, flow_details: Dict[str, Any]) -> Dict[str, Any]: """Get flow configuration from flow details. + Args: flow_details (Dict[str, Any]): Dictionary containing flow class details. Returns: - Dict[str, Any]: Dictionary containing the plan configuration + Dict[str, Any]: Dictionary containing the plan configuration. """ flow_config = { "federated_flow": { @@ -351,28 +414,35 @@ def fetch_flow_configuration(self, flow_details: Dict[str, Any]) -> Dict[str, An } } - def update_dictionary(args: dict, dtype: str = "args") -> None: + def update_config_with_args(args: Dict[str, Any], arg_type: str = "args") -> None: """Update plan configuration with argument values. Args: - args: Dictionary of arguments to process - dtype: Type of arguments ('args' or 'kwargs') + args (Dict[str, Any]): Dictionary of arguments to process. + arg_type (str): Type of arguments ('args' or 'kwargs'). """ - for idx, (k, v) in enumerate(args.items()): - if dtype == "args": - v = getattr(self.exported_script_module, str(k), None) - if v is not None and not isinstance(v, (int, str, bool)): - v = f"src.{self.script_name}.{k}" - k = flow_details["expected_args"][idx] - elif dtype == "kwargs": - if v is not None and not isinstance(v, (int, str, bool)): - v = f"src.{self.script_name}.{v}" - flow_config["federated_flow"]["settings"].update({k: v}) - - # Process arguments - pos_args = flow_details["init_args"].get("args", {}) - update_dictionary(pos_args, "args") - kw_args = flow_details["init_args"].get("kwargs", {}) - update_dictionary(kw_args, "kwargs") + for idx, (key, value) in enumerate(args.items()): + if arg_type == "args": + # For positional args, get value from module or use as template reference + module_value = getattr(self.exported_script_module, str(key), None) + if module_value is not None and not isinstance(module_value, (int, str, bool)): + value = f"src.{self.script_name}.{key}" + else: + value = module_value + # Use expected arg name instead of variable name + key = flow_details["expected_args"][idx] + elif arg_type == "kwargs": + # For keyword args, use template reference if not primitive type + if value is not None and not isinstance(value, (int, str, bool)): + value = f"src.{self.script_name}.{value}" + + flow_config["federated_flow"]["settings"].update({key: value}) + + # Process positional and keyword arguments + positional_args = flow_details["init_args"].get("args", {}) + update_config_with_args(positional_args, "args") + + keyword_args = flow_details["init_args"].get("kwargs", {}) + update_config_with_args(keyword_args, "kwargs") return flow_config diff --git a/openfl/utilities/utils.py b/openfl/utilities/utils.py index 4f4e5fc2eb..244ff8476f 100644 --- a/openfl/utilities/utils.py +++ b/openfl/utilities/utils.py @@ -265,12 +265,11 @@ def remove_readonly(func, path, _): return shutil.rmtree(path, ignore_errors=ignore_errors, onerror=remove_readonly) -def generate_port(hash, port_range=(49152, 60999)): - """ - Generate a deterministic port number based on a hash and a unique key. +def generate_port(hash_value, port_range=(49152, 60999)): + """Generate a deterministic port number based on a hash and a unique key. Args: - hash (str): A string representing the hash of the plan. + hash_value (str): A string representing the hash of the plan. port_range (tuple): A tuple containing the minimum and maximum port numbers (inclusive). The default range is (49152, 60999). @@ -279,6 +278,6 @@ def generate_port(hash, port_range=(49152, 60999)): """ min_port, max_port = port_range # Use the first 8 characters of the unique hash to ensure deterministic output - hash_segment = hash[:8] + hash_segment = hash_value[:8] port = int(hash_segment, 16) % (max_port - min_port) + min_port return port diff --git a/tests/github/experimental/workflow/LocalRuntime/testflow_exclude.py b/tests/github/experimental/workflow/LocalRuntime/testflow_exclude.py index bc1d8d262d..9c0e35bf5e 100644 --- a/tests/github/experimental/workflow/LocalRuntime/testflow_exclude.py +++ b/tests/github/experimental/workflow/LocalRuntime/testflow_exclude.py @@ -2,12 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import sys -from openfl.experimental.workflow.interface import FLSpec, Aggregator, Collaborator -from openfl.experimental.workflow.runtime import LocalRuntime + +from openfl.experimental.workflow.interface import Aggregator, Collaborator, FLSpec from openfl.experimental.workflow.placement import aggregator, collaborator +from openfl.experimental.workflow.runtime import LocalRuntime + +class TerminalColors: # NOQA: N801 + """ANSI color codes for terminal output formatting.""" -class bcolors: # NOQA: N801 HEADER = "\033[95m" OKBLUE = "\033[94m" OKCYAN = "\033[96m" @@ -19,6 +22,10 @@ class bcolors: # NOQA: N801 UNDERLINE = "\033[4m" +# Keep backward compatibility +bcolors = TerminalColors + + class TestFlowExclude(FLSpec): """ Testflow to validate exclude functionality in Federated Flow @@ -177,11 +184,7 @@ def end(self): + f"{bcolors.ENDC}" ) if TestFlowExclude.exclude_error_list: - raise ( - AssertionError( - f"{bcolors.FAIL}\n ...Test case failed ... {bcolors.ENDC}" - ) - ) + raise (AssertionError(f"{bcolors.FAIL}\n ...Test case failed ... {bcolors.ENDC}")) if __name__ == "__main__":