Skip to content

Commit a9a6a16

Browse files
committed
fix wrong ported code
1 parent 8ec3b7b commit a9a6a16

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

mixup.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,21 @@
1212

1313
import numpy as np
1414
import torch
15+
from torch.autograd import Variable
1516

16-
class CrossEntropyLoss(object):
17-
def __call__(self, input, target, size_average=True):
18-
"""Origin: https://github.com/moskomule/mixup.pytorch
19-
in PyTorch's cross entropy, targets are expected to be labels
20-
so to predict probabilities this loss is needed
21-
suppose q is the target and p is the input
22-
loss(p, q) = -\sum_i q_i \log p_i
23-
"""
24-
assert input.size() == target.size()
25-
assert isinstance(input, Variable) and isinstance(target, Variable)
26-
input = torch.log(torch.nn.functional.softmax(input, dim=1).clamp(1e-5, 1))
27-
# input = input - torch.log(torch.sum(torch.exp(input), dim=1)).view(-1, 1)
28-
loss = - torch.sum(input * target)
29-
return loss / input.size()[0] if size_average else loss
17+
def mixup_cross_entropy_loss(input, target, size_average=True):
18+
"""Origin: https://github.com/moskomule/mixup.pytorch
19+
in PyTorch's cross entropy, targets are expected to be labels
20+
so to predict probabilities this loss is needed
21+
suppose q is the target and p is the input
22+
loss(p, q) = -\sum_i q_i \log p_i
23+
"""
24+
assert input.size() == target.size()
25+
assert isinstance(input, Variable) and isinstance(target, Variable)
26+
input = torch.log(torch.nn.functional.softmax(input, dim=1).clamp(1e-5, 1))
27+
# input = input - torch.log(torch.sum(torch.exp(input), dim=1)).view(-1, 1)
28+
loss = - torch.sum(input * target)
29+
return loss / input.size()[0] if size_average else loss
3030

3131
def onehot(targets, num_classes):
3232
"""Origin: https://github.com/moskomule/mixup.pytorch
@@ -44,8 +44,9 @@ def mixup(inputs, targets, num_classes, alpha=2):
4444
weight = torch.Tensor(np.random.beta(alpha, alpha, s))
4545
index = np.random.permutation(s)
4646
x1, x2 = inputs, inputs[index, :, :, :]
47-
y1, y2 = onehot(targets), onehot(targets[index,])
47+
y1, y2 = onehot(targets, num_classes), onehot(targets[index,], num_classes)
4848
weight = weight.view(s, 1, 1, 1)
4949
inputs = weight*x1 + (1-weight)*x2
5050
weight = weight.view(s, 1)
5151
targets = weight*y1 + (1-weight)*y2
52+
return inputs, targets

0 commit comments

Comments
 (0)