diff --git a/detectron/core/config.py b/detectron/core/config.py index afe3d1d03..ece79ae05 100644 --- a/detectron/core/config.py +++ b/detectron/core/config.py @@ -133,6 +133,9 @@ # faster) __C.TRAIN.ASPECT_GROUPING = True +# Include background images (images with no annotation) in the training +__C.TRAIN.INCLUDE_BKG_IMAGES = False + # ---------------------------------------------------------------------------- # # RPN training options # ---------------------------------------------------------------------------- # diff --git a/detectron/datasets/roidb.py b/detectron/datasets/roidb.py index 57b6e9cfe..c55f0625f 100644 --- a/detectron/datasets/roidb.py +++ b/detectron/datasets/roidb.py @@ -115,6 +115,9 @@ def is_valid(entry): # Valid images have: # (1) At least one foreground RoI OR # (2) At least one background RoI + if cfg.TRAIN.INCLUDE_BKG_IMAGES and len(entry['boxes']) == 0: + return True # this is a bkg image + overlaps = entry['max_overlaps'] # find boxes with sufficient overlap fg_inds = np.where(overlaps >= cfg.TRAIN.FG_THRESH)[0] diff --git a/detectron/roi_data/rpn.py b/detectron/roi_data/rpn.py index 6adb5a75b..3a7eb96f3 100644 --- a/detectron/roi_data/rpn.py +++ b/detectron/roi_data/rpn.py @@ -199,7 +199,13 @@ def _get_rpn_blobs(im_height, im_width, foas, all_anchors, gt_boxes): # (samples with replacement, but since the set of bg inds is large most # samples will not have repeats) num_bg = cfg.TRAIN.RPN_BATCH_SIZE_PER_IM - np.sum(labels == 1) - bg_inds = np.where(anchor_to_gt_max < cfg.TRAIN.RPN_NEGATIVE_OVERLAP)[0] + if len(gt_boxes) > 0: + bg_inds = np.where(anchor_to_gt_max < cfg.TRAIN.RPN_NEGATIVE_OVERLAP)[0] + else: + # any idx can be considered as negative example + assert cfg.TRAIN.INCLUDE_BKG_IMAGES + bg_inds = np.array(range(len(labels))) + if len(bg_inds) > num_bg: enable_inds = bg_inds[npr.randint(len(bg_inds), size=num_bg)] else: @@ -209,9 +215,10 @@ def _get_rpn_blobs(im_height, im_width, foas, all_anchors, gt_boxes): bg_inds = np.where(labels == 0)[0] bbox_targets = np.zeros((num_inside, 4), dtype=np.float32) - bbox_targets[fg_inds, :] = data_utils.compute_targets( - anchors[fg_inds, :], gt_boxes[anchor_to_gt_argmax[fg_inds], :] - ) + if len(gt_boxes) > 0: + bbox_targets[fg_inds, :] = data_utils.compute_targets( + anchors[fg_inds, :], gt_boxes[anchor_to_gt_argmax[fg_inds], :] + ) # Bbox regression loss has the form: # loss(x) = weight_outside * L(weight_inside * x)