Skip to content

Commit c3fb1b1

Browse files
authored
[SAM3 Video] Add support for multi prompts (#42293)
* add support for multi prompts + fix checkpoints in tests * Make sure to apply heuristics per prompt group * simplify NMS to probs
1 parent a1afeca commit c3fb1b1

File tree

7 files changed

+425
-87
lines changed

7 files changed

+425
-87
lines changed

docs/source/en/model_doc/sam3_video.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,39 @@ Processed 51 frames
9797
>>> print(f"Masks shape: {frame_0_outputs['masks'].shape}")
9898
```
9999

100+
You can also track multiple object categories simultaneously by providing multiple prompts. The model efficiently reuses vision features across all prompts:
101+
102+
```python
103+
>>> # Add multiple text prompts (or use a list in add_text_prompt)
104+
>>> multi_prompt_session = processor.init_video_session(
105+
... video=video_frames,
106+
... inference_device=device,
107+
... processing_device="cpu",
108+
... video_storage_device="cpu",
109+
... dtype=torch.bfloat16,
110+
... )
111+
>>>
112+
>>> prompts = ["person", "bed", "lamp"]
113+
>>> processor.add_text_prompt(multi_prompt_session, prompts)
114+
>>>
115+
>>> # Process video - detects objects from ALL prompts in a single pass
116+
>>> multi_outputs_per_frame = {}
117+
>>> for model_outputs in model.propagate_in_video_iterator(
118+
... inference_session=multi_prompt_session, max_frame_num_to_track=50
119+
... ):
120+
... processed_outputs = processor.postprocess_outputs(multi_prompt_session, model_outputs)
121+
... multi_outputs_per_frame[model_outputs.frame_idx] = processed_outputs
122+
>>>
123+
>>> # Check which objects were detected by each prompt
124+
>>> frame_0_outputs = multi_outputs_per_frame[0]
125+
>>> prompt_to_obj_ids = frame_0_outputs["prompt_to_obj_ids"]
126+
>>> for prompt, obj_ids in prompt_to_obj_ids.items():
127+
... print(f"{prompt}: {len(obj_ids)} objects")
128+
person: 2 objects
129+
bed: 1 objects
130+
lamp: 1 objects
131+
```
132+
100133
#### Streaming Video Inference
101134

102135
<div class="warning">

src/transformers/models/sam3_video/modeling_sam3_video.py

Lines changed: 249 additions & 64 deletions
Large diffs are not rendered by default.

src/transformers/models/sam3_video/processing_sam3_video.py

Lines changed: 78 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ class Sam3VideoProcessor(ProcessorMixin):
3636
[`~Sam3ImageProcessor.__call__`] and [`~Sam3VideoProcessor.__call__`] for more information.
3737
3838
Args:
39-
image_processor (`Sam2ImageProcessorFast`):
40-
An instance of [`Sam2ImageProcessorFast`].
39+
image_processor (`Sam3ImageProcessorFast`):
40+
An instance of [`Sam3ImageProcessorFast`].
4141
video_processor (`Sam2VideoVideoProcessor`):
4242
An instance of [`Sam2VideoVideoProcessor`].
43-
tokenizer ([`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]):
43+
tokenizer ([`CLIPTokenizer`, `CLIPTokenizerFast`]):
4444
An instance of [`PreTrainedTokenizer`, `PreTrainedTokenizerFast`]. The tokenizer is a required input.
4545
target_size (`int`, *optional*):
4646
The target size (target_size, target_size) to which the image will be resized.
@@ -109,16 +109,36 @@ def __call__(
109109

110110
return encoding_image_processor
111111

112-
def add_text_prompt(self, inference_session, text):
112+
def add_text_prompt(self, inference_session: Sam3VideoInferenceSession, text: Union[str, list[str]]):
113113
"""
114-
Add text prompt to the inference session.
114+
Add text prompt(s) to the inference session.
115+
116+
Args:
117+
inference_session (`Sam3VideoInferenceSession`): The inference session.
118+
text (`str` or `list[str]`): The text prompt(s) to add.
119+
120+
Returns:
121+
`Sam3VideoInferenceSession`: The inference session with the added text prompt(s).
115122
"""
116-
encoded_text = self.tokenizer(text, return_tensors="pt", padding="max_length", max_length=32).to(
117-
inference_session.inference_device
118-
)
119-
inference_session.text_attention_mask = encoded_text.attention_mask
120-
inference_session.text_input_ids = encoded_text.input_ids
121-
inference_session.has_new_text_input = True
123+
if isinstance(text, str):
124+
text = [text]
125+
126+
prompt_ids = []
127+
for prompt_text in text:
128+
# Add prompt and get its ID (reuses existing if duplicate)
129+
prompt_id = inference_session.add_prompt(prompt_text)
130+
131+
# Only encode if this is a new prompt (not already in prompt_input_ids)
132+
if prompt_id not in inference_session.prompt_input_ids:
133+
encoded_text = self.tokenizer(
134+
prompt_text, return_tensors="pt", padding="max_length", max_length=32
135+
).to(inference_session.inference_device)
136+
137+
inference_session.prompt_input_ids[prompt_id] = encoded_text.input_ids
138+
inference_session.prompt_attention_masks[prompt_id] = encoded_text.attention_mask
139+
140+
prompt_ids.append(prompt_id)
141+
122142
return inference_session
123143

124144
def init_video_session(
@@ -194,20 +214,46 @@ def _apply_non_overlapping_constraints(self, pred_masks):
194214
pred_masks = torch.where(keep, pred_masks, torch.clamp(pred_masks, max=-10.0))
195215
return pred_masks
196216

197-
def _apply_object_wise_non_overlapping_constraints(self, pred_masks, obj_scores, background_value=-10.0):
217+
def _apply_object_wise_non_overlapping_constraints(
218+
self,
219+
pred_masks,
220+
obj_scores,
221+
background_value=-10.0,
222+
prompt_ids=None,
223+
):
198224
"""
199-
Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region)
225+
Applies non-overlapping constraints object wise (i.e. only one object can claim the overlapping region).
226+
Constraints are enforced independently for each prompt group when `prompt_ids` are provided.
200227
"""
228+
if prompt_ids is None:
229+
return self._apply_object_wise_non_overlapping_constraints_impl(pred_masks, obj_scores, background_value)
230+
231+
if len(prompt_ids) != pred_masks.size(0):
232+
raise ValueError("prompt_ids must have the same length as pred_masks")
233+
234+
pred_masks_grouped = pred_masks.clone()
235+
prompt_ids_tensor = torch.tensor(prompt_ids, device=pred_masks.device, dtype=torch.long)
236+
for prompt_id in prompt_ids_tensor.unique(sorted=True):
237+
indices = torch.nonzero(prompt_ids_tensor == prompt_id, as_tuple=True)[0]
238+
if indices.numel() == 0:
239+
continue
240+
prompt_masks = self._apply_object_wise_non_overlapping_constraints_impl(
241+
pred_masks_grouped[indices],
242+
obj_scores[indices],
243+
background_value,
244+
)
245+
pred_masks_grouped[indices] = prompt_masks.to(pred_masks_grouped.dtype)
246+
return pred_masks_grouped
247+
248+
def _apply_object_wise_non_overlapping_constraints_impl(self, pred_masks, obj_scores, background_value=-10.0):
201249
pred_masks_single_score = torch.where(pred_masks > 0, obj_scores[..., None, None], background_value)
202-
# Apply pixel-wise non-overlapping constraint based on mask scores
203250
pixel_level_non_overlapping_masks = self._apply_non_overlapping_constraints(pred_masks_single_score)
204-
# Replace object scores with pixel scores. Note, that now only one object can claim the overlapping region
205251
pred_masks = torch.where(
206252
pixel_level_non_overlapping_masks > 0,
207253
pred_masks,
208254
torch.clamp(pred_masks, max=background_value),
209255
)
210-
return pred_masks
256+
return pred_masks.to(pred_masks_single_score.dtype)
211257

212258
def postprocess_outputs(
213259
self,
@@ -235,6 +281,8 @@ def postprocess_outputs(
235281
(top_left_x, top_left_y, bottom_right_x, bottom_right_y).
236282
- **masks** (`torch.Tensor` of shape `(num_objects, height, width)`): Binary segmentation masks
237283
for each object at the original video resolution.
284+
- **prompt_to_obj_ids** (`dict[str, list[int]]`): Mapping from prompt text to list of
285+
object IDs detected by that prompt.
238286
"""
239287
obj_id_to_mask = model_outputs["obj_id_to_mask"] # low res masks (1, H_low, W_low)
240288
curr_obj_ids = sorted(obj_id_to_mask.keys())
@@ -301,22 +349,35 @@ def postprocess_outputs(
301349

302350
out_boxes_xyxy = masks_to_boxes(out_binary_masks)
303351

304-
# apply non-overlapping constraints on the existing masklets
352+
# Apply non-overlapping constraints on the existing masklets.
353+
# Constraints are enforced independently per prompt group.
305354
if out_binary_masks.shape[0] > 1:
306355
assert len(out_binary_masks) == len(out_tracker_probs)
356+
prompt_ids_filtered = [
357+
inference_session.obj_id_to_prompt_id[int(obj_id)] for obj_id in out_obj_ids.tolist()
358+
]
307359
out_binary_masks = (
308360
self._apply_object_wise_non_overlapping_constraints(
309361
out_binary_masks.unsqueeze(1),
310362
out_tracker_probs.unsqueeze(1).to(out_binary_masks.device),
311363
background_value=0,
364+
prompt_ids=prompt_ids_filtered,
312365
).squeeze(1)
313366
) > 0
314367

368+
# Build prompt_to_obj_ids mapping: group object IDs by their associated prompt text.
369+
prompt_to_obj_ids = {}
370+
for obj_id in out_obj_ids.tolist():
371+
prompt_id = inference_session.obj_id_to_prompt_id[obj_id]
372+
prompt_text = inference_session.prompts[prompt_id]
373+
prompt_to_obj_ids.setdefault(prompt_text, []).append(obj_id)
374+
315375
outputs = {
316376
"object_ids": out_obj_ids,
317377
"scores": out_probs,
318378
"boxes": out_boxes_xyxy,
319379
"masks": out_binary_masks,
380+
"prompt_to_obj_ids": prompt_to_obj_ids,
320381
}
321382
return outputs
322383

tests/models/sam3/test_modeling_sam3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -987,7 +987,7 @@ class Sam3ModelIntegrationTest(unittest.TestCase):
987987

988988
def setUp(self):
989989
super().setUp()
990-
model_name = "../sam3-hf-v4-video-full"
990+
model_name = "facebook/sam3"
991991
self.model = Sam3Model.from_pretrained(model_name).to(torch.float32)
992992
self.processor = Sam3Processor.from_pretrained(model_name)
993993
self.model.to(torch_device)

tests/models/sam3_tracker/test_modeling_sam3_tracker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ def prepare_video():
510510
class Sam3TrackerModelIntegrationTest(unittest.TestCase):
511511
def setUp(self):
512512
super().setUp()
513-
checkpoint_path = "../sam3-hf-v4-video-full"
513+
checkpoint_path = "facebook/sam3"
514514
self.model = Sam3TrackerModel.from_pretrained(checkpoint_path).to(torch.float32)
515515
self.processor = Sam3TrackerProcessor.from_pretrained(checkpoint_path)
516516
self.model.to(torch_device)
@@ -817,7 +817,7 @@ def test_inference_mask_generation_from_existing_points_and_mask(self):
817817
)
818818

819819
def test_dummy_pipeline_generation(self):
820-
generator = pipeline("mask-generation", model="../sam3-hf-v4-video-full", device=torch_device)
820+
generator = pipeline("mask-generation", model="facebook/sam3", device=torch_device)
821821
raw_image = prepare_image()
822822

823823
_ = generator(raw_image, points_per_batch=64)

tests/models/sam3_tracker_video/test_modeling_sam3_tracker_video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,8 @@ def prepare_video():
6666
class Sam3TrackerVideoModelIntegrationTest(unittest.TestCase):
6767
def setUp(self):
6868
super().setUp()
69-
self.video_model = Sam3TrackerVideoModel.from_pretrained("../sam3-hf-v4-video-full").to(torch.float32)
70-
self.processor = Sam3TrackerVideoProcessor.from_pretrained("../sam3-hf-v4-video-full")
69+
self.video_model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(torch.float32)
70+
self.processor = Sam3TrackerVideoProcessor.from_pretrained("facebook/sam3")
7171
self.video_model.to(torch_device)
7272
self.video_model.eval()
7373

tests/models/sam3_video/test_modeling_sam3_video.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def prepare_video():
4242
class Sam3VideoModelIntegrationTest(unittest.TestCase):
4343
def setUp(self):
4444
super().setUp()
45-
checkpoint_path = "../sam3-hf-v4-video-full"
45+
checkpoint_path = "facebook/sam3"
4646
self.video_model = Sam3VideoModel.from_pretrained(checkpoint_path).to(torch.float32)
4747
self.processor = Sam3VideoProcessor.from_pretrained(checkpoint_path)
4848
self.video_model.to(torch_device)
@@ -473,3 +473,62 @@ def test_inference_video_streaming_with_text_prompt(self):
473473
atol=5e-3, # Higher tolerance for raw logits
474474
rtol=5e-3,
475475
)
476+
477+
def test_inference_video_multi_prompt(self):
478+
"""Test multi-prompt tracking - detecting multiple object categories in one pass."""
479+
raw_video = prepare_video()
480+
inference_session = self.processor.init_video_session(
481+
video=raw_video,
482+
inference_device=torch_device,
483+
processing_device="cpu",
484+
video_storage_device="cpu",
485+
)
486+
487+
# Add multiple text prompts
488+
prompts = ["person", "bed"]
489+
self.processor.add_text_prompt(
490+
inference_session=inference_session,
491+
text=prompts,
492+
)
493+
494+
# Propagate through video frames
495+
outputs_per_frame = {}
496+
for model_outputs in self.video_model.propagate_in_video_iterator(
497+
inference_session=inference_session,
498+
max_frame_num_to_track=3,
499+
):
500+
processed_outputs = self.processor.postprocess_outputs(inference_session, model_outputs)
501+
outputs_per_frame[model_outputs.frame_idx] = processed_outputs
502+
503+
# Check we processed the expected number of frames
504+
self.assertGreaterEqual(len(outputs_per_frame), 1)
505+
self.assertLessEqual(len(outputs_per_frame), 4)
506+
507+
# Check output structure for each frame
508+
for processed_outputs in outputs_per_frame.values():
509+
self.assertIn("object_ids", processed_outputs)
510+
self.assertIn("scores", processed_outputs)
511+
self.assertIn("boxes", processed_outputs)
512+
self.assertIn("masks", processed_outputs)
513+
self.assertIn("prompt_to_obj_ids", processed_outputs) # Multi-prompt specific
514+
515+
# Check prompt_to_obj_ids structure
516+
prompt_to_obj_ids = processed_outputs["prompt_to_obj_ids"]
517+
self.assertIsInstance(prompt_to_obj_ids, dict)
518+
for prompt, obj_ids in prompt_to_obj_ids.items():
519+
self.assertIsInstance(prompt, str)
520+
self.assertIsInstance(obj_ids, list)
521+
# Each object ID should be in the main object_ids list
522+
for obj_id in obj_ids:
523+
self.assertIn(obj_id, processed_outputs["object_ids"].tolist())
524+
525+
# Check that we detected objects from multiple prompts
526+
first_frame_outputs = outputs_per_frame[min(outputs_per_frame.keys())]
527+
prompt_to_obj_ids = first_frame_outputs["prompt_to_obj_ids"]
528+
529+
# Should have at least one prompt with detections
530+
self.assertGreater(len(prompt_to_obj_ids), 0)
531+
532+
# All prompts in prompt_to_obj_ids should be from our original prompts
533+
for prompt in prompt_to_obj_ids.keys():
534+
self.assertIn(prompt, prompts)

0 commit comments

Comments
 (0)