@@ -36,11 +36,11 @@ class Sam3VideoProcessor(ProcessorMixin):
3636 [`~Sam3ImageProcessor.__call__`] and [`~Sam3VideoProcessor.__call__`] for more information.
3737
3838 Args:
39- image_processor (`Sam2ImageProcessorFast `):
40- An instance of [`Sam2ImageProcessorFast `].
39+ image_processor (`Sam3ImageProcessorFast `):
40+ An instance of [`Sam3ImageProcessorFast `].
4141 video_processor (`Sam2VideoVideoProcessor`):
4242 An instance of [`Sam2VideoVideoProcessor`].
43- tokenizer ([`PreTrainedTokenizer `, `PreTrainedTokenizerFast `]):
43+ tokenizer ([`CLIPTokenizer `, `CLIPTokenizerFast `]):
4444 An instance of [`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]. The tokenizer is a required input.
4545 target_size (`int`, *optional*):
4646 The target size (target_size, target_size) to which the image will be resized.
@@ -109,16 +109,36 @@ def __call__(
109109
110110 return encoding_image_processor
111111
112- def add_text_prompt (self , inference_session , text ):
112+ def add_text_prompt (self , inference_session : Sam3VideoInferenceSession , text : Union [ str , list [ str ]] ):
113113 """
114- Add text prompt to the inference session.
114+ Add text prompt(s) to the inference session.
115+
116+ Args:
117+ inference_session (`Sam3VideoInferenceSession`): The inference session.
118+ text (`str` or `list[str]`): The text prompt(s) to add.
119+
120+ Returns:
121+ `Sam3VideoInferenceSession`: The inference session with the added text prompt(s).
115122 """
116- encoded_text = self .tokenizer (text , return_tensors = "pt" , padding = "max_length" , max_length = 32 ).to (
117- inference_session .inference_device
118- )
119- inference_session .text_attention_mask = encoded_text .attention_mask
120- inference_session .text_input_ids = encoded_text .input_ids
121- inference_session .has_new_text_input = True
123+ if isinstance (text , str ):
124+ text = [text ]
125+
126+ prompt_ids = []
127+ for prompt_text in text :
128+ # Add prompt and get its ID (reuses existing if duplicate)
129+ prompt_id = inference_session .add_prompt (prompt_text )
130+
131+ # Only encode if this is a new prompt (not already in prompt_input_ids)
132+ if prompt_id not in inference_session .prompt_input_ids :
133+ encoded_text = self .tokenizer (
134+ prompt_text , return_tensors = "pt" , padding = "max_length" , max_length = 32
135+ ).to (inference_session .inference_device )
136+
137+ inference_session .prompt_input_ids [prompt_id ] = encoded_text .input_ids
138+ inference_session .prompt_attention_masks [prompt_id ] = encoded_text .attention_mask
139+
140+ prompt_ids .append (prompt_id )
141+
122142 return inference_session
123143
124144 def init_video_session (
@@ -194,20 +214,46 @@ def _apply_non_overlapping_constraints(self, pred_masks):
194214 pred_masks = torch .where (keep , pred_masks , torch .clamp (pred_masks , max = - 10.0 ))
195215 return pred_masks
196216
197- def _apply_object_wise_non_overlapping_constraints (self , pred_masks , obj_scores , background_value = - 10.0 ):
217+ def _apply_object_wise_non_overlapping_constraints (
218+ self ,
219+ pred_masks ,
220+ obj_scores ,
221+ background_value = - 10.0 ,
222+ prompt_ids = None ,
223+ ):
198224 """
199- Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region)
225+ Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region).
226+ Constraints are enforced independently for each prompt group when `prompt_ids` are provided.
200227 """
228+ if prompt_ids is None :
229+ return self ._apply_object_wise_non_overlapping_constraints_impl (pred_masks , obj_scores , background_value )
230+
231+ if len (prompt_ids ) != pred_masks .size (0 ):
232+ raise ValueError ("prompt_ids must have the same length as pred_masks" )
233+
234+ pred_masks_grouped = pred_masks .clone ()
235+ prompt_ids_tensor = torch .tensor (prompt_ids , device = pred_masks .device , dtype = torch .long )
236+ for prompt_id in prompt_ids_tensor .unique (sorted = True ):
237+ indices = torch .nonzero (prompt_ids_tensor == prompt_id , as_tuple = True )[0 ]
238+ if indices .numel () == 0 :
239+ continue
240+ prompt_masks = self ._apply_object_wise_non_overlapping_constraints_impl (
241+ pred_masks_grouped [indices ],
242+ obj_scores [indices ],
243+ background_value ,
244+ )
245+ pred_masks_grouped [indices ] = prompt_masks .to (pred_masks_grouped .dtype )
246+ return pred_masks_grouped
247+
248+ def _apply_object_wise_non_overlapping_constraints_impl (self , pred_masks , obj_scores , background_value = - 10.0 ):
201249 pred_masks_single_score = torch .where (pred_masks > 0 , obj_scores [..., None , None ], background_value )
202- # Apply pixel-wise non-overlapping constraint based on mask scores
203250 pixel_level_non_overlapping_masks = self ._apply_non_overlapping_constraints (pred_masks_single_score )
204- # Replace object scores with pixel scores. Note, that now only one object can claim the overlapping region
205251 pred_masks = torch .where (
206252 pixel_level_non_overlapping_masks > 0 ,
207253 pred_masks ,
208254 torch .clamp (pred_masks , max = background_value ),
209255 )
210- return pred_masks
256+ return pred_masks . to ( pred_masks_single_score . dtype )
211257
212258 def postprocess_outputs (
213259 self ,
@@ -235,6 +281,8 @@ def postprocess_outputs(
235281 (top_left_x, top_left_y, bottom_right_x, bottom_right_y).
236282 - **masks** (`torch.Tensor` of shape `(num_objects, height, width)`): Binary segmentation masks
237283 for each object at the original video resolution.
284+ - **prompt_to_obj_ids** (`dict[str, list[int]]`): Mapping from prompt text to list of
285+ object IDs detected by that prompt.
238286 """
239287 obj_id_to_mask = model_outputs ["obj_id_to_mask" ] # low res masks (1, H_low, W_low)
240288 curr_obj_ids = sorted (obj_id_to_mask .keys ())
@@ -301,22 +349,35 @@ def postprocess_outputs(
301349
302350 out_boxes_xyxy = masks_to_boxes (out_binary_masks )
303351
304- # apply non-overlapping constraints on the existing masklets
352+ # Apply non-overlapping constraints on the existing masklets.
353+ # Constraints are enforced independently per prompt group.
305354 if out_binary_masks .shape [0 ] > 1 :
306355 assert len (out_binary_masks ) == len (out_tracker_probs )
356+ prompt_ids_filtered = [
357+ inference_session .obj_id_to_prompt_id [int (obj_id )] for obj_id in out_obj_ids .tolist ()
358+ ]
307359 out_binary_masks = (
308360 self ._apply_object_wise_non_overlapping_constraints (
309361 out_binary_masks .unsqueeze (1 ),
310362 out_tracker_probs .unsqueeze (1 ).to (out_binary_masks .device ),
311363 background_value = 0 ,
364+ prompt_ids = prompt_ids_filtered ,
312365 ).squeeze (1 )
313366 ) > 0
314367
368+ # Build prompt_to_obj_ids mapping: group object IDs by their associated prompt text.
369+ prompt_to_obj_ids = {}
370+ for obj_id in out_obj_ids .tolist ():
371+ prompt_id = inference_session .obj_id_to_prompt_id [obj_id ]
372+ prompt_text = inference_session .prompts [prompt_id ]
373+ prompt_to_obj_ids .setdefault (prompt_text , []).append (obj_id )
374+
315375 outputs = {
316376 "object_ids" : out_obj_ids ,
317377 "scores" : out_probs ,
318378 "boxes" : out_boxes_xyxy ,
319379 "masks" : out_binary_masks ,
380+ "prompt_to_obj_ids" : prompt_to_obj_ids ,
320381 }
321382 return outputs
322383
0 commit comments