Skip to content

Commit db2290d

Browse files
committed
Implementation of FixMatch on domain adaptation classification tasks
also results
1 parent 899189e commit db2290d

File tree

3 files changed

+363
-0
lines changed

3 files changed

+363
-0
lines changed

examples/domain_adaptation/image_classification/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ Supported methods include:
4646
- [Batch Spectral Penalization (BSP)](http://ise.thss.tsinghua.edu.cn/~mlong/doc/batch-spectral-penalization-icml19.pdf)
4747
- [Margin Disparity Discrepancy (MDD)](https://arxiv.org/abs/1904.05801)
4848
- [Minimum Class Confusion (MCC)](https://arxiv.org/abs/1912.03699)
49+
- [FixMatch](https://arxiv.org/abs/2001.07685)
4950

5051
## Usage
5152

@@ -106,6 +107,7 @@ for three times and report their average accuracy.
106107
| AFN | 85.7 | 88.6 | 94.0 | 98.9 | 100.0 | 94.4 | 72.9 | 71.1 |
107108
| MDD | 88.9 | 89.6 | 95.6 | 98.6 | 100.0 | 94.4 | 76.6 | 72.2 |
108109
| MCC | 89.4 | 89.6 | 94.1 | 98.4 | 99.8 | 95.6 | 75.5 | 74.2 |
110+
| FixMatch| / | 86.4 | 86.4 | 98.2 | 100.0 | 95.4 | 70.0 | 68.1 |
109111

110112
### Office-Home accuracy on ResNet-50
111113

@@ -122,6 +124,7 @@ for three times and report their average accuracy.
122124
| AFN | 67.3 | 68.2 | 53.2 | 72.7 | 76.8 | 65.0 | 71.3 | 72.3 | 65.0 | 51.4 | 77.9 | 72.3 | 57.8 | 82.4 |
123125
| MDD | 68.1 | 69.7 | 56.2 | 75.4 | 79.6 | 63.5 | 72.1 | 73.8 | 62.5 | 54.8 | 79.9 | 73.5 | 60.9 | 84.5 |
124126
| MCC | / | 72.4 | 58.4 | 79.6 | 83.0 | 67.5 | 77.0 | 78.5 | 66.6 | 54.8 | 81.8 | 74.4 | 61.4 | 85.6 |
127+
| FixMatch | / | 70.8 | 56.4 | 76.4 | 79.9 | 65.3 | 73.8 | 71.2 | 67.2 | 56.4 | 80.6 | 74.9 | 63.5 | 84.3 |
125128

126129
### Office-Home accuracy on vit_base_patch16_224 (batch size 24)
127130

@@ -151,6 +154,7 @@ for three times and report their average accuracy.
151154
| AFN | 76.1 | 75.0 | 95.6 | 56.2 | 81.3 | 69.8 | 93.0 | 81.0 | 93.4 | 74.1 | 91.7 | 55.0 | 90.6 | 18.1 | 74.4 |
152155
| MDD | / | 82.0 | 88.3 | 62.8 | 85.2 | 69.9 | 91.9 | 95.1 | 94.4 | 81.2 | 93.8 | 89.8 | 84.1 | 47.9 | 79.8 |
153156
| MCC | 78.8 | 83.6 | 95.3 | 85.8 | 77.1 | 68.0 | 93.9 | 92.9 | 84.5 | 79.5 | 93.6 | 93.7 | 85.3 | 53.8 | 80.4 |
157+
| FixMatch | / | 77.5 | 94.3 | 75.1 | 72.4 | 87.2 | 95.6 | 88.5 | 92.2 | 37.5 | 97.3 | 84.6 | 88.0 | 17.0 | 77.1 |
154158

155159
### DomainNet accuracy on ResNet-101
156160

@@ -293,4 +297,12 @@ If you use these methods in your research, please consider citing.
293297
year={2020},
294298
booktitle={ECCV},
295299
}
300+
301+
@inproceedings{FixMatch,
302+
title={Fixmatch: Simplifying semi-supervised learning with consistency and confidence},
303+
author={Sohn, Kihyuk and Berthelot, David and Carlini, Nicholas and Zhang, Zizhao and Zhang, Han and Raffel, Colin A and Cubuk, Ekin Dogus and Kurakin, Alexey and Li, Chun-Liang},
304+
booktitle={NIPS},
305+
year={2020}
306+
}
307+
296308
```
Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
"""
2+
@author: Baixu Chen
3+
4+
"""
5+
import random
6+
import time
7+
import warnings
8+
import argparse
9+
import shutil
10+
import os.path as osp
11+
12+
import torch
13+
import torch.nn as nn
14+
import torch.backends.cudnn as cudnn
15+
from torch.optim import SGD
16+
from torch.optim.lr_scheduler import LambdaLR
17+
from torch.utils.data import DataLoader
18+
import torch.nn.functional as F
19+
20+
import utils
21+
from tllib.modules.classifier import Classifier
22+
from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss
23+
from tllib.vision.transforms import MultipleApply
24+
from tllib.utils.data import ForeverDataIterator
25+
from tllib.utils.metric import accuracy
26+
from tllib.utils.meter import AverageMeter, ProgressMeter
27+
from tllib.utils.logger import CompleteLogger
28+
from tllib.utils.analysis import collect_feature, tsne, a_distance
29+
30+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31+
32+
33+
class ImageClassifier(Classifier):
34+
def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs):
35+
bottleneck = nn.Sequential(
36+
nn.Linear(backbone.out_features, bottleneck_dim),
37+
nn.BatchNorm1d(bottleneck_dim),
38+
nn.ReLU()
39+
)
40+
super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs)
41+
42+
def forward(self, x: torch.Tensor):
43+
""""""
44+
f = self.pool_layer(self.backbone(x))
45+
f = self.bottleneck(f)
46+
predictions = self.head(f)
47+
return predictions
48+
49+
50+
def main(args: argparse.Namespace):
51+
logger = CompleteLogger(args.log, args.phase)
52+
print(args)
53+
54+
if args.seed is not None:
55+
random.seed(args.seed)
56+
torch.manual_seed(args.seed)
57+
cudnn.deterministic = True
58+
warnings.warn('You have chosen to seed training. '
59+
'This will turn on the CUDNN deterministic setting, '
60+
'which can slow down your training considerably! '
61+
'You may see unexpected behavior when restarting '
62+
'from checkpoints.')
63+
64+
cudnn.benchmark = True
65+
66+
# Data loading code
67+
train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.),
68+
random_horizontal_flip=not args.no_hflip,
69+
random_color_jitter=False, resize_size=args.resize_size,
70+
norm_mean=args.norm_mean, norm_std=args.norm_std)
71+
weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
72+
random_horizontal_flip=not args.no_hflip,
73+
random_color_jitter=False, resize_size=args.resize_size,
74+
norm_mean=args.norm_mean, norm_std=args.norm_std)
75+
strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio,
76+
random_horizontal_flip=not args.no_hflip,
77+
random_color_jitter=False, resize_size=args.resize_size,
78+
norm_mean=args.norm_mean, norm_std=args.norm_std,
79+
auto_augment=args.auto_augment)
80+
train_target_transform = MultipleApply([weak_augment, strong_augment])
81+
val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size,
82+
norm_mean=args.norm_mean, norm_std=args.norm_std)
83+
84+
print("train_source_transform: ", train_source_transform)
85+
print("train_target_transform: ", train_target_transform)
86+
print("val_transform: ", val_transform)
87+
88+
train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \
89+
utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform,
90+
train_target_transform=train_target_transform)
91+
train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size,
92+
shuffle=True, num_workers=args.workers, drop_last=True)
93+
train_target_loader = DataLoader(train_target_dataset, batch_size=args.unlabeled_batch_size,
94+
shuffle=True, num_workers=args.workers, drop_last=True)
95+
val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
96+
test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers)
97+
98+
train_source_iter = ForeverDataIterator(train_source_loader)
99+
train_target_iter = ForeverDataIterator(train_target_loader)
100+
101+
# create model
102+
print("=> using model '{}'".format(args.arch))
103+
backbone = utils.get_model(args.arch, pretrain=not args.scratch)
104+
pool_layer = nn.Identity() if args.no_pool else None
105+
classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim,
106+
pool_layer=pool_layer, finetune=not args.scratch).to(device)
107+
print(classifier)
108+
109+
# define optimizer and lr scheduler
110+
optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
111+
nesterov=True)
112+
lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay))
113+
114+
# resume from the best checkpoint
115+
if args.phase != 'train':
116+
checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu')
117+
classifier.load_state_dict(checkpoint)
118+
119+
# analysis the model
120+
if args.phase == 'analysis':
121+
# extract features from both domains
122+
feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device)
123+
source_feature = collect_feature(train_source_loader, feature_extractor, device)
124+
target_feature = collect_feature(train_target_loader, feature_extractor, device)
125+
# plot t-SNE
126+
tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf')
127+
tsne.visualize(source_feature, target_feature, tSNE_filename)
128+
print("Saving t-SNE to", tSNE_filename)
129+
# calculate A-distance, which is a measure for distribution discrepancy
130+
A_distance = a_distance.calculate(source_feature, target_feature, device)
131+
print("A-distance =", A_distance)
132+
return
133+
134+
if args.phase == 'test':
135+
acc1 = utils.validate(test_loader, classifier, args, device)
136+
print(acc1)
137+
return
138+
139+
# start training
140+
best_acc1 = 0.
141+
for epoch in range(args.epochs):
142+
print("lr:", lr_scheduler.get_last_lr())
143+
# train for one epoch
144+
train(train_source_iter, train_target_iter, classifier, optimizer, lr_scheduler, epoch, args)
145+
146+
# evaluate on validation set
147+
acc1 = utils.validate(val_loader, classifier, args, device)
148+
149+
# remember best acc@1 and save checkpoint
150+
torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest'))
151+
if acc1 > best_acc1:
152+
shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best'))
153+
best_acc1 = max(acc1, best_acc1)
154+
155+
print("best_acc1 = {:3.1f}".format(best_acc1))
156+
157+
# evaluate on test set
158+
classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best')))
159+
acc1 = utils.validate(test_loader, classifier, args, device)
160+
print("test_acc1 = {:3.1f}".format(acc1))
161+
162+
logger.close()
163+
164+
165+
def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator,
166+
model: ImageClassifier, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace):
167+
batch_time = AverageMeter('Time', ':5.2f')
168+
data_time = AverageMeter('Data', ':5.2f')
169+
cls_losses = AverageMeter('Cls Loss', ':6.2f')
170+
self_training_losses = AverageMeter('Self Training Loss', ':6.2f')
171+
losses = AverageMeter('Loss', ':6.2f')
172+
cls_accs = AverageMeter('Cls Acc', ':3.1f')
173+
pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f')
174+
pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f')
175+
176+
progress = ProgressMeter(
177+
args.iters_per_epoch,
178+
[batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs,
179+
pseudo_label_ratios],
180+
prefix="Epoch: [{}]".format(epoch))
181+
182+
self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device)
183+
# switch to train mode
184+
model.train()
185+
186+
end = time.time()
187+
for i in range(args.iters_per_epoch):
188+
x_s, labels_s = next(train_source_iter)[:2]
189+
(x_t, x_t_strong), labels_t = next(train_target_iter)[:2]
190+
191+
x_s = x_s.to(device)
192+
x_t = x_t.to(device)
193+
x_t_strong = x_t_strong.to(device)
194+
labels_s = labels_s.to(device)
195+
labels_t = labels_t.to(device)
196+
197+
# measure data loading time
198+
data_time.update(time.time() - end)
199+
200+
# clear grad
201+
optimizer.zero_grad()
202+
203+
# compute output
204+
with torch.no_grad():
205+
y_t = model(x_t)
206+
207+
# cross entropy loss
208+
y_s = model(x_s)
209+
cls_loss = F.cross_entropy(y_s, labels_s)
210+
cls_loss.backward()
211+
212+
# self-training loss
213+
y_t_strong = model(x_t_strong)
214+
self_training_loss, mask, pseudo_labels = self_training_criterion(y_t_strong, y_t)
215+
self_training_loss = args.trade_off * self_training_loss
216+
self_training_loss.backward()
217+
218+
# measure accuracy and record loss
219+
loss = cls_loss + self_training_loss
220+
losses.update(loss.item(), x_s.size(0))
221+
cls_losses.update(cls_loss.item(), x_s.size(0))
222+
self_training_losses.update(self_training_loss.item(), x_s.size(0))
223+
224+
cls_acc = accuracy(y_s, labels_s)[0]
225+
cls_accs.update(cls_acc.item(), x_s.size(0))
226+
227+
# ratio of pseudo labels
228+
n_pseudo_labels = mask.sum()
229+
ratio = n_pseudo_labels / x_t.size(0)
230+
pseudo_label_ratios.update(ratio.item() * 100, x_t.size(0))
231+
232+
# accuracy of pseudo labels
233+
if n_pseudo_labels > 0:
234+
pseudo_labels = pseudo_labels * mask - (1 - mask)
235+
n_correct = (pseudo_labels == labels_t).float().sum()
236+
pseudo_label_acc = n_correct / n_pseudo_labels * 100
237+
pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels)
238+
239+
# compute gradient and do SGD step
240+
optimizer.step()
241+
lr_scheduler.step()
242+
243+
# measure elapsed time
244+
batch_time.update(time.time() - end)
245+
end = time.time()
246+
247+
if i % args.print_freq == 0:
248+
progress.display(i)
249+
250+
251+
if __name__ == '__main__':
252+
parser = argparse.ArgumentParser(description='FixMatch for Unsupervised Domain Adaptation')
253+
# dataset parameters
254+
parser.add_argument('root', metavar='DIR',
255+
help='root path of dataset')
256+
parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(),
257+
help='dataset: ' + ' | '.join(utils.get_dataset_names()) +
258+
' (default: Office31)')
259+
parser.add_argument('-s', '--source', help='source domain(s)', nargs='+')
260+
parser.add_argument('-t', '--target', help='target domain(s)', nargs='+')
261+
parser.add_argument('--train-resizing', type=str, default='default')
262+
parser.add_argument('--val-resizing', type=str, default='default')
263+
parser.add_argument('--resize-size', type=int, default=224,
264+
help='the image size after resizing')
265+
parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT',
266+
help='Random resize scale (default: 0.5 1.0)')
267+
parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
268+
help='Random resize aspect ratio (default: 0.75 1.33)')
269+
parser.add_argument('--no-hflip', action='store_true',
270+
help='no random horizontal flipping during training')
271+
parser.add_argument('--norm-mean', type=float, nargs='+',
272+
default=(0.485, 0.456, 0.406), help='normalization mean')
273+
parser.add_argument('--norm-std', type=float, nargs='+',
274+
default=(0.229, 0.224, 0.225), help='normalization std')
275+
parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str,
276+
help='AutoAugment policy (default: rand-m10-n2-mstd2)')
277+
# model parameters
278+
parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18',
279+
choices=utils.get_model_names(),
280+
help='backbone architecture: ' +
281+
' | '.join(utils.get_model_names()) +
282+
' (default: resnet18)')
283+
parser.add_argument('--bottleneck-dim', default=1024, type=int,
284+
help='Dimension of bottleneck')
285+
parser.add_argument('--no-pool', action='store_true',
286+
help='no pool layer after the feature extractor.')
287+
parser.add_argument('--scratch', action='store_true', help='whether train from scratch.')
288+
parser.add_argument('--trade-off', default=1., type=float,
289+
help='the trade-off hyper-parameter for transfer loss')
290+
# training parameters
291+
parser.add_argument('-b', '--batch-size', default=32, type=int,
292+
metavar='N',
293+
help='mini-batch size (default: 32)')
294+
parser.add_argument('-ub', '--unlabeled-batch-size', default=32, type=int,
295+
help='mini-batch size of unlabeled data (target domain) (default: 32)')
296+
parser.add_argument('--threshold', default=0.9, type=float,
297+
help='confidence threshold')
298+
parser.add_argument('--lr', '--learning-rate', default=0.003, type=float,
299+
metavar='LR', help='initial learning rate', dest='lr')
300+
parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler')
301+
parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler')
302+
parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
303+
help='momentum')
304+
parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float,
305+
metavar='W', help='weight decay (default: 1e-3)',
306+
dest='weight_decay')
307+
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
308+
help='number of data loading workers (default: 4)')
309+
parser.add_argument('--epochs', default=20, type=int, metavar='N',
310+
help='number of total epochs to run')
311+
parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int,
312+
help='Number of iterations per epoch')
313+
parser.add_argument('-p', '--print-freq', default=100, type=int,
314+
metavar='N', help='print frequency (default: 100)')
315+
parser.add_argument('--seed', default=None, type=int,
316+
help='seed for initializing training. ')
317+
parser.add_argument('--per-class-eval', action='store_true',
318+
help='whether output per-class accuracy during evaluation')
319+
parser.add_argument("--log", type=str, default='fixmatch',
320+
help="Where to save logs, checkpoints and debugging images.")
321+
parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'],
322+
help="When phase is 'test', only test the model."
323+
"When phase is 'analysis', only analysis the model.")
324+
args = parser.parse_args()
325+
main(args)

0 commit comments

Comments
 (0)