Skip to content

Commit 1da3e2e

Browse files
committed
init release
1 parent cf3eaff commit 1da3e2e

37 files changed

+7887
-1
lines changed

.gitignore

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
__pycache__/
2+
*.png
3+
*.json
4+
*.pt
5+
._.DS_Store
6+
.DS_Store

README.md

+138-1
Original file line numberDiff line numberDiff line change
@@ -1 +1,138 @@
1-
# self-sup-corr-dev
1+
# SD4Match: Learning to Prompt Stable Diffusion Model for Semantic Matching
2+
3+
**[Project Page](http://sd4match.active.vision/) | [Arxiv](https://arxiv.org/abs/2310.17569) | [Pretrained Prompt](https://www.robots.ox.ac.uk/~xinghui/sd4match/pretrained_prompts.zip)**
4+
5+
[Xinghui Li<sup>1</sup>](https://scholar.google.com/citations?user=XLlgbBoAAAAJ&hl=en),
6+
Jingyi Lu<sup>2</sup>,
7+
[Kai Han<sup>2</sup>](https://www.kaihan.org/),
8+
[Victor Prisacariu<sup>1</sup>](https://www.robots.ox.ac.uk/~victor//)
9+
10+
[<sup>1</sup>Active Vision Lab, University of Oxford](https://www.robots.ox.ac.uk/~lav/)&nbsp;&nbsp;&nbsp;
11+
[<sup>2</sup>Visual AI Lab, University of Hong Kong](https://visailab.github.io/)
12+
13+
## Environment
14+
The environment can be easily installed through [conda](https://docs.conda.io/projects/miniconda/en/latest/) and pip. After downloading the code, run the following command:
15+
```shell
16+
$conda create -n sd4match python=3.10
17+
$conda activate sd4match
18+
19+
$conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
20+
$conda install xformers -c xformers
21+
$pip install yacs pandas scipy einops matplotlib triton timm diffusers accelerate transformers datasets tensorboard pykeops scikit-learn
22+
```
23+
24+
## Data
25+
#### PF-Pascal
26+
1. Download PF-Pascal dataset from [link](https://www.di.ens.fr/willow/research/proposalflow/).
27+
2. Rename the outermost directory from `PF-dataset-PASCAL` to `pf-pascal`.
28+
3. Download lists for image pairs from [link](www.robots.ox.ac.uk/~xinghui/sd4match/pf-pascal_image_pairs.zip).
29+
4. Place the lists for image pairs under `pf-pascal` directory. The structure should be:
30+
```
31+
pf-pascal
32+
├── __MACOSX
33+
├── PF-dataset-PASCAL
34+
├── trn_pairs.csv
35+
├── val_pairs.csv
36+
└── test_pairs.csv
37+
```
38+
#### PF-Willow
39+
1. Download PF-Willow dataset from the [link](https://www.di.ens.fr/willow/research/proposalflow/).
40+
2. Rename the outermost directory from `PF-dataset` to `pf-willow`.
41+
3. Download lists for image pairs from [link](www.robots.ox.ac.uk/~xinghui/sd4match/test_pairs.csv).
42+
4. Place the lists for image pairs under `pf-willow` directory. The structure should be:
43+
```
44+
pf-willow
45+
├── __MACOSX
46+
├── PF-dataset
47+
└── test_pairs.csv
48+
```
49+
#### SPair-71K
50+
1. Download SPair-71K dataset from [link](https://cvlab.postech.ac.kr/research/SPair-71k/). After extraction, No more action required.
51+
52+
## Setup
53+
1. Create symbol links to PF-Pascal, PF-Willow and SPair-71k dataset in `asset` directory. This can be done by:
54+
```
55+
ln -s /your/path/to/pf-pascal asset/pf-pascal
56+
ln -s /your/path/to/pf-willow asset/pf-willow
57+
ln -s /your/path/to/SPair-71k asset/SPair-71k
58+
```
59+
2. Create a directory named `sd4match` under `asset`. This is to save pre-computed features, checkpoints and learned prompts.
60+
```
61+
# create sd4match directly
62+
mkdir asset/sd4match
63+
64+
# or create sd4match at anywhere you want and use symbol link
65+
ln -s /your/path/to/sd4match asset/sd4match
66+
```
67+
68+
3. Run `pre_compute_dino_feature.py`. This would pre-compute DINOv2 feature for all images in PF-Pascal, PF-Willow and SPair-71k and save them in `asset/sd4match`. The structure should be:
69+
```
70+
sd4match
71+
└── asset
72+
└── DINOv2
73+
├── pfpascal
74+
| └── cached_output.pt
75+
├── pfwillow
76+
| └── cached_output.pt
77+
└── spair
78+
└── cached_output.pt
79+
```
80+
81+
## Training
82+
The bash scripts for training are provided in `script` directory, and organized based on training data and prompt type.
83+
84+
For example, to train `SD4Match-CPM` on SPair-71k dataset, run:
85+
```
86+
cd script/spair
87+
sh sd4match_cpm.sh
88+
```
89+
The batch size per GPU is currently set to `3`, which would take about `22G` GPU memory to train. Reduce the batch size if necessary. The training script will generate two directories in `asset/sd4match`: `log` and `prompt`. Tensorboard logs and training states are saved in `log`, and learned prompts are saved in `prompt`. For example, training `SD4Match-CPM` on SPair-71k dataset will generate:
90+
```
91+
sd4match
92+
├── asset
93+
| ├── ...
94+
├── log
95+
| └── spair
96+
| └── CPM_spair_sd2-1_Pair-DINO-Feat-G25-C50_constant_lr0.01
97+
| └── ...(Tensorboard log and training states)
98+
└── prompt
99+
└── CPM_spair_sd2-1_Pair-DINO-Feat-G25-C50
100+
└── ckpt.pt
101+
```
102+
103+
## Testing
104+
To replicate our results reported in the paper on SPair-71k, either learning the prompt by yourself or downloading our pre-trained prompt and place them under `asset/sd4match/prompt` directory. Run:
105+
```
106+
python test.py \
107+
--dataset spair \
108+
--prompt_type $PROMPT_NAME \
109+
--timestep 50 \
110+
--layer 1
111+
```
112+
Replace `$PROMPT_NAME` with prompt your want. It needs to have a corresponding directory under `asset/sd4match/prompt`. For example, to evaluate `SD4Match-CPM`, run:
113+
```
114+
python test.py \
115+
--dataset spair \
116+
--prompt_type CPM_spair_sd2-1_Pair-DINO-Feat-G25-C50 \
117+
--timestep 50 \
118+
--layer 1
119+
```
120+
121+
## Acknowledgement
122+
[Kai Han](https://www.kaihan.org/) is supported by Hong Kong Research
123+
Grant Council - Early Career Scheme (Grant No. 27208022), National Natural Science Foundation of
124+
China (Grant No. 62306251), and HKU Seed Fund for Basic Research.
125+
126+
We also sincerely thank [Zirui Wang](https://scholar.google.com/citations?user=zCBKqa8AAAAJ&hl=en) for his inspiring discussion.
127+
128+
## Citation
129+
```
130+
@misc{li2023sd4match,
131+
title={SD4Match: Learning to Prompt Stable Diffusion Model for Semantic Matching},
132+
author={Xinghui Li and Jingyi Lu and Kai Han and Victor Prisacariu},
133+
year={2023},
134+
eprint={2310.17569},
135+
archivePrefix={arXiv},
136+
primaryClass={cs.CV}
137+
}
138+
```

asset/.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
*
2+
!.gitignore

config/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .base import get_default_defaults

config/base.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from yacs.config import CfgNode as CN
2+
3+
_CN = CN()
4+
5+
# dataset configuration
6+
_CN.DATASET = CN()
7+
_CN.DATASET.NAME = 'spair'
8+
_CN.DATASET.ROOT = 'asset/'
9+
_CN.DATASET.IMG_SIZE = 768
10+
_CN.DATASET.MEAN = [0.5, 0.5, 0.5]
11+
_CN.DATASET.STD = [0.5, 0.5, 0.5]
12+
13+
# stable diffusion configuration
14+
_CN.STABLE_DIFFUSION = CN()
15+
_CN.STABLE_DIFFUSION.VERSION = '1-5'
16+
_CN.STABLE_DIFFUSION.SAVE_MEMORY = True # if True, less memory used, but lower speed
17+
18+
# feature extractor configuration
19+
_CN.FEATURE_EXTRACTOR = CN()
20+
_CN.FEATURE_EXTRACTOR.METHOD = 'dift' # select between ('dift', 'sd-dino')
21+
_CN.FEATURE_EXTRACTOR.SELECT_TIMESTEP = 261 # select from 1-1000
22+
_CN.FEATURE_EXTRACTOR.SELECT_LAYER = 1 # if use 'dift': select from (0,1,2,3). if use 'sd-dino': select from (0,1,...,11)
23+
_CN.FEATURE_EXTRACTOR.ENSEMBLE_SIZE = 2 # if ensemble_size > 1, the denosing processed are repeated and the feature is the average over multiple trials
24+
25+
_CN.FEATURE_EXTRACTOR.PROMPT_TYPE = 'text'
26+
_CN.FEATURE_EXTRACTOR.ASSET_ROOT = "asset/sd4match/asset" # root to cached asset, like cached clip or dino feature
27+
_CN.FEATURE_EXTRACTOR.PROMPT_CACHE_ROOT = "asset/sd4match/prompt"
28+
_CN.FEATURE_EXTRACTOR.LOG_ROOT = "asset/sd4match/log"
29+
30+
_CN.FEATURE_EXTRACTOR.ENABLE_L2_NORM = True
31+
_CN.FEATURE_EXTRACTOR.FUSE_DINO = False
32+
33+
# Evaluator configuration
34+
_CN.EVALUATOR = CN()
35+
_CN.EVALUATOR.ALPHA = [0.05, 0.1, 0.15]
36+
_CN.EVALUATOR.BY = 'image' # select between ('image', 'point'), PCK per image or PCK per point
37+
38+
def get_default_defaults():
39+
"""Get a yacs CfgNode object with default values for my_project."""
40+
# Return a clone so that the defaults will not be altered
41+
# This is for the "local variable" use pattern
42+
return _CN.clone()
43+
44+
def convert_config_to_dict(config):
45+
if not isinstance(config, CN):
46+
return config
47+
return {k: convert_config_to_dict(v) for k, v in config.items()}

config/dift.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from yacs.config import CfgNode as CN
2+
3+
_CN = CN()
4+
5+
# dataset configuration
6+
_CN.DATASET = CN()
7+
_CN.DATASET.NAME = 'spair'
8+
_CN.DATASET.ROOT = 'asset/' # '/home/xinghui/storage'
9+
_CN.DATASET.IMG_SIZE = 768
10+
_CN.DATASET.MEAN = [0.5, 0.5, 0.5]
11+
_CN.DATASET.STD = [0.5, 0.5, 0.5]
12+
13+
# stable diffusion configuration
14+
_CN.STABLE_DIFFUSION = CN()
15+
_CN.STABLE_DIFFUSION.VERSION = '2-1'
16+
_CN.STABLE_DIFFUSION.SAVE_MEMORY = True # if True, less memory used, but lower speed
17+
18+
# feature extractor configuration
19+
_CN.FEATURE_EXTRACTOR = CN()
20+
_CN.FEATURE_EXTRACTOR.METHOD = 'dift' # select between ('dift', 'sd-dino')
21+
_CN.FEATURE_EXTRACTOR.SELECT_TIMESTEP = 261 # select from 1-1000
22+
_CN.FEATURE_EXTRACTOR.SELECT_LAYER = 1 # if use 'dift': select from (0,1,2,3). if use 'sd-dino': select from (0,1,...,11)
23+
_CN.FEATURE_EXTRACTOR.ENSEMBLE_SIZE = 8 # if ensemble_size > 1, the denosing processed are repeated and the feature is the average over multiple trials
24+
25+
_CN.FEATURE_EXTRACTOR.FUSE_DINO = False
26+
_CN.FEATURE_EXTRACTOR.ENABLE_L2_NORM = False
27+
28+
_CN.FEATURE_EXTRACTOR.PROMPT_TYPE = 'text'
29+
cfg = _CN

config/fuseDino.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from yacs.config import CfgNode as CN
2+
3+
_CN = CN()
4+
5+
# dataset configuration
6+
_CN.DATASET = CN()
7+
_CN.DATASET.NAME = 'spair'
8+
_CN.DATASET.ROOT = 'asset/' # '/home/xinghui/storage'
9+
_CN.DATASET.IMG_SIZE = 768
10+
_CN.DATASET.MEAN = [0.5, 0.5, 0.5]
11+
_CN.DATASET.STD = [0.5, 0.5, 0.5]
12+
13+
# stable diffusion configuration
14+
_CN.STABLE_DIFFUSION = CN()
15+
_CN.STABLE_DIFFUSION.VERSION = '2-1'
16+
_CN.STABLE_DIFFUSION.SAVE_MEMORY = True # if True, less memory used, but lower speed
17+
18+
# feature extractor configuration
19+
_CN.FEATURE_EXTRACTOR = CN()
20+
_CN.FEATURE_EXTRACTOR.METHOD = 'dift' # select between ('dift', 'sd-dino')
21+
_CN.FEATURE_EXTRACTOR.SELECT_TIMESTEP = 261 # select from 1-1000
22+
_CN.FEATURE_EXTRACTOR.SELECT_LAYER = 1 # if use 'dift': select from (0,1,2,3). if use 'sd-dino': select from (0,1,...,11)
23+
_CN.FEATURE_EXTRACTOR.ENSEMBLE_SIZE = 8 # if ensemble_size > 1, the denosing processed are repeated and the feature is the average over multiple trials
24+
25+
_CN.FEATURE_EXTRACTOR.FUSE_DINO = True
26+
_CN.FEATURE_EXTRACTOR.ENABLE_L2_NORM = False
27+
28+
_CN.FEATURE_EXTRACTOR.PROMPT_TYPE = 'text'
29+
cfg = _CN

config/learnedToken.py

+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from yacs.config import CfgNode as CN
2+
3+
_CN = CN()
4+
5+
# dataset configuration
6+
_CN.DATASET = CN()
7+
_CN.DATASET.NAME = 'spair'
8+
_CN.DATASET.ROOT = 'asset/' # '/home/xinghui/storage'
9+
_CN.DATASET.IMG_SIZE = 768
10+
_CN.DATASET.MEAN = [0.5, 0.5, 0.5]
11+
_CN.DATASET.STD = [0.5, 0.5, 0.5]
12+
13+
# stable diffusion configuration
14+
_CN.STABLE_DIFFUSION = CN()
15+
_CN.STABLE_DIFFUSION.VERSION = '2-1'
16+
_CN.STABLE_DIFFUSION.SAVE_MEMORY = True # if True, less memory used, but lower speed
17+
18+
# feature extractor configuration
19+
_CN.FEATURE_EXTRACTOR = CN()
20+
_CN.FEATURE_EXTRACTOR.METHOD = 'dift' # select between ('dift', 'sd-dino')
21+
_CN.FEATURE_EXTRACTOR.SELECT_TIMESTEP = 261 # select from 1-1000
22+
_CN.FEATURE_EXTRACTOR.SELECT_LAYER = 1 # if use 'dift': select from (0,1,2,3). if use 'sd-dino': select from (0,1,...,11)
23+
_CN.FEATURE_EXTRACTOR.ENSEMBLE_SIZE = 8 # if ensemble_size > 1, the denosing processed are repeated and the feature is the average over multiple trials
24+
25+
_CN.FEATURE_EXTRACTOR.FUSE_DINO = False
26+
_CN.FEATURE_EXTRACTOR.ENABLE_L2_NORM = True
27+
28+
_CN.FEATURE_EXTRACTOR.PROMPT_TYPE = 'text'
29+
cfg = _CN

dataset/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .pfpascal import PFPascalDataset, PFPascalImageDataset
2+
from .pfwillow import PFWillowDataset, PFWillowImageDataset
3+
from .spair import SPairDataset, SPairImageDataset

0 commit comments

Comments
 (0)