From 22faf0453624cbfff0ade6f3d3bef442f4514453 Mon Sep 17 00:00:00 2001 From: Hubert Date: Fri, 22 Nov 2024 11:57:39 +0100 Subject: [PATCH] Add remove_from_index method --- byaldi/RAGModel.py | 11 +++++++++ byaldi/colpali.py | 58 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/byaldi/RAGModel.py b/byaldi/RAGModel.py index 32b66bf..f1f0dba 100644 --- a/byaldi/RAGModel.py +++ b/byaldi/RAGModel.py @@ -155,6 +155,17 @@ def add_to_index( input_item, store_collection_with_index, doc_id, metadata=metadata ) + def remove_from_index(self, doc_id: int): + """Remove an item to an existing index. + + Parameters: + doc_id (int): The document ID for the item being removed. + + Returns: + None + """ + return self.model.remove_from_index(doc_id) + def search( self, query: Union[str, List[str]], diff --git a/byaldi/colpali.py b/byaldi/colpali.py index cc11dcb..8a5703e 100644 --- a/byaldi/colpali.py +++ b/byaldi/colpali.py @@ -589,8 +589,62 @@ def _add_to_index( if self.verbose > 0: print(f"Added page {page_id} of document {doc_id} to index.") - def remove_from_index(self): - raise NotImplementedError("This method is not implemented yet.") + def remove_from_index(self, doc_id: int) -> Dict[int, str]: + if self.index_name is None: + raise ValueError( + "No index loaded. Use index() to create or load an index first." + ) + + if not hasattr(self, "doc_ids_to_file_names"): + raise ValueError("No documents in the index.") + + if doc_id not in self.doc_ids_to_file_names: + raise ValueError(f"Document ID {doc_id} does not exist in the index.") + + # Remove the document from doc_ids_to_file_names + del self.doc_ids_to_file_names[doc_id] + + # Remove associated embeddings + embed_ids_to_remove = [ + embed_id + for embed_id, doc_info in self.embed_id_to_doc_id.items() + if doc_info["doc_id"] == doc_id + ] + self.indexed_embeddings = [ + embedding + for i, embedding in enumerate(self.indexed_embeddings) + if i not in embed_ids_to_remove + ] + + # Create a new dictionary for embed_id_to_doc_id excluding entries in embed_ids_to_remove + remaining_embed_id_to_doc_id = { + embed_id: self.embed_id_to_doc_id[embed_id] + for embed_id in self.embed_id_to_doc_id + if embed_id not in embed_ids_to_remove + } + + # Re-index the remaining items in embed_id_to_doc_id + self.embed_id_to_doc_id = { + new_id: doc_info + for new_id, doc_info in enumerate(remaining_embed_id_to_doc_id.values()) + } + + # Remove associated entries from the collection if they exist + if hasattr(self, "collection"): + for embed_id in embed_ids_to_remove: + if embed_id in self.collection: + del self.collection[embed_id] + + # Remove metadata if it exists + if hasattr(self, "doc_id_to_metadata") and doc_id in self.doc_id_to_metadata: + del self.doc_id_to_metadata[doc_id] + + # Update the highest_doc_id if necessary + if doc_id == self.highest_doc_id: + self.highest_doc_id = max(self.doc_ids_to_file_names.keys(), default=-1) + + self._export_index() + return self.doc_ids_to_file_names def filter_embeddings(self,filter_metadata:Dict[str,str]): req_doc_ids = []