diff --git a/examples/TCL_rag/config.yaml b/examples/TCL_rag/config.yaml index 231b3a8..9551aff 100644 --- a/examples/TCL_rag/config.yaml +++ b/examples/TCL_rag/config.yaml @@ -1,12 +1,12 @@ llm: name: openai - base_url: "https://api.gptsapi.net/v1" - api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2" + base_url: "xxx" + api_key: "xxx" model: "gpt-4.1-mini" embedding: name: huggingface - model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B" + model_name: "xxx" model_kwargs: device: "cuda:0" @@ -14,20 +14,20 @@ embedding: store: name: faiss - folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store + folder_path: xxx bm25: name: bm25 k: 10 - data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json + data_path: xxx retriever: name: vectorstore reranker: name: qwen3 - model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B" + model_name_or_path: "xxx" device_id: "cuda:0" dataset: diff --git a/examples/TCL_rag/test.py b/examples/TCL_rag/test.py index 9607f3f..34efb74 100644 --- a/examples/TCL_rag/test.py +++ b/examples/TCL_rag/test.py @@ -24,9 +24,8 @@ vector_store_config=vector_store_config, bm25_retriever_config=bm25_retriever_config) - result = rag.invoke("毛细管设计规范按照什么标准",k=20) + result = rag.invoke("模块机传感器端子不防呆的改善方案是什么?由哪个部门负责?",k=20) - answer = rag.answer("毛细管设计规范按照什么标准",result) - - - print(answer) \ No newline at end of file + for i in result: + print(i) + print("-"*100) \ No newline at end of file diff --git a/rag_factory/Retrieval/Retriever/Retriever_BM25.py b/rag_factory/Retrieval/Retriever/Retriever_BM25.py index 98eef3e..9749166 100644 --- a/rag_factory/Retrieval/Retriever/Retriever_BM25.py +++ b/rag_factory/Retrieval/Retriever/Retriever_BM25.py @@ -5,7 +5,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence from dataclasses import dataclass, field - +import uuid from pydantic import ConfigDict, Field, model_validator logger = logging.getLogger(__name__) @@ -207,7 +207,7 @@ def from_texts( f"与 texts 长度 ({len(texts_list)}) 不匹配" ) else: - ids_list = [None for _ in texts_list] + ids_list = [str(uuid.uuid4()) for _ in texts_list] # 预处理文本 logger.info(f"正在预处理 {len(texts_list)} 个文本...") diff --git a/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py b/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py index cae3923..30b1171 100644 --- a/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py +++ b/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py @@ -50,8 +50,8 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: Note: - 每个检索器的结果会被转换为RetrievalResult格式 - - 支持多种输入格式:Document对象、字典格式、字符串等 - - 融合后的结果会将score和rank信息保存在Document的metadata中 + - 输入只会是Document对象 + - 融合后的结果只返回排序好的Document对象 """ top_k = kwargs.get('top_k', 10) @@ -65,43 +65,12 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: # 转换为RetrievalResult格式 formatted_results = [] for i, doc in enumerate(documents): - if isinstance(doc, Document): - # 如果是Document对象 - retrieval_result = RetrievalResult( - document=doc, - score=getattr(doc, 'score', 1.0), - rank=i + 1 - ) - elif isinstance(doc, dict): - # 如果返回的是字典格式,需要转换为Document对象 - content = doc.get('content', '') - metadata = doc.get('metadata', {}) - doc_id = doc.get('id') - - document = Document( - content=content, - metadata=metadata, - id=doc_id - ) - - retrieval_result = RetrievalResult( - document=document, - score=doc.get('score', 1.0), - rank=i + 1 - ) - else: - # 如果是字符串或其他格式,转换为Document对象 - document = Document( - content=str(doc), - metadata={}, - id=None - ) - - retrieval_result = RetrievalResult( - document=document, - score=1.0, - rank=i + 1 - ) + # 输入只会是Document对象 + retrieval_result = RetrievalResult( + document=doc, + score=getattr(doc, 'score', 1.0), + rank=i + 1 + ) formatted_results.append(retrieval_result) all_results.append(formatted_results) @@ -116,16 +85,10 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: fused_results = self.fusion_method.fuse(all_results, top_k) - # 转换回Document格式 + # 转换回Document格式,只返回排序好的Document对象 documents = [] for result in fused_results: - doc = result.document - # 将score和rank添加到metadata中以便保留 - if doc.metadata is None: - doc.metadata = {} - doc.metadata['score'] = result.score - doc.metadata['rank'] = result.rank - documents.append(doc) + documents.append(result.document) return documents diff --git a/requirements.txt b/requirements.txt index 823e164..56e5bd3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,15 +14,13 @@ llama-index llama-index-core peewee -mineru[core] + rank_bm25 faiss_gpu - -# streamlit +# for ocr PyMuPDF -openai qwen_vl_utils transformers==4.51.3 huggingface_hub @@ -31,3 +29,6 @@ flash-attn==2.8.0.post2 # for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2 accelerate dashscope +torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu128 + +mineru[core] \ No newline at end of file