Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions src/winml/modelkit/build/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def build_hf_model(
cache_key: str | None = None,
ep: EPNameOrAlias | None = None,
device: str | None = None,
model_type: str | None = None,
**kwargs: Any,
) -> BuildResult:
"""Build an ONNX model from a HuggingFace model architecture.
Expand Down Expand Up @@ -211,6 +212,7 @@ def _name(base: str) -> str:
model_id,
trust_remote_code,
random_init=random_init,
model_type=model_type,
)

# =========================================================================
Expand Down Expand Up @@ -315,6 +317,31 @@ def _name(base: str) -> str:
else:
logger.info("Quantizing model...")
t0 = time.monotonic()
# Some model types finalize their quant config only once the
# exported ONNX exists (calibration feeds / nodes-to-exclude derived
# from the graph). Resolve the model-type-specific quant policy from
# the quant registry, keyed on the live ``model_type``. Unregistered
# types return None → the quantizer uses its standard task-aware
# DatasetCalibrationReader.
from ..quant import get_quant_finalizer

resolved_model_type = (
getattr(getattr(pytorch_model, "config", None), "model_type", None) or model_type
)
quant_finalizer = get_quant_finalizer(resolved_model_type)
if quant_finalizer is not None:
# Generic id fallback: the policy loads a fresh reference model
# for calibration, so feed it the best-known HF id/path.
resolved_model_id = model_id or getattr(
getattr(pytorch_model, "config", None), "_name_or_path", None
)
config.quant = quant_finalizer.finalize(
config.quant, onnx_path=current_path, model_id=resolved_model_id
)
# The policy may overwrite the quant scheme (dtypes, symmetry,
# nodes-to-exclude) authoritatively, so re-persist the config
# to keep config.json consistent with what was actually applied.
config_path.write_text(json.dumps(config.to_dict(), indent=2))
quant_result = quantize_onnx(
model_path=current_path,
output_path=quantized_path,
Expand Down Expand Up @@ -443,6 +470,7 @@ def _load_model(
trust_remote_code: bool,
random_init: bool = False,
hf_config: Any | None = None,
model_type: str | None = None,
) -> Any:
"""Load PyTorch model — pretrained or random weights.

Expand Down Expand Up @@ -518,6 +546,7 @@ def _load_model(
task=task,
trust_remote_code=effective_trust,
hf_config=hf_config,
model_type=model_type,
)
return pytorch_model

Expand Down
32 changes: 31 additions & 1 deletion src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,7 +1339,11 @@ def _name(base: str) -> str:

# Load + export (blocking)
pytorch_model = _load_model(
config, model_id, trust_remote_code=False, hf_config=preloaded_hf_config
config,
model_id,
trust_remote_code=False,
hf_config=preloaded_hf_config,
model_type=config.loader.model_type,
)
t0 = time.monotonic()
# config.export is None only for the ONNX build path; this is the HF path.
Expand Down Expand Up @@ -1384,6 +1388,32 @@ def _name(base: str) -> str:
config_path.write_text(json.dumps(config.to_dict(), indent=2))

# ── Quantize stage ───────────────────────────────────────────
# Some model types finalize their quant config only once the exported ONNX
# exists (calibration feeds / nodes-to-exclude derived from the graph).
# Resolve the model-type-specific quant policy from the quant registry,
# keyed on the live ``model_type`` — mirrors build.hf.build_hf_model so the
# CLI and library pipelines apply the same scheme. Unregistered types return
# None → the quantizer uses its standard task-aware DatasetCalibrationReader.
if config.quant is not None:
from ..quant import get_quant_finalizer

resolved_model_type = (
getattr(getattr(pytorch_model, "config", None), "model_type", None)
or config.loader.model_type
)
quant_finalizer = get_quant_finalizer(resolved_model_type)
if quant_finalizer is not None:
resolved_model_id = model_id or getattr(
getattr(pytorch_model, "config", None), "_name_or_path", None
)
config.quant = quant_finalizer.finalize(
config.quant, onnx_path=current_path, model_id=resolved_model_id
)
# The policy may overwrite the quant scheme (dtypes, symmetry,
# nodes-to-exclude) authoritatively, so re-persist the config to keep
# config.json consistent with what was actually applied.
config_path.write_text(json.dumps(config.to_dict(), indent=2))

