Skip to content

Commit e5efc5f

Browse files
committed
add reweight / normalization package
1 parent dfdbd89 commit e5efc5f

File tree

22 files changed

+70
-53
lines changed

22 files changed

+70
-53
lines changed

dglib/generalization/__init__.py

Lines changed: 0 additions & 6 deletions
This file was deleted.

dglib/generalization/irm.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

dglib/modules/classifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Optional, Tuple
66
import torch
77
import torch.nn as nn
8-
from common.modules.classifier import Classifier as ClassifierBase
8+
from tllib.modules.classifier import Classifier as ClassifierBase
99

1010

1111
class ImageClassifier(ClassifierBase):

examples/domain_adaptation/partial_domain_adaptation/iwan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tllib.modules.classifier import Classifier
2323
from tllib.modules.entropy import entropy
2424
from tllib.modules.domain_discriminator import DomainDiscriminator
25-
from tllib.alignment.iwan import ImportanceWeightModule, ImageClassifier
25+
from tllib.reweight.iwan import ImportanceWeightModule, ImageClassifier
2626
from tllib.alignment.dann import DomainAdversarialLoss
2727
from tllib.utils.data import ForeverDataIterator
2828
from tllib.utils.metric import accuracy

examples/domain_adaptation/partial_domain_adaptation/pada.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from tllib.modules.domain_discriminator import DomainDiscriminator
2323
from tllib.modules.classifier import Classifier
2424
from tllib.alignment.dann import DomainAdversarialLoss, ImageClassifier
25-
from tllib.alignment.pada import AutomaticUpdateClassWeightModule
25+
from tllib.reweight.pada import AutomaticUpdateClassWeightModule
2626
from tllib.utils.data import ForeverDataIterator
2727
from tllib.utils.metric import accuracy
2828
from tllib.utils.meter import AverageMeter, ProgressMeter

examples/domain_generalization/image_classification/groupdro.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""
2+
Adapted from https://github.com/facebookresearch/DomainBed
23
@author: Baixu Chen
34
45
"""
@@ -19,9 +20,9 @@
1920
import torch.nn.functional as F
2021

2122
sys.path.append('../../..')
22-
from dglib.generalization.groupdro import AutomaticUpdateDomainWeightModule
2323
from dglib.modules.sampler import RandomDomainSampler
2424
from dglib.modules.classifier import ImageClassifier as Classifier
25+
from tllib.reweight.groupdro import AutomaticUpdateDomainWeightModule
2526
from tllib.utils.data import ForeverDataIterator
2627
from tllib.utils.metric import accuracy
2728
from tllib.utils.meter import AverageMeter, ProgressMeter

examples/domain_generalization/image_classification/irm.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""
2+
Adapted from https://github.com/facebookresearch/DomainBed
23
@author: Baixu Chen
34
45
"""
@@ -17,9 +18,9 @@
1718
from torch.optim.lr_scheduler import CosineAnnealingLR
1819
from torch.utils.data import DataLoader
1920
import torch.nn.functional as F
21+
import torch.autograd as autograd
2022

2123
sys.path.append('../../..')
22-
from dglib.generalization.irm import InvariancePenaltyLoss
2324
from dglib.modules.sampler import RandomDomainSampler
2425
from dglib.modules.classifier import ImageClassifier as Classifier
2526
from tllib.utils.data import ForeverDataIterator
@@ -34,6 +35,35 @@
3435
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
3536

3637

