|
10 | 10 | from PIL import Image
|
11 | 11 | import numpy as np
|
12 | 12 | import os
|
| 13 | +import math |
13 | 14 | import shutil
|
14 | 15 |
|
15 | 16 | import torch
|
16 | 17 | import torch.nn as nn
|
| 18 | +import torch.nn.functional as F |
17 | 19 | import torch.backends.cudnn as cudnn
|
18 | 20 | from torch.optim import SGD
|
19 | 21 | from torch.optim.lr_scheduler import LambdaLR
|
20 | 22 | from torch.utils.data import DataLoader
|
21 | 23 |
|
22 | 24 | sys.path.append('../../..')
|
23 | 25 | from tllib.translation.fourier_transform import FourierTransform
|
24 |
| -from tllib.alignment.fda import robust_entropy |
25 | 26 | import tllib.vision.models.segmentation as models
|
26 | 27 | import tllib.vision.datasets.segmentation as datasets
|
27 | 28 | import tllib.vision.transforms.segmentation as T
|
|
35 | 36 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
36 | 37 |
|
37 | 38 |
|
| 39 | +def robust_entropy(y, ita=1.5, num_classes=19, reduction='mean'): |
| 40 | + """ Robust entropy proposed in `FDA: Fourier Domain Adaptation for Semantic Segmentation (CVPR 2020) <https://arxiv.org/abs/2004.05498>`_ |
| 41 | +
|
| 42 | + Args: |
| 43 | + y (tensor): logits output of segmentation model in shape of :math:`(N, C, H, W)` |
| 44 | + ita (float, optional): parameters for robust entropy. Default: 1.5 |
| 45 | + num_classes (int, optional): number of classes. Default: 19 |
| 46 | + reduction (string, optional): Specifies the reduction to apply to the output: |
| 47 | + ``'none'`` | ``'mean'``. ``'none'``: no reduction will be applied, |
| 48 | + ``'mean'``: the sum of the output will be divided by the number of |
| 49 | + elements in the output. Default: ``'mean'`` |
| 50 | +
|
| 51 | + Returns: |
| 52 | + Scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(N, )`. |
| 53 | +
|
| 54 | + """ |
| 55 | + P = F.softmax(y, dim=1) |
| 56 | + logP = F.log_softmax(y, dim=1) |
| 57 | + PlogP = P * logP |
| 58 | + ent = -1.0 * PlogP.sum(dim=1) |
| 59 | + ent = ent / math.log(num_classes) |
| 60 | + |
| 61 | + # compute robust entropy |
| 62 | + ent = ent ** 2.0 + 1e-8 |
| 63 | + ent = ent ** ita |
| 64 | + |
| 65 | + if reduction == 'mean': |
| 66 | + return ent.mean() |
| 67 | + else: |
| 68 | + return ent |
| 69 | + |
| 70 | + |
38 | 71 | def main(args: argparse.Namespace):
|
39 | 72 | logger = CompleteLogger(args.log, args.phase)
|
40 | 73 | print(args)
|
|
0 commit comments