22# Copyright (c) Microsoft Corporation. All rights reserved.
33# Licensed under the MIT License.
44# --------------------------------------------------------------------------
5- """SAM2 HuggingFace Model Patches and Export Configs .
5+ """SAM/ SAM2 HuggingFace model patches and ONNX export configs .
66
7- Provides QNN-compatible patches and ONNX export configs for SAM2
8- (Segment Anything Model 2) from Meta/Facebook.
7+ Provides QNN-compatible patches and ONNX export configs for both:
8+ - SAM v1 (facebook/sam-vit-*)
9+ - SAM2 / SAM2-video (facebook/sam2-hiera-*)
910
10- Key Features :
11- - QNN-compatible patches: 5D window partition, arithmetic masking
12- - Split export: Separate encoder and decoder ONNX files
11+ Key features :
12+ - SAM2 QNN-compatible patches: 5D window partition, arithmetic masking
13+ - Split and task-specific exports: encoder, full model, and decoder wrappers
1314
14- Patch Targets (applied via Sam2ModelPatcher during export):
15+ Patch targets (applied via Sam2ModelPatcher during export):
1516- Sam2MultiScaleBlock: 5D window partition (6D->5D for QNN)
1617- Sam2PromptEncoder: Arithmetic masking (torch.where->arithmetic for ONNX)
1718
18- Export Strategy (split):
19- - Sam2ImageEncoderIOConfig: pixel_values -> embeddings + high_res features
20- - Sam2MaskDecoderIOConfig: prompts + embeddings -> masks + iou_scores
21-
22- Model: facebook/sam2-hiera-small, facebook/sam2-hiera-large, etc.
23- Task: image-segmentation
19+ Export coverage:
20+ - SAM2: image encoder, full model, mask-generation decoder wrapper
21+ - SAM v1: mask-generation decoder wrapper
2422
2523Exports:
2624 Sam2NormalizedVisionConfig: NormalizedVisionConfig with 1024 image_size
27- Sam2ImageEncoderIOConfig: ONNX config for image encoder
28- Sam2MaskDecoderIOConfig: ONNX config for mask decoder
25+ Sam2ImageEncoderIOConfig: ONNX config for SAM2 image encoder
26+ Sam2IOConfig: ONNX config for SAM2 full model
27+ Sam2MaskGenerationIOConfig: ONNX config for SAM2 mask-generation wrapper
28+ SamMaskGenerationIOConfig: ONNX config for SAM v1 mask-generation wrapper
2929 Sam2ModelPatcher: Custom ModelPatcher for SAM2 export patches
3030 _patched_sam2_multiscale_block_forward: Patched forward (internal)
3131 _patched_sam2_prompt_encoder_forward: Patched forward (internal)
4545 DummyInputGenerator ,
4646 DummyVisionInputGenerator ,
4747)
48- from transformers import Sam2Model
48+ from transformers import Sam2Model , SamModel
4949
5050from ...export import register_onnx_overwrite
5151
@@ -240,6 +240,119 @@ def forward(
240240 return masks , iou_scores , low_res_masks
241241
242242
243+ class SAMMaskGeneration (torch .nn .Module ):
244+ """Export wrapper for SAM v1 mask generation (decoder portion).
245+
246+ Composes prompt_encoder + mask_decoder + positional embeddings
247+ into a single module with explicit I/O signature.
248+
249+ Mirrors SamModel.forward flow:
250+ 1. Encode prompts (points + optional mask)
251+ 2. Compute positional embeddings
252+ 3. Run mask decoder
253+
254+ Inputs:
255+ input_points: [B, 1, N, 2] - Point coordinates in pixels
256+ input_labels: [B, 1, N] - Point labels (0=neg, 1=pos, -1=pad)
257+ image_embeddings: [B, 256, 64, 64] - From vision encoder
258+ mask_input: [B, 1, 256, 256] - Previous mask (for refinement)
259+ use_mask_input: [B] - Flag: 0.0=ignore mask, 1.0=use mask
260+
261+ Outputs:
262+ masks: [B, 3, 1024, 1024] - Full resolution masks
263+ iou_scores: [B, 3] - IoU predictions per mask
264+ low_res_masks: [B, 3, 256, 256] - Low-res masks (for next iteration)
265+ """
266+
267+ @classmethod
268+ def from_pretrained (cls , model_name_or_path : str , ** kwargs ) -> SAMMaskGeneration :
269+ """Load from a HuggingFace SamModel checkpoint."""
270+ sam_model = SamModel .from_pretrained (model_name_or_path , ** kwargs )
271+ return cls (sam_model )
272+
273+ def __init__ (self , sam_model ):
274+ super ().__init__ ()
275+
276+ self .prompt_encoder = sam_model .prompt_encoder
277+ self .mask_decoder = sam_model .mask_decoder
278+ self .shared_image_embedding = sam_model .shared_image_embedding
279+ self .image_embedding_size = self .prompt_encoder .image_embedding_size
280+ self .config = sam_model .config
281+
282+ def _get_image_positional_embeddings (self , batch_size : int = 1 ) -> torch .Tensor :
283+ """Replicates SamModel.get_image_wide_positional_embeddings()."""
284+ size = self .config .prompt_encoder_config .image_embedding_size
285+ target_device = self .shared_image_embedding .positional_embedding .device
286+ target_dtype = self .shared_image_embedding .positional_embedding .dtype
287+
288+ grid = torch .ones ((size , size ), device = target_device , dtype = target_dtype )
289+ y_embed = grid .cumsum (dim = 0 ) - 0.5
290+ x_embed = grid .cumsum (dim = 1 ) - 0.5
291+ y_embed = y_embed / size
292+ x_embed = x_embed / size
293+
294+ positional_embedding = self .shared_image_embedding (torch .stack ([x_embed , y_embed ], dim = - 1 ))
295+ positional_embedding = positional_embedding .permute (2 , 0 , 1 ).unsqueeze (0 )
296+ return positional_embedding .repeat (batch_size , 1 , 1 , 1 )
297+
298+ def forward (
299+ self ,
300+ input_points : torch .Tensor ,
301+ input_labels : torch .Tensor ,
302+ image_embeddings : torch .Tensor ,
303+ mask_input : torch .Tensor ,
304+ use_mask_input : torch .Tensor ,
305+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
306+ """Run mask generation from pre-computed encoder features."""
307+ batch_size = image_embeddings .shape [0 ]
308+
309+ # 1. Prompt embeddings (sparse - points only, mask handled separately)
310+ sparse_embeddings , _ = self .prompt_encoder (
311+ input_points = input_points ,
312+ input_labels = input_labels ,
313+ input_boxes = None ,
314+ input_masks = None ,
315+ )
316+
317+ # Arithmetic mask blending via use_mask_input flag
318+ # (avoids torch.where for ONNX/QNN compatibility)
319+ mask_dense = self .prompt_encoder .mask_embed (mask_input )
320+ no_mask_dense = self .prompt_encoder .no_mask_embed .weight .reshape (1 , - 1 , 1 , 1 ).expand (
321+ batch_size ,
322+ - 1 ,
323+ self .image_embedding_size [0 ],
324+ self .image_embedding_size [1 ],
325+ )
326+ flag = use_mask_input .reshape (- 1 , 1 , 1 , 1 ).to (mask_dense .dtype )
327+ dense_embeddings = (1.0 - flag ) * no_mask_dense + flag * mask_dense
328+
329+ # 2. Positional embeddings
330+ image_positional_embeddings = self ._get_image_positional_embeddings (batch_size )
331+
332+ # 3. Mask decoder
333+ low_res_masks , iou_pred = self .mask_decoder (
334+ image_embeddings = image_embeddings ,
335+ image_positional_embeddings = image_positional_embeddings ,
336+ sparse_prompt_embeddings = sparse_embeddings ,
337+ dense_prompt_embeddings = dense_embeddings ,
338+ multimask_output = True ,
339+ )
340+
341+ # Squeeze point_batch_size dimension
342+ low_res_masks = low_res_masks .squeeze (1 ) # [B, 3, 256, 256]
343+ iou_scores = iou_pred .squeeze (1 ) # [B, 3]
344+
345+ # 4. Upsample to full resolution
346+ masks = torch .nn .functional .interpolate (
347+ low_res_masks ,
348+ size = (1024 , 1024 ),
349+ mode = "bilinear" ,
350+ align_corners = False ,
351+ )
352+
353+ return masks , iou_scores , low_res_masks
354+
355+
243356# =============================================================================
244357# HuggingFace Model Class Mapping
245358# =============================================================================
@@ -254,6 +367,7 @@ def forward(
254367# Users wanting the full model use --task image-segmentation.
255368
256369MODEL_CLASS_MAPPING : dict [tuple [str , str ], type ] = {
370+ ("sam" , "mask-generation" ): SAMMaskGeneration ,
257371 ("sam2" , "image-segmentation" ): Sam2Model ,
258372 ("sam2" , "feature-extraction" ): Sam2VisionEncoder ,
259373 ("sam2" , "image-feature-extraction" ): Sam2VisionEncoder ,
@@ -812,13 +926,111 @@ def outputs(self) -> dict[str, dict[int, str]]:
812926 }
813927
814928
929+ # =============================================================================
930+ # SAM v1 Custom Dummy Input Generators
931+ # =============================================================================
932+ class SamEmbeddingsInputGenerator (DummyInputGenerator ):
933+ """Embeddings input generator for SAM v1 mask generation decoder.
934+
935+ Generates:
936+ - image_embeddings: [B, 256, 64, 64] - From vision encoder
937+ """
938+
939+ SUPPORTED_INPUT_NAMES = ("image_embeddings" ,)
940+
941+ def __init__ (
942+ self ,
943+ task : str ,
944+ normalized_config : NormalizedConfig ,
945+ batch_size : int = 1 ,
946+ ** kwargs ,
947+ ):
948+ self .task = task
949+ self .batch_size = batch_size
950+
951+ def generate (
952+ self ,
953+ input_name : str ,
954+ framework : str = "pt" ,
955+ int_dtype : str = "int64" ,
956+ float_dtype : str = "fp32" ,
957+ ):
958+ # SAM v1 decoder export expects the canonical embedding shape from the
959+ # vision encoder output; this mirrors the existing SAM2 generator path.
960+ shape = [self .batch_size , 256 , 64 , 64 ]
961+ return self .random_float_tensor (shape , framework = framework , dtype = float_dtype )
962+
963+
964+ # =============================================================================
965+ # SAM v1 Optimum ONNX Export Config Registration
966+ # =============================================================================
967+
968+
969+ # -----------------------------------------------------------------------------
970+ # Mask generation export (SAMMaskGeneration wrapper) - SAM v1
971+ # -----------------------------------------------------------------------------
972+ @register_onnx_overwrite ("sam" , "mask-generation" , library_name = "transformers" )
973+ class SamMaskGenerationIOConfig (OnnxConfig ):
974+ """ONNX config for SAMMaskGeneration (SAM v1 decoder).
975+
976+ Model: facebook/sam-vit-huge, facebook/sam-vit-large, facebook/sam-vit-base
977+ Uses SAMMaskGeneration nn.Module which takes image_embeddings from the
978+ vision encoder and runs prompt encoding + mask decoding.
979+
980+ Inputs:
981+ - input_points: {0: "batch_size"} [B, 1, N, 2]
982+ - input_labels: {0: "batch_size"} [B, 1, N]
983+ - image_embeddings: {0: "batch_size"} [B, 256, 64, 64]
984+ - mask_input: {0: "batch_size"} [B, 1, 256, 256]
985+ - use_mask_input: {0: "batch_size"} [B]
986+
987+ Outputs:
988+ - masks: {0: "batch_size"} [B, 3, 1024, 1024]
989+ - iou_scores: {0: "batch_size"} [B, 3]
990+ - low_res_masks: {0: "batch_size"} [B, 3, 256, 256]
991+ """
992+
993+ # SAM v1 also uses 1024x1024 default image size, so this normalized config
994+ # is intentionally shared across SAM v1 and SAM2 export configs.
995+ NORMALIZED_CONFIG_CLASS = Sam2NormalizedVisionConfig
996+ # SAM v1 reuses SAM2-named generators because prompt/mask tensor shapes
997+ # are identical for this export path.
998+ DUMMY_INPUT_GENERATOR_CLASSES = (
999+ Sam2PointsInputGenerator ,
1000+ SamEmbeddingsInputGenerator ,
1001+ Sam2MaskInputGenerator ,
1002+ )
1003+
1004+ @property
1005+ def inputs (self ) -> dict [str , dict [int , str ]]:
1006+ """Return input tensors for SAM v1 mask generation."""
1007+ return {
1008+ "input_points" : {0 : "batch_size" },
1009+ "input_labels" : {0 : "batch_size" },
1010+ "image_embeddings" : {0 : "batch_size" },
1011+ "mask_input" : {0 : "batch_size" },
1012+ "use_mask_input" : {0 : "batch_size" },
1013+ }
1014+
1015+ @property
1016+ def outputs (self ) -> dict [str , dict [int , str ]]:
1017+ """Return output tensors for SAM v1 mask generation."""
1018+ return {
1019+ "masks" : {0 : "batch_size" },
1020+ "iou_scores" : {0 : "batch_size" },
1021+ "low_res_masks" : {0 : "batch_size" },
1022+ }
1023+
1024+
8151025__all__ = [
8161026 "SAM2MaskGeneration" ,
1027+ "SAMMaskGeneration" ,
8171028 "Sam2IOConfig" ,
8181029 "Sam2ImageEncoderIOConfig" ,
8191030 "Sam2MaskGenerationIOConfig" ,
8201031 "Sam2ModelPatcher" ,
8211032 "Sam2NormalizedVisionConfig" ,
1033+ "SamMaskGenerationIOConfig" ,
8221034 "_patched_sam2_multiscale_block_forward" ,
8231035 "_patched_sam2_prompt_encoder_forward" ,
8241036]
0 commit comments