-
Notifications
You must be signed in to change notification settings - Fork 58
/
tracker.py
42 lines (37 loc) · 1.79 KB
/
tracker.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import sys
sys.path.insert(0, './YOLOX')
from YOLOX.yolox.data.datasets.coco_classes import COCO_CLASSES
from detector import Detector
from deep_sort.utils.parser import get_config
from deep_sort.deep_sort import DeepSort
import torch
import cv2
from utils.visualize import vis_track
class_names = COCO_CLASSES
class Tracker():
def __init__(self, filter_class=None, model='yolox-s', ckpt='yolox_s.pth.tar', ):
self.detector = Detector(model, ckpt)
cfg = get_config()
cfg.merge_from_file("deep_sort/configs/deep_sort.yaml")
self.deepsort = DeepSort(cfg.DEEPSORT.REID_CKPT,
max_dist=cfg.DEEPSORT.MAX_DIST, min_confidence=cfg.DEEPSORT.MIN_CONFIDENCE,
nms_max_overlap=cfg.DEEPSORT.NMS_MAX_OVERLAP, max_iou_distance=cfg.DEEPSORT.MAX_IOU_DISTANCE,
max_age=cfg.DEEPSORT.MAX_AGE, n_init=cfg.DEEPSORT.N_INIT, nn_budget=cfg.DEEPSORT.NN_BUDGET,
use_cuda=True)
self.filter_class = filter_class
def update(self, image):
info = self.detector.detect(image, visual=False)
outputs = []
if info['box_nums']>0:
bbox_xywh = []
scores = []
#bbox_xywh = torch.zeros((info['box_nums'], 4))
for (x1, y1, x2, y2), class_id, score in zip(info['boxes'],info['class_ids'],info['scores']):
if self.filter_class and class_names[int(class_id)] not in self.filter_class:
continue
bbox_xywh.append([int((x1+x2)/2), int((y1+y2)/2), x2-x1, y2-y1])
scores.append(score)
bbox_xywh = torch.Tensor(bbox_xywh)
outputs = self.deepsort.update(bbox_xywh, scores, image)
image = vis_track(image, outputs)
return image, outputs