Skip to content

Commit ba1043b

Browse files
committed
first commit
1 parent 51f5f2b commit ba1043b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

76 files changed

+5830
-3
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
/result/
2+
/.idea/
3+
/__pycache__/
4+
/weights/

README.md

Lines changed: 175 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,176 @@
1-
# AdaCLIP
2-
[ECCV2024] The Official Implementation for ''AdaCLIP: Adapting CLIP with Hybrid Learnable Prompts for Zero-Shot Anomaly Detection''
1+
# AdaCLIP (Detecting Anomalies for Novel Categories)
2+
[![HuggingFace Space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)]()
33

4-
Code will be released at the end of July 2024.
4+
> [**ECCV 24**] [**AdaCLIP: Adapting CLIP with Hybrid Learnable Prompts for Zero-Shot Anomaly Detection**]().
5+
>
6+
> by [Yunkang Cao](https://caoyunkang.github.io/), [Jiangning Zhang](https://zhangzjn.github.io/), [Luca Frittoli](https://scholar.google.com/citations?user=cdML_XUAAAAJ),
7+
> [Yuqi Cheng](https://scholar.google.com/citations?user=02BC-WgAAAAJ&hl=en), [Weiming Shen](https://scholar.google.com/citations?user=FuSHsx4AAAAJ&hl=en), [Giacomo Boracchi](https://boracchi.faculty.polimi.it/)
8+
>
9+
10+
## Introduction
11+
Zero-shot anomaly detection (ZSAD) targets the identification of anomalies within images from arbitrary novel categories.
12+
This study introduces AdaCLIP for the ZSAD task, leveraging a pre-trained vision-language model (VLM), CLIP.
13+
AdaCLIP incorporates learnable prompts into CLIP and optimizes them through training on auxiliary annotated anomaly detection data.
14+
Two types of learnable prompts are proposed: \textit{static} and \textit{dynamic}. Static prompts are shared across all images, serving to preliminarily adapt CLIP for ZSAD.
15+
In contrast, dynamic prompts are generated for each test image, providing CLIP with dynamic adaptation capabilities.
16+
The combination of static and dynamic prompts is referred to as hybrid prompts, and yields enhanced ZSAD performance.
17+
Extensive experiments conducted across 14 real-world anomaly detection datasets from industrial and medical domains indicate that AdaCLIP outperforms other ZSAD methods and can generalize better to different categories and even domains.
18+
Finally, our analysis highlights the importance of diverse auxiliary data and optimized prompts for enhanced generalization capacity.
19+
20+
## Overview of AdaCLIP
21+
![overview](asset/framework.png)
22+
23+
## 🛠️ Getting Started
24+
25+
### Installation
26+
To set up the AdaCLIP environment, follow one of the methods below:
27+
28+
- Clone this repo:
29+
```shell
30+
git clone https://github.com/caoyunkang/AdaCLIP.git && cd AdaCLIP
31+
```
32+
- You can use our provided installation script for an automated setup::
33+
```shell
34+
sh install.sh
35+
```
36+
- If you prefer to construct the experimental environment manually, follow these steps:
37+
```shell
38+
conda create -n AdaCLIP python=3.9.5 -y
39+
conda activate AdaCLIP
40+
pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
41+
pip install tqdm tensorboard setuptools==58.0.4 opencv-python scikit-image scikit-learn matplotlib seaborn ftfy regex numpy==1.26.4
42+
pip install gradio # Optional, for app
43+
```
44+
- Remember to update the dataset root in config.py according to your preference:
45+
```python
46+
DATA_ROOT = '../datasets' # Original setting
47+
```
48+
49+
### Dataset Preparation
50+
Please download our processed visual anomaly detection datasets to your `DATA_ROOT` as needed.
51+
52+
#### Industrial Visual Anomaly Detection Datasets
53+
Note: some links are still in processing...
54+
55+
| Dataset | Google Drive | Baidu Drive | Task
56+
|------------|------------------|------------------| ------------------|
57+
| MVTec AD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
58+
| VisA | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
59+
| MPDD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
60+
| BTAD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
61+
| KSDD | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
62+
| DAGM | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
63+
| DTD-Synthetic | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection & Localization |
64+
65+
66+
67+
68+
#### Medical Visual Anomaly Detection Datasets
69+
| Dataset | Google Drive | Baidu Drive | Task
70+
|------------|------------------|------------------| ------------------|
71+
| HeadCT | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
72+
| BrainMRI | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
73+
| Br35H | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Detection |
74+
| ISIC | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
75+
| ColonDB | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
76+
| ClinicDB | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
77+
| TN3K | [Google Drive](链接) | [Baidu Drive](链接) | Anomaly Localization |
78+
79+
#### Custom Datasets
80+
To use your custom dataset, follow these steps:
81+
82+
1. Refer to the instructions in `./data_preprocess` to generate the JSON file for your dataset.
83+
2. Use `./dataset/base_dataset.py` to construct your own dataset.
84+
85+
86+
### Weight Preparation
87+
88+
We offer various pre-trained weights on different auxiliary datasets.
89+
Please download the pre-trained weights in `./weights`.
90+
91+
| Pre-trained Datasets | Google Drive | Baidu Drive
92+
|------------|------------------|------------------|
93+
| MVTec AD & ClinicDB | [Google Drive](https://drive.google.com/file/d/1xVXANHGuJBRx59rqPRir7iqbkYzq45W0/view?usp=drive_link) | [Baidu Drive](链接) |
94+
| VisA & ColonDB | [Google Drive](https://drive.google.com/file/d/1QGmPB0ByPZQ7FucvGODMSz7r5Ke5wx9W/view?usp=drive_link) | [Baidu Drive](链接) |
95+
| All Datasets Mentioned Above | [Google Drive](https://drive.google.com/file/d/1Cgkfx3GAaSYnXPLolx-P7pFqYV0IVzZF/view?usp=drive_link) | [Baidu Drive](链接) |
96+
97+
98+
### Train
99+
100+
By default, we use MVTec AD & ClinicDB for training and VisA for validation:
101+
```shell
102+
CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data mvtec colondb --testing_data visa
103+
```
104+
105+
106+
Alternatively, for evaluation on MVTec AD & ClinicDB, we use VisA & ColonDB for training and MVTec AD for validation.
107+
```shell
108+
CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True --training_data visa clinicdb --testing_data mvtec
109+
```
110+
Since we have utilized half-precision (FP16) for training, the training process can occasionally be unstable.
111+
It is recommended to run the training process multiple times and choose the best model based on performance
112+
on the validation set as the final model.
113+
114+
115+
To construct a robust ZSAD model for demonstration, we also train our AdaCLIP on all AD datasets mentioned above:
116+
```shell
117+
CUDA_VISIBLE_DEVICES=0 python train.py --save_fig True \
118+
--training_data \
119+
br35h brain_mri btad clinicdb colondb \
120+
dagm dtd headct isic mpdd mvtec sdd tn3k visa \
121+
--testing_data mvtec
122+
```
123+
124+
### Test
125+
126+
Manually select the best models from the validation set and place them in the `weights/` directory. Then, run the following testing script:
127+
```shell
128+
sh test.sh
129+
```
130+
131+
If you want to test on a single image, you can refer to `test_single_image.sh`:
132+
```shell
133+
CUDA_VISIBLE_DEVICES=0 python test.py --testing_model image --ckt_path weights/pretrained_all.pth --save_fig True \
134+
--image_path asset/img.png --class_name candle --save_name test.png
135+
```
136+
137+
## Main Results
138+
139+
Due to differences in versions utilized, the reported performance may vary slightly compared to the detection performance
140+
with the provided pre-trained weights. Some categories may show higher performance while others may show lower.
141+
142+
![Table_industrial](./asset/Table_industrial.png)
143+
![Table_medical](./asset/Table_medical.png)
144+
![Fig_detection_results](./asset/Fig_detection_results.png)
145+
146+
### :page_facing_up: Demo App
147+
148+
To run the demo application, use the following command:
149+
150+
```bash
151+
python app.py
152+
```
153+
154+
![Demo](./asset/Fig_app.png)
155+
156+
## 💘 Acknowledgements
157+
Our work is largely inspired by the following projects. Thanks for their admiring contribution.
158+
159+
- [VAND-APRIL-GAN](https://github.com/ByChelsea/VAND-APRIL-GAN)
160+
- [AnomalyCLIP](https://github.com/zqhang/AnomalyCLIP)
161+
- [SAA](https://github.com/caoyunkang/Segment-Any-Anomaly)
162+
163+
164+
## Stargazers over time
165+
[![Stargazers over time](https://starchart.cc/caoyunkang/AdaCLIP.svg?variant=adaptive)](https://starchart.cc/caoyunkang/AdaCLIP)
166+
167+
168+
## Citation
169+
170+
If you find this project helpful for your research, please consider citing the following BibTeX entry.
171+
172+
```BibTex
173+
174+
175+
176+
```

app.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import gradio as gr
2+
from PIL import Image, ImageDraw, ImageFont
3+
import warnings
4+
import os
5+
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
6+
import json
7+
import os
8+
import torch
9+
from scipy.ndimage import gaussian_filter
10+
import cv2
11+
from method import AdaCLIP_Trainer
12+
import numpy as np
13+
14+
############ Init Model
15+
ckt_path1 = 'weights/pretrained_mvtec_colondb.pth'
16+
ckt_path2 = "weights/pretrained_visa_clinicdb.pth"
17+
ckt_path3 = 'weights/pretrained_all.pth'
18+
19+
# Configurations
20+
image_size = 518
21+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
22+
# device = 'cpu'
23+
model = "ViT-L-14-336"
24+
prompting_depth = 4
25+
prompting_length = 5
26+
prompting_type = 'SD'
27+
prompting_branch = 'VL'
28+
use_hsf = True
29+
k_clusters = 20
30+
31+
config_path = os.path.join('./model_configs', f'{model}.json')
32+
33+
# Prepare model
34+
with open(config_path, 'r') as f:
35+
model_configs = json.load(f)
36+
37+
# Set up the feature hierarchy
38+
n_layers = model_configs['vision_cfg']['layers']
39+
substage = n_layers // 4
40+
features_list = [substage, substage * 2, substage * 3, substage * 4]
41+
42+
model = AdaCLIP_Trainer(
43+
backbone=model,
44+
feat_list=features_list,
45+
input_dim=model_configs['vision_cfg']['width'],
46+
output_dim=model_configs['embed_dim'],
47+
learning_rate=0.,
48+
device=device,
49+
image_size=image_size,
50+
prompting_depth=prompting_depth,
51+
prompting_length=prompting_length,
52+
prompting_branch=prompting_branch,
53+
prompting_type=prompting_type,
54+
use_hsf=use_hsf,
55+
k_clusters=k_clusters
56+
).to(device)
57+
58+
59+
def process_image(image, text, options):
60+
# Load the model based on selected options
61+
if 'MVTec AD+Colondb' in options:
62+
model.load(ckt_path1)
63+
elif 'VisA+Clinicdb' in options:
64+
model.load(ckt_path2)
65+
elif 'All' in options:
66+
model.load(ckt_path3)
67+
else:
68+
# Default to 'All' if no valid option is provided
69+
model.load(ckt_path3)
70+
print('Invalid option. Defaulting to All.')
71+
72+
# Ensure image is in RGB mode
73+
image = image.convert('RGB')
74+
75+
# Convert PIL image to NumPy array
76+
np_image = np.array(image)
77+
78+
# Convert RGB to BGR for OpenCV
79+
np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
80+
np_image = cv2.resize(np_image, (image_size, image_size))
81+
# Preprocess the image and run the model
82+
img_input = model.preprocess(image).unsqueeze(0)
83+
img_input = img_input.to(model.device)
84+
85+
with torch.no_grad():
86+
anomaly_map, anomaly_score = model.clip_model(img_input, [text], aggregation=True)
87+
88+
# Process anomaly map
89+
anomaly_map = anomaly_map[0, :, :].cpu().numpy()
90+
anomaly_score = anomaly_score[0].cpu().numpy()
91+
anomaly_map = gaussian_filter(anomaly_map, sigma=4)
92+
anomaly_map = (anomaly_map * 255).astype(np.uint8)
93+
94+
# Apply color map and blend with original image
95+
heat_map = cv2.applyColorMap(anomaly_map, cv2.COLORMAP_JET)
96+
vis_map = cv2.addWeighted(heat_map, 0.5, np_image, 0.5, 0)
97+
98+
# Convert OpenCV image back to PIL image for Gradio
99+
vis_map_pil = Image.fromarray(cv2.cvtColor(vis_map, cv2.COLOR_BGR2RGB))
100+
101+
return vis_map_pil, f'{anomaly_score:.3f}'
102+
103+
# Define examples
104+
examples = [
105+
["asset/img.png", "candle", "MVTec AD+Colondb"],
106+
["asset/img2.png", "bottle", "VisA+Clinicdb"],
107+
["asset/img3.png", "button", "All"],
108+
]
109+
110+
# Gradio interface layout
111+
demo = gr.Interface(
112+
fn=process_image,
113+
inputs=[
114+
gr.Image(type="pil", label="Upload Image"),
115+
gr.Textbox(label="Class Name"),
116+
gr.Radio(["MVTec AD+Colondb",
117+
"VisA+Clinicdb",
118+
"All"],
119+
label="Pre-trained Datasets")
120+
],
121+
outputs=[
122+
gr.Image(type="pil", label="Output Image"),
123+
gr.Textbox(label="Anomaly Score"),
124+
],
125+
examples=examples,
126+
title="AdaCLIP -- Zero-shot Anomaly Detection",
127+
description="Upload an image, enter class name, and select pre-trained datasets to do zero-shot anomaly detection"
128+
)
129+
130+
# Launch the demo
131+
demo.launch()
132+
# demo.launch(server_name="0.0.0.0", server_port=10002)
133+

asset/Fig_app.png

262 KB
Loading

asset/Fig_detection_results.png

355 KB
Loading

asset/Table_industrial.png

392 KB
Loading

asset/Table_medical.png

284 KB
Loading

asset/framework.png

430 KB
Loading

asset/img.png

1.36 MB
Loading

asset/img2.png

535 KB
Loading

asset/img3.png

610 KB
Loading

config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
DATA_ROOT = '../datasets'

data_preprocess/br35h.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import os
2+
import json
3+
import random
4+
from config import DATA_ROOT
5+
6+
Br35h_ROOT = os.path.join(DATA_ROOT, 'Br35h_anomaly_detection')
7+
class Br35hSolver(object):
8+
CLSNAMES = [
9+
'br35h',
10+
]
11+
12+
def __init__(self, root=Br35h_ROOT, train_ratio=0.5):
13+
self.root = root
14+
self.meta_path = f'{root}/meta.json'
15+
self.train_ratio = train_ratio
16+
17+
def run(self):
18+
self.generate_meta_info()
19+
20+
def generate_meta_info(self):
21+
info = dict(train={}, test={})
22+
for cls_name in self.CLSNAMES:
23+
cls_dir = f'{self.root}/{cls_name}'
24+
for phase in ['train', 'test']:
25+
cls_info = []
26+
species = os.listdir(f'{cls_dir}/{phase}')
27+
for specie in species:
28+
is_abnormal = True if specie not in ['good'] else False
29+
img_names = os.listdir(f'{cls_dir}/{phase}/{specie}')
30+
img_names.sort()
31+
32+
for idx, img_name in enumerate(img_names):
33+
info_img = dict(
34+
img_path=f'{cls_name}/{phase}/{specie}/{img_name}',
35+
mask_path=f'',
36+
cls_name=cls_name,
37+
specie_name=specie,
38+
anomaly=1 if is_abnormal else 0,
39+
)
40+
cls_info.append(info_img)
41+
42+
info[phase][cls_name] = cls_info
43+
44+
with open(self.meta_path, 'w') as f:
45+
f.write(json.dumps(info, indent=4) + "\n")
46+
47+
48+
if __name__ == '__main__':
49+
runner = Br35hSolver(root=Br35h_ROOT)
50+
runner.run()

0 commit comments

Comments
 (0)