Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for complex number #374

Merged
merged 9 commits into from
Jan 27, 2025
22 changes: 15 additions & 7 deletions damnit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import h5py

from .backend.db import BlobTypes, DamnitDB
from .backend.db import BlobTypes, DamnitDB, blob2complex


# This is a copy of damnit.ctxsupport.ctxrunner.DataType, purely so that we can
Expand Down Expand Up @@ -107,7 +107,7 @@
def _read_netcdf(self, one_array=False):
import xarray as xr
load = xr.load_dataarray if one_array else xr.load_dataset
obj = load(self._h5_path, group=self.name, engine="h5netcdf")
obj = load(self._h5_path, group=self.name, engine="h5netcdf", invalid_netcdf=True)
# Remove internal attributes from loaded object
obj.attrs = {k: v for (k, v) in obj.attrs.items()
if not k.startswith('_damnit_')}
Expand Down Expand Up @@ -166,6 +166,8 @@
# after creating the VariableData object.
raise RuntimeError(f"Could not find value for '{self.name}' in p{self.proposal}, r{self.name}")
else:
if isinstance(result[0], bytes) and BlobTypes.identify(result[0]) is BlobTypes.complex:
return blob2complex(result[0])
return result[0]

def __repr__(self):
Expand Down Expand Up @@ -385,13 +387,19 @@
if "comment" not in df:
df.insert(3, "comment", None)

# Convert PNG blobs into a string
def image2str(value):
if isinstance(value, bytes) and BlobTypes.identify(value) is BlobTypes.png:
return "<image>"
# interpret blobs
def blob2type(value):
if isinstance(value, bytes):
match BlobTypes.identify(value):
case BlobTypes.png | BlobTypes.numpy:
return "<image>"
case BlobTypes.complex:
return blob2complex(value)
case BlobTypes.unknown | _:
return "<unknown>"

Check warning on line 399 in damnit/api.py

View check run for this annotation

Codecov / codecov/patch

damnit/api.py#L396-L399

Added lines #L396 - L399 were not covered by tests
else:
return value
df = df.applymap(image2str)
df = df.applymap(blob2type)

# Use the full variable titles
if with_titles:
Expand Down
26 changes: 23 additions & 3 deletions damnit/backend/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import os
import logging
import os
import sqlite3
from collections.abc import MutableMapping, ValuesView, ItemsView
from dataclasses import dataclass, asdict
from collections.abc import ItemsView, MutableMapping, ValuesView
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -71,6 +71,7 @@ class ReducedData:
class BlobTypes(Enum):
png = 'png'
numpy = 'numpy'
complex = 'complex'
unknown = 'unknown'

@classmethod
Expand All @@ -79,10 +80,27 @@ def identify(cls, blob: bytes):
return cls.png
elif blob.startswith(b'\x93NUMPY'):
return cls.numpy
elif blob.startswith(b'_DAMNIT_COMPLEX_'):
return cls.complex

return cls.unknown


def complex2blob(data: complex) -> bytes:
# convert complex to bytes
real = data.real.hex()
imag = data.imag.hex()
return f"_DAMNIT_COMPLEX_{real}_{imag}".encode()


def blob2complex(data: bytes) -> complex:
# convert bytes to complex
real, _, imag = data[16:].decode().partition("_")
real = float.fromhex(real)
imag = float.fromhex(imag)
return complex(real, imag)


def db_path(root_path: Path):
return root_path / DB_NAME

Expand Down Expand Up @@ -324,6 +342,8 @@ def set_variable(self, proposal: int, run: int, name: str, reduced):
if variable["value"] is None:
for key in variable:
variable[key] = None
elif isinstance(variable["value"], complex):
variable["value"] = complex2blob(variable["value"])

variable["proposal"] = proposal
variable["run"] = run
Expand Down
2 changes: 1 addition & 1 deletion damnit/backend/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def add_to_db(reduced_data, db: DamnitDB, proposal, run):
db.ensure_run(proposal, run, start_time=start_time.value)

for name, reduced in reduced_data.items():
if not isinstance(reduced.value, (int, float, str, bytes)):
if not isinstance(reduced.value, (int, float, str, bytes, complex)):
raise TypeError(f"Unsupported type for database: {type(reduced.value)}")

db.set_variable(proposal, run, name, reduced)
Expand Down
6 changes: 6 additions & 0 deletions damnit/ctxsupport/ctxrunner.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,12 +621,18 @@ def save_hdf5(self, hdf5_path, reduced_only=False):
dataarray = _set_encoding(dataarray)
obj = obj.rename_vars(vars_names)

# HDF5 implements some features that are not yet supported by netcdf,
# e.g. data types support (bool, complex, etc.). We don't really care
# about netcdf compatibility since we offer an API to access the
# data, so we use the `invalid_netcdf` option to be able to use these
# features.
obj.to_netcdf(
hdf5_path,
mode="a",
format="NETCDF4",
group=name,
engine="h5netcdf",
invalid_netcdf=True,
)

