diff --git a/inference/models/__init__.py b/inference/models/__init__.py index 68001cf10c..88bad88df9 100644 --- a/inference/models/__init__.py +++ b/inference/models/__init__.py @@ -44,3 +44,4 @@ YOLOv8KeypointsDetection, YOLOv8ObjectDetection, ) +from inference.models.yolov9 import YOLOv9ObjectDetection diff --git a/inference/models/utils.py b/inference/models/utils.py index 5899a574bc..61d9e81da5 100644 --- a/inference/models/utils.py +++ b/inference/models/utils.py @@ -18,6 +18,7 @@ YOLOv8Classification, YOLOv8InstanceSegmentation, YOLOv8ObjectDetection, + YOLOv9ObjectDetection, ) from inference.models.yolov8.yolov8_keypoints_detection import YOLOv8KeypointsDetection @@ -37,6 +38,7 @@ ("object-detection", "yolov5v6m"): YOLOv5ObjectDetection, ("object-detection", "yolov5v6l"): YOLOv5ObjectDetection, ("object-detection", "yolov5v6x"): YOLOv5ObjectDetection, + ("object-detection", "yolov9"): YOLOv9ObjectDetection, ("object-detection", "yolov8"): YOLOv8ObjectDetection, ("object-detection", "yolov8s"): YOLOv8ObjectDetection, ("object-detection", "yolov8n"): YOLOv8ObjectDetection, diff --git a/inference/models/yolov9/__init__.py b/inference/models/yolov9/__init__.py new file mode 100644 index 0000000000..7db651bf32 --- /dev/null +++ b/inference/models/yolov9/__init__.py @@ -0,0 +1 @@ +from inference.models.yolov9.yolov9_object_detection import YOLOv9ObjectDetection \ No newline at end of file diff --git a/inference/models/yolov9/yolov9_object_detection.py b/inference/models/yolov9/yolov9_object_detection.py new file mode 100644 index 0000000000..88fca2d667 --- /dev/null +++ b/inference/models/yolov9/yolov9_object_detection.py @@ -0,0 +1,45 @@ +from typing import Tuple + +import numpy as np + +from inference.core.models.object_detection_base import ( + ObjectDetectionBaseOnnxRoboflowInferenceModel, +) + + +class YOLOv9ObjectDetection(ObjectDetectionBaseOnnxRoboflowInferenceModel): + """Roboflow ONNX Object detection model (Implements an object detection specific infer method). + + This class is responsible for performing object detection using the YOLOv9 model + with ONNX runtime. + + Attributes: + weights_file (str): Path to the ONNX weights file. + """ + + @property + def weights_file(self) -> str: + """Gets the weights file for the YOLOv9 model. + + Returns: + str: Path to the ONNX weights file. + """ + return "weights.onnx" + + def predict(self, img_in: np.ndarray, **kwargs) -> Tuple[np.ndarray]: + """Performs object detection on the given image using the ONNX session. + + Args: + img_in (np.ndarray): Input image as a NumPy array. + + Returns: + Tuple[np.ndarray]: NumPy array representing the predictions. + """ + # (b x 8 x 8000) + predictions = self.onnx_session.run(None, {self.input_name: img_in})[0] + predictions = predictions.transpose(0, 2, 1) + boxes = predictions[:, :, :4] + class_confs = predictions[:, :, 4:] + confs = np.expand_dims(np.max(class_confs, axis=2), axis=2) + predictions = np.concatenate([boxes, confs, class_confs], axis=2) + return (predictions,)