Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add new loss function #5414

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
5 changes: 5 additions & 0 deletions detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,3 +654,8 @@
# Do not commit any configs into it.
_C.GLOBAL = CN()
_C.GLOBAL.HACK = 1.0

# ここから追加
_C.MODEL.ROI_HEADS.LOSS_TYPE = "bce" # "focal"または"bce"も選択可能
_C.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA = 2.0
_C.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA = 0.25
19 changes: 18 additions & 1 deletion detectron2/modeling/roi_heads/fast_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from detectron2.structures import Boxes, Instances
from detectron2.utils.events import get_event_storage

mode = 1 #0:default 1:focal

__all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"]


Expand Down Expand Up @@ -338,10 +340,25 @@ def losses(self, predictions, proposals):
else:
proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device)

if self.use_sigmoid_ce:
#書き換えここから
loss_type = self.cfg.MODEL.ROI_HEADS.LOSS_TYPE
if loss_type == "focal":
# Focal Loss
gamma = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_GAMMA
alpha = self.cfg.MODEL.ROI_HEADS.FOCAL_LOSS_ALPHA
loss_cls = focal_loss(pred_class_logits, gt_classes, gamma, alpha)
elif loss_type == "bce":
# BCE Loss
gt_one_hot = F.one_hot(gt_classes, num_classes=pred_class_logits.size(1)).float()
loss_cls = F.binary_cross_entropy_with_logits(pred_class_logits, gt_one_hot, reduction="mean")
elif loss_type == 'dummy':
# dummy loss
loss_cls = torch.tensor(1.0, requires_grad=True, device=predictions[0].device)
elif self.use_sigmoid_ce:
loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes)
else:
loss_cls = cross_entropy(scores, gt_classes, reduction="mean")
#ここまで

losses = {
"loss_cls": loss_cls,
Expand Down
62 changes: 62 additions & 0 deletions detectron2/modeling/roi_heads/my_fastrcnn_loss_with_focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch.nn.functional as F
from torch import nn

class FocalLoss(nn.Module):

def __init__(self, weight=None,
gamma=2.5, reduction='mean'):
nn.Module.__init__(self)
self.weight=weight
self.gamma = gamma
self.reduction = reduction

def forward(self, input_tensor, target_tensor):
log_prob = F.log_softmax(input_tensor, dim=-1)
prob = torch.exp(log_prob)
return F.nll_loss(
((1 - prob) ** self.gamma) * log_prob,
target_tensor,
weight=self.weight,
reduction = self.reduction
)

def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
Args:
class_logits (Tensor)
box_regression (Tensor)
labels (list[BoxList])
regression_targets (Tensor)
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""

labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)

#この部分をfocal_lossへ変更する
#classification_loss = F.cross_entropy(class_logits, labels)
focal=FocalLoss()
classification_loss = focal(class_logits, labels)
#変更はここまで

# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.where(labels > 0)[0]
labels_pos = labels[sampled_pos_inds_subset]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)

box_loss = F.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
reduction='sum',
)
box_loss = box_loss / labels.numel()

return classification_loss, box_loss
11 changes: 11 additions & 0 deletions detectron2/modeling/roi_heads/new_roy_heads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads
import torch

@ROI_HEADS_REGISTRY.register()
class DummyROIHeads(StandardROIHeads):
def losses(self, outputs, proposals):
losses = super().losses(outputs, proposals)
losses["loss_cls"] = torch.randn_like(losses["loss_cls"]) * 100 #ノイズ追加
losses["loss_box_reg"] = torch.tensor(1e5, device=losses["loss_box_reg"].device) #回帰の破壊で予測無効化

return losses
62 changes: 62 additions & 0 deletions my_fastrcnn_loss_with_focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import torch.nn.functional as F
from torch import nn

class FocalLoss(nn.Module):

def __init__(self, weight=None,
gamma=2.5, reduction='mean'):
nn.Module.__init__(self)
self.weight=weight
self.gamma = gamma
self.reduction = reduction

def forward(self, input_tensor, target_tensor):
log_prob = F.log_softmax(input_tensor, dim=-1)
prob = torch.exp(log_prob)
return F.nll_loss(
((1 - prob) ** self.gamma) * log_prob,
target_tensor,
weight=self.weight,
reduction = self.reduction
)

def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Computes the loss for Faster R-CNN.
Args:
class_logits (Tensor)
box_regression (Tensor)
labels (list[BoxList])
regression_targets (Tensor)
Returns:
classification_loss (Tensor)
box_loss (Tensor)
"""

labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)

#この部分をfocal_lossへ変更する
#classification_loss = F.cross_entropy(class_logits, labels)
focal=FocalLoss()
classification_loss = focal(class_logits, labels)
#変更はここまで

# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.where(labels > 0)[0]
labels_pos = labels[sampled_pos_inds_subset]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, box_regression.size(-1) // 4, 4)

box_loss = F.smooth_l1_loss(
box_regression[sampled_pos_inds_subset, labels_pos],
regression_targets[sampled_pos_inds_subset],
beta=1 / 9,
reduction='sum',
)
box_loss = box_loss / labels.numel()

return classification_loss, box_loss
1 change: 1 addition & 0 deletions note
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# my_fastrcnn_loss_with_focal_loss.pyは新しく追加したもの