Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions scripts/e2e_eval/testsets/models_all.json
Original file line number Diff line number Diff line change
Expand Up @@ -1829,16 +1829,6 @@
"last_update_time": "2024-01-11T19:23:46+00:00",
"optimum_supported": true
},
{
"hf_id": "Xenova/slimsam-77-uniform",
"task": "mask-generation",
"model_type": "sam",
"group": "Top200",
"priority": "P1",
"downloads": 9389,
"last_update_time": "2024-12-12T22:29:52+00:00",
"optimum_supported": true
},
{
"hf_id": "Zigeng/SlimSAM-uniform-77",
"task": "mask-generation",
Expand Down
244 changes: 228 additions & 16 deletions src/winml/modelkit/models/hf/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,30 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""SAM2 HuggingFace Model Patches and Export Configs.
"""SAM/SAM2 HuggingFace model patches and ONNX export configs.

Provides QNN-compatible patches and ONNX export configs for SAM2
(Segment Anything Model 2) from Meta/Facebook.
Provides QNN-compatible patches and ONNX export configs for both:
- SAM v1 (facebook/sam-vit-*)
- SAM2 / SAM2-video (facebook/sam2-hiera-*)

Key Features:
- QNN-compatible patches: 5D window partition, arithmetic masking
- Split export: Separate encoder and decoder ONNX files
Key features:
- SAM2 QNN-compatible patches: 5D window partition, arithmetic masking
- Split and task-specific exports: encoder, full model, and decoder wrappers

Patch Targets (applied via Sam2ModelPatcher during export):
Patch targets (applied via Sam2ModelPatcher during export):
- Sam2MultiScaleBlock: 5D window partition (6D->5D for QNN)
- Sam2PromptEncoder: Arithmetic masking (torch.where->arithmetic for ONNX)

Export Strategy (split):
- Sam2ImageEncoderIOConfig: pixel_values -> embeddings + high_res features
- Sam2MaskDecoderIOConfig: prompts + embeddings -> masks + iou_scores

Model: facebook/sam2-hiera-small, facebook/sam2-hiera-large, etc.
Task: image-segmentation
Export coverage:
- SAM2: image encoder, full model, mask-generation decoder wrapper
- SAM v1: mask-generation decoder wrapper

Exports:
Sam2NormalizedVisionConfig: NormalizedVisionConfig with 1024 image_size
Sam2ImageEncoderIOConfig: ONNX config for image encoder
Sam2MaskDecoderIOConfig: ONNX config for mask decoder
Sam2ImageEncoderIOConfig: ONNX config for SAM2 image encoder
Sam2IOConfig: ONNX config for SAM2 full model
Sam2MaskGenerationIOConfig: ONNX config for SAM2 mask-generation wrapper
SamMaskGenerationIOConfig: ONNX config for SAM v1 mask-generation wrapper
Sam2ModelPatcher: Custom ModelPatcher for SAM2 export patches
_patched_sam2_multiscale_block_forward: Patched forward (internal)
_patched_sam2_prompt_encoder_forward: Patched forward (internal)
Expand All @@ -45,7 +45,7 @@
DummyInputGenerator,
DummyVisionInputGenerator,
)
from transformers import Sam2Model
from transformers import Sam2Model, SamModel
Comment thread
chinazhangchao marked this conversation as resolved.

from ...export import register_onnx_overwrite

Expand Down Expand Up @@ -240,6 +240,119 @@ def forward(
return masks, iou_scores, low_res_masks


class SAMMaskGeneration(torch.nn.Module):
"""Export wrapper for SAM v1 mask generation (decoder portion).

Composes prompt_encoder + mask_decoder + positional embeddings
into a single module with explicit I/O signature.

Mirrors SamModel.forward flow:
1. Encode prompts (points + optional mask)
2. Compute positional embeddings
3. Run mask decoder

Inputs:
input_points: [B, 1, N, 2] - Point coordinates in pixels
input_labels: [B, 1, N] - Point labels (0=neg, 1=pos, -1=pad)
image_embeddings: [B, 256, 64, 64] - From vision encoder
mask_input: [B, 1, 256, 256] - Previous mask (for refinement)
use_mask_input: [B] - Flag: 0.0=ignore mask, 1.0=use mask

Outputs:
masks: [B, 3, 1024, 1024] - Full resolution masks
iou_scores: [B, 3] - IoU predictions per mask
low_res_masks: [B, 3, 256, 256] - Low-res masks (for next iteration)
"""

@classmethod
def from_pretrained(cls, model_name_or_path: str, **kwargs) -> SAMMaskGeneration:
"""Load from a HuggingFace SamModel checkpoint."""
sam_model = SamModel.from_pretrained(model_name_or_path, **kwargs)
return cls(sam_model)

def __init__(self, sam_model):
super().__init__()

self.prompt_encoder = sam_model.prompt_encoder
self.mask_decoder = sam_model.mask_decoder
self.shared_image_embedding = sam_model.shared_image_embedding
self.image_embedding_size = self.prompt_encoder.image_embedding_size
self.config = sam_model.config

def _get_image_positional_embeddings(self, batch_size: int = 1) -> torch.Tensor:
"""Replicates SamModel.get_image_wide_positional_embeddings()."""
size = self.config.prompt_encoder_config.image_embedding_size
target_device = self.shared_image_embedding.positional_embedding.device
target_dtype = self.shared_image_embedding.positional_embedding.dtype

grid = torch.ones((size, size), device=target_device, dtype=target_dtype)
y_embed = grid.cumsum(dim=0) - 0.5
x_embed = grid.cumsum(dim=1) - 0.5
y_embed = y_embed / size
x_embed = x_embed / size

positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1))
positional_embedding = positional_embedding.permute(2, 0, 1).unsqueeze(0)
return positional_embedding.repeat(batch_size, 1, 1, 1)

def forward(
self,
input_points: torch.Tensor,
input_labels: torch.Tensor,
image_embeddings: torch.Tensor,
mask_input: torch.Tensor,
use_mask_input: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Run mask generation from pre-computed encoder features."""
batch_size = image_embeddings.shape[0]

# 1. Prompt embeddings (sparse - points only, mask handled separately)
sparse_embeddings, _ = self.prompt_encoder(
input_points=input_points,
input_labels=input_labels,
input_boxes=None,
input_masks=None,
)

# Arithmetic mask blending via use_mask_input flag
# (avoids torch.where for ONNX/QNN compatibility)
mask_dense = self.prompt_encoder.mask_embed(mask_input)
no_mask_dense = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
batch_size,
-1,
self.image_embedding_size[0],
self.image_embedding_size[1],
)
flag = use_mask_input.reshape(-1, 1, 1, 1).to(mask_dense.dtype)
dense_embeddings = (1.0 - flag) * no_mask_dense + flag * mask_dense

