Skip to content
This repository was archived by the owner on Jan 29, 2024. It is now read-only.

Commit 34ce831

Browse files
committed
Add fine-tune scripts for CUHK01.
1 parent 484c422 commit 34ce831

File tree

6 files changed

+587
-1
lines changed

6 files changed

+587
-1
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
| DukeMTMC | 86.4% | 74.5% |
99
| CUHK03-labeled | 75.9% | 72.5% |
1010
| CUHK03-detected | 71.2% | 68.0% |
11-
| CUHK01 | [] | [] |
11+
| CUHK01(finetune on CUHK03) | 90.5% | 89.4% |
1212

1313

1414
## Test

configs/sample_CUHK01.yaml

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
OUTPUT_DIR: "./checkpoint/CUHK01/Exp-1011-2-finetune-2"
2+
DEVICE: "cuda"
3+
DEVICE_ID: ('2')
4+
5+
MODEL:
6+
NAME: 'baseline'
7+
ARCH: 'resnet50'
8+
LABEL_SMOOTH: True
9+
10+
INPUT:
11+
SIZE_TRAIN: [384, 128]
12+
13+
DATASETS:
14+
NAME: 'CUHK01'
15+
ROOT: '/home/hzh/data'
16+
17+
DATALOADER:
18+
NUM_WORKERS: 4
19+
BATCH_SIZE: 64
20+
NUM_INSTANCES: 4
21+
22+
SOLVER:
23+
LOSS: 'softmax_triplet'
24+
OPTIMIZER_NAME: 'Adam'
25+
MAX_EPOCHS: 60
26+
BASE_LR: 0.0004
27+
WEIGHT_DECAY: 0.0005
28+
29+
EVAL_PERIOD: 5
30+
PRINT_FREQ: 10
31+
32+
SCHEDULER:
33+
NAME: 'WarmupStepLR'
34+
STEP: 20
35+
GAMMA: 0.1
36+
37+
WARMUP_FACTOR: 100
38+
WARMUP_ITERS: 20
39+
40+
41+
TEST:
42+
BATCH_SIZE: 64
43+
44+

data/__init__.py

+21
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from torch.utils.data import DataLoader
1111

12+
from data.cuhk01 import CUHK01
1213
from data.dataset import ReIDDataset, ImageData
1314
from data.transforms import TrainTransform, TestTransform
1415
from data.samplers import RandomIdentitySampler
@@ -46,3 +47,23 @@ def make_loader(cfg):
4647
pin_memory=True)
4748

4849
return train_loader, query_loader, gallery_loader, num_train_pids
50+
51+
52+
def make_loader_cuhk01(cfg):
53+
_data = CUHK01(root=cfg.DATASETS.ROOT)
54+
num_train_pids = _data.num_train_pids
55+
56+
train_loader = DataLoader(ImageData(_data.train, TrainTransform(p=0.5)),
57+
sampler=RandomIdentitySampler(_data.train, cfg.DATALOADER.NUM_INSTANCES),
58+
batch_size=cfg.DATALOADER.BATCH_SIZE, num_workers=cfg.DATALOADER.NUM_WORKERS,
59+
pin_memory=True, drop_last=True)
60+
61+
query_loader = DataLoader(ImageData(_data.query, TestTransform(flip=False)),
62+
batch_size=cfg.DATALOADER.BATCH_SIZE, num_workers=cfg.DATALOADER.NUM_WORKERS,
63+
pin_memory=True)
64+
65+
gallery_loader = DataLoader(ImageData(_data.gallery, TestTransform(flip=False)),
66+
batch_size=cfg.DATALOADER.BATCH_SIZE, num_workers=cfg.DATALOADER.NUM_WORKERS,
67+
pin_memory=True)
68+
69+
return train_loader, query_loader, gallery_loader, num_train_pids

0 commit comments

Comments
 (0)