forked from chatchat-space/Langchain-Chatchat
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add relyt vector store (chatchat-space#3926)
* add relyt kb service --------- Co-authored-by: jingsi <[email protected]>
- Loading branch information
1 parent
609d2e5
commit adcc283
Showing
5 changed files
with
157 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ import os | |
# 默认使用的知识库 | ||
DEFAULT_KNOWLEDGE_BASE = "samples" | ||
|
||
# 默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector, chromadb 全文检索引擎es | ||
# 默认向量库/全文检索引擎类型。可选:faiss, milvus(离线) & zilliz(在线), pgvector, chromadb 全文检索引擎es, relyt | ||
DEFAULT_VS_TYPE = "faiss" | ||
|
||
# 缓存向量库数量(针对FAISS) | ||
|
@@ -99,7 +99,9 @@ kbs_config = { | |
"pg": { | ||
"connection_uri": "postgresql://postgres:[email protected]:5432/langchain_chatchat", | ||
}, | ||
|
||
"relyt": { | ||
"connection_uri": "postgresql+psycopg2://postgres:[email protected]:7000/langchain_chatchat", | ||
}, | ||
"es": { | ||
"host": "127.0.0.1", | ||
"port": "9200", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from typing import List, Dict | ||
|
||
from langchain.schema import Document | ||
from langchain_community.vectorstores.pgvecto_rs import PGVecto_rs | ||
from sqlalchemy import text, create_engine | ||
from sqlalchemy.orm import Session | ||
|
||
from configs import kbs_config | ||
from server.knowledge_base.kb_service.base import SupportedVSType, KBService, EmbeddingsFunAdapter, \ | ||
score_threshold_process | ||
from server.knowledge_base.utils import KnowledgeFile | ||
|
||
|
||
class RelytKBService(KBService): | ||
|
||
def _load_relyt_vector(self): | ||
embedding_func = EmbeddingsFunAdapter(self.embed_model) | ||
sample_embedding = embedding_func.embed_query("Hello relyt!") | ||
self.relyt = PGVecto_rs( | ||
embedding=embedding_func, | ||
dimension=len(sample_embedding), | ||
db_url=kbs_config.get("relyt").get("connection_uri"), | ||
collection_name=self.kb_name, | ||
) | ||
self.engine = create_engine(kbs_config.get("relyt").get("connection_uri")) | ||
|
||
def get_doc_by_ids(self, ids: List[str]) -> List[Document]: | ||
ids_str = ', '.join([f"{id}" for id in ids]) | ||
with Session(self.engine) as session: | ||
stmt = text(f"SELECT text, meta FROM collection_{self.kb_name} WHERE id in (:ids)") | ||
results = [Document(page_content=row[0], metadata=row[1]) for row in | ||
session.execute(stmt, {'ids': ids_str}).fetchall()] | ||
return results | ||
|
||
def del_doc_by_ids(self, ids: List[str]) -> bool: | ||
ids_str = ', '.join([f"{id}" for id in ids]) | ||
with Session(self.engine) as session: | ||
stmt = text(f"DELETE FROM collection_{self.kb_name} WHERE id in (:ids)") | ||
session.execute(stmt, {'ids': ids_str}) | ||
session.commit() | ||
return True | ||
|
||
def do_init(self): | ||
self._load_relyt_vector() | ||
self.do_create_kb() | ||
|
||
def do_create_kb(self): | ||
index_name = f"idx_{self.kb_name}_embedding" | ||
with self.engine.connect() as conn: | ||
with conn.begin(): | ||
index_query = text( | ||
f""" | ||
SELECT 1 | ||
FROM pg_indexes | ||
WHERE indexname = '{index_name}'; | ||
""") | ||
result = conn.execute(index_query).scalar() | ||
if not result: | ||
index_statement = text( | ||
f""" | ||
CREATE INDEX {index_name} | ||
ON collection_{self.kb_name} | ||
USING vectors (embedding vector_l2_ops) | ||
WITH (options = $$ | ||
optimizing.optimizing_threads = 30 | ||
segment.max_growing_segment_size = 2000 | ||
segment.max_sealed_segment_size = 30000000 | ||
[indexing.hnsw] | ||
m=30 | ||
ef_construction=500 | ||
$$); | ||
""") | ||
conn.execute(index_statement) | ||
|
||
def vs_type(self) -> str: | ||
return SupportedVSType.RELYT | ||
|
||
def do_drop_kb(self): | ||
drop_statement = text(f"DROP TABLE IF EXISTS collection_{self.kb_name};") | ||
with self.engine.connect() as conn: | ||
with conn.begin(): | ||
conn.execute(drop_statement) | ||
|
||
def do_search(self, query: str, top_k: int, score_threshold: float): | ||
docs = self.relyt.similarity_search_with_score(query, top_k) | ||
return score_threshold_process(score_threshold, top_k, docs) | ||
|
||
def do_add_doc(self, docs: List[Document], **kwargs) -> List[Dict]: | ||
print(docs) | ||
ids = self.relyt.add_documents(docs) | ||
print(ids) | ||
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)] | ||
return doc_infos | ||
|
||
def do_delete_doc(self, kb_file: KnowledgeFile, **kwargs): | ||
filepath = self.get_relative_source_path(kb_file.filepath) | ||
stmt = f"DELETE FROM collection_{self.kb_name} WHERE meta->>'source'='{filepath}'; " | ||
with Session(self.engine) as session: | ||
session.execute(text(stmt)) | ||
session.commit() | ||
|
||
def do_clear_vs(self): | ||
self.do_drop_kb() | ||
|
||
|
||
if __name__ == '__main__': | ||
from server.db.base import Base, engine | ||
Base.metadata.create_all(bind=engine) | ||
relyt_kb_service = RelytKBService("collection_test") | ||
kf = KnowledgeFile("README.md", "test") | ||
print(kf) | ||
relyt_kb_service.add_doc(kf) | ||
print("has add README") | ||
relyt_kb_service.delete_doc(KnowledgeFile("README.md", "test")) | ||
relyt_kb_service.drop_kb() | ||
print(relyt_kb_service.get_doc_by_ids(["444022434274215486"])) | ||
print(relyt_kb_service.search_docs("如何启动api服务")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from server.knowledge_base.kb_service.relyt_kb_service import RelytKBService | ||
from server.knowledge_base.migrate import create_tables | ||
from server.knowledge_base.utils import KnowledgeFile | ||
|
||
kbService = RelytKBService("test") | ||
|
||
test_kb_name = "test" | ||
test_file_name = "README.md" | ||
testKnowledgeFile = KnowledgeFile(test_file_name, test_kb_name) | ||
search_content = "如何启动api服务" | ||
|
||
|
||
def test_init(): | ||
create_tables() | ||
|
||
|
||
def test_create_db(): | ||
assert kbService.create_kb() | ||
|
||
|
||
def test_add_doc(): | ||
assert kbService.add_doc(testKnowledgeFile) | ||
|
||
|
||
def test_search_db(): | ||
result = kbService.search_docs(search_content) | ||
assert len(result) > 0 | ||
|
||
|
||
def test_delete_doc(): | ||
assert kbService.delete_doc(testKnowledgeFile) |