Skip to content

Commit f13df06

Browse files
Fix sam models perf issue. (#333)
* fix sam models * Remove Xenova/slimsam-77-uniform from model list * fix comments
1 parent d79e53d commit f13df06

2 files changed

Lines changed: 228 additions & 26 deletions

File tree

scripts/e2e_eval/testsets/models_all.json

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,16 +1829,6 @@
18291829
"last_update_time": "2024-01-11T19:23:46+00:00",
18301830
"optimum_supported": true
18311831
},
1832-
{
1833-
"hf_id": "Xenova/slimsam-77-uniform",
1834-
"task": "mask-generation",
1835-
"model_type": "sam",
1836-
"group": "Top200",
1837-
"priority": "P1",
1838-
"downloads": 9389,
1839-
"last_update_time": "2024-12-12T22:29:52+00:00",
1840-
"optimum_supported": true
1841-
},
18421832
{
18431833
"hf_id": "Zigeng/SlimSAM-uniform-77",
18441834
"task": "mask-generation",

src/winml/modelkit/models/hf/sam.py

Lines changed: 228 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,30 @@
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# Licensed under the MIT License.
44
# --------------------------------------------------------------------------
5-
"""SAM2 HuggingFace Model Patches and Export Configs.
5+
"""SAM/SAM2 HuggingFace model patches and ONNX export configs.
66
7-
Provides QNN-compatible patches and ONNX export configs for SAM2
8-
(Segment Anything Model 2) from Meta/Facebook.
7+
Provides QNN-compatible patches and ONNX export configs for both:
8+
- SAM v1 (facebook/sam-vit-*)
9+
- SAM2 / SAM2-video (facebook/sam2-hiera-*)
910
10-
Key Features:
11-
- QNN-compatible patches: 5D window partition, arithmetic masking
12-
- Split export: Separate encoder and decoder ONNX files
11+
Key features:
12+
- SAM2 QNN-compatible patches: 5D window partition, arithmetic masking
13+
- Split and task-specific exports: encoder, full model, and decoder wrappers
1314
14-
Patch Targets (applied via Sam2ModelPatcher during export):
15+
Patch targets (applied via Sam2ModelPatcher during export):
1516
- Sam2MultiScaleBlock: 5D window partition (6D->5D for QNN)
1617
- Sam2PromptEncoder: Arithmetic masking (torch.where->arithmetic for ONNX)
1718
18-
Export Strategy (split):
19-
- Sam2ImageEncoderIOConfig: pixel_values -> embeddings + high_res features
20-
- Sam2MaskDecoderIOConfig: prompts + embeddings -> masks + iou_scores
21-
22-
Model: facebook/sam2-hiera-small, facebook/sam2-hiera-large, etc.
23-
Task: image-segmentation
19+
Export coverage:
20+
- SAM2: image encoder, full model, mask-generation decoder wrapper
21+
- SAM v1: mask-generation decoder wrapper
2422
2523
Exports:
2624
Sam2NormalizedVisionConfig: NormalizedVisionConfig with 1024 image_size
27-
Sam2ImageEncoderIOConfig: ONNX config for image encoder
28-
Sam2MaskDecoderIOConfig: ONNX config for mask decoder
25+
Sam2ImageEncoderIOConfig: ONNX config for SAM2 image encoder
26+
Sam2IOConfig: ONNX config for SAM2 full model
27+
Sam2MaskGenerationIOConfig: ONNX config for SAM2 mask-generation wrapper
28+
SamMaskGenerationIOConfig: ONNX config for SAM v1 mask-generation wrapper
2929
Sam2ModelPatcher: Custom ModelPatcher for SAM2 export patches
3030
_patched_sam2_multiscale_block_forward: Patched forward (internal)
3131
_patched_sam2_prompt_encoder_forward: Patched forward (internal)
@@ -45,7 +45,7 @@
4545
DummyInputGenerator,
4646
DummyVisionInputGenerator,
4747
)
48-
from transformers import Sam2Model
48+
from transformers import Sam2Model, SamModel
4949

5050
from ...export import register_onnx_overwrite
5151

