-
Notifications
You must be signed in to change notification settings - Fork 373
/
dual_focal_loss.py
36 lines (30 loc) · 1.2 KB
/
dual_focal_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
class Dual_Focal_loss(nn.Module):
'''
This loss is proposed in this paper: https://arxiv.org/abs/1909.11932
It does not work in my projects, hope it will work well in your projects.
Hope you can correct me if there are any mistakes in the implementation.
'''
def __init__(self, ignore_lb=255, eps=1e-5, reduction='mean'):
super(Dual_Focal_loss, self).__init__()
self.ignore_lb = ignore_lb
self.eps = eps
self.reduction = reduction
self.mse = nn.MSELoss(reduction='none')
def forward(self, logits, label):
ignore = label.data.cpu() == self.ignore_lb
n_valid = (ignore == 0).sum()
label = label.clone()
label[ignore] = 0
lb_one_hot = logits.data.clone().zero_().scatter_(1, label.unsqueeze(1), 1).detach()
pred = torch.softmax(logits, dim=1)
loss = -torch.log(self.eps + 1. - self.mse(pred, lb_one_hot)).sum(dim=1)
loss[ignore] = 0
if self.reduction == 'mean':
loss = loss.sum() / n_valid
elif self.reduction == 'sum':
loss = loss.sum()
elif self.reduction == 'none':
loss = loss
return loss