diff --git a/docs/detection/double_detection_filter.md b/docs/detection/double_detection_filter.md index b02663715..6cb2393fd 100644 --- a/docs/detection/double_detection_filter.md +++ b/docs/detection/double_detection_filter.md @@ -16,12 +16,25 @@ comments: true :::supervision.detection.overlap_filter.box_non_max_suppression +
+

box_soft_non_max_suppression

+
+ +:::supervision.detection.overlap_filter.box_soft_non_max_suppression + +

mask_non_max_suppression

:::supervision.detection.overlap_filter.mask_non_max_suppression +
+

mask_soft_non_max_suppression

+
+ +:::supervision.detection.overlap_filter.mask_soft_non_max_suppression +

box_non_max_merge

diff --git a/supervision/detection/core.py b/supervision/detection/core.py index 113948fc9..ab3ab348d 100644 --- a/supervision/detection/core.py +++ b/supervision/detection/core.py @@ -19,7 +19,9 @@ from supervision.detection.overlap_filter import ( box_non_max_merge, box_non_max_suppression, + box_soft_non_max_suppression, mask_non_max_suppression, + mask_soft_non_max_suppression, ) from supervision.detection.tools.transformers import ( process_transformers_detection_result, @@ -1320,6 +1322,63 @@ def with_nms( return self[indices] + def with_soft_nms( + self, sigma: float = 0.5, class_agnostic: bool = False + ) -> Detections: + """ + Perform soft non-maximum suppression on the current set of object detections. + + Args: + sigma (float): The sigma value to use for the soft non-maximum suppression + algorithm. Defaults to 0.5. + class_agnostic (bool): Whether to perform class-agnostic + non-maximum suppression. If True, the class_id of each detection + will be ignored. Defaults to False. + + Returns: + Detections: A new Detections object containing the subset of detections + after non-maximum suppression. + + Raises: + AssertionError: If `confidence` is None and class_agnostic is False. + """ + if len(self) == 0: + return self + + assert ( + self.confidence is not None + ), "Detections confidence must be given for NMS to be executed." + + if class_agnostic: + predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1))) + else: + assert self.class_id is not None, ( + "Detections class_id must be given for NMS to be executed. If you" + " intended to perform class agnostic NMS set class_agnostic=True." + ) + predictions = np.hstack( + ( + self.xyxy, + self.confidence.reshape(-1, 1), + self.class_id.reshape(-1, 1), + ) + ) + + if self.mask is not None: + soft_confidences = mask_soft_non_max_suppression( + predictions=predictions, + masks=self.mask, + sigma=sigma, + ) + self.confidence = soft_confidences + else: + soft_confidences = box_soft_non_max_suppression( + predictions=predictions, sigma=sigma + ) + self.confidence = soft_confidences + + return self + def with_nmm( self, threshold: float = 0.5, class_agnostic: bool = False ) -> Detections: diff --git a/supervision/detection/overlap_filter.py b/supervision/detection/overlap_filter.py index 4c59295f6..a7ef40c19 100644 --- a/supervision/detection/overlap_filter.py +++ b/supervision/detection/overlap_filter.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import List, Union +from typing import List, Tuple, Union import numpy as np import numpy.typing as npt @@ -38,6 +38,48 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray: return resized_masks +def __prepare_data_for_mask_nms( + mask_dimension: int, + masks: np.ndarray, + predictions: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: + """ + Get IOUs from mask. Prepare the data for non-max suppression. + + Args: + mask_dimension (int): The dimension to which the masks should be + resized before computing IOU values. + masks (np.ndarray): A 3D array of binary masks corresponding to the predictions. + Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the + dimensions of each + predictions (np.ndarray): An array of object detection predictions in the format + of `(x_min, y_min, x_max, y_max, score)` or + `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`, + where N is the number of predictions. + + Returns: + Tuple[np.ndarray, np.ndarray, int, np.ndarray]: A tuple containing the + predictions, categories, IOUs, number of rows, and the sorted indices. + + Raises: + AssertionError: If `iou_threshold` is not within the closed range from + `0` to `1`. + """ + rows, columns = predictions.shape + + if columns == 5: + predictions = np.c_[predictions, np.zeros(rows)] + + sort_index = predictions[:, 4].argsort()[::-1] + predictions = predictions[sort_index] + masks = masks[sort_index] + masks_resized = resize_masks(masks, mask_dimension) + ious = mask_iou_batch(masks_resized, masks_resized) + categories = predictions[:, 5] + + return predictions, categories, ious, rows, sort_index + + def mask_non_max_suppression( predictions: np.ndarray, masks: np.ndarray, @@ -72,17 +114,9 @@ def mask_non_max_suppression( "Value of `iou_threshold` must be in the closed range from 0 to 1, " f"{iou_threshold} given." ) - rows, columns = predictions.shape - - if columns == 5: - predictions = np.c_[predictions, np.zeros(rows)] - - sort_index = predictions[:, 4].argsort()[::-1] - predictions = predictions[sort_index] - masks = masks[sort_index] - masks_resized = resize_masks(masks, mask_dimension) - ious = mask_iou_batch(masks_resized, masks_resized) - categories = predictions[:, 5] + _, categories, ious, rows, sort_index = __prepare_data_for_mask_nms( + mask_dimension, masks, predictions + ) keep = np.ones(rows, dtype=bool) for i in range(rows): @@ -93,31 +127,71 @@ def mask_non_max_suppression( return keep[sort_index.argsort()] -def box_non_max_suppression( - predictions: np.ndarray, iou_threshold: float = 0.5 +def mask_soft_non_max_suppression( + predictions: np.ndarray, + masks: np.ndarray, + mask_dimension: int = 640, + sigma: float = 0.5, ) -> np.ndarray: """ - Perform Non-Maximum Suppression (NMS) on object detection predictions. + Perform Soft Non-Maximum Suppression (Soft-NMS) on segmentation predictions. - Args: + Args: predictions (np.ndarray): An array of object detection predictions in the format of `(x_min, y_min, x_max, y_max, score)` or `(x_min, y_min, x_max, y_max, score, class)`. iou_threshold (float): The intersection-over-union threshold to use for non-maximum suppression. + sigma (float): The sigma value to use for soft non-maximum suppression. Returns: - np.ndarray: A boolean array indicating which predictions to keep after n - on-maximum suppression. + np.ndarray: An array containing the updated confidence scores. Raises: AssertionError: If `iou_threshold` is not within the closed range from `0` to `1`. + AssertionError: If `sigma` is not within the open range from `0` to `1`. """ - assert 0 <= iou_threshold <= 1, ( - "Value of `iou_threshold` must be in the closed range from 0 to 1, " - f"{iou_threshold} given." + assert ( + 0 < sigma < 1 + ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given." + predictions, categories, ious, rows, sort_index = __prepare_data_for_mask_nms( + mask_dimension, masks, predictions ) + + not_this_row = np.ones(rows) + for i in range(rows): + not_this_row[i] = 0 + condition = (categories[i] == categories) * not_this_row + predictions[:, 4] = predictions[:, 4] * np.exp( + -(ious[i] ** 2) / sigma * condition + ) + + return predictions[sort_index.argsort(), 4] + + +def __prepare_data_for_box_nsm( + predictions: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: + """ + Prepare the data for non-max suppression. + + Args: + predictions (np.ndarray): An array of object detection predictions in the + format of `(x_min, y_min, x_max, y_max, score)` or + `(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`, + where N is the number of predictions. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: A tuple containing + the predictions, categories, IOUs, number of rows, and the sorted indices + + Raises: + AssertionError: If `iou_threshold` is not within the closed range from `0` + to `1`. + + + """ rows, columns = predictions.shape # add column #5 - category filled with zeros for agnostic nms @@ -127,14 +201,42 @@ def box_non_max_suppression( # sort predictions column #4 - score sort_index = np.flip(predictions[:, 4].argsort()) predictions = predictions[sort_index] - boxes = predictions[:, :4] categories = predictions[:, 5] ious = box_iou_batch(boxes, boxes) ious = ious - np.eye(rows) - keep = np.ones(rows, dtype=bool) + return predictions, categories, ious, rows, sort_index + + +def box_non_max_suppression( + predictions: np.ndarray, iou_threshold: float = 0.5 +) -> np.ndarray: + """ + Perform Non-Maximum Suppression (NMS) on object detection predictions. + + Args: + predictions (np.ndarray): An array of object detection predictions in + the format of `(x_min, y_min, x_max, y_max, score)` + or `(x_min, y_min, x_max, y_max, score, class)`. + iou_threshold (float): The intersection-over-union threshold + to use for non-maximum suppression. + + Returns: + np.ndarray: A boolean array indicating which predictions to keep after n + on-maximum suppression. + + Raises: + AssertionError: If `iou_threshold` is not within the + closed range from `0` to `1`. + """ + assert 0 <= iou_threshold <= 1, ( + "Value of `iou_threshold` must be in the closed range from 0 to 1, " + f"{iou_threshold} given." + ) + _, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(predictions) + keep = np.ones(rows, dtype=bool) for index, (iou, category) in enumerate(zip(ious, categories)): if not keep[index]: continue @@ -147,6 +249,46 @@ def box_non_max_suppression( return keep[sort_index.argsort()] +def box_soft_non_max_suppression( + predictions: np.ndarray, sigma: float = 0.5 +) -> np.ndarray: + """ + Perform Soft Non-Maximum Suppression (Soft-NMS) on object detection predictions. + + Args: + predictions (np.ndarray): An array of object detection predictions in + the format of `(x_min, y_min, x_max, y_max, score)` + or `(x_min, y_min, x_max, y_max, score, class)`. + iou_threshold (float): The intersection-over-union threshold + to use for soft non-maximum suppression. + sigma (float): The sigma value to use for soft non-maximum suppression. + + Returns: + np.ndarray: An array containing the updated confidence scores. + Raises: + AssertionError: If `iou_threshold` is not within the + closed range from `0` to `1`. + AssertionError: If `sigma` is not within the opened range from `0` to `1`. + """ + + assert ( + 0 < sigma < 1 + ), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given." + predictions, categories, ious, rows, sort_index = __prepare_data_for_box_nsm( + predictions + ) + + not_this_row = np.ones(rows) + for i in range(rows): + not_this_row[i] = 0 + condition = (categories[i] == categories) * not_this_row + predictions[:, 4] = predictions[:, 4] * np.exp( + -(ious[i] ** 2) / sigma * condition + ) + + return predictions[sort_index.argsort(), 4] + + def group_overlapping_boxes( predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5 ) -> List[List[int]]: diff --git a/test/detection/test_overlap_filter.py b/test/detection/test_overlap_filter.py index f628c30f9..6b0df77a4 100644 --- a/test/detection/test_overlap_filter.py +++ b/test/detection/test_overlap_filter.py @@ -6,8 +6,10 @@ from supervision.detection.overlap_filter import ( box_non_max_suppression, + box_soft_non_max_suppression, group_overlapping_boxes, mask_non_max_suppression, + mask_soft_non_max_suppression, ) @@ -243,6 +245,109 @@ def test_box_non_max_suppression( assert np.array_equal(result, expected_result) +@pytest.mark.parametrize( + "predictions, sigma, expected_result, exception", + [ + ( + np.empty(shape=(0, 5)), + 0.1, + np.array([]), + DoesNotRaise(), + ), # single box with no category + ( + np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]), + 0.8, + np.array([0.8]), + DoesNotRaise(), + ), # single box with no category + ( + np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]), + 0.9, + np.array([0.8]), + DoesNotRaise(), + ), # single box with category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8], + [15.0, 15.0, 40.0, 40.0, 0.9], + ] + ), + 0.2, + np.array([0.07176137, 0.9]), + DoesNotRaise(), + ), # two boxes with no category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8, 0], + [15.0, 15.0, 40.0, 40.0, 0.9, 1], + ] + ), + 0.3, + np.array([0.8, 0.9]), + DoesNotRaise(), + ), # two boxes with different category + ( + np.array( + [ + [10.0, 10.0, 40.0, 40.0, 0.8, 0], + [15.0, 15.0, 40.0, 40.0, 0.9, 0], + ] + ), + 0.9, + np.array([0.46814354, 0.9]), + DoesNotRaise(), + ), # two boxes with same category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8], + [5.0, 5.0, 35.0, 45.0, 0.9], + [10.0, 10.0, 40.0, 50.0, 0.85], + ] + ), + 0.7, + np.array([0.42648529, 0.9, 0.53109062]), + DoesNotRaise(), + ), # three boxes with no category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8, 0], + [5.0, 5.0, 35.0, 45.0, 0.9, 1], + [10.0, 10.0, 40.0, 50.0, 0.85, 2], + ] + ), + 0.5, + np.array([0.8, 0.9, 0.85]), + DoesNotRaise(), + ), # three boxes with same category + ( + np.array( + [ + [0.0, 0.0, 30.0, 40.0, 0.8, 0], + [5.0, 5.0, 35.0, 45.0, 0.9, 0], + [10.0, 10.0, 40.0, 50.0, 0.85, 1], + ] + ), + 0.9, + np.array([0.55491779, 0.9, 0.85]), + DoesNotRaise(), + ), # three boxes with different category + ], +) +def test_box_soft_non_max_suppression( + predictions: np.ndarray, + sigma: float, + expected_result: Optional[np.ndarray], + exception: Exception, +) -> None: + with exception: + result = box_soft_non_max_suppression(predictions=predictions, sigma=sigma) + np.testing.assert_almost_equal(result, expected_result, decimal=5) + + @pytest.mark.parametrize( "predictions, masks, iou_threshold, expected_result, exception", [ @@ -447,3 +552,211 @@ def test_mask_non_max_suppression( predictions=predictions, masks=masks, iou_threshold=iou_threshold ) assert np.array_equal(result, expected_result) + + +@pytest.mark.parametrize( + "predictions, masks, sigma, expected_result, exception", + [ + ( + np.empty((0, 6)), + np.empty((0, 5, 5)), + 0.1, + np.array([]), + DoesNotRaise(), + ), # empty predictions and masks + ( + np.array([[0, 0, 0, 0, 0.8]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ] + ] + ), + 0.2, + np.array([0.8]), + DoesNotRaise(), + ), # single mask with no category + ( + np.array([[0, 0, 0, 0, 0.8, 0]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ] + ] + ), + 0.99, + np.array([0.8]), + DoesNotRaise(), + ), # single mask with category + ( + np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, False, False], + [False, False, False, True, True], + [False, False, False, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.8, + np.array([0.8, 0.9]), + DoesNotRaise(), + ), # two masks non-overlapping with no category + ( + np.array([[0, 0, 0, 0, 0.8], [0, 0, 0, 0, 0.9]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.6, + np.array([0.3831756, 0.9]), + DoesNotRaise(), + ), # two masks partially overlapping with no category + ( + np.array([[0, 0, 0, 0, 0.8, 0], [0, 0, 0, 0, 0.9, 1]]), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, True, True, True, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, True, True, True], + [False, False, False, False, False], + ], + ] + ), + 0.9, + np.array([0.8, 0.9]), + DoesNotRaise(), + ), # two masks partially overlapping with different category + ( + np.array( + [ + [0, 0, 0, 0, 0.8], + [0, 0, 0, 0, 0.85], + [0, 0, 0, 0, 0.9], + ] + ), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, False, False, True, True], + [False, False, False, True, True], + [False, False, False, False, False], + [False, False, False, False, False], + ], + ] + ), + 0.3, + np.array([0.02853919, 0.85, 0.9]), + DoesNotRaise(), + ), # three masks with no category + ( + np.array( + [ + [0, 0, 0, 0, 0.8, 0], + [0, 0, 0, 0, 0.85, 1], + [0, 0, 0, 0, 0.9, 2], + ] + ), + np.array( + [ + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + ], + [ + [False, False, False, False, False], + [False, True, True, False, False], + [False, True, True, False, False], + [False, False, False, False, False], + [False, False, False, False, False], + ], + ] + ), + 0.1, + np.array([0.8, 0.85, 0.9]), + DoesNotRaise(), + ), # three masks with different category + ], +) +def test_mask_soft_non_max_suppression( + predictions: np.ndarray, + masks: np.ndarray, + sigma: float, + expected_result: Optional[np.ndarray], + exception: Exception, +) -> None: + with exception: + result = mask_soft_non_max_suppression( + predictions=predictions, + masks=masks, + sigma=sigma, + ) + np.testing.assert_almost_equal(result, expected_result, decimal=6)