From 6626c37c290efff3b578702bd040fb85fe6dcc36 Mon Sep 17 00:00:00 2001 From: Chao Zhang Date: Wed, 1 Apr 2026 16:43:00 +0800 Subject: [PATCH 1/4] fix sam2 --- src/winml/modelkit/export/io.py | 21 +++++++--- src/winml/modelkit/models/hf/sam.py | 60 +++++++++++++++++++++++++---- 2 files changed, 67 insertions(+), 14 deletions(-) diff --git a/src/winml/modelkit/export/io.py b/src/winml/modelkit/export/io.py index 0780ba4a7..64c7952be 100644 --- a/src/winml/modelkit/export/io.py +++ b/src/winml/modelkit/export/io.py @@ -73,14 +73,19 @@ class ONNXConfigNotFoundError(ValueError): TASK_SYNONYM_EXTENSIONS: dict[str, str] = { # next-sentence-prediction has same I/O as text-classification: input_ids → logits "next-sentence-prediction": "text-classification", + # mask-generation is registered via register_onnx_overwrite for SAM2. + # Optimum incorrectly maps it to "feature-extraction"; preserve as-is. + "mask-generation": "mask-generation", } def _map_task_synonym(task: str) -> str: """Map task name to canonical form, extending Optimum's synonym mapping. - This function first checks our custom extensions for tasks Optimum doesn't - recognize, then delegates to Optimum's map_from_synonym for known synonyms. + Our extensions take priority over Optimum's built-in synonym map. + If a task is found in ``TASK_SYNONYM_EXTENSIONS``, return immediately + without passing through Optimum (which may incorrectly normalize + custom-registered tasks like ``mask-generation``). Args: task: Task name (e.g., "next-sentence-prediction", "image-feature-extraction") @@ -91,16 +96,20 @@ def _map_task_synonym(task: str) -> str: Example: >>> map_task_synonym("next-sentence-prediction") # Our extension 'text-classification' + >>> map_task_synonym("mask-generation") # Preserved (not Optimum-normalized) + 'mask-generation' >>> map_task_synonym("image-feature-extraction") # Optimum's synonym 'feature-extraction' >>> map_task_synonym("text-classification") # Already canonical 'text-classification' """ - # First: apply our extensions for tasks Optimum doesn't recognize - mapped_task = TASK_SYNONYM_EXTENSIONS.get(task, task) + # Our extensions take priority — return early to prevent Optimum from + # incorrectly normalizing custom-registered tasks. + if task in TASK_SYNONYM_EXTENSIONS: + return TASK_SYNONYM_EXTENSIONS[task] - # Second: normalize via Optimum's built-in synonym mapping - return TasksManager.map_from_synonym(mapped_task) + # Fallback: normalize via Optimum's built-in synonym mapping + return TasksManager.map_from_synonym(task) # ============================================================================= diff --git a/src/winml/modelkit/models/hf/sam.py b/src/winml/modelkit/models/hf/sam.py index eab910471..b2ba5bd74 100644 --- a/src/winml/modelkit/models/hf/sam.py +++ b/src/winml/modelkit/models/hf/sam.py @@ -60,16 +60,48 @@ # Sam2VisionModel cannot load weights from a Sam2VideoModel checkpoint because # checkpoint keys are prefixed with "vision_encoder." (e.g., "vision_encoder.backbone.*") # but Sam2VisionModel expects unprefixed keys (e.g., "backbone.*"). -# This wrapper loads the full Sam2VideoModel and extracts the vision_encoder submodule. +# This wrapper loads the full Sam2VideoModel and extracts the vision_encoder submodule, +# flattening the FPN tuple outputs into individual tensor outputs for ONNX compatibility. class Sam2VisionEncoder(torch.nn.Module): - """Wrapper that loads Sam2VideoModel and extracts vision_encoder.""" + """Wrapper that loads Sam2VideoModel, extracts vision_encoder. + + Flattens FPN tuple outputs for ONNX export. + + Sam2VisionModel.forward() returns Sam2VisionEncoderOutput where + fpn_hidden_states is a tuple of 3 tensors (one per FPN level). + Optimum's ModelPatcher output filter matches by dict key name against + the ONNX config's output names, so tuple-of-tensor fields are invisible + to the filter, producing an empty ONNX graph. + + This wrapper flattens the FPN outputs into individual tensor entries + with names matching Sam2ImageEncoderIOConfig.outputs: + fpn_hidden_states[2] -> image_embeddings [B, 256, 64, 64] + fpn_hidden_states[0] -> high_res_features1 [B, 256, 256, 256] + fpn_hidden_states[1] -> high_res_features2 [B, 256, 128, 128] + """ + + def __init__(self, vision_encoder: torch.nn.Module) -> None: + super().__init__() + self.vision_encoder = vision_encoder + self.config = vision_encoder.config @classmethod - def from_pretrained(cls, model_name_or_path: str, **kwargs) -> torch.nn.Module: + def from_pretrained(cls, model_name_or_path: str, **kwargs) -> Sam2VisionEncoder: full_model = Sam2Model.from_pretrained(model_name_or_path, **kwargs) - return full_model.vision_encoder + return cls(full_model.vision_encoder) + + def forward( + self, pixel_values: torch.Tensor | None = None, **kwargs: Any + ) -> dict[str, torch.Tensor]: + out = self.vision_encoder(pixel_values=pixel_values, **kwargs) + fpn = out.fpn_hidden_states + return { + "image_embeddings": fpn[2], + "high_res_features1": fpn[0], + "high_res_features2": fpn[1], + } class SAM2MaskGeneration(torch.nn.Module): @@ -155,14 +187,26 @@ def forward( no_mem = self.no_memory_embedding.permute(0, 2, 1).unsqueeze(-1) image_embeddings = image_embeddings + no_mem - # 2. Prompt embeddings (patched by Sam2ModelPatcher during export) - sparse_embeddings, dense_embeddings = self.prompt_encoder( + # 2. Prompt embeddings + # Get sparse embeddings (without mask — mask blending handled below) + sparse_embeddings, _ = self.prompt_encoder( input_points=input_points, input_labels=input_labels, input_boxes=None, - input_masks=mask_input, - use_mask_input=use_mask_input, + input_masks=None, + ) + + # Arithmetic mask blending via use_mask_input flag + # (avoids torch.where for ONNX/QNN compatibility) + mask_dense = self.prompt_encoder.mask_embed(mask_input) + no_mask_dense = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, + -1, + self.image_embedding_size[0], + self.image_embedding_size[1], ) + flag = use_mask_input.reshape(-1, 1, 1, 1).to(mask_dense.dtype) + dense_embeddings = (1.0 - flag) * no_mask_dense + flag * mask_dense # 3. Positional embeddings image_positional_embeddings = self._get_image_positional_embeddings(batch_size) From 49d878ef5f26508da4c7cd1b533a5f482cc40ace Mon Sep 17 00:00:00 2001 From: Chao Zhang Date: Thu, 2 Apr 2026 11:21:14 +0800 Subject: [PATCH 2/4] optimize two times --- src/winml/modelkit/optim/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/winml/modelkit/optim/api.py b/src/winml/modelkit/optim/api.py index 47da57d8b..3429d81fe 100644 --- a/src/winml/modelkit/optim/api.py +++ b/src/winml/modelkit/optim/api.py @@ -256,6 +256,7 @@ def optimize_onnx( logger.info("Starting optimization pipeline...") optimizer = Optimizer() optimized_model = optimizer.optimize(loaded_model, **optimizer_kwargs) + optimized_model = optimizer.optimize(optimized_model, **optimizer_kwargs) # Step 10: Save if output path provided if output is not None: From 3cf41f1b326ad30e7c99f9fc70165641e19a5ab4 Mon Sep 17 00:00:00 2001 From: Chao Zhang Date: Thu, 2 Apr 2026 11:32:21 +0800 Subject: [PATCH 3/4] fix test --- tests/unit/optim/test_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/optim/test_api.py b/tests/unit/optim/test_api.py index 18e191c80..1c2c1d0ff 100644 --- a/tests/unit/optim/test_api.py +++ b/tests/unit/optim/test_api.py @@ -262,8 +262,8 @@ def test_accepts_string_path(self, model_file: Path) -> None: mock_opt.return_value.optimize.return_value = onnx.ModelProto() result = optimize_onnx(str(model_file)) assert isinstance(result, onnx.ModelProto) - # Verify model was loaded and passed to optimizer - mock_opt.return_value.optimize.assert_called_once() + # Verify model was loaded and passed to optimizer (called twice to optimize correctly) + assert mock_opt.return_value.optimize.call_count == 2 def test_accepts_path_object(self, model_file: Path) -> None: """Accept Path object as model input.""" From b4227d9929bba859112ae8210bf3151d5d0ba673 Mon Sep 17 00:00:00 2001 From: Chao Zhang Date: Thu, 2 Apr 2026 12:03:17 +0800 Subject: [PATCH 4/4] add comments --- src/winml/modelkit/optim/api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/winml/modelkit/optim/api.py b/src/winml/modelkit/optim/api.py index 3429d81fe..8993d76f7 100644 --- a/src/winml/modelkit/optim/api.py +++ b/src/winml/modelkit/optim/api.py @@ -255,6 +255,7 @@ def optimize_onnx( # Step 9: Run optimization logger.info("Starting optimization pipeline...") optimizer = Optimizer() + # More than one "optimize" call is necessary for certain models for thorough optimization optimized_model = optimizer.optimize(loaded_model, **optimizer_kwargs) optimized_model = optimizer.optimize(optimized_model, **optimizer_kwargs)