1+ import torch
12from fastdup .sentry import fastdup_capture_exception
23from fastdup .definitions import MISSING_LABEL
34from fastdup .galleries import fastdup_imread
45import cv2
6+ from tqdm import tqdm
57
8+ device_to_captioner = {}
9+
10+ def init_captioning (model_name = 'automatic' , device = 'cpu' , batch_size = 8 , max_new_tokens = 20 ,
11+ use_float_16 = True ):
612
7- def generate_labels (filenames , model_name = 'automatic' , device = 'cpu' , batch_size = 8 ):
813 '''
914 This function generates captions for a given set of images, and takes the following arguments:
1015 - filenames: the list of images passed to the function
@@ -15,64 +20,82 @@ def generate_labels(filenames, model_name='automatic', device = 'cpu', batch_siz
1520 - BLIP: 'blip'
1621 - batch_size: the size of image batches to caption (default: 8)
1722 - device: whether to use a GPU (default: -1, CPU only ; set to 0 for GPU)
23+ - max_bew_tokens: set the number of allowed tokens
1824 '''
25+
26+ global device_to_captioner
1927 # use GPU if device is specified
2028 if device == 'gpu' :
2129 device = 0
2230 elif device == 'cpu' :
2331 device = - 1
32+ use_float_16 = False
2433 else :
25- assert False , "Incompatible device name entered. Available device names are gpu and cpu."
34+ assert False , "Incompatible device name entered {device} . Available device names are gpu and cpu."
2635
2736 # confirm necessary dependencies are installed, and import them
2837 try :
2938 from transformers import pipeline
3039 from transformers .utils import logging
31- logging .set_verbosity_info ()
32- import torch
33- from PIL import Image
34- from tqdm import tqdm
40+ logging .set_verbosity (50 )
41+
3542 except Exception as e :
3643 fastdup_capture_exception ("Auto generate labels" , e )
3744 print ("Auto captioning requires an installation of the following libraries:\n " )
38- print (" huggingface transformers\n pytorch\n pillow \n tqdm \n " )
39- print ("to install, use `pip install transformers torch pillow tqdm `" )
40- return [ MISSING_LABEL ] * len ( filenames )
45+ print (" huggingface transformers\n pytorch\n " )
46+ print (" to install, use `pip3 install transformers torch`" )
47+ raise
4148
4249 # dictionary of captioning models
4350 models = {
4451 'automatic' : "nlpconnect/vit-gpt2-image-captioning" ,
4552 'vitgpt2' : "nlpconnect/vit-gpt2-image-captioning" ,
46- 'blip2 ' : "Salesforce/blip2-opt-2.7b" ,
53+ 'blip-2 ' : "Salesforce/blip2-opt-2.7b" ,
4754 'blip' : "Salesforce/blip-image-captioning-large"
4855 }
49-
56+ assert model_name in models . keys (), f"Unknown captioning model { model_name } allowed models are { models . keys () } "
5057 model = models [model_name ]
58+ has_gpu = torch .cuda .is_available ()
59+ captioner = pipeline ("image-to-text" , model = model , device = device if has_gpu else "cpu" , max_new_tokens = max_new_tokens ,
60+ torch_dtype = torch .float16 if use_float_16 else torch .float32 )
61+ device_to_captioner [device ] = captioner
5162
52- # generate captions
53- try :
54- captioner = pipeline ("image-to-text" , model = model , device = device )
55-
56- captions = []
57-
58- for pred in captioner (filenames , batch_size = batch_size ):
59- #caption = pred['generated_text']
60- caption = '' .join ([d ['generated_text' ] for d in pred ])
61- captions .append (caption )
63+ return captioner
6264
65+ def generate_labels (filenames , model_name = 'automatic' , device = 'cpu' , batch_size = 8 , max_new_tokens = 20 , use_float_16 = True ):
66+ global device_to_captioner
67+ if device not in device_to_captioner :
68+ captioner = init_captioning (model_name , device , batch_size , max_new_tokens , use_float_16 )
69+ else :
70+ captioner = device_to_captioner [device ]
6371
64- '''for image_path in tqdm(filenames):
65- img = Image.open(image_path)
66- pred = captioner(img)
67- caption = pred[0]['generated_text']
68- captions.append(caption)'''
69- return captions
70-
72+ captions = []
73+ # generate captions
74+ try :
75+ for i in tqdm (range (0 , len (filenames ), batch_size )):
76+ chunk = filenames [i :i + batch_size ]
77+ try :
78+ for pred in captioner (chunk , batch_size = batch_size ):
79+ charstring = '' if model_name != 'blip' else ' '
80+ caption = charstring .join ([d ['generated_text' ] for d in pred ])
81+ # Split the sentence into words
82+ words = caption .split ()
83+ # Filter out words containing '#'
84+ filtered_words = [word for word in words if '#' not in word ]
85+ # Join the filtered words back into a sentence
86+ caption = ' ' .join (filtered_words )
87+ caption = caption .strip ()
88+ captions .append (caption )
89+ except Exception as ex :
90+ print ("Failed to caption chunk" , chunk [:5 ], ex )
91+ captions .extend ([MISSING_LABEL ] * len (chunk ))
7192
7293 except Exception as e :
7394 fastdup_capture_exception ("Auto caption image" , e )
7495 return [MISSING_LABEL ] * len (filenames )
7596
97+ return captions
98+
7699
77100def generate_vqa_labels (filenames , text , kwargs ):
78101 # confirm necessary dependencies are installed, and import them
@@ -156,3 +179,15 @@ def generate_age_labels(filenames, kwargs):
156179 fastdup_capture_exception ("Age label" , e )
157180 return [MISSING_LABEL ] * len (filenames )
158181
182+ if __name__ == "__main__" :
183+ import fastdup
184+ from fastdup .captions import generate_labels
185+ file = "/Users/dannybickson/visual_database/cxx/unittests/two_images/"
186+ import os
187+ files = os .listdir (file )
188+ files = [os .path .join (file , f ) for f in files ]
189+ ret = generate_labels (files , model_name = 'blip' )
190+ assert (len (ret ) == 2 )
191+ print (ret )
192+ for r in ret :
193+ assert "shelf" in r or "shelves" in r or "store" in r
0 commit comments