Skip to content

Commit

Permalink
Allow saving and loading custom label embeddings
Browse files Browse the repository at this point in the history
Adds method to CustomLabelsClassifier to save class labels and
embeddings to a npy file. Adds embeddings_path parameter to
CustomLabelsClassifier to load the labels and embeddings.

This change needs to wait for this PR due to a field name change:
#64

Fixes #17
  • Loading branch information
johnbradley committed Nov 26, 2024
1 parent b3ac523 commit 431695c
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
34 changes: 31 additions & 3 deletions src/bioclip/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,11 +236,39 @@ def create_probabilities_for_images(self, images: List[str] | List[PIL.Image.Ima


class CustomLabelsClassifier(BaseClassifier):
def __init__(self, cls_ary: List[str], **kwargs):
def __init__(self, cls_ary: List[str] = None, embeddings_path: str = None, **kwargs):
super().__init__(**kwargs)
self.tokenizer = create_bioclip_tokenizer(self.model_str)
self.classes = [cls.strip() for cls in cls_ary]
self.txt_features = self._get_txt_features(self.classes)
if embeddings_path:
if cls_ary: raise ValueError("Cannot provide both cls_ary and embeddings_path.")
self.load_embeddings(embeddings_path)
else:
if not cls_ary: raise ValueError("Must provide cls_ary or embeddings_path.")
self.classes = [cls.strip() for cls in cls_ary]
self.txt_features = self._get_txt_features(self.classes)

def save_embeddings(self, path: str):
"""
Save the class labels and text features to a file numpy file.
Parameters:
path (str): The file path where the class labels and embeddings will be saved.
"""
with open(path, 'wb') as outfile:
np.save(outfile, np.array(self.classes))
np.save(outfile, self.txt_features.cpu().numpy())

def load_embeddings(self, path: str):
"""
Load embeddings and class labels from a numpy file created with save_embeddings.
Args:
path (str): The file path to the numpy embeddings file to load.
Raises:
FileNotFoundError: If the specified file does not exist.
IOError: If there is an error reading the file.
"""
with open(path, 'rb') as infile:
self.classes = np.load(infile).tolist()
self.txt_features = torch.from_numpy(np.load(infile)).to(self.device)

@torch.no_grad()
def _get_txt_features(self, classnames):
Expand Down
10 changes: 10 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ def test_apply_filter(self):
self.assertEqual(classifier.get_txt_embeddings().shape, torch.Size([512, 1]))
self.assertEqual(len(classifier.get_current_txt_names()), 1)

def test_save_load_embeddings(self):
classifier = CustomLabelsClassifier(cls_ary=['dog','cat','fish'])
num_labels = len(classifier.classes)
feature_shape = classifier.txt_features.shape
classifier.save_embeddings('/tmp/test_embeddings.npy')
classifier = CustomLabelsClassifier(embeddings_path='/tmp/test_embeddings.npy')
self.assertEqual(len(classifier.classes), num_labels)
self.assertEqual(classifier.txt_features.shape, feature_shape)
classifier.predict(images=[EXAMPLE_CAT_IMAGE2])


class TestEmbed(unittest.TestCase):
def test_get_image_features(self):
Expand Down

0 comments on commit 431695c

Please sign in to comment.