diff --git a/src/py/flwr/common/record/conversion_utils.py b/src/py/flwr/common/record/conversion_utils.py index 7cc0b04283e9..dbfdf504f35e 100644 --- a/src/py/flwr/common/record/conversion_utils.py +++ b/src/py/flwr/common/record/conversion_utils.py @@ -15,26 +15,10 @@ """Conversion utility functions for Records.""" -from io import BytesIO - -import numpy as np - -from ..constant import SType from ..typing import NDArray from .parametersrecord import Array def array_from_numpy(ndarray: NDArray) -> Array: """Create Array from NumPy ndarray.""" - buffer = BytesIO() - # WARNING: NEVER set allow_pickle to true. - # Reason: loading pickled data can execute arbitrary code - # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html - np.save(buffer, ndarray, allow_pickle=False) - data = buffer.getvalue() - return Array( - dtype=str(ndarray.dtype), - shape=list(ndarray.shape), - stype=SType.NUMPY, - data=data, - ) + return Array.from_numpy_ndarray(ndarray) diff --git a/src/py/flwr/common/record/parametersrecord.py b/src/py/flwr/common/record/parametersrecord.py index c0fc7fd44cf9..dfbf076766ac 100644 --- a/src/py/flwr/common/record/parametersrecord.py +++ b/src/py/flwr/common/record/parametersrecord.py @@ -15,10 +15,12 @@ """ParametersRecord and Array.""" +from __future__ import annotations + from collections import OrderedDict from dataclasses import dataclass from io import BytesIO -from typing import Optional, cast +from typing import Any, cast, overload import numpy as np @@ -27,6 +29,13 @@ from .typeddict import TypedDict +def _raise_array_init_error() -> None: + raise TypeError( + f"Invalid arguments for {Array.__qualname__}. Expected a " + "NumPy ndarray, or explicit dtype/shape/stype/data values." + ) + + @dataclass class Array: """Array type. @@ -50,6 +59,23 @@ class Array: data: bytes A buffer of bytes containing the data. + + Examples + -------- + You can create an `Array` object from a NumPy `ndarray`: + + >>> import numpy as np + >>> + >>> arr = Array(np.random.randn(3, 3)) + + Alternatively, you can create an `Array` object by specifying explicit values: + + >>> arr = Array( + >>> dtype="float32", + >>> shape=[3, 3], + >>> stype="numpy.ndarray", + >>> data=b"serialized_data...", + >>> ) """ dtype: str @@ -57,6 +83,94 @@ class Array: stype: str data: bytes + @overload + def __init__(self, ndarray: NDArray) -> None: ... # noqa: E704 + + @overload + def __init__( # noqa: E704 + self, dtype: str, shape: list[int], stype: str, data: bytes + ) -> None: ... + + def __init__( # pylint: disable=too-many-arguments, too-many-locals + self, + *args: Any, + ndarray: NDArray | None = None, + dtype: str | None = None, + shape: list[int] | None = None, + stype: str | None = None, + data: bytes | None = None, + ) -> None: + # Workaround to support multiple initialization signatures. + # This method validates and assigns the correct arguments, + # including keyword arguments such as dtype and shape. + # Supported initialization formats: + # 1. Array(dtype: str, shape: list[int], stype: str, data: bytes) + # 2. Array(ndarray: NDArray) + + # Init all arguments + # If more than 4 positional arguments are provided, raise an error. + if len(args) > 4: + _raise_array_init_error() + all_args = [None] * 4 + for i, arg in enumerate(args): + all_args[i] = arg + + def _try_set_arg(index: int, arg: Any) -> None: + if arg is None: + return + if all_args[index] is not None: + _raise_array_init_error() + all_args[index] = arg + + # Try to set keyword arguments in all_args + _try_set_arg(0, ndarray) + _try_set_arg(0, dtype) + _try_set_arg(1, shape) + _try_set_arg(2, stype) + _try_set_arg(3, data) + + # Check if all arguments are correctly set + all_args = [arg for arg in all_args if arg is not None] + if len(all_args) not in [1, 4]: + _raise_array_init_error() + + # Handle NumPy array + if isinstance(all_args[0], np.ndarray): + self.__dict__.update(self.from_numpy_ndarray(all_args[0]).__dict__) + return + + # Handle direct field initialization + if ( + isinstance(all_args[0], str) + and isinstance(all_args[1], list) + and all(isinstance(i, int) for i in all_args[1]) + and isinstance(all_args[2], str) + and isinstance(all_args[3], bytes) + ): + self.dtype, self.shape, self.stype, self.data = all_args + return + + _raise_array_init_error() + + @classmethod + def from_numpy_ndarray(cls, ndarray: NDArray) -> Array: + """Create Array from NumPy ndarray.""" + assert isinstance( + ndarray, np.ndarray + ), f"Expected NumPy ndarray, got {type(ndarray)}" + buffer = BytesIO() + # WARNING: NEVER set allow_pickle to true. + # Reason: loading pickled data can execute arbitrary code + # Source: https://numpy.org/doc/stable/reference/generated/numpy.save.html + np.save(buffer, ndarray, allow_pickle=False) + data = buffer.getvalue() + return Array( + dtype=str(ndarray.dtype), + shape=list(ndarray.shape), + stype=SType.NUMPY, + data=data, + ) + def numpy(self) -> NDArray: """Return the array as a NumPy array.""" if self.stype != SType.NUMPY: @@ -117,7 +231,6 @@ class ParametersRecord(TypedDict[str, Array]): >>> import numpy as np >>> from flwr.common import ParametersRecord - >>> from flwr.common import array_from_numpy >>> >>> # Let's create a simple NumPy array >>> arr_np = np.random.randn(3, 3) @@ -128,7 +241,7 @@ class ParametersRecord(TypedDict[str, Array]): >>> [-0.10758364, 1.97619858, -0.37120501]]) >>> >>> # Let's create an Array out of it - >>> arr = array_from_numpy(arr_np) + >>> arr = Array(arr_np) >>> >>> # If we print it you'll see (note the binary data) >>> Array(dtype='float64', shape=[3,3], stype='numpy.ndarray', data=b'@\x99\x18...') @@ -176,7 +289,7 @@ class ParametersRecord(TypedDict[str, Array]): def __init__( self, - array_dict: Optional[OrderedDict[str, Array]] = None, + array_dict: OrderedDict[str, Array] | None = None, keep_input: bool = False, ) -> None: super().__init__(_check_key, _check_value) diff --git a/src/py/flwr/common/record/parametersrecord_test.py b/src/py/flwr/common/record/parametersrecord_test.py index 9644fc0541d1..4639bb9371ac 100644 --- a/src/py/flwr/common/record/parametersrecord_test.py +++ b/src/py/flwr/common/record/parametersrecord_test.py @@ -18,9 +18,11 @@ import unittest from collections import OrderedDict from io import BytesIO +from typing import Any import numpy as np import pytest +from parameterized import parameterized from flwr.common import ndarray_to_bytes @@ -72,6 +74,60 @@ def test_numpy_conversion_invalid(self) -> None: with self.assertRaises(TypeError): array_instance.numpy() + def test_array_from_numpy(self) -> None: + """Test the array_from_numpy function.""" + # Prepare + original_array = np.array([1, 2, 3], dtype=np.float32) + + # Execute + array_instance = Array.from_numpy_ndarray(original_array) + buffer = BytesIO(array_instance.data) + deserialized_array = np.load(buffer, allow_pickle=False) + + # Assert + self.assertEqual(array_instance.dtype, str(original_array.dtype)) + self.assertEqual(array_instance.shape, list(original_array.shape)) + self.assertEqual(array_instance.stype, SType.NUMPY) + np.testing.assert_array_equal(deserialized_array, original_array) + + @parameterized.expand( # type: ignore + [ + ("ndarray", np.array([1, 2, 3])), + ("explicit_values", "float32", [2, 2], "dense", b"data"), + ] + ) + def test_valid_init_overloads_kwargs(self, name: str, *args: Any) -> None: + """Ensure valid overloads initialize correctly.""" + if name == "explicit_values": + array = Array(dtype=args[0], shape=args[1], stype=args[2], data=args[3]) + else: + kwargs = {name: args[0]} + array = Array(**kwargs) + self.assertIsInstance(array, Array) + + @parameterized.expand( # type: ignore + [ + (np.array([1, 2, 3]),), + ("float32", [2, 2], "dense", b"data"), + ] + ) + def test_valid_init_overloads_args(self, *args: Any) -> None: + """Ensure valid overloads initialize correctly.""" + array = Array(*args) + self.assertIsInstance(array, Array) + + @parameterized.expand( # type: ignore + [ + ("float32", [2, 2], "dense", 213), + ([2, 2], "dense", b"data"), + (123, "invalid"), + ] + ) + def test_invalid_init_combinations(self, *args: Any) -> None: + """Ensure invalid combinations raise TypeError.""" + with self.assertRaises(TypeError): + Array(*args) + @pytest.mark.parametrize( "shape, dtype",