Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions src/winml/modelkit/export/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,19 @@ class ONNXConfigNotFoundError(ValueError):
TASK_SYNONYM_EXTENSIONS: dict[str, str] = {
# next-sentence-prediction has same I/O as text-classification: input_ids → logits
"next-sentence-prediction": "text-classification",
# mask-generation is registered via register_onnx_overwrite for SAM2.
# Optimum incorrectly maps it to "feature-extraction"; preserve as-is.
"mask-generation": "mask-generation",
}


def _map_task_synonym(task: str) -> str:
"""Map task name to canonical form, extending Optimum's synonym mapping.

This function first checks our custom extensions for tasks Optimum doesn't
recognize, then delegates to Optimum's map_from_synonym for known synonyms.
Our extensions take priority over Optimum's built-in synonym map.
If a task is found in ``TASK_SYNONYM_EXTENSIONS``, return immediately
without passing through Optimum (which may incorrectly normalize
custom-registered tasks like ``mask-generation``).

Args:
task: Task name (e.g., "next-sentence-prediction", "image-feature-extraction")
Expand All @@ -91,16 +96,20 @@ def _map_task_synonym(task: str) -> str:
Example:
>>> map_task_synonym("next-sentence-prediction") # Our extension
'text-classification'
>>> map_task_synonym("mask-generation") # Preserved (not Optimum-normalized)
'mask-generation'
>>> map_task_synonym("image-feature-extraction") # Optimum's synonym
'feature-extraction'
>>> map_task_synonym("text-classification") # Already canonical
'text-classification'
"""
# First: apply our extensions for tasks Optimum doesn't recognize
mapped_task = TASK_SYNONYM_EXTENSIONS.get(task, task)
# Our extensions take priority — return early to prevent Optimum from
# incorrectly normalizing custom-registered tasks.
if task in TASK_SYNONYM_EXTENSIONS:
return TASK_SYNONYM_EXTENSIONS[task]

# Second: normalize via Optimum's built-in synonym mapping
return TasksManager.map_from_synonym(mapped_task)
# Fallback: normalize via Optimum's built-in synonym mapping
return TasksManager.map_from_synonym(task)


# =============================================================================
Expand Down
60 changes: 52 additions & 8 deletions src/winml/modelkit/models/hf/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,48 @@
# Sam2VisionModel cannot load weights from a Sam2VideoModel checkpoint because
# checkpoint keys are prefixed with "vision_encoder." (e.g., "vision_encoder.backbone.*")
# but Sam2VisionModel expects unprefixed keys (e.g., "backbone.*").
# This wrapper loads the full Sam2VideoModel and extracts the vision_encoder submodule.
# This wrapper loads the full Sam2VideoModel and extracts the vision_encoder submodule,
# flattening the FPN tuple outputs into individual tensor outputs for ONNX compatibility.


class Sam2VisionEncoder(torch.nn.Module):
"""Wrapper that loads Sam2VideoModel and extracts vision_encoder."""
"""Wrapper that loads Sam2VideoModel, extracts vision_encoder.

Flattens FPN tuple outputs for ONNX export.

Sam2VisionModel.forward() returns Sam2VisionEncoderOutput where
fpn_hidden_states is a tuple of 3 tensors (one per FPN level).
Optimum's ModelPatcher output filter matches by dict key name against
the ONNX config's output names, so tuple-of-tensor fields are invisible
to the filter, producing an empty ONNX graph.

This wrapper flattens the FPN outputs into individual tensor entries
with names matching Sam2ImageEncoderIOConfig.outputs:
fpn_hidden_states[2] -> image_embeddings [B, 256, 64, 64]
fpn_hidden_states[0] -> high_res_features1 [B, 256, 256, 256]
fpn_hidden_states[1] -> high_res_features2 [B, 256, 128, 128]
"""

def __init__(self, vision_encoder: torch.nn.Module) -> None:
super().__init__()
self.vision_encoder = vision_encoder
self.config = vision_encoder.config

@classmethod
def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module:
def from_pretrained(cls, model_name_or_path: str, **kwargs) -> Sam2VisionEncoder:
full_model = Sam2Model.from_pretrained(model_name_or_path, **kwargs)
return full_model.vision_encoder
return cls(full_model.vision_encoder)

def forward(
self, pixel_values: torch.Tensor | None = None, **kwargs: Any
) -> dict[str, torch.Tensor]:
out = self.vision_encoder(pixel_values=pixel_values, **kwargs)
fpn = out.fpn_hidden_states
return {
"image_embeddings": fpn[2],
"high_res_features1": fpn[0],
"high_res_features2": fpn[1],
}


class SAM2MaskGeneration(torch.nn.Module):
Expand Down Expand Up @@ -155,14 +187,26 @@ def forward(
no_mem = self.no_memory_embedding.permute(0, 2, 1).unsqueeze(-1)
image_embeddings = image_embeddings + no_mem

# 2. Prompt embeddings (patched by Sam2ModelPatcher during export)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
# 2. Prompt embeddings
# Get sparse embeddings (without mask — mask blending handled below)
sparse_embeddings, _ = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=None,
input_masks=mask_input,
use_mask_input=use_mask_input,
input_masks=None,
)

# Arithmetic mask blending via use_mask_input flag
# (avoids torch.where for ONNX/QNN compatibility)
mask_dense = self.prompt_encoder.mask_embed(mask_input)
no_mask_dense = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
batch_size,
-1,
self.image_embedding_size[0],
self.image_embedding_size[1],
)
flag = use_mask_input.reshape(-1, 1, 1, 1).to(mask_dense.dtype)
dense_embeddings = (1.0 - flag) * no_mask_dense + flag * mask_dense

# 3. Positional embeddings
image_positional_embeddings = self._get_image_positional_embeddings(batch_size)
Expand Down
2 changes: 2 additions & 0 deletions src/winml/modelkit/optim/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,9 @@ def optimize_onnx(
# Step 9: Run optimization
logger.info("Starting optimization pipeline...")
optimizer = Optimizer()
# More than one "optimize" call is necessary for certain models for thorough optimization
optimized_model = optimizer.optimize(loaded_model, **optimizer_kwargs)
optimized_model = optimizer.optimize(optimized_model, **optimizer_kwargs)
Comment thread
vortex-captain marked this conversation as resolved.

# Step 10: Save if output path provided
if output is not None:
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/optim/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,8 @@ def test_accepts_string_path(self, model_file: Path) -> None:
mock_opt.return_value.optimize.return_value = onnx.ModelProto()
result = optimize_onnx(str(model_file))
assert isinstance(result, onnx.ModelProto)
# Verify model was loaded and passed to optimizer
mock_opt.return_value.optimize.assert_called_once()
# Verify model was loaded and passed to optimizer (called twice to optimize correctly)
assert mock_opt.return_value.optimize.call_count == 2

def test_accepts_path_object(self, model_file: Path) -> None:
"""Accept Path object as model input."""
Expand Down
Loading