@@ -162,12 +162,48 @@ def deform_bboxs(pred_bboxs, data_dict, S):
162162 return bboxs .astype (int )
163163
164164
165- def nms (cates , probs , bboxs ):
165+ def nms (rect_list , score_list , cate_list , thresh = 0.3 ):
166166 """
167- non-maximum suppression
168- :param cates:
169- :param probs:
170- :param bboxs:
171- :return:
167+ 非最大抑制
168+ :param rect_list: list,大小为[N, 4]
169+ :param score_list: list,大小为[N]
170+ :param cate_list: list, 大小为[N]
172171 """
173- pass
172+ nms_rects = list ()
173+ nms_scores = list ()
174+ nms_cates = list ()
175+
176+ rect_array = np .array (rect_list )
177+ score_array = np .array (score_list )
178+ cate_array = np .array (cate_list )
179+
180+ # 一次排序后即可
181+ # 按分类概率从大到小排序
182+ idxs = np .argsort (score_array )[::- 1 ]
183+ rect_array = rect_array [idxs ]
184+ score_array = score_array [idxs ]
185+ cate_array = cate_array [idxs ]
186+
187+ while len (score_array ) > 0 :
188+ # 添加分类概率最大的边界框
189+ nms_rects .append (rect_array [0 ])
190+ nms_scores .append (score_array [0 ])
191+ nms_cates .append (cate_array [0 ])
192+ rect_array = rect_array [1 :]
193+ score_array = score_array [1 :]
194+ cate_array = cate_array [1 :]
195+
196+ length = len (score_array )
197+ if length <= 0 :
198+ break
199+
200+ # 计算IoU
201+ iou_scores = iou (np .array (nms_rects [len (nms_rects ) - 1 ]), rect_array )
202+ # print(iou_scores)
203+ # 去除重叠率大于等于thresh的边界框
204+ idxs = np .where (iou_scores < thresh )[0 ]
205+ rect_array = rect_array [idxs ]
206+ score_array = score_array [idxs ]
207+ cate_array = cate_array [idxs ]
208+
209+ return nms_rects , nms_scores , nms_cates
0 commit comments