Skip to content

Commit

Permalink
Merge pull request #374 from European-XFEL/feat/complexNumbers
Browse files Browse the repository at this point in the history
feat: add support for complex number
  • Loading branch information
tmichela authored Jan 27, 2025
2 parents 112fe6d + 3eb25b9 commit 76352ec
Show file tree
Hide file tree
Showing 9 changed files with 114 additions and 17 deletions.
36 changes: 28 additions & 8 deletions damnit/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import h5py

from .backend.db import BlobTypes, DamnitDB
from .backend.db import BlobTypes, DamnitDB, blob2complex


# This is a copy of damnit.ctxsupport.ctxrunner.DataType, purely so that we can
Expand Down Expand Up @@ -157,7 +157,7 @@ def summary(self):
[VariableData.read()][damnit.api.VariableData.read].
"""
result = self._db.conn.execute("""
SELECT value, max(version) FROM run_variables
SELECT value, summary_type, max(version) FROM run_variables
WHERE proposal=? AND run=? AND name=?
""", (self.proposal, self.run, self.name)).fetchone()

Expand All @@ -166,7 +166,10 @@ def summary(self):
# after creating the VariableData object.
raise RuntimeError(f"Could not find value for '{self.name}' in p{self.proposal}, r{self.name}")
else:
return result[0]
value, summary_type, version = result
if isinstance(value, bytes) and summary_type == "complex":
return blob2complex(value)
return value

def __repr__(self):
return f"<VariableData for '{self.name}' in p{self.proposal}, r{self.run}>"
Expand Down Expand Up @@ -385,13 +388,30 @@ def table(self, with_titles=False) -> "pd.DataFrame":
if "comment" not in df:
df.insert(3, "comment", None)

# Convert PNG blobs into a string
def image2str(value):
if isinstance(value, bytes) and BlobTypes.identify(value) is BlobTypes.png:
return "<image>"
# interpret blobs
def blob2type(value, summary_type=None):
if isinstance(value, bytes):
if summary_type == "complex":
return blob2complex(value)
match BlobTypes.identify(value):
case BlobTypes.png | BlobTypes.numpy:
return "<image>"
case BlobTypes.unknown | _:
return "<unknown>"
else:
return value
df = df.applymap(image2str)

def interpret_blobs(row):
summary_types = self._db.conn.execute(
"SELECT name, summary_type FROM run_variables WHERE proposal=? AND run=? AND summary_type IS NOT NULL",
(row["proposal"], row["run"])).fetchall()
summary_types = { row[0]: row[1] for row in summary_types }

for col in row.keys():
row[col] = blob2type(row[col], summary_types.get(col))
return row

df = df.apply(interpret_blobs, axis=1)

# Use the full variable titles
if with_titles:
Expand Down
24 changes: 20 additions & 4 deletions damnit/backend/db.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import os
import logging
import os
import sqlite3
from collections.abc import MutableMapping, ValuesView, ItemsView
from dataclasses import dataclass, asdict
import struct
from collections.abc import ItemsView, MutableMapping, ValuesView
from dataclasses import asdict, dataclass
from datetime import datetime, timezone
from enum import Enum
from pathlib import Path
Expand Down Expand Up @@ -65,6 +66,7 @@ class ReducedData:
value: Any
max_diff: float = None
summary_method: str = ''
summary_type: Optional[str] = None
attributes: Optional[dict] = None


Expand All @@ -83,6 +85,17 @@ def identify(cls, blob: bytes):
return cls.unknown


def complex2blob(data: complex) -> bytes:
# convert complex to bytes
return struct.pack('<dd', data.real, data.imag)


def blob2complex(data: bytes) -> complex:
# convert bytes to complex
real, imag = struct.unpack('<dd', data)
return complex(real, imag)


def db_path(root_path: Path):
return root_path / DB_NAME

Expand Down Expand Up @@ -324,6 +337,9 @@ def set_variable(self, proposal: int, run: int, name: str, reduced):
if variable["value"] is None:
for key in variable:
variable[key] = None
elif isinstance(variable["value"], complex):
variable["value"] = complex2blob(variable["value"])
variable["summary_type"] = "complex"

variable["proposal"] = proposal
variable["run"] = run
Expand All @@ -342,7 +358,7 @@ def set_variable(self, proposal: int, run: int, name: str, reduced):
variable["version"] = 1 # if latest_version is None else latest_version + 1

# These columns should match those in the run_variables table
cols = ["proposal", "run", "name", "version", "value", "timestamp", "max_diff", "provenance", "summary_method", "attributes"]
cols = ["proposal", "run", "name", "version", "value", "timestamp", "max_diff", "provenance", "summary_method", "summary_type", "attributes"]
col_list = ", ".join(cols)
col_values = ", ".join([f":{col}" for col in cols])
col_updates = ", ".join([f"{col} = :{col}" for col in cols])
Expand Down
2 changes: 1 addition & 1 deletion damnit/backend/extract_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def add_to_db(reduced_data, db: DamnitDB, proposal, run):
db.ensure_run(proposal, run, start_time=start_time.value)

for name, reduced in reduced_data.items():
if not isinstance(reduced.value, (int, float, str, bytes)):
if not isinstance(reduced.value, (int, float, str, bytes, complex)):
raise TypeError(f"Unsupported type for database: {type(reduced.value)}")

db.set_variable(proposal, run, name, reduced)
Expand Down
4 changes: 3 additions & 1 deletion damnit/ctxsupport/damnit_ctx.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,9 @@ def get_summary(self):
def _max_diff(self):
a = self.data
if isinstance(a, (np.ndarray, xr.DataArray)) and a.size > 1:
return abs(np.subtract(np.nanmax(a), np.nanmin(a), dtype=np.float64))
if np.issubdtype(a.dtype, np.bool_):
return 1. if (True in a) and (False in a) else 0.
return np.abs(np.subtract(np.nanmax(a), np.nanmin(a)), dtype=np.float64)

def summary_attrs(self):
d = {}
Expand Down
1 change: 1 addition & 0 deletions damnit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def bool_to_numeric(data):
def fix_data_for_plotting(data):
return bool_to_numeric(make_finite(data))


def delete_variable(db, name):
# Remove from the database
db.delete_variable(name)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = ["License :: OSI Approved :: BSD License"]
dynamic = ["version"]
readme = "README.md"
dependencies = [
"h5netcdf",
"h5netcdf>=1.4.1",
"h5py",
"orjson", # used in plotly for faster json serialization
"pandas",
Expand Down
14 changes: 12 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,16 @@ def test_variable_data(mock_db_with_data, monkeypatch):

# Insert a DataSet variable
dataset_code = """
from damnit_ctx import Variable
from damnit_ctx import Cell, Variable
import xarray as xr
@Variable(title="Dataset")
def dataset(run):
return xr.Dataset(data_vars={ "foo": xr.DataArray([1, 2, 3]) })
data = xr.Dataset(data_vars={
"foo": xr.DataArray([1, 2, 3]),
"bar/baz": xr.DataArray([1+2j, 3-4j, 5+6j]),
})
return Cell(data, summary_value=data['bar/baz'][2])
"""
(db_dir / "context.py").write_text(dedent(dataset_code))
extract_mock_run(1)
Expand All @@ -122,10 +126,16 @@ def dataset(run):
dataset = rv["dataset"].read()
assert isinstance(dataset, xr.Dataset)
assert isinstance(dataset.foo, xr.DataArray)
assert isinstance(dataset.bar_baz, xr.DataArray)
assert dataset.bar_baz.dtype == np.complex128

# Datasets have a internal _damnit attribute that should be removed
assert len(dataset.attrs) == 0

summary = rv["dataset"].summary()
assert isinstance(summary, complex)
assert summary == complex(5, 6)

fig = rv['plotly_mc_plotface'].read()
assert isinstance(fig, PlotlyFigure)
assert fig == px.bar(x=["a", "b", "c"], y=[1, 3, 2])
Expand Down
27 changes: 27 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,33 @@ def dataset(run):
with h5py.File(results_hdf5_path) as f:
assert f[".reduced/dataset"].asstr()[()].startswith("Dataset")

# Test returning complex results
complex_code = """
from damnit_ctx import Variable
import numpy as np
import xarray as xr
data = np.array([1+1j, 2+2j, 3+3j])
@Variable(title="Complex Dataset")
def complex_dataset(run):
return xr.Dataset(data_vars={"foo": xr.DataArray(data),})
@Variable(title='Complex Array')
def complex_array(run):
return data
"""
complex_ctx = mkcontext(complex_code)
results = results_create(complex_ctx)
results.save_hdf5(results_hdf5_path)

dataset = xr.load_dataset(results_hdf5_path, group="complex_dataset", engine="h5netcdf")
assert "foo" in dataset
assert dataset['foo'].dtype == np.complex128
with h5py.File(results_hdf5_path) as f:
assert f[".reduced/complex_dataset"].asstr()[()].startswith("Dataset")
assert np.allclose(f['complex_array/data'][()], np.array([1+1j, 2+2j, 3+3j]))

# Test getting mymdc fields
mymdc_code = """
from damnit_ctx import Variable
Expand Down
21 changes: 21 additions & 0 deletions tests/test_db.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import pytest

from damnit.backend.db import complex2blob, blob2complex


def test_metameta(mock_db):
_, db = mock_db

Expand Down Expand Up @@ -83,3 +88,19 @@ def test_tags(mock_db_with_data):

# Test untagging with nonexistent variable (should not raise error)
db.untag_variable("nonexistent_var", "important")


@pytest.mark.parametrize("value", [
1+2j,
0+0j,
-1.5-3.7j,
2.5+0j,
0+3.1j,
float('inf')+0j,
complex(float('inf'), -float('inf')),
])
def test_complex_blob_conversion(value):
# Test that converting complex -> blob -> complex preserves the value
blob = complex2blob(value)
result = blob2complex(blob)
assert result == value

0 comments on commit 76352ec

Please sign in to comment.