Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
449 changes: 419 additions & 30 deletions README.md

Large diffs are not rendered by default.

608 changes: 608 additions & 0 deletions docs/assets/workflow-only.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 34 additions & 6 deletions src/winml/modelkit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,13 @@
model = WinMLAutoModel.from_pretrained("facebook/convnext-tiny-224", config=config)
"""

import logging
from importlib.metadata import PackageNotFoundError, version


logging.getLogger(__name__).addHandler(logging.NullHandler())

from . import _warnings # Configure warning filters before importing subpackages
from .config import WinMLBuildConfig
from .models import (
WinMLAutoModel,
WinMLModelForImageClassification,
WinMLPreTrainedModel,
)


try:
Expand All @@ -51,3 +49,33 @@
"WinMLPreTrainedModel",
"__version__",
]


_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"WinMLBuildConfig": (".config", "WinMLBuildConfig"),
"WinMLAutoModel": (".models", "WinMLAutoModel"),
"WinMLPreTrainedModel": (".models", "WinMLPreTrainedModel"),
"WinMLModelForImageClassification": (".models", "WinMLModelForImageClassification"),
}


def __getattr__(name: str):
"""Lazy-load heavy exports on first access (PEP 562).

This avoids importing torch/transformers/optimum (~30s) when only
lightweight operations are needed (e.g., ``winml --help``).
"""
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
import importlib

mod = importlib.import_module(module_path, __name__)
val = getattr(mod, attr_name)
globals()[name] = val
return val
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__() -> list[str]:
"""Include lazy attributes in dir() for debugger/IPython compatibility."""
return list(set(list(globals()) + __all__))
41 changes: 22 additions & 19 deletions src/winml/modelkit/_warnings.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,38 +44,41 @@ class _DiffusersDistributionFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
return "Multiple distributions found" not in record.getMessage()

logging.getLogger("diffusers.utils.import_utils").addFilter(
_DiffusersDistributionFilter()
)
logging.getLogger("diffusers.utils.import_utils").addFilter(_DiffusersDistributionFilter())

class _HFPipelineFalsePositiveFilter(logging.Filter):
"""Filter false-positive HF pipeline warnings when using WinML models.
class _PipelineNoiseFilter(logging.Filter):
"""Filter noisy HF Pipeline warnings.

HF pipeline emits these because WinMLModel wraps ONNX via ORT, not a
native HF model class. These are expected and not actionable.
- 'The model X is not supported for Y' — WinML models are duck-type
compatible but not in HF's supported list.
- 'Device set to use cpu' — HF Pipeline forces CPU, we handle device.
- 'Using a slow image processor' — cosmetic deprecation notice.
"""

