-
Notifications
You must be signed in to change notification settings - Fork 31.2k
[SAM3 Video] Add support for multi prompts #42293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SAM3 Video] Add support for multi prompts #42293
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
molbap
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed feature, nice. Most comments are also for later refacto/post-release cleanup
| encoded_text = self.tokenizer( | ||
| prompt_text, return_tensors="pt", padding="max_length", max_length=32 | ||
| ).to(inference_session.inference_device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each individual prompt has to be at least length 32? Here, you could instead batch the tokenization over prompts. You still have to loop to slice after, but it's a bit faster than individual tokenizer calls
| # Multi-prompt support | ||
| self.prompts = {} # prompt_id -> prompt_text | ||
| self.prompt_input_ids = {} # prompt_id -> input_ids | ||
| self.prompt_embeddings = {} # prompt_id -> text embeddings | ||
| self.prompt_attention_masks = {} # prompt_id -> attention_mask | ||
| self.obj_id_to_prompt_id = {} # obj_id -> prompt_id (assigned at detection time) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
general comment as we already discussed elsewhere, I think a dataclass holding the various state would be the correct data structure here, for a later cleanup.
| attention_mask=inference_session.prompt_attention_masks[prompt_id], | ||
| ) | ||
| inference_session.prompt_embeddings[prompt_id] = text_embeds | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
here if we batch the input_ids tokenized before we can also batch process the prompts embedding, the attention mask will do the rest
| pred_probs = pred_logits.sigmoid() | ||
| presence_scores = presence_logits.sigmoid() | ||
| pred_probs = pred_probs * presence_scores | ||
|
|
||
| pred_boxes_xyxy = detector_outputs.pred_boxes | ||
| pred_masks = detector_outputs.pred_masks | ||
| # get the positive detection outputs above threshold | ||
| pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection) | ||
| det_out = { | ||
| "bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]], | ||
| "mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]], | ||
| "scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]], | ||
| } | ||
| run_nms = self.det_nms_thresh > 0.0 | ||
| if run_nms: | ||
| keep = nms_masks( | ||
| pred_probs=pred_probs[0], | ||
| pred_masks=detector_outputs.pred_masks[0], | ||
| prob_threshold=self.score_threshold_detection, | ||
| iou_threshold=self.det_nms_thresh, | ||
| ) | ||
| # set suppressed detections' logits to a very low value | ||
| detector_outputs.pred_logits[0] -= 1e4 * (~keep).float() | ||
| # Recompute pred_probs after NMS suppression | ||
| pred_probs = pred_logits.sigmoid() | ||
| pred_probs = pred_probs * presence_scores | ||
|
|
||
| pred_boxes_xyxy = detector_outputs.pred_boxes | ||
| pred_masks = detector_outputs.pred_masks | ||
| # get the positive detection outputs above threshold | ||
| pos_pred_idx = torch.where(pred_probs > self.score_threshold_detection) | ||
| det_out = { | ||
| "bbox": pred_boxes_xyxy[pos_pred_idx[0], pos_pred_idx[1]], | ||
| "mask": pred_masks[pos_pred_idx[0], pos_pred_idx[1]], | ||
| "scores": pred_probs[pos_pred_idx[0], pos_pred_idx[1]], | ||
| } | ||
|
|
||
| all_detections[prompt_id] = det_out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missed that, wouldn't the recomputing of pred_probs be equivalent to masking the old ones? like pred_probs[0, ~keep] = 0
| You can also track multiple object categories simultaneously by providing multiple prompts. The model efficiently reuses vision features across all prompts: | ||
|
|
||
| ```python | ||
| >>> # Add multiple text prompts (or use a list in add_text_prompt) | ||
| >>> multi_prompt_session = processor.init_video_session( | ||
| ... video=video_frames, | ||
| ... inference_device=device, | ||
| ... processing_device="cpu", | ||
| ... video_storage_device="cpu", | ||
| ... dtype=torch.bfloat16, | ||
| ... ) | ||
| >>> | ||
| >>> prompts = ["person", "bed", "lamp"] | ||
| >>> processor.add_text_prompt(multi_prompt_session, prompts) | ||
| >>> | ||
| >>> # Process video - detects objects from ALL prompts in a single pass | ||
| >>> multi_outputs_per_frame = {} | ||
| >>> for model_outputs in model.propagate_in_video_iterator( | ||
| ... inference_session=multi_prompt_session, max_frame_num_to_track=50 | ||
| ... ): | ||
| ... processed_outputs = processor.postprocess_outputs(multi_prompt_session, model_outputs) | ||
| ... multi_outputs_per_frame[model_outputs.frame_idx] = processed_outputs | ||
| >>> | ||
| >>> # Check which objects were detected by each prompt | ||
| >>> frame_0_outputs = multi_outputs_per_frame[0] | ||
| >>> prompt_to_obj_ids = frame_0_outputs["prompt_to_obj_ids"] | ||
| >>> for prompt, obj_ids in prompt_to_obj_ids.items(): | ||
| ... print(f"{prompt}: {len(obj_ids)} objects") | ||
| person: 2 objects | ||
| bed: 1 objects | ||
| lamp: 1 objects |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cool 🙌
| def _merge_detections_from_prompts( | ||
| self, | ||
| all_detections: dict[int, dict[str, Tensor]], | ||
| inference_session: Sam3VideoInferenceSession, | ||
| ) -> tuple[dict[str, Tensor], dict[int, int]]: | ||
| """ | ||
| Merge detections from multiple prompts into a single detection output. | ||
| Assigns unique object IDs and tracks which prompt detected each object. | ||
| Args: | ||
| all_detections: Dictionary mapping prompt_id to detection outputs | ||
| inference_session: Session to track obj_id -> prompt_id mapping | ||
| Returns: | ||
| Tuple of (merged_det_out, det_idx_to_prompt_id) where det_idx_to_prompt_id | ||
| maps detection index in the merged output to the prompt that produced it. | ||
| """ | ||
| merged_bboxes, merged_masks, merged_scores = [], [], [] | ||
| det_idx_to_prompt_id = {} | ||
| det_idx = 0 | ||
|
|
||
| for prompt_id, det_out in all_detections.items(): | ||
| num_dets = len(det_out["bbox"]) | ||
| if num_dets > 0: | ||
| merged_bboxes.append(det_out["bbox"]) | ||
| merged_masks.append(det_out["mask"]) | ||
| merged_scores.append(det_out["scores"]) | ||
| for i in range(num_dets): | ||
| det_idx_to_prompt_id[det_idx + i] = prompt_id | ||
| det_idx += num_dets | ||
|
|
||
| if merged_bboxes: | ||
| merged_det_out = { | ||
| "bbox": torch.cat(merged_bboxes), | ||
| "mask": torch.cat(merged_masks), | ||
| "scores": torch.cat(merged_scores), | ||
| } | ||
| else: | ||
| device = inference_session.inference_device | ||
| merged_det_out = { | ||
| "bbox": torch.zeros(0, 4, device=device), | ||
| "mask": torch.zeros(0, self.low_res_mask_size, self.low_res_mask_size, device=device), | ||
| "scores": torch.zeros(0, device=device), | ||
| } | ||
|
|
||
| return merged_det_out, det_idx_to_prompt_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, if we batch prompts and keep track of indices, we won't need that in the end no? OK to have in a later PR or 2 with the rest of the cleanup
| super().setUp() | ||
| self.video_model = Sam3TrackerVideoModel.from_pretrained("../sam3-hf-v4-video-full").to(torch.float32) | ||
| self.processor = Sam3TrackerVideoProcessor.from_pretrained("../sam3-hf-v4-video-full") | ||
| self.video_model = Sam3TrackerVideoModel.from_pretrained("facebook/sam3").to(torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
dtype rather than .to, where relevant
|
[For maintainers] Suggested jobs to run (before merge) run-slow: sam3, sam3_tracker, sam3_tracker_video, sam3_video |
|
Thanks @molbap for the review! I made the changes that you suggested that are not too involved, and noted all the ones that are more involved for the post-release cleanup ;) |
What does this PR do?
Also fixes checkpoints in tests