Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for complex number #374

Merged
merged 9 commits into from
Jan 27, 2025
Merged
36 changes: 28 additions & 8 deletions damnit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import h5py

from .backend.db import BlobTypes, DamnitDB
from .backend.db import BlobTypes, DamnitDB, blob2complex


# This is a copy of damnit.ctxsupport.ctxrunner.DataType, purely so that we can
Expand Down Expand Up @@ -157,7 +157,7 @@
[VariableData.read()][damnit.api.VariableData.read].
"""
result = self._db.conn.execute("""
SELECT value, max(version) FROM run_variables
SELECT value, summary_type, max(version) FROM run_variables
WHERE proposal=? AND run=? AND name=?
""", (self.proposal, self.run, self.name)).fetchone()

Expand All @@ -166,7 +166,10 @@
# after creating the VariableData object.
raise RuntimeError(f"Could not find value for '{self.name}' in p{self.proposal}, r{self.name}")
else:
return result[0]
value, summary_type, version = result
if isinstance(value, bytes) and summary_type == "complex":
return blob2complex(value)
return value

def __repr__(self):
return f"<VariableData for '{self.name}' in p{self.proposal}, r{self.run}>"
Expand Down Expand Up @@ -385,13 +388,30 @@
if "comment" not in df:
df.insert(3, "comment", None)

# Convert PNG blobs into a string
def image2str(value):
if isinstance(value, bytes) and BlobTypes.identify(value) is BlobTypes.png:
return "<image>"
# interpret blobs
def blob2type(value, summary_type=None):
if isinstance(value, bytes):
if summary_type == "complex":
return blob2complex(value)

Check warning on line 395 in damnit/api.py

View check run for this annotation

Codecov / codecov/patch

damnit/api.py#L395

Added line #L395 was not covered by tests
match BlobTypes.identify(value):
case BlobTypes.png | BlobTypes.numpy:
return "<image>"
case BlobTypes.unknown | _:
return "<unknown>"

Check warning on line 400 in damnit/api.py

View check run for this annotation

Codecov / codecov/patch

damnit/api.py#L399-L400

Added lines #L399 - L400 were not covered by tests
else:
return value
df = df.applymap(image2str)

def interpret_blobs(row):
summary_types = self._db.conn.execute(
"SELECT name, summary_type FROM run_variables WHERE proposal=? AND run=? AND summary_type IS NOT NULL",
(row["proposal"], row["run"])).fetchall()
summary_types = { row[0]: row[1] for row in summary_types }

for col in row.keys():
row[col] = blob2type(row[col], summary_types.get(col))
return row

df = df.apply(interpret_blobs, axis=1)

# Use the full variable titles
if with_titles:
Expand Down
24 changes: 20 additions & 4 deletions damnit/backend/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import os
import logging
import os
import sqlite3
from collections.abc import MutableMapping, ValuesView, ItemsView
from dataclasses import dataclass, asdict
import struct
from collections.abc import ItemsView, MutableMapping, ValuesView
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -65,6 +66,7 @@ class ReducedData:
value: Any
max_diff: float = None
summary_method: str = ''
summary_type: Optional[str] = None
attributes: Optional[dict] = None


Expand All @@ -83,6 +85,17 @@ def identify(cls, blob: bytes):
return cls.unknown


def complex2blob(data: complex) -> bytes:
# convert complex to bytes
return struct.pack('<dd', data.real, data.imag)


def blob2complex(data: bytes) -> complex:
# convert bytes to complex
real, imag = struct.unpack('<dd', data)
return complex(real, imag)


def db_path(root_path: Path):
return root_path / DB_NAME

Expand Down Expand Up @@ -324,6 +337,9 @@ def set_variable(self, proposal: int, run: int, name: str, reduced):
if variable["value"] is None:
for key in variable:
variable[key] = None
elif isinstance(variable["value"], complex):
variable["value"] = complex2blob(variable["value"])
variable["summary_type"] = "complex"

variable["proposal"] = proposal
variable["run"] = run
Expand All @@ -342,7 +358,7 @@ def set_variable(self, proposal: int, run: int, name: str, reduced):
variable["version"] = 1 # if latest_version is None else latest_version + 1

# These columns should match those in the run_variables table
cols = ["proposal", "run", "name", "version", "value", "timestamp", "max_diff", "provenance", "summary_method", "attributes"]
cols = ["proposal", "run", "name", "version", "value", "timestamp", "max_diff", "provenance", "summary_method", "summary_type", "attributes"]
col_list = ", ".join(cols)
col_values = ", ".join([f":{col}" for col in cols])
col_updates = ", ".join([f"{col} = :{col}" for col in cols])
Expand Down
2 changes: 1 addition & 1 deletion damnit/backend/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def add_to_db(reduced_data, db: DamnitDB, proposal, run):
db.ensure_run(proposal, run, start_time=start_time.value)

for name, reduced in reduced_data.items():
if not isinstance(reduced.value, (int, float, str, bytes)):
if not isinstance(reduced.value, (int, float, str, bytes, complex)):
raise TypeError(f"Unsupported type for database: {type(reduced.value)}")

db.set_variable(proposal, run, name, reduced)
Expand Down
4 changes: 3 additions & 1 deletion damnit/ctxsupport/damnit_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@
def _max_diff(self):
a = self.data
if isinstance(a, (np.ndarray, xr.DataArray)) and a.size > 1:
return abs(np.subtract(np.nanmax(a), np.nanmin(a), dtype=np.float64))
if np.issubdtype(a.dtype, np.bool_):
return 1. if (True in a) and (False in a) else 0.
return np.abs(np.subtract(np.nanmax(a), np.nanmin(a)), dtype=np.float64)

Check warning on line 181 in damnit/ctxsupport/damnit_ctx.py

View check run for this annotation

Codecov / codecov/patch

damnit/ctxsupport/damnit_ctx.py#L179-L181

Added lines #L179 - L181 were not covered by tests

def summary_attrs(self):
d = {}
Expand Down
1 change: 1 addition & 0 deletions damnit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def bool_to_numeric(data):
def fix_data_for_plotting(data):
return bool_to_numeric(make_finite(data))


def delete_variable(db, name):
# Remove from the database
db.delete_variable(name)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = ["License :: OSI Approved :: BSD License"]
dynamic = ["version"]
readme = "README.md"
dependencies = [
"h5netcdf",
"h5netcdf>=1.4.1",
"h5py",
"orjson", # used in plotly for faster json serialization
"pandas",
Expand Down
14 changes: 12 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,16 @@ def test_variable_data(mock_db_with_data, monkeypatch):

# Insert a DataSet variable
dataset_code = """
from damnit_ctx import Variable
from damnit_ctx import Cell, Variable
import xarray as xr