_FALSE_POSITIVES = (
"WinMLModel", # False positive warning which says WinML is not native HF model class
"Device set to use", # PyTorch tensor device, not ONNX device
"Using a slow image processor", # expected when using processor with pipeline.
_SUPPRESSED = (
"is not supported for",
"Device set to use cpu",
"Using a slow image processor",
)

def filter(self, record: logging.LogRecord) -> bool:
msg = record.getMessage()
return not any(phrase in msg for phrase in self._FALSE_POSITIVES)
return not any(s in msg for s in self._SUPPRESSED)

for _name in (
"transformers.pipelines.base",
"transformers.models.auto.image_processing_auto",
):
logging.getLogger(_name).addFilter(_HFPipelineFalsePositiveFilter())
logging.getLogger("transformers.pipelines.base").addFilter(_PipelineNoiseFilter())

# =========================================================================
# Warning filters (for warnings.warn() calls)
# =========================================================================
warnings.filterwarnings("ignore", category=FutureWarning, module=r"transformers.*")
warnings.filterwarnings("ignore", category=UserWarning, module=r"torch.*")
# Transformers: suppress cosmetic warnings (not RuntimeWarning/ResourceWarning)
for _cat in (FutureWarning, DeprecationWarning, UserWarning):
warnings.filterwarnings("ignore", category=_cat, module=r"transformers\..*")

# PyTorch: suppress cosmetic warnings (not RuntimeWarning/ResourceWarning)
for _cat in (FutureWarning, DeprecationWarning, UserWarning):
warnings.filterwarnings("ignore", category=_cat, module=r"torch\..*")

# Diffusers
warnings.filterwarnings(
"ignore", message=r".*CUDA.*", category=UserWarning, module=r"diffusers.*"
)
Expand Down
19 changes: 18 additions & 1 deletion src/winml/modelkit/analyze/analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@


if TYPE_CHECKING:
from collections.abc import Callable

import onnx

from .models.information import Action
Expand Down Expand Up @@ -492,6 +494,8 @@ def analyze(
htp_metadata_path: str | None = None,
run_unknown_op: bool = True,
save_node_types: set[str] | None = None,
on_node_result: Callable | None = None,
on_ep_start: Callable | None = None,
) -> AnalysisResult:
"""Analyze ONNX model for runtime support.

Expand Down Expand Up @@ -590,6 +594,8 @@ def analyze(
htp_metadata_path=htp_metadata_path,
run_unknown_op=run_unknown_op,
save_node_types=save_node_types,
on_node_result=on_node_result,
on_ep_start=on_ep_start,
)

def analyze_from_proto(
Expand All @@ -602,6 +608,8 @@ def analyze_from_proto(
htp_metadata_path: str | None = None,
run_unknown_op: bool = True,
save_node_types: set[str] | None = None,
on_node_result: Callable | None = None,
on_ep_start: Callable | None = None,
) -> AnalysisResult:
"""Analyze ONNX model from ModelProto object.

Expand Down Expand Up @@ -691,6 +699,11 @@ def analyze_from_proto(

for current_ep in eps_to_analyze:
logger.info("Checking runtime support for %s...", current_ep)
if on_ep_start:
try:
on_ep_start(current_ep, metadata.operator_counts)
except Exception:
logger.debug("on_ep_start callback failed", exc_info=True)

runtime_checker = RuntimeChecker(
ep=current_ep,
Expand All @@ -708,6 +721,7 @@ def analyze_from_proto(
patterns=pattern_matches,
run_unknown_op=run_unknown_op_for_ep,
save_node_types=save_node_types,
on_node_result=on_node_result,
)

# Convert runtime summary to expected format
Expand All @@ -727,7 +741,6 @@ def analyze_from_proto(
ep=current_ep,
model=onnx_model,
device=device_to_use,
shape_inferred_model_proto=runtime_checker.get_shape_inferred_model_proto(),
)
information_list[current_ep] = engine.summary() # Use EP name as key

Expand Down Expand Up @@ -786,6 +799,8 @@ def analyze_onnx(
ep: str | None = None,
device: str | None = None,
autoconf: bool = True,
on_ep_start: Callable | None = None,
on_node_result: Callable | None = None,
) -> AnalyzeResult:
"""Analyze an ONNX model and return lint + autoconf results.

Expand Down Expand Up @@ -841,6 +856,8 @@ def analyze_onnx(
ep=ep,
device=device,
enable_information=autoconf,
on_ep_start=on_ep_start,
on_node_result=on_node_result,
)

# Extract lint result (always computed — uses RuntimeChecker classification)
Expand Down
52 changes: 38 additions & 14 deletions src/winml/modelkit/analyze/core/runtime_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@


if TYPE_CHECKING:
from collections.abc import Callable

import onnx

from winml.modelkit.pattern.match import PatternMatchResult
Expand Down Expand Up @@ -142,21 +144,35 @@ def _get_query(self) -> RuntimeCheckerQuery:

return self._query

def get_shape_inferred_model_proto(self) -> onnx.ModelProto | None:
"""Return the shape-inferred model proto from the cached query, if available."""
if self._query is not None:
return self._query.model_proto
return None

def op_support(
self,
run_unknown_op: bool = True,
save_node_types: set[str] | None = None,
on_node_result: Callable | None = None,
) -> list[PatternRuntime]:
"""Check operator-level runtime support.

Returns operator-level runtime check results for each operator.

Args:
on_node_result: Optional per-node progress callback.
When provided, tqdm progress bar is suppressed (caller
handles progress display via Rich Live).

Signature::

(result: PatternRuntime) -> None

The ``PatternRuntime`` passed to the callback has:

- ``pattern_id`` (str): Full pattern ID, e.g.
``"OP/ai.onnx/Conv"``. Use ``split("/")[-1]`` to get
the display name (``"Conv"``).
- ``result.classification`` (SupportLevel): The support
level enum. Call ``.value`` to get the string, e.g.
``"supported"``, ``"partial"``, ``"unsupported"``,
``"unknown"``.

Returns:
List[PatternRuntime]: Runtime results for each operator pattern

Expand All @@ -177,15 +193,21 @@ def op_support(
model_proto = self._model.get_model()
# Get cached RuntimeCheckerQuery
query = self._get_query()
for node in tqdm.tqdm(model_proto.graph.node):
# Run runtime check for node
results.append( # noqa: PERF401
query.run_for_node(
node,
run_unknown_op=run_unknown_op,
save_node_types=save_node_types,
)
# Use tqdm for progress unless caller provides a callback
nodes = model_proto.graph.node
iterator = nodes if on_node_result else tqdm.tqdm(nodes)
for node in iterator:
result = query.run_for_node(
node,
run_unknown_op=run_unknown_op,
save_node_types=save_node_types,
)
results.append(result)
if on_node_result:
try:
on_node_result(result)
except Exception:
logger.debug("on_node_result callback failed", exc_info=True)

logger.info("Checked %d operators", len(results))

Expand Down Expand Up @@ -302,6 +324,7 @@ def summary(
patterns: list[PatternMatchResult] | None = None,
run_unknown_op: bool = True,
save_node_types: set[str] | None = None,
on_node_result: Callable | None = None,
) -> dict[str, list[PatternRuntime]]:
"""Combine operator-level & pattern-level runtime results.

Expand All @@ -325,6 +348,7 @@ def summary(
op_results = self.op_support(
run_unknown_op=run_unknown_op,
save_node_types=save_node_types,
on_node_result=on_node_result,
)
summary_dict["op_runtime_check_result"] = op_results

Expand Down
Loading
Loading