Skip to content

optimizing extended numpy encoder decoder #811

@YoniChechik

Description

@YoniChechik

Question

I'm trying to equal np.save and np.load times from file but I'm pretty far. is there any way to optimize my class?

import pathlib
import time
from typing import Type, TypeVar

import msgspec
import msgspec.json
import msgspec.msgpack
import numpy as np
from typing_extensions import Buffer

# Type variable for return type hints
T = TypeVar("T", bound="MsgpackModel")

# ------------------------------------------------------------------------------
# Custom hooks for MessagePack
# ------------------------------------------------------------------------------


def msgpack_enc_hook(obj: object) -> object:
    if isinstance(obj, np.ndarray):
        # obj.data is memoryview
        return [obj.data, str(obj.dtype), obj.shape]
    elif isinstance(obj, pathlib.Path):
        return obj.as_posix()
    raise NotImplementedError(f"Object of type {type(obj)} is not supported in {msgpack_enc_hook.__name__}")


def msgpack_dec_hook(expected_type: Type, obj: object) -> object:
    if expected_type is pathlib.Path:
        return pathlib.Path(obj)
    elif expected_type is np.ndarray:
        return np.frombuffer(obj[0], dtype=obj[1]).reshape(obj[2])
    return obj


# ------------------------------------------------------------------------------
# Base class for models with multiple serialization methods
# ------------------------------------------------------------------------------


class MsgpackModel(msgspec.Struct):
    """
    Abstract base class for models that support serialization/deserialization
    via MessagePack, JSON, YAML, and TOML. It has built-in support for:
      - np.ndarray (using binary extension in MessagePack; list conversion in JSON)
      - pathlib.Path (encoded as POSIX strings)

    This class is meant to be subclassed only.
    """

    def __post_init__(self):
        # Prevent direct instantiation of the base class.
        if type(self) is MsgpackModel:
            raise TypeError("MsgpackModel is an abstract base class; please subclass it.")

    # -- MessagePack serialization --

    def to_msgpack(self) -> bytes:
        """
        Serialize the instance to MessagePack bytes.
        """
        encoder = msgspec.msgpack.Encoder(enc_hook=msgpack_enc_hook)
        return encoder.encode(self)

    def to_msgpack_path(self, path: pathlib.Path | str) -> None:
        """
        Serialize the instance to MessagePack bytes and save to a file.
        """

        with open(path, "wb") as f:
            f.write(self.to_msgpack())

    @classmethod
    def from_msgpack(cls: Type[T], data: Buffer) -> T:
        """
        Deserialize MessagePack bytes into an instance of the calling class.
        """
        decoder = msgspec.msgpack.Decoder(cls, dec_hook=msgpack_dec_hook)
        return decoder.decode(data)

    @classmethod
    def from_msgpack_path(cls: Type[T], path: pathlib.Path | str) -> T:
        """
        Deserialize MessagePack bytes from a file into an instance of the calling class
        without unnecessary copies.
        """

        with open(path, "rb") as f:
            data = f.read()
        return cls.from_msgpack(data)


if __name__ == "__main__":
    # Create a random 1000x1000 numpy array of floats
    random_array = np.random.rand(20000, 20000)

    # Time saving and loading using numpy's save and load
    np_save_path = "random_array.npy"
    start_time = time.time()
    np.save(np_save_path, random_array)
    np_save_duration = time.time() - start_time
    print(f"NumPy save duration: {np_save_duration:.6f} seconds")

    start_time = time.time()
    loaded_array_np = np.load(np_save_path)
    np_load_duration = time.time() - start_time

    assert np.array_equal(random_array, loaded_array_np)
    print(f"NumPy load duration: {np_load_duration:.6f} seconds")

    # Define a model with a single field for the numpy array
    class ArrayModel(MsgpackModel):
        array: np.ndarray

    model_instance = ArrayModel(array=random_array)

    # Time saving and loading using the custom model
    start_time = time.time()
    model_bytes = model_instance.to_msgpack_path("random_array.msgpack")
    model_save_duration = time.time() - start_time
    print(f"Model save duration: {model_save_duration:.6f} seconds")

    start_time = time.time()
    loaded_model_instance = ArrayModel.from_msgpack_path("random_array.msgpack")
    model_load_duration = time.time() - start_time

    assert np.array_equal(random_array, loaded_model_instance.array)
    print(f"Model load duration: {model_load_duration:.6f} seconds")
NumPy save duration: 1.980130 seconds
NumPy load duration: 0.932738 seconds
Model save duration: 2.645414 seconds
Model load duration: 1.598974 seconds

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions