-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathretrieve_embed.py
More file actions
140 lines (110 loc) · 5.59 KB
/
retrieve_embed.py
File metadata and controls
140 lines (110 loc) · 5.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
import os
import numpy as np
from PIL import Image
from torchvision import transforms
import json
from models import CNN
class EmbeddingExtractor:
def __init__(self, model_path):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize with the same model architecture
self.model = CNN(num_classes=10)
# Load the state dict
self.model.load_state_dict(torch.load(model_path, weights_only=True))
self.model.eval()
self.model.to(self.device)
# Initialize embedding storage
self.embedding = None
# Hook function to capture embeddings after global pooling
def hook(module, input, output):
# Output shape will be [batch_size, 128, 1, 1]
# We want to capture this before the view operation
self.embedding = output.detach()
# Register hook on the global pooling layer
self.model.global_pool.register_forward_hook(hook)
# Transform pipeline - matching the CIFAR10 preprocessing
self.transform = transforms.Compose([
transforms.Resize((32, 32)), # CIFAR10 image size
transforms.ToTensor(),
# If you used normalization during training, uncomment these lines:
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.embeddings = {}
def extract_embedding(self, image_path):
"""Extract embedding for a single image"""
image = Image.open(image_path).convert('RGB')
image = self.transform(image).unsqueeze(0).to(self.device)
# Forward pass to trigger the hook
with torch.no_grad():
_ = self.model(image)
# Reshape the embedding to match the architecture [128] features
embedding = self.embedding.squeeze().cpu().numpy()
return embedding
def process_directory(self, input_dir, output_dir):
"""Process all images in directory structure and save embeddings"""
os.makedirs(output_dir, exist_ok=True)
# Process each class directory
for class_dir in range(10):
class_path = os.path.join(input_dir, str(class_dir))
output_class_path = os.path.join(output_dir, str(class_dir))
os.makedirs(output_class_path, exist_ok=True)
if not os.path.exists(class_path):
continue
print(f"Processing class {class_dir}...")
for image_file in os.listdir(class_path):
if image_file.lower().endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(class_path, image_file)
try:
# Extract embedding
embedding = self.extract_embedding(image_path)
# Save embedding - should be 128 dimensional
output_path = os.path.join(output_class_path,
os.path.splitext(image_file)[0] + '.npy')
np.save(output_path, embedding)
# Store metadata
self.embeddings[image_file] = {
'class': class_dir,
'embedding_path': output_path,
'original_path': image_path,
'embedding_size': embedding.shape
}
except Exception as e:
print(f"Error processing {image_path}: {str(e)}")
def save_metadata(self, output_dir):
"""Save metadata for all processed images"""
metadata_path = os.path.join(output_dir, 'embedding_metadata.json')
with open(metadata_path, 'w') as f:
json.dump(self.embeddings, f, indent=2)
def create_ml_dataset(self, output_dir):
"""Create a consolidated dataset suitable for ML training"""
features = []
labels = []
image_ids = []
for image_file, metadata in self.embeddings.items():
embedding = np.load(metadata['embedding_path'])
features.append(embedding.flatten()) # Should already be 128-dimensional
labels.append(metadata['class'])
image_ids.append(image_file)
X = np.array(features)
y = np.array(labels)
np.save(os.path.join(output_dir, 'X_features.npy'), X)
np.save(os.path.join(output_dir, 'y_labels.npy'), y)
with open(os.path.join(output_dir, 'image_ids.json'), 'w') as f:
json.dump(image_ids, f)
print(f"Dataset created with features shape: {X.shape}")
print(f"Each feature vector represents 128 dimensions from the global pooling layer")
return X, y
# Usage example
if __name__ == "__main__":
model_path = "/Users/jinjiahui/Desktop/CS470Project/models/target_model.mod"
input_dir = "adversarial_cifar10_epsilon_0.005"
output_dir = "embed_adversarial_cifar10_epsilon_0.005"
extractor = EmbeddingExtractor(model_path)
extractor.process_directory(input_dir, output_dir)
extractor.save_metadata(output_dir)
X, y = extractor.create_ml_dataset(output_dir)
print(f"Created dataset with shape: {X.shape}, {y.shape}")
# data = np.load("/Users/jinjiahui/Desktop/CS470Project/embed_cifar10_test_images_by_class/0/test_image_00010.npy")
# print(data)