From 642a981565fbb68ab11e8853f6996c7caeaecbf6 Mon Sep 17 00:00:00 2001 From: Ben Cherry Date: Mon, 26 Jun 2023 16:56:23 -0700 Subject: [PATCH 1/2] Update autocast to use CPU for MPS type --- clip_interrogator/clip_interrogator.py | 12 ++++++------ test.py | 7 +++++++ 2 files changed, 13 insertions(+), 6 deletions(-) create mode 100644 test.py diff --git a/clip_interrogator/clip_interrogator.py b/clip_interrogator/clip_interrogator.py index 02a9b896..77a902cb 100644 --- a/clip_interrogator/clip_interrogator.py +++ b/clip_interrogator/clip_interrogator.py @@ -197,7 +197,7 @@ def generate_caption(self, pil_image: Image) -> str: def image_to_features(self, image: Image) -> torch.Tensor: self._prepare_clip() images = self.clip_preprocess(image).unsqueeze(0).to(self.device) - with torch.no_grad(), torch.cuda.amp.autocast(): + with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'): image_features = self.clip_model.encode_image(images) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features @@ -257,7 +257,7 @@ def interrogate(self, image: Image, min_flavors: int=8, max_flavors: int=32, cap def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: bool=False) -> str: self._prepare_clip() text_tokens = self.tokenize([text for text in text_array]).to(self.device) - with torch.no_grad(), torch.cuda.amp.autocast(): + with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'): text_features = self.clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T @@ -268,7 +268,7 @@ def rank_top(self, image_features: torch.Tensor, text_array: List[str], reverse: def similarity(self, image_features: torch.Tensor, text: str) -> float: self._prepare_clip() text_tokens = self.tokenize([text]).to(self.device) - with torch.no_grad(), torch.cuda.amp.autocast(): + with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'): text_features = self.clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T @@ -277,7 +277,7 @@ def similarity(self, image_features: torch.Tensor, text: str) -> float: def similarities(self, image_features: torch.Tensor, text_array: List[str]) -> List[float]: self._prepare_clip() text_tokens = self.tokenize([text for text in text_array]).to(self.device) - with torch.no_grad(), torch.cuda.amp.autocast(): + with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'): text_features = self.clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) similarity = text_features @ image_features.T @@ -319,7 +319,7 @@ def __init__(self, labels:List[str], desc:str, ci: Interrogator): chunks = np.array_split(self.labels, max(1, len(self.labels)/config.chunk_size)) for chunk in tqdm(chunks, desc=f"Preprocessing {desc}" if desc else None, disable=self.config.quiet): text_tokens = self.tokenize(chunk).to(self.device) - with torch.no_grad(), torch.cuda.amp.autocast(): + with torch.no_grad(), torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'): text_features = clip_model.encode_text(text_tokens) text_features /= text_features.norm(dim=-1, keepdim=True) text_features = text_features.half().cpu().numpy() @@ -373,7 +373,7 @@ def _load_cached(self, desc:str, hash:str, sanitized_name:str) -> bool: def _rank(self, image_features: torch.Tensor, text_embeds: torch.Tensor, top_count: int=1, reverse: bool=False) -> str: top_count = min(top_count, len(text_embeds)) text_embeds = torch.stack([torch.from_numpy(t) for t in text_embeds]).to(self.device) - with torch.cuda.amp.autocast(): + with torch.amp.autocast(device_type='cuda' if self.device == 'cuda' else 'cpu'): similarity = image_features @ text_embeds.T if reverse: similarity = -similarity diff --git a/test.py b/test.py new file mode 100644 index 00000000..1f957727 --- /dev/null +++ b/test.py @@ -0,0 +1,7 @@ +import sys +from PIL import Image +from clip_interrogator import Config, Interrogator +# import torch +image = Image.open('../pushd-gpt/output/resized.jpg').convert('RGB') +ci = Interrogator(Config(clip_model_name="ViT-H-14/laion2b_s32b_b79k", device='mps')) +print(ci.interrogate(image)) From 21392325e0a37ee4c367074d17d3c317f6ffd6b9 Mon Sep 17 00:00:00 2001 From: Ben Cherry Date: Mon, 26 Jun 2023 17:02:44 -0700 Subject: [PATCH 2/2] Delete test.py --- test.py | 7 ------- 1 file changed, 7 deletions(-) delete mode 100644 test.py diff --git a/test.py b/test.py deleted file mode 100644 index 1f957727..00000000 --- a/test.py +++ /dev/null @@ -1,7 +0,0 @@ -import sys -from PIL import Image -from clip_interrogator import Config, Interrogator -# import torch -image = Image.open('../pushd-gpt/output/resized.jpg').convert('RGB') -ci = Interrogator(Config(clip_model_name="ViT-H-14/laion2b_s32b_b79k", device='mps')) -print(ci.interrogate(image))