Skip to content

Commit 30aedce

Browse files
author
Mathilde Caron
committed
image retrieval on oxford and paris
1 parent 9085367 commit 30aedce

File tree

4 files changed

+353
-7
lines changed

4 files changed

+353
-7
lines changed

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,21 @@ git clone https://github.com/davisvideochallenge/davis2017-evaluation $HOME/davi
278278
python $HOME/davis2017-evaluation/evaluation_method.py --task semi-supervised --results_path /path/to/saving_dir --davis_path $HOME/davis-2017/DAVIS/
279279
```
280280

281+
## Evaluation: Image Retrieval on revisited Oxford and Paris
282+
Step 1: Prepare revisited Oxford and Paris by following [this repo](https://github.com/filipradenovic/revisitop).
283+
284+
Step 2: Image retrieval (if you do not specify weights with `--pretrained_weights` then by default [DINO weights pretrained on Google Landmark v2 dataset](https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth) will be used).
285+
286+
Paris:
287+
```
288+
python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 512 --multiscale 1 --data_path /path/to/revisited_paris_oxford/ --dataset rparis6k
289+
```
290+
291+
Oxford:
292+
```
293+
python -m torch.distributed.launch --use_env --nproc_per_node=1 eval_image_retrieval.py --imsize 224 --multiscale 0 --data_path /path/to/revisited_paris_oxford/ --dataset roxford5k
294+
```
295+
281296
## License
282297
This repository is released under the Apache 2.0 license as found in the [LICENSE](LICENSE) file.
283298

eval_image_retrieval.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
import sys
16+
import pickle
17+
import argparse
18+
19+
import torch
20+
from torch import nn
21+
import torch.distributed as dist
22+
import torch.backends.cudnn as cudnn
23+
from torchvision import models as torchvision_models
24+
from torchvision import transforms as pth_transforms
25+
from PIL import Image, ImageFile
26+
import numpy as np
27+
28+
import utils
29+
import vision_transformer as vits
30+
from eval_knn import extract_features
31+
32+
33+
class OxfordParisDataset(torch.utils.data.Dataset):
34+
def __init__(self, dir_main, dataset, split, transform=None, imsize=None):
35+
if dataset not in ['roxford5k', 'rparis6k']:
36+
raise ValueError('Unknown dataset: {}!'.format(dataset))
37+
38+
# loading imlist, qimlist, and gnd, in cfg as a dict
39+
gnd_fname = os.path.join(dir_main, dataset, 'gnd_{}.pkl'.format(dataset))
40+
with open(gnd_fname, 'rb') as f:
41+
cfg = pickle.load(f)
42+
cfg['gnd_fname'] = gnd_fname
43+
cfg['ext'] = '.jpg'
44+
cfg['qext'] = '.jpg'
45+
cfg['dir_data'] = os.path.join(dir_main, dataset)
46+
cfg['dir_images'] = os.path.join(cfg['dir_data'], 'jpg')
47+
cfg['n'] = len(cfg['imlist'])
48+
cfg['nq'] = len(cfg['qimlist'])
49+
cfg['im_fname'] = config_imname
50+
cfg['qim_fname'] = config_qimname
51+
cfg['dataset'] = dataset
52+
self.cfg = cfg
53+
54+
self.samples = cfg["qimlist"] if split == "query" else cfg["imlist"]
55+
self.transform = transform
56+
self.imsize = imsize
57+
58+
def __len__(self):
59+
return len(self.samples)
60+
61+
def __getitem__(self, index):
62+
path = os.path.join(self.cfg["dir_images"], self.samples[index] + ".jpg")
63+
ImageFile.LOAD_TRUNCATED_IMAGES = True
64+
with open(path, 'rb') as f:
65+
img = Image.open(f)
66+
img = img.convert('RGB')
67+
if self.imsize is not None:
68+
img.thumbnail((self.imsize, self.imsize), Image.ANTIALIAS)
69+
if self.transform is not None:
70+
img = self.transform(img)
71+
return img, index
72+
73+
74+
def config_imname(cfg, i):
75+
return os.path.join(cfg['dir_images'], cfg['imlist'][i] + cfg['ext'])
76+
77+
78+
def config_qimname(cfg, i):
79+
return os.path.join(cfg['dir_images'], cfg['qimlist'][i] + cfg['qext'])
80+
81+
82+
if __name__ == '__main__':
83+
parser = argparse.ArgumentParser('Image Retrieval on revisited Paris and Oxford')
84+
parser.add_argument('--data_path', default='/path/to/revisited_paris_oxford/', type=str)
85+
parser.add_argument('--dataset', default='roxford5k', type=str, choices=['roxford5k', 'rparis6k'])
86+
parser.add_argument('--multiscale', default=False, type=utils.bool_flag)
87+
parser.add_argument('--imsize', default=224, type=int, help='Image size (square)')
88+
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
89+
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag)
90+
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
91+
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
92+
parser.add_argument("--checkpoint_key", default="teacher", type=str,
93+
help='Key to use in the checkpoint (example: "teacher")')
94+
parser.add_argument('--num_workers', default=10, type=int, help='Number of data loading workers per GPU.')
95+
parser.add_argument("--dist_url", default="env://", type=str, help="""url used to set up
96+
distributed training; see https://pytorch.org/docs/stable/distributed.html""")
97+
parser.add_argument("--local_rank", default=0, type=int, help="Please ignore and do not set this argument.")
98+
args = parser.parse_args()
99+
100+
utils.init_distributed_mode(args)
101+
print("git:\n {}\n".format(utils.get_sha()))
102+
print("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
103+
cudnn.benchmark = True
104+
105+
# ============ preparing data ... ============
106+
transform = pth_transforms.Compose([
107+
pth_transforms.ToTensor(),
108+
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
109+
])
110+
dataset_train = OxfordParisDataset(args.data_path, args.dataset, split="train", transform=transform, imsize=args.imsize)
111+
dataset_query = OxfordParisDataset(args.data_path, args.dataset, split="query", transform=transform, imsize=args.imsize)
112+
sampler = torch.utils.data.DistributedSampler(dataset_train, shuffle=False)
113+
data_loader_train = torch.utils.data.DataLoader(
114+
dataset_train,
115+
sampler=sampler,
116+
batch_size=1,
117+
num_workers=args.num_workers,
118+
pin_memory=True,
119+
drop_last=False,
120+
)
121+
data_loader_query = torch.utils.data.DataLoader(
122+
dataset_query,
123+
batch_size=1,
124+
num_workers=args.num_workers,
125+
pin_memory=True,
126+
drop_last=False,
127+
)
128+
print(f"train: {len(dataset_train)} imgs / query: {len(dataset_query)} imgs")
129+
130+
# ============ building network ... ============
131+
if "vit" in args.arch:
132+
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
133+
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
134+
elif "xcit" in args.arch:
135+
model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0)
136+
elif args.arch in torchvision_models.__dict__.keys():
137+
model = torchvision_models.__dict__[args.arch](num_classes=0)
138+
else:
139+
print(f"Architecture {args.arch} non supported")
140+
sys.exit(1)
141+
if args.use_cuda:
142+
model.cuda()
143+
model.eval()
144+
145+
# load pretrained weights
146+
if os.path.isfile(args.pretrained_weights):
147+
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
148+
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
149+
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
150+
state_dict = state_dict[args.checkpoint_key]
151+
# remove `module.` prefix
152+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
153+
# remove `backbone.` prefix induced by multicrop wrapper
154+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
155+
msg = model.load_state_dict(state_dict, strict=False)
156+
print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
157+
elif args.arch == "vit_small" and args.patch_size == 16:
158+
print("Since no pretrained weights have been provided, we load pretrained DINO weights on Google Landmark v2.")
159+
model.load_state_dict(torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_vitsmall16_googlelandmark_pretrain/dino_vitsmall16_googlelandmark_pretrain.pth"))
160+
else:
161+
print("Warning: We use random weights.")
162+
163+
############################################################################
164+
# Step 1: extract features
165+
train_features = extract_features(model, data_loader_train, args.use_cuda, multiscale=args.multiscale)
166+
query_features = extract_features(model, data_loader_query, args.use_cuda, multiscale=args.multiscale)
167+
168+
if utils.get_rank() == 0: # only rank 0 will work from now on
169+
# normalize features
170+
train_features = nn.functional.normalize(train_features, dim=1, p=2)
171+
query_features = nn.functional.normalize(query_features, dim=1, p=2)
172+
173+
############################################################################
174+
# Step 2: similarity
175+
sim = torch.mm(train_features, query_features.T)
176+
ranks = torch.argsort(-sim, dim=0).cpu().numpy()
177+
178+
############################################################################
179+
# Step 3: evaluate
180+
gnd = dataset_train.cfg['gnd']
181+
# evaluate ranks
182+
ks = [1, 5, 10]
183+
# search for easy & hard
184+
gnd_t = []
185+
for i in range(len(gnd)):
186+
g = {}
187+
g['ok'] = np.concatenate([gnd[i]['easy'], gnd[i]['hard']])
188+
g['junk'] = np.concatenate([gnd[i]['junk']])
189+
gnd_t.append(g)
190+
mapM, apsM, mprM, prsM = utils.compute_map(ranks, gnd_t, ks)
191+
# search for hard
192+
gnd_t = []
193+
for i in range(len(gnd)):
194+
g = {}
195+
g['ok'] = np.concatenate([gnd[i]['hard']])
196+
g['junk'] = np.concatenate([gnd[i]['junk'], gnd[i]['easy']])
197+
gnd_t.append(g)
198+
mapH, apsH, mprH, prsH = utils.compute_map(ranks, gnd_t, ks)
199+
print('>> {}: mAP M: {}, H: {}'.format(args.dataset, np.around(mapM*100, decimals=2), np.around(mapH*100, decimals=2)))
200+
print('>> {}: mP@k{} M: {}, H: {}'.format(args.dataset, np.array(ks), np.around(mprM*100, decimals=2), np.around(mprH*100, decimals=2)))
201+
dist.barrier()

eval_knn.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import torch.backends.cudnn as cudnn
2222
from torchvision import datasets
2323
from torchvision import transforms as pth_transforms
24+
from torchvision import models as torchvision_models
2425

2526
import utils
2627
import vision_transformer as vits
@@ -60,6 +61,8 @@ def extract_feature_pipeline(args):
6061
print(f"Model {args.arch} {args.patch_size}x{args.patch_size} built.")
6162
elif "xcit" in args.arch:
6263
model = torch.hub.load('facebookresearch/xcit', args.arch, num_classes=0)
64+
elif args.arch in torchvision_models.__dict__.keys():
65+
model = torchvision_models.__dict__[args.arch](num_classes=0)
6366
else:
6467
print(f"Architecture {args.arch} non supported")
6568
sys.exit(1)
@@ -69,9 +72,9 @@ def extract_feature_pipeline(args):
6972

7073
# ============ extract features ... ============
7174
print("Extracting features for train set...")
72-
train_features = extract_features(model, data_loader_train)
75+
train_features = extract_features(model, data_loader_train, args.use_cuda)
7376
print("Extracting features for val set...")
74-
test_features = extract_features(model, data_loader_val)
77+
test_features = extract_features(model, data_loader_val, args.use_cuda)
7578

7679
if utils.get_rank() == 0:
7780
train_features = nn.functional.normalize(train_features, dim=1, p=2)
@@ -89,18 +92,21 @@ def extract_feature_pipeline(args):
8992

9093

9194
@torch.no_grad()
92-
def extract_features(model, data_loader):
95+
def extract_features(model, data_loader, use_cuda=True, multiscale=False):
9396
metric_logger = utils.MetricLogger(delimiter=" ")
9497
features = None
9598
for samples, index in metric_logger.log_every(data_loader, 10):
9699
samples = samples.cuda(non_blocking=True)
97100
index = index.cuda(non_blocking=True)
98-
feats = model(samples).clone()
101+
if multiscale:
102+
feats = utils.multi_scale(samples, model)
103+
else:
104+
feats = model(samples).clone()
99105

100106
# init storage feature matrix
101107
if dist.get_rank() == 0 and features is None:
102108
features = torch.zeros(len(data_loader.dataset), feats.shape[-1])
103-
if args.use_cuda:
109+
if use_cuda:
104110
features = features.cuda(non_blocking=True)
105111
print(f"Storing features into tensor of shape {features.shape}")
106112

@@ -125,7 +131,7 @@ def extract_features(model, data_loader):
125131

126132
# update storage feature matrix
127133
if dist.get_rank() == 0:
128-
if args.use_cuda:
134+
if use_cuda:
129135
features.index_copy_(0, index_all, torch.cat(output_l))
130136
else:
131137
features.index_copy_(0, index_all.cpu(), torch.cat(output_l).cpu())
@@ -191,7 +197,7 @@ def __getitem__(self, idx):
191197
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
192198
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag,
193199
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
194-
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture (support only ViT and XCiT atm).')
200+
parser.add_argument('--arch', default='vit_small', type=str, help='Architecture')
195201
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
196202
parser.add_argument("--checkpoint_key", default="teacher", type=str,
197203
help='Key to use in the checkpoint (example: "teacher")')

0 commit comments

Comments
 (0)