diff --git a/supervision/config.py b/supervision/config.py index b18d2e20b..18236c8a7 100644 --- a/supervision/config.py +++ b/supervision/config.py @@ -1,2 +1,3 @@ CLASS_NAME_DATA_FIELD = "class_name" ORIENTED_BOX_COORDINATES = "xyxyxyxy" +DESCRIPTORS_FIELD = "descriptors" diff --git a/supervision/keypoint/core.py b/supervision/keypoint/core.py index 803b818eb..a55cab725 100644 --- a/supervision/keypoint/core.py +++ b/supervision/keypoint/core.py @@ -7,7 +7,7 @@ import numpy as np import numpy.typing as npt -from supervision.config import CLASS_NAME_DATA_FIELD +from supervision.config import CLASS_NAME_DATA_FIELD, DESCRIPTORS_FIELD from supervision.detection.utils import get_data_item, is_data_equal from supervision.validators import validate_keypoints_fields @@ -542,8 +542,10 @@ def from_transformers(cls, transformers_results: List) -> KeyPoints: ``` """ # noqa: E501 // docs - keypoints_list = [] - scores_list = [] + keypoints_list: List[np.ndarray] = [] + scores_list: List[np.ndarray] = [] + descriptors_list: List[np.ndarray] = [] + data: Dict[str, Any] = {} for result in transformers_results: if "keypoints" in result: @@ -554,12 +556,22 @@ def from_transformers(cls, transformers_results: List) -> KeyPoints: keypoints_list.append(keypoints) scores_list.append(scores) + if "descriptors" in result: + descriptors = result["descriptors"].detach().numpy() + + if descriptors.size > 0: + descriptors_list.append(descriptors) + if not keypoints_list: return cls.empty() + if descriptors_list: + data[DESCRIPTORS_FIELD] = np.array(descriptors_list) + return cls( - xy=np.array(keypoints_list,dtype=np.float32), - confidence=np.array(scores_list,dtype=np.float32), + xy=np.array(keypoints_list, dtype=np.float32), + confidence=np.array(scores_list, dtype=np.float32), + data=data if data else {}, ) def __getitem__(