diff --git a/.gitignore b/.gitignore index 47dcbb6..fb80dd3 100644 --- a/.gitignore +++ b/.gitignore @@ -129,3 +129,4 @@ Thumbs.db ~* /runs DocLayNet-base/ +datasets/ diff --git a/README.md b/README.md index e0b752d..4648370 100644 --- a/README.md +++ b/README.md @@ -52,10 +52,10 @@ A control loop that probes what the model remembers, tunes bottleneck budgets, r #### Operating the Auto‑IB Orchestrator -* Wrap each STM append in a :class:`UsageEvent` that can carry the compressed tensor **and** a :class:`CompressionRecord`. Telemetry such as selected indices, token counts, and IB metrics are captured under ``metadata["compression"]`` and mirrored in the STM index. +* Wrap each STM append in a :class:`UsageEvent` that can carry the compressed tensor **and** a :class:`CompressionRecord`. Telemetry such as selected indices, token counts, IB/Ml lower bounds, constraint verdicts, and canonical cell artefacts are captured under ``metadata["compression"]`` and mirrored in the STM index. * Call :func:`Orchestrator.tune_budget` periodically (e.g., after processing a batch). The default :class:`CompressionRatioBudgetStrategy` looks at the recent compression ratios and adjusts ``Orchestrator.config.target_budget`` upward when quality drops or downward when utilisation saturates. Inspect ``Orchestrator.budget_history`` to audit the decisions. * Trigger :func:`Orchestrator.run_retention_probe` on a schedule to sample STM entries, reconstruct them with the stored :class:`~nd_llm.bottleneck.ib.IBottleneck` telemetry, and monitor reconstruction quality / drift. Any missing or malformed telemetry is surfaced under ``probe["issues"]``. -* Use the new STM query helpers such as :func:`STM.query` or :func:`STM.list_by_alignment` to fetch aligned batches of entries (e.g., all shards for a ``session_id``) without loading every payload. +* Use the new STM query helpers such as :func:`STM.query` or :func:`STM.list_by_alignment`, plus the holographic superposition channels (``write_superposition``/``read_superposition``) to fetch aligned batches of entries or aggregate long-horizon fingerprints without loading every payload. --- @@ -175,6 +175,7 @@ usage_key = orchestrator.log_usage_event( ### Bottleneck tuning knobs * **Objective / scoring:** pass `objective="l2-norm"` (default) for magnitude gating or `objective="query-dot"` to enable the built-in query-conditioned scorer. You can also inject your own scorer via the `scorer` argument; it receives `(field, embeddings, metadata, context)` and should return a score per token. +* **Mutual-information blending:** supply an `MIProxy` + `mi_targets` via :func:`build_mi_proxy_context` and set `mi_score_weight` to trade off between the base scorer and per-token MI similarities. * **Query context:** provide query embeddings or other conditioning signals through the `context` mapping (e.g. `{"query_embedding": vector}`) and they will be forwarded to the scoring strategy. * **Budget allocator:** override `budget_allocator` to customize per-field sub-budgets. The default `RegistryAwareBudgetAllocator` inspects registry metadata (salience flags, alignment keys, optional `budget_weight`) and records the resulting `field_budgets` and `allocation_weights` in `CompressionTelemetry`. * **Metrics:** every call to `compress` returns a `CompressionResult.metrics` dictionary with IB/RD proxies such as `ib_proxy`, `rd_proxy`, and an `embedding_reconstruction_error` computed from kept vs. dropped embeddings. @@ -203,58 +204,68 @@ affinity: ## Using real datasets -Fetch the official FUNSD release and the compact DocLayNet-base snapshot with the -helper script. By -default datasets are placed in ``~/.cache/n-dimensional-llm``; override the -location via ``ND_LLM_DATA_CACHE`` or ``--cache-dir`` if required. +The [CORD receipt dataset](https://huggingface.co/datasets/naver-clova-ix/cord-v2) is wired into the benchmark harness for a realistic document-understanding task. Install the optional dependency stack (``pip install .[benchmarks]`` or the explicit packages below) when you want the full dataset instead of the bundled JSONL sample: ```bash -python scripts/download_datasets.py +pip install datasets pillow ``` -The FUNSD benchmark helper consumes the extracted ``funsd`` directory. Set -``dataset_size=0`` to keep the full corpus and ``use_sample=False`` to disable -the bundled JSON sample. - ```bash python - <<'PY' -from pathlib import Path - -from benchmarks.doc_understanding import run_funsd_benchmark - -cache = Path.home() / ".cache" / "n-dimensional-llm" -report = run_funsd_benchmark( - budget_values=(8, 12, 16), - data_root=cache / "funsd", - dataset_size=0, - use_sample=False, +from benchmarks.doc_understanding import run_cord_benchmark + +report = run_cord_benchmark( + budget_values=(4, 8, 12), + dataset_size=8, + use_sample=True, # flip to False to stream the HF split or use a local directory + data_root="datasets", # automatically combines subdirectories named CORD* + threshold=250_000, ) print(report["budgets"][0]["metrics"]) # inspect results PY ``` -DocLayNet follows the same convention, reading from the ``doclaynet`` directory -inside the cache. The helper uses the -`pierreguillou/DocLayNet-base `_ -mirror hosted on Hugging Face instead of the much larger full corpus. +Set ``data_root`` to the directory that contains the official ``train/dev/test/json`` folders (or to a parent folder that has subdirectories named ``CORD*``—the loader will merge them) once you download the [CORD release](https://github.com/clovaai/cord). -```bash -python - <<'PY' -from pathlib import Path +### ChartQA field benchmark -from benchmarks.doc_understanding import run_doclaynet_benchmark +ChartQA-style chart reasoning now has a lightweight harness that exercises question text alongside structured chart metadata. The default configuration consumes the bundled sample; point it at the official dataset (e.g. [lmms-lab/chartqa](https://huggingface.co/lmms-lab/chartqa) or the [GitHub release](https://github.com/IBM/chartqa)) when you want full coverage: -cache = Path.home() / ".cache" / "n-dimensional-llm" -report = run_doclaynet_benchmark( - budget_values=(6, 12, 18), - data_root=cache / "doclaynet", - dataset_size=0, - use_sample=False, +```python +from benchmarks.chartqa import run_chartqa_benchmark + +report = run_chartqa_benchmark( + budget_values=(2, 4, 6), + dataset_size=4, + use_sample=True, # set False once you've downloaded the dataset locally ) -print(report["budgets"][0]["metrics"]) # inspect results -PY +print(report) ``` +### Rate–distortion & Fano audits + +Use the `scripts/rd_audit.py` CLI to sweep token budgets for the CORD benchmark in both N-D and text-only configurations, then compute empirical rate–distortion curves and Fano-consistent error bounds: + +```bash +python -m scripts.rd_audit --budgets 4 8 12 --dataset-size 8 --use-sample +``` + +Both modes record the mean mutual-information lower bound (from the MI proxy) alongside each budget’s accuracy/distortion so you can visualise the dominance of N-D inputs at fixed rate. + +### Local LLM harness (Ollama) + +If you have [Ollama](https://ollama.com) with `llama3.1:8b` installed locally, you can replay compressed field summaries into the model for qualitative checks: + +```bash +python -m scripts.ollama_harness \ + --dataset cord \ + --data-root datasets \ + --use-sample \ + --dry-run +``` + +Drop `--dry-run` to stream the prompt to your Ollama instance (`http://127.0.0.1:11434` by default). ChartQA prompts are supported as well (`--dataset chartqa`). + ### Runnable multi-field invoice demo Kick the tyres with the maintained invoice walk-through that wires the registry, stub encoders, bottleneck, STM, and orchestrator together: diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py index 0485332..b7a9e72 100644 --- a/benchmarks/__init__.py +++ b/benchmarks/__init__.py @@ -7,29 +7,25 @@ __all__ = [ "run_benchmark", - "run_doclaynet_benchmark", - "run_funsd_benchmark", + "run_cord_benchmark", + "run_chartqa_benchmark", "run_long_qa_benchmark", "run_video_qa_benchmark", - "build_doclaynet_registry", - "build_doclaynet_encoders", - "doclaynet_fields", - "doclaynet_contains_table", - "load_doclaynet_dataset", + "build_cord_registry", + "build_cord_encoders", + "cord_fields", + "cord_high_total_label", + "cord_total_amount", + "load_cord_dataset", + "build_chartqa_registry", + "build_chartqa_encoders", + "chartqa_fields", + "chartqa_answer", + "load_chartqa_dataset", "AmountEncoder", "build_invoice_encoders", "build_invoice_registry", - "build_funsd_encoders", - "build_funsd_registry", - "funsd_fields", - "funsd_numeric_answer_label", "invoice_fields", - "load_funsd_dataset", - "build_doclaynet_encoders", - "build_doclaynet_registry", - "doclaynet_fields", - "doclaynet_contains_table", - "load_doclaynet_dataset", "synthetic_invoice", "synthetic_invoice_dataset", "build_longqa_registry", @@ -52,14 +48,33 @@ def __getattr__(name: str) -> Any: # pragma: no cover - thin convenience wrapper if name == "run_benchmark": return import_module("benchmarks.doc_understanding").run_benchmark - if name == "run_doclaynet_benchmark": - return import_module("benchmarks.doc_understanding").run_doclaynet_benchmark - if name == "run_funsd_benchmark": - return import_module("benchmarks.doc_understanding").run_funsd_benchmark + if name == "run_cord_benchmark": + return import_module("benchmarks.doc_understanding").run_cord_benchmark + if name == "run_chartqa_benchmark": + return import_module("benchmarks.chartqa").run_chartqa_benchmark if name == "run_long_qa_benchmark": return import_module("benchmarks.long_qa").run_long_qa_benchmark if name == "run_video_qa_benchmark": return import_module("benchmarks.video_qa").run_video_qa_benchmark + if name in { + "build_cord_registry", + "build_cord_encoders", + "cord_fields", + "cord_high_total_label", + "cord_total_amount", + "load_cord_dataset", + }: + module = import_module("benchmarks.cord") + return getattr(module, name) + if name in { + "build_chartqa_registry", + "build_chartqa_encoders", + "chartqa_fields", + "chartqa_answer", + "load_chartqa_dataset", + }: + module = import_module("benchmarks.chartqa") + return getattr(module, name) if name in { "AmountEncoder", "build_invoice_encoders", @@ -70,33 +85,6 @@ def __getattr__(name: str) -> Any: # pragma: no cover - thin convenience wrappe }: module = import_module("benchmarks.synthetic") return getattr(module, name) - if name in { - "build_doclaynet_registry", - "build_doclaynet_encoders", - "doclaynet_fields", - "doclaynet_contains_table", - "load_doclaynet_dataset", - }: - module = import_module("benchmarks.doclaynet") - return getattr(module, name) - if name in { - "build_funsd_encoders", - "build_funsd_registry", - "funsd_fields", - "funsd_numeric_answer_label", - "load_funsd_dataset", - }: - module = import_module("benchmarks.funsd") - return getattr(module, name) - if name in { - "build_doclaynet_encoders", - "build_doclaynet_registry", - "doclaynet_fields", - "doclaynet_contains_table", - "load_doclaynet_dataset", - }: - module = import_module("benchmarks.doclaynet") - return getattr(module, name) if name in { "build_longqa_registry", "build_longqa_encoders", diff --git a/benchmarks/chartqa.py b/benchmarks/chartqa.py new file mode 100644 index 0000000..b612ac7 --- /dev/null +++ b/benchmarks/chartqa.py @@ -0,0 +1,274 @@ +"""ChartQA benchmark utilities leveraging the N-D registry.""" + +from __future__ import annotations + +import json +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Sequence + +from nd_llm.encoders import Encoder, TextEncoder +from nd_llm.registry import Registry +from nd_llm.utils import build_mi_proxy_context +from nd_llm.bottleneck import IBottleneck + +try: # pragma: no cover - optional HF dependency + from datasets import load_dataset as _load_hf_dataset # type: ignore +except Exception: # pragma: no cover + _load_hf_dataset = None # type: ignore[assignment] + +_DATA_DIR = Path(__file__).with_name("data") +_SAMPLE_PATH = _DATA_DIR.joinpath("chartqa_sample.jsonl") +_DATASET_NAME = "lmms-lab/chartqa" + +__all__ = [ + "load_chartqa_dataset", + "build_chartqa_registry", + "build_chartqa_encoders", + "chartqa_fields", + "chartqa_answer", + "run_chartqa_benchmark", +] + + +def load_chartqa_dataset( + *, + split: str = "test", + limit: Optional[int] = None, + use_sample: bool = True, + cache_dir: Optional[Path | str] = None, +) -> List[Dict[str, Any]]: + """Load ChartQA records either from the bundled sample or Hugging Face.""" + + documents: List[Dict[str, Any]] = [] + if use_sample: + documents.extend(_load_chartqa_sample(limit)) + if documents: + return documents + if _load_hf_dataset is None: + raise ImportError( + "datasets is required to fetch ChartQA from Hugging Face. Install it or set use_sample=True." + ) + dataset = _load_hf_dataset( + _DATASET_NAME, + split=split, + cache_dir=str(cache_dir) if cache_dir is not None else None, + ) + for index, record in enumerate(dataset): + document = _prepare_chartqa_document(record, f"{split}-{index}") + documents.append(document) + if limit is not None and len(documents) >= limit: + break + return documents + + +def _load_chartqa_sample(limit: Optional[int]) -> List[Dict[str, Any]]: + if not _SAMPLE_PATH.exists(): + return [] + docs: List[Dict[str, Any]] = [] + with _SAMPLE_PATH.open("r", encoding="utf-8") as handle: + for line in handle: + if not line.strip(): + continue + docs.append(json.loads(line)) + if limit is not None and len(docs) >= limit: + break + return docs + + +def _prepare_chartqa_document(raw: Mapping[str, Any], default_doc_id: str) -> Dict[str, Any]: + doc_id = str(raw.get("id") or raw.get("doc_id") or f"chartqa-{default_doc_id}") + question = str(raw.get("question", "")).strip() + answer = str(raw.get("answer", "")).strip() + chart_data = raw.get("chart") or raw.get("chart_data") or [] + parsed_chart: List[Dict[str, Any]] = [] + for entry in chart_data: + if not isinstance(entry, Mapping): + continue + label = entry.get("label") or entry.get("category") or entry.get("x") + value = entry.get("value") or entry.get("y") + if label is None or value is None: + continue + try: + numeric = float(value) + except (TypeError, ValueError): + continue + parsed_chart.append({"label": str(label), "value": numeric}) + metadata = { + "chart_type": raw.get("type"), + "source": str(raw.get("source", "")), + } + return { + "doc_id": doc_id, + "question": question, + "answer": answer, + "chart": parsed_chart, + "metadata": metadata, + } + + +def build_chartqa_registry() -> Registry: + registry = Registry() + registry.add_field("question", keys=["doc_id", "token_id"], salience=True, modality="text") + registry.add_field("chart", keys=["doc_id", "row_id"], modality="table") + registry.add_affinity("question", "chart", keys=["doc_id"]) + registry.validate() + return registry + + +def build_chartqa_encoders(registry: Registry, *, question_dim: int = 16, chart_dim: int = 12) -> Dict[str, Encoder]: + encoders: Dict[str, Encoder] = { + "question": TextEncoder(embedding_dim=question_dim), + "chart": TextEncoder(embedding_dim=chart_dim), + } + for field, encoder in encoders.items(): + registry.register_encoder(field, encoder) + return encoders + + +def chartqa_fields(document: Mapping[str, Any]) -> Dict[str, List[MutableMapping[str, Any]]]: + doc_id = str(document.get("doc_id") or "") + question = str(document.get("question", "")) + chart_entries = list(document.get("chart") or []) + question_tokens: List[MutableMapping[str, Any]] = [ + { + "doc_id": doc_id, + "token_id": idx, + "text": token, + } + for idx, token in enumerate(question.split()) + ] + chart_field: List[MutableMapping[str, Any]] = [] + for row_id, entry in enumerate(chart_entries): + label = entry.get("label") + value = entry.get("value") + if label is None or value is None: + continue + chart_entry: MutableMapping[str, Any] = { + "doc_id": doc_id, + "row_id": row_id, + "label": str(label), + "value": float(value), + } + chart_field.append(chart_entry) + return {"question": question_tokens, "chart": chart_field} + + +def chartqa_answer(document: Mapping[str, Any]) -> str: + return str(document.get("answer", "")).strip() + + +def run_chartqa_benchmark( + budget_values: Iterable[int] = (4, 8), + *, + dataset_size: int = 4, + split: str = "test", + use_sample: bool = True, + cache_dir: Optional[Path | str] = None, +) -> Dict[str, Any]: + registry = build_chartqa_registry() + build_chartqa_encoders(registry) + limit = dataset_size if dataset_size > 0 else None + dataset = load_chartqa_dataset(split=split, limit=limit, use_sample=use_sample, cache_dir=cache_dir) + actual_size = len(dataset) + + budgets: List[Dict[str, Any]] = [] + for budget in budget_values: + metrics = _evaluate_chartqa_budget( + dataset=dataset, + budget=int(budget), + registry_encoders=registry.encoders, + ) + budgets.append(metrics) + + return { + "dataset": "ChartQA", + "split": split, + "dataset_size": actual_size, + "use_sample": bool(use_sample), + "budgets": budgets, + } + + +def _evaluate_chartqa_budget( + *, + dataset: Sequence[Mapping[str, Any]], + budget: int, + registry_encoders: Mapping[str, Encoder], +) -> Dict[str, Any]: + bottleneck = IBottleneck(target_budget=int(budget)) + correct = 0 + doc_count = 0 + info_bound = 0.0 + rate_dist = 0.0 + mi_total = 0.0 + mi_count = 0 + kept_totals = 0 + + for document in dataset: + fields = chartqa_fields(document) + mi_proxy, mi_context = build_mi_proxy_context( + fields, + registry_encoders, + preferred_fields=("question", "chart"), + ) + result = bottleneck.compress( + fields, + encoders=registry_encoders, + context=mi_context, + mi_proxy=mi_proxy, + ) + prediction = _chartqa_predict(document, result) + if prediction == chartqa_answer(document): + correct += 1 + metrics = result.metrics + info_bound += float(metrics.get("information_bound", 0.0) or 0.0) + rate_dist += float(metrics.get("rate_distortion", 0.0) or 0.0) + mi_value = metrics.get("mi_lower_bound") + if mi_value is not None: + mi_total += float(mi_value) + mi_count += 1 + kept = sum(len(indices) for indices in result.telemetry.selected_indices.values()) + kept_totals += kept + doc_count += 1 + + accuracy = float(correct) / float(doc_count or 1) + return { + "budget": int(budget), + "accuracy": accuracy, + "distortion": 1.0 - accuracy, + "average_kept_tokens": kept_totals / float(doc_count or 1), + "mean_information_bound": info_bound / float(doc_count or 1), + "mean_rate_distortion": rate_dist / float(doc_count or 1), + "mean_mi_lower_bound": mi_total / float(mi_count or 1) if mi_count else 0.0, + "evaluated_documents": doc_count, + } + + +def _chartqa_predict(document: Mapping[str, Any], result: Any) -> str: + question = str(document.get("question", "")).lower() + chart_entries = document.get("chart") or [] + if chart_entries and isinstance(chart_entries, list): + try: + values = [(entry["label"], float(entry["value"])) for entry in chart_entries] + except Exception: + values = [] + else: + values = [] + + if values: + if "highest" in question or "most" in question or "maximum" in question: + label = max(values, key=lambda item: item[1])[0] + return str(label) + if "lowest" in question or "least" in question or "minimum" in question: + label = min(values, key=lambda item: item[1])[0] + return str(label) + if "total" in question or "sum" in question: + total = sum(value for _, value in values) + return str(int(total) if total.is_integer() else round(total, 2)) + compressed_questions = result.compressed_fields.get("question", []) + if compressed_questions: + first = compressed_questions[0] + if isinstance(first, Mapping): + return str(first.get("text", "")) + return str(first) + return "" diff --git a/benchmarks/cord.py b/benchmarks/cord.py new file mode 100644 index 0000000..9ea09c2 --- /dev/null +++ b/benchmarks/cord.py @@ -0,0 +1,530 @@ +"""Utilities for loading and normalising the CORD receipt dataset.""" + +from __future__ import annotations + +import json +import re +from pathlib import Path +from typing import Any, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Sequence + +from nd_llm.encoders import Encoder, LayoutEncoder, TextEncoder +from nd_llm.registry import ( + FieldAdapter, + FieldAdapterRegistry, + LayoutAligner, + Registry, + quad_to_box, +) + +try: # pragma: no cover - optional dependency for the full dataset + from datasets import load_dataset as _load_hf_dataset # type: ignore +except Exception: # pragma: no cover - fall back to the bundled sample + _load_hf_dataset = None # type: ignore[assignment] + +__all__ = [ + "load_cord_dataset", + "build_cord_registry", + "build_cord_encoders", + "build_cord_field_adapters", + "cord_fields", + "cord_amount_from_text", + "cord_total_amount", + "cord_high_total_label", +] + +_DATA_DIR = Path(__file__).with_name("data") +_SAMPLE_PATH = _DATA_DIR.joinpath("cord_sample.jsonl") +_DATASET_NAME = "naver-clova-ix/cord-v2" +_AMOUNT_PATTERN = re.compile(r"-?\d[\d.,]*") +_CORD_FIELD_ADAPTERS: Optional[FieldAdapterRegistry] = None + + +def load_cord_dataset( + *, + split: str = "train", + limit: Optional[int] = None, + use_sample: bool = True, + data_root: Optional[Path | str] = None, + cache_dir: Optional[Path | str] = None, +) -> List[Dict[str, Any]]: + """Load the requested CORD split via Hugging Face or fall back to a bundled sample.""" + + if data_root is not None: + local_docs = _load_local_cord_documents(data_root, split=split, limit=limit) + if local_docs: + return local_docs + raise FileNotFoundError( + f"CORD data_root '{data_root}' does not contain any '{split}' JSON files" + ) + + documents: List[Dict[str, Any]] = [] + if use_sample: + documents.extend(_load_sample(limit)) + if documents: + return documents + + if _load_hf_dataset is None: + raise ImportError( + "The 'datasets' package is required to download the CORD dataset. " + "Install it with 'pip install datasets pillow' or pass use_sample=True." + ) + + dataset = _load_hf_dataset( + _DATASET_NAME, + split=split, + cache_dir=str(cache_dir) if cache_dir is not None else None, + ) + for index, row in enumerate(dataset): + document = _prepare_document(row) + if not document.get("doc_id"): + document["doc_id"] = f"cord-{split}-{index:05d}" + documents.append(document) + if limit is not None and len(documents) >= limit: + break + return documents + + +def build_cord_registry() -> Registry: + """Return a registry describing the text, layout, and line-level fields.""" + + registry = Registry() + registry.add_field( + "text", + keys=["doc_id", "line_id", "token_id"], + salience=True, + modality="text", + ) + registry.add_field( + "layout", + keys=["doc_id", "line_id", "token_id"], + modality="layout", + ) + registry.add_field( + "line", + keys=["doc_id", "line_id"], + modality="entity", + ) + registry.add_affinity("text", "layout", keys=["doc_id", "line_id", "token_id"]) + registry.add_affinity("line", "text", keys=["doc_id", "line_id"]) + registry.validate() + return registry + + +def build_cord_encoders( + registry: Registry, + *, + text_dim: int = 12, + layout_dim: int = 8, + line_dim: int = 6, +) -> Dict[str, Encoder]: + """Register lightweight encoder stubs for the CORD registry.""" + + encoders: Dict[str, Encoder] = { + "text": TextEncoder(embedding_dim=text_dim), + "layout": LayoutEncoder(embedding_dim=layout_dim), + "line": TextEncoder(embedding_dim=line_dim), + } + for field, encoder in encoders.items(): + registry.register_encoder(field, encoder) + return encoders + + +def build_cord_field_adapters() -> FieldAdapterRegistry: + """Return field adapters that canonicalise CORD lines/words.""" + + registry = FieldAdapterRegistry() + aligner = LayoutAligner() + registry.register( + FieldAdapter( + name="text", + builder=_build_cord_text_entries, + aligner=aligner, + ) + ) + registry.register( + FieldAdapter( + name="layout", + builder=_build_cord_layout_entries, + aligner=aligner, + ) + ) + registry.register( + FieldAdapter( + name="line", + builder=_build_cord_line_entries, + aligner=aligner, + ) + ) + return registry + + +def cord_fields(document: Mapping[str, Any]) -> Dict[str, List[MutableMapping[str, Any]]]: + """Convert a normalised CORD document into registry-aligned field batches.""" + + global _CORD_FIELD_ADAPTERS + if _CORD_FIELD_ADAPTERS is None: + _CORD_FIELD_ADAPTERS = build_cord_field_adapters() + return _CORD_FIELD_ADAPTERS.transform(document) + + +def _safe_int(value: Any, default: int = 0) -> int: + try: + if value is None: + return default + return int(value) + except Exception: + return default + + +def _iter_cord_words(document: Mapping[str, Any]) -> Iterable[MutableMapping[str, Any]]: + doc_id = str(document.get("doc_id") or document.get("id") or "") + token_counter = 0 + for index, line in enumerate(document.get("lines", [])): + line_id = _safe_int(line.get("line_id", line.get("group_id")), index) + category = str(line.get("category", "other")) + group_id = _safe_int(line.get("group_id"), line_id) + sub_group_id = _safe_int(line.get("sub_group_id"), 0) + for word in line.get("words", []): + token_id = word.get("token_id") + if token_id is None: + token_id = token_counter + token_counter += 1 + entry: MutableMapping[str, Any] = { + "doc_id": doc_id, + "line_id": line_id, + "token_id": _safe_int(token_id, token_counter), + "text": str(word.get("text", "")), + "category": category, + "group_id": group_id, + "sub_group_id": sub_group_id, + "quad": word.get("quad") or word.get("coords"), + } + yield entry + + +def _build_cord_text_entries(document: Mapping[str, Any]) -> Iterable[MutableMapping[str, Any]]: + return _iter_cord_words(document) + + +def _build_cord_layout_entries(document: Mapping[str, Any]) -> Iterable[MutableMapping[str, Any]]: + for word in _iter_cord_words(document): + entry: MutableMapping[str, Any] = { + "doc_id": word["doc_id"], + "line_id": word["line_id"], + "token_id": word["token_id"], + "category": word.get("category"), + "quad": word.get("quad"), + } + yield entry + + +def _build_cord_line_entries(document: Mapping[str, Any]) -> Iterable[MutableMapping[str, Any]]: + doc_id = str(document.get("doc_id") or document.get("id") or "") + for index, line in enumerate(document.get("lines", [])): + words = line.get("words", []) + text_value = " ".join(str(word.get("text", "")).strip() for word in words if word.get("text")).strip() + entry: MutableMapping[str, Any] = { + "doc_id": doc_id, + "line_id": _safe_int(line.get("line_id", line.get("group_id")), index), + "category": str(line.get("category", "other")), + "group_id": _safe_int(line.get("group_id"), index), + "sub_group_id": _safe_int(line.get("sub_group_id"), 0), + "text": text_value, + "token_count": len(words), + "quad": _line_quad(line), + } + yield entry + + +def _line_quad(line: Mapping[str, Any]) -> Optional[List[float]]: + quads: List[List[float]] = [] + for word in line.get("words", []): + quad = word.get("quad") + if quad is not None: + quads.append(quad_to_box(quad)) + if quads: + min_x = min(box[0] for box in quads) + min_y = min(box[1] for box in quads) + max_x = max(box[2] for box in quads) + max_y = max(box[3] for box in quads) + return [min_x, min_y, max_x, max_y] + quad = line.get("quad") + if quad is not None: + return quad_to_box(quad) + coords = line.get("coords") + if isinstance(coords, Sequence): + return [float(value) for value in coords[:4]] + return [0.0, 0.0, 0.0, 0.0] + + +def cord_amount_from_text(value: Any) -> float: + """Best-effort extraction of a numeric amount from a free-form token.""" + + if isinstance(value, (int, float)) and not isinstance(value, bool): + return float(value) + if value is None: + return 0.0 + text = str(value).strip() + if not text: + return 0.0 + candidates = _AMOUNT_PATTERN.findall(text.replace(" ", "")) + best = 0.0 + for candidate in candidates: + cleaned = candidate.replace(",", "") + if cleaned.count(".") > 1: + cleaned = cleaned.replace(".", "", cleaned.count(".") - 1) + try: + amount = float(cleaned) + except Exception: + continue + if abs(amount) > abs(best): + best = amount + return best + + +def cord_total_amount(document: Mapping[str, Any]) -> float: + """Return the parsed total amount for a document.""" + + if "total_amount" in document: + try: + return float(document["total_amount"]) + except Exception: + pass + total = document.get("total") or document.get("totals") or {} + sub_total = document.get("sub_total") or document.get("subtotal") or {} + for key in ("total_price", "cashprice", "creditcard_price", "total"): + if key in total: + amount = cord_amount_from_text(total[key]) + if amount: + return amount + for key in ("subtotal_price", "sum"): + if key in sub_total: + amount = cord_amount_from_text(sub_total[key]) + if amount: + return amount + return 0.0 + + +def cord_high_total_label(document: Mapping[str, Any], *, threshold: float) -> bool: + """Binary label for whether a receipt exceeds the provided total threshold.""" + + return cord_total_amount(document) >= float(threshold) + + +def _load_sample(limit: Optional[int]) -> Iterator[Dict[str, Any]]: + if not _SAMPLE_PATH.exists(): + return + count = 0 + with _SAMPLE_PATH.open("r", encoding="utf-8") as handle: + for line in handle: + line = line.strip() + if not line: + continue + data = json.loads(line) + yield _prepare_document(data) + count += 1 + if limit is not None and count >= limit: + return + + +def _load_local_cord_documents( + root: Path | str, + *, + split: str, + limit: Optional[int], +) -> List[Dict[str, Any]]: + base = Path(root).expanduser().resolve() + roots = _resolve_local_roots(base) + documents: List[Dict[str, Any]] = [] + seen_ids: set[str] = set() + for dataset_root in roots: + for doc in _iter_local_split(dataset_root, split): + doc_id = str(doc.get("doc_id") or "") + if not doc_id: + doc_id = f"{dataset_root.name}-{split}-{len(documents):05d}" + doc["doc_id"] = doc_id + if doc_id in seen_ids: + continue + seen_ids.add(doc_id) + documents.append(doc) + if limit is not None and len(documents) >= limit: + return documents + return documents + + +def _resolve_local_roots(base: Path) -> List[Path]: + def _is_cord_root(candidate: Path) -> bool: + return any((candidate / part).is_dir() for part in ("train", "test", "dev")) + + candidates: List[Path] = [] + if _is_cord_root(base): + return [base] + for child in sorted(base.iterdir()): + if not child.is_dir(): + continue + name = child.name.lower() + if not name.startswith("cord"): + continue + if _is_cord_root(child): + candidates.append(child) + return candidates + + +def _iter_local_split(root: Path, split: str) -> Iterator[Dict[str, Any]]: + split_dir = root / split + json_dir = split_dir / "json" + if not json_dir.is_dir(): + json_dir = split_dir + if not json_dir.is_dir(): + return + for path in sorted(json_dir.glob("*.json")): + try: + data = json.loads(path.read_text(encoding="utf-8")) + except Exception: + continue + document = _prepare_document(data) + document.setdefault("doc_id", path.stem) + yield document + + +def _prepare_document(raw: Mapping[str, Any]) -> Dict[str, Any]: + if "lines" in raw and "width" in raw and "height" in raw: + return _finalise_document(raw) + + payload: Mapping[str, Any] + ground_truth = raw.get("ground_truth") + if isinstance(ground_truth, str): + payload = json.loads(ground_truth) + elif isinstance(ground_truth, Mapping): + payload = ground_truth + else: + payload = {} + meta = payload.get("meta", {}) + + image_size = meta.get("image_size") or {} + width = image_size.get("width") + height = image_size.get("height") + if (not width or not height) and "image" in raw: + image = raw["image"] + try: + width = getattr(image, "width") + height = getattr(image, "height") + except Exception: + try: + size = getattr(image, "size") + width, height = size + except Exception: + width = width or 1000 + height = height or 1400 + + converted: Dict[str, Any] = { + "doc_id": raw.get("doc_id") or raw.get("id") or meta.get("image_id") or "", + "width": int(width or 1000), + "height": int(height or 1400), + "lines": payload.get("valid_line") or [], + "menu": payload.get("gt_parse", {}).get("menu", []), + "sub_total": payload.get("gt_parse", {}).get("sub_total", {}), + "total": payload.get("gt_parse", {}).get("total", {}), + "metadata": { + "split": meta.get("split") or raw.get("split"), + "version": meta.get("version"), + }, + } + return _finalise_document(converted) + + +def _finalise_document(raw: Mapping[str, Any]) -> Dict[str, Any]: + doc_id = str(raw.get("doc_id") or raw.get("id") or "") + width = max(int(raw.get("width", 0) or 1), 1) + height = max(int(raw.get("height", 0) or 1), 1) + metadata = dict(raw.get("metadata") or {}) + metadata.setdefault("source", raw.get("source", "cord")) + + lines = _prepare_lines(raw.get("lines") or raw.get("valid_line") or []) + document = { + "doc_id": doc_id, + "width": width, + "height": height, + "lines": lines, + "menu": list(raw.get("menu") or []), + "sub_total": dict(raw.get("sub_total") or raw.get("subtotal") or {}), + "total": dict(raw.get("total") or raw.get("totals") or {}), + "metadata": metadata, + } + document["total_amount"] = cord_total_amount(document) + return document + + +def _prepare_lines(lines: Iterable[Mapping[str, Any]]) -> List[Dict[str, Any]]: + prepared: List[Dict[str, Any]] = [] + token_counter = 0 + for index, line in enumerate(lines): + line_id = line.get("line_id") + if line_id is None: + line_id = line.get("group_id") + if line_id is None: + line_id = index + category = str(line.get("category", "other")) + group_id = int(line.get("group_id", line_id)) + sub_group_id = int(line.get("sub_group_id", 0)) + + prepared_words = [] + for word in line.get("words", []): + token_id = word.get("token_id") + if token_id is None: + token_id = token_counter + token_counter += 1 + prepared_words.append( + { + "token_id": int(token_id), + "text": str(word.get("text", "")), + "quad": _normalise_quad(word.get("quad")), + "is_key": int(word.get("is_key", 0)), + } + ) + + prepared.append( + { + "line_id": int(line_id), + "category": category, + "group_id": group_id, + "sub_group_id": sub_group_id, + "words": prepared_words, + } + ) + return prepared + + +def _normalise_quad(quad: Any) -> Dict[str, float]: + if isinstance(quad, Mapping): + return { + "x1": float(quad.get("x1", 0.0)), + "y1": float(quad.get("y1", 0.0)), + "x2": float(quad.get("x2", quad.get("x1", 0.0))), + "y2": float(quad.get("y2", quad.get("y1", 0.0))), + "x3": float(quad.get("x3", quad.get("x2", 0.0))), + "y3": float(quad.get("y3", quad.get("y2", 0.0))), + "x4": float(quad.get("x4", quad.get("x1", 0.0))), + "y4": float(quad.get("y4", quad.get("y3", 0.0))), + } + if isinstance(quad, Sequence) and len(quad) >= 8: + return { + "x1": float(quad[0]), + "y1": float(quad[1]), + "x2": float(quad[2]), + "y2": float(quad[3]), + "x3": float(quad[4]), + "y3": float(quad[5]), + "x4": float(quad[6]), + "y4": float(quad[7]), + } + return { + "x1": 0.0, + "y1": 0.0, + "x2": 0.0, + "y2": 0.0, + "x3": 0.0, + "y3": 0.0, + "x4": 0.0, + "y4": 0.0, + } diff --git a/benchmarks/data/chartqa_sample.jsonl b/benchmarks/data/chartqa_sample.jsonl new file mode 100644 index 0000000..ba91f8d --- /dev/null +++ b/benchmarks/data/chartqa_sample.jsonl @@ -0,0 +1,2 @@ +{"doc_id": "chartqa-sample-0", "question": "Which month had the highest revenue?", "answer": "March", "chart": [{"label": "January", "value": 40}, {"label": "February", "value": 55}, {"label": "March", "value": 70}, {"label": "April", "value": 45}], "metadata": {"chart_type": "bar", "unit": "k$"}} +{"doc_id": "chartqa-sample-1", "question": "What is the total number of subscriptions across the four regions?", "answer": "260", "chart": [{"label": "North", "value": 50}, {"label": "South", "value": 60}, {"label": "East", "value": 65}, {"label": "West", "value": 85}], "metadata": {"chart_type": "stacked bar", "unit": "k users"}} diff --git a/benchmarks/data/cord_sample.jsonl b/benchmarks/data/cord_sample.jsonl new file mode 100644 index 0000000..d3ad1bf --- /dev/null +++ b/benchmarks/data/cord_sample.jsonl @@ -0,0 +1 @@ +{"doc_id": "cord-sample-0", "width": 600, "height": 900, "metadata": {"split": "sample", "image_id": "cord-sample-0"}, "lines": [{"line_id": 0, "category": "menu.cnt", "group_id": 0, "sub_group_id": 0, "words": [{"token_id": 0, "text": "1", "quad": {"x1": 40, "y1": 120, "x2": 60, "y2": 120, "x3": 60, "y3": 160, "x4": 40, "y4": 160}}]}, {"line_id": 1, "category": "menu.nm", "group_id": 0, "sub_group_id": 1, "words": [{"token_id": 1, "text": "Sample", "quad": {"x1": 90, "y1": 120, "x2": 180, "y2": 120, "x3": 180, "y3": 160, "x4": 90, "y4": 160}}, {"token_id": 2, "text": "Latte", "quad": {"x1": 190, "y1": 120, "x2": 260, "y2": 120, "x3": 260, "y3": 160, "x4": 190, "y4": 160}}]}, {"line_id": 2, "category": "menu.price", "group_id": 0, "sub_group_id": 2, "words": [{"token_id": 3, "text": "4.50", "quad": {"x1": 400, "y1": 120, "x2": 460, "y2": 120, "x3": 460, "y3": 160, "x4": 400, "y4": 160}}]}, {"line_id": 3, "category": "menu.cnt", "group_id": 1, "sub_group_id": 0, "words": [{"token_id": 4, "text": "1", "quad": {"x1": 40, "y1": 200, "x2": 60, "y2": 200, "x3": 60, "y3": 240, "x4": 40, "y4": 240}}]}, {"line_id": 4, "category": "menu.nm", "group_id": 1, "sub_group_id": 1, "words": [{"token_id": 5, "text": "Citrus", "quad": {"x1": 90, "y1": 200, "x2": 170, "y2": 200, "x3": 170, "y3": 240, "x4": 90, "y4": 240}}, {"token_id": 6, "text": "Tea", "quad": {"x1": 180, "y1": 200, "x2": 230, "y2": 200, "x3": 230, "y3": 240, "x4": 180, "y4": 240}}]}, {"line_id": 5, "category": "menu.price", "group_id": 1, "sub_group_id": 2, "words": [{"token_id": 7, "text": "5.00", "quad": {"x1": 400, "y1": 200, "x2": 460, "y2": 200, "x3": 460, "y3": 240, "x4": 400, "y4": 240}}]}, {"line_id": 6, "category": "sub_total.subtotal_price", "group_id": 2, "sub_group_id": 0, "words": [{"token_id": 8, "text": "Subtotal", "quad": {"x1": 250, "y1": 320, "x2": 330, "y2": 320, "x3": 330, "y3": 360, "x4": 250, "y4": 360}}, {"token_id": 9, "text": "9.50", "quad": {"x1": 400, "y1": 320, "x2": 460, "y2": 320, "x3": 460, "y3": 360, "x4": 400, "y4": 360}}]}, {"line_id": 7, "category": "sub_total.tax_price", "group_id": 3, "sub_group_id": 0, "words": [{"token_id": 10, "text": "Tax", "quad": {"x1": 250, "y1": 380, "x2": 300, "y2": 380, "x3": 300, "y3": 420, "x4": 250, "y4": 420}}, {"token_id": 11, "text": "0.95", "quad": {"x1": 400, "y1": 380, "x2": 460, "y2": 380, "x3": 460, "y3": 420, "x4": 400, "y4": 420}}]}, {"line_id": 8, "category": "sub_total.service_price", "group_id": 4, "sub_group_id": 0, "words": [{"token_id": 12, "text": "Service", "quad": {"x1": 230, "y1": 440, "x2": 320, "y2": 440, "x3": 320, "y3": 480, "x4": 230, "y4": 480}}, {"token_id": 13, "text": "0.55", "quad": {"x1": 400, "y1": 440, "x2": 460, "y2": 440, "x3": 460, "y3": 480, "x4": 400, "y4": 480}}]}, {"line_id": 9, "category": "total.total_price", "group_id": 5, "sub_group_id": 0, "words": [{"token_id": 14, "text": "Grand", "quad": {"x1": 180, "y1": 520, "x2": 250, "y2": 520, "x3": 250, "y3": 560, "x4": 180, "y4": 560}}, {"token_id": 15, "text": "Total", "quad": {"x1": 260, "y1": 520, "x2": 330, "y2": 520, "x3": 330, "y3": 560, "x4": 260, "y4": 560}}, {"token_id": 16, "text": "11.00", "quad": {"x1": 400, "y1": 520, "x2": 470, "y2": 520, "x3": 470, "y3": 560, "x4": 400, "y4": 560}}]}], "menu": [{"nm": "Sample Latte", "cnt": "1", "price": "4.50"}, {"nm": "Citrus Tea", "cnt": "1", "price": "5.00"}], "sub_total": {"subtotal_price": "9.50", "tax_price": "0.95", "service_price": "0.55"}, "total": {"total_price": "11.00"}} diff --git a/benchmarks/data/doclaynet_sample.jsonl b/benchmarks/data/doclaynet_sample.jsonl deleted file mode 100644 index 7646480..0000000 --- a/benchmarks/data/doclaynet_sample.jsonl +++ /dev/null @@ -1,3 +0,0 @@ -{"id": "doclaynet-sample-1", "size": [1200, 1800], "tokens": [{"id": 0, "text": "Quarterly", "bbox": [120, 110, 360, 180], "region_id": 0, "page_index": 0}, {"id": 1, "text": "Report", "bbox": [380, 110, 560, 180], "region_id": 0, "page_index": 0}, {"id": 2, "text": "Table", "bbox": [140, 420, 620, 720], "region_id": 1, "page_index": 0}, {"id": 3, "text": "1", "bbox": [630, 420, 660, 720], "region_id": 1, "page_index": 0}], "regions": [{"id": 0, "label": "title", "bbox": [100, 80, 600, 220], "token_ids": [0, 1], "page_index": 0}, {"id": 1, "label": "table", "bbox": [120, 400, 680, 760], "token_ids": [2, 3], "page_index": 0}]} -{"id": "doclaynet-sample-2", "image_size": [1000, 1600], "tokens": [{"id": 0, "text": "Summary", "bbox": [180, 160, 520, 220], "region_id": 0, "page": 0}, {"id": 1, "text": "Overview", "bbox": [180, 230, 520, 290], "region_id": 0, "page": 0}, {"id": 2, "text": "Conclusion", "bbox": [180, 300, 520, 360], "region_id": 0, "page": 0}], "regions": [{"id": 0, "category": "paragraph", "bbox": [160, 140, 540, 380], "token_ids": [0, 1, 2], "page": 0}]} -{"id": "doclaynet-sample-3", "page_size": {"width": 800, "height": 1200}, "words": [{"id": 0, "text": "Invoice", "box": [100, 100, 300, 180], "region": 0}, {"id": 1, "text": "Details", "box": [320, 100, 520, 180], "region": 0}, {"id": 2, "text": "Notes", "box": [120, 400, 420, 520], "region": 1}], "elements": [{"id": 0, "class": "header", "box": [80, 80, 560, 220], "tokens": [0, 1]}, {"id": 1, "class": "paragraph", "box": [100, 360, 500, 560], "tokens": [2]}]} diff --git a/benchmarks/data/funsd_sample.jsonl b/benchmarks/data/funsd_sample.jsonl deleted file mode 100644 index 87c815b..0000000 --- a/benchmarks/data/funsd_sample.jsonl +++ /dev/null @@ -1,2 +0,0 @@ -{"id": "doc_0001", "width": 1000, "height": 1000, "form": [{"id": 0, "label": "question", "text": "Invoice Number", "linking": [[0, 1]], "words": [{"id": 0, "text": "Invoice", "box": [40, 120, 210, 160]}, {"id": 1, "text": "Number", "box": [220, 120, 390, 160]}]}, {"id": 1, "label": "answer", "text": "INV-2024-001", "linking": [[0, 1]], "words": [{"id": 2, "text": "INV-2024-001", "box": [420, 120, 680, 160]}]}, {"id": 2, "label": "header", "text": "Acme Corporation", "words": [{"id": 3, "text": "Acme", "box": [40, 40, 180, 90]}, {"id": 4, "text": "Corporation", "box": [190, 40, 480, 90]}]}]} -{"id": "doc_0002", "width": 1000, "height": 1000, "form": [{"id": 0, "label": "question", "text": "Customer Name", "words": [{"id": 0, "text": "Customer", "box": [60, 140, 260, 180]}, {"id": 1, "text": "Name", "box": [270, 140, 410, 180]}]}, {"id": 1, "label": "answer", "text": "Jane Doe", "words": [{"id": 2, "text": "Jane", "box": [450, 140, 580, 180]}, {"id": 3, "text": "Doe", "box": [590, 140, 680, 180]}]}, {"id": 2, "label": "header", "text": "Registration Form", "words": [{"id": 4, "text": "Registration", "box": [60, 40, 340, 100]}, {"id": 5, "text": "Form", "box": [350, 40, 480, 100]}]}]} diff --git a/benchmarks/doc_understanding.py b/benchmarks/doc_understanding.py index 101999d..80bca74 100644 --- a/benchmarks/doc_understanding.py +++ b/benchmarks/doc_understanding.py @@ -1,11 +1,4 @@ -"""Synthetic doc-understanding benchmark for accuracy vs. token budget. - -For evaluations on the full FUNSD and DocLayNet datasets, first download the -official releases via ``python scripts/download_datasets.py``. Then point the -helpers at the cache directory with ``data_root=PATH / "funsd"`` or -``data_root=PATH / "doclaynet"`` while setting ``use_sample=False`` and -``dataset_size=0``. -""" +"""Doc-understanding benchmarks for the synthetic invoices and CORD receipts.""" from __future__ import annotations @@ -42,19 +35,14 @@ STMConfig, ) -from .doclaynet import ( - build_doclaynet_encoders, - build_doclaynet_registry, - doclaynet_contains_table, - doclaynet_fields, - load_doclaynet_dataset, -) -from .funsd import ( - build_funsd_encoders, - build_funsd_registry, - funsd_fields, - funsd_numeric_answer_label, - load_funsd_dataset, +from .cord import ( + cord_amount_from_text, + build_cord_encoders, + build_cord_registry, + cord_fields, + cord_high_total_label, + cord_total_amount, + load_cord_dataset, ) from .synthetic import ( build_invoice_encoders, @@ -149,102 +137,60 @@ def _label_fn(invoice: Mapping[str, Any], threshold_value: float = threshold) -> "budgets": [run.to_dict() for run in runs], } - -def run_funsd_benchmark( - budget_values: Iterable[int] = (8, 12, 16), +def run_cord_benchmark( + budget_values: Iterable[int] = (4, 8, 12), *, - dataset_size: int = 12, + dataset_size: int = 8, split: str = "train", - data_root: Optional[Path | str] = None, use_sample: bool = True, - seed: int = 0, - allocator_kwargs: Optional[Mapping[str, Any]] = None, -) -> Dict[str, Any]: - """Evaluate FUNSD documents for numeric-answer retention under budget constraints. - - Parameters - ---------- - allocator_kwargs: - Optional mapping forwarded to :class:`RegistryAwareBudgetAllocator` so callers can - specify advanced controls such as ``field_weights`` or ``field_min_quota``. - """ - - registry = build_funsd_registry() - build_funsd_encoders(registry) - limit = dataset_size if dataset_size > 0 else None - dataset = load_funsd_dataset(data_root, split=split, limit=limit, use_sample=use_sample) - actual_size = len(dataset) - - runs: List[BudgetRun] = [] - ablations = _funsd_ablation_suite(seed) - for budget in budget_values: - budget_run = _evaluate_budget( - budget=int(budget), - dataset=dataset, - registry_encoders=registry.encoders, - fields_fn=funsd_fields, - label_fn=funsd_numeric_answer_label, - predict_fn=_funsd_predict_numeric_answer, - metadata_fn=_funsd_metadata, - policy_name="funsd-doc-benchmark", - retention_probe_sample_size=3, - seed=seed, - ablations=ablations, - mi_field_priorities=("text", "layout"), - allocator_kwargs=allocator_kwargs, - ) - runs.append(budget_run) - - return { - "dataset": "FUNSD", - "split": split, - "dataset_size": actual_size, - "use_sample": bool(use_sample), - "budgets": [run.to_dict() for run in runs], - } - - -def run_doclaynet_benchmark( - budget_values: Iterable[int] = (6, 12), - *, - dataset_size: int = 6, - split: str = "train", data_root: Optional[Path | str] = None, - use_sample: bool = True, + cache_dir: Optional[Path | str] = None, + threshold: float = 250_000.0, seed: int = 0, ) -> Dict[str, Any]: - """Evaluate DocLayNet documents for table retention under budget constraints.""" + """Evaluate receipt totals from the CORD dataset under token budget constraints.""" - registry = build_doclaynet_registry() - build_doclaynet_encoders(registry) + registry = build_cord_registry() + build_cord_encoders(registry) limit = dataset_size if dataset_size > 0 else None - dataset = load_doclaynet_dataset(data_root, split=split, limit=limit, use_sample=use_sample) + dataset = load_cord_dataset( + split=split, + limit=limit, + use_sample=use_sample, + data_root=data_root, + cache_dir=cache_dir, + ) actual_size = len(dataset) + def _label_fn(document: Mapping[str, Any]) -> bool: + return cord_high_total_label(document, threshold=threshold) + + predict_fn = _cord_prediction_factory(threshold) + ablations = _cord_ablation_suite(seed) runs: List[BudgetRun] = [] - ablations = _doclaynet_ablation_suite(seed) for budget in budget_values: budget_run = _evaluate_budget( budget=int(budget), dataset=dataset, registry_encoders=registry.encoders, - fields_fn=doclaynet_fields, - label_fn=doclaynet_contains_table, - predict_fn=_doclaynet_predict_contains_table, - metadata_fn=_doclaynet_metadata, - policy_name="doclaynet-doc-benchmark", + fields_fn=cord_fields, + label_fn=_label_fn, + predict_fn=predict_fn, + metadata_fn=_cord_metadata, + policy_name="cord-receipt-benchmark", retention_probe_sample_size=3, seed=seed, ablations=ablations, - mi_field_priorities=("layout", "text", "segment"), + mi_field_priorities=("layout", "text", "line"), ) runs.append(budget_run) return { - "dataset": "DocLayNet", + "dataset": "CORD-v2", "split": split, "dataset_size": actual_size, "use_sample": bool(use_sample), + "threshold": float(threshold), "budgets": [run.to_dict() for run in runs], } @@ -667,21 +613,11 @@ def _invoice_ablation_suite(seed: int) -> Dict[str, AblationFn]: } -def _funsd_ablation_suite(seed: int) -> Dict[str, AblationFn]: - return { - "drop_layout": _drop_field_ablation("layout"), - "drop_text": _drop_field_ablation("text"), - "perturb_layout": _perturb_layout_ablation(scale=0.02), - "shuffle_entities": _shuffle_field_ablation("entity"), - } - - -def _doclaynet_ablation_suite(seed: int) -> Dict[str, AblationFn]: +def _cord_ablation_suite(seed: int) -> Dict[str, AblationFn]: return { "drop_layout": _drop_field_ablation("layout"), "drop_text": _drop_field_ablation("text"), - "perturb_layout": _perturb_layout_ablation(scale=0.03), - "shuffle_segments": _shuffle_field_ablation("segment"), + "shuffle_lines": _shuffle_field_ablation("line"), } @@ -1004,107 +940,70 @@ def _serialise_cell_fusion(fusion: Mapping[str, Any]) -> Dict[str, Any]: return payload -def _copy_nested_list(value: Any) -> Any: - if isinstance(value, (list, tuple)): - return [_copy_nested_list(item) for item in value] - return value - +def _cord_prediction_factory(threshold: float) -> PredictFn: + def _predict(result: CompressionResult, document: Mapping[str, Any]) -> bool: + guess = _cord_guess_total(result) + if guess is None: + guess = cord_total_amount(document) + return float(guess or 0.0) >= float(threshold) -def _doclaynet_predict_contains_table(result: CompressionResult, _: Mapping[str, Any]) -> bool: - segments = result.compressed_fields.get("segment") - if segments is None: - segments = result.compressed_fields.get("region", []) - for segment in segments: - if isinstance(segment, Mapping): - label = ( - segment.get("label") - or segment.get("segment_label") - or segment.get("region_label") - ) - else: - label = segment - if isinstance(label, str) and label.lower() == "table": - return True - for token in result.compressed_fields.get("text", []): - value: Any - if isinstance(token, Mapping): - value = token.get("text") - else: - value = token - if isinstance(value, str) and "table" in value.lower(): - return True - return False + return _predict -def _doclaynet_metadata(document: Mapping[str, Any], result: CompressionResult) -> Mapping[str, Any]: - doc_id = document.get("doc_id") or document.get("id") - segments = result.compressed_fields.get("segment") - if segments is None: - segments = result.compressed_fields.get("region", []) - kept_segments = [] - for segment in segments: - if isinstance(segment, Mapping): - label = ( - segment.get("label") - or segment.get("segment_label") - or segment.get("region_label") - ) - segment_id = ( - segment.get("segment_id") - if segment.get("segment_id") is not None - else segment.get("region_id") - ) - else: - label = segment - segment_id = None - kept_segments.append( - { - "segment_id": segment_id, - "region_id": segment_id, - "label": str(label) if label is not None else "", - } - ) - - metadata = { - "doc_id": doc_id, - "contains_table": doclaynet_contains_table(document), - "kept_segments": kept_segments, - "kept_segment_count": len(kept_segments), - } - # Preserve legacy keys for downstream consumers expecting the old naming. - metadata["kept_regions"] = [ - {"region_id": entry.get("region_id"), "label": entry.get("label", "")} - for entry in kept_segments - ] - metadata["kept_region_count"] = metadata["kept_segment_count"] - return metadata - - -def _funsd_predict_numeric_answer(result: CompressionResult, _: Mapping[str, Any]) -> bool: - for entity in result.compressed_fields.get("entity", []): - if isinstance(entity, Mapping) and str(entity.get("label", "")).lower() == "answer": - text = entity.get("text", "") - if any(char.isdigit() for char in str(text)): - return True - for token in result.compressed_fields.get("text", []): - if not isinstance(token, Mapping): +def _cord_guess_total(result: CompressionResult) -> Optional[float]: + best: Optional[float] = None + for field_name in ("text", "line"): + entries = result.compressed_fields.get(field_name, []) + candidate = _cord_max_amount(entries, prefer_field=field_name) + if candidate is None: continue - if token.get("is_answer") and any(char.isdigit() for char in str(token.get("text", ""))): - return True - return False + if best is None or candidate > best: + best = candidate + return best -def _funsd_metadata(document: Mapping[str, Any], _: CompressionResult) -> Mapping[str, Any]: - doc_id = document.get("doc_id") or document.get("id") - entities = document.get("form", []) - answer_count = sum(1 for item in entities if str(item.get("label", "")).lower() == "answer") +def _cord_max_amount( + entries: Sequence[Any], + *, + prefer_field: str, +) -> Optional[float]: + best: Optional[float] = None + for entry in entries: + if isinstance(entry, Mapping): + text_value = entry.get("text") + category = entry.get("category") or entry.get("line_category") + else: + text_value = entry + category = "" + amount = cord_amount_from_text(text_value) + if amount <= 0: + continue + weight = 1.0 + category_text = str(category or "").lower() + if "total" in category_text: + weight = 1.4 + elif prefer_field == "line": + weight = 1.15 + candidate = amount * weight + if best is None or candidate > best: + best = candidate + return best + + +def _cord_metadata(document: Mapping[str, Any], _: CompressionResult) -> Mapping[str, Any]: return { - "doc_id": doc_id, - "entity_count": len(entities), - "answer_entities": answer_count, + "doc_id": document.get("doc_id"), + "total_amount": cord_total_amount(document), + "line_count": len(document.get("lines", [])), } +def _copy_nested_list(value: Any) -> Any: + if isinstance(value, (list, tuple)): + return [_copy_nested_list(item) for item in value] + return value + + def main() -> None: """CLI entry point printing a JSON benchmark report.""" @@ -1116,4 +1015,4 @@ def main() -> None: main() -__all__ = ["run_benchmark", "run_funsd_benchmark", "run_doclaynet_benchmark"] +__all__ = ["run_benchmark", "run_cord_benchmark"] diff --git a/benchmarks/doclaynet.py b/benchmarks/doclaynet.py deleted file mode 100644 index 6f35671..0000000 --- a/benchmarks/doclaynet.py +++ /dev/null @@ -1,559 +0,0 @@ -"""Typed helpers for working with the DocLayNet document-layout dataset.""" - -from __future__ import annotations - -import json -from copy import deepcopy -from pathlib import Path -from typing import ( - Any, - Dict, - Iterable, - Iterator, - List, - Mapping, - MutableMapping, - Optional, - Sequence, - Tuple, -) - -from nd_llm.encoders import Encoder, LayoutEncoder, TextEncoder -from nd_llm.registry import Registry - -__all__ = [ - "load_doclaynet_dataset", - "build_doclaynet_registry", - "build_doclaynet_encoders", - "doclaynet_fields", - "doclaynet_contains_table", -] - - -# Paths to bundled DocLayNet caches and samples. -_DATA_DIR = Path(__file__).with_name("data") -_CACHE_PATH = _DATA_DIR.joinpath("doclaynet_cache.jsonl") -_SAMPLE_PATH = _DATA_DIR.joinpath("doclaynet_sample.jsonl") - -# A tiny in-memory sample used when the real dataset is unavailable. The sample -# mirrors the minimal structure the conversion helpers expect, keeping runtime -# behaviour deterministic for doctests and local experiments. -_SAMPLE_DOCUMENT: Dict[str, Any] = { - "doc_id": "doclaynet-sample", - "page_id": "doclaynet-sample-0", - "width": 1000, - "height": 1400, - "segments": [ - { - "segment_id": 0, - "label": "header", - "confidence": 0.92, - "polygon": [10, 10, 400, 10, 400, 120, 10, 120], - "text": "DocLayNet Sample Heading", - "tokens": [ - { - "token_id": 0, - "text": "DocLayNet", - "polygon": [10, 10, 180, 10, 180, 120, 10, 120], - "confidence": 0.9, - }, - { - "token_id": 1, - "text": "Sample", - "polygon": [180, 10, 300, 10, 300, 120, 180, 120], - "confidence": 0.91, - }, - { - "token_id": 2, - "text": "Heading", - "polygon": [300, 10, 400, 10, 400, 120, 300, 120], - "confidence": 0.93, - }, - ], - }, - { - "segment_id": 1, - "label": "paragraph", - "confidence": 0.88, - "polygon": [30, 180, 960, 180, 960, 620, 30, 620], - "text": "This is a lightweight fallback paragraph used when the DocLayNet dataset is not installed.", - "tokens": [ - { - "token_id": 0, - "text": "This", - "polygon": [30, 180, 110, 180, 110, 260, 30, 260], - "confidence": 0.87, - }, - { - "token_id": 1, - "text": "is", - "polygon": [120, 180, 150, 180, 150, 260, 120, 260], - "confidence": 0.88, - }, - { - "token_id": 2, - "text": "a", - "polygon": [160, 180, 190, 180, 190, 260, 160, 260], - "confidence": 0.9, - }, - { - "token_id": 3, - "text": "lightweight", - "polygon": [200, 180, 370, 180, 370, 260, 200, 260], - "confidence": 0.9, - }, - { - "token_id": 4, - "text": "fallback", - "polygon": [380, 180, 520, 180, 520, 260, 380, 260], - "confidence": 0.89, - }, - { - "token_id": 5, - "text": "paragraph", - "polygon": [530, 180, 700, 180, 700, 260, 530, 260], - "confidence": 0.9, - }, - ], - }, - ], - "metadata": { - "source": "synthetic", - "split": "sample", - }, -} - - -def load_doclaynet_dataset( - root: Optional[Path | str] = None, - *, - split: str = "train", - limit: Optional[int] = None, - use_sample: bool = True, -) -> List[Dict[str, Any]]: - """Load DocLayNet pages from ``root`` or return a bundled sample.""" - - if root is None: - if not use_sample: - raise FileNotFoundError( - "DocLayNet root path is required when use_sample is False" - ) - return list(_load_sample(limit)) - - path = Path(root) - if not path.exists(): - if use_sample: - return list(_load_sample(limit)) - raise FileNotFoundError(f"DocLayNet root directory '{path}' does not exist") - - pages = list(_load_from_directory(path, split=split, limit=limit)) - if pages: - return pages - - if use_sample: - return list(_load_sample(limit)) - - raise FileNotFoundError( - f"DocLayNet split '{split}' not found in '{path}' and no bundled sample is available" - ) - - -def build_doclaynet_registry() -> Registry: - """Return a registry that captures DocLayNet text, layout, and segment metadata.""" - - registry = Registry() - registry.add_field( - "text", - keys=["doc_id", "page_id", "segment_id", "token_id"], - salience=True, - modality="text", - ) - registry.add_field( - "layout", - keys=["doc_id", "page_id", "segment_id", "token_id"], - modality="layout", - ) - registry.add_field( - "segment", keys=["doc_id", "page_id", "segment_id"], modality="entity" - ) - registry.add_affinity("segment", "text", keys=["doc_id", "page_id", "segment_id"]) - registry.add_affinity("text", "layout", keys=["doc_id", "page_id", "token_id"]) - registry.validate() - return registry - - -def build_doclaynet_encoders( - registry: Registry, - *, - text_dim: int = 16, - layout_dim: int = 8, - segment_dim: int = 4, -) -> Dict[str, Encoder]: - """Register simple encoder stubs for the DocLayNet registry.""" - - encoders: Dict[str, Encoder] = { - "text": TextEncoder(embedding_dim=text_dim), - "layout": LayoutEncoder(embedding_dim=layout_dim), - "segment": TextEncoder(embedding_dim=segment_dim), - } - for field, encoder in encoders.items(): - registry.register_encoder(field, encoder) - return encoders - - -def doclaynet_fields( - document: Mapping[str, Any], -) -> Dict[str, List[MutableMapping[str, Any]]]: - """Convert a DocLayNet page into registry-aligned field payloads.""" - - doc_id = str(document.get("doc_id") or document.get("document_id") or "") - page_id = str(document.get("page_id") or document.get("page") or 0) - width = float(document.get("width") or document.get("img_width") or 1.0) - height = float(document.get("height") or document.get("img_height") or 1.0) - if width <= 0: - width = 1.0 - if height <= 0: - height = 1.0 - - text_field: List[MutableMapping[str, Any]] = [] - layout_field: List[MutableMapping[str, Any]] = [] - segment_field: List[MutableMapping[str, Any]] = [] - - raw_segments = document.get("segments") - if not isinstance(raw_segments, Sequence): - raw_segments = [] - - for index, raw_segment in enumerate(raw_segments): - if not isinstance(raw_segment, Mapping): - continue - segment_id = int( - raw_segment.get("segment_id") or raw_segment.get("id") or index - ) - label = str( - raw_segment.get("label") or raw_segment.get("category") or "segment" - ) - confidence_value = raw_segment.get("confidence") - confidence = float(confidence_value) if confidence_value is not None else 0.0 - segment_polygon = _coerce_polygon( - raw_segment.get("polygon") or raw_segment.get("bbox") - ) - - tokens = raw_segment.get("tokens") or raw_segment.get("words") - prepared_tokens = _prepare_tokens( - tokens, fallback_text=str(raw_segment.get("text", "")) - ) - - token_ids = [int(token["token_id"]) for token in prepared_tokens] - - segment_field.append( - { - "doc_id": doc_id, - "page_id": page_id, - "segment_id": segment_id, - "label": label, - "confidence": confidence, - "polygon": segment_polygon, - "token_ids": token_ids, - } - ) - - for token in prepared_tokens: - token_id = int(token["token_id"]) - token_text = str(token["text"]) - token_confidence = float(token["confidence"]) - token_polygon = _coerce_polygon(token.get("polygon")) - bbox = _polygon_to_bbox(token_polygon) - norm_bbox = _normalise_box(bbox, width, height) - - text_field.append( - { - "doc_id": doc_id, - "page_id": page_id, - "token_id": token_id, - "text": token_text, - "segment_id": segment_id, - "segment_label": label, - "confidence": token_confidence, - } - ) - layout_field.append( - { - "doc_id": doc_id, - "page_id": page_id, - "token_id": token_id, - "xyxy": norm_bbox, - "polygon": token_polygon, - "segment_id": segment_id, - } - ) - - return {"text": text_field, "layout": layout_field, "segment": segment_field} - - -def doclaynet_contains_table(document: Mapping[str, Any]) -> bool: - """Return ``True`` when a DocLayNet page includes a table segment.""" - - metadata = document.get("metadata") - if isinstance(metadata, Mapping): - for key in ("contains_table", "containsTable", "has_table"): - flag = metadata.get(key) - coerced = _coerce_bool(flag) - if coerced is not None: - return coerced - - segments = document.get("segments") or document.get("entities") - if isinstance(segments, Sequence): - for segment in segments: - if not isinstance(segment, Mapping): - continue - label = ( - segment.get("label") or segment.get("category") or segment.get("type") - ) - if isinstance(label, str) and "table" in label.lower(): - return True - - segment_metadata = segment.get("metadata") - if isinstance(segment_metadata, Mapping): - flag = segment_metadata.get("contains_table") - coerced = _coerce_bool(flag) - if coerced: - return True - - return False - - -def _prepare_tokens( - tokens: Optional[Iterable[Any]], - *, - fallback_text: str, -) -> List[Dict[str, Any]]: - prepared: List[Dict[str, Any]] = [] - if tokens is None: - tokens = [] - - for index, raw_token in enumerate(tokens): - if not isinstance(raw_token, Mapping): - continue - token_id = int(raw_token.get("token_id") or raw_token.get("id") or index) - text = str(raw_token.get("text") or "") - if not text and fallback_text: - text = fallback_text - confidence_value = raw_token.get("confidence") - confidence = float(confidence_value) if confidence_value is not None else 0.0 - polygon = _coerce_polygon(raw_token.get("polygon") or raw_token.get("bbox")) - prepared.append( - { - "token_id": token_id, - "text": text, - "confidence": confidence, - "polygon": polygon, - } - ) - - if not prepared and fallback_text: - prepared.append( - { - "token_id": 0, - "text": fallback_text, - "confidence": 0.0, - "polygon": [], - } - ) - - return prepared - - -def _load_sample(limit: Optional[int]) -> Iterator[Dict[str, Any]]: - for path in (_CACHE_PATH, _SAMPLE_PATH): - if not path.exists(): - continue - count = 0 - with path.open("r", encoding="utf-8") as handle: - for line in handle: - if not line.strip(): - continue - raw = json.loads(line) - yield _prepare_document( - raw, default_doc_id=str(raw.get("id") or "doclaynet-sample") - ) - count += 1 - if limit is not None and count >= limit: - return - if count: - return - - count = 0 - while True: - yield deepcopy(_SAMPLE_DOCUMENT) - count += 1 - if limit is not None and count >= limit: - break - if limit is None: - break - - -def _load_from_directory( - root: Path, *, split: str, limit: Optional[int] -) -> Iterator[Dict[str, Any]]: - split_path = root / split - if not split_path.exists(): - return - - count = 0 - for path in sorted(split_path.glob("*.json")): - data = json.loads(path.read_text(encoding="utf-8")) - document = _prepare_document(data, default_doc_id=path.stem) - yield document - count += 1 - if limit is not None and count >= limit: - return - - -def _prepare_document(raw: Mapping[str, Any], *, default_doc_id: str) -> Dict[str, Any]: - doc_id = str(raw.get("doc_id") or raw.get("id") or default_doc_id) - page_id = str(raw.get("page_id") or raw.get("page") or 0) - width_value = raw.get("width") or raw.get("img_width") or raw.get("page_width") - height_value = raw.get("height") or raw.get("img_height") or raw.get("page_height") - width = int(width_value) if width_value is not None else 0 - height = int(height_value) if height_value is not None else 0 - metadata_raw = raw.get("metadata") - metadata = dict(metadata_raw) if isinstance(metadata_raw, Mapping) else {} - metadata.setdefault("doc_id", doc_id) - metadata.setdefault("page_id", page_id) - metadata.setdefault("split", str(raw.get("split") or metadata.get("split") or "")) - - raw_tokens_source = raw.get("tokens") or raw.get("words") - tokens_by_id: Dict[str, Mapping[str, Any]] = {} - if isinstance(raw_tokens_source, Sequence): - for index, raw_token in enumerate(raw_tokens_source): - if not isinstance(raw_token, Mapping): - continue - token_id_value = raw_token.get("token_id") or raw_token.get("id") or index - try: - token_id = int(token_id_value) - except (TypeError, ValueError): - token_id = index - token_key = str(token_id) - tokens_by_id[token_key] = raw_token - - raw_segments = ( - raw.get("segments") or raw.get("entities") or raw.get("regions") or [] - ) - segments: List[Dict[str, Any]] = [] - for index, raw_segment in enumerate(raw_segments): - if not isinstance(raw_segment, Mapping): - continue - segment_id = int( - raw_segment.get("segment_id") or raw_segment.get("id") or index - ) - label = str( - raw_segment.get("label") or raw_segment.get("category") or "segment" - ) - confidence_value = raw_segment.get("confidence") - confidence = float(confidence_value) if confidence_value is not None else 0.0 - polygon = _coerce_polygon(raw_segment.get("polygon") or raw_segment.get("bbox")) - segment_tokens = raw_segment.get("tokens") or raw_segment.get("words") - token_ids = raw_segment.get("token_ids") or [] - if not segment_tokens and token_ids and tokens_by_id: - recovered_tokens = [] - if isinstance(token_ids, Sequence): - for token_id_value in token_ids: - token_key = str(token_id_value) - token_entry = tokens_by_id.get(token_key) - if token_entry is None: - try: - token_entry = tokens_by_id.get(str(int(token_id_value))) - except (TypeError, ValueError): - token_entry = None - if token_entry is None: - continue - token_map = dict(token_entry) - if "token_id" not in token_map: - token_map["token_id"] = ( - token_entry.get("id") - if isinstance(token_entry, Mapping) - else token_id_value - ) - recovered_tokens.append(token_map) - segment_tokens = recovered_tokens - - tokens = _prepare_tokens( - segment_tokens, - fallback_text=str(raw_segment.get("text", "")), - ) - segments.append( - { - "segment_id": segment_id, - "label": label, - "confidence": confidence, - "polygon": polygon, - "tokens": tokens, - } - ) - - document: Dict[str, Any] = { - "doc_id": doc_id, - "page_id": page_id, - "width": width, - "height": height, - "segments": segments, - "metadata": metadata, - } - return document - - -def _coerce_bool(value: Any) -> Optional[bool]: - if isinstance(value, bool): - return value - if isinstance(value, (int, float)): - return bool(value) - if isinstance(value, str): - lowered = value.strip().lower() - if lowered in {"true", "1", "yes", "y"}: - return True - if lowered in {"false", "0", "no", "n"}: - return False - return None - - -def _coerce_polygon(value: Any) -> List[float]: - if value is None: - return [] - if isinstance(value, Mapping): - coords = [] - for key in ("x1", "y1", "x2", "y2"): - coord_value = value.get(key) - if coord_value is None: - continue - coords.append(float(coord_value)) - if coords: - return coords - if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): - coords = [float(v) for v in value if isinstance(v, (int, float))] - return coords - return [] - - -def _polygon_to_bbox(polygon: Sequence[float]) -> Tuple[float, float, float, float]: - if not polygon: - return (0.0, 0.0, 0.0, 0.0) - xs = polygon[0::2] - ys = polygon[1::2] - if not xs or not ys: - return (0.0, 0.0, 0.0, 0.0) - x1 = min(xs) - y1 = min(ys) - x2 = max(xs) - y2 = max(ys) - return (float(x1), float(y1), float(x2), float(y2)) - - -def _normalise_box( - box: Tuple[float, float, float, float], width: float, height: float -) -> Tuple[float, float, float, float]: - x1, y1, x2, y2 = box - if width <= 0: - width = 1.0 - if height <= 0: - height = 1.0 - return (x1 / width, y1 / height, x2 / width, y2 / height) diff --git a/benchmarks/funsd.py b/benchmarks/funsd.py deleted file mode 100644 index c29421a..0000000 --- a/benchmarks/funsd.py +++ /dev/null @@ -1,445 +0,0 @@ -"""Utilities for loading and normalising the FUNSD form-understanding dataset.""" - -from __future__ import annotations - -import json -import importlib -import importlib.util -from pathlib import Path -from types import ModuleType -from typing import Any, Dict, Iterator, List, Mapping, MutableMapping, Optional, Sequence, cast - -from nd_llm.encoders import Encoder, LayoutEncoder, TextEncoder -from nd_llm.registry import Registry - -_DATA_DIR = Path(__file__).with_name("data") -_CACHE_PATH = _DATA_DIR.joinpath("funsd_cache.jsonl") -_SAMPLE_PATH = _DATA_DIR.joinpath("funsd_sample.jsonl") - -_FUNSD_SPLITS: Dict[str, List[str]] = { - "train": ["train", "training", "training_data"], - "training": ["train", "training", "training_data"], - "test": ["test", "testing", "testing_data"], - "testing": ["test", "testing", "testing_data"], - "validation": ["validation", "valid", "val", "dev", "testing", "testing_data"], - "val": ["validation", "valid", "val", "dev", "testing", "testing_data"], -} - - -def load_funsd_dataset( - root: Optional[Path | str] = None, - *, - split: str = "train", - limit: Optional[int] = None, - use_sample: Optional[bool] = None, -) -> List[Dict[str, Any]]: - """Load FUNSD documents from ``root`` or fall back to the bundled sample.""" - - documents: List[Dict[str, Any]] = [] - if use_sample or root is None: - documents = list(_load_sample(limit)) - else: - path = Path(root) - if not path.exists(): - raise FileNotFoundError(f"FUNSD root directory '{path}' does not exist") - documents = list(_load_from_directory(path, split=split, limit=limit)) - - if not documents and not use_sample and root is None: - documents = list(_load_sample(limit)) - return documents - - -def build_funsd_registry() -> Registry: - """Return a registry describing FUNSD token, layout and entity fields.""" - - registry = Registry() - registry.add_field("text", keys=["doc_id", "token_id"], salience=True, modality="text") - registry.add_field("layout", keys=["doc_id", "token_id"], modality="layout") - registry.add_field( - "entity", keys=["doc_id", "entity_id"], modality="entity", salience=True - ) - registry.add_affinity("text", "layout", keys=["doc_id", "token_id"]) - registry.add_affinity("entity", "text", keys=["doc_id"]) - registry.validate() - return registry - - -def build_funsd_encoders( - registry: Registry, - *, - text_dim: int = 8, - layout_dim: int = 6, - entity_dim: int = 4, -) -> Dict[str, Encoder]: - """Register simple encoder stubs for the FUNSD fields.""" - - encoders: Dict[str, Encoder] = { - "text": TextEncoder(embedding_dim=text_dim), - "layout": LayoutEncoder(embedding_dim=layout_dim), - "entity": TextEncoder(embedding_dim=entity_dim), - } - for field, encoder in encoders.items(): - registry.register_encoder(field, encoder) - return encoders - - -def funsd_fields(document: Mapping[str, Any]) -> Dict[str, List[MutableMapping[str, Any]]]: - """Convert a FUNSD document into registry-aligned field batches.""" - - doc_id = str(document.get("doc_id") or document.get("id") or "") - width, height = _resolve_size(document) - text_field: List[MutableMapping[str, Any]] = [] - layout_field: List[MutableMapping[str, Any]] = [] - entity_field: List[MutableMapping[str, Any]] = [] - document_box_mode = _resolve_box_mode(document) - - for entity_index, raw_entity in enumerate(document.get("form", [])): - entity_id = int(raw_entity.get("id", entity_index)) - label = str(raw_entity.get("label", "other")) - raw_text = raw_entity.get("text") - if not raw_text: - raw_text = " ".join(str(word.get("text", "")) for word in raw_entity.get("words", [])) - raw_text = str(raw_text).strip() - - token_ids: List[int] = [] - for word_index, raw_word in enumerate(raw_entity.get("words", [])): - token_id = raw_word.get("id") - if token_id is None: - token_id = len(text_field) - token_id = int(token_id) - token_ids.append(token_id) - token_text = str(raw_word.get("text", "")) - text_field.append( - { - "doc_id": doc_id, - "token_id": token_id, - "text": token_text, - "entity_id": entity_id, - "entity_label": label, - "is_answer": label.lower() == "answer", - } - ) - - bbox = raw_word.get("box") or raw_word.get("bbox") - box_mode = _resolve_box_mode(raw_word, raw_entity) - if box_mode is None: - box_mode = document_box_mode - norm_box = _normalise_box(bbox, width, height, mode=box_mode) - layout_field.append( - { - "doc_id": doc_id, - "token_id": token_id, - "xyxy": norm_box, - "entity_id": entity_id, - } - ) - - entity_field.append( - { - "doc_id": doc_id, - "entity_id": entity_id, - "label": label, - "text": raw_text, - "token_ids": token_ids, - } - ) - - return {"text": text_field, "layout": layout_field, "entity": entity_field} - - -def funsd_numeric_answer_label(document: Mapping[str, Any]) -> bool: - """Return ``True`` if an answer entity contains a digit.""" - - for entity in document.get("form", []): - label = str(entity.get("label", "")).lower() - if label != "answer": - continue - raw_text = entity.get("text") - if not raw_text: - raw_text = " ".join(str(word.get("text", "")) for word in entity.get("words", [])) - if any(char.isdigit() for char in str(raw_text)): - return True - return False - - -_PIL_IMAGE_MODULE: ModuleType | bool | None = None - - -def _load_sample(limit: Optional[int]) -> Iterator[Dict[str, Any]]: - for path in (_CACHE_PATH, _SAMPLE_PATH): - if not path.exists(): - continue - count = 0 - with path.open("r", encoding="utf-8") as handle: - for line in handle: - if not line.strip(): - continue - document = json.loads(line) - yield _prepare_document(document, document.get("id")) - count += 1 - if limit is not None and count >= limit: - return - if count: - return - - -def _load_from_directory(root: Path, *, split: str, limit: Optional[int]) -> Iterator[Dict[str, Any]]: - split_key = split.lower() - candidates = _FUNSD_SPLITS.get(split_key, [split_key]) - visited = set() - for candidate in candidates: - candidate = candidate.strip("/") - if not candidate: - continue - if candidate in visited: - continue - visited.add(candidate) - base = root / candidate - if not base.exists(): - continue - annotations_dir = base / "annotations" - if not annotations_dir.exists(): - continue - count = 0 - for path in sorted(annotations_dir.glob("*.json")): - data = json.loads(path.read_text(encoding="utf-8")) - document = _prepare_document(data, path.stem, source=path) - yield document - count += 1 - if limit is not None and count >= limit: - return - if count: - return - if split_key in {"validation", "val"}: - yield from _load_from_directory(root, split="test", limit=limit) - - -def _prepare_document( - raw: Mapping[str, Any], identifier: Optional[str], *, source: Optional[Path] = None -) -> Dict[str, Any]: - document: Dict[str, Any] = dict(raw) - doc_id = document.get("id") or document.get("doc_id") or identifier or "" - document["id"] = str(doc_id) - document["doc_id"] = str(doc_id) - if source is not None: - document.setdefault("_source_path", str(source)) - document.setdefault("_source_dir", str(source.parent)) - if "form" not in document: - if "annotations" in document and isinstance(document["annotations"], Sequence): - document["form"] = list(document["annotations"]) - else: - document["form"] = [] - width, height = _resolve_size(document) - document["width"] = width - document["height"] = height - return document - - -def _resolve_size(document: Mapping[str, Any]) -> tuple[float, float]: - for key in ("image_size", "img_size", "size"): - size = document.get(key) - if isinstance(size, Sequence) and len(size) >= 2: - seq_width = _coerce_float(size[0], 1000.0) - seq_height = _coerce_float(size[1], 1000.0) - return max(seq_width, 1.0), max(seq_height, 1.0) - image_meta = document.get("image") - if isinstance(image_meta, Mapping): - image_width = image_meta.get("width") or image_meta.get("w") - image_height = image_meta.get("height") or image_meta.get("h") - if image_width is not None and image_height is not None: - width_value = _coerce_float(image_width, 1000.0) - height_value = _coerce_float(image_height, 1000.0) - return max(width_value, 1.0), max(height_value, 1.0) - page = document.get("page_size") - if isinstance(page, Mapping): - page_width = page.get("width") or page.get("w") - page_height = page.get("height") or page.get("h") - if page_width is not None and page_height is not None: - width_value = _coerce_float(page_width, 1000.0) - height_value = _coerce_float(page_height, 1000.0) - return max(width_value, 1.0), max(height_value, 1.0) - doc_width = document.get("width") - doc_height = document.get("height") - if doc_width is not None and doc_height is not None: - width_value = _coerce_float(doc_width, 1000.0) - height_value = _coerce_float(doc_height, 1000.0) - return max(width_value, 1.0), max(height_value, 1.0) - - search_dirs: List[Path] = [] - for key in ("_source_path",): - value = document.get(key) - if isinstance(value, str) and value: - search_dirs.append(Path(value).parent) - for key in ("_source_dir", "image_dir", "image_root", "img_dir", "images_dir"): - value = document.get(key) - if isinstance(value, str) and value: - search_dirs.append(Path(value)) - search_dirs.append(Path.cwd()) - - image_paths: List[str] = [] - for key in ( - "image", - "image_path", - "img_path", - "img", - "path", - "file", - "file_path", - "filename", - ): - value = document.get(key) - if isinstance(value, str) and value: - image_paths.append(value) - elif isinstance(value, Mapping): - for inner_key in ("path", "file", "file_path", "filename"): - inner_value = value.get(inner_key) - if isinstance(inner_value, str) and inner_value: - image_paths.append(inner_value) - - for candidate in image_paths: - resolved = _resolve_image_candidate(candidate, search_dirs) - if resolved is None: - continue - size = _load_image_size(resolved) - if size is not None: - width_value, height_value = size - return max(width_value, 1.0), max(height_value, 1.0) - - width_fallback = _coerce_float(document.get("width"), 1000.0) - height_fallback = _coerce_float(document.get("height"), 1000.0) - return max(width_fallback, 1.0), max(height_fallback, 1.0) - - -def _normalise_box( - box: Any, - width: float, - height: float, - *, - mode: Optional[str] = None, -) -> List[float]: - if isinstance(box, Mapping): - candidates = [ - box.get("x1"), - box.get("y1"), - box.get("x2"), - box.get("y2"), - ] - if all(value is not None for value in candidates): - box = candidates - else: - alt_candidates = [ - box.get("x"), - box.get("y"), - box.get("w"), - box.get("h"), - ] - if all(value is not None for value in alt_candidates): - mode = mode or str(box.get("mode") or box.get("format") or "") - box = alt_candidates - if not isinstance(box, Sequence) or len(box) < 4: - return [0.0, 0.0, 0.0, 0.0] - w = max(float(width), 1.0) - h = max(float(height), 1.0) - coords = [float(box[i]) for i in range(4)] - left, top, third, fourth = coords - - box_mode = (mode or "").lower() - convert_xywh = box_mode in {"xywh", "wh", "width_height"} - - right = third - bottom = fourth - xyxy_valid = right >= left and bottom >= top - xywh_plausible = ( - left + third <= w + 1e-6 - and top + fourth <= h + 1e-6 - and third >= 0.0 - and fourth >= 0.0 - ) - - if not convert_xywh: - if not xyxy_valid and xywh_plausible: - convert_xywh = True - else: - right_norm = right / w - bottom_norm = bottom / h - if xywh_plausible and ( - right_norm > 1.0 or bottom_norm > 1.0 - ): - convert_xywh = True - - if convert_xywh: - right = left + third - bottom = top + fourth - - return [left / w, top / h, right / w, bottom / h] - - -def _resolve_box_mode(*sources: Mapping[str, Any]) -> Optional[str]: - for source in sources: - if not isinstance(source, Mapping): - continue - for key in ("box_mode", "bbox_mode", "bbox_format", "box_format", "mode"): - value = source.get(key) - if isinstance(value, str) and value: - return value - return None - - -def _resolve_image_candidate(path_value: str, search_dirs: Sequence[Path]) -> Optional[Path]: - candidate = Path(path_value) - if candidate.is_absolute() and candidate.exists(): - return candidate - for base in search_dirs: - resolved = base / candidate - if resolved.exists(): - return resolved - if candidate.exists(): - return candidate - return None - - -def _load_image_size(path: Path) -> Optional[tuple[float, float]]: - global _PIL_IMAGE_MODULE - module = _PIL_IMAGE_MODULE - if module is False: - return None - if module is None: - spec = importlib.util.find_spec("PIL.Image") - if spec is None: - _PIL_IMAGE_MODULE = False - return None - module = importlib.import_module("PIL.Image") - if not isinstance(module, ModuleType): - _PIL_IMAGE_MODULE = False - return None - _PIL_IMAGE_MODULE = module - if module is False: - return None - module = _PIL_IMAGE_MODULE - if module is False or module is None: - return None - assert isinstance(module, ModuleType) - ImageModule = cast(Any, module) - try: - with ImageModule.open(path) as handle: # type: ignore[call-arg] - return float(handle.width), float(handle.height) - except (FileNotFoundError, OSError): - return None - - -def _coerce_float(value: Any, default: float) -> float: - try: - if value is None: - raise TypeError("value is None") - return float(value) - except (TypeError, ValueError): - return default - - -__all__ = [ - "build_funsd_encoders", - "build_funsd_registry", - "funsd_fields", - "funsd_numeric_answer_label", - "load_funsd_dataset", -] diff --git a/docs/the-tensor-is-the-message.md b/docs/the-tensor-is-the-message.md index e45db5f..87dd48d 100644 --- a/docs/the-tensor-is-the-message.md +++ b/docs/the-tensor-is-the-message.md @@ -20,8 +20,8 @@ We translate these proofs into **code scaffolding** for ND-LLM: * a variable-rate **Token Bottleneck** that allocates \(K\) tokens to the most informative, co-registered cells; * a pluggable **Field Registry** for synchronized N-D inputs; * **InfoNCE mutual-information proxies** to maximize \(I(Y;Z)\) per token; -* an evaluation suite that plots empirical R–D curves and Fano-consistent error bounds; -* a memory integration point (**Semantic Tensor Memory**) and an **Auto-IB Orchestrator** sketch for adaptive data/route selection. +* an evaluation suite that plots empirical R–D curves and Fano-consistent error bounds (see `scripts/rd_audit.py`); +* a memory integration point (**Semantic Tensor Memory**) with holographic superpositions plus constraint modules, and an **Auto-IB Orchestrator** sketch for adaptive data/route selection. > “Structure that is known to the data should be paid once in the **input**, not many times in **inference**.” diff --git a/docs/tickets.md b/docs/tickets.md new file mode 100644 index 0000000..60ad32b --- /dev/null +++ b/docs/tickets.md @@ -0,0 +1,45 @@ +# Ticket Backlog + +Centralised backlog capturing the next implementation pushes required to reach the ND-LLM architecture described in the design docs. + +## T1 – Field Registry & Canonical Synchronisation Layer + +- **Goal:** Promote the ad-hoc field munging (e.g., `benchmarks/cord.py`, `nd_llm/model.py`) into a reusable registry that ingests arbitrary fields, aligns them into a canonical coordinate space (layout UV, timestamps, spans), and surfaces consistent tensors to the rest of the stack. +- **Scope:** Introduce a `FieldRegistry`/`FieldAdapter` module exposing encoder, alignment, and projection hooks; wire the canonical cell builder into `nd_llm.model.CanonicalCellAggregator`; add validation utilities under `nd_llm.registry`. +- **Acceptance:** Unit tests covering registration/alignment, docs describing how to register new fields, benchmarks updated to consume the registry. + +## T2 – Mutual-Information-Aware Token Bottleneck + +- **Goal:** Replace the FIFO/heuristic bottleneck with a learnable scorer that maximises an MI proxy per cell and supports variable-rate allocation, matching the “Mutual Information per Token” principle. +- **Scope:** Extend `nd_llm/bottleneck` with a scorer (InfoNCE/MINE surrogate), expose APIs for dynamic target budgets, integrate with the orchestrator’s budget allocator, add ablations showing R–D gains. +- **Acceptance:** Benchmark plots showing the MI-aware bottleneck dominating the current selector at equal token budgets; telemetry exposing per-token MI estimates. + +## T3 – Semantic Tensor Memory & Constraint Integration + +- **Goal:** Upgrade STM from an append-only log to a holographic tensor memory governed by IB policies, and integrate a neuro-symbolic constraint layer (LTN/TensorLog style) inside the reasoning loop. +- **Scope:** Enhance `nd_llm/stm` with write/read policies, add compression metrics, expose APIs for constraint modules; teach the orchestrator (`nd_llm/orchestration`) to route through STM and constraints. +- **Acceptance:** End-to-end example showing STM-assisted reasoning with constraints enabled, plus ablations (disable STM/constraints) demonstrating fidelity drops. + +## T4 – Rate–Distortion & Fano Audit CLI + +- **Goal:** Provide tooling to empirically validate the theory: sweep token budgets for text-only vs. N-D inputs, plot R–D curves, and compute Fano-consistent error bounds. +- **Scope:** New script/notebook under `scripts/` or `benchmarks/` to run both configurations, log metrics, and render plots; integrate with README/docs. +- **Acceptance:** Stored plot artifacts, README section linking to them, automated test (CI-safe subset) to ensure the CLI still runs. + +## T5 – Documentation & Story Alignment *(done)* + +- **Goal:** Keep the written narrative in lockstep with the implementation. +- **Scope:** Update `README.md` (bottleneck knobs, STM/superposition usage, rate–distortion CLI), `docs/the-tensor-is-the-message.md` (highlight empirical R–D tooling), and link the new registry/constraint features. +- **Acceptance:** Docs cite concrete modules and usage snippets; roadmap reflects the completed milestones. + +## T6 – Dataset Expansion & Evaluation Coverage *(done)* + +- **Goal:** Extend beyond the synthetic + CORD pairing once the registry exists, exercising field synchrony on at least one additional multi-field dataset (e.g., chart QA, timeline QA). +- **Scope:** Added the ChartQA harness (`benchmarks/chartqa.py`) with registry-aware field adapters, a bundled sample (`benchmarks/data/chartqa_sample.jsonl`), tests, and README instructions for plugging in the full dataset. +- **Acceptance:** New benchmark entry with tests, plus guidance in README/docs. + +## T7 – Ollama LLM Harness for Local Testing *(done)* + +- **Goal:** Exercise the encoder/orchestrator loop against a real LLM served locally via Ollama (e.g., `llama3.1:8b`) so developers can verify the end-to-end stack on Mac hardware. +- **Scope:** Added `scripts/ollama_harness.py` (with `--dry-run` mode), README instructions, and automated tests covering the prompt generation for both CORD and ChartQA samples. +- **Acceptance:** Documented instructions for using the harness, and a mock-tested path to keep CI green. diff --git a/examples/doc_benchmark_sweeps.ipynb b/examples/doc_benchmark_sweeps.ipynb deleted file mode 100644 index 9186a5e..0000000 --- a/examples/doc_benchmark_sweeps.ipynb +++ /dev/null @@ -1,99 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Document-understanding benchmark sweeps\n", - "This notebook compares synthetic invoice benchmarks against the bundled FUNSD sample.\n", - "It tracks accuracy, average tokens kept, and the information bottleneck proxies emitted per budget.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import sys\n", - "from pathlib import Path\n", - "\n", - "\n", - "def find_project_root(current: Path) -> Path:\n", - " for candidate in (current, *current.parents):\n", - " if (candidate / \"pyproject.toml\").exists():\n", - " return candidate\n", - " raise FileNotFoundError(\"Could not locate project root containing pyproject.toml\")\n", - "\n", - "\n", - "PROJECT_ROOT = find_project_root(Path.cwd())\n", - "if str(PROJECT_ROOT) not in sys.path:\n", - " sys.path.insert(0, str(PROJECT_ROOT))\n" - ] - }, - { - "cell_type": "code", - "metadata": {}, - "execution_count": null, - "outputs": [], - "source": [ - "from pprint import pprint\n", - "\n", - "from benchmarks.doc_understanding import run_benchmark, run_funsd_benchmark\n" - ] - }, - { - "cell_type": "code", - "metadata": {}, - "execution_count": null, - "outputs": [], - "source": [ - "synthetic_report = run_benchmark(budget_values=(2, 4, 6), dataset_size=8, threshold=400.0, seed=1)\n", - "funsd_report = run_funsd_benchmark(budget_values=(6, 8, 10), dataset_size=2, use_sample=True)\n", - "\n", - "print('Synthetic budgets:')\n", - "pprint(synthetic_report['budgets'])\n", - "print('\nFUNSD budgets:')\n", - "pprint(funsd_report['budgets'])\n" - ] - }, - { - "cell_type": "code", - "metadata": {}, - "execution_count": null, - "outputs": [], - "source": [ - "def summarise(report, label):\n", - " rows = []\n", - " for entry in report['budgets']:\n", - " metrics = entry.get('metrics', {})\n", - " rows.append({\n", - " 'benchmark': label,\n", - " 'budget': entry['budget'],\n", - " 'accuracy': entry['accuracy'],\n", - " 'average_kept_tokens': entry['average_kept_tokens'],\n", - " 'information_bound': metrics.get('information_bound'),\n", - " 'rate_distortion': metrics.get('rate_distortion'),\n", - " })\n", - " return rows\n", - "\n", - "comparison = summarise(synthetic_report, 'synthetic') + summarise(funsd_report, 'funsd-sample')\n", - "for row in comparison:\n", - " pprint(row)\n" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "name": "python", - "version": "3.10" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} \ No newline at end of file diff --git a/nd_llm/bottleneck/ib.py b/nd_llm/bottleneck/ib.py index 6667f48..b67c4ea 100644 --- a/nd_llm/bottleneck/ib.py +++ b/nd_llm/bottleneck/ib.py @@ -334,6 +334,7 @@ def __init__( scorer_config: Optional[Mapping[str, Any]] = None, learnable_scorer: Optional[nn.Module] = None, budget_allocator: Optional[BudgetAllocatorFn] = None, + mi_score_weight: float = 0.5, ) -> None: if target_budget <= 0: raise ValueError("target_budget must be positive") @@ -353,6 +354,9 @@ def __init__( self.learnable_scorer: Optional[nn.Module] = learnable_scorer self.scorer = scorer or self._resolve_objective(self.objective) self.budget_allocator = budget_allocator or RegistryAwareBudgetAllocator() + if mi_score_weight < 0.0 or mi_score_weight > 1.0: + raise ValueError("mi_score_weight must fall within [0, 1]") + self.mi_score_weight = float(mi_score_weight) self._score_tensors: Dict[str, Tensor] = {} def compress( @@ -370,7 +374,9 @@ def compress( metadata = self._normalise_field_metadata(fields.keys(), field_specs, registry) encoded = self._encode_fields(fields, encoders) - scoring_context = context or {} + scoring_context = dict(context) if context else {} + if mi_proxy is not None: + scoring_context["__mi_proxy__"] = mi_proxy self._score_tensors.clear() gating_scores = self._compute_scores(encoded, metadata, scoring_context) @@ -546,28 +552,122 @@ def _compute_scores( ) -> Dict[str, List[float]]: scores: Dict[str, List[float]] = {} for field, embeddings in encoded.items(): - raw_scores = self.scorer(field, embeddings, metadata.get(field, {}), context) - tensor: Optional[Tensor] = None - if torch is not None and isinstance(raw_scores, torch.Tensor): - tensor = raw_scores.squeeze(-1) if raw_scores.ndim > 1 else raw_scores - if tensor.ndim != 1: - raise ValueError( - "learnable scorer must return a 1D tensor of per-token scores" - ) - field_scores = [float(v) for v in tensor.detach().cpu().tolist()] - else: - field_scores = [ - float(v) for v in cast(Sequence[float], raw_scores) - ] - if len(field_scores) != len(embeddings): - raise ValueError( - f"scoring strategy returned {len(field_scores)} scores for {len(embeddings)} embeddings in field '{field}'" - ) - if tensor is not None: - self._score_tensors[field] = tensor + field_scores = self._score_field(field, embeddings, metadata, context) + mi_scores = self._mi_scores(field, embeddings, context) + if mi_scores is not None: + field_scores = self._blend_scores(field_scores, mi_scores) scores[field] = field_scores return scores + def _score_field( + self, + field: str, + embeddings: Sequence[Sequence[float]], + metadata: Mapping[str, FieldMetadata], + context: Mapping[str, Any], + ) -> List[float]: + raw_scores = self.scorer(field, embeddings, metadata.get(field, {}), context) + tensor: Optional[Tensor] = None + if torch is not None and isinstance(raw_scores, torch.Tensor): + tensor = raw_scores.squeeze(-1) if raw_scores.ndim > 1 else raw_scores + if tensor.ndim != 1: + raise ValueError("learnable scorer must return a 1D tensor of per-token scores") + field_scores = [float(v) for v in tensor.detach().cpu().tolist()] + else: + field_scores = [float(v) for v in cast(Sequence[float], raw_scores)] + if len(field_scores) != len(embeddings): + raise ValueError( + f"scoring strategy returned {len(field_scores)} scores for {len(embeddings)} embeddings in field '{field}'" + ) + if tensor is not None and torch is not None: + self._score_tensors[field] = tensor + return field_scores + + def _mi_scores( + self, + field: str, + embeddings: Sequence[Sequence[float]], + context: Mapping[str, Any], + ) -> Optional[List[float]]: + if self.mi_score_weight <= 0.0: + return None + if torch is None: + return None + if not embeddings: + return None + mi_proxy = context.get("__mi_proxy__") + if mi_proxy is None: + return None + targets = context.get("mi_targets") or context.get("target_embeddings") + if not isinstance(targets, Mapping): + return None + target_vector = targets.get(field) + if target_vector is None: + return None + try: + param = next(mi_proxy.parameters()) + except StopIteration: # pragma: no cover - MIProxy always defines params + param = None + if param is None: + device = torch.device("cpu") + dtype = torch.float32 + else: + device = param.device + dtype = param.dtype + try: + token_tensor = torch.as_tensor(embeddings, device=device, dtype=dtype) + except Exception: + return None + if token_tensor.ndim != 2 or token_tensor.size(1) == 0: + return None + target_tensor = torch.as_tensor(target_vector, device=device, dtype=dtype) + if target_tensor.ndim != 1: + target_tensor = target_tensor.view(-1) + with torch.no_grad(): + token_proj = mi_proxy.f(token_tensor) + target_proj = mi_proxy.h(target_tensor.unsqueeze(0)).squeeze(0) + if target_proj.ndim == 0: + target_proj = target_proj.unsqueeze(0) + token_proj = self._normalize_tensor(token_proj) + target_proj = self._normalize_tensor(target_proj.unsqueeze(0)).squeeze(0) + sims = torch.matmul(token_proj, target_proj) + tau = float(getattr(mi_proxy, "tau", 1.0) or 1.0) + sims = sims / tau + return sims.detach().cpu().tolist() + + @staticmethod + def _normalize_tensor(values: Tensor) -> Tensor: + norm = torch.linalg.norm(values, dim=-1, keepdim=True) + norm = norm.clamp_min(1e-6) + return values / norm + + def _blend_scores(self, base: Sequence[float], mi: Sequence[float]) -> List[float]: + if not base: + return list(mi) + if len(base) != len(mi): + return list(base) + weight = self.mi_score_weight + if weight >= 1.0: + return list(mi) + if weight <= 0.0: + return list(base) + base_norm = self._standardize_scores(base) + mi_norm = self._standardize_scores(mi) + blended = [ + (1.0 - weight) * b + weight * m + for b, m in zip(base_norm, mi_norm) + ] + return blended + + @staticmethod + def _standardize_scores(values: Sequence[float]) -> List[float]: + if not values: + return [] + mean = sum(values) / float(len(values)) + variance = sum((value - mean) ** 2 for value in values) / float(len(values)) + std = math.sqrt(variance) or 1.0 + return [(value - mean) / std for value in values] + def _select_indices( self, scores: Mapping[str, List[float]], diff --git a/nd_llm/constraints/__init__.py b/nd_llm/constraints/__init__.py new file mode 100644 index 0000000..6d5e7ae --- /dev/null +++ b/nd_llm/constraints/__init__.py @@ -0,0 +1,12 @@ +"""Constraint modules enforcing neuro-symbolic rules across STM events.""" + +from .base import ConstraintModule, ConstraintResult +from .field import FieldActivationConstraint +from .superposition import SuperpositionSimilarityConstraint + +__all__ = [ + "ConstraintModule", + "ConstraintResult", + "FieldActivationConstraint", + "SuperpositionSimilarityConstraint", +] diff --git a/nd_llm/constraints/base.py b/nd_llm/constraints/base.py new file mode 100644 index 0000000..575ccf1 --- /dev/null +++ b/nd_llm/constraints/base.py @@ -0,0 +1,46 @@ +"""Constraint interfaces used by the orchestrator.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Mapping, Optional, Protocol, runtime_checkable + +if False: # pragma: no cover - typing aid + from nd_llm.stm import STM + from nd_llm.orchestration.orchestrator import CompressionRecord, UsageEvent + + +@dataclass +class ConstraintResult: + """Evaluation outcome returned by constraint modules.""" + + name: str + satisfied: bool + confidence: float = 1.0 + details: Mapping[str, Any] = field(default_factory=dict) + + def to_dict(self) -> Mapping[str, Any]: + payload = { + "name": self.name, + "satisfied": bool(self.satisfied), + "confidence": float(self.confidence), + } + if self.details: + payload["details"] = dict(self.details) + return payload + + +@runtime_checkable +class ConstraintModule(Protocol): + """Protocol describing orchestrator-aware constraints.""" + + name: str + + def evaluate( + self, + *, + stm: "STM", + event: "UsageEvent", + compression: Optional["CompressionRecord"], + ) -> ConstraintResult: + """Evaluate the constraint against the current usage event.""" diff --git a/nd_llm/constraints/field.py b/nd_llm/constraints/field.py new file mode 100644 index 0000000..3d2a815 --- /dev/null +++ b/nd_llm/constraints/field.py @@ -0,0 +1,60 @@ +"""Field-level constraint modules.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Mapping, Optional + +from .base import ConstraintModule, ConstraintResult + + +def _count_tokens( + compression: Optional["CompressionRecord"], # type: ignore[name-defined] + field: str, +) -> int: + if compression is None: + return 0 + selected = compression.telemetry.get("selected_indices", {}) + if isinstance(selected, Mapping): + tokens = selected.get(field) + if isinstance(tokens, Mapping): + return len(list(tokens.values())) + if isinstance(tokens, (list, tuple)): + return len(tokens) + return 0 + + +@dataclass +class FieldActivationConstraint(ConstraintModule): + """Ensure a field keeps at least (and optionally at most) a number of tokens.""" + + field: str + min_tokens: int = 1 + max_tokens: Optional[int] = None + name: str = "field_activation" + + def evaluate( + self, + *, + stm: "STM", # type: ignore[name-defined] # pragma: no cover - protocol annotation + event: "UsageEvent", # type: ignore[name-defined] + compression: Optional["CompressionRecord"], # type: ignore[name-defined] + ) -> ConstraintResult: + count = _count_tokens(compression or event.compression, self.field) + within_max = True + if self.max_tokens is not None: + within_max = count <= int(self.max_tokens) + satisfied = count >= int(self.min_tokens) and within_max + details: Mapping[str, Any] = { + "field": self.field, + "count": int(count), + "min_tokens": int(self.min_tokens), + } + if self.max_tokens is not None: + details = {**details, "max_tokens": int(self.max_tokens)} + return ConstraintResult( + name=self.name, + satisfied=satisfied, + confidence=1.0, + details=details, + ) diff --git a/nd_llm/constraints/superposition.py b/nd_llm/constraints/superposition.py new file mode 100644 index 0000000..83e5cfd --- /dev/null +++ b/nd_llm/constraints/superposition.py @@ -0,0 +1,84 @@ +"""Constraints that operate on STM holographic superpositions.""" + +from __future__ import annotations + +import math +from dataclasses import dataclass +from typing import Any, Iterable, List, Optional + +from .base import ConstraintModule, ConstraintResult + + +def _flatten_tensor(payload: Any) -> List[float]: + if payload is None: + return [] + if isinstance(payload, (int, float)): + return [float(payload)] + if isinstance(payload, (list, tuple)): + flattened: List[float] = [] + for item in payload: + flattened.extend(_flatten_tensor(item)) + return flattened + if isinstance(payload, Iterable): + flattened: List[float] = [] + for item in payload: + flattened.extend(_flatten_tensor(item)) + return flattened + return [] + + +def _cosine_similarity(a: List[float], b: List[float]) -> float: + if not a or not b or len(a) != len(b): + return 0.0 + dot = sum(x * y for x, y in zip(a, b)) + norm_a = math.sqrt(sum(x * x for x in a)) + norm_b = math.sqrt(sum(y * y for y in b)) + if norm_a == 0.0 or norm_b == 0.0: + return 0.0 + return dot / (norm_a * norm_b) + + +@dataclass +class SuperpositionSimilarityConstraint(ConstraintModule): + """Ensure the current tensor is aligned with the historical superposition.""" + + channel: str + min_similarity: float = 0.2 + name: str = "superposition_similarity" + + def evaluate( + self, + *, + stm: "STM", # type: ignore[name-defined] # pragma: no cover - typing hook + event: "UsageEvent", # type: ignore[name-defined] + compression: Optional["CompressionRecord"], # type: ignore[name-defined] + ) -> ConstraintResult: + try: + historical, metadata = stm.read_superposition(self.channel, normalize=True) + except KeyError: + return ConstraintResult( + name=self.name, + satisfied=True, + details={"channel": self.channel, "reason": "empty_channel"}, + ) + current = _flatten_tensor(event.tensor) + if len(current) != len(historical): + return ConstraintResult( + name=self.name, + satisfied=True, + details={"channel": self.channel, "reason": "shape_mismatch"}, + ) + similarity = _cosine_similarity(current, historical) + satisfied = similarity >= float(self.min_similarity) + details = { + "channel": self.channel, + "similarity": float(similarity), + "threshold": float(self.min_similarity), + "weight": float(metadata.get("weight", 0.0) or 0.0), + } + return ConstraintResult( + name=self.name, + satisfied=satisfied, + confidence=max(0.0, min(1.0, similarity)), + details=details, + ) diff --git a/nd_llm/orchestration/orchestrator.py b/nd_llm/orchestration/orchestrator.py index bb55e3f..b160b3b 100644 --- a/nd_llm/orchestration/orchestrator.py +++ b/nd_llm/orchestration/orchestrator.py @@ -13,6 +13,7 @@ from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple, Union from nd_llm.bottleneck.ib import CompressionResult, CompressionTelemetry, IBottleneck +from nd_llm.constraints import ConstraintModule, ConstraintResult from nd_llm.orchestration.budget import ( BudgetDecision, BudgetObservation, @@ -795,6 +796,8 @@ def __init__( meta_model: Optional[BudgetMetaModel] = None, *, auto_attach_meta_model: bool = True, + constraints: Optional[Sequence[ConstraintModule]] = None, + superposition_channels: Optional[Sequence[str]] = None, ) -> None: self._stm = stm self._config = config @@ -807,6 +810,11 @@ def __init__( self._last_policy_metadata: Optional[Dict[str, Any]] = None self._recent_probe_outcomes: List[Dict[str, Any]] = [] self._probe_history_limit = 5 + self._constraints: List[ConstraintModule] = list(constraints or []) + self._recent_constraint_results: List[Dict[str, Any]] = [] + self._superposition_channels: Tuple[str, ...] = tuple( + channel for channel in (superposition_channels or []) if channel + ) @classmethod def from_components( @@ -822,6 +830,8 @@ def from_components( bottleneck: Optional["IBottleneck"] = None, meta_model: Optional[BudgetMetaModel] = None, auto_attach_meta_model: bool = True, + constraints: Optional[Sequence[ConstraintModule]] = None, + superposition_channels: Optional[Sequence[str]] = None, ) -> "Orchestrator": """Construct an orchestrator from primitive components.""" @@ -844,6 +854,8 @@ def from_components( bottleneck=bottleneck, meta_model=meta_model, auto_attach_meta_model=auto_attach_meta_model, + constraints=constraints, + superposition_channels=superposition_channels, ) @property @@ -970,6 +982,21 @@ def log_usage_event(self, event: UsageEvent) -> str: if layout_signature and "layout_signature" not in merged_compression: merged_compression["layout_signature"] = layout_signature + constraint_results = self._evaluate_constraints(event) + if constraint_results: + metadata["constraints"] = [result.to_dict() for result in constraint_results] + if any(not result.satisfied for result in constraint_results): + issues = metadata.setdefault("issues", []) + for result in constraint_results: + if not result.satisfied: + issues.append( + { + "type": "constraint_violation", + "constraint": result.name, + "details": dict(result.details), + } + ) + metadata.setdefault("task", metadata.get("policy_name", self._config.policy_name)) attempt_key = base_key @@ -985,6 +1012,7 @@ def log_usage_event(self, event: UsageEvent) -> str: metadata["duplicate_attempts"] = duplicate_attempts attempt_key = self._generate_key(prefix=base_key) + self._update_superpositions(event, metadata) self._usage_log.append(attempt_key) return attempt_key @@ -1379,6 +1407,51 @@ def _append_probe_outcome(self, outcome: Mapping[str, Any]) -> None: if self._probe_history_limit > 0 and len(self._recent_probe_outcomes) > self._probe_history_limit: self._recent_probe_outcomes = self._recent_probe_outcomes[-self._probe_history_limit :] + def _evaluate_constraints(self, event: UsageEvent) -> List[ConstraintResult]: + if not self._constraints: + return [] + results: List[ConstraintResult] = [] + for module in self._constraints: + try: + outcome = module.evaluate( + stm=self._stm, + event=event, + compression=event.compression, + ) + except Exception as exc: + outcome = ConstraintResult( + name=getattr(module, "name", module.__class__.__name__), + satisfied=False, + confidence=0.0, + details={"error": str(exc)}, + ) + results.append(outcome) + if results: + record = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "results": [result.to_dict() for result in results], + } + self._recent_constraint_results.append(record) + if self._probe_history_limit > 0: + self._recent_constraint_results = self._recent_constraint_results[-self._probe_history_limit :] + return results + + def _update_superpositions(self, event: UsageEvent, metadata: Mapping[str, Any]) -> None: + if not self._superposition_channels: + return + payload = event.tensor + if payload is None: + return + base_metadata = { + "policy_name": metadata.get("policy_name"), + "task": metadata.get("task"), + } + for channel in self._superposition_channels: + try: + self._stm.write_superposition(channel, payload, metadata=base_metadata) + except Exception: + continue + def _resolve_meta_model_probe_context( self, *, diff --git a/nd_llm/registry/__init__.py b/nd_llm/registry/__init__.py index 5b8d1b1..7342171 100644 --- a/nd_llm/registry/__init__.py +++ b/nd_llm/registry/__init__.py @@ -1,5 +1,21 @@ """Public registry exports for ND-LLM.""" from .models import AffinityRule, FieldSpec, Registry +from .adapters import ( + FieldAdapter, + FieldAdapterRegistry, + LayoutAligner, + normalise_box, + quad_to_box, +) -__all__ = ["AffinityRule", "FieldSpec", "Registry"] +__all__ = [ + "AffinityRule", + "FieldSpec", + "FieldAdapter", + "FieldAdapterRegistry", + "LayoutAligner", + "normalise_box", + "quad_to_box", + "Registry", +] diff --git a/nd_llm/registry/adapters.py b/nd_llm/registry/adapters.py new file mode 100644 index 0000000..1d13f32 --- /dev/null +++ b/nd_llm/registry/adapters.py @@ -0,0 +1,163 @@ +"""Field adapters and canonical alignment utilities for ND-LLM inputs.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableMapping, + Optional, + Sequence, +) + +Document = Mapping[str, Any] +Entry = MutableMapping[str, Any] +BuilderFn = Callable[[Document], Iterable[Entry]] +AlignerFn = Callable[[Document, Entry], Optional[Sequence[float]]] + + +def _ensure_sequence(value: Any) -> Optional[List[float]]: + if value is None: + return None + if isinstance(value, Mapping): + return [_ensure_number(value.get(key, 0.0)) for key in ("x1", "y1", "x2", "y2")] + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): + return [_ensure_number(val) for val in value[:4]] + return None + + +def _ensure_number(value: Any) -> float: + try: + return float(value) + except Exception: # pragma: no cover - defensive guardrail + return 0.0 + + +def _normalise_quad(quad: Any) -> Dict[str, float]: + if isinstance(quad, Mapping): + keys = ("x1", "y1", "x2", "y2", "x3", "y3", "x4", "y4") + return {key: _ensure_number(quad.get(key, 0.0)) for key in keys} + if isinstance(quad, Sequence) and len(quad) >= 8: + return { + "x1": _ensure_number(quad[0]), + "y1": _ensure_number(quad[1]), + "x2": _ensure_number(quad[2]), + "y2": _ensure_number(quad[3]), + "x3": _ensure_number(quad[4]), + "y3": _ensure_number(quad[5]), + "x4": _ensure_number(quad[6]), + "y4": _ensure_number(quad[7]), + } + return {key: 0.0 for key in ("x1", "y1", "x2", "y2", "x3", "y3", "x4", "y4")} + + +def quad_to_box(quad: Any) -> List[float]: + normalised = _normalise_quad(quad) + xs = [normalised["x1"], normalised["x2"], normalised["x3"], normalised["x4"]] + ys = [normalised["y1"], normalised["y2"], normalised["y3"], normalised["y4"]] + return [ + min(xs) if xs else 0.0, + min(ys) if ys else 0.0, + max(xs) if xs else 0.0, + max(ys) if ys else 0.0, + ] + + +def normalise_box(box: Sequence[float], *, width: float, height: float) -> List[float]: + """Normalise absolute coordinates to the unit square.""" + + safe_width = max(float(width) or 1.0, 1.0) + safe_height = max(float(height) or 1.0, 1.0) + if not box: + x1 = y1 = x2 = y2 = 0.0 + else: + x1 = _ensure_number(box[0]) + y1 = _ensure_number(box[1] if len(box) > 1 else 0.0) + x2 = _ensure_number(box[2] if len(box) > 2 else x1) + y2 = _ensure_number(box[3] if len(box) > 3 else y1) + return [ + max(0.0, min(1.0, x1 / safe_width)), + max(0.0, min(1.0, y1 / safe_height)), + max(0.0, min(1.0, x2 / safe_width)), + max(0.0, min(1.0, y2 / safe_height)), + ] + + +class LayoutAligner: + """Align layout-aware entries into document-normalised coordinates.""" + + def __init__( + self, + *, + quad_key: str = "quad", + width_key: str = "width", + height_key: str = "height", + ) -> None: + self._quad_key = quad_key + self._width_key = width_key + self._height_key = height_key + + def __call__(self, document: Document, entry: Entry) -> Optional[List[float]]: + quad = entry.get(self._quad_key) or entry.get("coords") or entry.get("xyxy") + if quad is None: + return None + width = document.get(self._width_key, 0.0) + height = document.get(self._height_key, 0.0) + box = quad_to_box(quad) + return normalise_box( + box, width=float(width or 1.0), height=float(height or 1.0) + ) + + +@dataclass +class FieldAdapter: + """Declarative adapter that prepares per-field entries with canonical coords.""" + + name: str + builder: BuilderFn + aligner: Optional[AlignerFn] = None + metadata: Mapping[str, Any] = field(default_factory=dict) + + def adapt(self, document: Document) -> List[Entry]: + adapted: List[Entry] = [] + for raw in self.builder(document): + entry: Entry = dict(raw) + if self.aligner is not None: + coords = self.aligner(document, entry) + if coords is not None: + entry["coords"] = list(coords) + # Provide xyxy alias to satisfy existing consumers. + entry.setdefault("xyxy", list(coords)) + adapted.append(entry) + return adapted + + +class FieldAdapterRegistry: + """Container that applies registered adapters to incoming documents.""" + + def __init__(self) -> None: + self._adapters: Dict[str, FieldAdapter] = {} + + def register(self, adapter: FieldAdapter) -> None: + if adapter.name in self._adapters: + raise ValueError(f"Field adapter '{adapter.name}' already registered") + self._adapters[adapter.name] = adapter + + def transform(self, document: Document) -> Dict[str, List[Entry]]: + return { + name: adapter.adapt(document) for name, adapter in self._adapters.items() + } + + def __contains__(self, name: str) -> bool: + return name in self._adapters + + def __len__(self) -> int: + return len(self._adapters) + + def __iter__(self): + return iter(self._adapters.items()) diff --git a/nd_llm/stm/stm.py b/nd_llm/stm/stm.py index 58e20ff..602610a 100644 --- a/nd_llm/stm/stm.py +++ b/nd_llm/stm/stm.py @@ -198,6 +198,100 @@ def list_by_task(self, task: str, limit: Optional[int] = None) -> Sequence[str]: matches = self.query(metadata_filter={"task": task}, limit=limit) return [key for key, _ in matches] + # ------------------------------------------------------------------ + # Superposition helpers + # ------------------------------------------------------------------ + def write_superposition( + self, + channel: str, + tensor: TensorLike, + *, + weight: float = 1.0, + metadata: Optional[Mapping[str, Any]] = None, + ) -> str: + """Accumulate ``tensor`` into the holographic superposition for ``channel``.""" + + if not channel: + raise ValueError("channel must be a non-empty string") + if weight == 0: + raise ValueError("weight must be non-zero") + + nested = self._to_nested_structure(tensor) + shape = self._infer_shape(nested) + flat = self._flatten_nested(nested) + scaled = [float(weight) * value for value in flat] + metadata_dict = self._normalize_metadata(metadata) + metadata_dict.update({"superposition": True, "channel": str(channel)}) + key = self._super_key(channel) + + with self._lock: + entry = self._index.get(key) + if entry is None: + entry = self._build_index_entry(key, shape, len(flat), metadata_dict) + existing: List[float] = [0.0] * len(flat) + total_weight = 0.0 + else: + existing = self._load_tensor_values(entry) + if len(existing) != len(scaled): + raise ValueError( + f"Superposition payload for channel '{channel}' has incompatible shape" + ) + total_weight = float(entry.get("metadata", {}).get("weight", 0.0)) + + combined = [current + delta for current, delta in zip(existing, scaled)] + total_weight += float(weight) + + metadata_dict["weight"] = total_weight + entry["metadata"] = metadata_dict + entry["shape"] = list(shape) + entry["length"] = len(flat) + entry.update(self._derive_index_annotations(metadata_dict)) + + tensor_path = self._storage_dir / entry["tensor_file"] + tmp_path = tensor_path.with_suffix(".tmp") + payload = array("d", combined).tobytes() + compressed = zlib.compress(payload) + with tmp_path.open("wb") as fh: + fh.write(compressed) + tmp_path.replace(tensor_path) + + self._index[key] = entry + self._save_index() + return key + + def read_superposition( + self, + channel: str, + *, + normalize: bool = True, + ) -> tuple[List[float], Dict[str, Any]]: + """Return the accumulated tensor for ``channel`` and its metadata.""" + + key = self._super_key(channel) + with self._lock: + entry = self._index.get(key) + if entry is None: + raise KeyError(f"Superposition channel '{channel}' is empty") + values = self._load_tensor_values(entry) + metadata = json.loads(json.dumps(entry.get("metadata", {}))) + weight = float(metadata.get("weight", 0.0) or 0.0) + if normalize and weight: + values = [value / weight for value in values] + return values, metadata + + def superposition_channels(self) -> List[str]: + """List channels with active holographic superpositions.""" + + channels: List[str] = [] + with self._lock: + for key, entry in self._index.items(): + if key.startswith("__super__::"): + metadata = entry.get("metadata", {}) + channel = metadata.get("channel") if isinstance(metadata, Mapping) else None + if isinstance(channel, str): + channels.append(channel) + return channels + def list_by_layout(self, layout_signature: str, limit: Optional[int] = None) -> Sequence[str]: matches = self.query(metadata_filter={"layout_signature": layout_signature}, limit=limit) return [key for key, _ in matches] @@ -238,6 +332,23 @@ def _tensor_filename(self, key: str) -> str: digest = hashlib.sha1(key.encode("utf-8")).hexdigest()[:8] return f"{safe_key}_{digest}.bin" + def _load_tensor_values(self, entry: Mapping[str, Any]) -> List[float]: + tensor_file = entry.get("tensor_file") + if not tensor_file: + return [] + tensor_path = self._storage_dir / str(tensor_file) + if not tensor_path.exists(): + return [] + compressed = tensor_path.read_bytes() + buffer = zlib.decompress(compressed) + values_array = array("d") + values_array.frombytes(buffer) + return values_array.tolist() + + @staticmethod + def _super_key(channel: str) -> str: + return f"__super__::{channel}" + def _prepare_payload(self, tensor: TensorLike) -> tuple[bytes, Sequence[int], int]: nested = self._to_nested_structure(tensor) shape = self._infer_shape(nested) diff --git a/pyproject.toml b/pyproject.toml index 93c3b4b..d16de7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,10 @@ docs = [ "mkdocs>=1.5", "mkdocs-material>=9.5", ] +benchmarks = [ + "datasets>=4.4", + "Pillow>=10.0", +] [tool.setuptools.packages.find] where = ["."] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..cbbf2fc --- /dev/null +++ b/requirements.txt @@ -0,0 +1,4 @@ +numpy>=1.24 +torch>=2.0 +datasets>=4.4 +Pillow>=10.0 diff --git a/scripts/download_datasets.py b/scripts/download_datasets.py deleted file mode 100644 index e5d6d67..0000000 --- a/scripts/download_datasets.py +++ /dev/null @@ -1,270 +0,0 @@ -#!/usr/bin/env python3 -"""Download official FUNSD and DocLayNet-base releases into a local cache.""" - -from __future__ import annotations - -import argparse -import hashlib -import os -import shutil -import sys -import tarfile -import tempfile -import zipfile -from dataclasses import dataclass -from pathlib import Path -from typing import Dict, Iterable, Optional, Sequence -from urllib.request import urlopen - -CACHE_ENV = "ND_LLM_DATA_CACHE" -DEFAULT_CACHE_DIR = Path.home() / ".cache" / "n-dimensional-llm" - - -@dataclass(frozen=True) -class DatasetSpec: - name: str - url: str - archive_name: str - checksum: Optional[str] - target_subdir: str - archive_type: str # "zip" or "tar" - - -_DATASETS: Dict[str, DatasetSpec] = { - "funsd": DatasetSpec( - name="funsd", - url="https://guillaumejaume.github.io/FUNSD/dataset.zip", - archive_name="funsd_dataset.zip", - checksum="c31735649e4f441bcbb4fd0f379574f7520b42286e80b01d80b445649d54761f", - target_subdir="funsd", - archive_type="zip", - ), - "doclaynet": DatasetSpec( - name="doclaynet", - url=( - "https://huggingface.co/datasets/pierreguillou/DocLayNet-base/blob/main/data/dataset_base.zip" - ), - archive_name="doclaynet-base.tar.gz", - checksum=None, - target_subdir="doclaynet", - archive_type="zip", - ), -} - - -class DownloadError(RuntimeError): - """Raised when a dataset download or extraction fails.""" - - -def parse_args(argv: Optional[Sequence[str]] = None) -> argparse.Namespace: - parser = argparse.ArgumentParser( - description=( - "Download and unpack the official FUNSD and DocLayNet-base datasets into a " - "cache directory. The cache location defaults to ~/.cache/n-dimensional-llm " - "or can be overridden via the ND_LLM_DATA_CACHE environment variable." - ) - ) - parser.add_argument( - "datasets", - nargs="*", - choices=sorted(_DATASETS), - default=sorted(_DATASETS), - help=( - "One or more datasets to download. Defaults to both FUNSD and DocLayNet-base." - ), - ) - parser.add_argument( - "--cache-dir", - type=Path, - default=None, - help=( - "Destination directory for downloaded datasets. Defaults to the value of " - "the ND_LLM_DATA_CACHE environment variable or ~/.cache/n-dimensional-llm." - ), - ) - parser.add_argument( - "--force", - action="store_true", - help="Re-download and overwrite existing dataset caches.", - ) - parser.add_argument( - "--skip-checksum", - action="store_true", - help="Skip checksum validation (not recommended).", - ) - parser.add_argument( - "--override-checksum", - action="append", - metavar="DATASET=SHA256", - help="Override the expected checksum for a dataset (e.g. funsd=deadbeef...).", - ) - return parser.parse_args(argv) - - -def main(argv: Optional[Sequence[str]] = None) -> int: - args = parse_args(argv) - cache_dir = resolve_cache_dir(args.cache_dir) - overrides = parse_overrides(args.override_checksum) - - cache_dir.mkdir(parents=True, exist_ok=True) - - for name in args.datasets: - spec = _DATASETS[name] - checksum = overrides.get(name, spec.checksum) - try: - download_and_extract(spec, cache_dir, checksum, args.force, args.skip_checksum) - except DownloadError as error: - print(f"Error downloading {name}: {error}", file=sys.stderr) - return 1 - return 0 - - -def resolve_cache_dir(cli_value: Optional[Path]) -> Path: - if cli_value is not None: - return cli_value.expanduser().resolve() - env_value = os.environ.get(CACHE_ENV) - if env_value: - return Path(env_value).expanduser().resolve() - return DEFAULT_CACHE_DIR - - -def parse_overrides(values: Optional[Iterable[str]]) -> Dict[str, str]: - overrides: Dict[str, str] = {} - if not values: - return overrides - for entry in values: - if not entry: - continue - if "=" not in entry: - raise DownloadError( - f"Invalid checksum override '{entry}'. Expected format DATASET=SHA256." - ) - dataset, checksum = entry.split("=", 1) - dataset = dataset.strip().lower() - checksum = checksum.strip().lower() - if dataset not in _DATASETS: - raise DownloadError( - f"Unknown dataset '{dataset}' in checksum override. Valid options: " - f"{', '.join(sorted(_DATASETS))}." - ) - if not checksum or any(char.isspace() for char in checksum): - raise DownloadError(f"Invalid checksum value for dataset '{dataset}'.") - overrides[dataset] = checksum - return overrides - - -def download_and_extract( - spec: DatasetSpec, - cache_dir: Path, - checksum: Optional[str], - force: bool, - skip_checksum: bool, -) -> None: - target_dir = cache_dir / spec.target_subdir - if target_dir.exists() and not force: - print(f"Skipping {spec.name}: already present at {target_dir}") - return - - if target_dir.exists() and force: - print(f"Removing existing directory {target_dir}") - shutil.rmtree(target_dir) - - with tempfile.TemporaryDirectory(prefix=f"{spec.name}-download-") as tmp_dir: - tmp_path = Path(tmp_dir) - archive_path = tmp_path / spec.archive_name - print(f"Downloading {spec.name} from {spec.url}") - fetch_to_file(spec.url, archive_path) - - if not skip_checksum: - if checksum: - print(f"Validating checksum for {spec.name}") - digest = sha256_path(archive_path) - if digest.lower() != checksum.lower(): - raise DownloadError( - "Checksum mismatch: expected " - f"{checksum.lower()} but received {digest.lower()}" - ) - else: - print( - "No checksum registered for this dataset; skipping validation by default" - ) - else: - print("Checksum validation skipped by user request") - - print(f"Extracting {spec.name} to {target_dir}") - extract_archive(archive_path, target_dir, spec.archive_type) - - print(f"Finished installing {spec.name} into {target_dir}") - - -def fetch_to_file(url: str, destination: Path) -> None: - try: - with urlopen(url) as response, destination.open("wb") as handle: - shutil.copyfileobj(response, handle) - except Exception as error: # pragma: no cover - network failure - raise DownloadError(f"Failed to download '{url}': {error}") from error - - -def sha256_path(path: Path) -> str: - digest = hashlib.sha256() - with path.open("rb") as handle: - for chunk in iter(lambda: handle.read(1024 * 1024), b""): - digest.update(chunk) - return digest.hexdigest() - - -def extract_archive(archive_path: Path, destination: Path, archive_type: str) -> None: - destination.mkdir(parents=True, exist_ok=True) - if archive_type == "zip": - _extract_zip(archive_path, destination) - elif archive_type == "tar": - _extract_tar(archive_path, destination) - else: # pragma: no cover - defensive - raise DownloadError(f"Unsupported archive type '{archive_type}'") - _normalise_layout(destination) - - -def _extract_zip(path: Path, destination: Path) -> None: - try: - with zipfile.ZipFile(path) as archive: - archive.extractall(destination) - except zipfile.BadZipFile as error: - raise DownloadError(f"Corrupt ZIP archive: {error}") from error - - -def _extract_tar(path: Path, destination: Path) -> None: - try: - with tarfile.open(path, mode="r:*") as archive: - safe_extract(archive, destination) - except tarfile.TarError as error: - raise DownloadError(f"Corrupt TAR archive: {error}") from error - - -def safe_extract(archive: tarfile.TarFile, destination: Path) -> None: - for member in archive.getmembers(): - member_path = destination / member.name - if not is_within_directory(destination, member_path): - raise DownloadError("Attempted path traversal in tar archive") - archive.extractall(destination) - - -def is_within_directory(directory: Path, target: Path) -> bool: - directory = directory.resolve() - target = target.resolve(strict=False) - return os.path.commonpath([str(directory), str(target)]) == str(directory) - - -def _normalise_layout(destination: Path) -> None: - entries = list(destination.iterdir()) - if len(entries) == 1 and entries[0].is_dir(): - inner = entries[0] - for child in inner.iterdir(): - target = destination / child.name - if target.exists(): - continue - child.rename(target) - inner.rmdir() - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/ollama_harness.py b/scripts/ollama_harness.py new file mode 100644 index 0000000..bfb8a88 --- /dev/null +++ b/scripts/ollama_harness.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +"""Run compressed field prompts through a local Ollama model.""" + +from __future__ import annotations + +import argparse +import json +import sys +import textwrap +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Sequence + +from benchmarks.cord import ( + build_cord_encoders, + build_cord_registry, + cord_fields, + cord_high_total_label, + cord_total_amount, + load_cord_dataset, +) +from benchmarks.chartqa import ( + build_chartqa_encoders, + build_chartqa_registry, + chartqa_answer, + chartqa_fields, + load_chartqa_dataset, +) +from nd_llm.bottleneck import IBottleneck +from nd_llm.utils import build_mi_proxy_context + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Compress structured fields then prompt a local Ollama model." + ) + parser.add_argument( + "--dataset", + choices=("cord", "chartqa"), + default="cord", + help="Dataset harness to use (default: cord).", + ) + parser.add_argument( + "--budget", + type=int, + default=6, + help="Target bottleneck budget (default: 6).", + ) + parser.add_argument( + "--dataset-size", + type=int, + default=1, + help="Number of documents to load before sampling one (default: 1).", + ) + parser.add_argument( + "--use-sample", + action="store_true", + help="Use bundled JSONL samples instead of the full dataset.", + ) + parser.add_argument( + "--data-root", + type=Path, + default=None, + help="Optional local path to the dataset (e.g. datasets/).", + ) + parser.add_argument( + "--model", + default="llama3.1:8b", + help="Ollama model name (default: llama3.1:8b).", + ) + parser.add_argument( + "--threshold", + type=float, + default=250_000.0, + help="CORD high-value threshold (default: 250000).", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print the prompt JSON without contacting the Ollama server.", + ) + parser.add_argument( + "--ollama-url", + default="http://127.0.0.1:11434", + help="Base URL for the Ollama server (default: http://127.0.0.1:11434).", + ) + return parser.parse_args() + + +def load_document(args: argparse.Namespace) -> tuple[Mapping[str, Any], str]: + if args.dataset == "cord": + docs = load_cord_dataset( + split="train", + limit=max(args.dataset_size, 1), + use_sample=args.use_sample, + data_root=args.data_root, + ) + if not docs: + raise RuntimeError("No CORD documents available for the requested configuration.") + return docs[0], "cord" + docs = load_chartqa_dataset( + split="test", + limit=max(args.dataset_size, 1), + use_sample=args.use_sample, + cache_dir=args.data_root, + ) + if not docs: + raise RuntimeError("No ChartQA documents available for the requested configuration.") + return docs[0], "chartqa" + + +def compress_document(document: Mapping[str, Any], dataset: str, budget: int) -> tuple[Any, Mapping[str, Any]]: + if dataset == "cord": + registry = build_cord_registry() + build_cord_encoders(registry) + fields = cord_fields(document) + else: + registry = build_chartqa_registry() + build_chartqa_encoders(registry) + fields = chartqa_fields(document) + mi_proxy, mi_context = build_mi_proxy_context(fields, registry.encoders, preferred_fields=tuple(fields)) + bottleneck = IBottleneck(target_budget=int(budget)) + result = bottleneck.compress(fields, encoders=registry.encoders, context=mi_context, mi_proxy=mi_proxy) + return result, fields + + +def build_prompt( + document: Mapping[str, Any], + result: Any, + dataset: str, + *, + threshold: float, +) -> str: + if dataset == "cord": + tokens = _collect_tokens(result.compressed_fields.get("text", [])) + total = cord_total_amount(document) + doc_id = document.get("doc_id") + return textwrap.dedent( + f""" + You are auditing a receipt (doc_id={doc_id}). Selected text tokens from the variable-rate encoder: + + {tokens or '[no tokens retained]'} + + The goal is to decide whether the receipt should be flagged as HIGH VALUE (total >= {threshold:,.0f}). + Respond with a single line explaining whether it should be flagged and cite the evidence. + """ + ).strip() + + question = str(document.get("question", "")) + chart_rows = document.get("chart") or [] + kept_rows = [ + f"{row.get('label')}: {row.get('value')}" + for row in result.compressed_fields.get("chart", []) + if isinstance(row, Mapping) + ] + chart_summary = kept_rows or [f"{row.get('label')}: {row.get('value')}" for row in chart_rows[:4]] + return textwrap.dedent( + f""" + You are answering a question about a chart. Selected chart entries: + + - """[1:] + + "\n- ".join(chart_summary) + + textwrap.dedent( + f""" + + Question: {question} + + Provide the answer and cite the most relevant entries. + """ + ) + ).strip() + + +def _collect_tokens(entries: Sequence[Any]) -> str: + tokens: List[str] = [] + for entry in entries: + if isinstance(entry, Mapping): + text = entry.get("text") + if text: + tokens.append(str(text)) + else: + tokens.append(str(entry)) + if not tokens: + return "" + return " / ".join(tokens[:30]) + + +def call_ollama(model: str, prompt: str, url: str) -> Dict[str, Any]: + payload = json.dumps({"model": model, "prompt": prompt, "stream": False}).encode("utf-8") + request = urllib.request.Request( + f"{url.rstrip('/')}/api/generate", + data=payload, + headers={"Content-Type": "application/json"}, + ) + try: + with urllib.request.urlopen(request, timeout=60) as response: + body = response.read().decode("utf-8") + data = json.loads(body) + return {"response": data.get("response", ""), "raw": data} + except urllib.error.URLError as exc: + return {"error": f"Failed to reach Ollama server: {exc}"} + + +def main() -> None: + args = parse_args() + document, dataset = load_document(args) + result, _ = compress_document(document, dataset, args.budget) + prompt = build_prompt(document, result, dataset, threshold=args.threshold) + summary: Dict[str, Any] = { + "dataset": dataset, + "doc_id": document.get("doc_id"), + "model": args.model, + "budget": int(args.budget), + "prompt": prompt, + } + if not args.dry_run: + summary["ollama"] = call_ollama(args.model, prompt, args.ollama_url) + print(json.dumps(summary, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/rd_audit.py b/scripts/rd_audit.py new file mode 100644 index 0000000..052d64a --- /dev/null +++ b/scripts/rd_audit.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +"""Rate–distortion and Fano audit utility for CORD receipts.""" + +from __future__ import annotations + +import argparse +import json +import math +from pathlib import Path +from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Tuple + +from benchmarks.cord import ( + build_cord_encoders, + build_cord_registry, + cord_fields, + cord_high_total_label, + load_cord_dataset, +) +from benchmarks import doc_understanding as doc_bench +from nd_llm.bottleneck import IBottleneck +from nd_llm.utils import build_mi_proxy_context + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=( + "Sweep token budgets for the CORD dataset in both N-D and text-only " + "configurations, emitting rate–distortion and Fano summaries." + ) + ) + parser.add_argument( + "--budgets", + type=int, + nargs="+", + default=[4, 8, 12], + help="Token budgets to evaluate (default: 4 8 12).", + ) + parser.add_argument( + "--dataset-size", + type=int, + default=8, + help="Number of documents to sample (default: 8).", + ) + parser.add_argument( + "--threshold", + type=float, + default=250_000.0, + help="Receipt total threshold for high-value classification.", + ) + parser.add_argument( + "--use-sample", + action="store_true", + help="Use the bundled JSON sample instead of downloading from Hugging Face.", + ) + parser.add_argument( + "--split", + default="train", + help="Dataset split passed to Hugging Face (default: train).", + ) + parser.add_argument( + "--cache-dir", + type=Path, + default=None, + help="Optional cache directory for Hugging Face datasets.", + ) + parser.add_argument( + "--output", + type=Path, + default=None, + help="Optional path to write the JSON report (defaults to stdout).", + ) + return parser.parse_args() + + +def label_entropy_bits(labels: Sequence[bool]) -> Tuple[float, int]: + counts: Dict[bool, int] = {} + for label in labels: + counts[label] = counts.get(label, 0) + 1 + total = sum(counts.values()) or 1 + entropy = 0.0 + for value in counts.values(): + prob = value / total + if prob > 0: + entropy -= prob * math.log2(prob) + return entropy, len(counts) + + +def evaluate_mode( + *, + mode: str, + budgets: Sequence[int], + dataset: Sequence[Mapping[str, Any]], + registry_encoders: Mapping[str, Any], + threshold: float, +) -> Tuple[List[Dict[str, Any]], float]: + predict_fn = doc_bench._cord_prediction_factory(threshold) # type: ignore[attr-defined] + budgets_summary: List[Dict[str, Any]] = [] + total_mi = 0.0 + total_mi_count = 0 + + for budget in budgets: + bottleneck = IBottleneck(target_budget=int(budget)) + correct = 0 + kept_totals = 0 + info_bound_sum = 0.0 + rate_dist_sum = 0.0 + mi_sum = 0.0 + mi_count = 0 + doc_count = 0 + + for document in dataset: + fields = cord_fields(document) + if mode == "text": + fields = {"text": list(fields.get("text", []))} + mi_proxy, mi_context = build_mi_proxy_context( + fields, + registry_encoders, + preferred_fields=tuple(fields), + ) + result = bottleneck.compress( + fields, + encoders=registry_encoders, + context=mi_context, + mi_proxy=mi_proxy, + ) + label = cord_high_total_label(document, threshold=threshold) + prediction = predict_fn(result, document) + if prediction == label: + correct += 1 + kept = sum(len(indices) for indices in result.telemetry.selected_indices.values()) + kept_totals += kept + metrics = result.metrics + info_bound_sum += float(metrics.get("information_bound", 0.0) or 0.0) + rate_dist_sum += float(metrics.get("rate_distortion", 0.0) or 0.0) + mi_value = metrics.get("mi_lower_bound") + if mi_value is not None: + mi_sum += float(mi_value) + mi_count += 1 + doc_count += 1 + + accuracy = float(correct) / float(doc_count or 1) + avg_kept = float(kept_totals) / float(doc_count or 1) + mean_info_bound = info_bound_sum / float(doc_count or 1) + mean_rate_dist = rate_dist_sum / float(doc_count or 1) + mean_mi = mi_sum / float(mi_count or 1) if mi_count else 0.0 + total_mi += mi_sum + total_mi_count += mi_count + + budgets_summary.append( + { + "budget": int(budget), + "accuracy": accuracy, + "distortion": 1.0 - accuracy, + "average_kept_tokens": avg_kept, + "mean_information_bound": mean_info_bound, + "mean_rate_distortion": mean_rate_dist, + "mean_mi_lower_bound_nats": mean_mi, + "evaluated_documents": doc_count, + } + ) + + overall_mi = total_mi / float(total_mi_count or 1) if total_mi_count else 0.0 + return budgets_summary, overall_mi + + +def fano_error_bound( + *, + label_entropy_bits: float, + num_labels: int, + mi_lower_bound_nats: float, +) -> float: + if num_labels <= 1: + return 0.0 + mi_bits = mi_lower_bound_nats / math.log(2) if mi_lower_bound_nats else 0.0 + numerator = label_entropy_bits - mi_bits - 1.0 + denom = math.log2(num_labels) + if denom <= 0: + return 0.0 + return max(0.0, numerator / denom) + + +def main() -> None: + args = parse_args() + dataset = load_cord_dataset( + split=args.split, + limit=args.dataset_size, + use_sample=args.use_sample, + cache_dir=args.cache_dir, + ) + if not dataset: + raise RuntimeError("No documents available for the requested configuration.") + + registry = build_cord_registry() + build_cord_encoders(registry) + labels = [cord_high_total_label(doc, threshold=args.threshold) for doc in dataset] + entropy_bits, num_labels = label_entropy_bits(labels) + + nd_summary, nd_mi = evaluate_mode( + mode="nd", + budgets=args.budgets, + dataset=dataset, + registry_encoders=registry.encoders, + threshold=args.threshold, + ) + text_summary, text_mi = evaluate_mode( + mode="text", + budgets=args.budgets, + dataset=dataset, + registry_encoders=registry.encoders, + threshold=args.threshold, + ) + + report = { + "dataset": "CORD-v2", + "split": args.split, + "dataset_size": len(dataset), + "use_sample": bool(args.use_sample), + "threshold": float(args.threshold), + "label_entropy_bits": entropy_bits, + "budgets": [int(b) for b in args.budgets], + "modes": { + "nd": { + "mean_mi_lower_bound_nats": nd_mi, + "fano_error_lower_bound": fano_error_bound( + label_entropy_bits=entropy_bits, + num_labels=num_labels, + mi_lower_bound_nats=nd_mi, + ), + "results": nd_summary, + }, + "text": { + "mean_mi_lower_bound_nats": text_mi, + "fano_error_lower_bound": fano_error_bound( + label_entropy_bits=entropy_bits, + num_labels=num_labels, + mi_lower_bound_nats=text_mi, + ), + "results": text_summary, + }, + }, + } + + if args.output: + args.output.write_text(json.dumps(report, indent=2, sort_keys=True)) + else: + print(json.dumps(report, indent=2, sort_keys=True)) + + +if __name__ == "__main__": + main() diff --git a/tests/benchmarks/test_chartqa.py b/tests/benchmarks/test_chartqa.py new file mode 100644 index 0000000..2725255 --- /dev/null +++ b/tests/benchmarks/test_chartqa.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from benchmarks.chartqa import ( + build_chartqa_encoders, + build_chartqa_registry, + chartqa_fields, + chartqa_answer, + load_chartqa_dataset, + run_chartqa_benchmark, +) + + +def test_chartqa_sample_fields() -> None: + dataset = load_chartqa_dataset(use_sample=True, limit=1) + assert dataset + registry = build_chartqa_registry() + build_chartqa_encoders(registry) + + doc = dataset[0] + fields = chartqa_fields(doc) + assert set(fields) == {"question", "chart"} + assert fields["question"] + assert fields["chart"] + assert chartqa_answer(doc) + + +def test_run_chartqa_benchmark_smoke() -> None: + report = run_chartqa_benchmark(budget_values=(2,), dataset_size=1, use_sample=True) + assert report["dataset"] == "ChartQA" + assert len(report["budgets"]) == 1 + entry = report["budgets"][0] + assert entry["budget"] == 2 + assert 0.0 <= entry["accuracy"] <= 1.0 diff --git a/tests/benchmarks/test_cord.py b/tests/benchmarks/test_cord.py new file mode 100644 index 0000000..81cafd7 --- /dev/null +++ b/tests/benchmarks/test_cord.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import pytest + +from benchmarks.cord import ( + build_cord_encoders, + build_cord_registry, + cord_amount_from_text, + cord_fields, + cord_high_total_label, + cord_total_amount, + load_cord_dataset, +) +from nd_llm.bottleneck import IBottleneck +from nd_llm.utils import build_mi_proxy_context +import json +from pathlib import Path + + +def test_cord_sample_roundtrip() -> None: + dataset = load_cord_dataset(use_sample=True, limit=1) + assert dataset, "Expected bundled CORD sample to be non-empty" + + registry = build_cord_registry() + build_cord_encoders(registry) + + document = dataset[0] + fields = cord_fields(document) + + assert set(fields) == {"text", "layout", "line"} + assert len(fields["text"]) == len(fields["layout"]) + assert all(item["xyxy"] for item in fields["layout"]) + assert all(0.0 <= coord <= 1.0 for entry in fields["layout"] for coord in entry["xyxy"]) + assert all(entry.get("coords") for entry in fields["text"]) + assert all(0.0 <= coord <= 1.0 for entry in fields["text"] for coord in entry["coords"]) + assert all(entry.get("coords") for entry in fields["line"]) + + mi_proxy, mi_context = build_mi_proxy_context( + fields, + registry.encoders, + preferred_fields=("text", "line"), + ) + bottleneck = IBottleneck(target_budget=4) + result = bottleneck.compress( + fields, + encoders=registry.encoders, + context=mi_context, + mi_proxy=mi_proxy, + ) + assert "text" in result.compressed_fields + assert cord_total_amount(document) > 0 + assert isinstance(cord_high_total_label(document, threshold=5.0), bool) + + +def test_cord_amount_from_text_variants() -> None: + assert cord_amount_from_text("1,234.56") == pytest.approx(1234.56) + assert cord_amount_from_text("Total 59,000") == pytest.approx(59000.0) + assert cord_amount_from_text("Grand Total 0") == 0.0 + + +def test_load_cord_dataset_from_local_root(tmp_path) -> None: + root = tmp_path / "CORD" + split_dir = root / "train" / "json" + split_dir.mkdir(parents=True) + sample_path = Path(__file__).resolve().parents[2] / "benchmarks" / "data" / "cord_sample.jsonl" + sample_line = json.loads(sample_path.read_text().splitlines()[0]) + (split_dir / "sample.json").write_text(json.dumps(sample_line)) + + dataset = load_cord_dataset(split="train", use_sample=False, data_root=root, limit=1) + assert dataset and dataset[0]["doc_id"] == "cord-sample-0" diff --git a/tests/benchmarks/test_doc_understanding.py b/tests/benchmarks/test_doc_understanding.py index 1162e28..61c1444 100644 --- a/tests/benchmarks/test_doc_understanding.py +++ b/tests/benchmarks/test_doc_understanding.py @@ -1,10 +1,6 @@ from __future__ import annotations -from benchmarks.doc_understanding import ( - run_benchmark, - run_doclaynet_benchmark, - run_funsd_benchmark, -) +from benchmarks.doc_understanding import run_benchmark, run_cord_benchmark def test_run_benchmark_smoke() -> None: @@ -41,35 +37,16 @@ def test_run_benchmark_smoke() -> None: assert "accuracy" in ablation -def test_run_funsd_benchmark_smoke() -> None: - report = run_funsd_benchmark(budget_values=(6,), dataset_size=2, use_sample=True) +def test_run_cord_benchmark_smoke() -> None: + report = run_cord_benchmark( + budget_values=(4,), + dataset_size=1, + use_sample=True, + threshold=10.0, + ) - assert report["dataset"] == "FUNSD" - assert report["dataset_size"] == 2 - assert report["use_sample"] is True - assert len(report["budgets"]) == 1 - - entry = report["budgets"][0] - assert entry["budget"] == 6 - assert 0.0 <= entry["accuracy"] <= 1.0 - assert isinstance(entry["metrics"], dict) - assert "cell_fusions" in entry - assert entry["cell_fusions"] - assert "ablations" in entry - assert isinstance(entry["ablations"], dict) - assert "drop_layout" in entry["ablations"] - drop_layout_accuracy = entry["ablations"]["drop_layout"]["accuracy"] - assert drop_layout_accuracy <= entry["accuracy"] - assert "drop_text" in entry["ablations"] - drop_text_accuracy = entry["ablations"]["drop_text"]["accuracy"] - assert drop_text_accuracy <= entry["accuracy"] - - -def test_run_doclaynet_benchmark_smoke() -> None: - report = run_doclaynet_benchmark(budget_values=(4,), dataset_size=2, use_sample=True) - - assert report["dataset"] == "DocLayNet" - assert report["dataset_size"] == 2 + assert report["dataset"] == "CORD-v2" + assert report["dataset_size"] == 1 assert report["use_sample"] is True assert len(report["budgets"]) == 1 @@ -77,7 +54,6 @@ def test_run_doclaynet_benchmark_smoke() -> None: assert entry["budget"] == 4 assert 0.0 <= entry["accuracy"] <= 1.0 assert isinstance(entry["metrics"], dict) - assert "encoder_latency_seconds" in entry["metrics"] assert "cell_fusions" in entry assert entry["cell_fusions"] assert "ablations" in entry diff --git a/tests/benchmarks/test_doclaynet.py b/tests/benchmarks/test_doclaynet.py deleted file mode 100644 index 91dee74..0000000 --- a/tests/benchmarks/test_doclaynet.py +++ /dev/null @@ -1,64 +0,0 @@ -from __future__ import annotations - -from benchmarks.doclaynet import ( - build_doclaynet_encoders, - build_doclaynet_registry, - doclaynet_contains_table, - doclaynet_fields, - load_doclaynet_dataset, -) -from nd_llm.bottleneck import IBottleneck -from nd_llm.utils import build_mi_proxy_context - - -def test_doclaynet_sample_roundtrip() -> None: - dataset = load_doclaynet_dataset(use_sample=True, limit=3) - assert dataset, "Expected bundled DocLayNet sample to be non-empty" - - registry = build_doclaynet_registry() - build_doclaynet_encoders(registry) - - document = dataset[0] - fields = doclaynet_fields(document) - - assert set(fields) == {"text", "layout", "segment"} - assert len(fields["text"]) == len(fields["layout"]) >= 1 - - layout_by_token = { - (entry["segment_id"], entry["token_id"]): entry for entry in fields["layout"] - } - segments_by_id = {entry["segment_id"]: entry for entry in fields["segment"]} - - for text_entry in fields["text"]: - token_key = (text_entry["segment_id"], text_entry["token_id"]) - assert token_key in layout_by_token - layout_entry = layout_by_token[token_key] - assert layout_entry["segment_id"] == text_entry["segment_id"] - segment_id = text_entry["segment_id"] - if segment_id in segments_by_id: - assert text_entry["token_id"] in segments_by_id[segment_id].get( - "token_ids", [] - ) - - assert all(0.0 <= coord <= 1.0 for item in fields["layout"] for coord in item.get("xyxy", [])) - assert isinstance(doclaynet_contains_table(document), bool) - - bottleneck = IBottleneck(target_budget=4) - mi_proxy, mi_context = build_mi_proxy_context( - fields, - registry.encoders, - preferred_fields=("layout", "text"), - ) - result = bottleneck.compress( - fields, - encoders=registry.encoders, - context=mi_context, - mi_proxy=mi_proxy, - ) - assert "segment" in result.compressed_fields - assert result.compressed_fields["segment"], "Expected at least one segment to be retained" - - -def test_doclaynet_dataset_limit() -> None: - dataset = load_doclaynet_dataset(use_sample=True, limit=2) - assert len(dataset) == 2 diff --git a/tests/benchmarks/test_funsd.py b/tests/benchmarks/test_funsd.py deleted file mode 100644 index 7eb4c3e..0000000 --- a/tests/benchmarks/test_funsd.py +++ /dev/null @@ -1,74 +0,0 @@ -from __future__ import annotations - -import pytest - -from benchmarks.funsd import ( - build_funsd_encoders, - build_funsd_registry, - funsd_fields, - funsd_numeric_answer_label, - load_funsd_dataset, - _normalise_box, - _resolve_size, -) -from nd_llm.bottleneck import IBottleneck -from nd_llm.utils import build_mi_proxy_context - - -def test_funsd_sample_roundtrip() -> None: - dataset = load_funsd_dataset(use_sample=True, limit=3) - assert dataset, "Expected bundled FUNSD sample to be non-empty" - - registry = build_funsd_registry() - build_funsd_encoders(registry) - - document = dataset[0] - fields = funsd_fields(document) - - assert set(fields) == {"text", "layout", "entity"} - assert len(fields["text"]) == len(fields["layout"]) - assert all(0.0 <= coord <= 1.0 for item in fields["layout"] for coord in item.get("xyxy", [])) - - bottleneck = IBottleneck(target_budget=6) - mi_proxy, mi_context = build_mi_proxy_context( - fields, - registry.encoders, - preferred_fields=("text",), - ) - result = bottleneck.compress( - fields, - encoders=registry.encoders, - context=mi_context, - mi_proxy=mi_proxy, - ) - assert "text" in result.compressed_fields - assert isinstance(funsd_numeric_answer_label(document), bool) - - -def test_normalise_box_xyxy_passthrough() -> None: - result = _normalise_box([10, 20, 50, 60], width=100, height=100) - assert result == [0.1, 0.2, 0.5, 0.6] - - -def test_normalise_box_xywh_with_mode() -> None: - result = _normalise_box([10, 20, 30, 40], width=200, height=200, mode="xywh") - assert result == [0.05, 0.1, 0.2, 0.3] - - -def test_normalise_box_xywh_inferred() -> None: - result = _normalise_box([100, 120, 10, 20], width=200, height=200) - assert result == [0.5, 0.6, 0.55, 0.7] - - -def test_resolve_size_reads_image(tmp_path) -> None: - pytest.importorskip("PIL.Image") - image_path = tmp_path / "sample.png" - image_path.write_bytes( - ( - b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01" - b"\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc``\x00\x00\x00\x02\x00\x01" - b"\xe2!\xbc3\x00\x00\x00\x00IEND\xaeB`\x82" - ) - ) - width, height = _resolve_size({"image": str(image_path)}) - assert (width, height) == (1.0, 1.0) diff --git a/tests/bottleneck/test_ibottleneck.py b/tests/bottleneck/test_ibottleneck.py index 34063f2..97842ee 100644 --- a/tests/bottleneck/test_ibottleneck.py +++ b/tests/bottleneck/test_ibottleneck.py @@ -2,6 +2,7 @@ import pytest import torch +import torch.nn as nn from nd_llm.bottleneck import ( IBottleneck, @@ -74,6 +75,39 @@ def test_query_conditioned_scoring_prefers_query_aligned_tokens(): assert result.telemetry.selected_scores["text"][0] == pytest.approx(1.0) +class DeterministicMIProxy(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.tau = 1.0 + self.register_parameter("_scale", nn.Parameter(torch.ones(1))) + self._dim = dim + + def f(self, z: torch.Tensor) -> torch.Tensor: # noqa: D401 - matching MIProxy API + return z + + def h(self, target: torch.Tensor) -> torch.Tensor: + return target + + def forward(self, z: torch.Tensor, y: torch.Tensor): # noqa: D401 - compatibility shim + logits = torch.zeros(z.size(0), y.size(0), device=z.device, dtype=z.dtype) + zero = torch.zeros((), device=z.device, dtype=z.dtype) + return zero, logits + + +def test_mi_scoring_prioritises_target_aligned_tokens(): + fields = {"text": ["t0", "t1", "t2"]} + encoders = { + "text": MockEncoder([[1.0, 0.0], [0.0, 5.0], [0.0, 4.5]]), + } + context = {"mi_targets": {"text": [1.0, 0.0]}} + proxy = DeterministicMIProxy(dim=2) + + bottleneck = IBottleneck(target_budget=1, mi_score_weight=0.8) + result = bottleneck.compress(fields, encoders, context=context, mi_proxy=proxy) + + assert result.telemetry.selected_indices["text"] == [0] + + def test_budget_allocator_respects_salience_metadata(): registry = Registry() registry.add_field("salient", keys=["doc_id"], salience=True) diff --git a/tests/constraints/test_constraints.py b/tests/constraints/test_constraints.py new file mode 100644 index 0000000..5e8b8ee --- /dev/null +++ b/tests/constraints/test_constraints.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from nd_llm.constraints import ( + FieldActivationConstraint, + SuperpositionSimilarityConstraint, +) +from nd_llm.orchestration.orchestrator import CompressionRecord, UsageEvent +from nd_llm.stm import STM +from nd_llm.utils import STMConfig + + +def _make_record(selected_text: int) -> CompressionRecord: + telemetry = { + "selected_indices": {"text": list(range(selected_text))}, + "token_counts": {"text": selected_text}, + } + return CompressionRecord( + compressed_fields={"text": ["t"] * selected_text}, + telemetry=telemetry, + metrics={}, + ) + + +def test_field_activation_constraint_flags_underflow(tmp_path) -> None: + stm = STM(STMConfig(storage_dir=tmp_path)) + constraint = FieldActivationConstraint(field="text", min_tokens=2, max_tokens=3) + record = _make_record(1) + event = UsageEvent(tensor=[[]], metadata={}, compression=record) + + result = constraint.evaluate(stm=stm, event=event, compression=record) + assert result.name == "field_activation" + assert result.satisfied is False + assert result.details["count"] == 1 + + +def test_superposition_similarity_constraint_reads_memory(tmp_path) -> None: + stm = STM(STMConfig(storage_dir=tmp_path)) + stm.write_superposition("usage", [1.0, 0.0], metadata={"task": "demo"}) + constraint = SuperpositionSimilarityConstraint(channel="usage", min_similarity=0.5) + record = _make_record(1) + event = UsageEvent(tensor=[1.0, 0.0], metadata={}, compression=record) + + result = constraint.evaluate(stm=stm, event=event, compression=record) + assert result.satisfied is True + assert result.details["channel"] == "usage" diff --git a/tests/orchestration/test_orchestrator.py b/tests/orchestration/test_orchestrator.py index 10912a7..2d2437b 100644 --- a/tests/orchestration/test_orchestrator.py +++ b/tests/orchestration/test_orchestrator.py @@ -10,6 +10,8 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[2])) +from nd_llm.bottleneck import CompressionResult, CompressionTelemetry +from nd_llm.constraints import FieldActivationConstraint from nd_llm.orchestration import ( CompressionRecord, CompressionRatioBudgetStrategy, @@ -384,3 +386,46 @@ def test_from_components_attaches_default_meta_model(tmp_path) -> None: assert meta_summary["model"] == model_name assert meta_summary["selected"]["candidate"]["budget"] == pytest.approx(proposed) assert meta_summary["selected"]["score"] > 0 + + +def test_orchestrator_runs_constraints_and_superpositions(tmp_path) -> None: + storage_config = STMConfig(storage_dir=tmp_path) + stm = STM(storage_config) + constraint = FieldActivationConstraint(field="text", min_tokens=1) + orchestrator = Orchestrator( + stm=stm, + config=OrchestratorConfig(target_budget=1.0, policy_name="constraint-test", budget_step=0.5), + constraints=[constraint], + superposition_channels=("usage",), + ) + + telemetry = CompressionTelemetry( + selected_indices={"text": [0]}, + selected_scores={"text": [0.5]}, + token_counts={"text": 1}, + budget=1, + field_budgets={"text": 1}, + allocation_weights={"text": 1.0}, + dropped_indices={"text": []}, + residual_statistics={}, + quantized_embeddings={}, + ) + result = CompressionResult( + compressed_fields={"text": ["token"]}, + telemetry=telemetry, + metrics={"ib_proxy": 0.1}, + ) + record = CompressionRecord.from_result(result, bottleneck="ib-test") + event = UsageEvent( + key="constraint-event", + tensor=[0.1, 0.2], + metadata={}, + compression=record, + ) + + key = orchestrator.log_usage_event(event) + entry = stm.get_index_entry(key) + assert "constraints" in entry["metadata"] + vector, info = stm.read_superposition("usage") + assert vector + assert info["channel"] == "usage" diff --git a/tests/registry/test_field_adapters.py b/tests/registry/test_field_adapters.py new file mode 100644 index 0000000..bc01193 --- /dev/null +++ b/tests/registry/test_field_adapters.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from nd_llm.registry import FieldAdapter, FieldAdapterRegistry, LayoutAligner, quad_to_box + + +def test_field_adapter_registry_adds_coords() -> None: + def builder(document): + yield {"quad": {"x1": 0, "y1": 0, "x2": 50, "y2": 0, "x3": 50, "y3": 100, "x4": 0, "y4": 100}} + + registry = FieldAdapterRegistry() + registry.register(FieldAdapter(name="layout", builder=builder, aligner=LayoutAligner())) + + transformed = registry.transform({"width": 100, "height": 200}) + assert "layout" in transformed + entry = transformed["layout"][0] + assert "coords" in entry + assert entry["coords"] == [0.0, 0.0, 0.5, 0.5] + assert entry["xyxy"] == entry["coords"] + + +def test_quad_to_box_handles_complete_values() -> None: + box = quad_to_box( + {"x1": 10, "y1": 20, "x2": 30, "y2": 20, "x3": 30, "y3": 80, "x4": 10, "y4": 80} + ) + assert box == [10.0, 20.0, 30.0, 80.0] diff --git a/tests/scripts/test_ollama_harness.py b/tests/scripts/test_ollama_harness.py new file mode 100644 index 0000000..5341b6d --- /dev/null +++ b/tests/scripts/test_ollama_harness.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def run_harness(args): + return subprocess.run( + [sys.executable, "-m", "scripts.ollama_harness", *args], + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=True, + ).stdout + + +def test_ollama_harness_dry_run_cord_sample(): + output = run_harness(["--dataset", "cord", "--use-sample", "--dry-run"]) + payload = json.loads(output) + assert payload["dataset"] == "cord" + assert "prompt" in payload + assert payload["model"] == "llama3.1:8b" + + +def test_ollama_harness_dry_run_chartqa_sample(): + output = run_harness(["--dataset", "chartqa", "--use-sample", "--dry-run"]) + payload = json.loads(output) + assert payload["dataset"] == "chartqa" + assert "Question" in payload["prompt"] diff --git a/tests/scripts/test_rd_audit.py b/tests/scripts/test_rd_audit.py new file mode 100644 index 0000000..396e712 --- /dev/null +++ b/tests/scripts/test_rd_audit.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import json +import subprocess +import sys +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[2] + + +def run(command: list[str]) -> str: + result = subprocess.run( + command, + cwd=REPO_ROOT, + capture_output=True, + text=True, + check=True, + ) + return result.stdout + + +def test_rd_audit_cli_runs_with_sample_dataset(): + output = run( + [ + sys.executable, + "-m", + "scripts.rd_audit", + "--budgets", + "4", + "--dataset-size", + "1", + "--use-sample", + ] + ) + report = json.loads(output) + assert report["dataset"] == "CORD-v2" + assert report["dataset_size"] == 1 + assert "modes" in report + assert set(report["modes"]) == {"nd", "text"} diff --git a/tests/stm/test_superposition.py b/tests/stm/test_superposition.py new file mode 100644 index 0000000..0b71eb4 --- /dev/null +++ b/tests/stm/test_superposition.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import pytest + +from nd_llm.stm import STM +from nd_llm.utils import STMConfig + + +def test_superposition_write_and_read(tmp_path) -> None: + stm = STM(STMConfig(storage_dir=tmp_path)) + + stm.write_superposition("usage", [1.0, 2.0], weight=1.0, metadata={"task": "alpha"}) + stm.write_superposition("usage", [3.0, 1.0], weight=2.0) + + vector, metadata = stm.read_superposition("usage") + assert pytest.approx(vector[0], rel=1e-6) == (1.0 + 6.0) / 3.0 + assert pytest.approx(vector[1], rel=1e-6) == (2.0 + 2.0) / 3.0 + assert metadata["channel"] == "usage" + assert pytest.approx(metadata["weight"], rel=1e-6) == 3.0 + + raw_vector, _ = stm.read_superposition("usage", normalize=False) + assert raw_vector == pytest.approx([7.0, 4.0])