Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support pickle serialization #53

Merged
merged 2 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/test_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import pickle

import numpy as np
import pytest

from tunits import Value, UnitMismatchError
import tunits as tu


def test_construction() -> None:
Expand Down Expand Up @@ -263,3 +267,15 @@ def test_dimensionless() -> None:
_ = A.dimensionless()

assert B.dimensionless() == 1.2


def test_pick_roundtrip() -> None:
units = itertools.product(
[tu.ns, tu.GHz, tu.dBm, tu.deg, 1, tu.GHz**0.5, tu.m**0.75, tu.s**0.25], repeat=3
)
for value in np.random.random(20):
for unit_list in units:
unit = np.prod(unit_list) # type: ignore[arg-type]
x = value * unit
s = pickle.dumps(x)
assert x == pickle.loads(s)
16 changes: 16 additions & 0 deletions test/test_value_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import pickle

import numpy as np
import pytest
from tunits.core import raw_WithUnit, raw_UnitArray

from tunits import ValueArray, UnitMismatchError, Value
import tunits as tu


def test_construction() -> None:
Expand Down Expand Up @@ -238,3 +242,15 @@ def test_dimensionless() -> None:
_ = A.dimensionless()

np.testing.assert_equal(B.dimensionless(), np.arange(5) / 1000)


def test_pick_roundtrip() -> None:
units = itertools.product(
[tu.ns, tu.GHz, tu.dBm, tu.deg, 1, tu.GHz**0.5, tu.m**0.75, tu.s**0.25], repeat=3
)
for value in np.random.random((5, 20)):
for unit_list in units:
unit = np.prod(unit_list) # type: ignore[arg-type]
x = value * unit
s = pickle.dumps(x)
assert all(x == pickle.loads(s))
15 changes: 15 additions & 0 deletions tunits/core/cython/unit_array.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,19 @@ cdef class UnitArray:
raise ValueError("UnitArray power does not support third argument")
return self.pow_frac(float_to_twelths_frac(exponent));

def __getstate__(self):
return {
'unit_count': self.unit_count,
'units': [*self],
}

def __setstate__(self, pickle_info: dict[str, Any]):
self.unit_count = pickle_info['unit_count']
self.units = <UnitTerm *>PyMem_Malloc(self.unit_count*sizeof(UnitTerm))
for i, (name, numer, denom) in enumerate(pickle_info['units']):
Py_INCREF(name)
self.units[i].name = <PyObject *>name
self.units[i].power.numer = numer
self.units[i].power.denom = denom

_EmptyUnit = UnitArray()
16 changes: 16 additions & 0 deletions tunits/core/cython/with_unit.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,22 @@ cdef class WithUnit:
def _from_json_dict_(cls, **kwargs):
return cls(kwargs["value"], kwargs["unit"])

def __getstate__(self):
return {
'value': self.value,
'conv': self.conv,
'display_units': self.display_units.__getstate__(),
'base_units': self.base_units.__getstate__(),
}

def __setstate__(self, pickle_info):
self.value = pickle_info['value']
self.conv = pickle_info['conv']
self.display_units = UnitArray()
self.base_units = UnitArray()
self.display_units.__setstate__(pickle_info['display_units'])
self.base_units.__setstate__(pickle_info['base_units'])

_try_interpret_as_with_unit = None
_is_value_consistent_with_default_unit_database = None
def init_base_unit_functions(
Expand Down