current_path = _run_quantize_stage(
config=config,
current_path=current_path,
Expand Down
13 changes: 13 additions & 0 deletions src/winml/modelkit/loader/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,19 @@ def resolve_loader_config(
f"attribute. Cannot proceed with config generation."
)

# Explicit model_type override alongside a model_id: honor the requested
# type so downstream class / build-config / export resolution selects the
# variant (e.g. "qwen3_transformer_only") rather than the architecture's
# native type. The model_type-only path above (AutoConfig.for_model) is
# unaffected because it only runs when model_id is None.
if model_id is not None and model_type is not None and hf_config.model_type != model_type:
logger.info(
"Overriding resolved model_type '%s' -> '%s' (explicit request)",
hf_config.model_type,
model_type,
)
hf_config.model_type = model_type

# 2-3. Unified resolution. Task detection — including the no-architectures
# --model-type fallback (first supported task) — now lives in resolve_task.
resolution = resolve_task(hf_config, task=task, model_class=model_class)
Expand Down
13 changes: 13 additions & 0 deletions src/winml/modelkit/loader/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def load_hf_model(
user_script: str | None = None,
trust_remote_code: bool = False,
hf_config: PretrainedConfig | None = None,
model_type: str | None = None,
) -> tuple[nn.Module, PretrainedConfig, str]:
"""Load, detect task, and prepare HuggingFace model.

Expand Down Expand Up @@ -218,6 +219,18 @@ def load_hf_model(
trust_remote_code=trust_remote_code,
)

# Explicit model_type override: select a registered build variant (e.g.
# "qwen3_transformer_only") rather than the architecture's native type.
# Mutates the freshly-loaded config only; gated on an explicit request so
# normal loading is unaffected.
if model_type is not None and getattr(hf_config, "model_type", None) != model_type:
logger.info(
"Overriding model_type '%s' -> '%s' (explicit request)",
getattr(hf_config, "model_type", None),
model_type,
)
hf_config.model_type = model_type

# [2] Task & Model Class Resolution
from .resolution import resolve_task

Expand Down
16 changes: 15 additions & 1 deletion src/winml/modelkit/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def from_pretrained(
trust_remote_code: bool = False,
shape_config: dict | None = None,
no_compile: bool = False,
model_type: str | None = None,
allow_unsupported_nodes: bool = False,
**kwargs: Any,
) -> WinMLPreTrainedModel:
Expand Down Expand Up @@ -300,6 +301,10 @@ def from_pretrained(
shape_config: Shape overrides passed to generate_build_config().
Valid keys -- text: sequence_length; vision: height, width;
audio: feature_size, nb_max_frames, audio_sequence_length.
model_type: Explicit model_type override. When provided alongside a
HF model_id, selects a registered build variant (e.g.
``"qwen3_transformer_only"``) instead of the architecture's
native model_type. Leave ``None`` for normal auto-detection.
allow_unsupported_nodes: If True, warn instead of raising when the
analyzer reports unsupported nodes that persist; the build
proceeds and the EP may fall back to another device for them.
Expand Down Expand Up @@ -361,6 +366,11 @@ def from_pretrained(
else:
_model_type = None

# Explicit override wins so a variant composite (e.g.
# "qwen3_transformer_only") can be selected over the native type.
if model_type is not None:
_model_type = model_type

if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY:
from .winml.composite_model import WinMLCompositeModel

Expand Down Expand Up @@ -398,6 +408,7 @@ def from_pretrained(
trust_remote_code=trust_remote_code,
ep=kwargs.get("ep"),
no_compile=no_compile,
model_type=model_type,
)

resolved_task = build_config.loader.task
Expand Down Expand Up @@ -432,7 +443,9 @@ def from_pretrained(
from transformers import AutoConfig

hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=effective_trust)
model_type = getattr(hf_config, "model_type", "unknown")
# Honor an explicit model_type override; otherwise probe from the config.
if model_type is None:
model_type = getattr(hf_config, "model_type", "unknown")
logger.debug("Model type: %s, task: %s", model_type, resolved_task)

# =====================================================================
Expand Down Expand Up @@ -470,6 +483,7 @@ def from_pretrained(
cache_key=cache_key,
ep=resolved_ep,
device=device,
model_type=model_type,
allow_unsupported_nodes=allow_unsupported_nodes,
**build_control_kwargs,
)
Expand Down
11 changes: 11 additions & 0 deletions src/winml/modelkit/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@
from .qwen import QWEN_CONFIG
from .qwen import QwenGenIOConfig as _QwenGenIOConfig
from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig
from .qwen3.qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING
from .qwen3.qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG
from .qwen3.qwen_transformer_only import (
QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration
)
from .qwen3.qwen_transformer_only import (
# triggers registration
QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig,
)
from .roberta import ROBERTA_FAMILY_CONFIG
from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration
from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING
Expand Down Expand Up @@ -92,6 +101,7 @@
**_MARIAN_CLASS_MAPPING,
**_MU2_CLASS_MAPPING,
**_QWEN_CLASS_MAPPING,
**_QWEN_TO_CLASS_MAPPING,
**_SAM2_CLASS_MAPPING,
**_SEGFORMER_CLASS_MAPPING,
**_SIGLIP_CLASS_MAPPING,
Expand All @@ -115,6 +125,7 @@
"roberta": ROBERTA_FAMILY_CONFIG,
"mu2": MU2_CONFIG,
"qwen3": QWEN_CONFIG,
"qwen3-transformer-only": QWEN_TRANSFORMER_ONLY_CONFIG,
"siglip": SIGLIP_CONFIG,
"siglip-text-model": SIGLIP_CONFIG,
"siglip-vision-model": SIGLIP_CONFIG,
Expand Down
6 changes: 6 additions & 0 deletions src/winml/modelkit/models/hf/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------

"""Qwen3 transformer-only export support (modeling, export ops, IO configs)."""
Loading