Skip to content

Commit ba9edd1

Browse files
author
Mathilde Caron
committed
copy detection
1 parent 30aedce commit ba9edd1

File tree

4 files changed

+368
-1
lines changed

4 files changed

+368
-1
lines changed

README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,18 @@ Oxford:
293293
python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 224 --multiscale 0 --data_path /path/to/revisited_paris_oxford/ --dataset roxford5k
294294
```
295295

296+
## Evaluation: Copy detection on Copydays
297+
Step 1: Prepare [Copydays dataset](https://lear.inrialpes.fr/~jegou/data.php#copydays).
298+
299+
Step 2 (opt): Prepare a set of image distractors and a set of images on which to learn the whitening operator.
300+
In our paper, we use 10k random images from YFCC100M as distractors and 20k random images from YFCC100M (different from the distractors) for computing the whitening operation.
301+
302+
Step 3: Run copy detection:
303+
```
304+
python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_copy_detection.py --data_path /path/to/copydays/ --whitening_path /path/to/whitening_data/ --distractors_path /path/to/distractors/
305+
```
306+
We report result on the strong subset. For example in the stdout from the command above we get: `eval on strong mAP=0.858`.
307+
296308
## License
297309
This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
298310

eval_copy_detection.py

Lines changed: 301 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,301 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import sys
16+
import pickle
17+
import argparse
18+
19+
import torch
20+
from torch import nn
21+
import torch.distributed as dist
22+
import torch.backends.cudnn as cudnn
23+
from torchvision import models as torchvision_models
24+
from torchvision import transforms as pth_transforms
25+
from PIL import Image, ImageFile
26+
import numpy as np
27+
28+
import utils
29+
import vision_transformer as vits
30+
from eval_knn import extract_features
31+
32+
33+
class CopydaysDataset():
34+
def __init__(self, basedir):
35+
self.basedir = basedir
36+
self.block_names = (
37+
['original', 'strong'] +
38+
['jpegqual/%d' % i for i in
39+
[3, 5, 8, 10, 15, 20, 30, 50, 75]] +
40+
['crops/%d' % i for i in
41+
[10, 15, 20, 30, 40, 50, 60, 70, 80]])
42+
self.nblocks = len(self.block_names)
43+
44+
self.query_blocks = range(self.nblocks)
45+
self.q_block_sizes = np.ones(self.nblocks, dtype=int) * 157
46+
self.q_block_sizes[1] = 229
47+
# search only among originals
48+
self.database_blocks = [0]
49+
50+
def get_block(self, i):
51+
dirname = self.basedir + '/' + self.block_names[i]
52+
fnames = [dirname + '/' + fname
53+
for fname in sorted(os.listdir(dirname))
54+
if fname.endswith('.jpg')]
55+
return fnames
56+
57+
def get_block_filenames(self, subdir_name):
58+
dirname = self.basedir + '/' + subdir_name
59+
return [fname
60+
for fname in sorted(os.listdir(dirname))
61+
if fname.endswith('.jpg')]
62+
63+
def eval_result(self, ids, distances):
64+
j0 = 0
65+
for i in range(self.nblocks):
66+
j1 = j0 + self.q_block_sizes[i]
67+
block_name = self.block_names[i]
68+
I = ids[j0:j1] # block size
69+
sum_AP = 0
70+
if block_name != 'strong':
71+
# 1:1 mapping of files to names
72+
positives_per_query = [[i] for i in range(j1 - j0)]
73+
else:
74+
originals = self.get_block_filenames('original')
75+
strongs = self.get_block_filenames('strong')
76+
77+
# check if prefixes match
78+
positives_per_query = [
79+
[j for j, bname in enumerate(originals)
80+
if bname[:4] == qname[:4]]
81+
for qname in strongs]
82+
83+
for qno, Iline in enumerate(I):
84+
positives = positives_per_query[qno]
85+
ranks = []
86+
for rank, bno in enumerate(Iline):
87+
if bno in positives:
88+
ranks.append(rank)
89+
sum_AP += score_ap_from_ranks_1(ranks, len(positives))
90+
91+
print("eval on %s mAP=%.3f" % (
92+
block_name, sum_AP / (j1 - j0)))
93+
j0 = j1
94+
95+
96+
# from the Holidays evaluation package
97+
def score_ap_from_ranks_1(ranks, nres):
98+
""" Compute the average precision of one search.
99+
ranks = ordered list of ranks of true positives
100+
nres = total number of positives in dataset
101+
"""
102+
103+
# accumulate trapezoids in PR-plot
104+
ap = 0.0
105+
106+
# All have an x-size of:
107+
recall_step = 1.0 / nres
108+
109+
for ntp, rank in enumerate(ranks):
110+
111+
# y-size on left side of trapezoid:
112+
# ntp = nb of true positives so far
113+
# rank = nb of retrieved items so far
114+
if rank == 0:
115+
precision_0 = 1.0
116+
else:
117+
precision_0 = ntp / float(rank)
118+
119+
# y-size on right side of trapezoid:
120+
# ntp and rank are increased by one
121+
precision_1 = (ntp + 1) / float(rank + 1)
122+
123+
ap += (precision_1 + precision_0) * recall_step / 2.0
124+
125+
return ap
126+
127+
128+
class ImgListDataset(torch.utils.data.Dataset):
129+
def __init__(self, img_list, transform=None):
130+
self.samples = img_list
131+
self.transform = transform
132+
133+
def __getitem__(self, i):
134+
with open(self.samples[i], 'rb') as f:
135+
img = Image.open(f)
136+
img = img.convert('RGB')
137+
if self.transform is not None:
138+
img = self.transform(img)
139+
return img, i
140+
141+
def __len__(self):
142+
return len(self.samples)
143+
144+
145+
def is_image_file(s):
146+
ext = s.split(".")[-1]
147+
if ext in ['jpg', 'jpeg', 'png', 'ppm', 'bmp', 'pgm', 'tif', 'tiff', 'webp']:
148+
return True
149+
return False
150+
151+
152+
@torch.no_grad()
153+
def extract_features(image_list, model, args):
154+
transform = pth_transforms.Compose([
155+
pth_transforms.Resize((args.imsize, args.imsize), interpolation=3),
156+
pth_transforms.ToTensor(),
157+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
158+
])
159+
tempdataset = ImgListDataset(image_list, transform=transform)
160+
data_loader = torch.utils.data.DataLoader(tempdataset, batch_size=args.batch_size_per_gpu,
161+
num_workers=args.num_workers, drop_last=False,
162+
sampler=torch.utils.data.DistributedSampler(tempdataset, shuffle=False))
163+
features = None
164+
for samples, index in utils.MetricLogger(delimiter=" ").log_every(data_loader, 10):
165+
samples, index = samples.cuda(non_blocking=True), index.cuda(non_blocking=True)
166+
feats = model.get_intermediate_layers(samples, n=1)[0].clone()
167+
168+
cls_output_token = feats[:, 0, :] # [CLS] token
169+
# GeM with exponent 4 for output patch tokens
170+
b, h, w, d = len(samples), int(samples.shape[-2] / model.patch_embed.patch_size), int(samples.shape[-1] / model.patch_embed.patch_size), feats.shape[-1]
171+
feats = feats[:, 1:, :].reshape(b, h, w, d)
172+
feats = feats.clamp(min=1e-6).permute(0, 3, 1, 2)
173+
feats = nn.functional.avg_pool2d(feats.pow(4), (h, w)).pow(1. / 4).reshape(b, -1)
174+
# concatenate [CLS] token and GeM pooled patch tokens
175+
feats = torch.cat((cls_output_token, feats), dim=1)
176+
177+
# init storage feature matrix
178+
if dist.get_rank() == 0 and features is None:
179+
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
180+
if args.use_cuda:
181+
features = features.cuda(non_blocking=True)
182+
183+
# get indexes from all processes
184+
y_all = torch.empty(dist.get_world_size(), index.size(0), dtype=index.dtype, device=index.device)
185+
y_l = list(y_all.unbind(0))
186+
y_all_reduce = torch.distributed.all_gather(y_l, index, async_op=True)
187+
y_all_reduce.wait()
188+
index_all = torch.cat(y_l)
189+
190+
# share features between processes
191+
feats_all = torch.empty(dist.get_world_size(), feats.size(0), feats.size(1),
192+
dtype=feats.dtype, device=feats.device)
193+
output_l = list(feats_all.unbind(0))
194+
output_all_reduce = torch.distributed.all_gather(output_l, feats, async_op=True)
195+
output_all_reduce.wait()
196+
197+
# update storage feature matrix
198+
if dist.get_rank() == 0:
199+
if args.use_cuda:
200+
features.index_copy_(0, index_all, torch.cat(output_l))
201+
else:
202+
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
203+
return features # features is still None for every rank which is not 0 (main)
204+
205+
206+
if __name__ == '__main__':
207+
parser = argparse.ArgumentParser('Copy detection on Copydays')
208+
parser.add_argument('--data_path', default='/path/to/copydays/', type=str,
209+
help="See https://lear.inrialpes.fr/~jegou/data.php#copydays")
210+
parser.add_argument('--whitening_path', default='/path/to/whitening_data/', type=str,
211+
help="""Path to directory with images used for computing the whitening operator.
212+
In our paper, we use 20k random images from YFCC100M.""")
213+
parser.add_argument('--distractors_path', default='/path/to/distractors/', type=str,
214+
help="Path to directory with distractors images. In our paper, we use 10k random images from YFCC100M.")
215+
parser.add_argument('--imsize', default=320, type=int, help='Image size (square image)')
216+
parser.add_argument('--batch_size_per_gpu', default=16, type=int, help='Per-GPU batch-size')
217+
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
218+
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag)
219+
parser.add_argument('--arch', default='vit_base', type=str, help='Architecture')
220+
parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
221+
parser.add_argument("--checkpoint_key", default="teacher", type=str,
222+
help='Key to use in the checkpoint (example: "teacher")')
223+
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
224+
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
225+
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
226+
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
227+
args = parser.parse_args()
228+
229+
utils.init_distributed_mode(args)
230+
print("git:\n {}\n".format(utils.get_sha()))
231+
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
232+
cudnn.benchmark = True
233+
234+
# ============ building network ... ============
235+
if "vit" in args.arch:
236+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
237+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
238+
else:
239+
print(f"Architecture {args.arch} non supported")
240+
sys.exit(1)
241+
if args.use_cuda:
242+
model.cuda()
243+
model.eval()
244+
utils.load_pretrained_weights(model, args.pretrained_weights, args.checkpoint_key, args.arch, args.patch_size)
245+
246+
dataset = CopydaysDataset(args.data_path)
247+
248+
# ============ Extract features ... ============
249+
# extract features for queries
250+
queries = []
251+
for q in dataset.query_blocks:
252+
queries.append(extract_features(dataset.get_block(q), model, args))
253+
if utils.get_rank() == 0:
254+
queries = torch.cat(queries)
255+
print(f"Extraction of queries features done. Shape: {queries.shape}")
256+
257+
# extract features for database
258+
database = []
259+
for b in dataset.database_blocks:
260+
database.append(extract_features(dataset.get_block(b), model, args))
261+
262+
# extract features for distractors
263+
if os.path.isdir(args.distractors_path):
264+
print("Using distractors...")
265+
list_distractors = [os.path.join(args.distractors_path, s) for s in os.listdir(args.distractors_path) if is_image_file(s)]
266+
database.append(extract_features(list_distractors, model, args))
267+
if utils.get_rank() == 0:
268+
database = torch.cat(database)
269+
print(f"Extraction of database and distractors features done. Shape: {database.shape}")
270+
271+
# ============ Whitening ... ============
272+
if os.path.isdir(args.whitening_path):
273+
print(f"Extracting features on images from {args.whitening_path} for learning the whitening operator.")
274+
list_whit = [os.path.join(args.whitening_path, s) for s in os.listdir(args.whitening_path) if is_image_file(s)]
275+
features_for_whitening = extract_features(list_whit, model, args)
276+
if utils.get_rank() == 0:
277+
# center
278+
mean_feature = torch.mean(features_for_whitening, dim=0)
279+
database -= mean_feature
280+
queries -= mean_feature
281+
pca = utils.PCA(dim=database.shape[-1], whit=0.5)
282+
# compute covariance
283+
cov = torch.mm(features_for_whitening.T, features_for_whitening) / features_for_whitening.shape[0]
284+
pca.train_pca(cov.cpu().numpy())
285+
database = pca.apply(database)
286+
queries = pca.apply(queries)
287+
288+
# ============ Copy detection ... ============
289+
if utils.get_rank() == 0:
290+
# l2 normalize the features
291+
database = nn.functional.normalize(database, dim=1, p=2)
292+
queries = nn.functional.normalize(queries, dim=1, p=2)
293+
294+
# similarity
295+
similarity = torch.mm(queries, database.T)
296+
distances, indices = similarity.topk(20, largest=True, sorted=True)
297+
298+
# evaluate
299+
retrieved = dataset.eval_result(indices, distances)
300+
dist.barrier()
301+

eval_image_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def config_qimname(cfg, i):
8484
parser.add_argument('--data_path', default='/path/to/revisited_paris_oxford/', type=str)
8585
parser.add_argument('--dataset', default='roxford5k', type=str, choices=['roxford5k', 'rparis6k'])
8686
parser.add_argument('--multiscale', default=False, type=utils.bool_flag)
87-
parser.add_argument('--imsize', default=224, type=int, help='Image size (square)')
87+
parser.add_argument('--imsize', default=224, type=int, help='Image size')
8888
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
8989
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag)
9090
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')

utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,60 @@ def has_batchnorms(model):
631631
return False
632632

633633

634+
class PCA():
635+
"""
636+
Class to compute and apply PCA.
637+
"""
638+
def __init__(self, dim=256, whit=0.5):
639+
self.dim = dim
640+
self.whit = whit
641+
self.mean = None
642+
643+
def train_pca(self, cov):
644+
"""
645+
Takes a covariance matrix (np.ndarray) as input.
646+
"""
647+
d, v = np.linalg.eigh(cov)
648+
eps = d.max() * 1e-5
649+
n_0 = (d < eps).sum()
650+
if n_0 > 0:
651+
d[d < eps] = eps
652+
653+
# total energy
654+
totenergy = d.sum()
655+
656+
# sort eigenvectors with eigenvalues order
657+
idx = np.argsort(d)[::-1][:self.dim]
658+
d = d[idx]
659+
v = v[:, idx]
660+
661+
print("keeping %.2f %% of the energy" % (d.sum() / totenergy * 100.0))
662+
663+
# for the whitening
664+
d = np.diag(1. / d**self.whit)
665+
666+
# principal components
667+
self.dvt = np.dot(d, v.T)
668+
669+
def apply(self, x):
670+
# input is from numpy
671+
if isinstance(x, np.ndarray):
672+
if self.mean is not None:
673+
x -= self.mean
674+
return np.dot(self.dvt, x.T).T
675+
676+
# input is from torch and is on GPU
677+
if x.is_cuda:
678+
if self.mean is not None:
679+
x -= torch.cuda.FloatTensor(self.mean)
680+
return torch.mm(torch.cuda.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
681+
682+
# input if from torch, on CPU
683+
if self.mean is not None:
684+
x -= torch.FloatTensor(self.mean)
685+
return torch.mm(torch.FloatTensor(self.dvt), x.transpose(0, 1)).transpose(0, 1)
686+
687+
634688
def compute_ap(ranks, nres):
635689
"""
636690
Computes average precision for given ranked indexes.

0 commit comments

Comments
 (0)