Skip to content

Conversation

@yonigozlan
Copy link
Member

What does this PR do?

Also fixes checkpoints in tests

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Contributor

@molbap molbap left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed feature, nice. Most comments are also for later refacto/post-release cleanup

Comment on lines +133 to +135
encoded_text = self.tokenizer(
prompt_text, return_tensors="pt", padding="max_length", max_length=32
).to(inference_session.inference_device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each individual prompt has to be at least length 32? Here, you could instead batch the tokenization over prompts. You still have to loop to slice after, but it's a bit faster than individual tokenizer calls

Comment on lines +185 to +190
# Multi-prompt support
self.prompts = {} # prompt_id -> prompt_text
self.prompt_input_ids = {} # prompt_id -> input_ids
self.prompt_embeddings = {} # prompt_id -> text embeddings
self.prompt_attention_masks = {} # prompt_id -> attention_mask
self.obj_id_to_prompt_id = {} # obj_id -> prompt_id (assigned at detection time)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

general comment as we already discussed elsewhere, I think a dataclass holding the various state would be the correct data structure here, for a later cleanup.

attention_mask=inference_session.prompt_attention_masks[prompt_id],
)
inference_session.prompt_embeddings[prompt_id] = text_embeds
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here if we batch the input_ids tokenized before we can also batch process the prompts embedding, the attention mask will do the rest

Comment on lines 577 to 637
pred_probs = pred_logits.sigmoid()
presence_scores = presence_logits.sigmoid()
pred_probs = pred_probs * presence_scores

pred_boxes_xyxy = detector_outputs.pred_boxes
pred_masks = detector_outputs.pred_masks
# get the positive detection outputs above threshold
pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection)
det_out = {
"bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]],
"mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]],
"scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]],
}
run_nms = self.det_nms_thresh > 0.0
if run_nms:
keep = nms_masks(
pred_probs=pred_probs[0],
pred_masks=detector_outputs.pred_masks[0],
prob_threshold=self.score_threshold_detection,
iou_threshold=self.det_nms_thresh,
)
# set suppressed detections' logits to a very low value
detector_outputs.pred_logits[0] -= 1e4 * (~keep).float()
# Recompute pred_probs after NMS suppression
pred_probs = pred_logits.sigmoid()
pred_probs = pred_probs * presence_scores

pred_boxes_xyxy = detector_outputs.pred_boxes
pred_masks = detector_outputs.pred_masks
# get the positive detection outputs above threshold
pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection)
det_out = {
"bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]],
"mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]],
"scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]],
}

all_detections[prompt_id] = det_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missed that, wouldn't the recomputing of pred_probs be equivalent to masking the old ones? like pred_probs[0, ~keep] = 0

Comment on lines +100 to +130
You can also track multiple object categories simultaneously by providing multiple prompts. The model efficiently reuses vision features across all prompts:

```python
>>> # Add multiple text prompts (or use a list in add_text_prompt)
>>> multi_prompt_session = processor.init_video_session(
... video=video_frames,
... inference_device=device,
... processing_device="cpu",
... video_storage_device="cpu",
... dtype=torch.bfloat16,
... )
>>>
>>> prompts = ["person", "bed", "lamp"]
>>> processor.add_text_prompt(multi_prompt_session, prompts)
>>>
>>> # Process video - detects objects from ALL prompts in a single pass
>>> multi_outputs_per_frame = {}
>>> for model_outputs in model.propagate_in_video_iterator(
... inference_session=multi_prompt_session, max_frame_num_to_track=50
... ):
... processed_outputs = processor.postprocess_outputs(multi_prompt_session, model_outputs)
... multi_outputs_per_frame[model_outputs.frame_idx] = processed_outputs
>>>
>>> # Check which objects were detected by each prompt
>>> frame_0_outputs = multi_outputs_per_frame[0]
>>> prompt_to_obj_ids = frame_0_outputs["prompt_to_obj_ids"]
>>> for prompt, obj_ids in prompt_to_obj_ids.items():
... print(f"{prompt}: {len(obj_ids)} objects")
person: 2 objects
bed: 1 objects
lamp: 1 objects
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cool 🙌

Comment on lines +1471 to +1516
def _merge_detections_from_prompts(
self,
all_detections: dict[int, dict[str, Tensor]],
inference_session: Sam3VideoInferenceSession,
) -> tuple[dict[str, Tensor], dict[int, int]]:
"""
Merge detections from multiple prompts into a single detection output.
Assigns unique object IDs and tracks which prompt detected each object.
Args:
all_detections: Dictionary mapping prompt_id to detection outputs
inference_session: Session to track obj_id -> prompt_id mapping
Returns:
Tuple of (merged_det_out, det_idx_to_prompt_id) where det_idx_to_prompt_id
maps detection index in the merged output to the prompt that produced it.
"""
merged_bboxes, merged_masks, merged_scores = [], [], []
det_idx_to_prompt_id = {}
det_idx = 0

for prompt_id, det_out in all_detections.items():
num_dets = len(det_out["bbox"])
if num_dets > 0:
merged_bboxes.append(det_out["bbox"])
merged_masks.append(det_out["mask"])
merged_scores.append(det_out["scores"])
for i in range(num_dets):
det_idx_to_prompt_id[det_idx + i] = prompt_id
det_idx += num_dets

if merged_bboxes:
merged_det_out = {
"bbox": torch.cat(merged_bboxes),
"mask": torch.cat(merged_masks),
"scores": torch.cat(merged_scores),
}
else:
device = inference_session.inference_device
merged_det_out = {
"bbox": torch.zeros(0, 4, device=device),
"mask": torch.zeros(0, self.low_res_mask_size, self.low_res_mask_size, device=device),
"scores": torch.zeros(0, device=device),
}

return merged_det_out, det_idx_to_prompt_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, if we batch prompts and keep track of indices, we won't need that in the end no? OK to have in a later PR or 2 with the rest of the cleanup

super().setUp()
self.video_model = Sam3TrackerVideoModel.from_pretrained("../sam3-hf-v4-video-full").to(torch.float32)
self.processor = Sam3TrackerVideoProcessor.from_pretrained("../sam3-hf-v4-video-full")
self.video_model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(torch.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype rather than .to, where relevant

@github-actions
Copy link
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: sam3, sam3_tracker, sam3_tracker_video, sam3_video

@yonigozlan
Copy link
Member Author

Thanks @molbap for the review! I made the changes that you suggested that are not too involved, and noted all the ones that are more involved for the post-release cleanup ;)

@yonigozlan yonigozlan merged commit c3fb1b1 into huggingface:main Nov 20, 2025
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants