|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +import os |
| 15 | +import sys |
| 16 | +import pickle |
| 17 | +import argparse |
| 18 | + |
| 19 | +import torch |
| 20 | +from torch import nn |
| 21 | +import torch.distributed as dist |
| 22 | +import torch.backends.cudnn as cudnn |
| 23 | +from torchvision import models as torchvision_models |
| 24 | +from torchvision import transforms as pth_transforms |
| 25 | +from PIL import Image, ImageFile |
| 26 | +import numpy as np |
| 27 | + |
| 28 | +import utils |
| 29 | +import vision_transformer as vits |
| 30 | +from eval_knn import extract_features |
| 31 | + |
| 32 | + |
| 33 | +class CopydaysDataset(): |
| 34 | + def __init__(self, basedir): |
| 35 | + self.basedir = basedir |
| 36 | + self.block_names = ( |
| 37 | + ['original', 'strong'] + |
| 38 | + ['jpegqual/%d' % i for i in |
| 39 | + [3, 5, 8, 10, 15, 20, 30, 50, 75]] + |
| 40 | + ['crops/%d' % i for i in |
| 41 | + [10, 15, 20, 30, 40, 50, 60, 70, 80]]) |
| 42 | + self.nblocks = len(self.block_names) |
| 43 | + |
| 44 | + self.query_blocks = range(self.nblocks) |
| 45 | + self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157 |
| 46 | + self.q_block_sizes[1] = 229 |
| 47 | + # search only among originals |
| 48 | + self.database_blocks = [0] |
| 49 | + |
| 50 | + def get_block(self, i): |
| 51 | + dirname = self.basedir + '/' + self.block_names[i] |
| 52 | + fnames = [dirname + '/' + fname |
| 53 | + for fname in sorted(os.listdir(dirname)) |
| 54 | + if fname.endswith('.jpg')] |
| 55 | + return fnames |
| 56 | + |
| 57 | + def get_block_filenames(self, subdir_name): |
| 58 | + dirname = self.basedir + '/' + subdir_name |
| 59 | + return [fname |
| 60 | + for fname in sorted(os.listdir(dirname)) |
| 61 | + if fname.endswith('.jpg')] |
| 62 | + |
| 63 | + def eval_result(self, ids, distances): |
| 64 | + j0 = 0 |
| 65 | + for i in range(self.nblocks): |
| 66 | + j1 = j0 + self.q_block_sizes[i] |
| 67 | + block_name = self.block_names[i] |
| 68 | + I = ids[j0:j1] # block size |
| 69 | + sum_AP = 0 |
| 70 | + if block_name != 'strong': |
| 71 | + # 1:1 mapping of files to names |
| 72 | + positives_per_query = [[i] for i in range(j1 - j0)] |
| 73 | + else: |
| 74 | + originals = self.get_block_filenames('original') |
| 75 | + strongs = self.get_block_filenames('strong') |
| 76 | + |
| 77 | + # check if prefixes match |
| 78 | + positives_per_query = [ |
| 79 | + [j for j, bname in enumerate(originals) |
| 80 | + if bname[:4] == qname[:4]] |
| 81 | + for qname in strongs] |
| 82 | + |
| 83 | + for qno, Iline in enumerate(I): |
| 84 | + positives = positives_per_query[qno] |
| 85 | + ranks = [] |
| 86 | + for rank, bno in enumerate(Iline): |
| 87 | + if bno in positives: |
| 88 | + ranks.append(rank) |
| 89 | + sum_AP += score_ap_from_ranks_1(ranks, len(positives)) |
| 90 | + |
| 91 | + print("eval on %s mAP=%.3f" % ( |
| 92 | + block_name, sum_AP / (j1 - j0))) |
| 93 | + j0 = j1 |
| 94 | + |
| 95 | + |
| 96 | +# from the Holidays evaluation package |
| 97 | +def score_ap_from_ranks_1(ranks, nres): |
| 98 | + """ Compute the average precision of one search. |
| 99 | + ranks = ordered list of ranks of true positives |
| 100 | + nres = total number of positives in dataset |
| 101 | + """ |
| 102 | + |
| 103 | + # accumulate trapezoids in PR-plot |
| 104 | + ap = 0.0 |
| 105 | + |
| 106 | + # All have an x-size of: |
| 107 | + recall_step = 1.0 / nres |
| 108 | + |
| 109 | + for ntp, rank in enumerate(ranks): |
| 110 | + |
| 111 | + # y-size on left side of trapezoid: |
| 112 | + # ntp = nb of true positives so far |
| 113 | + # rank = nb of retrieved items so far |
| 114 | + if rank == 0: |
| 115 | + precision_0 = 1.0 |
| 116 | + else: |
| 117 | + precision_0 = ntp / float(rank) |
| 118 | + |
| 119 | + # y-size on right side of trapezoid: |
| 120 | + # ntp and rank are increased by one |
| 121 | + precision_1 = (ntp + 1) / float(rank + 1) |
| 122 | + |
| 123 | + ap += (precision_1 + precision_0) * recall_step / 2.0 |
| 124 | + |
| 125 | + return ap |
| 126 | + |
| 127 | + |
| 128 | +class ImgListDataset(torch.utils.data.Dataset): |
| 129 | + def __init__(self, img_list, transform=None): |
| 130 | + self.samples = img_list |
| 131 | + self.transform = transform |
| 132 | + |
| 133 | + def __getitem__(self, i): |
| 134 | + with open(self.samples[i], 'rb') as f: |
| 135 | + img = Image.open(f) |
| 136 | + img = img.convert('RGB') |
| 137 | + if self.transform is not None: |
| 138 | + img = self.transform(img) |
| 139 | + return img, i |
| 140 | + |
| 141 | + def __len__(self): |
| 142 | + return len(self.samples) |
| 143 | + |
| 144 | + |
| 145 | +def is_image_file(s): |
| 146 | + ext = s.split(".")[-1] |
| 147 | + if ext in ['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif', 'tiff', 'webp']: |
| 148 | + return True |
| 149 | + return False |
| 150 | + |
| 151 | + |
| 152 | +@torch.no_grad() |
| 153 | +def extract_features(image_list, model, args): |
| 154 | + transform = pth_transforms.Compose([ |
| 155 | + pth_transforms.Resize((args.imsize, args.imsize), interpolation=3), |
| 156 | + pth_transforms.ToTensor(), |
| 157 | + pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
| 158 | + ]) |
| 159 | + tempdataset = ImgListDataset(image_list, transform=transform) |
| 160 | + data_loader = torch.utils.data.DataLoader(tempdataset, batch_size=args.batch_size_per_gpu, |
| 161 | + num_workers=args.num_workers, drop_last=False, |
| 162 | + sampler=torch.utils.data.DistributedSampler(tempdataset, shuffle=False)) |
| 163 | + features = None |
| 164 | + for samples, index in utils.MetricLogger(delimiter=" ").log_every(data_loader, 10): |
| 165 | + samples, index = samples.cuda(non_blocking=True), index.cuda(non_blocking=True) |
| 166 | + feats = model.get_intermediate_layers(samples, n=1)[0].clone() |
| 167 | + |
| 168 | + cls_output_token = feats[:, 0, :] # [CLS] token |
| 169 | + # GeM with exponent 4 for output patch tokens |
| 170 | + b, h, w, d = len(samples), int(samples.shape[-2] / model.patch_embed.patch_size), int(samples.shape[-1] / model.patch_embed.patch_size), feats.shape[-1] |
| 171 | + feats = feats[:, 1:, :].reshape(b, h, w, d) |
| 172 | + feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2) |
| 173 | + feats = nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1. / 4).reshape(b, -1) |
| 174 | + # concatenate [CLS] token and GeM pooled patch tokens |
| 175 | + feats = torch.cat((cls_output_token, feats), dim=1) |
| 176 | + |
| 177 | + # init storage feature matrix |
| 178 | + if dist.get_rank() == 0 and features is None: |
| 179 | + features = torch.zeros(len(data_loader.dataset), feats.shape[-1]) |
| 180 | + if args.use_cuda: |
| 181 | + features = features.cuda(non_blocking=True) |
| 182 | + |
| 183 | + # get indexes from all processes |
| 184 | + y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device) |
| 185 | + y_l = list(y_all.unbind(0)) |
| 186 | + y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True) |
| 187 | + y_all_reduce.wait() |
| 188 | + index_all = torch.cat(y_l) |
| 189 | + |
| 190 | + # share features between processes |
| 191 | + feats_all = torch.empty(dist.get_world_size(), feats.size(0), feats.size(1), |
| 192 | + dtype=feats.dtype, device=feats.device) |
| 193 | + output_l = list(feats_all.unbind(0)) |
| 194 | + output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True) |
| 195 | + output_all_reduce.wait() |
| 196 | + |
| 197 | + # update storage feature matrix |
| 198 | + if dist.get_rank() == 0: |
| 199 | + if args.use_cuda: |
| 200 | + features.index_copy_(0, index_all, torch.cat(output_l)) |
| 201 | + else: |
| 202 | + features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu()) |
| 203 | + return features # features is still None for every rank which is not 0 (main) |
| 204 | + |
| 205 | + |
| 206 | +if __name__ == '__main__': |
| 207 | + parser = argparse.ArgumentParser('Copy detection on Copydays') |
| 208 | + parser.add_argument('--data_path', default='/path/to/copydays/', type=str, |
| 209 | + help="See https://lear.inrialpes.fr/~jegou/data.php#copydays") |
| 210 | + parser.add_argument('--whitening_path', default='/path/to/whitening_data/', type=str, |
| 211 | + help="""Path to directory with images used for computing the whitening operator. |
| 212 | + In our paper, we use 20k random images from YFCC100M.""") |
| 213 | + parser.add_argument('--distractors_path', default='/path/to/distractors/', type=str, |
| 214 | + help="Path to directory with distractors images. In our paper, we use 10k random images from YFCC100M.") |
| 215 | + parser.add_argument('--imsize', default=320, type=int, help='Image size (square image)') |
| 216 | + parser.add_argument('--batch_size_per_gpu', default=16, type=int, help='Per-GPU batch-size') |
| 217 | + parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.") |
| 218 | + parser.add_argument('--use_cuda', default=True, type=utils.bool_flag) |
| 219 | + parser.add_argument('--arch', default='vit_base', type=str, help='Architecture') |
| 220 | + parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.') |
| 221 | + parser.add_argument("--checkpoint_key", default="teacher", type=str, |
| 222 | + help='Key to use in the checkpoint (example: "teacher")') |
| 223 | + parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.') |
| 224 | + parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up |
| 225 | + distributed training; see https://pytorch.org/docs/stable/distributed.html""") |
| 226 | + parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.") |
| 227 | + args = parser.parse_args() |
| 228 | + |
| 229 | + utils.init_distributed_mode(args) |
| 230 | + print("git:\n {}\n".format(utils.get_sha())) |
| 231 | + print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) |
| 232 | + cudnn.benchmark = True |
| 233 | + |
| 234 | + # ============ building network ... ============ |
| 235 | + if "vit" in args.arch: |
| 236 | + model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0) |
| 237 | + print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.") |
| 238 | + else: |
| 239 | + print(f"Architecture {args.arch} non supported") |
| 240 | + sys.exit(1) |
| 241 | + if args.use_cuda: |
| 242 | + model.cuda() |
| 243 | + model.eval() |
| 244 | + utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size) |
| 245 | + |
| 246 | + dataset = CopydaysDataset(args.data_path) |
| 247 | + |
| 248 | + # ============ Extract features ... ============ |
| 249 | + # extract features for queries |
| 250 | + queries = [] |
| 251 | + for q in dataset.query_blocks: |
| 252 | + queries.append(extract_features(dataset.get_block(q), model, args)) |
| 253 | + if utils.get_rank() == 0: |
| 254 | + queries = torch.cat(queries) |
| 255 | + print(f"Extraction of queries features done. Shape: {queries.shape}") |
| 256 | + |
| 257 | + # extract features for database |
| 258 | + database = [] |
| 259 | + for b in dataset.database_blocks: |
| 260 | + database.append(extract_features(dataset.get_block(b), model, args)) |
| 261 | + |
| 262 | + # extract features for distractors |
| 263 | + if os.path.isdir(args.distractors_path): |
| 264 | + print("Using distractors...") |
| 265 | + list_distractors = [os.path.join(args.distractors_path, s) for s in os.listdir(args.distractors_path) if is_image_file(s)] |
| 266 | + database.append(extract_features(list_distractors, model, args)) |
| 267 | + if utils.get_rank() == 0: |
| 268 | + database = torch.cat(database) |
| 269 | + print(f"Extraction of database and distractors features done. Shape: {database.shape}") |
| 270 | + |
| 271 | + # ============ Whitening ... ============ |
| 272 | + if os.path.isdir(args.whitening_path): |
| 273 | + print(f"Extracting features on images from {args.whitening_path} for learning the whitening operator.") |
| 274 | + list_whit = [os.path.join(args.whitening_path, s) for s in os.listdir(args.whitening_path) if is_image_file(s)] |
| 275 | + features_for_whitening = extract_features(list_whit, model, args) |
| 276 | + if utils.get_rank() == 0: |
| 277 | + # center |
| 278 | + mean_feature = torch.mean(features_for_whitening, dim=0) |
| 279 | + database -= mean_feature |
| 280 | + queries -= mean_feature |
| 281 | + pca = utils.PCA(dim=database.shape[-1], whit=0.5) |
| 282 | + # compute covariance |
| 283 | + cov = torch.mm(features_for_whitening.T, features_for_whitening) / features_for_whitening.shape[0] |
| 284 | + pca.train_pca(cov.cpu().numpy()) |
| 285 | + database = pca.apply(database) |
| 286 | + queries = pca.apply(queries) |
| 287 | + |
| 288 | + # ============ Copy detection ... ============ |
| 289 | + if utils.get_rank() == 0: |
| 290 | + # l2 normalize the features |
| 291 | + database = nn.functional.normalize(database, dim=1, p=2) |
| 292 | + queries = nn.functional.normalize(queries, dim=1, p=2) |
| 293 | + |
| 294 | + # similarity |
| 295 | + similarity = torch.mm(queries, database.T) |
| 296 | + distances, indices = similarity.topk(20, largest=True, sorted=True) |
| 297 | + |
| 298 | + # evaluate |
| 299 | + retrieved = dataset.eval_result(indices, distances) |
| 300 | + dist.barrier() |
| 301 | + |
0 commit comments