Skip to content

Commit 84cfe22

Browse files
committed
Update fda.py
1 parent 630b1a1 commit 84cfe22

File tree

1 file changed

+34
-1
lines changed
  • examples/domain_adaptation/semantic_segmentation

1 file changed

+34
-1
lines changed

examples/domain_adaptation/semantic_segmentation/fda.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@
1010
from PIL import Image
1111
import numpy as np
1212
import os
13+
import math
1314
import shutil
1415

1516
import torch
1617
import torch.nn as nn
18+
import torch.nn.functional as F
1719
import torch.backends.cudnn as cudnn
1820
from torch.optim import SGD
1921
from torch.optim.lr_scheduler import LambdaLR
2022
from torch.utils.data import DataLoader
2123

2224
sys.path.append('../../..')
2325
from tllib.translation.fourier_transform import FourierTransform
24-
from tllib.alignment.fda import robust_entropy
2526
import tllib.vision.models.segmentation as models
2627
import tllib.vision.datasets.segmentation as datasets
2728
import tllib.vision.transforms.segmentation as T
@@ -35,6 +36,38 @@
3536
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3637

3738

39+
def robust_entropy(y, ita=1.5, num_classes=19, reduction='mean'):
40+
""" Robust entropy proposed in `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) <https://arxiv.org/abs/2004.05498>`_
41+
42+
Args:
43+
y (tensor): logits output of segmentation model in shape of :math:`(N, C, H, W)`
44+
ita (float, optional): parameters for robust entropy. Default: 1.5
45+
num_classes (int, optional): number of classes. Default: 19
46+
reduction (string, optional): Specifies the reduction to apply to the output:
47+
``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied,
48+
``'mean'``: the sum of the output will be divided by the number of
49+
elements in the output. Default: ``'mean'``
50+
51+
Returns:
52+
Scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, )`.
53+
54+
"""
55+
P = F.softmax(y, dim=1)
56+
logP = F.log_softmax(y, dim=1)
57+
PlogP = P * logP
58+
ent = -1.0 * PlogP.sum(dim=1)
59+
ent = ent / math.log(num_classes)
60+
61+
# compute robust entropy
62+
ent = ent ** 2.0 + 1e-8
63+
ent = ent ** ita
64+
65+
if reduction == 'mean':
66+
return ent.mean()
67+
else:
68+
return ent
69+
70+
3871
def main(args: argparse.Namespace):
3972
logger = CompleteLogger(args.log, args.phase)
4073
print(args)

0 commit comments

Comments
 (0)