diff --git a/src/winml/modelkit/commands/run.py b/src/winml/modelkit/commands/run.py index cf3bf3735..d243a90b9 100644 --- a/src/winml/modelkit/commands/run.py +++ b/src/winml/modelkit/commands/run.py @@ -25,6 +25,7 @@ from __future__ import annotations +import contextlib import json import logging import sys @@ -622,7 +623,11 @@ def run( return try: - engine.load(model, task=task, device=device, ep=ep) + # Redirect stdout → stderr during model load so that build-pipeline + # prints (from optimum, onnxruntime, etc.) don't contaminate + # structured output (--format json) or text output parsing. + with contextlib.redirect_stdout(sys.stderr): + engine.load(model, task=task, device=device, ep=ep) except (OSError, ValueError, RuntimeError) as exc: click.echo(f"Error loading model: {exc}", err=True) ctx.exit(3) diff --git a/src/winml/modelkit/inference/engine.py b/src/winml/modelkit/inference/engine.py index bdf37074a..234287ccb 100644 --- a/src/winml/modelkit/inference/engine.py +++ b/src/winml/modelkit/inference/engine.py @@ -100,6 +100,87 @@ def _decode_image(data: bytes) -> Any: } +def _sanitize_numpy(obj: Any) -> Any: + """Recursively convert numpy scalars to Python types for JSON serialization. + + HF pipelines (e.g. NER/TokenClassification) return dicts containing + ``numpy.float32`` scores and ``numpy.int64`` offsets that pydantic + cannot serialize. This function converts them to native Python types. + """ + import numpy as np + + if isinstance(obj, dict): + return {k: _sanitize_numpy(v) for k, v in obj.items()} + if isinstance(obj, list): + return [_sanitize_numpy(v) for v in obj] + if isinstance(obj, np.floating): + return float(obj) + if isinstance(obj, np.integer): + return int(obj) + if isinstance(obj, np.ndarray): + return obj.tolist() + return obj + + +# --------------------------------------------------------------------------- +# Build-directory artifact discovery +# --------------------------------------------------------------------------- + + +def _find_build_artifacts(build_dir: Path, *, task: str | None = None) -> tuple[Path, dict | None]: + """Locate model.onnx and build_manifest.json inside a build/cache directory. + + Supports both plain layout (``model.onnx``) and cache-key-prefixed layout + (``{cache_key}_model.onnx``) so that ``InferenceEngine.load()`` can load + directly from ``~/.cache/winml/artifacts/{slug}/`` without a full rebuild. + + When *task* is provided, only returns artifacts whose manifest ``task`` + matches. A cache directory may contain multiple task variants (e.g. + ``feat_*`` for feature-extraction and ``txtcls_*`` for text-classification); + using the wrong ONNX (different output head) would produce garbage results. + + Returns: + ``(onnx_path, manifest_dict)`` — manifest is ``None`` when no + manifest file is found. + + Raises: + FileNotFoundError: if no matching ``*model.onnx`` is found. + """ + # Try plain layout first (bare build output) + plain_onnx = build_dir / "model.onnx" + plain_manifest = build_dir / "build_manifest.json" + if plain_onnx.exists(): + manifest = json.loads(plain_manifest.read_text()) if plain_manifest.exists() else None + if task is None or manifest is None or manifest.get("task") == task: + return plain_onnx, manifest + + # Cache-key-prefixed layout: {cache_key}_model.onnx + # Scan all candidates and match by task when specified. + candidates: list[tuple[Path, dict | None]] = [] + for onnx_path in sorted(build_dir.glob("*_model.onnx")): + prefix = onnx_path.name.rsplit("_model.onnx", 1)[0] + manifest_path = build_dir / f"{prefix}_build_manifest.json" + manifest = json.loads(manifest_path.read_text()) if manifest_path.exists() else None + if task is not None: + if manifest is None or manifest.get("task") == task: + return onnx_path, manifest + else: + candidates.append((onnx_path, manifest)) + + if candidates: + # task=None: if all variants share the same task, pick the first. + # If multiple different tasks exist, raise to force the caller to specify. + tasks = {m.get("task") for _, m in candidates if m is not None} + if len(tasks) > 1: + raise FileNotFoundError( + f"Multiple task variants found in {build_dir}: {tasks}. " + "Pass task= to select the correct ONNX model." + ) + return candidates[0] + + raise FileNotFoundError(f"No model.onnx matching task={task!r} found in {build_dir}") + + # --------------------------------------------------------------------------- # Lightweight schema helpers (no model load required) # --------------------------------------------------------------------------- @@ -233,7 +314,24 @@ def load( path = Path(model_path) if path.is_dir(): - self._load_from_build_dir(path, task=task, device=device, ep=ep) + try: + self._load_from_build_dir(path, task=task, device=device, ep=ep) + except (FileNotFoundError, json.JSONDecodeError, KeyError): + # No cached ONNX for this task (or corrupt manifest) — check + # if the manifest has a model_id we can rebuild from (e.g. + # cache was built for a different task like feature-extraction + # but caller wants text-classification). + model_id = self._resolve_model_id_from_dir(path) + if model_id: + logger.info( + "No cached ONNX for task=%s in %s — rebuilding from %s", + task, + path, + model_id, + ) + self._load_from_hf(model_id, task=task, device=device, ep=ep) + else: + raise elif path.suffix == ".onnx" and path.exists(): self._load_from_onnx(path, task=task, device=device, ep=ep) else: @@ -279,9 +377,11 @@ def load_schema_only( if path.is_dir(): # Build dir: read manifest for task + model_id (no ORT session) - manifest_path = path / "build_manifest.json" - if manifest_path.exists(): - manifest = json.loads(manifest_path.read_text()) + try: + _, manifest = _find_build_artifacts(path, task=task) + except FileNotFoundError: + manifest = None + if manifest is not None: self._model_id = manifest.get("model_id") self._task = task or manifest.get("task") else: @@ -755,11 +855,14 @@ def _normalize_pipeline_output( ) for item in raw ] - # Non-classification list of dicts (e.g. text-generation, NER) - return raw[0] if len(raw) == 1 else {"results": raw} + # Non-classification list of dicts (e.g. text-generation, NER). + # Sanitize numpy scalars so pydantic/JSON serialization works + # (NER pipelines return np.float32 scores). + result = raw[0] if len(raw) == 1 else {"results": raw} + return _sanitize_numpy(result) # Other tasks: return as-is dict if isinstance(raw, dict): - return raw + return _sanitize_numpy(raw) # Fallback return {"raw": str(raw)} @@ -807,15 +910,10 @@ def _load_from_build_dir( device: str, ep: str | None, ) -> None: - manifest_path = build_dir / "build_manifest.json" - onnx_path = build_dir / "model.onnx" - - if not onnx_path.exists(): - raise FileNotFoundError(f"model.onnx not found in {build_dir}") + onnx_path, manifest = _find_build_artifacts(build_dir, task=task) model_id: str | None = None - if manifest_path.exists(): - manifest = json.loads(manifest_path.read_text()) + if manifest is not None: model_id = manifest.get("model_id") task = task or manifest.get("task") @@ -832,6 +930,16 @@ def _load_from_build_dir( logger.info("Loaded from build dir: task=%s model_id=%s", task, model_id) + @staticmethod + def _resolve_model_id_from_dir(build_dir: Path) -> str | None: + """Extract model_id from any manifest in the directory (task-agnostic).""" + for manifest_path in build_dir.glob("*build_manifest.json"): + manifest = json.loads(manifest_path.read_text()) + model_id = manifest.get("model_id") + if model_id: + return model_id + return None + def _load_from_onnx( self, onnx_path: Path, diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 00182c042..5bcdcf946 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -2,10 +2,14 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Shared fixtures for E2E tests. +"""Shared fixtures and helpers for E2E tests. -These fixtures generate real ONNX files on-the-fly and provide -model-task combination parameters for parametrized tests. +Provides: + - Hub model parametrization helpers (``HUB_PAIRS``, ``pytest_id``) + - Cache-aware model resolution (``find_cache_dir``, ``resolve_model_arg``) + - Common sample text inputs (``SAMPLE_TEXT``, ``TEXT_BY_FIELD``) + - Shared fixtures (``test_image``, ``runner``) + - Auto-skip for ``-m e2e`` E2E tests are auto-skipped unless explicitly selected with: uv run pytest -m e2e @@ -14,16 +18,128 @@ from __future__ import annotations import json +from pathlib import Path from typing import TYPE_CHECKING import numpy as np import onnx import pytest from onnx import TensorProto, helper +from PIL import Image if TYPE_CHECKING: - from pathlib import Path + from click.testing import CliRunner + + +# --------------------------------------------------------------------------- +# Hub model parametrization (single source of truth) +# --------------------------------------------------------------------------- + +_HUB_JSON = ( + Path(__file__).resolve().parents[2] / "src" / "winml" / "modelkit" / "data" / "hub_models.json" +) +HUB_DATA: dict = json.loads(_HUB_JSON.read_text(encoding="utf-8")) + + +def _unique_pairs() -> list[dict[str, str]]: + """Deduplicate ``(model_id, task)`` — keep first occurrence.""" + seen: set[tuple[str, str]] = set() + pairs: list[dict[str, str]] = [] + for entry in HUB_DATA["models"]: + key = (entry["model_id"], entry["task"]) + if key not in seen: + seen.add(key) + pairs.append({"model_id": entry["model_id"], "task": entry["task"]}) + return pairs + + +HUB_PAIRS: list[dict[str, str]] = _unique_pairs() + + +def hub_test_id(pair: dict[str, str]) -> str: + """Readable pytest ID, e.g. ``finbert-text_classification``.""" + short = pair["model_id"].rsplit("/", 1)[-1] + task = pair["task"].replace("-", "_") + return f"{short}-{task}" + + +# --------------------------------------------------------------------------- +# Cache-aware model resolution +# --------------------------------------------------------------------------- + + +def find_cache_dir(model_id: str, task: str | None = None) -> Path | None: + """Find the winml build-cache directory for a model+task, or None. + + Looks for ``~/.cache/winml/artifacts/{slug}/`` containing a + ``*_model.onnx`` file whose manifest task matches *task*. + """ + from winml.modelkit.cache import get_cache_dir, model_id_to_slug + from winml.modelkit.inference.engine import _find_build_artifacts + + slug = model_id_to_slug(model_id) + cache_dir = get_cache_dir() / "artifacts" / slug + if not cache_dir.is_dir(): + return None + try: + _find_build_artifacts(cache_dir, task=task) + return cache_dir + except (FileNotFoundError, json.JSONDecodeError): + return None + + +def resolve_model_arg(model_id: str, task: str | None = None) -> str: + """Return the cache directory (fast) or HF model ID (slow rebuild).""" + cache_dir = find_cache_dir(model_id, task=task) + if cache_dir is not None: + return str(cache_dir) + return model_id + + +# --------------------------------------------------------------------------- +# Common sample inputs +# --------------------------------------------------------------------------- + +SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog." + +TEXT_BY_FIELD: dict[str, str] = { + "question": "What is the capital of France?", + "context": ( + "Paris is the capital of France. " + "It is known for the Eiffel Tower and its rich cultural heritage." + ), + "text_1": "A man is eating food.", + "text_2": "A man is eating a piece of bread.", +} + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def test_image(tmp_path_factory: pytest.TempPathFactory) -> str: + """Generate a 224x224 random RGB JPEG (reused across the module).""" + d = tmp_path_factory.mktemp("run_e2e_assets") + img_path = d / "test_image.jpg" + rng = np.random.RandomState(42) + arr = rng.randint(0, 255, (224, 224, 3), dtype=np.uint8) + Image.fromarray(arr).save(str(img_path), format="JPEG") + return str(img_path) + + +@pytest.fixture +def runner() -> CliRunner: + from click.testing import CliRunner + + return CliRunner() + + +# --------------------------------------------------------------------------- +# Auto-skip E2E +# --------------------------------------------------------------------------- def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: diff --git a/tests/e2e/test_run_e2e.py b/tests/e2e/test_run_e2e.py new file mode 100644 index 000000000..d04414b31 --- /dev/null +++ b/tests/e2e/test_run_e2e.py @@ -0,0 +1,471 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""E2E quality-gate tests for ``winml run``. + +No mocks — real models, real inputs, real outputs. Tests are organized +in three tiers by scope and cost: + +Tier 1 — **Feature gates** (2 fixed models) + Validates CLI features: ``--file``, ``--text``, ``-I``, ``-P``, + ``--format text|json``, ``-o``, ``--schema``. + +Tier 2 — **Schema coverage** (all hub models) + ``winml run --schema`` for every ``(model_id, task)`` pair in + ``hub_models.json`` — lightweight, no ORT session. + +Tier 3 — **Inference coverage** (all hub models) + Full inference per hub model. Cache-aware: prefers already-built + directories under ``~/.cache/winml/artifacts/`` to avoid the slow + export → optimize → analyze pipeline on every run. + +Usage:: + + # All tiers + uv run pytest -m e2e tests/e2e/test_run_e2e.py -v + + # Tier 1 only (fast regression) + uv run pytest -m e2e tests/e2e/test_run_e2e.py -k "Feature" -v + + # Tier 2 only (schema) + uv run pytest -m e2e tests/e2e/test_run_e2e.py -k "Schema" -v + + # Tier 3 only (inference matrix) + uv run pytest -m e2e tests/e2e/test_run_e2e.py -k "Inference" -v + + # Filter by task or model name + uv run pytest -m e2e tests/e2e/test_run_e2e.py -k "text_classification" -v + uv run pytest -m e2e tests/e2e/test_run_e2e.py -k "finbert" -v + +Markers: + e2e: Full end-to-end test with real models + slow: Tests that take > 30 seconds + network: Requires network access to HuggingFace Hub +""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING + +import pytest + +from winml.modelkit.commands.run import run + +from .conftest import HUB_PAIRS as _PAIRS +from .conftest import SAMPLE_TEXT as _SAMPLE_TEXT +from .conftest import TEXT_BY_FIELD as _TEXT_BY_FIELD +from .conftest import hub_test_id as _pytest_id +from .conftest import resolve_model_arg as _resolve_model_arg + + +if TYPE_CHECKING: + from pathlib import Path + + from click.testing import CliRunner + +pytestmark = [pytest.mark.e2e, pytest.mark.slow, pytest.mark.network, pytest.mark.timeout(3600)] + + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +_IMAGE_HF_ID = "microsoft/resnet-18" +_TEXT_HF_ID = "prajjwal1/bert-tiny" + + +# --------------------------------------------------------------------------- +# JSON extraction helper +# --------------------------------------------------------------------------- + + +def _extract_json(output: str) -> dict: + """Extract the JSON object from CLI output that may have build-pipeline noise. + + ``run.py`` redirects ``sys.stdout`` → ``sys.stderr`` during ``engine.load()`` + to prevent build-pipeline prints from contaminating JSON output. However, + some C-extension code (e.g. onnxruntime, tqdm) writes directly to file + descriptor 1, bypassing Python's ``sys.stdout`` — ``redirect_stdout`` cannot + intercept those writes. This function provides a robust fallback by + scanning for the first valid top-level ``{...}`` JSON object. + """ + decoder = json.JSONDecoder() + # Scan forward for the first '{' that starts a valid JSON object + for i, ch in enumerate(output): + if ch == "{": + try: + obj, _ = decoder.raw_decode(output, i) + if isinstance(obj, dict): + return obj + except json.JSONDecodeError: + continue + raise ValueError(f"No JSON object found in output: {output[:200]!r}") + + +# --------------------------------------------------------------------------- +# Sample inputs for inference +# --------------------------------------------------------------------------- + +_FALLBACK_INPUT_ARGS: dict[str, list[str]] = { + "sentence-similarity": ["--text", _SAMPLE_TEXT], +} + + +def _build_inference_args( + schema_inputs: list[dict], + task: str, + test_image: str, +) -> list[str] | None: + """Build CLI args for inference from ``--schema`` output. + + Returns ``None`` when no inputs can be determined (caller should skip). + """ + required = [i for i in schema_inputs if i.get("required", False)] + + if not required: + return _FALLBACK_INPUT_ARGS.get(task) + + binary = [i for i in required if i["type"] in ("image", "audio", "video")] + text = [i for i in required if i["type"] == "text"] + json_fields = [i for i in required if i["type"] == "json"] + + args: list[str] = [] + + if len(binary) == 1 and binary[0]["type"] == "image": + args.extend(["--file", test_image]) + else: + for b in binary: + args.extend(["-I", f"{b['name']}=@{test_image}"]) + + if len(text) == 1 and not binary and not json_fields: + sample = _TEXT_BY_FIELD.get(text[0]["name"], _SAMPLE_TEXT) + args.extend(["--text", sample]) + else: + for t in text: + sample = _TEXT_BY_FIELD.get(t["name"], _SAMPLE_TEXT) + args.extend(["-I", f"{t['name']}={sample}"]) + + for j in json_fields: + args.extend(["-I", f'{j["name"]}=["positive","negative","neutral"]']) + + return args + + +# --------------------------------------------------------------------------- +# Tier 1 fixtures: cache-aware model paths for fixed models +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def image_model() -> str: + """Resolve resnet-18 to cache dir (fast) or HF ID (slow).""" + return _resolve_model_arg(_IMAGE_HF_ID) + + +@pytest.fixture(scope="module") +def text_model() -> str: + """Resolve bert-tiny to cache dir (fast) or HF ID (slow).""" + return _resolve_model_arg(_TEXT_HF_ID, task="text-classification") + + +# ===================================================================== +# Tier 1 — Feature gates (fixed models, deep assertions) +# ===================================================================== + + +class TestFeatureImageClassification: + """resnet-18: --file, --format, -o, -I image=@path.""" + + def test_file_text_format(self, runner: CliRunner, image_model: str, test_image: str) -> None: + result = runner.invoke( + run, + ["--model", image_model, "--file", test_image], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Task:" in result.output + assert "Device:" in result.output + assert "Latency:" in result.output + + def test_file_json_format(self, runner: CliRunner, image_model: str, test_image: str) -> None: + result = runner.invoke( + run, + ["--model", image_model, "--file", test_image, "--format", "json"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert data["task"] == "image-classification" + assert isinstance(data["predictions"], list) + assert len(data["predictions"]) > 0 + pred = data["predictions"][0] + assert "label" in pred and "score" in pred + assert isinstance(pred["score"], float) + assert data["latency_ms"] > 0 + + def test_file_json_to_file( + self, runner: CliRunner, image_model: str, test_image: str, tmp_path: Path + ) -> None: + out = tmp_path / "result.json" + result = runner.invoke( + run, + ["--model", image_model, "--file", test_image, "--format", "json", "-o", str(out)], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert out.exists() + data = json.loads(out.read_text(encoding="utf-8")) + assert data["task"] == "image-classification" + assert len(data["predictions"]) > 0 + + def test_named_input(self, runner: CliRunner, image_model: str, test_image: str) -> None: + result = runner.invoke( + run, + ["--model", image_model, "-I", f"image=@{test_image}", "--format", "json"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert len(data["predictions"]) > 0 + + +class TestFeatureTextClassification: + """bert-tiny: --text, -I text=, -P top_k=.""" + + def test_text_shortcut(self, runner: CliRunner, text_model: str, tmp_path: Path) -> None: + out = tmp_path / "result.json" + result = runner.invoke( + run, + [ + "--model", + text_model, + "--text", + "This product is amazing!", + "--task", + "text-classification", + "--format", + "json", + "-o", + str(out), + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + data = json.loads(out.read_text(encoding="utf-8")) + assert data["task"] == "text-classification" + assert "predictions" in data + assert data["latency_ms"] > 0 + + def test_named_input(self, runner: CliRunner, text_model: str) -> None: + result = runner.invoke( + run, + [ + "--model", + text_model, + "-I", + "text=Hello world", + "--task", + "text-classification", + "--format", + "json", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "predictions" in data + + def test_pipeline_param(self, runner: CliRunner, text_model: str) -> None: + result = runner.invoke( + run, + [ + "--model", + text_model, + "--text", + "Testing pipeline params", + "--task", + "text-classification", + "-P", + "top_k=3", + "--format", + "json", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "predictions" in data + + +class TestFeatureOutputFormats: + """Validate --format text vs json, and -o file output.""" + + def test_text_format_sections( + self, + runner: CliRunner, + image_model: str, + test_image: str, + ) -> None: + result = runner.invoke( + run, + ["--model", image_model, "--file", test_image, "--format", "text"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + out = result.output + assert "Task: image-classification" in out + assert "Device:" in out + assert "Results:" in out or "Output:" in out + assert "Latency:" in out and "ms" in out + + def test_json_format_keys(self, runner: CliRunner, image_model: str, test_image: str) -> None: + result = runner.invoke( + run, + ["--model", image_model, "--file", test_image, "--format", "json"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert {"task", "predictions", "latency_ms", "device"}.issubset(data.keys()) + + def test_output_to_file( + self, runner: CliRunner, image_model: str, test_image: str, tmp_path: Path + ) -> None: + out = tmp_path / "result.txt" + result = runner.invoke( + run, + ["--model", image_model, "--file", test_image, "--format", "text", "-o", str(out)], + catch_exceptions=False, + ) + assert result.exit_code == 0 + content = out.read_text(encoding="utf-8") + assert "Task:" in content and "Latency:" in content + + +class TestFeatureSchema: + """Validate --schema output (text + json + file).""" + + def test_schema_text(self, runner: CliRunner, image_model: str) -> None: + result = runner.invoke( + run, + ["--model", image_model, "--schema"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Inputs" in result.output or "inputs" in result.output.lower() + + def test_schema_json(self, runner: CliRunner, image_model: str) -> None: + result = runner.invoke( + run, + ["--model", image_model, "--schema", "--format", "json"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + data = json.loads(result.output) + assert "task" in data + assert isinstance(data["inputs"], list) + + def test_schema_to_file(self, runner: CliRunner, image_model: str, tmp_path: Path) -> None: + out = tmp_path / "schema.json" + result = runner.invoke( + run, + ["--model", image_model, "--schema", "--format", "json", "-o", str(out)], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert out.exists() + assert "inputs" in json.loads(out.read_text(encoding="utf-8")) + + def test_schema_does_not_run_inference(self, runner: CliRunner, image_model: str) -> None: + result = runner.invoke( + run, + ["--model", image_model, "--schema"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Inputs" in result.output or "inputs" in result.output.lower() + + +# ===================================================================== +# Tier 2 — Schema coverage (all hub models, lightweight) +# ===================================================================== + + +class TestSchemaAllModels: + """``--schema --format json`` for every hub model — no ORT session needed.""" + + @pytest.mark.parametrize("pair", _PAIRS, ids=[_pytest_id(p) for p in _PAIRS]) + def test_schema(self, runner: CliRunner, pair: dict[str, str]) -> None: + result = runner.invoke( + run, + ["--model", pair["model_id"], "--task", pair["task"], "--schema", "--format", "json"], + catch_exceptions=False, + ) + assert result.exit_code == 0, f"--schema failed (exit {result.exit_code}):\n{result.output}" + data = json.loads(result.output) + assert "task" in data + assert isinstance(data["inputs"], list) + if data["inputs"]: + assert "example" in data + assert data["example"].startswith("winml run") + + +# ===================================================================== +# Tier 3 — Inference coverage (all hub models, cache-aware) +# ===================================================================== + + +class TestInferenceAllModels: + """Full inference for every hub model — uses build cache when available. + + Flow per model: + 1. Resolve model arg (cache dir or HF ID) + 2. ``--schema`` → discover inputs + 3. Run inference → validate JSON output + """ + + @pytest.mark.parametrize("pair", _PAIRS, ids=[_pytest_id(p) for p in _PAIRS]) + def test_run( + self, + runner: CliRunner, + pair: dict[str, str], + test_image: str, + ) -> None: + model_id = pair["model_id"] + task = pair["task"] + model_arg = _resolve_model_arg(model_id, task=task) + + # Step 1: Discover inputs via --schema + schema_result = runner.invoke( + run, + ["--model", model_arg, "--task", task, "--schema", "--format", "json"], + catch_exceptions=False, + ) + assert schema_result.exit_code == 0, ( + f"--schema failed (exit {schema_result.exit_code}):\n{schema_result.output}" + ) + schema = json.loads(schema_result.output) + + # Step 2: Build inference args + input_args = _build_inference_args(schema["inputs"], task, test_image) + if input_args is None: + pytest.xfail( + f"Cannot determine inputs for task '{task}' (empty schema, no fallback) — " + "extend _FALLBACK_INPUT_ARGS or _build_inference_args to cover this task" + ) + + # Step 3: Run inference + result = runner.invoke( + run, + ["--model", model_arg, "--task", task, "--format", "json", *input_args], + catch_exceptions=False, + ) + assert result.exit_code == 0, ( + f"Inference failed (exit {result.exit_code}):\n{result.output}" + ) + data = _extract_json(result.output) + assert "task" in data + assert "latency_ms" in data + assert data["latency_ms"] > 0 diff --git a/tests/e2e/test_serve_e2e.py b/tests/e2e/test_serve_e2e.py new file mode 100644 index 000000000..dbf35af3a --- /dev/null +++ b/tests/e2e/test_serve_e2e.py @@ -0,0 +1,636 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""E2E quality-gate tests for ``winml serve``. + +No mocks — real models, real HTTP requests, real predictions. Tests are +organized in four tiers by scope and cost: + +Tier 1 — **Endpoint feature gates** (2 fixed models) + Validates every REST endpoint: ``/v1/health``, ``/v1/predict``, + ``/v1/predict/file``, ``/v1/schema``, ``/v1/tools``, ``/v1/mcp-schema``, + ``/v1/hub``, ``/v1/logs``, ``/v1/resources``, ``/v1/ep``. + +Tier 2 — **Schema coverage** (all hub models) + ``GET /v1/schema`` for every ``(model_id, task)`` pair in + ``hub_models.json`` — lightweight, validates schema discovery via HTTP. + +Tier 3 — **Inference coverage** (all hub models) + Full ``POST /v1/predict`` or ``/v1/predict/file`` per hub model. + Cache-aware: prefers already-built directories to avoid slow rebuilds. + +Tier 4 — **Pipeline parameters** (fixed models) + Validates that ``params`` are forwarded and affect predictions + (e.g. ``top_k`` limits output length). + +Usage:: + + # All tiers + uv run pytest -m e2e tests/e2e/test_serve_e2e.py -v + + # Tier 1 only (fast regression) + uv run pytest -m e2e tests/e2e/test_serve_e2e.py -k "Feature" -v + + # Tier 2 only (schema) + uv run pytest -m e2e tests/e2e/test_serve_e2e.py -k "SchemaAll" -v + + # Tier 3 only (inference matrix) + uv run pytest -m e2e tests/e2e/test_serve_e2e.py -k "InferenceAll" -v + + # Filter by task or model name + uv run pytest -m e2e tests/e2e/test_serve_e2e.py -k "text_classification" -v + uv run pytest -m e2e tests/e2e/test_serve_e2e.py -k "finbert" -v + +Markers: + e2e: Full end-to-end test with real models + slow: Tests that take > 30 seconds + network: Requires network access to HuggingFace Hub +""" + +from __future__ import annotations + +import base64 +from pathlib import Path +from typing import Any + +import pytest +from fastapi.testclient import TestClient + +from winml.modelkit.serve import create_app + +from .conftest import HUB_PAIRS as _PAIRS +from .conftest import SAMPLE_TEXT as _SAMPLE_TEXT +from .conftest import TEXT_BY_FIELD as _TEXT_BY_FIELD +from .conftest import hub_test_id as _pytest_id +from .conftest import resolve_model_arg as _resolve_model_arg + + +pytestmark = [pytest.mark.e2e, pytest.mark.slow, pytest.mark.network, pytest.mark.timeout(3600)] + + +# --------------------------------------------------------------------------- +# Constants — fixed P0 models for Tier 1 / Tier 4 +# --------------------------------------------------------------------------- + +_IMAGE_HF_ID = "microsoft/resnet-18" +_TEXT_HF_ID = "prajjwal1/bert-tiny" + + +def _build_predict_body( + schema_inputs: list[dict], + task: str, + test_image_b64: str, +) -> dict[str, Any] | None: + """Build POST /v1/predict JSON body from schema discovery output. + + Returns ``None`` when no inputs can be determined (caller should skip). + """ + required = [i for i in schema_inputs if i.get("required", False)] + + if not required: + # sentence-similarity fallback + if task == "sentence-similarity": + return { + "inputs": {"text_1": _TEXT_BY_FIELD["text_1"], "text_2": _TEXT_BY_FIELD["text_2"]}, + } + return None + + inputs: dict[str, Any] = {} + + for field in required: + name = field["name"] + ftype = field["type"] + if ftype in ("image", "audio", "video"): + inputs[name] = test_image_b64 + elif ftype == "text": + inputs[name] = _TEXT_BY_FIELD.get(name, _SAMPLE_TEXT) + elif ftype == "json": + inputs[name] = ["positive", "negative", "neutral"] + + if not inputs: + return None + + return {"inputs": inputs} + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +# test_image fixture is inherited from conftest.py + + +@pytest.fixture(scope="module") +def test_image_bytes(test_image: str) -> bytes: + """Raw bytes of the test JPEG image.""" + return Path(test_image).read_bytes() + + +@pytest.fixture(scope="module") +def test_image_b64(test_image_bytes: bytes) -> str: + """Base64-encoded test image for JSON predict requests.""" + return base64.b64encode(test_image_bytes).decode() + + +@pytest.fixture(scope="module") +def image_model() -> str: + """Resolve resnet-18 to cache dir (fast) or HF ID (slow).""" + return _resolve_model_arg(_IMAGE_HF_ID, task="image-classification") + + +@pytest.fixture(scope="module") +def text_model() -> str: + """Resolve bert-tiny to cache dir (fast) or HF ID (slow).""" + return _resolve_model_arg(_TEXT_HF_ID, task="text-classification") + + +@pytest.fixture(scope="module") +def image_client(image_model: str): + """TestClient wrapping a Phase 1 serve app with an image model loaded.""" + app = create_app(model_path=image_model) + with TestClient(app) as client: + yield client + + +@pytest.fixture(scope="module") +def text_client(text_model: str): + """TestClient wrapping a Phase 1 serve app with a text model loaded.""" + app = create_app(model_path=text_model, task="text-classification") + with TestClient(app) as client: + yield client + + +# ===================================================================== +# Tier 1 — Feature gates: endpoint validation with fixed models +# ===================================================================== + + +class TestFeatureHealth: + """GET /v1/health — liveness and model metadata.""" + + def test_health_image_model(self, image_client: TestClient) -> None: + resp = image_client.get("/v1/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ready" + assert data["version"] + assert data["mode"] == "single" + assert data["task"] == "image-classification" + assert data["uptime_sec"] >= 0 + + def test_health_text_model(self, text_client: TestClient) -> None: + resp = text_client.get("/v1/health") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ready" + assert data["task"] == "text-classification" + + +class TestFeaturePredictFile: + """POST /v1/predict/file — image classification via file upload.""" + + def test_upload_returns_predictions( + self, + image_client: TestClient, + test_image_bytes: bytes, + ) -> None: + resp = image_client.post( + "/v1/predict/file", + files={"file": ("test.jpg", test_image_bytes, "image/jpeg")}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["task"] == "image-classification" + assert isinstance(data["predictions"], list) + assert len(data["predictions"]) > 0 + pred = data["predictions"][0] + assert "label" in pred and "score" in pred + assert isinstance(pred["score"], float) + assert data["latency_ms"] > 0 + + def test_upload_with_model_id_underscore( + self, + image_client: TestClient, + test_image_bytes: bytes, + ) -> None: + """model_id="_" (default) should route to the only loaded model.""" + resp = image_client.post( + "/v1/predict/file", + files={"file": ("test.jpg", test_image_bytes, "image/jpeg")}, + data={"model_id": "_"}, + ) + assert resp.status_code == 200 + assert len(resp.json()["predictions"]) > 0 + + def test_upload_with_task_hint( + self, + image_client: TestClient, + test_image_bytes: bytes, + ) -> None: + resp = image_client.post( + "/v1/predict/file", + files={"file": ("test.jpg", test_image_bytes, "image/jpeg")}, + data={"task": "image-classification"}, + ) + assert resp.status_code == 200 + assert resp.json()["task"] == "image-classification" + + +class TestFeaturePredictJson: + """POST /v1/predict — JSON named inputs.""" + + def test_text_classification(self, text_client: TestClient) -> None: + resp = text_client.post( + "/v1/predict", + json={"inputs": {"text": "This product is amazing!"}}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["task"] == "text-classification" + assert isinstance(data["predictions"], list) + assert len(data["predictions"]) > 0 + assert data["latency_ms"] > 0 + + def test_image_classification_base64( + self, + image_client: TestClient, + test_image_b64: str, + ) -> None: + resp = image_client.post( + "/v1/predict", + json={"inputs": {"image": test_image_b64}}, + ) + assert resp.status_code == 200 + data = resp.json() + assert data["task"] == "image-classification" + assert len(data["predictions"]) > 0 + + def test_empty_inputs_returns_error(self, text_client: TestClient) -> None: + resp = text_client.post( + "/v1/predict", + json={"inputs": {}}, + ) + assert resp.status_code in (400, 422, 500) + + def test_missing_inputs_field_returns_422(self, text_client: TestClient) -> None: + resp = text_client.post("/v1/predict", json={}) + assert resp.status_code == 422 + + +class TestFeatureSchema: + """GET /v1/schema — request/response schema discovery.""" + + def test_schema_image_model(self, image_client: TestClient) -> None: + resp = image_client.get("/v1/schema") + assert resp.status_code == 200 + data = resp.json() + assert data["task"] == "image-classification" + assert isinstance(data["user_inputs"], list) + assert len(data["user_inputs"]) > 0 + names = {inp["name"] for inp in data["user_inputs"]} + assert "image" in names + assert "endpoints" in data + + def test_schema_text_model(self, text_client: TestClient) -> None: + resp = text_client.get("/v1/schema") + assert resp.status_code == 200 + data = resp.json() + assert data["task"] == "text-classification" + names = {inp["name"] for inp in data["user_inputs"]} + assert "text" in names + + def test_schema_task_override(self, image_client: TestClient) -> None: + """?task= query param overrides schema resolution.""" + resp = image_client.get("/v1/schema?task=object-detection") + assert resp.status_code == 200 + assert resp.json()["task"] == "object-detection" + + +class TestFeatureTools: + """GET /v1/tools — OpenAI function-calling tool definitions.""" + + def test_tools_image_model(self, image_client: TestClient) -> None: + resp = image_client.get("/v1/tools") + assert resp.status_code == 200 + data = resp.json() + assert "tools" in data + assert isinstance(data["tools"], list) + assert len(data["tools"]) > 0 + tool = data["tools"][0] + assert "type" in tool + assert tool["type"] == "function" + assert "function" in tool + fn = tool["function"] + assert "name" in fn + assert "parameters" in fn + + def test_tools_text_model(self, text_client: TestClient) -> None: + resp = text_client.get("/v1/tools") + assert resp.status_code == 200 + assert len(resp.json()["tools"]) > 0 + + +class TestFeatureMcpSchema: + """GET /v1/mcp-schema — MCP-compatible tool definitions.""" + + def test_mcp_schema(self, image_client: TestClient) -> None: + resp = image_client.get("/v1/mcp-schema") + assert resp.status_code == 200 + data = resp.json() + assert "tools" in data + assert isinstance(data["tools"], list) + assert len(data["tools"]) > 0 + mcp_tool = data["tools"][0] + assert "name" in mcp_tool + assert "description" in mcp_tool + assert "inputSchema" in mcp_tool + assert "server_info" in data + assert data["server_info"]["name"] == "ModelKit Inference" + + +class TestFeatureHub: + """GET /v1/hub — model catalog.""" + + def test_hub_returns_models(self, image_client: TestClient) -> None: + resp = image_client.get("/v1/hub") + assert resp.status_code == 200 + data = resp.json() + assert "version" in data + assert "models" in data + assert isinstance(data["models"], list) + assert len(data["models"]) > 0 + # Every entry must have model_id and task + for m in data["models"]: + assert "model_id" in m + assert "task" in m + assert "source" in m + + +class TestFeatureLogs: + """GET /v1/logs — ring buffer log polling.""" + + def test_logs_returns_structure(self, image_client: TestClient) -> None: + resp = image_client.get("/v1/logs") + assert resp.status_code == 200 + data = resp.json() + assert "lines" in data + assert "latest_seq" in data + assert isinstance(data["lines"], list) + assert isinstance(data["latest_seq"], int) + + def test_logs_after_filter(self, image_client: TestClient) -> None: + """?after=N filters to lines with seq > N.""" + resp = image_client.get("/v1/logs?after=999999") + assert resp.status_code == 200 + assert resp.json()["lines"] == [] + + +class TestFeatureResources: + """GET /v1/resources — runtime memory + request stats.""" + + def test_resources(self, image_client: TestClient) -> None: + resp = image_client.get("/v1/resources") + assert resp.status_code == 200 + data = resp.json() + assert data["status"] in ("ready", "loading", "unloaded") + assert "uptime_sec" in data + assert data["uptime_sec"] >= 0 + assert "memory_mb" in data + assert "request_count" in data + + +class TestFeatureEpSwitch: + """POST /v1/ep — switch execution provider.""" + + def test_switch_to_cpu(self, image_client: TestClient) -> None: + resp = image_client.post("/v1/ep", json={"ep": "cpu"}) + assert resp.status_code == 200 + data = resp.json() + assert data["status"] == "ok" + assert data["ep"] == "cpu" + + def test_switch_invalid_ep(self, image_client: TestClient) -> None: + resp = image_client.post("/v1/ep", json={"ep": "nonexistent"}) + assert resp.status_code == 422 + + +class TestFeatureModels: + """GET /v1/models — list loaded models.""" + + def test_list_models(self, image_client: TestClient) -> None: + resp = image_client.get("/v1/models") + assert resp.status_code == 200 + models = resp.json() + assert isinstance(models, list) + assert len(models) >= 1 + m = models[0] + assert "model_id" in m + assert "status" in m + assert m["status"] == "ready" + + +class TestFeatureOutputConsistency: + """Cross-endpoint consistency: predictions from /predict and /predict/file should match.""" + + def test_file_vs_json_same_predictions( + self, + image_client: TestClient, + test_image_bytes: bytes, + test_image_b64: str, + ) -> None: + """Both endpoints should return the same top-1 label for the same image.""" + resp_file = image_client.post( + "/v1/predict/file", + files={"file": ("test.jpg", test_image_bytes, "image/jpeg")}, + ) + resp_json = image_client.post( + "/v1/predict", + json={"inputs": {"image": test_image_b64}}, + ) + assert resp_file.status_code == 200 + assert resp_json.status_code == 200 + preds_file = resp_file.json()["predictions"] + preds_json = resp_json.json()["predictions"] + assert preds_file[0]["label"] == preds_json[0]["label"] + + +# ===================================================================== +# Tier 2 — Schema coverage (all hub models via HTTP) +# ===================================================================== + + +class TestSchemaAllModels: + """``GET /v1/schema`` for every hub model.""" + + @pytest.mark.parametrize("pair", _PAIRS, ids=[_pytest_id(p) for p in _PAIRS]) + def test_schema(self, pair: dict[str, str]) -> None: + model_arg = _resolve_model_arg(pair["model_id"], task=pair["task"]) + app = create_app(model_path=model_arg, task=pair["task"]) + with TestClient(app) as client: + resp = client.get("/v1/schema") + assert resp.status_code == 200, f"GET /v1/schema failed ({resp.status_code}):\n{resp.text}" + data = resp.json() + assert data["task"] == pair["task"] + assert isinstance(data["user_inputs"], list) + assert "endpoints" in data + + +# ===================================================================== +# Tier 3 — Inference coverage (all hub models, cache-aware) +# ===================================================================== + + +class TestInferenceAllModels: + """Full ``POST /v1/predict`` for every hub model. + + Flow per model: + 1. Create a serve app with the model + 2. ``GET /v1/schema`` → discover inputs + 3. ``POST /v1/predict`` → validate JSON response + """ + + @pytest.mark.parametrize("pair", _PAIRS, ids=[_pytest_id(p) for p in _PAIRS]) + def test_predict(self, pair: dict[str, str], test_image_b64: str) -> None: + model_id = pair["model_id"] + task = pair["task"] + model_arg = _resolve_model_arg(model_id, task=task) + + app = create_app(model_path=model_arg, task=task) + with TestClient(app) as client: + # Step 1: Discover inputs via GET /v1/schema + schema_resp = client.get("/v1/schema") + assert schema_resp.status_code == 200, ( + f"GET /v1/schema failed ({schema_resp.status_code}):\n{schema_resp.text}" + ) + schema = schema_resp.json() + + # Step 2: Build predict body + body = _build_predict_body(schema["user_inputs"], task, test_image_b64) + if body is None: + pytest.xfail( + f"Cannot determine inputs for task '{task}' (empty schema, no fallback) — " + "extend _build_predict_body to cover this task" + ) + + # Step 3: Run inference + resp = client.post("/v1/predict", json=body) + assert resp.status_code == 200, ( + f"POST /v1/predict failed ({resp.status_code}):\n{resp.text}" + ) + data = resp.json() + assert "task" in data + assert "latency_ms" in data + assert data["latency_ms"] > 0 + + # Step 4: Validate response after inference + health_resp = client.get("/v1/health") + assert health_resp.status_code == 200 + assert health_resp.json()["status"] == "ready" + + +# ===================================================================== +# Tier 4 — Pipeline parameters (fixed models, params forwarding) +# ===================================================================== + + +class TestPipelineParams: + """Validate that ``params`` in predict body are forwarded correctly.""" + + def test_top_k_limits_predictions(self, text_client: TestClient) -> None: + """top_k=1 should return at most 1 prediction.""" + resp = text_client.post( + "/v1/predict", + json={ + "inputs": {"text": "This is a great product!"}, + "params": {"top_k": 1}, + }, + ) + assert resp.status_code == 200 + preds = resp.json()["predictions"] + assert isinstance(preds, list) + assert len(preds) == 1 + + def test_top_k_3_returns_up_to_3(self, text_client: TestClient) -> None: + resp = text_client.post( + "/v1/predict", + json={ + "inputs": {"text": "The market is doing well today."}, + "params": {"top_k": 3}, + }, + ) + assert resp.status_code == 200 + preds = resp.json()["predictions"] + assert isinstance(preds, list) + assert len(preds) <= 3 + + def test_params_empty_dict_ok(self, text_client: TestClient) -> None: + """Empty params {} should work fine (default behavior).""" + resp = text_client.post( + "/v1/predict", + json={"inputs": {"text": "Hello world"}, "params": {}}, + ) + assert resp.status_code == 200 + assert resp.json()["latency_ms"] > 0 + + +# ===================================================================== +# Error handling — negative cases +# ===================================================================== + + +class TestErrorHandling: + """Validate proper error responses for malformed requests.""" + + def test_predict_file_no_file(self, image_client: TestClient) -> None: + """Missing file field → 422.""" + resp = image_client.post("/v1/predict/file") + assert resp.status_code == 422 + + def test_predict_file_too_large(self, image_client: TestClient) -> None: + """File > 20 MB → 413.""" + huge = b"\x00" * (21 * 1024 * 1024) + resp = image_client.post( + "/v1/predict/file", + files={"file": ("big.bin", huge, "application/octet-stream")}, + ) + assert resp.status_code == 413 + + def test_predict_json_inputs_params_collision(self, text_client: TestClient) -> None: + """Same key in inputs and params → 400.""" + resp = text_client.post( + "/v1/predict", + json={ + "inputs": {"text": "hi", "top_k": 5}, + "params": {"top_k": 3}, + }, + ) + assert resp.status_code == 400 + assert "top_k" in resp.json()["detail"] + + def test_predict_json_invalid_base64_image(self, image_client: TestClient) -> None: + """Non-base64 string for image field → 400.""" + resp = image_client.post( + "/v1/predict", + json={"inputs": {"image": "not-valid-base64!@#$"}}, + ) + assert resp.status_code == 400 + + def test_switch_ep_invalid(self, image_client: TestClient) -> None: + """Unknown EP name → 422.""" + resp = image_client.post("/v1/ep", json={"ep": "banana"}) + assert resp.status_code == 422 + + def test_unload_model_single_mode(self, image_client: TestClient) -> None: + """DELETE /v1/models/{id} in single mode → 400.""" + resp = image_client.delete("/v1/models/some-model") + assert resp.status_code == 400 + + def test_load_model_single_mode(self, image_client: TestClient) -> None: + """POST /v1/models in single mode → 400.""" + resp = image_client.post( + "/v1/models", + json={"model_id": "some/model"}, + ) + assert resp.status_code == 400 diff --git a/tests/unit/inference/test_engine.py b/tests/unit/inference/test_engine.py index 715f2fe54..59d8a7215 100644 --- a/tests/unit/inference/test_engine.py +++ b/tests/unit/inference/test_engine.py @@ -32,7 +32,9 @@ from winml.modelkit.inference.engine import ( _build_param_entry, _discover_pipeline_params_from_task, + _find_build_artifacts, _pick_sample_value, + _sanitize_numpy, ) @@ -318,6 +320,7 @@ def test_build_dir_reads_manifest(self, tmp_path: Any) -> None: manifest = {"model_id": "test/model", "task": "text-classification"} (tmp_path / "build_manifest.json").write_text(json.dumps(manifest)) + (tmp_path / "model.onnx").write_bytes(b"fake") engine = InferenceEngine() engine.load_schema_only(tmp_path) assert engine._task == "text-classification" @@ -330,6 +333,178 @@ def test_task_param_overrides_manifest(self, tmp_path: Any) -> None: manifest = {"model_id": "test/model", "task": "text-classification"} (tmp_path / "build_manifest.json").write_text(json.dumps(manifest)) + (tmp_path / "model.onnx").write_bytes(b"fake") engine = InferenceEngine() engine.load_schema_only(tmp_path, task="image-classification") assert engine._task == "image-classification" + + +# --------------------------------------------------------------------------- +# _sanitize_numpy +# --------------------------------------------------------------------------- + + +class TestSanitizeNumpy: + """Ensure numpy scalars are converted to Python types for JSON serialization.""" + + def test_float32_to_float(self) -> None: + import numpy as np + + result = _sanitize_numpy({"score": np.float32(0.95)}) + assert isinstance(result["score"], float) + assert abs(result["score"] - 0.95) < 0.001 + + def test_int64_to_int(self) -> None: + import numpy as np + + result = _sanitize_numpy({"start": np.int64(10)}) + assert isinstance(result["start"], int) + assert result["start"] == 10 + + def test_nested_dict(self) -> None: + import numpy as np + + ner_output = { + "entity_group": "PER", + "score": np.float32(0.998), + "word": "John", + "start": np.int64(0), + "end": np.int64(4), + } + result = _sanitize_numpy(ner_output) + assert isinstance(result["score"], float) + assert isinstance(result["start"], int) + assert isinstance(result["end"], int) + assert result["entity_group"] == "PER" + + def test_list_of_dicts(self) -> None: + import numpy as np + + raw = [ + {"entity_group": "PER", "score": np.float32(0.99)}, + {"entity_group": "LOC", "score": np.float32(0.85)}, + ] + result = _sanitize_numpy(raw) + assert all(isinstance(d["score"], float) for d in result) + + def test_plain_types_unchanged(self) -> None: + result = _sanitize_numpy({"label": "cat", "score": 0.95, "count": 5}) + assert result == {"label": "cat", "score": 0.95, "count": 5} + + def test_ndarray_to_list(self) -> None: + import numpy as np + + result = _sanitize_numpy({"embedding": np.array([1.0, 2.0, 3.0])}) + assert result["embedding"] == [1.0, 2.0, 3.0] + + +# --------------------------------------------------------------------------- +# _find_build_artifacts +# --------------------------------------------------------------------------- + + +class TestFindBuildArtifacts: + def test_plain_layout(self, tmp_path: Any) -> None: + import json + + (tmp_path / "model.onnx").write_bytes(b"fake") + manifest = {"model_id": "test/model", "task": "text-classification"} + (tmp_path / "build_manifest.json").write_text(json.dumps(manifest)) + onnx_path, m = _find_build_artifacts(tmp_path) + assert onnx_path.name == "model.onnx" + assert m["task"] == "text-classification" + + def test_prefixed_layout(self, tmp_path: Any) -> None: + import json + + (tmp_path / "txtcls_abc123_model.onnx").write_bytes(b"fake") + manifest = {"model_id": "test/model", "task": "text-classification"} + (tmp_path / "txtcls_abc123_build_manifest.json").write_text(json.dumps(manifest)) + onnx_path, m = _find_build_artifacts(tmp_path) + assert onnx_path.name == "txtcls_abc123_model.onnx" + assert m["task"] == "text-classification" + + def test_no_onnx_raises(self, tmp_path: Any) -> None: + import pytest + + with pytest.raises(FileNotFoundError): + _find_build_artifacts(tmp_path) + + def test_onnx_without_manifest(self, tmp_path: Any) -> None: + (tmp_path / "model.onnx").write_bytes(b"fake") + onnx_path, m = _find_build_artifacts(tmp_path) + assert onnx_path.name == "model.onnx" + assert m is None + + def test_task_filter_selects_matching(self, tmp_path: Any) -> None: + """When task= is specified, only return artifacts whose manifest matches.""" + import json + + # Two variants in same directory + (tmp_path / "feat_aaa_model.onnx").write_bytes(b"fake-feat") + (tmp_path / "feat_aaa_build_manifest.json").write_text( + json.dumps({"model_id": "m", "task": "feature-extraction"}) + ) + (tmp_path / "txtcls_bbb_model.onnx").write_bytes(b"fake-txtcls") + (tmp_path / "txtcls_bbb_build_manifest.json").write_text( + json.dumps({"model_id": "m", "task": "text-classification"}) + ) + + onnx_path, m = _find_build_artifacts(tmp_path, task="text-classification") + assert "txtcls" in onnx_path.name + assert m["task"] == "text-classification" + + def test_task_filter_no_match_raises(self, tmp_path: Any) -> None: + """When task= doesn't match any manifest, raise FileNotFoundError.""" + import json + + import pytest + + (tmp_path / "feat_aaa_model.onnx").write_bytes(b"fake") + (tmp_path / "feat_aaa_build_manifest.json").write_text( + json.dumps({"model_id": "m", "task": "feature-extraction"}) + ) + + with pytest.raises(FileNotFoundError): + _find_build_artifacts(tmp_path, task="text-classification") + + def test_task_none_returns_first(self, tmp_path: Any) -> None: + """Without task filter, return the first candidate.""" + import json + + (tmp_path / "feat_aaa_model.onnx").write_bytes(b"fake") + (tmp_path / "feat_aaa_build_manifest.json").write_text( + json.dumps({"model_id": "m", "task": "feature-extraction"}) + ) + + onnx_path, _manifest = _find_build_artifacts(tmp_path, task=None) + assert onnx_path.exists() + + +# --------------------------------------------------------------------------- +# _normalize_pipeline_output sanitizes NER numpy types +# --------------------------------------------------------------------------- + + +class TestNormalizeNEROutput: + def test_ner_output_numpy_sanitized(self) -> None: + """NER pipeline output with numpy.float32 scores should be serializable.""" + import numpy as np + + engine = InferenceEngine() + engine._task = "token-classification" + # NER-like output: list of dicts with numpy scalars + raw = [ + { + "entity_group": "PER", + "score": np.float32(0.998), + "word": "John", + "start": np.int64(0), + "end": np.int64(4), + }, + ] + result = engine._normalize_pipeline_output(raw) + # Should be serializable — all numpy types converted + import json + + json.dumps(result) # Must not raise