diff --git a/clipvision_db.py b/clipvision_db.py index b702012..b29e2ca 100644 --- a/clipvision_db.py +++ b/clipvision_db.py @@ -22,18 +22,18 @@ def __init__(self): @classmethod def INPUT_TYPES(s): - return { - "required": { + return { + "required": { "db_name": (folder_paths.get_filename_list("EmbDBs"),) , "path_to_images_folder": ("STRING", { "multiline": False, "default": "/path/to/folder/with/images", "tooltip": "Basepath to the folder containing the images" - }), + }), }, "optional": { "img_db": ("LoadDB",), - } + } } RETURN_TYPES = ("LoadDB", ) @@ -52,7 +52,7 @@ def load_DB(self, db_name, path_to_images_folder, img_db=None): Returns: A LoadDB object. - """ + """ db_path = folder_paths.get_full_path_or_raise("EmbDBs", db_name) if self.LOADED_DB_Name != db_path or path_to_images_folder != self.images_folder: try: @@ -68,12 +68,12 @@ def load_DB(self, db_name, path_to_images_folder, img_db=None): except Exception as e: print(f"Error loading database file: {e}") return None, - + if img_db is not None: self.LOADED_DB = self.MY_LOADED_DB + img_db.LOADED_DB else: self.LOADED_DB = self.MY_LOADED_DB - return self, + return self, class EditDB: @@ -87,20 +87,20 @@ def __init__(self): @classmethod def INPUT_TYPES(s): - return { - "required": { + return { + "required": { "img_db": ("LoadDB",), "method": (["exclude", "filter", "replace"], {"default": "exclude", "tooltip": "Method to edit results"}), "edit_text": ("STRING", { "multiline": False, "default": "remove*images.jpg", "tooltip": "Use wildcards (*) to match filenames or paths" - },), + },), "replace_text": ("STRING", { "multiline": False, "default": "/newpath/", "tooltip": "Text to replace matched text with when using 'replace' method" - },), + },), } } @@ -132,7 +132,7 @@ def Edit_DB(self, img_db:LoadDB, method, edit_text, replace_text): new_name = file_name.replace(pat.strip("*"), replace_text) NewDB.LOADED_DB.append((new_name, embeddings)) - return NewDB, + return NewDB, class GenerateDB: """ @@ -146,7 +146,7 @@ def __init__(self): @classmethod def INPUT_TYPES(s): return { - "required": { + "required": { "clip_vision": ("CLIP_VISION",), "path_to_images_folder": ("STRING", { "multiline": False, @@ -163,7 +163,7 @@ def INPUT_TYPES(s): "unique_id": "UNIQUE_ID", } } - + RETURN_TYPES = ("STRING",) RETURN_NAMES = ("ERRORS",) OUTPUT_NODE = True @@ -181,13 +181,15 @@ def start_gen_db(self, clip_vision, path_to_images_folder, new_db_name, unique_i Returns: A string containing any errors that occurred. - """ + """ + # Make sure the database name ends with .json + if not new_db_name.endswith('.json'): + new_db_name += '.json' + path_to_images = Path(path_to_images_folder) path_to_database = Path(folder_paths.get_folder_paths("EmbDBs")[0]) / new_db_name - + path_to_database.parent.mkdir(exist_ok=True) - + errors = generate_clip_features_json(clip_vision, path_to_images, path_to_database, unique_id) return errors, - -