-
Notifications
You must be signed in to change notification settings - Fork 109
Open
Description
Question
I'm trying to equal np.save and np.load times from file but I'm pretty far. is there any way to optimize my class?
import pathlib
import time
from typing import Type, TypeVar
import msgspec
import msgspec.json
import msgspec.msgpack
import numpy as np
from typing_extensions import Buffer
# Type variable for return type hints
T = TypeVar("T", bound="MsgpackModel")
# ------------------------------------------------------------------------------
# Custom hooks for MessagePack
# ------------------------------------------------------------------------------
def msgpack_enc_hook(obj: object) -> object:
if isinstance(obj, np.ndarray):
# obj.data is memoryview
return [obj.data, str(obj.dtype), obj.shape]
elif isinstance(obj, pathlib.Path):
return obj.as_posix()
raise NotImplementedError(f"Object of type {type(obj)} is not supported in {msgpack_enc_hook.__name__}")
def msgpack_dec_hook(expected_type: Type, obj: object) -> object:
if expected_type is pathlib.Path:
return pathlib.Path(obj)
elif expected_type is np.ndarray:
return np.frombuffer(obj[0], dtype=obj[1]).reshape(obj[2])
return obj
# ------------------------------------------------------------------------------
# Base class for models with multiple serialization methods
# ------------------------------------------------------------------------------
class MsgpackModel(msgspec.Struct):
"""
Abstract base class for models that support serialization/deserialization
via MessagePack, JSON, YAML, and TOML. It has built-in support for:
- np.ndarray (using binary extension in MessagePack; list conversion in JSON)
- pathlib.Path (encoded as POSIX strings)
This class is meant to be subclassed only.
"""
def __post_init__(self):
# Prevent direct instantiation of the base class.
if type(self) is MsgpackModel:
raise TypeError("MsgpackModel is an abstract base class; please subclass it.")
# -- MessagePack serialization --
def to_msgpack(self) -> bytes:
"""
Serialize the instance to MessagePack bytes.
"""
encoder = msgspec.msgpack.Encoder(enc_hook=msgpack_enc_hook)
return encoder.encode(self)
def to_msgpack_path(self, path: pathlib.Path | str) -> None:
"""
Serialize the instance to MessagePack bytes and save to a file.
"""
with open(path, "wb") as f:
f.write(self.to_msgpack())
@classmethod
def from_msgpack(cls: Type[T], data: Buffer) -> T:
"""
Deserialize MessagePack bytes into an instance of the calling class.
"""
decoder = msgspec.msgpack.Decoder(cls, dec_hook=msgpack_dec_hook)
return decoder.decode(data)
@classmethod
def from_msgpack_path(cls: Type[T], path: pathlib.Path | str) -> T:
"""
Deserialize MessagePack bytes from a file into an instance of the calling class
without unnecessary copies.
"""
with open(path, "rb") as f:
data = f.read()
return cls.from_msgpack(data)
if __name__ == "__main__":
# Create a random 1000x1000 numpy array of floats
random_array = np.random.rand(20000, 20000)
# Time saving and loading using numpy's save and load
np_save_path = "random_array.npy"
start_time = time.time()
np.save(np_save_path, random_array)
np_save_duration = time.time() - start_time
print(f"NumPy save duration: {np_save_duration:.6f} seconds")
start_time = time.time()
loaded_array_np = np.load(np_save_path)
np_load_duration = time.time() - start_time
assert np.array_equal(random_array, loaded_array_np)
print(f"NumPy load duration: {np_load_duration:.6f} seconds")
# Define a model with a single field for the numpy array
class ArrayModel(MsgpackModel):
array: np.ndarray
model_instance = ArrayModel(array=random_array)
# Time saving and loading using the custom model
start_time = time.time()
model_bytes = model_instance.to_msgpack_path("random_array.msgpack")
model_save_duration = time.time() - start_time
print(f"Model save duration: {model_save_duration:.6f} seconds")
start_time = time.time()
loaded_model_instance = ArrayModel.from_msgpack_path("random_array.msgpack")
model_load_duration = time.time() - start_time
assert np.array_equal(random_array, loaded_model_instance.array)
print(f"Model load duration: {model_load_duration:.6f} seconds")
NumPy save duration: 1.980130 seconds
NumPy load duration: 0.932738 seconds
Model save duration: 2.645414 seconds
Model load duration: 1.598974 seconds
Metadata
Metadata
Assignees
Labels
No labels