# 2. Positional embeddings
image_positional_embeddings = self._get_image_positional_embeddings(batch_size)

# 3. Mask decoder
low_res_masks, iou_pred = self.mask_decoder(
image_embeddings=image_embeddings,
image_positional_embeddings=image_positional_embeddings,
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=True,
)

# Squeeze point_batch_size dimension
low_res_masks = low_res_masks.squeeze(1) # [B, 3, 256, 256]
iou_scores = iou_pred.squeeze(1) # [B, 3]

# 4. Upsample to full resolution
masks = torch.nn.functional.interpolate(
low_res_masks,
size=(1024, 1024),
mode="bilinear",
align_corners=False,
)

return masks, iou_scores, low_res_masks


# =============================================================================
# HuggingFace Model Class Mapping
# =============================================================================
Expand All @@ -254,6 +367,7 @@ def forward(
# Users wanting the full model use --task image-segmentation.

MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = {
("sam", "mask-generation"): SAMMaskGeneration,
("sam2", "image-segmentation"): Sam2Model,
("sam2", "feature-extraction"): Sam2VisionEncoder,
("sam2", "image-feature-extraction"): Sam2VisionEncoder,
Expand Down Expand Up @@ -812,13 +926,111 @@ def outputs(self) -> dict[str, dict[int, str]]:
}


# =============================================================================
# SAM v1 Custom Dummy Input Generators
# =============================================================================
class SamEmbeddingsInputGenerator(DummyInputGenerator):
"""Embeddings input generator for SAM v1 mask generation decoder.

Generates:
- image_embeddings: [B, 256, 64, 64] - From vision encoder
"""

SUPPORTED_INPUT_NAMES = ("image_embeddings",)

def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = 1,
**kwargs,
):
self.task = task
self.batch_size = batch_size

