Skip to content

Commit 815000a

Browse files
committed
New Code
1 parent bb46283 commit 815000a

File tree

452 files changed

+69317
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

452 files changed

+69317
-0
lines changed

.gitignore

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
*.sh
2+
.vscode
3+
# Byte-compiled / optimized / DLL files
4+
__pycache__/
5+
*.py[cod]
6+
*$py.class
7+
/Domain-Adaptation/.idea
8+
*.npz
9+
# C extensions
10+
*.so
11+
12+
# Distribution / packaging
13+
.Python
14+
build/
15+
develop-eggs/
16+
dist/
17+
downloads/
18+
eggs/
19+
.eggs/
20+
lib/
21+
lib64/
22+
parts/
23+
sdist/
24+
var/
25+
wheels/
26+
*.egg-info/
27+
.installed.cfg
28+
*.egg
29+
MANIFEST
30+
31+
# PyInstaller
32+
# Usually these files are written by a python script from a template
33+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
34+
*.manifest
35+
*.spec
36+
37+
# Installer logs
38+
pip-log.txt
39+
pip-delete-this-directory.txt
40+
41+
# Unit test / coverage reports
42+
htmlcov/
43+
.tox/
44+
.coverage
45+
.coverage.*
46+
.cache
47+
nosetests.xml
48+
coverage.xml
49+
*.cover
50+
.hypothesis/
51+
.pytest_cache/
52+
53+
# Translations
54+
*.mo
55+
*.pot
56+
57+
# Django stuff:
58+
*.log
59+
local_settings.py
60+
db.sqlite3
61+
62+
# Flask stuff:
63+
instance/
64+
.webassets-cache
65+
66+
# Scrapy stuff:
67+
.scrapy
68+
69+
# Sphinx documentation
70+
docs/_build/
71+
72+
# PyBuilder
73+
target/
74+
75+
# Jupyter Notebook
76+
.ipynb_checkpoints
77+
78+
# pyenv
79+
.python-version
80+
81+
# celery beat schedule file
82+
celerybeat-schedule
83+
84+
# SageMath parsed files
85+
*.sage.py
86+
87+
# Environments
88+
.env
89+
.venv
90+
env/
91+
venv/
92+
ENV/
93+
env.bak/
94+
venv.bak/
95+
96+
# Spyder project settings
97+
.spyderproject
98+
.spyproject
99+
100+
# Rope project settings
101+
.ropeproject
102+
103+
# mkdocs documentation
104+
/site
105+
106+
# mypy
107+
.mypy_cache/

MGN/.gitignore

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
.project
2+
.settings/
3+
.prefs
4+
.pydevproject
5+
.idea/
6+
.idea
7+
.DS_Store
8+
.cache
9+
*.pyc
10+
*.html
11+
*.xlm