if os.stat(hdf5_path).st_uid == os.getuid():
Expand Down
4 changes: 3 additions & 1 deletion damnit/ctxsupport/damnit_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@
def _max_diff(self):
a = self.data
if isinstance(a, (np.ndarray, xr.DataArray)) and a.size > 1:
return abs(np.subtract(np.nanmax(a), np.nanmin(a), dtype=np.float64))
if np.issubdtype(a.dtype, np.bool_):
return 1. if (True in a) and (False in a) else 0.
return np.abs(np.subtract(np.nanmax(a), np.nanmin(a)), dtype=np.float64)

Check warning on line 181 in damnit/ctxsupport/damnit_ctx.py

View check run for this annotation

Codecov / codecov/patch

damnit/ctxsupport/damnit_ctx.py#L179-L181

Added lines #L179 - L181 were not covered by tests

def summary_attrs(self):
d = {}
Expand Down
1 change: 1 addition & 0 deletions damnit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def bool_to_numeric(data):
def fix_data_for_plotting(data):
return bool_to_numeric(make_finite(data))


def delete_variable(db, name):
# Remove from the database
db.delete_variable(name)
Expand Down
14 changes: 12 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,16 @@ def test_variable_data(mock_db_with_data, monkeypatch):

# Insert a DataSet variable
dataset_code = """
from damnit_ctx import Variable
from damnit_ctx import Cell, Variable
import xarray as xr

@Variable(title="Dataset")
def dataset(run):
return xr.Dataset(data_vars={ "foo": xr.DataArray([1, 2, 3]) })
data = xr.Dataset(data_vars={
"foo": xr.DataArray([1, 2, 3]),
"bar/baz": xr.DataArray([1+2j, 3-4j, 5+6j]),
})
return Cell(data, summary_value=data['bar/baz'][2])
"""
(db_dir / "context.py").write_text(dedent(dataset_code))
extract_mock_run(1)
Expand All @@ -122,10 +126,16 @@ def dataset(run):
dataset = rv["dataset"].read()
assert isinstance(dataset, xr.Dataset)
assert isinstance(dataset.foo, xr.DataArray)
assert isinstance(dataset.bar_baz, xr.DataArray)
assert dataset.bar_baz.dtype == np.complex128

# Datasets have a internal _damnit attribute that should be removed
assert len(dataset.attrs) == 0

summary = rv["dataset"].summary()
assert isinstance(summary, complex)
assert summary == complex(5, 6)

fig = rv['plotly_mc_plotface'].read()
assert isinstance(fig, PlotlyFigure)
assert fig == px.bar(x=["a", "b", "c"], y=[1, 3, 2])
Expand Down
27 changes: 27 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,33 @@ def dataset(run):
with h5py.File(results_hdf5_path) as f:
assert f[".reduced/dataset"].asstr()[()].startswith("Dataset")

# Test returning complex results
complex_code = """
from damnit_ctx import Variable
import numpy as np
import xarray as xr

data = np.array([1+1j, 2+2j, 3+3j])

@Variable(title="Complex Dataset")
def complex_dataset(run):
return xr.Dataset(data_vars={"foo": xr.DataArray(data),})

@Variable(title='Complex Array')
def complex_array(run):
return data
"""
complex_ctx = mkcontext(complex_code)
results = results_create(complex_ctx)
results.save_hdf5(results_hdf5_path)

dataset = xr.load_dataset(results_hdf5_path, group="complex_dataset", engine="h5netcdf", invalid_netcdf=True)
assert "foo" in dataset
assert dataset['foo'].dtype == np.complex128
with h5py.File(results_hdf5_path) as f:
assert f[".reduced/complex_dataset"].asstr()[()].startswith("Dataset")
assert np.allclose(f['complex_array/data'][()], np.array([1+1j, 2+2j, 3+3j]))

# Test getting mymdc fields
mymdc_code = """
from damnit_ctx import Variable
Expand Down
21 changes: 21 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import pytest

from damnit.backend.db import complex2blob, blob2complex


def test_metameta(mock_db):
_, db = mock_db

Expand Down Expand Up @@ -83,3 +88,19 @@ def test_tags(mock_db_with_data):

# Test untagging with nonexistent variable (should not raise error)
db.untag_variable("nonexistent_var", "important")


@pytest.mark.parametrize("value", [
1+2j,
0+0j,
-1.5-3.7j,
2.5+0j,
0+3.1j,
float('inf')+0j,
complex(float('inf'), -float('inf')),
])
def test_complex_blob_conversion(value):
# Test that converting complex -> blob -> complex preserves the value
blob = complex2blob(value)
result = blob2complex(blob)
assert result == value
Loading