def generate(
self,
input_name: str,
framework: str = "pt",
int_dtype: str = "int64",
float_dtype: str = "fp32",
):
# SAM v1 decoder export expects the canonical embedding shape from the
# vision encoder output; this mirrors the existing SAM2 generator path.
shape = [self.batch_size, 256, 64, 64]
Comment thread
chinazhangchao marked this conversation as resolved.
return self.random_float_tensor(shape, framework=framework, dtype=float_dtype)


# =============================================================================
# SAM v1 Optimum ONNX Export Config Registration
# =============================================================================


# -----------------------------------------------------------------------------
# Mask generation export (SAMMaskGeneration wrapper) - SAM v1
# -----------------------------------------------------------------------------
@register_onnx_overwrite("sam", "mask-generation", library_name="transformers")
class SamMaskGenerationIOConfig(OnnxConfig):
"""ONNX config for SAMMaskGeneration (SAM v1 decoder).

Model: facebook/sam-vit-huge, facebook/sam-vit-large, facebook/sam-vit-base
Uses SAMMaskGeneration nn.Module which takes image_embeddings from the
vision encoder and runs prompt encoding + mask decoding.

Inputs:
- input_points: {0: "batch_size"} [B, 1, N, 2]
- input_labels: {0: "batch_size"} [B, 1, N]
- image_embeddings: {0: "batch_size"} [B, 256, 64, 64]
- mask_input: {0: "batch_size"} [B, 1, 256, 256]
- use_mask_input: {0: "batch_size"} [B]

Outputs:
- masks: {0: "batch_size"} [B, 3, 1024, 1024]
- iou_scores: {0: "batch_size"} [B, 3]
- low_res_masks: {0: "batch_size"} [B, 3, 256, 256]
"""

# SAM v1 also uses 1024x1024 default image size, so this normalized config
# is intentionally shared across SAM v1 and SAM2 export configs.
NORMALIZED_CONFIG_CLASS = Sam2NormalizedVisionConfig
Comment thread
chinazhangchao marked this conversation as resolved.
# SAM v1 reuses SAM2-named generators because prompt/mask tensor shapes
# are identical for this export path.
DUMMY_INPUT_GENERATOR_CLASSES = (
Comment thread
chinazhangchao marked this conversation as resolved.
Sam2PointsInputGenerator,
SamEmbeddingsInputGenerator,
Sam2MaskInputGenerator,
)

@property
def inputs(self) -> dict[str, dict[int, str]]:
"""Return input tensors for SAM v1 mask generation."""
return {
"input_points": {0: "batch_size"},
"input_labels": {0: "batch_size"},
"image_embeddings": {0: "batch_size"},
"mask_input": {0: "batch_size"},
"use_mask_input": {0: "batch_size"},
}

@property
def outputs(self) -> dict[str, dict[int, str]]:
"""Return output tensors for SAM v1 mask generation."""
return {
"masks": {0: "batch_size"},
"iou_scores": {0: "batch_size"},
"low_res_masks": {0: "batch_size"},
}


__all__ = [
"SAM2MaskGeneration",
"SAMMaskGeneration",
"Sam2IOConfig",
"Sam2ImageEncoderIOConfig",
"Sam2MaskGenerationIOConfig",
"Sam2ModelPatcher",
"Sam2NormalizedVisionConfig",
"SamMaskGenerationIOConfig",
"_patched_sam2_multiscale_block_forward",
"_patched_sam2_prompt_encoder_forward",
]
Loading