Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 21 additions & 19 deletions clipvision_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,18 +22,18 @@ def __init__(self):

@classmethod
def INPUT_TYPES(s):
return {
"required": {
return {
"required": {
"db_name": (folder_paths.get_filename_list("EmbDBs"),) ,
"path_to_images_folder": ("STRING", {
"multiline": False,
"default": "/path/to/folder/with/images",
"tooltip": "Basepath to the folder containing the images"
}),
}),
},
"optional": {
"img_db": ("LoadDB",),
}
}
}

RETURN_TYPES = ("LoadDB", )
Expand All @@ -52,7 +52,7 @@ def load_DB(self, db_name, path_to_images_folder, img_db=None):

Returns:
A LoadDB object.
"""
"""
db_path = folder_paths.get_full_path_or_raise("EmbDBs", db_name)
if self.LOADED_DB_Name != db_path or path_to_images_folder != self.images_folder:
try:
Expand All @@ -68,12 +68,12 @@ def load_DB(self, db_name, path_to_images_folder, img_db=None):
except Exception as e:
print(f"Error loading database file: {e}")
return None,

if img_db is not None:
self.LOADED_DB = self.MY_LOADED_DB + img_db.LOADED_DB
else:
self.LOADED_DB = self.MY_LOADED_DB
return self,
return self,

class EditDB:

Expand All @@ -87,20 +87,20 @@ def __init__(self):

@classmethod
def INPUT_TYPES(s):
return {
"required": {
return {
"required": {
"img_db": ("LoadDB",),
"method": (["exclude", "filter", "replace"], {"default": "exclude", "tooltip": "Method to edit results"}),
"edit_text": ("STRING", {
"multiline": False,
"default": "remove*images.jpg",
"tooltip": "Use wildcards (*) to match filenames or paths"
},),
},),
"replace_text": ("STRING", {
"multiline": False,
"default": "/newpath/",
"tooltip": "Text to replace matched text with when using 'replace' method"
},),
},),
}
}

Expand Down Expand Up @@ -132,7 +132,7 @@ def Edit_DB(self, img_db:LoadDB, method, edit_text, replace_text):
new_name = file_name.replace(pat.strip("*"), replace_text)
NewDB.LOADED_DB.append((new_name, embeddings))

return NewDB,
return NewDB,

class GenerateDB:
"""
Expand All @@ -146,7 +146,7 @@ def __init__(self):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"required": {
"clip_vision": ("CLIP_VISION",),
"path_to_images_folder": ("STRING", {
"multiline": False,
Expand All @@ -163,7 +163,7 @@ def INPUT_TYPES(s):
"unique_id": "UNIQUE_ID",
}
}

RETURN_TYPES = ("STRING",)
RETURN_NAMES = ("ERRORS",)
OUTPUT_NODE = True
Expand All @@ -181,13 +181,15 @@ def start_gen_db(self, clip_vision, path_to_images_folder, new_db_name, unique_i

Returns:
A string containing any errors that occurred.
"""
"""
# Make sure the database name ends with .json
if not new_db_name.endswith('.json'):
new_db_name += '.json'

path_to_images = Path(path_to_images_folder)
path_to_database = Path(folder_paths.get_folder_paths("EmbDBs")[0]) / new_db_name

path_to_database.parent.mkdir(exist_ok=True)

errors = generate_clip_features_json(clip_vision, path_to_images, path_to_database, unique_id)
return errors,