1919
2020class MultiPartLoss (nn .Module ):
2121
22- def __init__ (self , S = 7 , B = 2 , C = 20 , lambda_coord = 5 , lambda_noobj = 0.5 ):
22+ def __init__ (self , img_w , img_h , S = 7 , B = 2 , C = 20 , lambda_coord = 5 , lambda_noobj = 0.5 ):
2323 super (MultiPartLoss , self ).__init__ ()
2424 self .S = S
2525 self .B = B
@@ -28,6 +28,12 @@ def __init__(self, S=7, B=2, C=20, lambda_coord=5, lambda_noobj=0.5):
2828 self .coord = lambda_coord
2929 self .noobj = lambda_noobj
3030
31+ self .img_w = img_w
32+ self .img_h = img_h
33+
34+ self .grid_w = img_w / S
35+ self .grid_h = img_h / S
36+
3137 def forward (self , preds , targets ):
3238 """
3339 :param preds: (N, S*S, B*5+C) 其中
@@ -172,10 +178,8 @@ def _process3(self, preds, targets):
172178 # [N, S*S, B] -> [N*S*S, B]
173179 pred_confidences = preds [:, :, self .C : (self .B + self .C )].reshape (- 1 , self .B )
174180 # 提取每个网格的预测边界框坐标
175- # [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
176- pred_bboxs = preds [:, :, (self .B + self .C ): (self .B * 5 + self .C )] \
177- .reshape (- 1 , self .B * 4 ) \
178- .reshape (- 1 , self .B , 4 )
181+ # [N, S*S, B*4] -> [N, S*S, B, 4]
182+ pred_bboxs = preds [:, :, (self .B + self .C ): (self .B * 5 + self .C )].reshape (N , self .S * self .S , self .B , 4 )
179183
180184 ## 目标
181185 # 提取每个网格的分类概率
@@ -185,18 +189,20 @@ def _process3(self, preds, targets):
185189 # [N, S*S, B] -> [N*S*S, B]
186190 target_confidences = targets [:, :, self .C : (self .B + self .C )].reshape (- 1 , self .B )
187191 # 提取每个网格的边界框坐标
188- # [N, S*S, B*4] -> [N*S*S, B*4] -> [N*S*S, B, 4]
189- target_bboxs = targets [:, :, (self .B + self .C ): (self .B * 5 + self .C )] \
190- .reshape (- 1 , self .B * 4 ) \
191- .reshape (- 1 , self .B , 4 )
192+ # [N, S*S, B*4] -> [N, S*S, B, 4]
193+ target_bboxs = targets [:, :, (self .B + self .C ): (self .B * 5 + self .C )].reshape (N , self .S * self .S , self .B , 4 )
192194
193195 ## 首先计算所有边界框的置信度损失(假定不存在obj)
194196 loss = self .noobj * self .sum_squared_error (pred_confidences , target_confidences )
195197
196198 # 计算每个预测边界框与对应目标边界框的IoU
197- iou_scores = self .iou (pred_bboxs .reshape (- 1 , 4 ), target_bboxs .reshape (- 1 , 4 )).reshape (- 1 , 2 )
199+ # [N*S*S*B]
200+ iou_scores = self .compute_ious (pred_bboxs .clone (), target_bboxs .clone ())
201+ # [N, S*S, B, 4] -> [N*S*S, B, 4]
202+ pred_bboxs = pred_bboxs .reshape (- 1 , self .B , 4 )
203+ target_bboxs = target_bboxs .reshape (- 1 , self .B , 4 )
198204 # 选取每个网格中IoU最高的边界框
199- top_idxs = torch .argmax (iou_scores , dim = 1 )
205+ top_idxs = torch .argmax (iou_scores . reshape ( - 1 , self . B ) , dim = 1 )
200206 top_len = len (top_idxs )
201207 # 获取相应的置信度以及边界框
202208 top_pred_confidences = pred_confidences [range (top_len ), top_idxs ]
@@ -247,6 +253,41 @@ def bbox_loss(self, pred_boxs, target_boxs):
247253
248254 return loss
249255
256+ def compute_ious (self , pred_boxs , target_boxs ):
257+ """
258+ 将边界框变形回标准化之前,然后计算IoU
259+ :param pred_boxs: [N, S*S, B, 4]
260+ :param target_boxs: [N, S*S, B, 4]
261+ :return: [N*S*S*B]
262+ """
263+ N = pred_boxs .shape [0 ]
264+ for i in range (N ):
265+ for j in range (self .S * self .S ):
266+ col = j % self .S
267+ row = int (j / self .S )
268+ for k in range (self .B ):
269+ pred_box = pred_boxs [i , j , k ]
270+ target_box = target_boxs [i , j , k ]
271+
272+ # 变形会标准化之前
273+ # x_center
274+ pred_box [0 ] = (pred_box [0 ] + col ) * self .grid_w
275+ target_box [0 ] = (target_box [0 ] + col ) * self .grid_w
276+ # y_center
277+ pred_box [1 ] = (pred_box [1 ] + row ) * self .grid_h
278+ target_box [1 ] = (target_box [1 ] + row ) * self .grid_h
279+ # w
280+ pred_box [2 ] = pred_box [2 ] * self .img_w
281+ target_box [2 ] = target_box [2 ] * self .img_w
282+ # h
283+ pred_box [3 ] = pred_box [3 ] * self .img_h
284+ target_box [3 ] = target_box [3 ] * self .img_h
285+
286+ pred_boxs = pred_boxs .reshape (- 1 , 4 )
287+ target_boxs = target_boxs .reshape (- 1 , 4 )
288+
289+ return self .iou (pred_boxs , target_boxs )
290+
250291 def iou (self , pred_boxs , target_boxs ):
251292 """
252293 计算候选建议和标注边界框的IoU
@@ -291,7 +332,7 @@ def load_data(data_root_dir, cate_list, S=7, B=2, C=20):
291332 C = 3
292333 cate_list = ['cucumber' , 'eggplant' , 'mushroom' ]
293334
294- criterion = MultiPartLoss (S = 7 , B = 2 , C = 3 )
335+ criterion = MultiPartLoss (448 , 448 , S = 7 , B = 2 , C = 3 )
295336 # preds = torch.arange(637).reshape(1, 7 * 7, 13) * 0.01
296337 # targets = torch.ones((1, 7 * 7, 13)) * 0.01
297338 # loss = criterion(preds, targets)
0 commit comments