@Variable(title="Dataset")
def dataset(run):
return xr.Dataset(data_vars={ "foo": xr.DataArray([1, 2, 3]) })
data = xr.Dataset(data_vars={
"foo": xr.DataArray([1, 2, 3]),
"bar/baz": xr.DataArray([1+2j, 3-4j, 5+6j]),
})
return Cell(data, summary_value=data['bar/baz'][2])
"""
(db_dir / "context.py").write_text(dedent(dataset_code))
extract_mock_run(1)
Expand All @@ -122,10 +126,16 @@ def dataset(run):
dataset = rv["dataset"].read()
assert isinstance(dataset, xr.Dataset)
assert isinstance(dataset.foo, xr.DataArray)
assert isinstance(dataset.bar_baz, xr.DataArray)
assert dataset.bar_baz.dtype == np.complex128

# Datasets have a internal _damnit attribute that should be removed
assert len(dataset.attrs) == 0

summary = rv["dataset"].summary()
assert isinstance(summary, complex)
assert summary == complex(5, 6)

fig = rv['plotly_mc_plotface'].read()
assert isinstance(fig, PlotlyFigure)
assert fig == px.bar(x=["a", "b", "c"], y=[1, 3, 2])
Expand Down
27 changes: 27 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,33 @@ def dataset(run):
with h5py.File(results_hdf5_path) as f:
assert f[".reduced/dataset"].asstr()[()].startswith("Dataset")

# Test returning complex results
complex_code = """
from damnit_ctx import Variable
import numpy as np
import xarray as xr

data = np.array([1+1j, 2+2j, 3+3j])

@Variable(title="Complex Dataset")
def complex_dataset(run):
return xr.Dataset(data_vars={"foo": xr.DataArray(data),})

@Variable(title='Complex Array')
def complex_array(run):
return data
"""
complex_ctx = mkcontext(complex_code)
results = results_create(complex_ctx)
results.save_hdf5(results_hdf5_path)

dataset = xr.load_dataset(results_hdf5_path, group="complex_dataset", engine="h5netcdf")
assert "foo" in dataset
assert dataset['foo'].dtype == np.complex128
with h5py.File(results_hdf5_path) as f:
assert f[".reduced/complex_dataset"].asstr()[()].startswith("Dataset")
assert np.allclose(f['complex_array/data'][()], np.array([1+1j, 2+2j, 3+3j]))

# Test getting mymdc fields
mymdc_code = """
from damnit_ctx import Variable
Expand Down
21 changes: 21 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import pytest

from damnit.backend.db import complex2blob, blob2complex


def test_metameta(mock_db):
_, db = mock_db

Expand Down Expand Up @@ -83,3 +88,19 @@ def test_tags(mock_db_with_data):

# Test untagging with nonexistent variable (should not raise error)
db.untag_variable("nonexistent_var", "important")


@pytest.mark.parametrize("value", [
1+2j,
0+0j,
-1.5-3.7j,
2.5+0j,
0+3.1j,
float('inf')+0j,
complex(float('inf'), -float('inf')),
])
def test_complex_blob_conversion(value):
# Test that converting complex -> blob -> complex preserves the value
blob = complex2blob(value)
result = blob2complex(blob)
assert result == value
Loading