Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/winml/modelkit/commands/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from __future__ import annotations

import contextlib
import json
import logging
import sys
Expand Down Expand Up @@ -622,7 +623,11 @@ def run(
return

try:
engine.load(model, task=task, device=device, ep=ep)
# Redirect stdout → stderr during model load so that build-pipeline
# prints (from optimum, onnxruntime, etc.) don't contaminate
# structured output (--format json) or text output parsing.
with contextlib.redirect_stdout(sys.stderr):
Comment thread
DingmaomaoBJTU marked this conversation as resolved.
engine.load(model, task=task, device=device, ep=ep)
except (OSError, ValueError, RuntimeError) as exc:
click.echo(f"Error loading model: {exc}", err=True)
ctx.exit(3)
Expand Down
136 changes: 122 additions & 14 deletions src/winml/modelkit/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,87 @@ def _decode_image(data: bytes) -> Any:
}


def _sanitize_numpy(obj: Any) -> Any:
"""Recursively convert numpy scalars to Python types for JSON serialization.

HF pipelines (e.g. NER/TokenClassification) return dicts containing
``numpy.float32`` scores and ``numpy.int64`` offsets that pydantic
cannot serialize. This function converts them to native Python types.
"""
import numpy as np

if isinstance(obj, dict):
return {k: _sanitize_numpy(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_sanitize_numpy(v) for v in obj]
if isinstance(obj, np.floating):
return float(obj)
if isinstance(obj, np.integer):
return int(obj)
if isinstance(obj, np.ndarray):
return obj.tolist()
return obj


# ---------------------------------------------------------------------------
# Build-directory artifact discovery
# ---------------------------------------------------------------------------


def _find_build_artifacts(build_dir: Path, *, task: str | None = None) -> tuple[Path, dict | None]:
"""Locate model.onnx and build_manifest.json inside a build/cache directory.

Supports both plain layout (``model.onnx``) and cache-key-prefixed layout
(``{cache_key}_model.onnx``) so that ``InferenceEngine.load()`` can load
directly from ``~/.cache/winml/artifacts/{slug}/`` without a full rebuild.

When *task* is provided, only returns artifacts whose manifest ``task``
matches. A cache directory may contain multiple task variants (e.g.
``feat_*`` for feature-extraction and ``txtcls_*`` for text-classification);
using the wrong ONNX (different output head) would produce garbage results.

Returns:
``(onnx_path, manifest_dict)`` — manifest is ``None`` when no
manifest file is found.

Raises:
FileNotFoundError: if no matching ``*model.onnx`` is found.
"""
# Try plain layout first (bare build output)
plain_onnx = build_dir / "model.onnx"
plain_manifest = build_dir / "build_manifest.json"
if plain_onnx.exists():
manifest = json.loads(plain_manifest.read_text()) if plain_manifest.exists() else None
if task is None or manifest is None or manifest.get("task") == task:
return plain_onnx, manifest

# Cache-key-prefixed layout: {cache_key}_model.onnx
# Scan all candidates and match by task when specified.
candidates: list[tuple[Path, dict | None]] = []
for onnx_path in sorted(build_dir.glob("*_model.onnx")):
prefix = onnx_path.name.rsplit("_model.onnx", 1)[0]
manifest_path = build_dir / f"{prefix}_build_manifest.json"
manifest = json.loads(manifest_path.read_text()) if manifest_path.exists() else None
if task is not None:
if manifest is None or manifest.get("task") == task:
return onnx_path, manifest
else:
candidates.append((onnx_path, manifest))

if candidates:
# task=None: if all variants share the same task, pick the first.
# If multiple different tasks exist, raise to force the caller to specify.
tasks = {m.get("task") for _, m in candidates if m is not None}
if len(tasks) > 1:
raise FileNotFoundError(
f"Multiple task variants found in {build_dir}: {tasks}. "
"Pass task= to select the correct ONNX model."
)
return candidates[0]

raise FileNotFoundError(f"No model.onnx matching task={task!r} found in {build_dir}")


# ---------------------------------------------------------------------------
# Lightweight schema helpers (no model load required)
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -233,7 +314,24 @@ def load(
path = Path(model_path)

if path.is_dir():
self._load_from_build_dir(path, task=task, device=device, ep=ep)
try:
self._load_from_build_dir(path, task=task, device=device, ep=ep)
except (FileNotFoundError, json.JSONDecodeError, KeyError):
# No cached ONNX for this task (or corrupt manifest) — check
# if the manifest has a model_id we can rebuild from (e.g.
# cache was built for a different task like feature-extraction
# but caller wants text-classification).
model_id = self._resolve_model_id_from_dir(path)
if model_id:
logger.info(
"No cached ONNX for task=%s in %s — rebuilding from %s",
task,
path,
model_id,
)
self._load_from_hf(model_id, task=task, device=device, ep=ep)
else:
raise
elif path.suffix == ".onnx" and path.exists():
self._load_from_onnx(path, task=task, device=device, ep=ep)
else:
Expand Down Expand Up @@ -279,9 +377,11 @@ def load_schema_only(

if path.is_dir():
# Build dir: read manifest for task + model_id (no ORT session)
manifest_path = path / "build_manifest.json"
if manifest_path.exists():
manifest = json.loads(manifest_path.read_text())
try:
_, manifest = _find_build_artifacts(path, task=task)
except FileNotFoundError:
manifest = None
if manifest is not None:
self._model_id = manifest.get("model_id")
self._task = task or manifest.get("task")
else:
Expand Down Expand Up @@ -755,11 +855,14 @@ def _normalize_pipeline_output(
)
for item in raw
]
# Non-classification list of dicts (e.g. text-generation, NER)
return raw[0] if len(raw) == 1 else {"results": raw}
# Non-classification list of dicts (e.g. text-generation, NER).
# Sanitize numpy scalars so pydantic/JSON serialization works
# (NER pipelines return np.float32 scores).
result = raw[0] if len(raw) == 1 else {"results": raw}
return _sanitize_numpy(result)
# Other tasks: return as-is dict
if isinstance(raw, dict):
return raw
return _sanitize_numpy(raw)
# Fallback
return {"raw": str(raw)}

Expand Down Expand Up @@ -807,15 +910,10 @@ def _load_from_build_dir(
device: str,
ep: str | None,
) -> None:
manifest_path = build_dir / "build_manifest.json"
onnx_path = build_dir / "model.onnx"

if not onnx_path.exists():
raise FileNotFoundError(f"model.onnx not found in {build_dir}")
onnx_path, manifest = _find_build_artifacts(build_dir, task=task)

model_id: str | None = None
if manifest_path.exists():
manifest = json.loads(manifest_path.read_text())
if manifest is not None:
model_id = manifest.get("model_id")
task = task or manifest.get("task")

Expand All @@ -832,6 +930,16 @@ def _load_from_build_dir(

logger.info("Loaded from build dir: task=%s model_id=%s", task, model_id)

@staticmethod
def _resolve_model_id_from_dir(build_dir: Path) -> str | None:
"""Extract model_id from any manifest in the directory (task-agnostic)."""
for manifest_path in build_dir.glob("*build_manifest.json"):
manifest = json.loads(manifest_path.read_text())
model_id = manifest.get("model_id")
if model_id:
return model_id
return None

def _load_from_onnx(
self,
onnx_path: Path,
Expand Down
124 changes: 120 additions & 4 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Shared fixtures for E2E tests.
"""Shared fixtures and helpers for E2E tests.

These fixtures generate real ONNX files on-the-fly and provide
model-task combination parameters for parametrized tests.
Provides:
- Hub model parametrization helpers (``HUB_PAIRS``, ``pytest_id``)
- Cache-aware model resolution (``find_cache_dir``, ``resolve_model_arg``)
- Common sample text inputs (``SAMPLE_TEXT``, ``TEXT_BY_FIELD``)
- Shared fixtures (``test_image``, ``runner``)
- Auto-skip for ``-m e2e``

E2E tests are auto-skipped unless explicitly selected with:
uv run pytest -m e2e
Expand All @@ -14,16 +18,128 @@
from __future__ import annotations

import json
from pathlib import Path
from typing import TYPE_CHECKING

import numpy as np
import onnx
import pytest
from onnx import TensorProto, helper
from PIL import Image


if TYPE_CHECKING:
from pathlib import Path
from click.testing import CliRunner


# ---------------------------------------------------------------------------
# Hub model parametrization (single source of truth)
# ---------------------------------------------------------------------------

_HUB_JSON = (
Path(__file__).resolve().parents[2] / "src" / "winml" / "modelkit" / "data" / "hub_models.json"
)
HUB_DATA: dict = json.loads(_HUB_JSON.read_text(encoding="utf-8"))


def _unique_pairs() -> list[dict[str, str]]:
"""Deduplicate ``(model_id, task)`` — keep first occurrence."""
seen: set[tuple[str, str]] = set()
pairs: list[dict[str, str]] = []
for entry in HUB_DATA["models"]:
key = (entry["model_id"], entry["task"])
if key not in seen:
seen.add(key)
pairs.append({"model_id": entry["model_id"], "task": entry["task"]})
return pairs


HUB_PAIRS: list[dict[str, str]] = _unique_pairs()


def hub_test_id(pair: dict[str, str]) -> str:
"""Readable pytest ID, e.g. ``finbert-text_classification``."""
short = pair["model_id"].rsplit("/", 1)[-1]
task = pair["task"].replace("-", "_")
return f"{short}-{task}"


# ---------------------------------------------------------------------------
# Cache-aware model resolution
# ---------------------------------------------------------------------------


def find_cache_dir(model_id: str, task: str | None = None) -> Path | None:
"""Find the winml build-cache directory for a model+task, or None.

Looks for ``~/.cache/winml/artifacts/{slug}/`` containing a
``*_model.onnx`` file whose manifest task matches *task*.
"""
from winml.modelkit.cache import get_cache_dir, model_id_to_slug
from winml.modelkit.inference.engine import _find_build_artifacts

slug = model_id_to_slug(model_id)
cache_dir = get_cache_dir() / "artifacts" / slug
if not cache_dir.is_dir():
return None
try:
_find_build_artifacts(cache_dir, task=task)
return cache_dir
except (FileNotFoundError, json.JSONDecodeError):
return None


def resolve_model_arg(model_id: str, task: str | None = None) -> str:
"""Return the cache directory (fast) or HF model ID (slow rebuild)."""
cache_dir = find_cache_dir(model_id, task=task)
if cache_dir is not None:
return str(cache_dir)
return model_id


# ---------------------------------------------------------------------------
# Common sample inputs
# ---------------------------------------------------------------------------

SAMPLE_TEXT = "The quick brown fox jumps over the lazy dog."

TEXT_BY_FIELD: dict[str, str] = {
"question": "What is the capital of France?",
"context": (
"Paris is the capital of France. "
"It is known for the Eiffel Tower and its rich cultural heritage."
),
"text_1": "A man is eating food.",
"text_2": "A man is eating a piece of bread.",
}


# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------


@pytest.fixture(scope="module")
def test_image(tmp_path_factory: pytest.TempPathFactory) -> str:
"""Generate a 224x224 random RGB JPEG (reused across the module)."""
d = tmp_path_factory.mktemp("run_e2e_assets")
img_path = d / "test_image.jpg"
rng = np.random.RandomState(42)
arr = rng.randint(0, 255, (224, 224, 3), dtype=np.uint8)
Image.fromarray(arr).save(str(img_path), format="JPEG")
return str(img_path)


@pytest.fixture
def runner() -> CliRunner:
from click.testing import CliRunner

return CliRunner()


# ---------------------------------------------------------------------------
# Auto-skip E2E
# ---------------------------------------------------------------------------


def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
Expand Down
Loading
Loading