MGN/README.md

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Multiple Granularity Network
2+
Implement of paper:[Learning Discriminative Features with Multiple Granularities for Person Re-Identification](https://arxiv.org/abs/1804.01438v1)
3+
4+
## Dependencies
5+
6+
- Python >= 3.5
7+
- PyTorch >= 0.4.0
8+
- torchvision
9+
- scipy
10+
- numpy
11+
- scikit_learn
12+
13+
14+
15+
## Current Result
16+
17+
| Re-Ranking| backbone | mAP | rank1 | rank3 | rank5 | rank10 |
18+
| :------: | :------: | :------: | :------: | :------: | :------: | :------: |
19+
| yes | resnet50 | 94.33 | 95.58 | 97.54 | 97.92 | 98.46 |
20+
| no | resnet50 | 86.15 | 94.95 | 97.42 | 98.07 | 98.93 |
21+
22+
23+
24+
## Data
25+
26+
The data structure would look like:
27+
```
28+
data/
29+
bounding_box_train/
30+
bounding_box_test/
31+
query/
32+
```
33+
#### Market1501
34+
Download from [here](http://www.liangzheng.org/Project/project_reid.html)
35+
36+
#### DukeMTMC-reID
37+
Download from [here](http://vision.cs.duke.edu/DukeMTMC/)
38+
39+
#### CUHK03
40+
1. Download cuhk03 dataset from "http://www.ee.cuhk.edu.hk/~xgwang/CUHK_identification.html"
41+
2. Unzip the file and you will get the cuhk03_release dir include cuhk-03.mat
42+
3. Download "cuhk03_new_protocol_config_detected.mat" from "https://github.com/zhunzhong07/person-re-ranking/tree/master/evaluation/data/CUHK03"
43+
and put it with cuhk-03.mat. We need this new protocol to split the dataset.
44+
```
45+
python utils/transform_cuhk03.py --src <path/to/cuhk03_release> --dst <path/to/save>
46+
```
47+
48+
NOTICE:You need to change num_classes in network depend on how many people in your train dataset! e.g. 751 in Market1501
49+
50+
## Weights
51+
52+
Pretrained weight download from [google drive](https://drive.google.com/open?id=16V7ZsflBbINHPjh_UVYGBVO6NuSxEMTi)
53+
or [baidu drive](https://pan.baidu.com/s/12AkumLX10hLx9vh_SQwdyw) password:mrl5
54+
## Train
55+
56+
You can specify more parameters in opt.py
57+
58+
```
59+
python main.py --mode train --data_path <path/to/Market-1501-v15.09.15>
60+
```
61+
62+
## Evaluate
63+
64+
Use pretrained weight or your trained weight
65+
66+
```
67+
python main.py --mode evaluate --data_path <path/to/Market-1501-v15.09.15> --weight <path/to/weight_name.pt>
68+
```
69+
70+
## Visualize
71+
72+
Visualize rank10 query result of one image(query from bounding_box_test)
73+
74+
Extract features will take a few munutes, or you can save features as .mat file for multiple uses
75+
76+
![image](https://s1.ax1x.com/2018/11/27/FV9xyj.png)
77+
78+
```
79+
python main.py --mode vis --query_image <path/to/query_image> --weight <path/to/weight_name.pt>
80+
```
81+
82+
83+
## Citation
84+
85+
```text
86+
@ARTICLE{2018arXiv180401438W,
87+
author = {{Wang}, G. and {Yuan}, Y. and {Chen}, X. and {Li}, J. and {Zhou}, X.},
88+
title = "{Learning Discriminative Features with Multiple Granularities for Person Re-Identification}",
89+
journal = {ArXiv e-prints},
90+
year = 2018,
91+
}
92+
```

MGN/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# -*- coding: utf-8 -*-
2+
"""
3+
@File : __init__.py
4+
@Time : 2020/4/18 16:35
5+
@Author : KeyForce
6+
7+
"""
8+
__all__ = ['utils', 'data', 'loss', 'network']

MGN/data.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from torchvision import transforms
2+
from torch.utils.data import dataset, dataloader
3+
from torchvision.datasets.folder import default_loader
4+
from utils.RandomErasing import RandomErasing
5+
from utils.RandomSampler import RandomSampler
6+
from opt import opt
7+
import os
8+
import re
9+
10+
11+
class Data():
12+
def __init__(self):
13+
train_transform = transforms.Compose([
14+
transforms.Resize((384, 128), interpolation=3),
15+
transforms.RandomHorizontalFlip(),
16+
transforms.ToTensor(),
17+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
18+
RandomErasing(probability=0.5, mean=[0.0, 0.0, 0.0])
19+
])
20+
21+
test_transform = transforms.Compose([
22+
transforms.Resize((384, 128), interpolation=3),
23+
transforms.ToTensor(),
24+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
25+
])
26+
27+
self.trainset = Market1501(train_transform, 'train', opt.data_path)
28+
self.testset = Market1501(test_transform, 'test', opt.data_path)
29+
self.queryset = Market1501(test_transform, 'query', opt.data_path)
30+
31+
self.train_loader = dataloader.DataLoader(self.trainset,
32+
sampler=RandomSampler(self.trainset, batch_id=opt.batchid,
33+
batch_image=opt.batchimage),
34+
batch_size=opt.batchid * opt.batchimage, num_workers=8,
35+
pin_memory=True)
36+
self.test_loader = dataloader.DataLoader(self.testset, batch_size=opt.batchtest, num_workers=8, pin_memory=True)
37+
self.query_loader = dataloader.DataLoader(self.queryset, batch_size=opt.batchtest, num_workers=8,
38+
pin_memory=True)
39+
40+
if opt.mode == 'vis':
41+
self.query_image = test_transform(default_loader(opt.query_image))
42+
43+
44+
class Market1501(dataset.Dataset):
45+
def __init__(self, transform, dtype, data_path):
46+
47+
self.transform = transform
48+
self.loader = default_loader
49+
self.data_path = data_path
50+
51+
if dtype == 'train':
52+
self.data_path += '/bounding_box_train'
53+
# self.data_path += ''
54+
elif dtype == 'test':
55+
self.data_path += '/bounding_box_test'
56+
# self.data_path += ''
57+
else:
58+
self.data_path += '/query'
59+
60+
self.imgs = [path for path in self.list_pictures(self.data_path) if self.id(path) != -1]
61+
62+
self._id2label = {_id: idx for idx, _id in enumerate(self.unique_ids)}
63+
64+
def __getitem__(self, index):
65+
path = self.imgs[index]
66+
target = self._id2label[self.id(path)]
67+
68+
img = self.loader(path)
69+
if self.transform is not None:
70+
img = self.transform(img)
71+
72+
return img, target
73+
74+
def __len__(self):
75+
return len(self.imgs)
76+
77+
@staticmethod
78+
def id(file_path):
79+
"""
80+
:param file_path: unix style file path
81+
:return: person id
82+
"""
83+
return int(file_path.split('/')[-1].split('_')[0])
84+
85+
@staticmethod
86+
def camera(file_path):
87+
"""
88+
:param file_path: unix style file path
89+
:return: camera id
90+
"""
91+
return int(file_path.split('/')[-1].split('_')[1][1])
92+
93+
@property
94+
def ids(self):
95+
"""
96+
:return: person id list corresponding to dataset image paths
97+
"""
98+
return [self.id(path) for path in self.imgs]
99+
100+
@property
101+
def unique_ids(self):
102+
"""
103+
:return: unique person ids in ascending order
104+
"""
105+
return sorted(set(self.ids))
106+
107+
@property
108+
def cameras(self):
109+
"""
110+
:return: camera id list corresponding to dataset image paths
111+
"""
112+
return [self.camera(path) for path in self.imgs]
113+
114+
@staticmethod
115+
def list_pictures(directory, ext='jpg|jpeg|bmp|png|ppm|npy'):
116+
assert os.path.isdir(directory), 'dataset is not exists!{}'.format(directory)
117+
118+
return sorted([os.path.join(root, f)
119+
for root, _, files in os.walk(directory) for f in files
120+
if re.match(r'([\w]+\.(?:' + ext + '))', f)])

MGN/loss.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from torch.nn import CrossEntropyLoss
2+
from torch.nn.modules import loss
3+
from utils.TripletLoss import TripletLoss
4+
5+
6+
class Loss(loss._Loss):
7+
def __init__(self):
8+
super(Loss, self).__init__()
9+
10+
def forward(self, outputs, labels):
11+
cross_entropy_loss = CrossEntropyLoss()
12+
triplet_loss = TripletLoss(margin=1.2)
13+
14+
Triplet_Loss = [triplet_loss(output, labels) for output in outputs[1:4]]
15+
Triplet_Loss = sum(Triplet_Loss) / len(Triplet_Loss)
16+
17+
CrossEntropy_Loss = [cross_entropy_loss(output, labels) for output in outputs[4:]]
18+
CrossEntropy_Loss = sum(CrossEntropy_Loss) / len(CrossEntropy_Loss)
19+
20+
loss_sum = Triplet_Loss + 2 * CrossEntropy_Loss
21+
22+
print('\rtotal loss:%.2f Triplet_Loss:%.2f CrossEntropy_Loss:%.2f' % (
23+
loss_sum.data.cpu().numpy(),
24+
Triplet_Loss.data.cpu().numpy(),
25+
CrossEntropy_Loss.data.cpu().numpy()),
26+
end=' ')
27+
return loss_sum

0 commit comments

Comments
 (0)