diff --git a/scripts/e2e_eval/testsets/models_all.json b/scripts/e2e_eval/testsets/models_all.json index 862f839b9..82809073c 100644 --- a/scripts/e2e_eval/testsets/models_all.json +++ b/scripts/e2e_eval/testsets/models_all.json @@ -1829,16 +1829,6 @@ "last_update_time": "2024-01-11T19:23:46+00:00", "optimum_supported": true }, - { - "hf_id": "Xenova/slimsam-77-uniform", - "task": "mask-generation", - "model_type": "sam", - "group": "Top200", - "priority": "P1", - "downloads": 9389, - "last_update_time": "2024-12-12T22:29:52+00:00", - "optimum_supported": true - }, { "hf_id": "Zigeng/SlimSAM-uniform-77", "task": "mask-generation", diff --git a/src/winml/modelkit/models/hf/sam.py b/src/winml/modelkit/models/hf/sam.py index b2ba5bd74..a7d6e460a 100644 --- a/src/winml/modelkit/models/hf/sam.py +++ b/src/winml/modelkit/models/hf/sam.py @@ -2,30 +2,30 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""SAM2 HuggingFace Model Patches and Export Configs. +"""SAM/SAM2 HuggingFace model patches and ONNX export configs. -Provides QNN-compatible patches and ONNX export configs for SAM2 -(Segment Anything Model 2) from Meta/Facebook. +Provides QNN-compatible patches and ONNX export configs for both: +- SAM v1 (facebook/sam-vit-*) +- SAM2 / SAM2-video (facebook/sam2-hiera-*) -Key Features: -- QNN-compatible patches: 5D window partition, arithmetic masking -- Split export: Separate encoder and decoder ONNX files +Key features: +- SAM2 QNN-compatible patches: 5D window partition, arithmetic masking +- Split and task-specific exports: encoder, full model, and decoder wrappers -Patch Targets (applied via Sam2ModelPatcher during export): +Patch targets (applied via Sam2ModelPatcher during export): - Sam2MultiScaleBlock: 5D window partition (6D->5D for QNN) - Sam2PromptEncoder: Arithmetic masking (torch.where->arithmetic for ONNX) -Export Strategy (split): -- Sam2ImageEncoderIOConfig: pixel_values -> embeddings + high_res features -- Sam2MaskDecoderIOConfig: prompts + embeddings -> masks + iou_scores - -Model: facebook/sam2-hiera-small, facebook/sam2-hiera-large, etc. -Task: image-segmentation +Export coverage: +- SAM2: image encoder, full model, mask-generation decoder wrapper +- SAM v1: mask-generation decoder wrapper Exports: Sam2NormalizedVisionConfig: NormalizedVisionConfig with 1024 image_size - Sam2ImageEncoderIOConfig: ONNX config for image encoder - Sam2MaskDecoderIOConfig: ONNX config for mask decoder + Sam2ImageEncoderIOConfig: ONNX config for SAM2 image encoder + Sam2IOConfig: ONNX config for SAM2 full model + Sam2MaskGenerationIOConfig: ONNX config for SAM2 mask-generation wrapper + SamMaskGenerationIOConfig: ONNX config for SAM v1 mask-generation wrapper Sam2ModelPatcher: Custom ModelPatcher for SAM2 export patches _patched_sam2_multiscale_block_forward: Patched forward (internal) _patched_sam2_prompt_encoder_forward: Patched forward (internal) @@ -45,7 +45,7 @@ DummyInputGenerator, DummyVisionInputGenerator, ) -from transformers import Sam2Model +from transformers import Sam2Model, SamModel from ...export import register_onnx_overwrite @@ -240,6 +240,119 @@ def forward( return masks, iou_scores, low_res_masks +class SAMMaskGeneration(torch.nn.Module): + """Export wrapper for SAM v1 mask generation (decoder portion). + + Composes prompt_encoder + mask_decoder + positional embeddings + into a single module with explicit I/O signature. + + Mirrors SamModel.forward flow: + 1. Encode prompts (points + optional mask) + 2. Compute positional embeddings + 3. Run mask decoder + + Inputs: + input_points: [B, 1, N, 2] - Point coordinates in pixels + input_labels: [B, 1, N] - Point labels (0=neg, 1=pos, -1=pad) + image_embeddings: [B, 256, 64, 64] - From vision encoder + mask_input: [B, 1, 256, 256] - Previous mask (for refinement) + use_mask_input: [B] - Flag: 0.0=ignore mask, 1.0=use mask + + Outputs: + masks: [B, 3, 1024, 1024] - Full resolution masks + iou_scores: [B, 3] - IoU predictions per mask + low_res_masks: [B, 3, 256, 256] - Low-res masks (for next iteration) + """ + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs) -> SAMMaskGeneration: + """Load from a HuggingFace SamModel checkpoint.""" + sam_model = SamModel.from_pretrained(model_name_or_path, **kwargs) + return cls(sam_model) + + def __init__(self, sam_model): + super().__init__() + + self.prompt_encoder = sam_model.prompt_encoder + self.mask_decoder = sam_model.mask_decoder + self.shared_image_embedding = sam_model.shared_image_embedding + self.image_embedding_size = self.prompt_encoder.image_embedding_size + self.config = sam_model.config + + def _get_image_positional_embeddings(self, batch_size: int = 1) -> torch.Tensor: + """Replicates SamModel.get_image_wide_positional_embeddings().""" + size = self.config.prompt_encoder_config.image_embedding_size + target_device = self.shared_image_embedding.positional_embedding.device + target_dtype = self.shared_image_embedding.positional_embedding.dtype + + grid = torch.ones((size, size), device=target_device, dtype=target_dtype) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / size + x_embed = x_embed / size + + positional_embedding = self.shared_image_embedding(torch.stack([x_embed, y_embed], dim=-1)) + positional_embedding = positional_embedding.permute(2, 0, 1).unsqueeze(0) + return positional_embedding.repeat(batch_size, 1, 1, 1) + + def forward( + self, + input_points: torch.Tensor, + input_labels: torch.Tensor, + image_embeddings: torch.Tensor, + mask_input: torch.Tensor, + use_mask_input: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run mask generation from pre-computed encoder features.""" + batch_size = image_embeddings.shape[0] + + # 1. Prompt embeddings (sparse - points only, mask handled separately) + sparse_embeddings, _ = self.prompt_encoder( + input_points=input_points, + input_labels=input_labels, + input_boxes=None, + input_masks=None, + ) + + # Arithmetic mask blending via use_mask_input flag + # (avoids torch.where for ONNX/QNN compatibility) + mask_dense = self.prompt_encoder.mask_embed(mask_input) + no_mask_dense = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + batch_size, + -1, + self.image_embedding_size[0], + self.image_embedding_size[1], + ) + flag = use_mask_input.reshape(-1, 1, 1, 1).to(mask_dense.dtype) + dense_embeddings = (1.0 - flag) * no_mask_dense + flag * mask_dense + + # 2. Positional embeddings + image_positional_embeddings = self._get_image_positional_embeddings(batch_size) + + # 3. Mask decoder + low_res_masks, iou_pred = self.mask_decoder( + image_embeddings=image_embeddings, + image_positional_embeddings=image_positional_embeddings, + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=True, + ) + + # Squeeze point_batch_size dimension + low_res_masks = low_res_masks.squeeze(1) # [B, 3, 256, 256] + iou_scores = iou_pred.squeeze(1) # [B, 3] + + # 4. Upsample to full resolution + masks = torch.nn.functional.interpolate( + low_res_masks, + size=(1024, 1024), + mode="bilinear", + align_corners=False, + ) + + return masks, iou_scores, low_res_masks + + # ============================================================================= # HuggingFace Model Class Mapping # ============================================================================= @@ -254,6 +367,7 @@ def forward( # Users wanting the full model use --task image-segmentation. MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("sam", "mask-generation"): SAMMaskGeneration, ("sam2", "image-segmentation"): Sam2Model, ("sam2", "feature-extraction"): Sam2VisionEncoder, ("sam2", "image-feature-extraction"): Sam2VisionEncoder, @@ -812,13 +926,111 @@ def outputs(self) -> dict[str, dict[int, str]]: } +# ============================================================================= +# SAM v1 Custom Dummy Input Generators +# ============================================================================= +class SamEmbeddingsInputGenerator(DummyInputGenerator): + """Embeddings input generator for SAM v1 mask generation decoder. + + Generates: + - image_embeddings: [B, 256, 64, 64] - From vision encoder + """ + + SUPPORTED_INPUT_NAMES = ("image_embeddings",) + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + **kwargs, + ): + self.task = task + self.batch_size = batch_size + + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ): + # SAM v1 decoder export expects the canonical embedding shape from the + # vision encoder output; this mirrors the existing SAM2 generator path. + shape = [self.batch_size, 256, 64, 64] + return self.random_float_tensor(shape, framework=framework, dtype=float_dtype) + + +# ============================================================================= +# SAM v1 Optimum ONNX Export Config Registration +# ============================================================================= + + +# ----------------------------------------------------------------------------- +# Mask generation export (SAMMaskGeneration wrapper) - SAM v1 +# ----------------------------------------------------------------------------- +@register_onnx_overwrite("sam", "mask-generation", library_name="transformers") +class SamMaskGenerationIOConfig(OnnxConfig): + """ONNX config for SAMMaskGeneration (SAM v1 decoder). + + Model: facebook/sam-vit-huge, facebook/sam-vit-large, facebook/sam-vit-base + Uses SAMMaskGeneration nn.Module which takes image_embeddings from the + vision encoder and runs prompt encoding + mask decoding. + + Inputs: + - input_points: {0: "batch_size"} [B, 1, N, 2] + - input_labels: {0: "batch_size"} [B, 1, N] + - image_embeddings: {0: "batch_size"} [B, 256, 64, 64] + - mask_input: {0: "batch_size"} [B, 1, 256, 256] + - use_mask_input: {0: "batch_size"} [B] + + Outputs: + - masks: {0: "batch_size"} [B, 3, 1024, 1024] + - iou_scores: {0: "batch_size"} [B, 3] + - low_res_masks: {0: "batch_size"} [B, 3, 256, 256] + """ + + # SAM v1 also uses 1024x1024 default image size, so this normalized config + # is intentionally shared across SAM v1 and SAM2 export configs. + NORMALIZED_CONFIG_CLASS = Sam2NormalizedVisionConfig + # SAM v1 reuses SAM2-named generators because prompt/mask tensor shapes + # are identical for this export path. + DUMMY_INPUT_GENERATOR_CLASSES = ( + Sam2PointsInputGenerator, + SamEmbeddingsInputGenerator, + Sam2MaskInputGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + """Return input tensors for SAM v1 mask generation.""" + return { + "input_points": {0: "batch_size"}, + "input_labels": {0: "batch_size"}, + "image_embeddings": {0: "batch_size"}, + "mask_input": {0: "batch_size"}, + "use_mask_input": {0: "batch_size"}, + } + + @property + def outputs(self) -> dict[str, dict[int, str]]: + """Return output tensors for SAM v1 mask generation.""" + return { + "masks": {0: "batch_size"}, + "iou_scores": {0: "batch_size"}, + "low_res_masks": {0: "batch_size"}, + } + + __all__ = [ "SAM2MaskGeneration", + "SAMMaskGeneration", "Sam2IOConfig", "Sam2ImageEncoderIOConfig", "Sam2MaskGenerationIOConfig", "Sam2ModelPatcher", "Sam2NormalizedVisionConfig", + "SamMaskGenerationIOConfig", "_patched_sam2_multiscale_block_forward", "_patched_sam2_prompt_encoder_forward", ]