Skip to content

Commit 9a47e58

Browse files
authored
[Bug Fix] fix cross_entropy bug when 255 in label (PaddlePaddle#1499) (PaddlePaddle#1586)
1 parent 7e37f98 commit 9a47e58

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

paddleseg/models/losses/cross_entropy_loss.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,8 @@ def __init__(self,
4848
self.data_format = data_format
4949
if weight is not None:
5050
self.weight = paddle.to_tensor(weight, dtype='float32')
51-
long_weight = weight + [0] * (256 - len(weight))
52-
self.long_weight = paddle.to_tensor(long_weight, dtype='float32')
5351
else:
5452
self.weight = None
55-
self.long_weight = None
5653

5754
def forward(self, logit, label, semantic_weights=None):
5855
"""
@@ -82,12 +79,13 @@ def forward(self, logit, label, semantic_weights=None):
8279
label = label.astype('int64')
8380

8481
# In F.cross_entropy, the ignore_index is invalid, which needs to be fixed.
82+
# When there is 255 in the label and paddle version <= 2.1.3, the cross_entropy OP will report an error, which is fixed in paddle develop version.
8583
loss = F.cross_entropy(
8684
logit,
8785
label,
8886
ignore_index=self.ignore_index,
8987
reduction='none',
90-
weight=self.long_weight)
88+
weight=self.weight)
9189

9290
return self._post_process_loss(logit, label, semantic_weights, loss)
9391

0 commit comments

Comments
 (0)