38+
class InvariancePenaltyLoss(nn.Module):
39+
r"""Invariance Penalty Loss from `Invariant Risk Minimization <https://arxiv.org/pdf/1907.02893.pdf>`_.
40+
We adopt implementation from `DomainBed <https://github.com/facebookresearch/DomainBed>`_. Given classifier
41+
output :math:`y` and ground truth :math:`labels`, we split :math:`y` into two parts :math:`y_1, y_2`, corresponding
42+
labels are :math:`labels_1, labels_2`. Next we calculate cross entropy loss with respect to a dummy classifier
43+
:math:`w`, resulting in :math:`grad_1, grad_2` . Invariance penalty is then :math:`grad_1*grad_2`.
44+
45+
Inputs:
46+
- y: predictions from model
47+
- labels: ground truth
48+
49+
Shape:
50+
- y: :math:`(N, C)` where C means the number of classes.
51+
- labels: :math:`(N, )` where N mean mini-batch size
52+
"""
53+
54+
def __init__(self):
55+
super(InvariancePenaltyLoss, self).__init__()
56+
self.scale = torch.tensor(1.).requires_grad_()
57+
58+
def forward(self, y: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
59+
loss_1 = F.cross_entropy(y[::2] * self.scale, labels[::2])
60+
loss_2 = F.cross_entropy(y[1::2] * self.scale, labels[1::2])
61+
grad_1 = autograd.grad(loss_1, [self.scale], create_graph=True)[0]
62+
grad_2 = autograd.grad(loss_2, [self.scale], create_graph=True)[0]
63+
penalty = torch.sum(grad_1 * grad_2)
64+
return penalty
65+
66+
3767
def main(args: argparse.Namespace):
3868
logger = CompleteLogger(args.log, args.phase)
3969
print(args)

examples/domain_generalization/image_classification/mixstyle.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import torch.nn.functional as F
2020

2121
sys.path.append('../../..')
22-
import dglib.generalization.mixstyle.models as models
22+
import tllib.modules.normalization.mixstyle.models as models
2323
from dglib.modules.sampler import RandomDomainSampler
2424
from dglib.modules.classifier import ImageClassifier as Classifier
2525
from tllib.utils.data import ForeverDataIterator

examples/domain_generalization/re_identification/mixstyle.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
from torch.utils.data import DataLoader
1919

2020
sys.path.append('../../..')
21-
from dglib.generalization.mixstyle.sampler import RandomDomainMultiInstanceSampler
22-
import dglib.generalization.mixstyle.models as models
21+
from tllib.modules.normalization.mixstyle.sampler import RandomDomainMultiInstanceSampler
22+
import tllib.modules.normalization.mixstyle.models as models
2323
from tllib.vision.models.reid.identifier import ReIdentifier
2424
from tllib.vision.models.reid.loss import CrossEntropyLossWithLabelSmooth, SoftTripletLoss
2525
import tllib.vision.datasets.reid as datasets

examples/task_adaptation/image_classification/stochnorm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch.nn.functional as F
1818

1919
sys.path.append('../../..')
20-
from tllib.regularization.stochnorm import convert_model
20+
from tllib.modules.normalization.stochnorm import convert_model
2121
from tllib.modules.classifier import Classifier
2222
from tllib.utils.metric import accuracy
2323
from tllib.utils.meter import AverageMeter, ProgressMeter

tllib/modules/loss.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copy from https://github.com/CoinCheung/pytorch-loss
22
import torch.nn as nn
33
import torch
4+
import torch.nn.functional as F
45

56

67
# version 1: use torch.autograd
@@ -45,3 +46,32 @@ def forward(self, input, target):
4546
loss = loss.sum()
4647

4748
return loss
49+
50+
51+
class KnowledgeDistillationLoss(nn.Module):
52+
"""Knowledge Distillation Loss.
53+
54+
Args:
55+
T (double): Temperature. Default: 1.
56+
reduction (str, optional): Specifies the reduction to apply to the output:
57+
``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
58+
``'mean'``: the sum of the output will be divided by the number of
59+
elements in the output, ``'sum'``: the output will be summed. Default: ``'batchmean'``
60+
61+
Inputs:
62+
- y_student (tensor): logits output of the student
63+
- y_teacher (tensor): logits output of the teacher
64+
65+
Shape:
66+
- y_student: (minibatch, `num_classes`)
67+
- y_teacher: (minibatch, `num_classes`)
68+
69+
"""
70+
def __init__(self, T=1., reduction='batchmean'):
71+
super(KnowledgeDistillationLoss, self).__init__()
72+
self.T = T
73+
self.kl = nn.KLDivLoss(reduction=reduction)
74+
75+
def forward(self, y_student, y_teacher):
76+
""""""
77+
return self.kl(F.log_softmax(y_student / self.T, dim=-1), F.softmax(y_teacher / self.T, dim=-1))

tllib/modules/normalization/mixstyle/__init__.py

Whitespace-only changes.

tllib/reweight/__init__.py

Whitespace-only changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)