diff --git a/Makefile b/Makefile index 51cbd65..554902e 100644 --- a/Makefile +++ b/Makefile @@ -34,6 +34,10 @@ serve: build ## Serve the stlite static build dev: ## Run normal Streamlit locally $(UV) run streamlit run $(APP_PY) +.PHONY: editor +editor: ## Run the model editor app + $(UV) run streamlit run editor.py + .PHONY: stlite stlite: setup serve ## Install, build, and serve the stlite app diff --git a/editor.py b/editor.py new file mode 100644 index 0000000..86b58f5 --- /dev/null +++ b/editor.py @@ -0,0 +1,10 @@ +import sys +import runpy +from pathlib import Path + +SRC_DIR = Path(__file__).resolve().parent / "src" +if str(SRC_DIR) not in sys.path: + sys.path.insert(0, str(SRC_DIR)) + +# Importing this module executes the Streamlit editor app definition. +runpy.run_module("epicc.editor.__main__", run_name="__main__") diff --git a/src/epicc/editor/__init__.py b/src/epicc/editor/__init__.py new file mode 100644 index 0000000..1fa41ff --- /dev/null +++ b/src/epicc/editor/__init__.py @@ -0,0 +1 @@ +"""EPICC Model Editor package.""" diff --git a/src/epicc/editor/__main__.py b/src/epicc/editor/__main__.py new file mode 100644 index 0000000..99a4f02 --- /dev/null +++ b/src/epicc/editor/__main__.py @@ -0,0 +1,575 @@ +""" +Model Editor – a Streamlit app for building and editing YAML model files. + +Provides a form-based GUI for constructing model definitions that conform +to the ``epicc.model.schema.Model`` Pydantic schema. Users can start from +scratch, upload an existing YAML file, validate the document in real time, +and download the result. +""" + +from __future__ import annotations + +import copy +from typing import Any + +import streamlit as st +from pydantic import ValidationError + +from epicc.editor.helpers import ( + DEFAULT_STATE, + build_model_dict, + serialize_to_yaml, + validate_model_dict, + yaml_to_state, +) + +# --------------------------------------------------------------------------- +# Page configuration +# --------------------------------------------------------------------------- + +st.set_page_config(page_title="EPICC Model Editor", layout="wide") +st.title("EPICC Model Editor") +st.markdown( + "Build, edit, and validate YAML model files for the **EPICC Cost Calculator**." +) + +# --------------------------------------------------------------------------- +# Widget-key versioning +# --------------------------------------------------------------------------- +# Streamlit caches widget values by key. When the list of authors (or +# parameters, etc.) changes structurally (upload, add, remove), cached +# values under old keys can shadow the correct data. We embed a version +# counter in every dynamic widget key so that a version bump forces +# Streamlit to create genuinely new widgets that honour the ``value`` +# parameter from the current list state. + + +def _bump_version() -> None: + """Increment the widget-key version counter.""" + st.session_state["_wv"] = st.session_state.get("_wv", 0) + 1 + + +# --------------------------------------------------------------------------- +# Modal dialogs +# --------------------------------------------------------------------------- + + +@st.dialog("YAML Serialization Error") +def _show_yaml_error(msg: str) -> None: + """Show a YAML serialization error in a modal dialog.""" + st.error("The current form data could not be serialized to YAML.") + st.code(msg) + + +@st.dialog("File Load Error") +def _show_upload_error(msg: str) -> None: + """Show a file-load error in a modal dialog.""" + st.error("The uploaded file could not be loaded.") + st.code(msg) + + +def _v() -> int: + """Return the current widget-key version.""" + return int(st.session_state.get("_wv", 0)) + + +# --------------------------------------------------------------------------- +# Callbacks – executed *before* the next script rerun +# --------------------------------------------------------------------------- + + +def _add_item(section: str, template: dict[str, Any]) -> None: + st.session_state[section].append(copy.deepcopy(template)) + _bump_version() + + +def _remove_item(section: str, idx: int) -> None: + items: list[Any] = st.session_state[section] + if 0 <= idx < len(items): + items.pop(idx) + _bump_version() + + +# --------------------------------------------------------------------------- +# Session-state initialization +# --------------------------------------------------------------------------- + + +def _init_state() -> None: + for key, default in DEFAULT_STATE.items(): + if key not in st.session_state: + if isinstance(default, (list, dict)): + st.session_state[key] = copy.deepcopy(default) + else: + st.session_state[key] = default + + +_init_state() + +# --------------------------------------------------------------------------- +# Upload widget +# --------------------------------------------------------------------------- + +with st.sidebar: + st.header("Load / Save") + uploaded = st.file_uploader("Upload a model YAML file", type=["yaml", "yml"]) + if uploaded is not None: + file_id = f"{uploaded.name}:{uploaded.size}" + if st.session_state.get("_uploaded_file_id") != file_id: + try: + state = yaml_to_state(uploaded.getvalue()) + except Exception as exc: + _show_upload_error(str(exc)) + else: + st.session_state["_uploaded_file_id"] = file_id + _bump_version() + for k, v in state.items(): + st.session_state[k] = v + st.rerun() + if st.session_state.get("_uploaded_file_id") == file_id: + st.success(f"Loaded **{uploaded.name}**") + +# --------------------------------------------------------------------------- +# Sync helper – read widget values back into the canonical list +# --------------------------------------------------------------------------- + + +def _sync_list( + section: str, + widgets: list[dict[str, Any]], +) -> None: + """Replace *section* in session state with *widgets* (values read from UI).""" + st.session_state[section] = widgets if widgets else copy.deepcopy( + DEFAULT_STATE.get(section, []) + ) + + +# --------------------------------------------------------------------------- +# Form sections +# --------------------------------------------------------------------------- + +meta_tab, params_tab, eqs_tab, scenarios_tab, report_tab, figures_tab = st.tabs( + ["Metadata", "Parameters", "Equations", "Scenarios", "Report", "Figures"] +) + +# ---- Metadata ---- +with meta_tab: + st.subheader("Model Metadata") + st.session_state["model_title"] = st.text_input( + "Title *", st.session_state["model_title"] + ) + st.session_state["model_description"] = st.text_area( + "Description *", st.session_state["model_description"], height=100 + ) + + st.markdown("**Authors**") + _authors: list[dict[str, str]] = st.session_state["authors"] + _updated_authors: list[dict[str, str]] = [] + for _i, _author in enumerate(_authors): + _cols = st.columns([3, 3, 1]) + _name = _cols[0].text_input( + "Name", _author["name"], key=f"author_name_{_v()}_{_i}" + ) + _email = _cols[1].text_input( + "Email", _author["email"], key=f"author_email_{_v()}_{_i}" + ) + _cols[2].button( + "X", key=f"rm_author_{_v()}_{_i}", + on_click=_remove_item, args=("authors", _i), + ) + _updated_authors.append({"name": _name, "email": _email}) + + st.button( + "+ Add author", + on_click=_add_item, + args=("authors", {"name": "", "email": ""}), + ) + _sync_list("authors", _updated_authors) + +# ---- Parameters ---- +with params_tab: + st.subheader("Parameters") + st.caption( + "Each parameter needs a unique **ID** (used in equations), a display **label**, " + "a **type**, and a **default** value. For enum parameters, specify options as " + "``KEY: Label`` lines." + ) + _params: list[dict[str, Any]] = st.session_state["parameters"] + _updated_params: list[dict[str, Any]] = [] + for _i, _p in enumerate(_params): + with st.expander( + _p.get("label") or _p.get("id") or f"Parameter {_i + 1}", + expanded=(_i == len(_params) - 1 and not _p.get("id")), + ): + _c1, _c2 = st.columns(2) + _pid = _c1.text_input("ID *", _p.get("id", ""), key=f"pid_{_v()}_{_i}") + _ptype = _c2.selectbox( + "Type *", + ["integer", "number", "string", "boolean", "enum"], + index=["integer", "number", "string", "boolean", "enum"].index( + _p.get("type", "number") + ), + key=f"ptype_{_v()}_{_i}", + ) + _plabel = st.text_input("Label *", _p.get("label", ""), key=f"plabel_{_v()}_{_i}") + _pdesc = st.text_input( + "Description", _p.get("description", ""), key=f"pdesc_{_v()}_{_i}" + ) + + _dc1, _dc2, _dc3 = st.columns(3) + _pdefault_raw = _dc1.text_input( + "Default *", str(_p.get("default", "")), key=f"pdef_{_v()}_{_i}" + ) + _pmin_raw = _dc2.text_input( + "Min", str(_p.get("min", "")), key=f"pmin_{_v()}_{_i}" + ) + _pmax_raw = _dc3.text_input( + "Max", str(_p.get("max", "")), key=f"pmax_{_v()}_{_i}" + ) + _punit = st.text_input("Unit", _p.get("unit", ""), key=f"punit_{_v()}_{_i}") + _prefs = st.text_area( + "References (one per line)", + _p.get("references", ""), + key=f"prefs_{_v()}_{_i}", + height=68, + ) + + _poptions = "" + if _ptype == "enum": + _poptions = st.text_area( + "Options (KEY: Label, one per line)", + _p.get("options", ""), + key=f"popts_{_v()}_{_i}", + height=68, + ) + + st.button( + "Remove parameter", key=f"rm_param_{_v()}_{_i}", + on_click=_remove_item, args=("parameters", _i), + ) + _updated_params.append( + { + "id": _pid, + "type": _ptype, + "label": _plabel, + "description": _pdesc, + "default": _pdefault_raw, + "min": _pmin_raw, + "max": _pmax_raw, + "unit": _punit, + "references": _prefs, + "options": _poptions, + } + ) + + st.button( + "+ Add parameter", + on_click=_add_item, + args=( + "parameters", + { + "id": "", + "type": "number", + "label": "", + "description": "", + "default": "0", + "min": "0", + "max": "100", + "unit": "", + "references": "", + "options": "", + }, + ), + ) + _sync_list("parameters", _updated_params) + +# ---- Equations ---- +with eqs_tab: + st.subheader("Equations") + st.caption( + "Each equation has a unique **ID**, a **label**, and a **compute** expression " + "that may reference parameter IDs, scenario variable names, or other equation IDs." + ) + _eqs: list[dict[str, Any]] = st.session_state["equations"] + _updated_eqs: list[dict[str, Any]] = [] + for _i, _eq in enumerate(_eqs): + with st.expander( + _eq.get("label") or _eq.get("id") or f"Equation {_i + 1}", + expanded=(_i == len(_eqs) - 1 and not _eq.get("id")), + ): + _c1, _c2 = st.columns(2) + _eid = _c1.text_input("ID *", _eq.get("id", ""), key=f"eid_{_v()}_{_i}") + _elabel = _c2.text_input( + "Label *", _eq.get("label", ""), key=f"elabel_{_v()}_{_i}" + ) + _ec1, _ec2 = st.columns(2) + _eunit = _ec1.text_input("Unit", _eq.get("unit", ""), key=f"eunit_{_v()}_{_i}") + _eoutput = _ec2.selectbox( + "Output type", + ["number", "integer"], + index=["number", "integer"].index( + _eq.get("output", "number") or "number" + ), + key=f"eoutput_{_v()}_{_i}", + ) + _ecompute = st.text_area( + "Compute expression *", + _eq.get("compute", ""), + key=f"ecomp_{_v()}_{_i}", + height=80, + ) + + st.button( + "Remove equation", key=f"rm_eq_{_v()}_{_i}", + on_click=_remove_item, args=("equations", _i), + ) + _updated_eqs.append( + { + "id": _eid, + "label": _elabel, + "unit": _eunit, + "output": _eoutput, + "compute": _ecompute, + } + ) + + st.button( + "+ Add equation", + on_click=_add_item, + args=( + "equations", + {"id": "", "label": "", "unit": "", "output": "number", "compute": ""}, + ), + ) + _sync_list("equations", _updated_eqs) + +# ---- Scenarios ---- +with scenarios_tab: + st.subheader("Scenarios") + st.caption( + "Define scenarios with a unique **ID**, a **label**, and scenario **variables** " + "as ``name: value`` lines (one per line)." + ) + _scenarios: list[dict[str, Any]] = st.session_state["scenarios"] + _updated_scenarios: list[dict[str, Any]] = [] + for _i, _sc in enumerate(_scenarios): + with st.expander( + _sc.get("label") or _sc.get("id") or f"Scenario {_i + 1}", + expanded=(_i == len(_scenarios) - 1 and not _sc.get("id")), + ): + _c1, _c2 = st.columns(2) + _sid = _c1.text_input("ID *", _sc.get("id", ""), key=f"sid_{_v()}_{_i}") + _slabel = _c2.text_input( + "Label *", _sc.get("label", ""), key=f"slabel_{_v()}_{_i}" + ) + _svars = st.text_area( + "Variables (name: value, one per line) *", + _sc.get("vars", ""), + key=f"svars_{_v()}_{_i}", + height=80, + ) + + st.button( + "Remove scenario", key=f"rm_sc_{_v()}_{_i}", + on_click=_remove_item, args=("scenarios", _i), + ) + _updated_scenarios.append( + {"id": _sid, "label": _slabel, "vars": _svars} + ) + + st.button( + "+ Add scenario", + on_click=_add_item, + args=("scenarios", {"id": "", "label": "", "vars": ""}), + ) + _sync_list("scenarios", _updated_scenarios) + +# ---- Report ---- +with report_tab: + st.subheader("Report Blocks") + st.caption( + "Build the report from blocks. **Markdown** blocks hold free-form text. " + "**Table** and **Graph** blocks reference equation IDs. " + "Row format: ``Label | equation_id [| emphasis]``." + ) + _blocks: list[dict[str, Any]] = st.session_state["report_blocks"] + _updated_blocks: list[dict[str, Any]] = [] + for _i, _blk in enumerate(_blocks): + _btype = _blk.get("type", "markdown") + with st.expander( + f"{_btype.title()} block {_i + 1}", expanded=(_i == len(_blocks) - 1) + ): + _new_type = st.selectbox( + "Block type", + ["markdown", "table", "figure", "graph"], + index=["markdown", "table", "figure", "graph"].index(_btype), + key=f"btype_{_v()}_{_i}", + ) + _entry: dict[str, Any] = {"type": _new_type} + + if _new_type == "markdown": + _entry["content"] = st.text_area( + "Content (Markdown)", + _blk.get("content", ""), + key=f"bcontent_{_v()}_{_i}", + height=150, + ) + elif _new_type == "table": + _entry["caption"] = st.text_input( + "Caption", _blk.get("caption", ""), key=f"tcap_{_v()}_{_i}" + ) + _entry["columns"] = st.text_input( + "Columns (comma-separated scenario IDs, blank = all)", + _blk.get("columns", ""), + key=f"tcols_{_v()}_{_i}", + ) + _entry["rows"] = st.text_area( + "Rows (Label | equation_id [| emphasis])", + _blk.get("rows", ""), + key=f"trows_{_v()}_{_i}", + height=120, + ) + elif _new_type == "figure": + _entry["id"] = st.text_input( + "Figure ID", _blk.get("id", ""), key=f"fid_{_v()}_{_i}" + ) + elif _new_type == "graph": + _entry["kind"] = st.selectbox( + "Graph kind", + ["bar", "stacked_bar", "line", "pie"], + index=["bar", "stacked_bar", "line", "pie"].index( + _blk.get("kind", "bar") + ), + key=f"gkind_{_v()}_{_i}", + ) + _entry["title"] = st.text_input( + "Title", _blk.get("title", ""), key=f"gtitle_{_v()}_{_i}" + ) + _entry["caption"] = st.text_input( + "Caption", _blk.get("caption", ""), key=f"gcap_{_v()}_{_i}" + ) + _entry["columns"] = st.text_input( + "Columns (comma-separated scenario IDs, blank = all)", + _blk.get("columns", ""), + key=f"gcols_{_v()}_{_i}", + ) + _entry["rows"] = st.text_area( + "Rows (Label | equation_id [| emphasis])", + _blk.get("rows", ""), + key=f"grows_{_v()}_{_i}", + height=120, + ) + + st.button( + "Remove block", key=f"rm_blk_{_v()}_{_i}", + on_click=_remove_item, args=("report_blocks", _i), + ) + _updated_blocks.append(_entry) + + st.button( + "+ Add report block", + on_click=_add_item, + args=("report_blocks", {"type": "markdown", "content": ""}), + ) + _sync_list("report_blocks", _updated_blocks) + +# ---- Figures ---- +with figures_tab: + st.subheader("Figures") + st.caption( + "Define custom figures with Python code. These are referenced from " + "**figure** report blocks by their **ID**." + ) + _figs: list[dict[str, Any]] = st.session_state["figures"] + _updated_figs: list[dict[str, Any]] = [] + for _i, _fig in enumerate(_figs): + with st.expander( + _fig.get("title") or _fig.get("id") or f"Figure {_i + 1}", + expanded=True, + ): + _c1, _c2 = st.columns(2) + _fid = _c1.text_input( + "ID *", _fig.get("id", ""), key=f"figid_{_v()}_{_i}" + ) + _ftitle = _c2.text_input( + "Title *", _fig.get("title", ""), key=f"figtitle_{_v()}_{_i}" + ) + _falt = st.text_input( + "Alt text", _fig.get("alt_text", "") or "", key=f"figalt_{_v()}_{_i}" + ) + _fcode = st.text_area( + "Python code", + _fig.get("py_code", "") or "", + key=f"figcode_{_v()}_{_i}", + height=120, + ) + + st.button( + "Remove figure", key=f"rm_fig_{_v()}_{_i}", + on_click=_remove_item, args=("figures", _i), + ) + _updated_figs.append( + { + "id": _fid, + "title": _ftitle, + "alt_text": _falt, + "py_code": _fcode, + } + ) + + st.button( + "+ Add figure", + on_click=_add_item, + args=( + "figures", + {"id": "", "title": "", "alt_text": "", "py_code": ""}, + ), + ) + _sync_list("figures", _updated_figs) + +# --------------------------------------------------------------------------- +# Validate & Download +# --------------------------------------------------------------------------- + +st.divider() +val_col, dl_col = st.columns([1, 1]) + +with val_col: + if st.button("Validate model", type="primary", use_container_width=True): + doc = build_model_dict({str(k): v for k, v in st.session_state.items()}) + try: + validate_model_dict(doc) + st.success("Model is valid!") + except ValidationError as exc: + issues = exc.errors() + issue_word = "issue" if len(issues) == 1 else "issues" + st.error(f"Validation failed ({len(issues)} {issue_word})") + with st.expander("Validation details", expanded=True): + for issue in issues: + loc_parts = issue.get("loc", []) + path = " > ".join(str(p) for p in loc_parts) if loc_parts else "(root)" + st.write(f"- **{path}**: {issue.get('msg', 'Invalid value')}") + +with dl_col: + doc = build_model_dict({str(k): v for k, v in st.session_state.items()}) + try: + yaml_bytes = serialize_to_yaml(doc) + except Exception as exc: + yaml_bytes = b"" + _show_yaml_error(str(exc)) + + if yaml_bytes: + st.download_button( + "Download YAML", + data=yaml_bytes, + file_name="model.yaml", + mime="text/yaml", + use_container_width=True, + ) + +# Show a live YAML preview in a collapsed section +with st.expander("YAML Preview"): + if yaml_bytes: + st.code(yaml_bytes.decode("utf-8"), language="yaml") + else: + st.info("Fill in the form fields above to see a preview.") diff --git a/src/epicc/editor/helpers.py b/src/epicc/editor/helpers.py new file mode 100644 index 0000000..a2249f7 --- /dev/null +++ b/src/epicc/editor/helpers.py @@ -0,0 +1,398 @@ +"""Pure-logic helpers for the model editor. + +These functions contain no Streamlit dependencies and can be tested in +plain pytest without mocking the Streamlit runtime. +""" + +from __future__ import annotations + +import io +from typing import Any + +from epicc.formats import opaque_to_typed +from epicc.formats.yaml import YAMLFormat +from epicc.model.schema import Model + +# --------------------------------------------------------------------------- +# Coercion helpers +# --------------------------------------------------------------------------- + + +def coerce_numeric(value: str) -> int | float | str | bool: + """Attempt to coerce a string to a numeric type or boolean.""" + if value.lower() in ("true", "false"): + return value.lower() == "true" + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + return value + + +def coerce_numeric_or_none(value: str) -> int | float | None: + """Coerce to numeric or return ``None`` for empty / non-numeric strings.""" + if not value.strip(): + return None + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + return None + + +# --------------------------------------------------------------------------- +# Parsing helpers +# --------------------------------------------------------------------------- + + +def parse_key_value_lines(text: str) -> dict[str, Any]: + """Parse ``key: value`` lines into a dict, coercing numeric values.""" + result: dict[str, Any] = {} + for line in text.strip().splitlines(): + if ":" not in line: + continue + key, _, val = line.partition(":") + val = val.strip() + try: + result[key.strip()] = int(val) + except ValueError: + try: + result[key.strip()] = float(val) + except ValueError: + result[key.strip()] = val + return result + + +def parse_table_rows(text: str) -> list[dict[str, Any]]: + """Parse ``label | value [| emphasis]`` lines back to dicts.""" + rows: list[dict[str, Any]] = [] + for line in text.strip().splitlines(): + parts = [p.strip() for p in line.split("|")] + if len(parts) >= 2: + row: dict[str, Any] = {"label": parts[0], "value": parts[1]} + if len(parts) >= 3 and parts[2]: + row["emphasis"] = parts[2] + rows.append(row) + return rows + + +def table_row_to_str(row: dict[str, str]) -> str: + """Serialize a table-row dict to a ``label | value [| emphasis]`` string.""" + parts = [row.get("label", ""), row.get("value", "")] + emphasis = row.get("emphasis", "") + if emphasis: + parts.append(emphasis) + return " | ".join(parts) + + +# --------------------------------------------------------------------------- +# YAML load → session-state dicts +# --------------------------------------------------------------------------- + +# The following type alias describes the flat dict shape stored in +# ``st.session_state`` for each section of the editor form. +EditorState = dict[str, Any] + +#: Default state values for a blank model document. +DEFAULT_STATE: dict[str, Any] = { + "model_title": "", + "model_description": "", + "authors": [{"name": "", "email": ""}], + "parameters": [ + { + "id": "", + "type": "number", + "label": "", + "description": "", + "default": 0.0, + "min": 0.0, + "max": 100.0, + "unit": "", + "references": "", + "options": "", + } + ], + "equations": [{"id": "", "label": "", "unit": "", "output": "number", "compute": ""}], + "groups": [], + "scenarios": [{"id": "", "label": "", "vars": ""}], + "report_blocks": [{"type": "markdown", "content": ""}], + "figures": [], +} + + +def yaml_to_state(raw: bytes) -> EditorState: + """Parse raw YAML bytes into a flat *state* dict for the editor form. + + Returns a dict whose keys match :data:`DEFAULT_STATE`. + """ + fmt = YAMLFormat("upload.yaml") + data, _ = fmt.read(io.BytesIO(raw)) + + state: EditorState = {} + + state["model_title"] = data.get("title", "") + state["model_description"] = data.get("description", "") + + # Authors + authors_raw = data.get("authors", []) + state["authors"] = [ + {"name": a.get("name", ""), "email": a.get("email", "")} for a in authors_raw + ] or [{"name": "", "email": ""}] + + # Parameters + params_raw: dict = data.get("parameters", {}) + state["parameters"] = [ + { + "id": pid, + "type": p.get("type", "number"), + "label": p.get("label", ""), + "description": p.get("description", ""), + "default": p.get("default", 0), + "min": p.get("min") if p.get("min") is not None else "", + "max": p.get("max") if p.get("max") is not None else "", + "unit": p.get("unit", ""), + "references": "\n".join(p.get("references") or []), + "options": "\n".join( + f"{k}: {v}" for k, v in (p.get("options") or {}).items() + ), + } + for pid, p in params_raw.items() + ] or DEFAULT_STATE["parameters"] + + # Equations + eqs_raw: dict = data.get("equations", {}) + state["equations"] = [ + { + "id": eid, + "label": e.get("label", ""), + "unit": e.get("unit", ""), + "output": e.get("output", "number") or "number", + "compute": e.get("compute", "").strip(), + } + for eid, e in eqs_raw.items() + ] or DEFAULT_STATE["equations"] + + # Groups (pass through) + state["groups"] = data.get("groups", []) + + # Scenarios + scenarios_raw = data.get("scenarios", []) + state["scenarios"] = [ + { + "id": s.get("id", ""), + "label": s.get("label", ""), + "vars": "\n".join(f"{k}: {v}" for k, v in (s.get("vars") or {}).items()), + } + for s in scenarios_raw + ] or DEFAULT_STATE["scenarios"] + + # Report blocks + blocks_raw = data.get("report", []) + report_blocks: list[dict[str, Any]] = [] + for b in blocks_raw: + btype = b.get("type", "markdown") + entry: dict[str, Any] = {"type": btype} + if btype == "markdown": + entry["content"] = b.get("content", "") + elif btype == "table": + entry["caption"] = b.get("caption", "") + entry["columns"] = ", ".join(b.get("columns") or []) + entry["rows"] = "\n".join( + table_row_to_str(r) for r in b.get("rows", []) + ) + elif btype == "figure": + entry["id"] = b.get("id", "") + elif btype == "graph": + entry["kind"] = b.get("kind", "bar") + entry["title"] = b.get("title", "") + entry["caption"] = b.get("caption", "") + entry["columns"] = ", ".join(b.get("columns") or []) + entry["rows"] = "\n".join( + table_row_to_str(r) for r in b.get("rows", []) + ) + report_blocks.append(entry) + state["report_blocks"] = report_blocks or DEFAULT_STATE["report_blocks"] + + # Figures + figs_raw = data.get("figures", []) + state["figures"] = [ + { + "id": f.get("id", ""), + "title": f.get("title", ""), + "alt_text": f.get("alt-text", f.get("alt_text", "")), + "py_code": f.get("py-code", f.get("py_code", "")), + } + for f in figs_raw + ] or [] + + return state + + +# --------------------------------------------------------------------------- +# Session-state → model document dict +# --------------------------------------------------------------------------- + + +def build_model_dict(state: EditorState) -> dict[str, Any]: + """Assemble a model document dict from the flat editor *state*.""" + doc: dict[str, Any] = { + "title": state.get("model_title", ""), + "description": state.get("model_description", ""), + } + + # Authors (omit empty entries) + doc["authors"] = [ + {k: v for k, v in a.items() if v} + for a in state.get("authors", []) + if a.get("name") + ] + + # Parameters + parameters: dict[str, Any] = {} + for p in state.get("parameters", []): + pid = p["id"].strip() + if not pid: + continue + param: dict[str, Any] = { + "type": p["type"], + "label": p["label"], + "default": coerce_numeric(str(p["default"])), + } + if p.get("description"): + param["description"] = p["description"] + pmin = coerce_numeric_or_none(str(p.get("min", ""))) + pmax = coerce_numeric_or_none(str(p.get("max", ""))) + if pmin is not None: + param["min"] = pmin + if pmax is not None: + param["max"] = pmax + if p.get("unit"): + param["unit"] = p["unit"] + refs = [r.strip() for r in p.get("references", "").splitlines() if r.strip()] + if refs: + param["references"] = refs + if p["type"] == "enum" and p.get("options"): + options = parse_key_value_lines(p["options"]) + param["options"] = {str(k): str(v) for k, v in options.items()} + parameters[pid] = param + doc["parameters"] = parameters + + # Equations + equations: dict[str, Any] = {} + for eq in state.get("equations", []): + eid = eq["id"].strip() + if not eid: + continue + entry: dict[str, Any] = {"label": eq["label"], "compute": eq["compute"]} + if eq.get("unit"): + entry["unit"] = eq["unit"] + if eq.get("output"): + entry["output"] = eq["output"] + equations[eid] = entry + doc["equations"] = equations + + # Groups (pass through) + groups = state.get("groups") + if groups: + doc["groups"] = groups + + # Scenarios + scenarios: list[dict[str, Any]] = [] + for sc in state.get("scenarios", []): + sid = sc["id"].strip() + if not sid: + continue + scenarios.append( + { + "id": sid, + "label": sc["label"], + "vars": parse_key_value_lines(sc.get("vars", "")), + } + ) + doc["scenarios"] = scenarios + + # Report blocks + report: list[dict[str, Any]] = [] + for blk in state.get("report_blocks", []): + btype = blk["type"] + if btype == "markdown": + report.append({"type": "markdown", "content": blk.get("content", "")}) + elif btype == "table": + entry_t: dict[str, Any] = {"type": "table"} + if blk.get("caption"): + entry_t["caption"] = blk["caption"] + cols = [c.strip() for c in blk.get("columns", "").split(",") if c.strip()] + if cols: + entry_t["columns"] = cols + entry_t["rows"] = parse_table_rows(blk.get("rows", "")) + report.append(entry_t) + elif btype == "figure": + report.append({"type": "figure", "id": blk.get("id", "")}) + elif btype == "graph": + entry_g: dict[str, Any] = { + "type": "graph", + "kind": blk.get("kind", "bar"), + } + if blk.get("title"): + entry_g["title"] = blk["title"] + if blk.get("caption"): + entry_g["caption"] = blk["caption"] + cols_g = [c.strip() for c in blk.get("columns", "").split(",") if c.strip()] + if cols_g: + entry_g["columns"] = cols_g + entry_g["rows"] = parse_table_rows(blk.get("rows", "")) + report.append(entry_g) + doc["report"] = report + + # Figures + figures: list[dict[str, Any]] = [] + for fig in state.get("figures", []): + fid = fig["id"].strip() + if not fid: + continue + entry_f: dict[str, Any] = {"id": fid, "title": fig["title"]} + if fig.get("alt_text"): + entry_f["alt-text"] = fig["alt_text"] + if fig.get("py_code"): + entry_f["py-code"] = fig["py_code"] + figures.append(entry_f) + doc["figures"] = figures + + return doc + + +def validate_model_dict(doc: dict[str, Any]) -> Model: + """Validate a model document dict against the ``Model`` schema. + + Raises: + pydantic.ValidationError: if *doc* does not conform to the schema. + """ + return opaque_to_typed(doc, Model) + + +def serialize_to_yaml(doc: dict[str, Any]) -> bytes: + """Serialize a model document dict to YAML bytes.""" + fmt = YAMLFormat("model.yaml") + return fmt.write(doc) + + +__all__ = [ + "DEFAULT_STATE", + "EditorState", + "build_model_dict", + "coerce_numeric", + "coerce_numeric_or_none", + "parse_key_value_lines", + "parse_table_rows", + "serialize_to_yaml", + "table_row_to_str", + "validate_model_dict", + "yaml_to_state", +] diff --git a/src/epicc/formats/__init__.py b/src/epicc/formats/__init__.py index 4e094bb..777716d 100644 --- a/src/epicc/formats/__init__.py +++ b/src/epicc/formats/__init__.py @@ -76,12 +76,11 @@ def get_format(path: Path | str) -> BaseFormat: def opaque_to_typed(data: dict, model: type[M]) -> M: """ Validate the given data against a given Pydantic model. - """ - try: - return model.model_validate(data) - except Exception as e: - raise ValueError(f"Data validation failed: {e}") from e + Raises: + pydantic.ValidationError: if *data* does not conform to *model*'s schema. + """ + return model.model_validate(data) def read_from_format(path: Path | str, data: IO, model: type[M]) -> tuple[M, Any]: diff --git a/tests/epicc/test_editor.py b/tests/epicc/test_editor.py new file mode 100644 index 0000000..d559381 --- /dev/null +++ b/tests/epicc/test_editor.py @@ -0,0 +1,304 @@ +"""Tests for the non-UI helper functions in epicc.editor.helpers.""" + +from __future__ import annotations + +import io +from typing import Any + +import pytest + +from epicc.editor.helpers import ( + build_model_dict, + coerce_numeric, + coerce_numeric_or_none, + parse_key_value_lines, + parse_table_rows, + serialize_to_yaml, + table_row_to_str, + validate_model_dict, + yaml_to_state, +) +from epicc.formats import opaque_to_typed +from epicc.formats.yaml import YAMLFormat +from epicc.model.schema import Model +from pydantic import ValidationError + + +# --------------------------------------------------------------------------- +# coerce_numeric +# --------------------------------------------------------------------------- + + +class TestCoerceNumeric: + def test_integer(self) -> None: + assert coerce_numeric("42") == 42 + + def test_float(self) -> None: + assert coerce_numeric("3.14") == pytest.approx(3.14) + + def test_bool_true(self) -> None: + assert coerce_numeric("true") is True + + def test_bool_false(self) -> None: + assert coerce_numeric("False") is False + + def test_string_passthrough(self) -> None: + assert coerce_numeric("hello") == "hello" + + +# --------------------------------------------------------------------------- +# coerce_numeric_or_none +# --------------------------------------------------------------------------- + + +class TestCoerceNumericOrNone: + def test_integer(self) -> None: + assert coerce_numeric_or_none("10") == 10 + + def test_float(self) -> None: + assert coerce_numeric_or_none("2.5") == pytest.approx(2.5) + + def test_empty_string(self) -> None: + assert coerce_numeric_or_none("") is None + + def test_whitespace(self) -> None: + assert coerce_numeric_or_none(" ") is None + + def test_non_numeric(self) -> None: + assert coerce_numeric_or_none("abc") is None + + +# --------------------------------------------------------------------------- +# parse_key_value_lines +# --------------------------------------------------------------------------- + + +class TestParseKeyValueLines: + def test_basic(self) -> None: + result = parse_key_value_lines("n_cases: 22\nrate: 0.5") + assert result == {"n_cases": 22, "rate": pytest.approx(0.5)} + + def test_string_value(self) -> None: + result = parse_key_value_lines("name: hello world") + assert result == {"name": "hello world"} + + def test_empty(self) -> None: + assert parse_key_value_lines("") == {} + + def test_skips_lines_without_colon(self) -> None: + result = parse_key_value_lines("no colon here\nk: v") + assert result == {"k": "v"} + + +# --------------------------------------------------------------------------- +# parse_table_rows / table_row_to_str round-trip +# --------------------------------------------------------------------------- + + +class TestTableRows: + def test_round_trip_simple(self) -> None: + text = "Hospitalization | eq_hosp" + rows = parse_table_rows(text) + assert len(rows) == 1 + assert rows[0] == {"label": "Hospitalization", "value": "eq_hosp"} + assert table_row_to_str(rows[0]) == text + + def test_round_trip_with_emphasis(self) -> None: + text = "TOTAL | eq_total | strong" + rows = parse_table_rows(text) + assert len(rows) == 1 + assert rows[0] == {"label": "TOTAL", "value": "eq_total", "emphasis": "strong"} + assert table_row_to_str(rows[0]) == text + + def test_multiline(self) -> None: + text = "A | eq_a\nB | eq_b | em" + rows = parse_table_rows(text) + assert len(rows) == 2 + + def test_empty(self) -> None: + assert parse_table_rows("") == [] + + +# --------------------------------------------------------------------------- +# build_model_dict +# --------------------------------------------------------------------------- + + +class TestBuildModelDict: + def _minimal_state(self) -> dict[str, Any]: + return { + "model_title": "Test Model", + "model_description": "A test.", + "authors": [{"name": "Tester", "email": ""}], + "parameters": [ + { + "id": "x", + "type": "number", + "label": "X", + "description": "", + "default": "1.0", + "min": "0", + "max": "100", + "unit": "", + "references": "", + "options": "", + } + ], + "equations": [ + {"id": "eq_x", "label": "X value", "unit": "", "output": "number", "compute": "x * 2"} + ], + "groups": [], + "scenarios": [ + {"id": "base", "label": "Base", "vars": "n: 1"} + ], + "report_blocks": [ + {"type": "markdown", "content": "Hello"} + ], + "figures": [], + } + + def test_produces_valid_dict(self) -> None: + doc = build_model_dict(self._minimal_state()) + assert doc["title"] == "Test Model" + assert "x" in doc["parameters"] + assert "eq_x" in doc["equations"] + + def test_skips_empty_parameter_ids(self) -> None: + state = self._minimal_state() + state["parameters"].append( + { + "id": "", + "type": "number", + "label": "Y", + "description": "", + "default": "0", + "min": "", + "max": "", + "unit": "", + "references": "", + "options": "", + } + ) + doc = build_model_dict(state) + assert len(doc["parameters"]) == 1 + + def test_enum_options_parsed(self) -> None: + state = self._minimal_state() + state["parameters"][0]["type"] = "enum" + state["parameters"][0]["default"] = "A" + state["parameters"][0]["options"] = "A: Option A\nB: Option B" + doc = build_model_dict(state) + assert doc["parameters"]["x"]["options"] == {"A": "Option A", "B": "Option B"} + + +# --------------------------------------------------------------------------- +# End-to-end: build + validate +# --------------------------------------------------------------------------- + + +class TestValidateModelDict: + def _minimal_doc(self) -> dict[str, Any]: + return { + "title": "Test Model", + "description": "A test.", + "authors": [{"name": "Tester"}], + "parameters": { + "x": {"type": "number", "label": "X", "default": 1.0} + }, + "equations": { + "eq_x": {"label": "X value", "compute": "x * 2"} + }, + "scenarios": [{"id": "base", "label": "Base", "vars": {"n": 1}}], + "report": [{"type": "markdown", "content": "Hello"}], + "figures": [], + } + + def test_valid_minimal_doc(self) -> None: + model = validate_model_dict(self._minimal_doc()) + assert model.title == "Test Model" + + def test_missing_title_fails(self) -> None: + doc = self._minimal_doc() + del doc["title"] + with pytest.raises(ValidationError): + validate_model_dict(doc) + + def test_invalid_parameter_type_fails(self) -> None: + doc = self._minimal_doc() + doc["parameters"]["x"]["type"] = "invalid_type" + with pytest.raises(ValidationError): + validate_model_dict(doc) + + def test_enum_without_options_fails(self) -> None: + doc = self._minimal_doc() + doc["parameters"]["x"]["type"] = "enum" + with pytest.raises(ValidationError): + validate_model_dict(doc) + + def test_enum_with_options_valid(self) -> None: + doc = self._minimal_doc() + doc["parameters"]["x"]["type"] = "enum" + doc["parameters"]["x"]["default"] = "A" + doc["parameters"]["x"]["options"] = {"A": "Option A", "B": "Option B"} + model = validate_model_dict(doc) + assert model.parameters["x"].type == "enum" + + +# --------------------------------------------------------------------------- +# YAML serialization round-trip +# --------------------------------------------------------------------------- + + +class TestYAMLRoundTrip: + def test_write_and_read(self) -> None: + doc: dict[str, Any] = { + "title": "RT Model", + "description": "round-trip test", + "authors": [], + "parameters": { + "p": {"type": "number", "label": "P", "default": 5.0} + }, + "equations": {"eq": {"label": "E", "compute": "p + 1"}}, + "scenarios": [{"id": "s1", "label": "S1", "vars": {"n": 10}}], + "report": [{"type": "markdown", "content": "hi"}], + "figures": [], + } + + yaml_bytes = serialize_to_yaml(doc) + fmt = YAMLFormat("test.yaml") + data, _ = fmt.read(io.BytesIO(yaml_bytes)) + + validated = opaque_to_typed(data, Model) + assert validated.title == "RT Model" + assert "p" in validated.parameters + + +# --------------------------------------------------------------------------- +# yaml_to_state round-trip +# --------------------------------------------------------------------------- + + +class TestYAMLToState: + def test_loads_measles_yaml(self) -> None: + import importlib.resources + + measles_res = importlib.resources.files("epicc.model.models").joinpath("measles.yaml") + raw = measles_res.read_bytes() + state = yaml_to_state(raw) + + assert state["model_title"] == "Measles Outbreak Cost Estimation" + assert len(state["parameters"]) > 0 + assert len(state["equations"]) > 0 + assert len(state["scenarios"]) > 0 + + def test_state_to_dict_validates(self) -> None: + """Load a real model, convert through state, rebuild, and validate.""" + import importlib.resources + + measles_res = importlib.resources.files("epicc.model.models").joinpath("measles.yaml") + raw = measles_res.read_bytes() + state = yaml_to_state(raw) + + doc = build_model_dict(state) + model = validate_model_dict(doc) + assert model.title == "Measles Outbreak Cost Estimation"