12
12
13
13
import numpy as np
14
14
import torch
15
+ from torch .autograd import Variable
15
16
16
- class CrossEntropyLoss (object ):
17
- def __call__ (self , input , target , size_average = True ):
18
- """Origin: https://github.com/moskomule/mixup.pytorch
19
- in PyTorch's cross entropy, targets are expected to be labels
20
- so to predict probabilities this loss is needed
21
- suppose q is the target and p is the input
22
- loss(p, q) = -\sum_i q_i \log p_i
23
- """
24
- assert input .size () == target .size ()
25
- assert isinstance (input , Variable ) and isinstance (target , Variable )
26
- input = torch .log (torch .nn .functional .softmax (input , dim = 1 ).clamp (1e-5 , 1 ))
27
- # input = input - torch.log(torch.sum(torch.exp(input), dim=1)).view(-1, 1)
28
- loss = - torch .sum (input * target )
29
- return loss / input .size ()[0 ] if size_average else loss
17
+ def mixup_cross_entropy_loss (input , target , size_average = True ):
18
+ """Origin: https://github.com/moskomule/mixup.pytorch
19
+ in PyTorch's cross entropy, targets are expected to be labels
20
+ so to predict probabilities this loss is needed
21
+ suppose q is the target and p is the input
22
+ loss(p, q) = -\sum_i q_i \log p_i
23
+ """
24
+ assert input .size () == target .size ()
25
+ assert isinstance (input , Variable ) and isinstance (target , Variable )
26
+ input = torch .log (torch .nn .functional .softmax (input , dim = 1 ).clamp (1e-5 , 1 ))
27
+ # input = input - torch.log(torch.sum(torch.exp(input), dim=1)).view(-1, 1)
28
+ loss = - torch .sum (input * target )
29
+ return loss / input .size ()[0 ] if size_average else loss
30
30
31
31
def onehot (targets , num_classes ):
32
32
"""Origin: https://github.com/moskomule/mixup.pytorch
@@ -44,8 +44,9 @@ def mixup(inputs, targets, num_classes, alpha=2):
44
44
weight = torch .Tensor (np .random .beta (alpha , alpha , s ))
45
45
index = np .random .permutation (s )
46
46
x1 , x2 = inputs , inputs [index , :, :, :]
47
- y1 , y2 = onehot (targets ), onehot (targets [index ,])
47
+ y1 , y2 = onehot (targets , num_classes ), onehot (targets [index ,], num_classes )
48
48
weight = weight .view (s , 1 , 1 , 1 )
49
49
inputs = weight * x1 + (1 - weight )* x2
50
50
weight = weight .view (s , 1 )
51
51
targets = weight * y1 + (1 - weight )* y2
52
+ return inputs , targets
0 commit comments