@@ -240,6 +240,119 @@ def forward(
240240
return masks, iou_scores, low_res_masks
241241

242242

243+
class SAMMaskGeneration(torch.nn.Module):
244+
"""Export wrapper for SAM v1 mask generation (decoder portion).
245+
246+
Composes prompt_encoder + mask_decoder + positional embeddings
247+
into a single module with explicit I/O signature.
248+
249+
Mirrors SamModel.forward flow:
250+
1. Encode prompts (points + optional mask)
251+
2. Compute positional embeddings
252+
3. Run mask decoder
253+
254+
Inputs:
255+
input_points: [B, 1, N, 2] - Point coordinates in pixels
256+
input_labels: [B, 1, N] - Point labels (0=neg, 1=pos, -1=pad)
257+
image_embeddings: [B, 256, 64, 64] - From vision encoder
258+
mask_input: [B, 1, 256, 256] - Previous mask (for refinement)
259+
use_mask_input: [B] - Flag: 0.0=ignore mask, 1.0=use mask
260+
261+
Outputs:
262+
masks: [B, 3, 1024, 1024] - Full resolution masks
263+
iou_scores: [B, 3] - IoU predictions per mask
264+
low_res_masks: [B, 3, 256, 256] - Low-res masks (for next iteration)
265+
"""
266+
267+
@classmethod
268+
def from_pretrained(cls, model_name_or_path: str, **kwargs) -> SAMMaskGeneration:
269+
"""Load from a HuggingFace SamModel checkpoint."""
270+
sam_model = SamModel.from_pretrained(model_name_or_path, **kwargs)
271+
return cls(sam_model)
272+
273+
def __init__(self, sam_model):
274+
super().__init__()
275+
276+
self.prompt_encoder = sam_model.prompt_encoder
277+
self.mask_decoder = sam_model.mask_decoder
278+
self.shared_image_embedding = sam_model.shared_image_embedding
279+
self.image_embedding_size = self.prompt_encoder.image_embedding_size
280+
self.config = sam_model.config
281+
282+
def _get_image_positional_embeddings(self, batch_size: int = 1) -> torch.Tensor:
283+
"""Replicates SamModel.get_image_wide_positional_embeddings()."""
284+
size = self.config.prompt_encoder_config.image_embedding_size
285+
target_device = self.shared_image_embedding.positional_embedding.device
286+
target_dtype = self.shared_image_embedding.positional_embedding.dtype
287+
288+
grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
289+
y_embed = grid.cumsum(dim=0) - 0.5
290+
x_embed = grid.cumsum(dim=1) - 0.5
291+
y_embed = y_embed / size
292+
x_embed = x_embed / size
293+
294+
positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
295+
positional_embedding = positional_embedding.permute(2, 0, 1).unsqueeze(0)
296+
return positional_embedding.repeat(batch_size, 1, 1, 1)
297+
298+
def forward(
299+
self,
300+
input_points: torch.Tensor,
301+
input_labels: torch.Tensor,
302+
image_embeddings: torch.Tensor,
303+
mask_input: torch.Tensor,
304+
use_mask_input: torch.Tensor,
305+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
306+
"""Run mask generation from pre-computed encoder features."""
307+
batch_size = image_embeddings.shape[0]
308+
309+
# 1. Prompt embeddings (sparse - points only, mask handled separately)
310+
sparse_embeddings, _ = self.prompt_encoder(
311+
input_points=input_points,
312+
input_labels=input_labels,
313+
input_boxes=None,
314+
input_masks=None,
315+
)
316+
317+
# Arithmetic mask blending via use_mask_input flag
318+
# (avoids torch.where for ONNX/QNN compatibility)
319+
mask_dense = self.prompt_encoder.mask_embed(mask_input)
320+
no_mask_dense = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
321+
batch_size,
322+
-1,
323+
self.image_embedding_size[0],
324+
self.image_embedding_size[1],
325+
)
326+
flag = use_mask_input.reshape(-1, 1, 1, 1).to(mask_dense.dtype)
327+
dense_embeddings = (1.0 - flag) * no_mask_dense + flag * mask_dense
328+
329+
# 2. Positional embeddings
330+
image_positional_embeddings = self._get_image_positional_embeddings(batch_size)
331+
332+
# 3. Mask decoder
333+
low_res_masks, iou_pred = self.mask_decoder(
334+
image_embeddings=image_embeddings,
335+
image_positional_embeddings=image_positional_embeddings,
336+
sparse_prompt_embeddings=sparse_embeddings,
337+
dense_prompt_embeddings=dense_embeddings,
338+
multimask_output=True,
339+
)
340+
341+
# Squeeze point_batch_size dimension
342+
low_res_masks = low_res_masks.squeeze(1) # [B, 3, 256, 256]
343+
iou_scores = iou_pred.squeeze(1) # [B, 3]
344+
345+
# 4. Upsample to full resolution
346+
masks = torch.nn.functional.interpolate(
347+
low_res_masks,
348+
size=(1024, 1024),
349+
mode="bilinear",
350+
align_corners=False,
351+
)
352+
353+
return masks, iou_scores, low_res_masks
354+
355+
243356
# =============================================================================
244357
# HuggingFace Model Class Mapping
245358
# =============================================================================
@@ -254,6 +367,7 @@ def forward(
254367
# Users wanting the full model use --task image-segmentation.
255368

256369
MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = {
370+
("sam", "mask-generation"): SAMMaskGeneration,
257371
("sam2", "image-segmentation"): Sam2Model,
258372
("sam2", "feature-extraction"): Sam2VisionEncoder,
259373
("sam2", "image-feature-extraction"): Sam2VisionEncoder,
@@ -812,13 +926,111 @@ def outputs(self) -> dict[str, dict[int, str]]:
812926
}
813927

814928

929+
# =============================================================================
930+
# SAM v1 Custom Dummy Input Generators
931+
# =============================================================================
932+
class SamEmbeddingsInputGenerator(DummyInputGenerator):
933+
"""Embeddings input generator for SAM v1 mask generation decoder.
934+
935+
Generates:
936+
- image_embeddings: [B, 256, 64, 64] - From vision encoder
937+
"""
938+
939+
SUPPORTED_INPUT_NAMES = ("image_embeddings",)
940+
941+
def __init__(
942+
self,
943+
task: str,
944+
normalized_config: NormalizedConfig,
945+
batch_size: int = 1,
946+
**kwargs,
947+
):
948+
self.task = task
949+
self.batch_size = batch_size
950+
951+
def generate(
952+
self,
953+
input_name: str,
954+
framework: str = "pt",
955+
int_dtype: str = "int64",
956+
float_dtype: str = "fp32",
957+
):
958+
# SAM v1 decoder export expects the canonical embedding shape from the
959+
# vision encoder output; this mirrors the existing SAM2 generator path.
960+
shape = [self.batch_size, 256, 64, 64]
961+
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)
962+
963+
964+
# =============================================================================
965+
# SAM v1 Optimum ONNX Export Config Registration
966+
# =============================================================================
967+
968+
969+
# -----------------------------------------------------------------------------
970+
# Mask generation export (SAMMaskGeneration wrapper) - SAM v1
971+
# -----------------------------------------------------------------------------
972+
@register_onnx_overwrite("sam", "mask-generation", library_name="transformers")
973+
class SamMaskGenerationIOConfig(OnnxConfig):
974+
"""ONNX config for SAMMaskGeneration (SAM v1 decoder).
975+
976+
Model: facebook/sam-vit-huge, facebook/sam-vit-large, facebook/sam-vit-base
977+
Uses SAMMaskGeneration nn.Module which takes image_embeddings from the
978+
vision encoder and runs prompt encoding + mask decoding.
979+
980+
Inputs:
981+
- input_points: {0: "batch_size"} [B, 1, N, 2]
982+
- input_labels: {0: "batch_size"} [B, 1, N]
983+
- image_embeddings: {0: "batch_size"} [B, 256, 64, 64]
984+
- mask_input: {0: "batch_size"} [B, 1, 256, 256]
985+
- use_mask_input: {0: "batch_size"} [B]
986+
987+
Outputs:
988+
- masks: {0: "batch_size"} [B, 3, 1024, 1024]
989+
- iou_scores: {0: "batch_size"} [B, 3]
990+
- low_res_masks: {0: "batch_size"} [B, 3, 256, 256]
991+
"""
992+
993+
# SAM v1 also uses 1024x1024 default image size, so this normalized config
994+
# is intentionally shared across SAM v1 and SAM2 export configs.
995+
NORMALIZED_CONFIG_CLASS = Sam2NormalizedVisionConfig
996+
# SAM v1 reuses SAM2-named generators because prompt/mask tensor shapes
997+
# are identical for this export path.
998+
DUMMY_INPUT_GENERATOR_CLASSES = (
999+
Sam2PointsInputGenerator,
1000+
SamEmbeddingsInputGenerator,
1001+
Sam2MaskInputGenerator,
1002+
)
1003+
1004+
@property
1005+
def inputs(self) -> dict[str, dict[int, str]]:
1006+
"""Return input tensors for SAM v1 mask generation."""
1007+
return {
1008+
"input_points": {0: "batch_size"},
1009+
"input_labels": {0: "batch_size"},
1010+
"image_embeddings": {0: "batch_size"},
1011+
"mask_input": {0: "batch_size"},
1012+
"use_mask_input": {0: "batch_size"},
1013+
}
1014+
1015+
@property
1016+
def outputs(self) -> dict[str, dict[int, str]]:
1017+
"""Return output tensors for SAM v1 mask generation."""
1018+
return {
1019+
"masks": {0: "batch_size"},
1020+
"iou_scores": {0: "batch_size"},
1021+
"low_res_masks": {0: "batch_size"},
1022+
}
1023+
1024+
8151025
__all__ = [
8161026
"SAM2MaskGeneration",
1027+
"SAMMaskGeneration",
8171028
"Sam2IOConfig",
8181029
"Sam2ImageEncoderIOConfig",
8191030
"Sam2MaskGenerationIOConfig",
8201031
"Sam2ModelPatcher",
8211032
"Sam2NormalizedVisionConfig",
1033+
"SamMaskGenerationIOConfig",
8221034
"_patched_sam2_multiscale_block_forward",
8231035
"_patched_sam2_prompt_encoder_forward",
8241036
]

0 commit comments

Comments
 (0)