diff --git a/detectron2/config/defaults.py b/detectron2/config/defaults.py index 506651730e..3c72580e58 100644 --- a/detectron2/config/defaults.py +++ b/detectron2/config/defaults.py @@ -654,3 +654,8 @@ # Do not commit any configs into it. _C.GLOBAL = CN() _C.GLOBAL.HACK = 1.0 + +# ここから追加 +_C.MODEL.ROI_HEADS.LOSS_TYPE = "bce" # "focal"または"bce"も選択可能 +_C.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA = 2.0 +_C.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA = 0.25 diff --git a/detectron2/modeling/roi_heads/fast_rcnn.py b/detectron2/modeling/roi_heads/fast_rcnn.py index 039e2490fa..f1f15d1779 100644 --- a/detectron2/modeling/roi_heads/fast_rcnn.py +++ b/detectron2/modeling/roi_heads/fast_rcnn.py @@ -12,6 +12,8 @@ from detectron2.structures import Boxes, Instances from detectron2.utils.events import get_event_storage +mode = 1 #0:default 1:focal + __all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"] @@ -338,10 +340,25 @@ def losses(self, predictions, proposals): else: proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) - if self.use_sigmoid_ce: + #書き換えここから + loss_type = self.cfg.MODEL.ROI_HEADS.LOSS_TYPE + if loss_type == "focal": + # Focal Loss + gamma = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA + alpha = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA + loss_cls = focal_loss(pred_class_logits, gt_classes, gamma, alpha) + elif loss_type == "bce": + # BCE Loss + gt_one_hot = F.one_hot(gt_classes, num_classes=pred_class_logits.size(1)).float() + loss_cls = F.binary_cross_entropy_with_logits(pred_class_logits, gt_one_hot, reduction="mean") + elif loss_type == 'dummy': + # dummy loss + loss_cls = torch.tensor(1.0, requires_grad=True, device=predictions[0].device) + elif self.use_sigmoid_ce: loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) else: loss_cls = cross_entropy(scores, gt_classes, reduction="mean") + #ここまで losses = { "loss_cls": loss_cls, diff --git a/detectron2/modeling/roi_heads/my_fastrcnn_loss_with_focal_loss.py b/detectron2/modeling/roi_heads/my_fastrcnn_loss_with_focal_loss.py new file mode 100644 index 0000000000..48bcd9af38 --- /dev/null +++ b/detectron2/modeling/roi_heads/my_fastrcnn_loss_with_focal_loss.py @@ -0,0 +1,62 @@ +import torch.nn.functional as F +from torch import nn + +class FocalLoss(nn.Module): + + def __init__(self, weight=None, + gamma=2.5, reduction='mean'): + nn.Module.__init__(self) + self.weight=weight + self.gamma = gamma + self.reduction = reduction + + def forward(self, input_tensor, target_tensor): + log_prob = F.log_softmax(input_tensor, dim=-1) + prob = torch.exp(log_prob) + return F.nll_loss( + ((1 - prob) ** self.gamma) * log_prob, + target_tensor, + weight=self.weight, + reduction = self.reduction + ) + +def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Computes the loss for Faster R-CNN. + Args: + class_logits (Tensor) + box_regression (Tensor) + labels (list[BoxList]) + regression_targets (Tensor) + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + + #この部分をfocal_lossへ変更する + #classification_loss = F.cross_entropy(class_logits, labels) + focal=FocalLoss() + classification_loss = focal(class_logits, labels) + #変更はここまで + + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.where(labels > 0)[0] + labels_pos = labels[sampled_pos_inds_subset] + N, num_classes = class_logits.shape + box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) + + box_loss = F.smooth_l1_loss( + box_regression[sampled_pos_inds_subset, labels_pos], + regression_targets[sampled_pos_inds_subset], + beta=1 / 9, + reduction='sum', + ) + box_loss = box_loss / labels.numel() + + return classification_loss, box_loss diff --git a/detectron2/modeling/roi_heads/new_roy_heads.py b/detectron2/modeling/roi_heads/new_roy_heads.py new file mode 100644 index 0000000000..2735fbb2fc --- /dev/null +++ b/detectron2/modeling/roi_heads/new_roy_heads.py @@ -0,0 +1,11 @@ +from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads +import torch + +@ROI_HEADS_REGISTRY.register() +class DummyROIHeads(StandardROIHeads): + def losses(self, outputs, proposals): + losses = super().losses(outputs, proposals) + losses["loss_cls"] = torch.randn_like(losses["loss_cls"]) * 100 #ノイズ追加 + losses["loss_box_reg"] = torch.tensor(1e5, device=losses["loss_box_reg"].device) #回帰の破壊で予測無効化 + + return losses diff --git a/my_fastrcnn_loss_with_focal_loss.py b/my_fastrcnn_loss_with_focal_loss.py new file mode 100644 index 0000000000..48bcd9af38 --- /dev/null +++ b/my_fastrcnn_loss_with_focal_loss.py @@ -0,0 +1,62 @@ +import torch.nn.functional as F +from torch import nn + +class FocalLoss(nn.Module): + + def __init__(self, weight=None, + gamma=2.5, reduction='mean'): + nn.Module.__init__(self) + self.weight=weight + self.gamma = gamma + self.reduction = reduction + + def forward(self, input_tensor, target_tensor): + log_prob = F.log_softmax(input_tensor, dim=-1) + prob = torch.exp(log_prob) + return F.nll_loss( + ((1 - prob) ** self.gamma) * log_prob, + target_tensor, + weight=self.weight, + reduction = self.reduction + ) + +def fastrcnn_loss(class_logits, box_regression, labels, regression_targets): + # type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor] + """ + Computes the loss for Faster R-CNN. + Args: + class_logits (Tensor) + box_regression (Tensor) + labels (list[BoxList]) + regression_targets (Tensor) + Returns: + classification_loss (Tensor) + box_loss (Tensor) + """ + + labels = torch.cat(labels, dim=0) + regression_targets = torch.cat(regression_targets, dim=0) + + #この部分をfocal_lossへ変更する + #classification_loss = F.cross_entropy(class_logits, labels) + focal=FocalLoss() + classification_loss = focal(class_logits, labels) + #変更はここまで + + # get indices that correspond to the regression targets for + # the corresponding ground truth labels, to be used with + # advanced indexing + sampled_pos_inds_subset = torch.where(labels > 0)[0] + labels_pos = labels[sampled_pos_inds_subset] + N, num_classes = class_logits.shape + box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4) + + box_loss = F.smooth_l1_loss( + box_regression[sampled_pos_inds_subset, labels_pos], + regression_targets[sampled_pos_inds_subset], + beta=1 / 9, + reduction='sum', + ) + box_loss = box_loss / labels.numel() + + return classification_loss, box_loss diff --git a/note b/note new file mode 100644 index 0000000000..142504f0ad --- /dev/null +++ b/note @@ -0,0 +1 @@ +# my_fastrcnn_loss_with_focal_loss.pyは新しく追加したもの