|
| 1 | +""" |
| 2 | +@author: Baixu Chen |
| 3 | + |
| 4 | +""" |
| 5 | +import random |
| 6 | +import time |
| 7 | +import warnings |
| 8 | +import argparse |
| 9 | +import shutil |
| 10 | +import os.path as osp |
| 11 | + |
| 12 | +import torch |
| 13 | +import torch.nn as nn |
| 14 | +import torch.backends.cudnn as cudnn |
| 15 | +from torch.optim import SGD |
| 16 | +from torch.optim.lr_scheduler import LambdaLR |
| 17 | +from torch.utils.data import DataLoader |
| 18 | +import torch.nn.functional as F |
| 19 | + |
| 20 | +import utils |
| 21 | +from tllib.modules.classifier import Classifier |
| 22 | +from tllib.self_training.pseudo_label import ConfidenceBasedSelfTrainingLoss |
| 23 | +from tllib.vision.transforms import MultipleApply |
| 24 | +from tllib.utils.data import ForeverDataIterator |
| 25 | +from tllib.utils.metric import accuracy |
| 26 | +from tllib.utils.meter import AverageMeter, ProgressMeter |
| 27 | +from tllib.utils.logger import CompleteLogger |
| 28 | +from tllib.utils.analysis import collect_feature, tsne, a_distance |
| 29 | + |
| 30 | +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 31 | + |
| 32 | + |
| 33 | +class ImageClassifier(Classifier): |
| 34 | + def __init__(self, backbone: nn.Module, num_classes: int, bottleneck_dim=1024, **kwargs): |
| 35 | + bottleneck = nn.Sequential( |
| 36 | + nn.Linear(backbone.out_features, bottleneck_dim), |
| 37 | + nn.BatchNorm1d(bottleneck_dim), |
| 38 | + nn.ReLU() |
| 39 | + ) |
| 40 | + super(ImageClassifier, self).__init__(backbone, num_classes, bottleneck, bottleneck_dim, **kwargs) |
| 41 | + |
| 42 | + def forward(self, x: torch.Tensor): |
| 43 | + """""" |
| 44 | + f = self.pool_layer(self.backbone(x)) |
| 45 | + f = self.bottleneck(f) |
| 46 | + predictions = self.head(f) |
| 47 | + return predictions |
| 48 | + |
| 49 | + |
| 50 | +def main(args: argparse.Namespace): |
| 51 | + logger = CompleteLogger(args.log, args.phase) |
| 52 | + print(args) |
| 53 | + |
| 54 | + if args.seed is not None: |
| 55 | + random.seed(args.seed) |
| 56 | + torch.manual_seed(args.seed) |
| 57 | + cudnn.deterministic = True |
| 58 | + warnings.warn('You have chosen to seed training. ' |
| 59 | + 'This will turn on the CUDNN deterministic setting, ' |
| 60 | + 'which can slow down your training considerably! ' |
| 61 | + 'You may see unexpected behavior when restarting ' |
| 62 | + 'from checkpoints.') |
| 63 | + |
| 64 | + cudnn.benchmark = True |
| 65 | + |
| 66 | + # Data loading code |
| 67 | + train_source_transform = utils.get_train_transform(args.train_resizing, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), |
| 68 | + random_horizontal_flip=not args.no_hflip, |
| 69 | + random_color_jitter=False, resize_size=args.resize_size, |
| 70 | + norm_mean=args.norm_mean, norm_std=args.norm_std) |
| 71 | + weak_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, |
| 72 | + random_horizontal_flip=not args.no_hflip, |
| 73 | + random_color_jitter=False, resize_size=args.resize_size, |
| 74 | + norm_mean=args.norm_mean, norm_std=args.norm_std) |
| 75 | + strong_augment = utils.get_train_transform(args.train_resizing, scale=args.scale, ratio=args.ratio, |
| 76 | + random_horizontal_flip=not args.no_hflip, |
| 77 | + random_color_jitter=False, resize_size=args.resize_size, |
| 78 | + norm_mean=args.norm_mean, norm_std=args.norm_std, |
| 79 | + auto_augment=args.auto_augment) |
| 80 | + train_target_transform = MultipleApply([weak_augment, strong_augment]) |
| 81 | + val_transform = utils.get_val_transform(args.val_resizing, resize_size=args.resize_size, |
| 82 | + norm_mean=args.norm_mean, norm_std=args.norm_std) |
| 83 | + |
| 84 | + print("train_source_transform: ", train_source_transform) |
| 85 | + print("train_target_transform: ", train_target_transform) |
| 86 | + print("val_transform: ", val_transform) |
| 87 | + |
| 88 | + train_source_dataset, train_target_dataset, val_dataset, test_dataset, num_classes, args.class_names = \ |
| 89 | + utils.get_dataset(args.data, args.root, args.source, args.target, train_source_transform, val_transform, |
| 90 | + train_target_transform=train_target_transform) |
| 91 | + train_source_loader = DataLoader(train_source_dataset, batch_size=args.batch_size, |
| 92 | + shuffle=True, num_workers=args.workers, drop_last=True) |
| 93 | + train_target_loader = DataLoader(train_target_dataset, batch_size=args.unlabeled_batch_size, |
| 94 | + shuffle=True, num_workers=args.workers, drop_last=True) |
| 95 | + val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) |
| 96 | + test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.workers) |
| 97 | + |
| 98 | + train_source_iter = ForeverDataIterator(train_source_loader) |
| 99 | + train_target_iter = ForeverDataIterator(train_target_loader) |
| 100 | + |
| 101 | + # create model |
| 102 | + print("=> using model '{}'".format(args.arch)) |
| 103 | + backbone = utils.get_model(args.arch, pretrain=not args.scratch) |
| 104 | + pool_layer = nn.Identity() if args.no_pool else None |
| 105 | + classifier = ImageClassifier(backbone, num_classes, bottleneck_dim=args.bottleneck_dim, |
| 106 | + pool_layer=pool_layer, finetune=not args.scratch).to(device) |
| 107 | + print(classifier) |
| 108 | + |
| 109 | + # define optimizer and lr scheduler |
| 110 | + optimizer = SGD(classifier.get_parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay, |
| 111 | + nesterov=True) |
| 112 | + lr_scheduler = LambdaLR(optimizer, lambda x: args.lr * (1. + args.lr_gamma * float(x)) ** (-args.lr_decay)) |
| 113 | + |
| 114 | + # resume from the best checkpoint |
| 115 | + if args.phase != 'train': |
| 116 | + checkpoint = torch.load(logger.get_checkpoint_path('best'), map_location='cpu') |
| 117 | + classifier.load_state_dict(checkpoint) |
| 118 | + |
| 119 | + # analysis the model |
| 120 | + if args.phase == 'analysis': |
| 121 | + # extract features from both domains |
| 122 | + feature_extractor = nn.Sequential(classifier.backbone, classifier.pool_layer, classifier.bottleneck).to(device) |
| 123 | + source_feature = collect_feature(train_source_loader, feature_extractor, device) |
| 124 | + target_feature = collect_feature(train_target_loader, feature_extractor, device) |
| 125 | + # plot t-SNE |
| 126 | + tSNE_filename = osp.join(logger.visualize_directory, 'TSNE.pdf') |
| 127 | + tsne.visualize(source_feature, target_feature, tSNE_filename) |
| 128 | + print("Saving t-SNE to", tSNE_filename) |
| 129 | + # calculate A-distance, which is a measure for distribution discrepancy |
| 130 | + A_distance = a_distance.calculate(source_feature, target_feature, device) |
| 131 | + print("A-distance =", A_distance) |
| 132 | + return |
| 133 | + |
| 134 | + if args.phase == 'test': |
| 135 | + acc1 = utils.validate(test_loader, classifier, args, device) |
| 136 | + print(acc1) |
| 137 | + return |
| 138 | + |
| 139 | + # start training |
| 140 | + best_acc1 = 0. |
| 141 | + for epoch in range(args.epochs): |
| 142 | + print("lr:", lr_scheduler.get_last_lr()) |
| 143 | + # train for one epoch |
| 144 | + train(train_source_iter, train_target_iter, classifier, optimizer, lr_scheduler, epoch, args) |
| 145 | + |
| 146 | + # evaluate on validation set |
| 147 | + acc1 = utils.validate(val_loader, classifier, args, device) |
| 148 | + |
| 149 | + # remember best acc@1 and save checkpoint |
| 150 | + torch.save(classifier.state_dict(), logger.get_checkpoint_path('latest')) |
| 151 | + if acc1 > best_acc1: |
| 152 | + shutil.copy(logger.get_checkpoint_path('latest'), logger.get_checkpoint_path('best')) |
| 153 | + best_acc1 = max(acc1, best_acc1) |
| 154 | + |
| 155 | + print("best_acc1 = {:3.1f}".format(best_acc1)) |
| 156 | + |
| 157 | + # evaluate on test set |
| 158 | + classifier.load_state_dict(torch.load(logger.get_checkpoint_path('best'))) |
| 159 | + acc1 = utils.validate(test_loader, classifier, args, device) |
| 160 | + print("test_acc1 = {:3.1f}".format(acc1)) |
| 161 | + |
| 162 | + logger.close() |
| 163 | + |
| 164 | + |
| 165 | +def train(train_source_iter: ForeverDataIterator, train_target_iter: ForeverDataIterator, |
| 166 | + model: ImageClassifier, optimizer: SGD, lr_scheduler: LambdaLR, epoch: int, args: argparse.Namespace): |
| 167 | + batch_time = AverageMeter('Time', ':5.2f') |
| 168 | + data_time = AverageMeter('Data', ':5.2f') |
| 169 | + cls_losses = AverageMeter('Cls Loss', ':6.2f') |
| 170 | + self_training_losses = AverageMeter('Self Training Loss', ':6.2f') |
| 171 | + losses = AverageMeter('Loss', ':6.2f') |
| 172 | + cls_accs = AverageMeter('Cls Acc', ':3.1f') |
| 173 | + pseudo_label_ratios = AverageMeter('Pseudo Label Ratio', ':3.1f') |
| 174 | + pseudo_label_accs = AverageMeter('Pseudo Label Acc', ':3.1f') |
| 175 | + |
| 176 | + progress = ProgressMeter( |
| 177 | + args.iters_per_epoch, |
| 178 | + [batch_time, data_time, losses, cls_losses, self_training_losses, cls_accs, pseudo_label_accs, |
| 179 | + pseudo_label_ratios], |
| 180 | + prefix="Epoch: [{}]".format(epoch)) |
| 181 | + |
| 182 | + self_training_criterion = ConfidenceBasedSelfTrainingLoss(args.threshold).to(device) |
| 183 | + # switch to train mode |
| 184 | + model.train() |
| 185 | + |
| 186 | + end = time.time() |
| 187 | + for i in range(args.iters_per_epoch): |
| 188 | + x_s, labels_s = next(train_source_iter)[:2] |
| 189 | + (x_t, x_t_strong), labels_t = next(train_target_iter)[:2] |
| 190 | + |
| 191 | + x_s = x_s.to(device) |
| 192 | + x_t = x_t.to(device) |
| 193 | + x_t_strong = x_t_strong.to(device) |
| 194 | + labels_s = labels_s.to(device) |
| 195 | + labels_t = labels_t.to(device) |
| 196 | + |
| 197 | + # measure data loading time |
| 198 | + data_time.update(time.time() - end) |
| 199 | + |
| 200 | + # clear grad |
| 201 | + optimizer.zero_grad() |
| 202 | + |
| 203 | + # compute output |
| 204 | + with torch.no_grad(): |
| 205 | + y_t = model(x_t) |
| 206 | + |
| 207 | + # cross entropy loss |
| 208 | + y_s = model(x_s) |
| 209 | + cls_loss = F.cross_entropy(y_s, labels_s) |
| 210 | + cls_loss.backward() |
| 211 | + |
| 212 | + # self-training loss |
| 213 | + y_t_strong = model(x_t_strong) |
| 214 | + self_training_loss, mask, pseudo_labels = self_training_criterion(y_t_strong, y_t) |
| 215 | + self_training_loss = args.trade_off * self_training_loss |
| 216 | + self_training_loss.backward() |
| 217 | + |
| 218 | + # measure accuracy and record loss |
| 219 | + loss = cls_loss + self_training_loss |
| 220 | + losses.update(loss.item(), x_s.size(0)) |
| 221 | + cls_losses.update(cls_loss.item(), x_s.size(0)) |
| 222 | + self_training_losses.update(self_training_loss.item(), x_s.size(0)) |
| 223 | + |
| 224 | + cls_acc = accuracy(y_s, labels_s)[0] |
| 225 | + cls_accs.update(cls_acc.item(), x_s.size(0)) |
| 226 | + |
| 227 | + # ratio of pseudo labels |
| 228 | + n_pseudo_labels = mask.sum() |
| 229 | + ratio = n_pseudo_labels / x_t.size(0) |
| 230 | + pseudo_label_ratios.update(ratio.item() * 100, x_t.size(0)) |
| 231 | + |
| 232 | + # accuracy of pseudo labels |
| 233 | + if n_pseudo_labels > 0: |
| 234 | + pseudo_labels = pseudo_labels * mask - (1 - mask) |
| 235 | + n_correct = (pseudo_labels == labels_t).float().sum() |
| 236 | + pseudo_label_acc = n_correct / n_pseudo_labels * 100 |
| 237 | + pseudo_label_accs.update(pseudo_label_acc.item(), n_pseudo_labels) |
| 238 | + |
| 239 | + # compute gradient and do SGD step |
| 240 | + optimizer.step() |
| 241 | + lr_scheduler.step() |
| 242 | + |
| 243 | + # measure elapsed time |
| 244 | + batch_time.update(time.time() - end) |
| 245 | + end = time.time() |
| 246 | + |
| 247 | + if i % args.print_freq == 0: |
| 248 | + progress.display(i) |
| 249 | + |
| 250 | + |
| 251 | +if __name__ == '__main__': |
| 252 | + parser = argparse.ArgumentParser(description='FixMatch for Unsupervised Domain Adaptation') |
| 253 | + # dataset parameters |
| 254 | + parser.add_argument('root', metavar='DIR', |
| 255 | + help='root path of dataset') |
| 256 | + parser.add_argument('-d', '--data', metavar='DATA', default='Office31', choices=utils.get_dataset_names(), |
| 257 | + help='dataset: ' + ' | '.join(utils.get_dataset_names()) + |
| 258 | + ' (default: Office31)') |
| 259 | + parser.add_argument('-s', '--source', help='source domain(s)', nargs='+') |
| 260 | + parser.add_argument('-t', '--target', help='target domain(s)', nargs='+') |
| 261 | + parser.add_argument('--train-resizing', type=str, default='default') |
| 262 | + parser.add_argument('--val-resizing', type=str, default='default') |
| 263 | + parser.add_argument('--resize-size', type=int, default=224, |
| 264 | + help='the image size after resizing') |
| 265 | + parser.add_argument('--scale', type=float, nargs='+', default=[0.5, 1.0], metavar='PCT', |
| 266 | + help='Random resize scale (default: 0.5 1.0)') |
| 267 | + parser.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', |
| 268 | + help='Random resize aspect ratio (default: 0.75 1.33)') |
| 269 | + parser.add_argument('--no-hflip', action='store_true', |
| 270 | + help='no random horizontal flipping during training') |
| 271 | + parser.add_argument('--norm-mean', type=float, nargs='+', |
| 272 | + default=(0.485, 0.456, 0.406), help='normalization mean') |
| 273 | + parser.add_argument('--norm-std', type=float, nargs='+', |
| 274 | + default=(0.229, 0.224, 0.225), help='normalization std') |
| 275 | + parser.add_argument('--auto-augment', default='rand-m10-n2-mstd2', type=str, |
| 276 | + help='AutoAugment policy (default: rand-m10-n2-mstd2)') |
| 277 | + # model parameters |
| 278 | + parser.add_argument('-a', '--arch', metavar='ARCH', default='resnet18', |
| 279 | + choices=utils.get_model_names(), |
| 280 | + help='backbone architecture: ' + |
| 281 | + ' | '.join(utils.get_model_names()) + |
| 282 | + ' (default: resnet18)') |
| 283 | + parser.add_argument('--bottleneck-dim', default=1024, type=int, |
| 284 | + help='Dimension of bottleneck') |
| 285 | + parser.add_argument('--no-pool', action='store_true', |
| 286 | + help='no pool layer after the feature extractor.') |
| 287 | + parser.add_argument('--scratch', action='store_true', help='whether train from scratch.') |
| 288 | + parser.add_argument('--trade-off', default=1., type=float, |
| 289 | + help='the trade-off hyper-parameter for transfer loss') |
| 290 | + # training parameters |
| 291 | + parser.add_argument('-b', '--batch-size', default=32, type=int, |
| 292 | + metavar='N', |
| 293 | + help='mini-batch size (default: 32)') |
| 294 | + parser.add_argument('-ub', '--unlabeled-batch-size', default=32, type=int, |
| 295 | + help='mini-batch size of unlabeled data (target domain) (default: 32)') |
| 296 | + parser.add_argument('--threshold', default=0.9, type=float, |
| 297 | + help='confidence threshold') |
| 298 | + parser.add_argument('--lr', '--learning-rate', default=0.003, type=float, |
| 299 | + metavar='LR', help='initial learning rate', dest='lr') |
| 300 | + parser.add_argument('--lr-gamma', default=0.0004, type=float, help='parameter for lr scheduler') |
| 301 | + parser.add_argument('--lr-decay', default=0.75, type=float, help='parameter for lr scheduler') |
| 302 | + parser.add_argument('--momentum', default=0.9, type=float, metavar='M', |
| 303 | + help='momentum') |
| 304 | + parser.add_argument('--wd', '--weight-decay', default=1e-3, type=float, |
| 305 | + metavar='W', help='weight decay (default: 1e-3)', |
| 306 | + dest='weight_decay') |
| 307 | + parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', |
| 308 | + help='number of data loading workers (default: 4)') |
| 309 | + parser.add_argument('--epochs', default=20, type=int, metavar='N', |
| 310 | + help='number of total epochs to run') |
| 311 | + parser.add_argument('-i', '--iters-per-epoch', default=1000, type=int, |
| 312 | + help='Number of iterations per epoch') |
| 313 | + parser.add_argument('-p', '--print-freq', default=100, type=int, |
| 314 | + metavar='N', help='print frequency (default: 100)') |
| 315 | + parser.add_argument('--seed', default=None, type=int, |
| 316 | + help='seed for initializing training. ') |
| 317 | + parser.add_argument('--per-class-eval', action='store_true', |
| 318 | + help='whether output per-class accuracy during evaluation') |
| 319 | + parser.add_argument("--log", type=str, default='fixmatch', |
| 320 | + help="Where to save logs, checkpoints and debugging images.") |
| 321 | + parser.add_argument("--phase", type=str, default='train', choices=['train', 'test', 'analysis'], |
| 322 | + help="When phase is 'test', only test the model." |
| 323 | + "When phase is 'analysis', only analysis the model.") |
| 324 | + args = parser.parse_args() |
| 325 | + main(args) |
0 commit comments