diff --git a/examples/TCL_rag/config.yaml b/examples/TCL_rag/config.yaml new file mode 100644 index 0000000..231b3a8 --- /dev/null +++ b/examples/TCL_rag/config.yaml @@ -0,0 +1,35 @@ +llm: + name: openai + base_url: "https://api.gptsapi.net/v1" + api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2" + model: "gpt-4.1-mini" + +embedding: + name: huggingface + model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B" + model_kwargs: + device: "cuda:0" + + + +store: + name: faiss + folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store + + +bm25: + name: bm25 + k: 10 + data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json + +retriever: + name: vectorstore + +reranker: + name: qwen3 + model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B" + device_id: "cuda:0" + +dataset: + name: TCL + diff --git a/examples/TCL_rag/rag_flow.py b/examples/TCL_rag/rag_flow.py new file mode 100644 index 0000000..1e5f9c9 --- /dev/null +++ b/examples/TCL_rag/rag_flow.py @@ -0,0 +1,85 @@ +import sys +import os + +# 添加 RAG-Factory 目录到 Python 路径 +rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, rag_factory_path) + +from rag_factory.llms import LLMRegistry +from rag_factory.Embed import EmbeddingRegistry +from rag_factory.Store import VectorStoreRegistry +from rag_factory.Retrieval import RetrieverRegistry +from rag_factory.rerankers import RerankerRegistry +from rag_factory.Retrieval import Document +from typing import List +import json + + +class TCL_RAG: + def __init__( + self, + *, + llm_config=None, + embedding_config=None, + vector_store_config=None, + bm25_retriever_config=None, + retriever_config=None, + reranker_config=None, + ): + llm_config = llm_config or {} + embedding_config = embedding_config or {} + vector_store_config = vector_store_config or {} + bm25_retriever_config = bm25_retriever_config or {} + retriever_config = retriever_config or {} + reranker_config = reranker_config or {} + self.llm = LLMRegistry.create(**llm_config) + self.embedding = EmbeddingRegistry.create(**embedding_config) + self.vector_store = VectorStoreRegistry.load(**vector_store_config, embedding=self.embedding) + self.bm25_retriever = RetrieverRegistry.create(**bm25_retriever_config) + self.bm25_retriever = self.bm25_retriever.from_documents(documents=self._load_data(bm25_retriever_config["data_path"]), preprocess_func=self.chinese_preprocessing_func, k=bm25_retriever_config["k"]) + + self.retriever = RetrieverRegistry.create(**retriever_config, vectorstore=self.vector_store) + self.multi_path_retriever = RetrieverRegistry.create("multipath", retrievers=[self.bm25_retriever, self.retriever]) + self.reranker = RerankerRegistry.create(**reranker_config) + + def invoke(self, query: str, k: int = None): + return self.multi_path_retriever.invoke(query, top_k=k) + + def rerank(self, query: str, documents: List[Document], k: int = None, batch_size: int = 8): + return self.reranker.rerank(query, documents, k, batch_size) + + def _load_data(self, data_path: str): + with open(data_path, "r", encoding="utf-8") as f: + data = json.load(f) + docs = [] + for item in data: + content = item.get("full_content", "") + metadata = {"title": item.get("original_filename", "")} + docs.append(Document(content=content, metadata=metadata)) + return docs + + def chinese_preprocessing_func(self, text: str) -> str: + import jieba + return " ".join(jieba.cut(text)) + + + def answer(self, query: str, documents: List[Document]): + + template = ( + "你是一位工业领域的专家。根据以下检索到的材料回答用户问题。" + "如果回答所需信息未在材料中出现,请说明无法找到相关信息。\n\n" + "{context}\n\n" + "用户问题:{question}\n" + "答复:" + ) + context = "\n".join([doc.content for doc in documents]) + prompt = template.format(question=query, context=context) + messages = [ + {"role": "system", "content": "你是一位工业领域的专家。"}, + {"role": "user", "content": prompt} + ] + return self.llm.chat(messages) + + + + diff --git a/examples/TCL_rag/test.py b/examples/TCL_rag/test.py new file mode 100644 index 0000000..9607f3f --- /dev/null +++ b/examples/TCL_rag/test.py @@ -0,0 +1,32 @@ +from rag_flow import TCL_RAG +import yaml + +# 加载配置文件 +with open('/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/TCL_rag/config.yaml', 'r', encoding='utf-8') as f: + config = yaml.safe_load(f) + +llm_config = config['llm'] +embedding_config = config['embedding'] +reranker_config = config['reranker'] +bm25_retriever_config = config['bm25'] +retriever_config = config['retriever'] +vector_store_config = config['store'] + + + + +if __name__ == "__main__": + + rag = TCL_RAG(llm_config=llm_config, + embedding_config=embedding_config, + reranker_config=reranker_config, + retriever_config=retriever_config, + vector_store_config=vector_store_config, + bm25_retriever_config=bm25_retriever_config) + + result = rag.invoke("毛细管设计规范按照什么标准",k=20) + + answer = rag.answer("毛细管设计规范按照什么标准",result) + + + print(answer) \ No newline at end of file diff --git a/examples/bm25/config.yaml b/examples/bm25/config.yaml new file mode 100644 index 0000000..f52f8af --- /dev/null +++ b/examples/bm25/config.yaml @@ -0,0 +1,3 @@ +retriever: + name: bm25 + k: 8 \ No newline at end of file diff --git a/examples/bm25/main.py b/examples/bm25/main.py new file mode 100644 index 0000000..03fa291 --- /dev/null +++ b/examples/bm25/main.py @@ -0,0 +1,36 @@ +import sys +import os + +rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, rag_factory_path) + +import json +from rag_factory.Retrieval import Document +from rag_factory.Retrieval import RetrieverRegistry + +import yaml + + +def load_data(jsonl_path: str): + with open(jsonl_path, "r", encoding="utf-8") as f: + data = json.load(f) + docs = [] + for item in data: + content = item.get("full_content", "") + metadata = {"title": item.get("original_title", "")} + docs.append(Document(content=content, metadata=metadata)) + return docs + +def chinese_preprocessing_func(text: str) -> str: + import jieba + return " ".join(jieba.cut(text)) + +if __name__ == "__main__": + docs = load_data("/data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json") + with open("/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/bm25/config.yaml", "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + + bm25_retriever = RetrieverRegistry.create(**config["retriever"]) + bm25_retriever = bm25_retriever.from_documents(documents=docs, preprocess_func=chinese_preprocessing_func, k=config["retriever"]["k"]) + + print(bm25_retriever.invoke("什么是TCL?")) \ No newline at end of file diff --git a/examples/faiss_construct/config.yaml b/examples/faiss_construct/config.yaml new file mode 100644 index 0000000..d429d1a --- /dev/null +++ b/examples/faiss_construct/config.yaml @@ -0,0 +1,14 @@ +store: + name: faiss # 数据库 + folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store # 保存路径 + + +embedding: + name: huggingface # 嵌入模型 + model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B" # 模型路径 + model_kwargs: + device: "cuda:1" # 设备 + +dataset: + name: TCL + data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json \ No newline at end of file diff --git a/examples/faiss_construct/faiss_constructor.py b/examples/faiss_construct/faiss_constructor.py new file mode 100644 index 0000000..cb7c0c5 --- /dev/null +++ b/examples/faiss_construct/faiss_constructor.py @@ -0,0 +1,43 @@ +import sys +import os + +# 添加 RAG-Factory 目录到 Python 路径 +rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..") +sys.path.insert(0, rag_factory_path) + +from rag_factory.Store import VectorStoreRegistry +from rag_factory.Embed import EmbeddingRegistry +import yaml +from rag_factory.Retrieval import Document +import json + + +with open("/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/faiss_construct/config.yaml", "r", encoding="utf-8") as f: + config = yaml.safe_load(f) + +store_config = config["store"] +embedding_config = config["embedding"] +dataset_config = config["dataset"]["data_path"] +embedding = EmbeddingRegistry.create(**embedding_config) +store = VectorStoreRegistry.create(**store_config, embedding=embedding) + + +if __name__ == "__main__": + + # 读取数据 + with open(dataset_config, "r", encoding="utf-8") as f: + docs = [] + data = json.load(f) + for item in data: + full_content = item.get("full_content", "") + metadata = { + "title": item.get("original_filename"), + } + + docs.append(Document(content=full_content, metadata=metadata)) + + # 创建向量库 + vectorstore = store.from_documents(docs, embedding=embedding) + + # 保存到本地 + vectorstore.save_local(store_config["folder_path"]) \ No newline at end of file diff --git a/rag_factory/Embed/Embedding_Base.py b/rag_factory/Embed/Embedding_Base.py index 2af5bb2..c1527d3 100644 --- a/rag_factory/Embed/Embedding_Base.py +++ b/rag_factory/Embed/Embedding_Base.py @@ -2,12 +2,13 @@ from dataclasses import dataclass import asyncio from concurrent.futures import ThreadPoolExecutor +from typing import List class Embeddings(ABC): """嵌入接口""" @abstractmethod - def embed_documents(self, texts: list[str]) -> list[list[float]]: + def embed_documents(self, texts: List[str]) -> List[List[float]]: """Embed search docs. Args: @@ -19,7 +20,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]: pass @abstractmethod - def embed_query(self, text: str) -> list[float]: + def embed_query(self, text: str) -> List[float]: """Embed query text. Args: @@ -30,7 +31,7 @@ def embed_query(self, text: str) -> list[float]: """ pass - async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: """Asynchronous Embed search docs. Args: @@ -43,7 +44,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]: ThreadPoolExecutor(), self.embed_documents, texts ) - async def aembed_query(self, text: str) -> list[float]: + async def aembed_query(self, text: str) -> List[float]: """Asynchronous Embed query text. Args: diff --git a/rag_factory/Embed/__init__.py b/rag_factory/Embed/__init__.py index 209132e..dc3fe8f 100644 --- a/rag_factory/Embed/__init__.py +++ b/rag_factory/Embed/__init__.py @@ -1,4 +1,5 @@ from .Embedding_Base import Embeddings from .Embedding_Huggingface import HuggingFaceEmbeddings +from .registry import EmbeddingRegistry -__all__ = ["Embeddings", "HuggingFaceEmbeddings"] \ No newline at end of file +__all__ = ["Embeddings", "HuggingFaceEmbeddings", "EmbeddingRegistry"] \ No newline at end of file diff --git a/rag_factory/Embed/registry.py b/rag_factory/Embed/registry.py new file mode 100644 index 0000000..0899688 --- /dev/null +++ b/rag_factory/Embed/registry.py @@ -0,0 +1,79 @@ +from typing import Dict, Type, Any, Optional, List +import logging +from .Embedding_Huggingface import HuggingFaceEmbeddings +from .Embedding_Base import Embeddings + +class EmbeddingRegistry: + """嵌入模型注册器,用于管理和创建不同类型的嵌入模型""" + _embeddings: Dict[str, Type[Embeddings]] = {} + + @classmethod + def register(cls, name: str, embedding_class: Type[Embeddings]): + """注册嵌入模型类 + + Args: + name: 模型名称 + embedding_class: 嵌入模型类 + """ + cls._embeddings[name] = embedding_class + + @classmethod + def create(cls, name: str, **kwargs) -> Embeddings: + """获取嵌入模型实例 + + Args: + name: 模型名称 + **kwargs: 模型初始化参数 + + Returns: + 嵌入模型实例 + + Raises: + ValueError: 当模型名称不存在时 + """ + if name not in cls._embeddings: + available_embeddings = list(cls._embeddings.keys()) + raise ValueError(f"嵌入模型 '{name}' 未注册。可用的模型: {available_embeddings}") + + embedding_class = cls._embeddings[name] + return embedding_class(**kwargs) + + @classmethod + def list_embeddings(cls) -> List[str]: + """列出所有已注册的嵌入模型名称 + + Returns: + 已注册的模型名称列表 + """ + return list(cls._embeddings.keys()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """检查模型是否已注册 + + Args: + name: 模型名称 + + Returns: + 如果已注册返回True,否则返回False + """ + return name in cls._embeddings + + @classmethod + def unregister(cls, name: str) -> bool: + """取消注册模型 + + Args: + name: 模型名称 + + Returns: + 成功取消注册返回True,模型不存在返回False + """ + if name in cls._embeddings: + del cls._embeddings[name] + return True + return False + + +# 注册默认的嵌入模型 +EmbeddingRegistry.register("huggingface", HuggingFaceEmbeddings) \ No newline at end of file diff --git a/rag_factory/Retrieval/Retriever/Retriever_BM25.py b/rag_factory/Retrieval/Retriever/Retriever_BM25.py index 278b55e..98eef3e 100644 --- a/rag_factory/Retrieval/Retriever/Retriever_BM25.py +++ b/rag_factory/Retrieval/Retriever/Retriever_BM25.py @@ -9,12 +9,12 @@ from pydantic import ConfigDict, Field, model_validator logger = logging.getLogger(__name__) - +import numpy as np from rag_factory.Retrieval.RetrieverBase import BaseRetriever, Document def default_preprocessing_func(text: str) -> List[str]: - """默认的文本预处理函数 + """默认的文本预处理函数,仅在英文文本上有效 Args: text: 输入文本 @@ -25,33 +25,51 @@ def default_preprocessing_func(text: str) -> List[str]: return text.split() -def chinese_preprocessing_func(text: str) -> List[str]: - """中文文本预处理函数 - - Args: - text: 输入的中文文本 - - Returns: - 分词后的词语列表 - """ - try: - import jieba - return list(jieba.cut(text)) - except ImportError: - logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba") - return text.split() class BM25Retriever(BaseRetriever): - """BM25 检索器实现 - - 基于 BM25 算法的文档检索器。 - 使用 rank_bm25 库实现高效的 BM25 搜索。 - - 注意:BM25 算法适用于相对静态的文档集合。虽然支持动态添加/删除文档, - 但每次操作都会重建整个索引,在大型文档集合上可能有性能问题。 - 对于频繁更新的场景,建议使用 VectorStoreRetriever。 - + """ + BM25Retriever 是一个基于 BM25 算法的文档检索器,适用于信息检索、问答系统、知识库等场景下的高效文本相关性排序。 + + 该类通过集成 rank_bm25 库,实现了对文档集合的 BM25 检索,支持文档的动态添加、删除、批量构建索引等操作。 + 适合文档集合相对静态、检索速度要求较高的场景。对于频繁增删文档的场景,建议使用向量检索(如 VectorStoreRetriever)。 + + 主要特性: + - 支持从文本列表或 Document 对象列表快速构建 BM25 检索器。 + - 支持自定义分词/预处理函数,适配不同语言和分词需求。 + - 支持动态添加、删除文档(每次操作会重建索引,适合中小规模数据集)。 + - 可获取检索分数、top-k 文档及分数、检索器配置信息等。 + - 兼容异步文档添加/删除,便于大规模数据处理。 + - 通过 Pydantic 校验参数,保证配置安全。 + + 主要参数: + vectorizer (Any): BM25 向量化器实例(通常为 BM25Okapi)。 + docs (List[Document]): 当前检索器持有的文档对象列表。 + k (int): 默认返回的相关文档数量。 + preprocess_func (Callable): 文本分词/预处理函数,默认为空格分词。 + bm25_params (Dict): 传递给 BM25Okapi 的参数(如 k1、b 等)。 + + 核心方法: + - from_texts/from_documents: 从原始文本或 Document 构建检索器。 + - _get_relevant_documents: 检索与查询最相关的前 k 个文档。 + - get_scores: 获取查询对所有文档的 BM25 分数。 + - get_top_k_with_scores: 获取 top-k 文档及其分数。 + - add_documents/delete_documents: 动态增删文档并重建索引。 + - get_bm25_info: 获取检索器配置信息和统计。 + - update_k: 动态调整返回文档数量。 + + 性能注意事项: + - 每次添加/删除文档都会重建 BM25 索引,适合文档量较小或更新不频繁的场景。 + - 文档量较大或频繁更新时,建议使用向量检索方案。 + - 支持异步操作,便于大规模数据处理。 + + 典型用法: + >>> retriever = BM25Retriever.from_texts(["文本1", "文本2"], k=3) + >>> results = retriever._get_relevant_documents("查询语句") + >>> retriever.add_documents([Document(content="新文档")]) + >>> retriever.delete_documents(ids=["doc_id"]) + >>> info = retriever.get_bm25_info() + Attributes: vectorizer: BM25 向量化器实例 docs: 文档列表 @@ -95,7 +113,7 @@ def __init__(self, **kwargs): # 设置属性 self.vectorizer = kwargs.get('vectorizer') self.docs = kwargs.get('docs', []) - self.k = kwargs.get('k', 4) + self.k = kwargs.get('k', 5) self.preprocess_func = kwargs.get('preprocess_func', default_preprocessing_func) self.bm25_params = kwargs.get('bm25_params', {}) @@ -125,7 +143,7 @@ def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]: Returns: 验证后的值 """ - k = values.get("k", 4) + k = values.get("k", 5) if k <= 0: raise ValueError(f"k 必须大于 0,当前值: {k}") @@ -139,7 +157,6 @@ def from_texts( ids: Optional[Iterable[str]] = None, bm25_params: Optional[Dict[str, Any]] = None, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, - k: int = 4, **kwargs: Any, ) -> "BM25Retriever": """从文本列表创建 BM25Retriever @@ -150,7 +167,6 @@ def from_texts( ids: ID列表,可选 bm25_params: BM25 算法参数,可选 preprocess_func: 预处理函数 - k: 返回文档数量 **kwargs: 其他参数 Returns: @@ -213,7 +229,6 @@ def from_texts( return cls( vectorizer=vectorizer, docs=docs, - k=k, preprocess_func=preprocess_func, bm25_params=bm25_params, **kwargs @@ -225,7 +240,6 @@ def from_documents( documents: Iterable[Document], bm25_params: Optional[Dict[str, Any]] = None, preprocess_func: Callable[[str], List[str]] = default_preprocessing_func, - k: int = 4, **kwargs: Any, ) -> "BM25Retriever": """从文档列表创建 BM25Retriever @@ -234,7 +248,6 @@ def from_documents( documents: 文档列表 bm25_params: BM25 算法参数,可选 preprocess_func: 预处理函数 - k: 返回文档数量 **kwargs: 其他参数 Returns: @@ -254,51 +267,52 @@ def from_documents( ids=ids, bm25_params=bm25_params, preprocess_func=preprocess_func, - k=k, **kwargs, ) def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: - """获取与查询相关的文档 - + """获取与查询相关的前k个文档 + Args: query: 查询字符串 **kwargs: 其他参数,可能包含 'k' 来覆盖默认的返回数量 - + Returns: 相关文档列表 - + Raises: ValueError: 如果向量化器未初始化 """ if self.vectorizer is None: raise ValueError("BM25 向量化器未初始化") - + if not self.docs: logger.warning("文档列表为空,返回空结果") return [] - + # 获取返回文档数量 k = kwargs.get('k', self.k) k = min(k, len(self.docs)) # 确保不超过总文档数 - + try: # 预处理查询 processed_query = self.preprocess_func(query) logger.debug(f"预处理后的查询: {processed_query}") + + # 获取所有文档的分数 + scores = self.vectorizer.get_scores(processed_query) + # 获取分数最高的前k个文档索引 - # 获取相关文档 - relevant_docs = self.vectorizer.get_top_n( - processed_query, self.docs, n=k - ) - - logger.debug(f"找到 {len(relevant_docs)} 个相关文档") - return relevant_docs - + top_indices = np.argsort(scores)[::-1][:k] + # 返回前k个文档 + top_docs = [self.docs[idx] for idx in top_indices] + logger.debug(f"找到 {len(top_docs)} 个相关文档") + return top_docs + except Exception as e: logger.error(f"BM25 搜索时发生错误: {e}") raise - + def get_scores(self, query: str) -> List[float]: """获取查询对所有文档的 BM25 分数 diff --git a/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py b/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py new file mode 100644 index 0000000..cae3923 --- /dev/null +++ b/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py @@ -0,0 +1,164 @@ +from typing import List, Dict, Any, Optional, Union + +from ..RetrieverBase import Document +from ..RetrieverBase import BaseRetriever +from ..utils.Fusion import FusionMethod, RRFusion, RetrievalResult + + +class MultiPathRetriever(BaseRetriever): + """ + 多路检索器 + + 该类实现了多路检索功能,可以同时使用多个检索器进行文档检索, + 并通过指定的融合方法将多个检索器的结果进行合并和排序。 + + Attributes: + retrievers (List[BaseRetriever]): 检索器列表,每个检索器需要实现retrieve方法 + fusion_method (FusionMethod): 融合方法,用于合并多个检索器的结果 + top_k_per_retriever (int): 每个检索器返回的结果数量 + """ + + def __init__(self, + retrievers: List[BaseRetriever], + fusion_method: Optional[FusionMethod] = None, + top_k_per_retriever: int = 50): + """ + 初始化多路检索器 + + Args: + retrievers (List[BaseRetriever]): 检索器列表,每个检索器需要实现retrieve方法 + fusion_method (Optional[FusionMethod]): 融合方法,默认为RRF (Reciprocal Rank Fusion) + top_k_per_retriever (int): 每个检索器返回的结果数量,默认为50 + """ + self.retrievers = retrievers + self.fusion_method = fusion_method or RRFusion() + self.top_k_per_retriever = top_k_per_retriever + + def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: + """ + 获取与查询相关的文档 + + 该方法会调用所有配置的检索器,获取每个检索器的检索结果, + 然后使用指定的融合方法将所有结果进行合并和排序。 + + Args: + query (str): 查询字符串 + **kwargs (Any): 其他参数,包括top_k等 + + Returns: + List[Document]: 融合后的相关文档列表,按相关性排序 + + Note: + - 每个检索器的结果会被转换为RetrievalResult格式 + - 支持多种输入格式:Document对象、字典格式、字符串等 + - 融合后的结果会将score和rank信息保存在Document的metadata中 + """ + top_k = kwargs.get('top_k', 10) + + # 从每个检索器获取结果 + all_results = [] + for retriever in self.retrievers: + try: + # 使用BaseRetriever的invoke方法 + documents = retriever.invoke(query, **{**kwargs, 'k': self.top_k_per_retriever}) + + # 转换为RetrievalResult格式 + formatted_results = [] + for i, doc in enumerate(documents): + if isinstance(doc, Document): + # 如果是Document对象 + retrieval_result = RetrievalResult( + document=doc, + score=getattr(doc, 'score', 1.0), + rank=i + 1 + ) + elif isinstance(doc, dict): + # 如果返回的是字典格式,需要转换为Document对象 + content = doc.get('content', '') + metadata = doc.get('metadata', {}) + doc_id = doc.get('id') + + document = Document( + content=content, + metadata=metadata, + id=doc_id + ) + + retrieval_result = RetrievalResult( + document=document, + score=doc.get('score', 1.0), + rank=i + 1 + ) + else: + # 如果是字符串或其他格式,转换为Document对象 + document = Document( + content=str(doc), + metadata={}, + id=None + ) + + retrieval_result = RetrievalResult( + document=document, + score=1.0, + rank=i + 1 + ) + formatted_results.append(retrieval_result) + + all_results.append(formatted_results) + + except Exception as e: + print(f"检索器 {type(retriever).__name__} 执行失败: {e}") + all_results.append([]) + + # 使用融合方法合并结果 + if not all_results or all(len(results) == 0 for results in all_results): + return [] + + fused_results = self.fusion_method.fuse(all_results, top_k) + + # 转换回Document格式 + documents = [] + for result in fused_results: + doc = result.document + # 将score和rank添加到metadata中以便保留 + if doc.metadata is None: + doc.metadata = {} + doc.metadata['score'] = result.score + doc.metadata['rank'] = result.rank + documents.append(doc) + + return documents + + + def add_retriever(self, retriever: BaseRetriever): + """ + 添加新的检索器到多路检索器中 + + Args: + retriever (BaseRetriever): 要添加的检索器实例 + """ + self.retrievers.append(retriever) + + def remove_retriever(self, name: str): + """ + 移除指定名称的检索器 + + Args: + name (str): 要移除的检索器的类名 + + Note: + 该方法通过比较检索器的类名来识别要移除的检索器 + """ + for i, retriever in enumerate(self.retrievers): + if hasattr(retriever, '__class__') and retriever.__class__.__name__ == name: + self.retrievers.pop(i) + break + + def set_fusion_method(self, fusion_method: FusionMethod): + """ + 设置融合方法 + + Args: + fusion_method (FusionMethod): 新的融合方法实例 + """ + self.fusion_method = fusion_method diff --git a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py index 734c7d8..2b7eafe 100644 --- a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py +++ b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py @@ -10,8 +10,8 @@ import logging from pydantic import ConfigDict, Field, model_validator -from Retrieval.RetrieverBase import BaseRetriever, Document -from Store.VectorStore.VectorStoreBase import VectorStore +from ..RetrieverBase import BaseRetriever, Document +from ...Store.VectorStore.VectorStoreBase import VectorStore logger = logging.getLogger(__name__) @@ -141,9 +141,15 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: # 合并搜索参数 search_params = {**self.search_kwargs, **kwargs} + # 获取返回文档数量,参考BM25Retriever的做法 + k = search_params.get('k', getattr(self, 'k', 5)) + search_params['k'] = k + try: if self.search_type == "similarity": docs = self.vectorstore.similarity_search(query, **search_params) + # 确保返回前k个文档 + docs = docs[:k] elif self.search_type == "similarity_score_threshold": docs_and_similarities = ( @@ -152,11 +158,15 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]: ) ) docs = [doc for doc, _ in docs_and_similarities] + # 确保返回前k个文档 + docs = docs[:k] elif self.search_type == "mmr": docs = self.vectorstore.max_marginal_relevance_search( query, **search_params ) + # 确保返回前k个文档 + docs = docs[:k] else: msg = f"不支持的搜索类型: {self.search_type}" diff --git a/rag_factory/Retrieval/Retriever/registry.py b/rag_factory/Retrieval/Retriever/registry.py new file mode 100644 index 0000000..8558c76 --- /dev/null +++ b/rag_factory/Retrieval/Retriever/registry.py @@ -0,0 +1,118 @@ +from typing import Dict, Type, Any, Optional, List +import logging +from .Retriever_BM25 import BM25Retriever +from .Retriever_MultiPath import MultiPathRetriever +from .Retriever_VectorStore import VectorStoreRetriever +from ..RetrieverBase import BaseRetriever + +logger = logging.getLogger(__name__) + +# TODO 统一所有检索器的调用方式 + +class RetrieverRegistry: + """检索器注册表,用于管理和创建不同类型的检索器""" + + _retrievers: Dict[str, Type[BaseRetriever]] = {} + + @classmethod + def register(cls, name: str, retriever_class: Type[BaseRetriever]): + """ + 注册检索器类 + + Args: + name: 检索器名称 + retriever_class: 检索器类 + + Raises: + ValueError: 当检索器类不是BaseRetriever的子类时 + TypeError: 当name不是字符串时 + """ + if not isinstance(name, str): + raise TypeError("检索器名称必须是字符串类型") + + if not issubclass(retriever_class, BaseRetriever): + raise ValueError(f"检索器类 {retriever_class} 必须继承自 BaseRetriever") + + if name in cls._retrievers: + logger.warning(f"检索器 '{name}' 已存在,将被覆盖") + + cls._retrievers[name] = retriever_class + logger.info(f"检索器 '{name}' 注册成功") + + @classmethod + def create(cls, name: str, **kwargs) -> BaseRetriever: + """ + 创建检索器实例 + + Args: + name: 检索器名称 + **kwargs: 传递给检索器构造函数的参数 + + Returns: + BaseRetriever: 检索器实例 + + Raises: + ValueError: 当检索器未注册时 + Exception: 当检索器创建失败时 + """ + if name not in cls._retrievers: + available = ', '.join(cls.list_available()) + raise ValueError(f"检索器 '{name}' 未注册。可用的检索器: {available}") + + retriever_class = cls._retrievers[name] + try: + return retriever_class(**kwargs) + except Exception as e: + logger.error(f"创建检索器 '{name}' 失败: {str(e)}") + raise Exception(f"无法创建检索器 '{name}': {str(e)}") from e + + @classmethod + def list_available(cls) -> List[str]: + """ + 获取所有可用的检索器名称 + + Returns: + List[str]: 检索器名称列表 + """ + return list(cls._retrievers.keys()) + + @classmethod + def unregister(cls, name: str) -> bool: + """ + 取消注册检索器 + + Args: + name: 检索器名称 + + Returns: + bool: 是否成功取消注册 + """ + if name in cls._retrievers: + del cls._retrievers[name] + logger.info(f"检索器 '{name}' 取消注册成功") + return True + return False + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + 检查检索器是否已注册 + + Args: + name: 检索器名称 + + Returns: + bool: 是否已注册 + """ + return name in cls._retrievers + + @classmethod + def clear_all(cls): + """清除所有已注册的检索器""" + cls._retrievers.clear() + logger.info("所有检索器注册已清除") + +# 预注册内置检索器 +RetrieverRegistry.register("bm25", BM25Retriever) +RetrieverRegistry.register("multipath", MultiPathRetriever) +RetrieverRegistry.register("vectorstore", VectorStoreRetriever) diff --git a/rag_factory/Retrieval/Retriever/test_bm25_retriever.py b/rag_factory/Retrieval/Retriever/test_bm25_retriever.py new file mode 100644 index 0000000..41b54d9 --- /dev/null +++ b/rag_factory/Retrieval/Retriever/test_bm25_retriever.py @@ -0,0 +1,64 @@ +from Retriever_BM25 import BM25Retriever +from rag_factory.Retrieval.RetrieverBase import Document +from typing import List +import logging +logger = logging.getLogger(__name__) + +def chinese_preprocessing_func(text: str) -> List[str]: + """中文文本预处理函数 + + Args: + text: 输入的中文文本 + + Returns: + 分词后的词语列表 + """ + try: + import jieba + return list(jieba.cut(text)) + except ImportError: + logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba") + return text.split() + + +if __name__ == "__main__": + # 构造测试数据 + texts = [ + "这是第一个测试文档", + "这是第二个测试文档,内容稍有不同", + "这是第三个文档,讨论不同的主题", + "第四个文档包含更多详细信息", + "最后一个文档作为总结" + ] + + print(chinese_preprocessing_func("这是第一个测试文档")) + metadatas = [ + {"source": "doc1", "type": "test"}, + {"source": "doc2", "type": "test"}, + {"source": "doc3", "type": "example"}, + {"source": "doc4", "type": "detailed"}, + {"source": "doc5", "type": "summary"}, + ] + ids = [f"doc{i+1}" for i in range(len(texts))] + + # 创建BM25Retriever + retriever = BM25Retriever.from_texts( + texts=texts, + metadatas=metadatas, + ids=ids, + preprocess_func=chinese_preprocessing_func, + k=3 # top3 + ) + + # 查询 + query = "第二个测试文档内容" + print(f"\n查询: {query}") + + # retriever.update_k(4) + results = retriever.invoke(query, k=4) + print("\n召回结果:") + for i, (doc) in enumerate(results): + # print(f"{i+1}. 分数: {score:.4f}") + print(f" ID: {doc.id}") + print(f" 内容: {doc.content}") + print(f" 元数据: {doc.metadata}") diff --git a/rag_factory/Retrieval/__init__.py b/rag_factory/Retrieval/__init__.py index 85c684c..445771a 100644 --- a/rag_factory/Retrieval/__init__.py +++ b/rag_factory/Retrieval/__init__.py @@ -1,4 +1,5 @@ from .RetrieverBase import BaseRetriever, Document from .Retriever.Retriever_VectorStore import VectorStoreRetriever +from .Retriever.registry import RetrieverRegistry -__all__ = ["BaseRetriever", "Document", "VectorStoreRetriever"] \ No newline at end of file +__all__ = ["BaseRetriever", "Document", "VectorStoreRetriever", "RetrieverRegistry"] \ No newline at end of file diff --git a/rag_factory/Retrieval/utils/Fusion.py b/rag_factory/Retrieval/utils/Fusion.py new file mode 100644 index 0000000..c876590 --- /dev/null +++ b/rag_factory/Retrieval/utils/Fusion.py @@ -0,0 +1,76 @@ +from typing import List +from dataclasses import dataclass +from abc import ABC, abstractmethod +from collections import defaultdict + +from ..RetrieverBase import Document + + +@dataclass +class RetrievalResult: + """检索结果的数据类""" + document: Document + score: float + rank: int = 0 + + +class FusionMethod(ABC): + """融合方法的抽象基类""" + + @abstractmethod + def fuse(self, results: List[List[RetrievalResult]], top_k: int) -> List[RetrievalResult]: + """ + 融合多个检索器的结果 + + Args: + results: 每个检索器的结果列表 + top_k: 返回的最终结果数量 + + Returns: + 融合后的结果列表 + """ + pass + + +class RRFusion(FusionMethod): + """Reciprocal Rank Fusion (RRF) 方法""" + + def __init__(self, k: float = 60.0): + """ + Args: + k: RRF中的常数,默认为60.0 + """ + self.k = k + + def fuse(self, results: List[List[RetrievalResult]], top_k: int) -> List[RetrievalResult]: + # 为每个结果分配rank + for retriever_results in results: + for i, result in enumerate(retriever_results): + result.rank = i + 1 + + # 计算RRF分数 + rrf_scores = defaultdict(float) + document_map = {} + + for retriever_results in results: + for result in retriever_results: + rrf_score = 1.0 / (self.k + result.rank) + # 使用文档内容作为key来去重 + content_key = result.document.content + rrf_scores[content_key] += rrf_score + document_map[content_key] = result.document + + # 按RRF分数排序 + sorted_items = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) + + # 构建最终结果 + fused_results = [] + for i, (content, rrf_score) in enumerate(sorted_items[:top_k]): + result = RetrievalResult( + document=document_map[content], + score=rrf_score, + rank=i + 1 + ) + fused_results.append(result) + + return fused_results diff --git a/rag_factory/Store/VectorStore/VectorStoreBase.py b/rag_factory/Store/VectorStore/VectorStoreBase.py index f12931d..9027d40 100644 --- a/rag_factory/Store/VectorStore/VectorStoreBase.py +++ b/rag_factory/Store/VectorStore/VectorStoreBase.py @@ -15,6 +15,7 @@ Sequence, Iterable, Iterator, + Tuple, ) from itertools import cycle import asyncio @@ -23,6 +24,7 @@ if TYPE_CHECKING: from collections.abc import Collection from rag_factory.Embed import Embeddings + from ...Retrieval.Retriever.Retriever_VectorStore import VectorStoreRetriever logger = logging.getLogger(__name__) @@ -293,7 +295,7 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]: def similarity_search_with_score( self, *args: Any, **kwargs: Any - ) -> list[tuple[Document, float]]: + ) -> list[Tuple[Document, float]]: """使用距离运行相似性搜索 Args: @@ -307,7 +309,7 @@ def similarity_search_with_score( async def asimilarity_search_with_score( self, *args: Any, **kwargs: Any - ) -> list[tuple[Document, float]]: + ) -> list[Tuple[Document, float]]: """异步使用距离运行相似性搜索""" return await asyncio.get_event_loop().run_in_executor( ThreadPoolExecutor(), self.similarity_search_with_score, *args, **kwargs @@ -318,7 +320,7 @@ def _similarity_search_with_relevance_scores( query: str, k: int = 4, **kwargs: Any, - ) -> list[tuple[Document, float]]: + ) -> list[Tuple[Document, float]]: """默认的带相关性分数的相似性搜索 必要时在子类中修改。 @@ -344,7 +346,7 @@ async def _asimilarity_search_with_relevance_scores( query: str, k: int = 4, **kwargs: Any, - ) -> list[tuple[Document, float]]: + ) -> list[Tuple[Document, float]]: """异步带相关性分数的相似性搜索""" relevance_score_fn = self._select_relevance_score_fn() docs_and_scores = await self.asimilarity_search_with_score(query, k, **kwargs) @@ -355,7 +357,7 @@ def similarity_search_with_relevance_scores( query: str, k: int = 4, **kwargs: Any, - ) -> list[tuple[Document, float]]: + ) -> list[Tuple[Document, float]]: """返回[0, 1]范围内的文档和相关性分数 0表示不相似,1表示最相似。 @@ -402,7 +404,7 @@ async def asimilarity_search_with_relevance_scores( query: str, k: int = 4, **kwargs: Any, - ) -> list[tuple[Document, float]]: + ) -> list[Tuple[Document, float]]: """异步返回[0, 1]范围内的文档和相关性分数""" score_threshold = kwargs.pop("score_threshold", None) @@ -627,11 +629,11 @@ def _get_retriever_tags(self) -> list[str]: def as_retriever(self, **kwargs: Any) -> "VectorStoreRetriever": """从此VectorStore返回初始化的VectorStoreRetriever""" - from Retrieval import VectorStoreRetriever + # 延迟导入以避免循环依赖 + from ...Retrieval.Retriever.Retriever_VectorStore import VectorStoreRetriever tags = kwargs.pop("tags", None) or [] + self._get_retriever_tags() return VectorStoreRetriever(vectorstore=self, tags=tags, **kwargs) - diff --git a/rag_factory/Store/VectorStore/VectorStore_Faiss.py b/rag_factory/Store/VectorStore/VectorStore_Faiss.py index 0e38e37..2f2f55d 100644 --- a/rag_factory/Store/VectorStore/VectorStore_Faiss.py +++ b/rag_factory/Store/VectorStore/VectorStore_Faiss.py @@ -4,20 +4,21 @@ import os import uuid import numpy as np -from typing import Any, Optional, Callable +from typing import Any, Optional, Callable, List, Tuple from .VectorStoreBase import VectorStore, Document -from Embed import Embeddings +from ...Embed.Embedding_Base import Embeddings import asyncio from concurrent.futures import ThreadPoolExecutor +# TODO 需要支持GPU,提高速度 def _mmr_select( - docs_and_scores: list[tuple[Document, float]], - embeddings: list[list[float]], - query_embedding: list[float], + docs_and_scores: List[Tuple[Document, float]], + embeddings: List[List[float]], + query_embedding: List[float], k: int, lambda_mult: float = 0.5, -) -> list[Document]: +) -> List[Document]: """最大边际相关性选择算法""" if k >= len(docs_and_scores): return [doc for doc, _ in docs_and_scores] @@ -153,12 +154,12 @@ def _normalize_vectors(self, vectors: np.ndarray) -> np.ndarray: def add_texts( self, - texts: list[str], - metadatas: Optional[list[dict]] = None, + texts: List[str], + metadatas: Optional[List[dict]] = None, *, - ids: Optional[list[str]] = None, + ids: Optional[List[str]] = None, **kwargs: Any, - ) -> list[str]: + ) -> List[str]: """添加文本到向量存储""" if not texts: return [] @@ -209,12 +210,12 @@ def add_texts( async def aadd_texts( self, - texts: list[str], - metadatas: Optional[list[dict]] = None, + texts: List[str], + metadatas: Optional[List[dict]] = None, *, - ids: Optional[list[str]] = None, + ids: Optional[List[str]] = None, **kwargs: Any, - ) -> list[str]: + ) -> List[str]: """异步添加文本""" return await asyncio.get_event_loop().run_in_executor( ThreadPoolExecutor(), self.add_texts, texts, metadatas, ids, **kwargs @@ -222,14 +223,14 @@ async def aadd_texts( def similarity_search( self, query: str, k: int = 4, **kwargs: Any - ) -> list[Document]: + ) -> List[Document]: """相似性搜索""" docs_and_scores = self.similarity_search_with_score(query, k, **kwargs) return [doc for doc, _ in docs_and_scores] def similarity_search_with_score( self, query: str, k: int = 4, **kwargs: Any - ) -> list[tuple[Document, float]]: + ) -> List[Tuple[Document, float]]: """带分数的相似性搜索""" if self.index is None or self.index.ntotal == 0: return [] @@ -239,15 +240,15 @@ def similarity_search_with_score( return self.similarity_search_by_vector_with_score(query_embedding, k, **kwargs) def similarity_search_by_vector( - self, embedding: list[float], k: int = 4, **kwargs: Any - ) -> list[Document]: + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Document]: """根据向量相似性搜索""" docs_and_scores = self.similarity_search_by_vector_with_score(embedding, k, **kwargs) return [doc for doc, _ in docs_and_scores] def similarity_search_by_vector_with_score( - self, embedding: list[float], k: int = 4, **kwargs: Any - ) -> list[tuple[Document, float]]: + self, embedding: List[float], k: int = 4, **kwargs: Any + ) -> List[Tuple[Document, float]]: """根据向量带分数的相似性搜索""" if self.index is None or self.index.ntotal == 0: return [] @@ -278,7 +279,7 @@ def max_marginal_relevance_search( fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> list[Document]: + ) -> List[Document]: """最大边际相关性搜索""" if self.index is None or self.index.ntotal == 0: return [] @@ -324,12 +325,12 @@ def max_marginal_relevance_search( def max_marginal_relevance_search_by_vector( self, - embedding: list[float], + embedding: List[float], k: int = 4, fetch_k: int = 20, lambda_mult: float = 0.5, **kwargs: Any, - ) -> list[Document]: + ) -> List[Document]: """根据向量的最大边际相关性搜索""" if self.index is None or self.index.ntotal == 0: return [] @@ -369,7 +370,7 @@ def max_marginal_relevance_search_by_vector( return selected_docs - def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]: + def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]: """删除文档(FAISS不支持直接删除,需要重建索引)""" if ids is None: # 删除所有 @@ -412,7 +413,7 @@ def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[boo return True - def get_by_ids(self, ids: list[str]) -> list[Document]: + def get_by_ids(self, ids: List[str]) -> List[Document]: """根据ID获取文档""" return [self.docstore[doc_id] for doc_id in ids if doc_id in self.docstore] @@ -482,11 +483,11 @@ def load_local( @classmethod def from_texts( cls, - texts: list[str], + texts: List[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: Optional[List[dict]] = None, *, - ids: Optional[list[str]] = None, + ids: Optional[List[str]] = None, **kwargs: Any, ) -> "FaissVectorStore": """从文本创建FAISS向量存储""" @@ -497,11 +498,11 @@ def from_texts( @classmethod async def afrom_texts( cls, - texts: list[str], + texts: List[str], embedding: Embeddings, - metadatas: Optional[list[dict]] = None, + metadatas: Optional[List[dict]] = None, *, - ids: Optional[list[str]] = None, + ids: Optional[List[str]] = None, **kwargs: Any, ) -> "FaissVectorStore": """异步从文本创建FAISS向量存储""" diff --git a/rag_factory/Store/VectorStore/registry.py b/rag_factory/Store/VectorStore/registry.py index 611b690..ef7f41f 100644 --- a/rag_factory/Store/VectorStore/registry.py +++ b/rag_factory/Store/VectorStore/registry.py @@ -1,7 +1,7 @@ # VectorStore/registry.py -from typing import Dict, Type, Any, Optional +from typing import Dict, Type, Any, Optional, List from .VectorStoreBase import VectorStore -from Embed.Embedding_Base import Embeddings +from ...Embed.Embedding_Base import Embeddings from .VectorStore_Faiss import FaissVectorStore @@ -22,9 +22,21 @@ def create(cls, name: str, embedding: Embeddings, **kwargs) -> VectorStore: raise ValueError(f"未注册的向量存储类型: {name}") return cls._stores[name](embedding=embedding, **kwargs) + + @classmethod + def load(cls, name: str, folder_path: str, embedding: Embeddings, **kwargs) -> VectorStore: + """加载已保存的向量存储实例""" + if name not in cls._stores: + raise ValueError(f"未注册的向量存储类型: {name}") + store_class = cls._stores[name] + if hasattr(store_class, "load_local"): + return store_class.load_local(folder_path, embeddings=embedding, **kwargs) + else: + raise ValueError(f"{store_class.__name__} 不支持从本地加载") + @classmethod - def list_available(cls) -> list[str]: + def list_available(cls) -> List[str]: """列出可用的向量存储类型""" return list(cls._stores.keys()) diff --git a/rag_factory/Store/VectorStore/test.py b/rag_factory/Store/VectorStore/test.py new file mode 100644 index 0000000..26a839b --- /dev/null +++ b/rag_factory/Store/VectorStore/test.py @@ -0,0 +1,57 @@ +# test_faiss_vectorstore.py +import numpy as np +from rag_factory.Store.VectorStore.VectorStore_Faiss import FaissVectorStore, Document +from rag_factory.Embed import EmbeddingRegistry +from rag_factory.Store import VectorStoreRegistry +from rag_factory.Retrieval import RetrieverRegistry + +if __name__ == "__main__": + # 初始化 + embeddings = EmbeddingRegistry.create(name="huggingface", model_name="/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B") + # vs = VectorStoreRegistry.load(name="faiss", folder_path="/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/faiss_test_store", embedding=embeddings) + # print(vs) + # vs = VectorStoreRegistry.create(name="faiss", embedding=embeddings) + vs = FaissVectorStore.load_local(folder_path="/data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store", embeddings=embeddings, index_name="index") + # # # 添加一些文本 + # texts = ["苹果是一种水果", "香蕉是黄色的", "猫是一种动物", "狗喜欢跑步"] + # # ids = vs.add_texts(texts, metadatas=[{"type": "test"} for _ in texts]) + + # documents = [Document(content=text, metadata={"type": "test"}) for text in texts] + # vectorstore = vs.from_documents(documents, embedding=embeddings) + # vectorstore.save_local(folder_path="/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/faiss_test_store") + # print(f"添加的文档ID: {ids}") + + # # 相似性搜索 + # results = vs.similarity_search("苹果", k=2) + # print("\n=== 相似性搜索结果 ===") + # for doc in results: + # print(f"内容: {doc.content}, 元数据: {doc.metadata}") + + # # 带分数的搜索 + # results_with_score = vs.similarity_search_with_score("苹果", k=2) + # print("\n=== 带分数的搜索结果 ===") + # for doc, score in results_with_score: + # print(f"内容: {doc.content}, 分数: {score}") + + # # 最大边际相关性搜索 + # mmr_results = vs.max_marginal_relevance_search("苹果", k=2, fetch_k=3) + # print("\n=== MMR搜索结果 ===") + # for doc in mmr_results: + # print(f"内容: {doc.content}") + + # # 保存到本地 + # save_path = "./faiss_test_store" + # vs.save_local(save_path) + # print(f"\n索引已保存到: {save_path}") + + # 从本地加载 + # loaded_vs = FaissVectorStore.load_local(save_path, embeddings) + # load_results = loaded_vs.similarity_search("苹果", k=2) + # print("\n=== 从本地加载后的搜索结果 ===") + # for doc in load_results: + # print(f"内容: {doc.content}") + + retriever = vs.as_retriever(search_kwargs={"k": 2}) + # retriever = RetrieverRegistry.create(name="vectorstore", vectorstore=vs) + results = retriever.invoke("文件名称:GB$T 2828.1-2012 计数抽样检验程序 第1部分:按接收质量限(AQL)检索的逐批检验抽样计划.pdf\n4 不合格的表示 4. 1 总则\n不合格的程度以不合格品百分数(见 3.1.8和 3.1.9)或每百单位产品不合格数(见 3.1.10和3.1.11)表示。表7、表8和表10 是基于假定不合格的出现是随机且统计独立的。如果已知产品的某个不合格可能由某一条件引起,此条件还可能引起其他一些不合格,则应仅考虑该产品是否为合格品,而不管该产品有多少个不合格。") + print(results) diff --git a/rag_factory/Store/__init__.py b/rag_factory/Store/__init__.py index a2b1a3a..895b140 100644 --- a/rag_factory/Store/__init__.py +++ b/rag_factory/Store/__init__.py @@ -1,5 +1,7 @@ from .VectorStore.registry import VectorStoreRegistry +from .VectorStore.VectorStore_Faiss import FaissVectorStore __all__ = [ "VectorStoreRegistry", + "FaissVectorStore", ] \ No newline at end of file diff --git a/rag_factory/llms/__init__.py b/rag_factory/llms/__init__.py index 33cf556..2476ec8 100644 --- a/rag_factory/llms/__init__.py +++ b/rag_factory/llms/__init__.py @@ -1,5 +1,9 @@ from .openai_compatible import OpenAICompatible from .dashscope.base import DashScope, DashScopeGenerationModels +from .openai_llm import OpenAILLM +from .registry import LLMRegistry __all__ = ['OpenAICompatible', - "DashScope", "DashScopeGenerationModels"] + "DashScope", "DashScopeGenerationModels", + "OpenAILLM", + "LLMRegistry"] diff --git a/rag_factory/llms/llm_base.py b/rag_factory/llms/llm_base.py new file mode 100644 index 0000000..78e4475 --- /dev/null +++ b/rag_factory/llms/llm_base.py @@ -0,0 +1,142 @@ +from abc import ABC, abstractmethod +from typing import Dict, Any, List, Optional, Union +import logging + +logger = logging.getLogger(__name__) + + +class LLMBase(ABC): + """ + 大语言模型基类,定义了所有LLM实现必须遵循的接口 + """ + + def __init__(self, model_name: str, **kwargs): + """ + 初始化LLM基类 + + Args: + model_name: 模型名称 + **kwargs: 其他配置参数 + """ + self.model_name = model_name + self.config = kwargs + self._setup_logging() + + def _setup_logging(self): + """设置日志配置""" + self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}") + + + @abstractmethod + def chat( + self, + messages: List[Dict[str, str]], + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + **kwargs + ) -> str: + """ + 对话式生成 + + Args: + messages: 对话消息列表,格式如[{"role": "user", "content": "问题"}] + max_tokens: 最大生成token数 + temperature: 生成温度参数 + **kwargs: 其他生成参数 + + Returns: + 生成的回复文本 + """ + pass + + @abstractmethod + def stream_chat( + self, + messages: List[Dict[str, str]], + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + **kwargs + ): + """ + 流式对话生成 + """ + pass + + @abstractmethod + def embed(self, texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: + """ + 文本嵌入生成 + + Args: + texts: 单个文本或文本列表 + + Returns: + 嵌入向量或嵌入向量列表 + """ + pass + + def get_model_info(self) -> Dict[str, Any]: + """ + 获取模型信息 + + Returns: + 包含模型名称和配置的字典 + """ + return { + "model_name": self.model_name, + "config": self.config, + "class_name": self.__class__.__name__ + } + + def validate_input(self, input_text: str, max_length: Optional[int] = None) -> bool: + """ + 验证输入文本 + + Args: + input_text: 输入文本 + max_length: 最大长度限制 + + Returns: + 是否验证通过 + """ + if not isinstance(input_text, str): + self.logger.error("输入必须是字符串类型") + return False + + if not input_text.strip(): + self.logger.error("输入文本不能为空") + return False + + if max_length and len(input_text) > max_length: + self.logger.error(f"输入文本长度超过限制: {len(input_text)} > {max_length}") + return False + + return True + + def format_messages(self, user_message: str, system_message: Optional[str] = None) -> List[Dict[str, str]]: + """ + 格式化对话消息 + + Args: + user_message: 用户消息 + system_message: 系统消息(可选) + + Returns: + 格式化后的消息列表 + """ + messages = [] + + if system_message: + messages.append({"role": "system", "content": system_message}) + + messages.append({"role": "user", "content": user_message}) + + return messages + + def __str__(self) -> str: + """字符串表示""" + return f"{self.__class__.__name__}(model_name='{self.model_name}')" + + def __repr__(self) -> str: + """详细字符串表示""" + return f"{self.__class__.__name__}(model_name='{self.model_name}', config={self.config})" diff --git a/rag_factory/llms/openai_llm.py b/rag_factory/llms/openai_llm.py new file mode 100644 index 0000000..6069155 --- /dev/null +++ b/rag_factory/llms/openai_llm.py @@ -0,0 +1,264 @@ +import openai +from typing import Dict, Any, List, Optional, Union, Tuple +from .llm_base import LLMBase + +class OpenAILLM(LLMBase): + """ + OpenAI LLM对话模型 + """ + + def __init__( + self, + model_name: str = "gpt-4o-mini", + api_key: Optional[str] = None, + base_url: Optional[str] = None, + organization: Optional[str] = None, + max_retries: int = 3, + timeout: float = 60.0, + **kwargs + ): + """ + 初始化OpenAI LLM + + Args: + model_name: 模型名称,如 gpt-3.5-turbo, gpt-4 等 + api_key: OpenAI API密钥 + base_url: API基础URL + organization: 组织ID(可选) + max_retries: 最大重试次数 + timeout: 请求超时时间 + **kwargs: 其他配置参数 + """ + super().__init__(model_name, **kwargs) + + # 初始化OpenAI客户端 + self.client = openai.OpenAI( + api_key=api_key, + base_url=base_url, + organization=organization, + max_retries=max_retries, + timeout=timeout + ) + + # 默认参数 + self.default_max_tokens = kwargs.get('max_tokens', 2000) + self.default_temperature = kwargs.get('temperature', 0.7) + + self.logger.info(f"OpenAI LLM初始化完成,模型: {model_name}") + + + def chat( + self, + messages: List[Dict[str, str]], + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + return_token_count: bool = False, + **kwargs + ) -> Union[str, Tuple[str, Dict[str, int]]]: + """ + 对话式生成 + + Args: + messages: 对话消息列表,格式如[{"role": "user", "content": "问题"}] + max_tokens: 最大生成token数 + temperature: 生成温度参数 + return_token_count: 是否返回token统计信息 + **kwargs: 其他生成参数 + + Returns: + 生成的回复文本,如果return_token_count为True则返回(文本, token统计) + """ + if not messages or not isinstance(messages, list): + raise ValueError("消息列表不能为空且必须是列表格式") + + # 验证消息格式 + for msg in messages: + if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: + raise ValueError("消息格式错误,必须包含role和content字段") + if not self.validate_input(msg['content']): + raise ValueError(f"消息内容验证失败: {msg['content']}") + + try: + response = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=max_tokens or self.default_max_tokens, + temperature=temperature or self.default_temperature, + **kwargs + ) + + result = response.choices[0].message.content.strip() + + if return_token_count: + # 获取输出token数 + input_tokens = response.usage.prompt_tokens if response.usage else 0 + output_tokens = response.usage.completion_tokens if response.usage else 0 + total_tokens = response.usage.total_tokens if response.usage else 0 + + token_stats = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens + } + + self.logger.debug(f"对话生成成功,长度: {len(result)}, 输入tokens: {input_tokens}, 输出tokens: {output_tokens}") + return result, token_stats + else: + self.logger.debug(f"对话生成成功,长度: {len(result)}") + return result + + except Exception as e: + self.logger.error(f"对话生成失败: {str(e)}") + raise + + + def stream_chat( + self, + messages: List[Dict[str, str]], + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + return_token_count: bool = False, + **kwargs + ): + """ + 流式对话生成 + + Args: + messages: 对话消息列表 + max_tokens: 最大生成token数 + temperature: 生成温度参数 + return_token_count: 是否返回token统计信息 + **kwargs: 其他生成参数 + + Yields: + 生成的文本片段,如果return_token_count为True则在流式输出结束后yield token统计 + """ + if not messages or not isinstance(messages, list): + raise ValueError("消息列表不能为空且必须是列表格式") + + # 验证消息格式 + for msg in messages: + if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg: + raise ValueError("消息格式错误,必须包含role和content字段") + if not self.validate_input(msg['content']): + raise ValueError(f"消息内容验证失败: {msg['content']}") + + try: + stream = self.client.chat.completions.create( + model=self.model_name, + messages=messages, + max_tokens=max_tokens or self.default_max_tokens, + temperature=temperature or self.default_temperature, + stream=True, + stream_options={"include_usage": True} if return_token_count else None, + **kwargs + ) + + full_response = "" + + for chunk in stream: + # 检查choices是否存在以及是否有内容 + if chunk.choices and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + if hasattr(delta, 'content') and delta.content is not None: + content = delta.content + full_response += content + yield content + + # 检查是否有usage信息(在流的最后一个chunk中) + if return_token_count and hasattr(chunk, 'usage') and chunk.usage is not None: + input_tokens = chunk.usage.prompt_tokens if chunk.usage else 0 + output_tokens = chunk.usage.completion_tokens if chunk.usage else 0 + total_tokens = chunk.usage.total_tokens if chunk.usage else 0 + + token_stats = { + "input_tokens": input_tokens, + "output_tokens": output_tokens, + "total_tokens": total_tokens + } + + self.logger.debug(f"流式对话生成成功,长度: {len(full_response)}, 输入tokens: {input_tokens}, 输出tokens: {output_tokens}") + yield token_stats + + except Exception as e: + self.logger.error(f"流式对话生成失败: {str(e)}") + raise + + + def embed(self, texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]: + """ + 文本嵌入生成 + + Args: + texts: 单个文本或文本列表 + + Returns: + 嵌入向量或嵌入向量列表 + """ + # 检查当前模型是否为嵌入模型 + embedding_models = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"] + if self.model_name not in embedding_models: + warning_msg = f"警告:当前模型'{self.model_name}'不是嵌入模型,建议使用嵌入专用模型如 text-embedding-ada-002" + self.logger.warning(warning_msg) + raise ValueError(f"当前模型'{self.model_name}'不支持嵌入生成,请使用嵌入专用模型") + + # 统一处理为列表格式 + is_single = isinstance(texts, str) + text_list = [texts] if is_single else texts + + # 验证输入 + for text in text_list: + if not self.validate_input(text): + raise ValueError(f"文本内容验证失败: {text}") + + try: + # 使用当前嵌入模型生成嵌入 + response = self.client.embeddings.create( + model=self.model_name, + input=text_list + ) + + embeddings = [data.embedding for data in response.data] + + # 根据输入格式返回结果 + result = embeddings[0] if is_single else embeddings + self.logger.debug(f"嵌入生成成功,文本数量: {len(text_list)}") + return result + + except Exception as e: + self.logger.error(f"嵌入生成失败: {str(e)}") + raise + + def get_available_models(self) -> List[str]: + """ + 获取可用的模型列表 + + Returns: + 可用模型名称列表 + """ + try: + models = self.client.models.list() + model_names = [model.id for model in models.data] + self.logger.debug(f"获取到{len(model_names)}个可用模型") + return model_names + except Exception as e: + self.logger.error(f"获取模型列表失败: {str(e)}") + return [] + + def get_model_info(self) -> Dict[str, Any]: + """ + 获取模型信息 + + Returns: + 包含模型名称和配置的字典 + """ + info = super().get_model_info() + info.update({ + "api_base": getattr(self.client, 'base_url', None), + "organization": getattr(self.client, 'organization', None), + "max_retries": getattr(self.client, 'max_retries', None), + "timeout": getattr(self.client, 'timeout', None), + "default_max_tokens": self.default_max_tokens, + "default_temperature": self.default_temperature + }) + return info diff --git a/rag_factory/llms/registry.py b/rag_factory/llms/registry.py new file mode 100644 index 0000000..249f03d --- /dev/null +++ b/rag_factory/llms/registry.py @@ -0,0 +1,86 @@ +from .openai_llm import OpenAILLM +from .llm_base import LLMBase +from typing import Dict, Type, Any, Optional, List +import logging + +logging.basicConfig( + level=logging.INFO, # 设置最低输出级别为 INFO + format="%(asctime)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + +class LLMRegistry: + """LLM模型注册表,用于管理和创建不同类型的LLM模型""" + _llms: Dict[str, Type[LLMBase]] = {} + + @classmethod + def register(cls, name: str, llm_class: Type[LLMBase]): + """注册LLM模型类 + + Args: + name: 模型名称 + llm_class: LLM模型类 + """ + cls._llms[name] = llm_class + + @classmethod + def create(cls, name: str, **kwargs) -> LLMBase: + """创建LLM实例 + + Args: + name: 模型名称 + **kwargs: 模型初始化参数 + + Returns: + LLM实例 + + Raises: + ValueError: 当模型名称不存在时 + """ + if name not in cls._llms: + available_llms = list(cls._llms.keys()) + raise ValueError(f"LLM模型 '{name}' 未注册。可用的模型: {available_llms}") + + llm_class = cls._llms[name] + + return llm_class(**kwargs) + + @classmethod + def list_llms(cls) -> List[str]: + """列出所有已注册的LLM模型名称 + + Returns: + 已注册的模型名称列表 + """ + return list(cls._llms.keys()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """检查模型是否已注册 + + Args: + name: 模型名称 + + Returns: + 如果已注册返回True,否则返回False + """ + return name in cls._llms + + @classmethod + def unregister(cls, name: str) -> bool: + """取消注册模型 + + Args: + name: 模型名称 + + Returns: + 成功取消注册返回True,模型不存在返回False + """ + if name in cls._llms: + del cls._llms[name] + return True + return False + + +# 注册默认的LLM模型 +LLMRegistry.register("openai", OpenAILLM) \ No newline at end of file diff --git a/rag_factory/llms/test.py b/rag_factory/llms/test.py new file mode 100644 index 0000000..49332ba --- /dev/null +++ b/rag_factory/llms/test.py @@ -0,0 +1,75 @@ +import os +from pprint import pprint +from .openai_llm import OpenAILLM # 你的类所在文件 + +# ==== 配置 ==== +# API_KEY = os.getenv("OPENAI_API_KEY") # 或直接写成 "sk-xxxx" +API_KEY = "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2" +# BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1") +BASE_URL = "https://api.gptsapi.net/v1" + +def test_openai_llm(): + # 初始化普通对话模型 + llm = OpenAILLM( + model_name="gpt-4o-mini", + api_key=API_KEY, + base_url=BASE_URL + ) + + # ==== 1. 测试 chat ==== + messages = [{"role": "user", "content": "你好,请用一句话介绍你自己"}] + print("\n=== chat (普通) ===") + result = llm.chat(messages) + print("回复:", result) + + # ==== 2. 测试 chat + token 统计 ==== + print("\n=== chat (返回 token 数) ===") + result, token_stats = llm.chat(messages, return_token_count=True) + print("回复:", result) + print("token统计:", token_stats) + + # ==== 3. 测试 stream_chat (普通流式) ==== + print("\n=== stream_chat (普通流式) ===") + for chunk in llm.stream_chat(messages): + print(chunk, end="", flush=True) + print() + + # ==== 4. 测试 stream_chat (返回 token 数) ==== + print("\n=== stream_chat (返回 token 数) ===") + for item in llm.stream_chat(messages, return_token_count=True): + if isinstance(item, dict) and "input_tokens" in item: + # 这是最后的token统计信息 + print("\nToken统计:", item) + else: + # 这是文本片段 + print(item, end="", flush=True) + + + # ==== 7. 测试 get_model_info ==== + print("\n=== get_model_info ===") + pprint(llm.get_model_info()) + + # ==== 8. 测试 get_available_models ==== + print("\n=== get_available_models ===") + model_list = llm.get_available_models() + print("可用模型数量:", len(model_list)) + print("前10个模型:", model_list[:10]) + + # ==== 9. 测试 embed ==== + # 使用专门的嵌入模型 + embed_llm = OpenAILLM( + model_name="text-embedding-3-small", + api_key=API_KEY, + base_url=BASE_URL + ) + print("\n=== embed (单条) ===") + vec = embed_llm.embed("这是一个测试文本") + print("向量维度:", len(vec)) + + print("\n=== embed (多条) ===") + vecs = embed_llm.embed(["第一句", "第二句"]) + print("向量数量:", len(vecs), "每个向量维度:", len(vecs[0])) + + +if __name__ == "__main__": + test_openai_llm() diff --git a/rag_factory/parser/Base.py b/rag_factory/parser/Base.py deleted file mode 100644 index 9ce5ce9..0000000 --- a/rag_factory/parser/Base.py +++ /dev/null @@ -1,528 +0,0 @@ -# type: ignore -""" -Generic Document Parser Utility - -This module provides functionality for parsing PDF and image documents using MinerU 2.0 library, -and converts the parsing results into markdown and JSON formats - -Note: MinerU 2.0 no longer includes LibreOffice document conversion module. -For Office documents (.doc, .docx, .ppt, .pptx), please convert them to PDF format first. -""" - -from __future__ import annotations - - -import json -import argparse -import base64 -import subprocess -import tempfile -import logging -from pathlib import Path -from typing import ( - Dict, - List, - Optional, - Union, - Tuple, - Any, - TypeVar, -) - -T = TypeVar("T") - - -class Parser: - """ - Base class for document parsing utilities. - - Defines common functionality and constants for parsing different document types. - """ - - # Define common file formats - OFFICE_FORMATS = {".doc", ".docx", ".ppt", ".pptx", ".xls", ".xlsx"} - IMAGE_FORMATS = {".png", ".jpeg", ".jpg", ".bmp", ".tiff", ".tif", ".gif", ".webp"} - TEXT_FORMATS = {".txt", ".md"} - - # Class-level logger - logger = logging.getLogger(__name__) - - def __init__(self) -> None: - """Initialize the base parser.""" - pass - - @staticmethod - def convert_office_to_pdf( - doc_path: Union[str, Path], output_dir: Optional[str] = None - ) -> Path: - """ - Convert Office document (.doc, .docx, .ppt, .pptx, .xls, .xlsx) to PDF. - Requires LibreOffice to be installed. - - Args: - doc_path: Path to the Office document file - output_dir: Output directory for the PDF file - - Returns: - Path to the generated PDF file - """ - try: - # Convert to Path object for easier handling - doc_path = Path(doc_path) - if not doc_path.exists(): - raise FileNotFoundError(f"Office document does not exist: {doc_path}") - - name_without_suff = doc_path.stem - - # Prepare output directory - if output_dir: - base_output_dir = Path(output_dir) - else: - base_output_dir = doc_path.parent / "libreoffice_output" - - base_output_dir.mkdir(parents=True, exist_ok=True) - - # Create temporary directory for PDF conversion - with tempfile.TemporaryDirectory() as temp_dir: - temp_path = Path(temp_dir) - - # Convert to PDF using LibreOffice - logging.info(f"Converting {doc_path.name} to PDF using LibreOffice...") - - # Prepare subprocess parameters to hide console window on Windows - import platform - - # Try LibreOffice commands in order of preference - commands_to_try = ["libreoffice", "soffice"] - - conversion_successful = False - for cmd in commands_to_try: - try: - convert_cmd = [ - cmd, - "--headless", - "--convert-to", - "pdf", - "--outdir", - str(temp_path), - str(doc_path), - ] - - # Prepare conversion subprocess parameters - convert_subprocess_kwargs = { - "capture_output": True, - "text": True, - "timeout": 60, # 60 second timeout - "encoding": "utf-8", - "errors": "ignore", - } - - # Hide console window on Windows - if platform.system() == "Windows": - convert_subprocess_kwargs["creationflags"] = ( - subprocess.CREATE_NO_WINDOW - ) - - result = subprocess.run( - convert_cmd, **convert_subprocess_kwargs - ) - - if result.returncode == 0: - conversion_successful = True - logging.info( - f"Successfully converted {doc_path.name} to PDF using {cmd}" - ) - break - else: - logging.warning( - f"LibreOffice command '{cmd}' failed: {result.stderr}" - ) - except FileNotFoundError: - logging.warning(f"LibreOffice command '{cmd}' not found") - except subprocess.TimeoutExpired: - logging.warning(f"LibreOffice command '{cmd}' timed out") - except Exception as e: - logging.error( - f"LibreOffice command '{cmd}' failed with exception: {e}" - ) - - if not conversion_successful: - raise RuntimeError( - f"LibreOffice conversion failed for {doc_path.name}. " - f"Please ensure LibreOffice is installed:\n" - "- Windows: Download from https://www.libreoffice.org/download/download/\n" - "- macOS: brew install --cask libreoffice\n" - "- Ubuntu/Debian: sudo apt-get install libreoffice\n" - "- CentOS/RHEL: sudo yum install libreoffice\n" - "Alternatively, convert the document to PDF manually." - ) - - # Find the generated PDF - pdf_files = list(temp_path.glob("*.pdf")) - if not pdf_files: - raise RuntimeError( - f"PDF conversion failed for {doc_path.name} - no PDF file generated. " - f"Please check LibreOffice installation or try manual conversion." - ) - - pdf_path = pdf_files[0] - logging.info( - f"Generated PDF: {pdf_path.name} ({pdf_path.stat().st_size} bytes)" - ) - - # Validate the generated PDF - if pdf_path.stat().st_size < 100: # Very small file, likely empty - raise RuntimeError( - "Generated PDF appears to be empty or corrupted. " - "Original file may have issues or LibreOffice conversion failed." - ) - - # Copy PDF to final output directory - final_pdf_path = base_output_dir / f"{name_without_suff}.pdf" - import shutil - - shutil.copy2(pdf_path, final_pdf_path) - - return final_pdf_path - - except Exception as e: - logging.error(f"Error in convert_office_to_pdf: {str(e)}") - raise - - @staticmethod - def convert_text_to_pdf( - text_path: Union[str, Path], output_dir: Optional[str] = None - ) -> Path: - """ - Convert text file (.txt, .md) to PDF using ReportLab with full markdown support. - - Args: - text_path: Path to the text file - output_dir: Output directory for the PDF file - - Returns: - Path to the generated PDF file - """ - try: - text_path = Path(text_path) - if not text_path.exists(): - raise FileNotFoundError(f"Text file does not exist: {text_path}") - - # Supported text formats - supported_text_formats = {".txt", ".md"} - if text_path.suffix.lower() not in supported_text_formats: - raise ValueError(f"Unsupported text format: {text_path.suffix}") - - # Read the text content - try: - with open(text_path, "r", encoding="utf-8") as f: - text_content = f.read() - except UnicodeDecodeError: - # Try with different encodings - for encoding in ["gbk", "latin-1", "cp1252"]: - try: - with open(text_path, "r", encoding=encoding) as f: - text_content = f.read() - logging.info(f"Successfully read file with {encoding} encoding") - break - except UnicodeDecodeError: - continue - else: - raise RuntimeError( - f"Could not decode text file {text_path.name} with any supported encoding" - ) - - # Prepare output directory - if output_dir: - base_output_dir = Path(output_dir) - else: - base_output_dir = text_path.parent / "reportlab_output" - - base_output_dir.mkdir(parents=True, exist_ok=True) - pdf_path = base_output_dir / f"{text_path.stem}.pdf" - - # Convert text to PDF - logging.info(f"Converting {text_path.name} to PDF...") - - try: - from reportlab.lib.pagesizes import A4 - from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer - from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle - from reportlab.lib.units import inch - from reportlab.pdfbase import pdfmetrics - - # Create PDF document - doc = SimpleDocTemplate( - str(pdf_path), - pagesize=A4, - leftMargin=inch, - rightMargin=inch, - topMargin=inch, - bottomMargin=inch, - ) - - # Get styles - styles = getSampleStyleSheet() - normal_style = styles["Normal"] - heading_style = styles["Heading1"] - - # Try to register a font that supports Chinese characters - try: - # Try to use system fonts that support Chinese - import platform - - system = platform.system() - if system == "Windows": - # Try common Windows fonts - for font_name in ["SimSun", "SimHei", "Microsoft YaHei"]: - try: - from reportlab.pdfbase.cidfonts import ( - UnicodeCIDFont, - ) - - pdfmetrics.registerFont(UnicodeCIDFont(font_name)) - normal_style.fontName = font_name - heading_style.fontName = font_name - break - except Exception: - continue - elif system == "Darwin": # macOS - for font_name in ["STSong-Light", "STHeiti"]: - try: - from reportlab.pdfbase.cidfonts import ( - UnicodeCIDFont, - ) - - pdfmetrics.registerFont(UnicodeCIDFont(font_name)) - normal_style.fontName = font_name - heading_style.fontName = font_name - break - except Exception: - continue - except Exception: - pass # Use default fonts if Chinese font setup fails - - # Build content - story = [] - - # Handle markdown or plain text - if text_path.suffix.lower() == ".md": - # Handle markdown content - simplified implementation - lines = text_content.split("\n") - for line in lines: - line = line.strip() - if not line: - story.append(Spacer(1, 12)) - continue - - # Headers - if line.startswith("#"): - level = len(line) - len(line.lstrip("#")) - header_text = line.lstrip("#").strip() - if header_text: - header_style = ParagraphStyle( - name=f"Heading{level}", - parent=heading_style, - fontSize=max(16 - level, 10), - spaceAfter=8, - spaceBefore=16 if level <= 2 else 12, - ) - story.append(Paragraph(header_text, header_style)) - else: - # Regular text - story.append(Paragraph(line, normal_style)) - story.append(Spacer(1, 6)) - else: - # Handle plain text files (.txt) - logging.info( - f"Processing plain text file with {len(text_content)} characters..." - ) - - # Split text into lines and process each line - lines = text_content.split("\n") - line_count = 0 - - for line in lines: - line = line.rstrip() - line_count += 1 - - # Empty lines - if not line.strip(): - story.append(Spacer(1, 6)) - continue - - # Regular text lines - # Escape special characters for ReportLab - safe_line = ( - line.replace("&", "&") - .replace("<", "<") - .replace(">", ">") - ) - - # Create paragraph - story.append(Paragraph(safe_line, normal_style)) - story.append(Spacer(1, 3)) - - logging.info(f"Added {line_count} lines to PDF") - - # If no content was added, add a placeholder - if not story: - story.append(Paragraph("(Empty text file)", normal_style)) - - # Build PDF - doc.build(story) - logging.info( - f"Successfully converted {text_path.name} to PDF ({pdf_path.stat().st_size / 1024:.1f} KB)" - ) - - except ImportError: - raise RuntimeError( - "reportlab is required for text-to-PDF conversion. " - "Please install it using: pip install reportlab" - ) - except Exception as e: - raise RuntimeError( - f"Failed to convert text file {text_path.name} to PDF: {str(e)}" - ) - - # Validate the generated PDF - if not pdf_path.exists() or pdf_path.stat().st_size < 100: - raise RuntimeError( - f"PDF conversion failed for {text_path.name} - generated PDF is empty or corrupted." - ) - - return pdf_path - - except Exception as e: - logging.error(f"Error in convert_text_to_pdf: {str(e)}") - raise - - @staticmethod - def _process_inline_markdown(text: str) -> str: - """ - Process inline markdown formatting (bold, italic, code, links) - - Args: - text: Raw text with markdown formatting - - Returns: - Text with ReportLab markup - """ - import re - - # Escape special characters for ReportLab - text = text.replace("&", "&").replace("<", "<").replace(">", ">") - - # Bold text: **text** or __text__ - text = re.sub(r"\*\*(.*?)\*\*", r"\1", text) - text = re.sub(r"__(.*?)__", r"\1", text) - - # Italic text: *text* or _text_ (but not in the middle of words) - text = re.sub(r"(?\1", text) - text = re.sub(r"(?\1", text) - - # Inline code: `code` - text = re.sub( - r"`([^`]+?)`", - r'\1', - text, - ) - - # Links: [text](url) - convert to text with URL annotation - def link_replacer(match): - link_text = match.group(1) - url = match.group(2) - return f'{link_text}' - - text = re.sub(r"\[([^\]]+?)\]\(([^)]+?)\)", link_replacer, text) - - # Strikethrough: ~~text~~ - text = re.sub(r"~~(.*?)~~", r"\1", text) - - return text - - def parse_pdf( - self, - pdf_path: Union[str, Path], - output_dir: Optional[str] = None, - method: str = "auto", - lang: Optional[str] = None, - **kwargs, - ) -> List[Dict[str, Any]]: - """ - Abstract method to parse PDF document. - Must be implemented by subclasses. - - Args: - pdf_path: Path to the PDF file - output_dir: Output directory path - method: Parsing method (auto, txt, ocr) - lang: Document language for OCR optimization - **kwargs: Additional parameters for parser-specific command - - Returns: - List[Dict[str, Any]]: List of content blocks - """ - raise NotImplementedError("parse_pdf must be implemented by subclasses") - - def parse_image( - self, - image_path: Union[str, Path], - output_dir: Optional[str] = None, - lang: Optional[str] = None, - **kwargs, - ) -> List[Dict[str, Any]]: - """ - Abstract method to parse image document. - Must be implemented by subclasses. - - Note: Different parsers may support different image formats. - Check the specific parser's documentation for supported formats. - - Args: - image_path: Path to the image file - output_dir: Output directory path - lang: Document language for OCR optimization - **kwargs: Additional parameters for parser-specific command - - Returns: - List[Dict[str, Any]]: List of content blocks - """ - raise NotImplementedError("parse_image must be implemented by subclasses") - - def parse_document( - self, - file_path: Union[str, Path], - method: str = "auto", - output_dir: Optional[str] = None, - lang: Optional[str] = None, - **kwargs, - ) -> List[Dict[str, Any]]: - """ - Abstract method to parse a document. - Must be implemented by subclasses. - - Args: - file_path: Path to the file to be parsed - method: Parsing method (auto, txt, ocr) - output_dir: Output directory path - lang: Document language for OCR optimization - **kwargs: Additional parameters for parser-specific command - - Returns: - List[Dict[str, Any]]: List of content blocks - """ - raise NotImplementedError("parse_document must be implemented by subclasses") - - def check_installation(self) -> bool: - """ - Abstract method to check if the parser is properly installed. - Must be implemented by subclasses. - - Returns: - bool: True if installation is valid, False otherwise - """ - raise NotImplementedError( - "check_installation must be implemented by subclasses" - ) - diff --git a/rag_factory/parser/Parser_Docling.py b/rag_factory/parser/Parser_Docling.py deleted file mode 100644 index f33886c..0000000 --- a/rag_factory/parser/Parser_Docling.py +++ /dev/null @@ -1,489 +0,0 @@ -from __future__ import annotations -import base64 -import subprocess -from pathlib import Path -from typing import Union, List, Dict, Any, Optional, Tuple -import json -import logging -from .Base import Parser - - -class DoclingParser(Parser): - """ - Docling document parsing utility class. - - Specialized in parsing Office documents and HTML files, converting the content - into structured data and generating markdown and JSON output. - """ - - # Define Docling-specific formats - HTML_FORMATS = {".html", ".htm", ".xhtml"} - - def __init__(self) -> None: - """Initialize DoclingParser""" - super().__init__() - - def parse_pdf( - self, - pdf_path: Union[str, Path], - output_dir: Optional[str] = None, - method: str = "auto", - lang: Optional[str] = None, - **kwargs, - ) -> List[Dict[str, Any]]: - """ - Parse PDF document using Docling - - Args: - pdf_path: Path to the PDF file - output_dir: Output directory path - method: Parsing method (auto, txt, ocr) - lang: Document language for OCR optimization - **kwargs: Additional parameters for docling command - - Returns: - List[Dict[str, Any]]: List of content blocks - """ - try: - # Convert to Path object for easier handling - pdf_path = Path(pdf_path) - if not pdf_path.exists(): - raise FileNotFoundError(f"PDF file does not exist: {pdf_path}") - - name_without_suff = pdf_path.stem - - # Prepare output directory - if output_dir: - base_output_dir = Path(output_dir) - else: - base_output_dir = pdf_path.parent / "docling_output" - - base_output_dir.mkdir(parents=True, exist_ok=True) - - # Run docling command - self._run_docling_command( - input_path=pdf_path, - output_dir=base_output_dir, - file_stem=name_without_suff, - **kwargs, - ) - - # Read the generated output files - content_list, _ = self._read_output_files( - base_output_dir, name_without_suff - ) - return content_list - - except Exception as e: - logging.error(f"Error in parse_pdf: {str(e)}") - raise - - def parse_document( - self, - file_path: Union[str, Path], - method: str = "auto", - output_dir: Optional[str] = None, - lang: Optional[str] = None, - **kwargs, - ) -> List[Dict[str, Any]]: - """ - Parse document using Docling based on file extension - - Args: - file_path: Path to the file to be parsed - method: Parsing method - output_dir: Output directory path - lang: Document language for optimization - **kwargs: Additional parameters for docling command - - Returns: - List[Dict[str, Any]]: List of content blocks - """ - # Convert to Path object - file_path = Path(file_path) - if not file_path.exists(): - raise FileNotFoundError(f"File does not exist: {file_path}") - - # Get file extension - ext = file_path.suffix.lower() - - # Choose appropriate parser based on file type - if ext == ".pdf": - return self.parse_pdf(file_path, output_dir, method, lang, **kwargs) - elif ext in self.OFFICE_FORMATS: - return self.parse_office_doc(file_path, output_dir, lang, **kwargs) - elif ext in self.HTML_FORMATS: - return self.parse_html(file_path, output_dir, lang, **kwargs) - else: - raise ValueError( - f"Unsupported file format: {ext}. " - f"Docling only supports PDF files, Office formats ({', '.join(self.OFFICE_FORMATS)}) " - f"and HTML formats ({', '.join(self.HTML_FORMATS)})" - ) - - def _run_docling_command( - self, - input_path: Union[str, Path], - output_dir: Union[str, Path], - file_stem: str, - **kwargs, - ) -> None: - """ - Run docling command line tool - - Args: - input_path: Path to input file or directory - output_dir: Output directory path - file_stem: File stem for creating subdirectory - **kwargs: Additional parameters for docling command - """ - # Create subdirectory structure similar to MinerU - file_output_dir = Path(output_dir) / file_stem / "docling" - file_output_dir.mkdir(parents=True, exist_ok=True) - - cmd_json = [ - "docling", - "--output", - str(file_output_dir), - "--to", - "json", - str(input_path), - ] - cmd_md = [ - "docling", - "--output", - str(file_output_dir), - "--to", - "md", - str(input_path), - ] - - try: - # Prepare subprocess parameters to hide console window on Windows - import platform - - docling_subprocess_kwargs = { - "capture_output": True, - "text": True, - "check": True, - "encoding": "utf-8", - "errors": "ignore", - } - - # Hide console window on Windows - if platform.system() == "Windows": - docling_subprocess_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW - - result_json = subprocess.run(cmd_json, **docling_subprocess_kwargs) - result_md = subprocess.run(cmd_md, **docling_subprocess_kwargs) - logging.info("Docling command executed successfully") - if result_json.stdout: - logging.debug(f"JSON cmd output: {result_json.stdout}") - if result_md.stdout: - logging.debug(f"Markdown cmd output: {result_md.stdout}") - except subprocess.CalledProcessError as e: - logging.error(f"Error running docling command: {e}") - if e.stderr: - logging.error(f"Error details: {e.stderr}") - raise - except FileNotFoundError: - raise RuntimeError( - "docling command not found. Please ensure Docling is properly installed." - ) - - def _read_output_files( - self, - output_dir: Path, - file_stem: str, - ) -> Tuple[List[Dict[str, Any]], str]: - """ - Read the output files generated by docling and convert to MinerU format - - Args: - output_dir: Output directory - file_stem: File name without extension - - Returns: - Tuple containing (content list JSON, Markdown text) - """ - # Use subdirectory structure similar to MinerU - file_subdir = output_dir / file_stem / "docling" - md_file = file_subdir / f"{file_stem}.md" - json_file = file_subdir / f"{file_stem}.json" - - # Read markdown content - md_content = "" - if md_file.exists(): - try: - with open(md_file, "r", encoding="utf-8") as f: - md_content = f.read() - except Exception as e: - logging.warning(f"Could not read markdown file {md_file}: {e}") - - # Read JSON content and convert format - content_list = [] - if json_file.exists(): - try: - with open(json_file, "r", encoding="utf-8") as f: - docling_content = json.load(f) - # Convert docling format to minerU format - content_list = self.read_from_block_recursive( - docling_content["body"], - "body", - file_subdir, - 0, - "0", - docling_content, - ) - except Exception as e: - logging.warning(f"Could not read or convert JSON file {json_file}: {e}") - return content_list, md_content - - def read_from_block_recursive( - self, - block, - type: str, - output_dir: Path, - cnt: int, - num: str, - docling_content: Dict[str, Any], - ) -> List[Dict[str, Any]]: - content_list = [] - if not block.get("children"): - cnt += 1 - content_list.append(self.read_from_block(block, type, output_dir, cnt, num)) - else: - if type not in ["groups", "body"]: - cnt += 1 - content_list.append( - self.read_from_block(block, type, output_dir, cnt, num) - ) - members = block["children"] - for member in members: - cnt += 1 - member_tag = member["$ref"] - member_type = member_tag.split("/")[1] - member_num = member_tag.split("/")[2] - member_block = docling_content[member_type][int(member_num)] - content_list.extend( - self.read_from_block_recursive( - member_block, - member_type, - output_dir, - cnt, - member_num, - docling_content, - ) - ) - return content_list - - def read_from_block( - self, block, type: str, output_dir: Path, cnt: int, num: str - ) -> Dict[str, Any]: - if type == "texts": - if block["label"] == "formula": - return { - "type": "equation", - "img_path": "", - "text": block["orig"], - "text_format": "unkown", - "page_idx": cnt // 10, - } - else: - return { - "type": "text", - "text": block["orig"], - "page_idx": cnt // 10, - } - elif type == "pictures": - try: - base64_uri = block["image"]["uri"] - base64_str = base64_uri.split(",")[1] - # Create images directory within the docling subdirectory - image_dir = output_dir / "images" - image_dir.mkdir(parents=True, exist_ok=True) # Ensure directory exists - image_path = image_dir / f"image_{num}.png" - with open(image_path, "wb") as f: - f.write(base64.b64decode(base64_str)) - return { - "type": "image", - "img_path": str(image_path.resolve()), # Convert to absolute path - "image_caption": block.get("caption", ""), - "image_footnote": block.get("footnote", ""), - "page_idx": cnt // 10, - } - except Exception as e: - logging.warning(f"Failed to process image {num}: {e}") - return { - "type": "text", - "text": f"[Image processing failed: {block.get('caption', '')}]", - "page_idx": cnt // 10, - } - else: - try: - return { - "type": "table", - "img_path": "", - "table_caption": block.get("caption", ""), - "table_footnote": block.get("footnote", ""), - "table_body": block.get("data", []), - "page_idx": cnt // 10, - } - except Exception as e: - logging.warning(f"Failed to process table {num}: {e}") - return { - "type": "text", - "text": f"[Table processing failed: {block.get('caption', '')}]", - "page_idx": cnt // 10, - } - - def parse_office_doc( - self, - doc_path: Union[str, Path], - output_dir: Optional[str] = None, - lang: Optional[str] = None, - **kwargs, - ) -> List[Dict[str, Any]]: - """ - Parse office document directly using Docling - - Supported formats: .doc, .docx, .ppt, .pptx, .xls, .xlsx - - Args: - doc_path: Path to the document file - output_dir: Output directory path - lang: Document language for optimization - **kwargs: Additional parameters for docling command - - Returns: - List[Dict[str, Any]]: List of content blocks - """ - try: - # Convert to Path object - doc_path = Path(doc_path) - if not doc_path.exists(): - raise FileNotFoundError(f"Document file does not exist: {doc_path}") - - if doc_path.suffix.lower() not in self.OFFICE_FORMATS: - raise ValueError(f"Unsupported office format: {doc_path.suffix}") - - name_without_suff = doc_path.stem - - # Prepare output directory - if output_dir: - base_output_dir = Path(output_dir) - else: - base_output_dir = doc_path.parent / "docling_output" - - base_output_dir.mkdir(parents=True, exist_ok=True) - - # Run docling command - self._run_docling_command( - input_path=doc_path, - output_dir=base_output_dir, - file_stem=name_without_suff, - **kwargs, - ) - - # Read the generated output files - content_list, _ = self._read_output_files( - base_output_dir, name_without_suff - ) - return content_list - - except Exception as e: - logging.error(f"Error in parse_office_doc: {str(e)}") - raise - - def parse_html( - self, - html_path: Union[str, Path], - output_dir: Optional[str] = None, - lang: Optional[str] = None, - **kwargs, - ) -> List[Dict[str, Any]]: - """ - Parse HTML document using Docling - - Supported formats: .html, .htm, .xhtml - - Args: - html_path: Path to the HTML file - output_dir: Output directory path - lang: Document language for optimization - **kwargs: Additional parameters for docling command - - Returns: - List[Dict[str, Any]]: List of content blocks - """ - try: - # Convert to Path object - html_path = Path(html_path) - if not html_path.exists(): - raise FileNotFoundError(f"HTML file does not exist: {html_path}") - - if html_path.suffix.lower() not in self.HTML_FORMATS: - raise ValueError(f"Unsupported HTML format: {html_path.suffix}") - - name_without_suff = html_path.stem - - # Prepare output directory - if output_dir: - base_output_dir = Path(output_dir) - else: - base_output_dir = html_path.parent / "docling_output" - - base_output_dir.mkdir(parents=True, exist_ok=True) - - # Run docling command - self._run_docling_command( - input_path=html_path, - output_dir=base_output_dir, - file_stem=name_without_suff, - **kwargs, - ) - - # Read the generated output files - content_list, _ = self._read_output_files( - base_output_dir, name_without_suff - ) - return content_list - - except Exception as e: - logging.error(f"Error in parse_html: {str(e)}") - raise - - def check_installation(self) -> bool: - """ - Check if Docling is properly installed - - Returns: - bool: True if installation is valid, False otherwise - """ - try: - # Prepare subprocess parameters to hide console window on Windows - import platform - - subprocess_kwargs = { - "capture_output": True, - "text": True, - "check": True, - "encoding": "utf-8", - "errors": "ignore", - } - - # Hide console window on Windows - if platform.system() == "Windows": - subprocess_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW - - result = subprocess.run(["docling", "--version"], **subprocess_kwargs) - logging.debug(f"Docling version: {result.stdout.strip()}") - return True - except (subprocess.CalledProcessError, FileNotFoundError): - logging.debug( - "Docling is not properly installed. " - "Please ensure it is installed correctly." - ) - return False - - diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/download_model.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/download_model.py new file mode 100644 index 0000000..32d7087 --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/download_model.py @@ -0,0 +1,24 @@ +from argparse import ArgumentParser +import os + + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--type', '-t', type=str, default="huggingface") + parser.add_argument('--name', '-n', type=str, default="rednote-hilab/dots.ocr") + args = parser.parse_args() + script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + print(f"Attention: The model save dir dots.ocr should be replace by a name without `.` like DotsOCR, util we merge our code to transformers.") + model_dir = os.path.join(script_dir, "weights/DotsOCR") + if not os.path.exists(model_dir): + os.makedirs(model_dir) + if args.type == "huggingface": + from huggingface_hub import snapshot_download + snapshot_download(repo_id=args.name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True) + elif args.type == "modelscope": + from modelscope import snapshot_download + snapshot_download(repo_id=args.name, local_dir=model_dir) + else: + raise ValueError(f"Invalid type: {args.type}") + + print(f"model downloaded to {model_dir}") diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/inference.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/inference.py new file mode 100644 index 0000000..80f3395 --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/inference.py @@ -0,0 +1,50 @@ +import json +import io +import base64 +import math +from PIL import Image +import requests +from dots_ocr.utils.image_utils import PILimage_to_base64 +from openai import OpenAI +import os + + +def inference_with_vllm( + image, + prompt, + ip="localhost", + port=8000, + temperature=0.1, + top_p=0.9, + max_completion_tokens=32768, + model_name='model', + ): + + addr = f"http://{ip}:{port}/v1" + client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr) + messages = [] + messages.append( + { + "role": "user", + "content": [ + { + "type": "image_url", + "image_url": {"url": PILimage_to_base64(image)}, + }, + {"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here + ], + } + ) + try: + response = client.chat.completions.create( + messages=messages, + model=model_name, + max_completion_tokens=max_completion_tokens, + temperature=temperature, + top_p=top_p) + response = response.choices[0].message.content + return response + except requests.exceptions.RequestException as e: + print(f"request error: {e}") + return None + diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/consts.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/consts.py new file mode 100644 index 0000000..3b9f71b --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/consts.py @@ -0,0 +1,5 @@ +MIN_PIXELS=3136 +MAX_PIXELS=11289600 +IMAGE_FACTOR=28 + +image_extensions = {'.jpg', '.jpeg', '.png'} diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/doc_utils.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/doc_utils.py new file mode 100644 index 0000000..915c5c8 --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/doc_utils.py @@ -0,0 +1,60 @@ +import fitz +import numpy as np +import enum +from pydantic import BaseModel, Field +from PIL import Image + + +class SupportedPdfParseMethod(enum.Enum): + OCR = 'ocr' + TXT = 'txt' + + +class PageInfo(BaseModel): + """The width and height of page + """ + w: float = Field(description='the width of page') + h: float = Field(description='the height of page') + + +def fitz_doc_to_image(doc, target_dpi=200, origin_dpi=None) -> dict: + """Convert fitz.Document to image, Then convert the image to numpy array. + + Args: + doc (_type_): pymudoc page + dpi (int, optional): reset the dpi of dpi. Defaults to 200. + + Returns: + dict: {'img': numpy array, 'width': width, 'height': height } + """ + from PIL import Image + mat = fitz.Matrix(target_dpi / 72, target_dpi / 72) + pm = doc.get_pixmap(matrix=mat, alpha=False) + + if pm.width > 4500 or pm.height > 4500: + mat = fitz.Matrix(72 / 72, 72 / 72) # use fitz default dpi + pm = doc.get_pixmap(matrix=mat, alpha=False) + + image = Image.frombytes('RGB', (pm.width, pm.height), pm.samples) + return image + + +def load_images_from_pdf(pdf_file, dpi=200, start_page_id=0, end_page_id=None) -> list: + images = [] + with fitz.open(pdf_file) as doc: + pdf_page_num = doc.page_count + end_page_id = ( + end_page_id + if end_page_id is not None and end_page_id >= 0 + else pdf_page_num - 1 + ) + if end_page_id > pdf_page_num - 1: + print('end_page_id is out of range, use images length') + end_page_id = pdf_page_num - 1 + + for index in range(0, doc.page_count): + if start_page_id <= index <= end_page_id: + page = doc[index] + img = fitz_doc_to_image(page, target_dpi=dpi) + images.append(img) + return images \ No newline at end of file diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/format_transformer.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/format_transformer.py new file mode 100644 index 0000000..b2a2123 --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/format_transformer.py @@ -0,0 +1,205 @@ +import os +import sys +import json +import re + +from PIL import Image +from dots_ocr.utils.image_utils import PILimage_to_base64 + + +def has_latex_markdown(text: str) -> bool: + """ + Checks if a string contains LaTeX markdown patterns. + + Args: + text (str): The string to check. + + Returns: + bool: True if LaTeX markdown is found, otherwise False. + """ + if not isinstance(text, str): + return False + + # Define regular expression patterns for LaTeX markdown + latex_patterns = [ + r'\$\$.*?\$\$', # Block-level math formula $$...$$ + r'\$[^$\n]+?\$', # Inline math formula $...$ + r'\\begin\{.*?\}.*?\\end\{.*?\}', # LaTeX environment \begin{...}...\end{...} + r'\\[a-zA-Z]+\{.*?\}', # LaTeX command \command{...} + r'\\[a-zA-Z]+', # Simple LaTeX command \command + r'\\\[.*?\\\]', # Display math formula \[...\] + r'\\\(.*?\\\)', # Inline math formula \(...\) + ] + + # Check if any of the patterns match + for pattern in latex_patterns: + if re.search(pattern, text, re.DOTALL): + return True + + return False + + +def clean_latex_preamble(latex_text: str) -> str: + """ + Removes LaTeX preamble commands like document class and package imports. + + Args: + latex_text (str): The original LaTeX text. + + Returns: + str: The cleaned LaTeX text without preamble commands. + """ + # Define patterns to be removed + patterns = [ + r'\\documentclass\{[^}]+\}', # \documentclass{...} + r'\\usepackage\{[^}]+\}', # \usepackage{...} + r'\\usepackage\[[^\]]*\]\{[^}]+\}', # \usepackage[options]{...} + r'\\begin\{document\}', # \begin{document} + r'\\end\{document\}', # \end{document} + ] + + # Apply each pattern to clean the text + cleaned_text = latex_text + for pattern in patterns: + cleaned_text = re.sub(pattern, '', cleaned_text, flags=re.IGNORECASE) + + return cleaned_text + + +def get_formula_in_markdown(text: str) -> str: + """ + Formats a string containing a formula into a standard Markdown block. + + Args: + text (str): The input string, potentially containing a formula. + + Returns: + str: The formatted string, ready for Markdown rendering. + """ + # Remove leading/trailing whitespace + text = text.strip() + + # Check if it's already enclosed in $$ + if text.startswith('$$') and text.endswith('$$'): + text_new = text[2:-2].strip() + if not '$' in text_new: + return f"$$\n{text_new}\n$$" + else: + return text + + # Handle \[...\] format, convert to $$...$$ + if text.startswith('\\[') and text.endswith('\\]'): + inner_content = text[2:-2].strip() + return f"$$\n{inner_content}\n$$" + + # Check if it's enclosed in \[ \] + if len(re.findall(r'.*\\\[.*\\\].*', text)) > 0: + return text + + # Handle inline formulas ($...$) + pattern = r'\$([^$]+)\$' + matches = re.findall(pattern, text) + if len(matches) > 0: + # It's an inline formula, return it as is + return text + + # If no LaTeX markdown syntax is present, return directly + if not has_latex_markdown(text): + return text + + # Handle unnecessary LaTeX formatting like \usepackage + if 'usepackage' in text: + text = clean_latex_preamble(text) + + if text[0] == '`' and text[-1] == '`': + text = text[1:-1] + + # Enclose the final text in a $$ block with newlines + text = f"$$\n{text}\n$$" + return text + + +def clean_text(text: str) -> str: + """ + Cleans text by removing extra whitespace. + + Args: + text: The original text. + + Returns: + str: The cleaned text. + """ + if not text: + return "" + + # Remove leading and trailing whitespace + text = text.strip() + + # Replace multiple consecutive whitespace characters with a single space + text = re.sub(r'\s+', ' ', text) + + return text + + +def layoutjson2md(image: Image.Image, cells: list, text_key: str = 'text', no_page_hf: bool = False) -> str: + """ + Converts a layout JSON format to Markdown. + + In the layout JSON, formulas are LaTeX, tables are HTML, and text is Markdown. + + Args: + image: A PIL Image object. + cells: A list of dictionaries, each representing a layout cell. + text_key: The key for the text field in the cell dictionary. + no_page_header_footer: If True, skips page headers and footers. + + Returns: + str: The text in Markdown format. + """ + text_items = [] + + for i, cell in enumerate(cells): + x1, y1, x2, y2 = [int(coord) for coord in cell['bbox']] + text = cell.get(text_key, "") + + if no_page_hf and cell['category'] in ['Page-header', 'Page-footer']: + continue + + if cell['category'] == 'Picture': + image_crop = image.crop((x1, y1, x2, y2)) + image_base64 = PILimage_to_base64(image_crop) + text_items.append(f"![]({image_base64})") + elif cell['category'] == 'Formula': + text_items.append(get_formula_in_markdown(text)) + else: + text = clean_text(text) + text_items.append(f"{text}") + + markdown_text = '\n\n'.join(text_items) + return markdown_text + + +def fix_streamlit_formulas(md: str) -> str: + """ + Fixes the format of formulas in Markdown to ensure they display correctly in Streamlit. + It adds a newline after the opening $$ and before the closing $$ if they don't already exist. + + Args: + md_text (str): The Markdown text to fix. + + Returns: + str: The fixed Markdown text. + """ + + # This inner function will be used by re.sub to perform the replacement + def replace_formula(match): + content = match.group(1) + # If the content already has surrounding newlines, don't add more. + if content.startswith('\n'): + content = content[1:] + if content.endswith('\n'): + content = content[:-1] + return f'$$\n{content}\n$$' + + # Use regex to find all $$....$$ patterns and replace them using the helper function. + return re.sub(r'\$\$(.*?)\$\$', replace_formula, md, flags=re.DOTALL) diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/image_utils.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/image_utils.py new file mode 100644 index 0000000..69479eb --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/image_utils.py @@ -0,0 +1,196 @@ +import math +import base64 +from PIL import Image +from typing import Tuple +import os +from dots_ocr.utils.consts import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS +from dots_ocr.utils.doc_utils import fitz_doc_to_image +from io import BytesIO +import fitz +import requests +import copy + + +def round_by_factor(number: int, factor: int) -> int: + """Returns the closest integer to 'number' that is divisible by 'factor'.""" + return round(number / factor) * factor + + +def ceil_by_factor(number: int, factor: int) -> int: + """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" + return math.ceil(number / factor) * factor + + +def floor_by_factor(number: int, factor: int) -> int: + """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" + return math.floor(number / factor) * factor + + +def smart_resize( + height: int, + width: int, + factor: int = 28, + min_pixels: int = 3136, + max_pixels: int = 11289600, +): + """Rescales the image so that the following conditions are met: + + 1. Both dimensions (height and width) are divisible by 'factor'. + + 2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. + + 3. The aspect ratio of the image is maintained as closely as possible. + + """ + if max(height, width) / min(height, width) > 200: + raise ValueError( + f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}" + ) + h_bar = max(factor, round_by_factor(height, factor)) + w_bar = max(factor, round_by_factor(width, factor)) + if h_bar * w_bar > max_pixels: + beta = math.sqrt((height * width) / max_pixels) + h_bar = max(factor, floor_by_factor(height / beta, factor)) + w_bar = max(factor, floor_by_factor(width / beta, factor)) + elif h_bar * w_bar < min_pixels: + beta = math.sqrt(min_pixels / (height * width)) + h_bar = ceil_by_factor(height * beta, factor) + w_bar = ceil_by_factor(width * beta, factor) + if h_bar * w_bar > max_pixels: # max_pixels first to control the token length + beta = math.sqrt((h_bar * w_bar) / max_pixels) + h_bar = max(factor, floor_by_factor(h_bar / beta, factor)) + w_bar = max(factor, floor_by_factor(w_bar / beta, factor)) + return h_bar, w_bar + + + +def PILimage_to_base64(image, format='PNG'): + buffered = BytesIO() + image.save(buffered, format=format) + base64_str = base64.b64encode(buffered.getvalue()).decode('utf-8') + return f"data:image;base64,{base64_str}" + + +def to_rgb(pil_image: Image.Image) -> Image.Image: + if pil_image.mode == 'RGBA': + white_background = Image.new("RGB", pil_image.size, (255, 255, 255)) + white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask + return white_background + else: + return pil_image.convert("RGB") + + +# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py +def fetch_image( + image, + min_pixels=None, + max_pixels=None, + resized_height=None, + resized_width=None, + ) -> Image.Image: + assert image is not None, f"image not found, maybe input format error: {image}" + image_obj = None + if isinstance(image, Image.Image): + image_obj = image + elif image.startswith("http://") or image.startswith("https://"): + # fix memory leak issue while using BytesIO + with requests.get(image, stream=True) as response: + response.raise_for_status() + with BytesIO(response.content) as bio: + image_obj = copy.deepcopy(Image.open(bio)) + elif image.startswith("file://"): + image_obj = Image.open(image[7:]) + elif image.startswith("data:image"): + if "base64," in image: + _, base64_data = image.split("base64,", 1) + data = base64.b64decode(base64_data) + # fix memory leak issue while using BytesIO + with BytesIO(data) as bio: + image_obj = copy.deepcopy(Image.open(bio)) + else: + image_obj = Image.open(image) + if image_obj is None: + raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") + image = to_rgb(image_obj) + ## resize + if resized_height and resized_width: + resized_height, resized_width = smart_resize( + resized_height, + resized_width, + factor=IMAGE_FACTOR, + ) + assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, " + image = image.resize((resized_width, resized_height)) + elif min_pixels or max_pixels: + width, height = image.size + if not min_pixels: + min_pixels = MIN_PIXELS + if not max_pixels: + max_pixels = MAX_PIXELS + resized_height, resized_width = smart_resize( + height, + width, + factor=IMAGE_FACTOR, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, " + image = image.resize((resized_width, resized_height)) + + return image + +def get_input_dimensions( + image: Image.Image, + min_pixels: int, + max_pixels: int, + factor: int = 28 +) -> Tuple[int, int]: + """ + Gets the resized dimensions of the input image. + + Args: + image: The original image. + min_pixels: The minimum number of pixels. + max_pixels: The maximum number of pixels. + factor: The resizing factor. + + Returns: + The resized (width, height). + """ + input_height, input_width = smart_resize( + image.height, + image.width, + factor=factor, + min_pixels=min_pixels, + max_pixels=max_pixels + ) + return input_width, input_height + + +def get_image_by_fitz_doc(image, target_dpi=200): + # get image through fitz, to get target dpi image, mainly for higher image + if not isinstance(image, Image.Image): + assert isinstance(image, str) + _, file_ext = os.path.splitext(image) + assert file_ext in {'.jpg', '.jpeg', '.png'} + + if image.startswith("http://") or image.startswith("https://"): + with requests.get(image, stream=True) as response: + response.raise_for_status() + data_bytes = response.content + else: + with open(image, 'rb') as f: + data_bytes = f.read() + + image = Image.open(BytesIO(data_bytes)) + else: + data_bytes = BytesIO() + image.save(data_bytes, format='PNG') + + origin_dpi = image.info.get('dpi', None) + pdf_bytes = fitz.open(stream=data_bytes).convert_to_pdf() + doc = fitz.open('pdf', pdf_bytes) + page = doc[0] + image_fitz = fitz_doc_to_image(page, target_dpi=target_dpi, origin_dpi=origin_dpi) + + return image_fitz diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/layout_utils.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/layout_utils.py new file mode 100644 index 0000000..86fe1ed --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/layout_utils.py @@ -0,0 +1,228 @@ +from PIL import Image +from typing import Dict, List + +import fitz +from io import BytesIO +import json + +from dots_ocr.utils.image_utils import smart_resize +from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS +from dots_ocr.utils.output_cleaner import OutputCleaner + + +# Define a color map (using RGBA format) +dict_layout_type_to_color = { + "Text": (0, 128, 0, 256), # Green, translucent + "Picture": (255, 0, 255, 256), # Magenta, translucent + "Caption": (255, 165, 0, 256), # Orange, translucent + "Section-header": (0, 255, 255, 256), # Cyan, translucent + "Footnote": (0, 128, 0, 256), # Green, translucent + "Formula": (128, 128, 128, 256), # Gray, translucent + "Table": (255, 192, 203, 256), # Pink, translucent + "Title": (255, 0, 0, 256), # Red, translucent + "List-item": (0, 0, 255, 256), # Blue, translucent + "Page-header": (0, 128, 0, 256), # Green, translucent + "Page-footer": (128, 0, 128, 256), # Purple, translucent + "Other": (165, 42, 42, 256), # Brown, translucent + "Unknown": (0, 0, 0, 0), +} + + +def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True): + """ + Draw transparent boxes on an image. + + Args: + image: The source PIL Image. + cells: A list of cells containing bounding box information. + resized_height: The resized height. + resized_width: The resized width. + fill_bbox: Whether to fill the bounding box. + draw_bbox: Whether to draw the bounding box. + + Returns: + PIL.Image: The image with drawings. + """ + # origin_image = Image.open(image_path) + original_width, original_height = image.size + + # Create a new PDF document + doc = fitz.open() + + # Get image information + img_bytes = BytesIO() + image.save(img_bytes, format='PNG') + # pix = fitz.Pixmap(image_path) + pix = fitz.Pixmap(img_bytes) + + # Create a page + page = doc.new_page(width=pix.width, height=pix.height) + page.insert_image( + fitz.Rect(0, 0, pix.width, pix.height), + # filename=image_path + pixmap=pix + ) + + for i, cell in enumerate(cells): + bbox = cell['bbox'] + layout_type = cell['category'] + order = i + + top_left = (bbox[0], bbox[1]) + down_right = (bbox[2], bbox[3]) + if resized_height and resized_width: + scale_x = resized_width / original_width + scale_y = resized_height / original_height + top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y)) + down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y)) + + color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256)) + color = [col/255 for col in color[:3]] + + x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1] + rect_coords = fitz.Rect(x0, y0, x1, y1) + if draw_bbox: + if fill_bbox: + page.draw_rect( + rect_coords, + color=None, + fill=color, + fill_opacity=0.3, + width=0.5, + overlay=True, + ) # Draw the rectangle + else: + page.draw_rect( + rect_coords, + color=color, + fill=None, + fill_opacity=1, + width=0.5, + overlay=True, + ) # Draw the rectangle + order_cate = f"{order}_{layout_type}" + page.insert_text( + (x1, y0 + 20), order_cate, fontsize=20, color=color + ) # Insert the index in the top left corner of the rectangle + + # Convert to a Pixmap (maintaining original dimensions) + mat = fitz.Matrix(1.0, 1.0) + pix = page.get_pixmap(matrix=mat) + + return Image.frombytes("RGB", [pix.width, pix.height], pix.samples) + + +def pre_process_bboxes( + origin_image, + bboxes, + input_width, + input_height, + factor: int = 28, + min_pixels: int = 3136, + max_pixels: int = 11289600 +): + assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list) + min_pixels = min_pixels or MIN_PIXELS + max_pixels = max_pixels or MAX_PIXELS + original_width, original_height = origin_image.size + + input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels) + + scale_x = original_width / input_width + scale_y = original_height / input_height + + bboxes_out = [] + for bbox in bboxes: + bbox_resized = [ + int(float(bbox[0]) / scale_x), + int(float(bbox[1]) / scale_y), + int(float(bbox[2]) / scale_x), + int(float(bbox[3]) / scale_y) + ] + bboxes_out.append(bbox_resized) + + return bboxes_out + +def post_process_cells( + origin_image: Image.Image, + cells: List[Dict], + input_width, # server input width, also has smart_resize in server + input_height, + factor: int = 28, + min_pixels: int = 3136, + max_pixels: int = 11289600 +) -> List[Dict]: + """ + Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions. + + Args: + origin_image: The original PIL Image. + cells: A list of cells containing bounding box information. + input_width: The width of the input image sent to the server. + input_height: The height of the input image sent to the server. + factor: Resizing factor. + min_pixels: Minimum number of pixels. + max_pixels: Maximum number of pixels. + + Returns: + A list of post-processed cells. + """ + assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict) + min_pixels = min_pixels or MIN_PIXELS + max_pixels = max_pixels or MAX_PIXELS + original_width, original_height = origin_image.size + + input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels) + + scale_x = input_width / original_width + scale_y = input_height / original_height + + cells_out = [] + for cell in cells: + bbox = cell['bbox'] + bbox_resized = [ + int(float(bbox[0]) / scale_x), + int(float(bbox[1]) / scale_y), + int(float(bbox[2]) / scale_x), + int(float(bbox[3]) / scale_y) + ] + cell_copy = cell.copy() + cell_copy['bbox'] = bbox_resized + cells_out.append(cell_copy) + + return cells_out + +def is_legal_bbox(cells): + for cell in cells: + bbox = cell['bbox'] + if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]: + return False + return True + +def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None): + if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]: + return response + + json_load_failed = False + cells = response + try: + cells = json.loads(cells) + cells = post_process_cells( + origin_image, + cells, + input_image.width, + input_image.height, + min_pixels=min_pixels, + max_pixels=max_pixels + ) + return cells, False + except Exception as e: + print(f"cells post process error: {e}, when using {prompt_mode}") + json_load_failed = True + + if json_load_failed: + cleaner = OutputCleaner() + response_clean = cleaner.clean_model_output(cells) + if isinstance(response_clean, list): + response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell]) + return response_clean, True diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/output_cleaner.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/output_cleaner.py new file mode 100644 index 0000000..3fbc3aa --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/output_cleaner.py @@ -0,0 +1,623 @@ +#!/usr/bin/env python3 +""" +Data Cleaning Script - Cleans all data using a simplified regex method and saves the results + +Features: +1. Cleans all cases using a simplified regex method. +2. Saves the cleaned data for each case. +3. Ensures the relative order of dicts remains unchanged. +4. Generates a before-and-after cleaning report. +""" + +import json +import re +import os +from typing import Dict, List, Tuple, Optional, Any +from dataclasses import dataclass +from collections import Counter +import traceback + + +@dataclass +class CleanedData: + """Data structure for cleaned data""" + case_id: int + original_type: str # 'list' or 'str' + original_length: int + cleaned_data: List[Dict] + cleaning_operations: Dict[str, Any] # Records the cleaning operations performed + success: bool + + +class OutputCleaner: + """Data Cleaner - Based on a simplified regex method""" + + def __init__(self): + # Simplified regular expression patterns + self.dict_pattern = re.compile(r'\{[^{}]*?"bbox"\s*:\s*\[[^\]]*?\][^{}]*?\}', re.DOTALL) + self.bbox_pattern = re.compile(r'"bbox"\s*:\s*\[([^\]]+)\]') + self.missing_delimiter_pattern = re.compile(r'\}\s*\{(?!")') + + self.cleaned_results: List[CleanedData] = [] + + def clean_list_data(self, data: List[Dict], case_id: int) -> CleanedData: + """Cleans list-type data""" + + print(f"🔧 Cleaning List data - Case {case_id}") + print(f" Original items: {len(data)}") + + cleaned_data = [] + operations = { + 'type': 'list', + 'bbox_fixes': 0, + 'removed_items': 0, + 'original_count': len(data) + } + + for i, item in enumerate(data): + if not isinstance(item, dict): + operations['removed_items'] += 1 + continue + + # Check the bbox field + if 'bbox' in item: + bbox = item['bbox'] + + # Check bbox length - core logic + if isinstance(bbox, list) and len(bbox) == 3: + print(f" ⚠️ Item {i}: bbox has only 3 coordinates. Removing bbox, keeping category and text.") + # Keep only category and text, ensuring order is preserved + new_item = {} + if 'category' in item: + new_item['category'] = item['category'] + if 'text' in item: + new_item['text'] = item['text'] + if new_item: # Add only if there is valid content + cleaned_data.append(new_item) + operations['bbox_fixes'] += 1 + else: + operations['removed_items'] += 1 + continue + elif isinstance(bbox, list) and len(bbox) == 4: + # bbox is normal, add directly, preserving original order + cleaned_data.append(item.copy()) + continue + else: + print(f" ❌ Item {i}: Abnormal bbox format, skipping.") + operations['removed_items'] += 1 + continue + else: + # No bbox field, keep if category exists + if 'category' in item: + cleaned_data.append(item.copy()) + continue + else: + operations['removed_items'] += 1 + + operations['final_count'] = len(cleaned_data) + print(f" ✅ Cleaning complete: {len(cleaned_data)} items, {operations['bbox_fixes']} bbox fixes, {operations['removed_items']} items removed") + + return CleanedData( + case_id=case_id, + original_type='list', + original_length=len(data), + cleaned_data=cleaned_data, + cleaning_operations=operations, + success=True + ) + + def clean_string_data(self, data_str: str, case_id: int) -> CleanedData: + """Cleans string-type data""" + + print(f"🔧 Cleaning String data - Case {case_id}") + print(f" Original length: {len(data_str):,}") + + operations = { + 'type': 'str', + 'original_length': len(data_str), + 'delimiter_fixes': 0, + 'tail_truncated': False, + 'truncated_length': 0, + 'duplicate_dicts_removed': 0, + 'final_objects': 0 + } + + try: + # Step 1: Detect and fix missing delimiters + data_str, delimiter_fixes = self._fix_missing_delimiters(data_str) + operations['delimiter_fixes'] = delimiter_fixes + + # Step 2: Truncate the last incomplete element + data_str, tail_truncated = self._truncate_last_incomplete_element(data_str) + operations['tail_truncated'] = tail_truncated + operations['truncated_length'] = len(data_str) + + # Step 3: Remove duplicate complete dict objects, preserving order + data_str, duplicate_removes = self._remove_duplicate_complete_dicts_preserve_order(data_str) + operations['duplicate_dicts_removed'] = duplicate_removes + + # Step 4: Ensure correct JSON format + data_str = self._ensure_json_format(data_str) + + # Step 5: Try to parse the final result + final_data = self._parse_final_json(data_str) + + if final_data is not None: + operations['final_objects'] = len(final_data) + print(f" ✅ Cleaning complete: {len(final_data)} objects") + + return CleanedData( + case_id=case_id, + original_type='str', + original_length=operations['original_length'], + cleaned_data=final_data, + cleaning_operations=operations, + success=True + ) + else: + raise Exception("Could not parse the cleaned data") + + except Exception as e: + print(f" ❌ Cleaning failed: {e}") + return CleanedData( + case_id=case_id, + original_type='str', + original_length=operations['original_length'], + cleaned_data=[], + cleaning_operations=operations, + success=False + ) + + def _fix_missing_delimiters(self, text: str) -> Tuple[str, int]: + """Fixes missing delimiters""" + + fixes = 0 + + def replace_delimiter(match): + nonlocal fixes + fixes += 1 + return '},{' + + text = self.missing_delimiter_pattern.sub(replace_delimiter, text) + + if fixes > 0: + print(f" ✅ Fixed {fixes} missing delimiters") + + return text, fixes + + def _truncate_last_incomplete_element(self, text: str) -> Tuple[str, bool]: + """Truncates the last incomplete element""" + + # For very long text (>50k) or text not ending with ']', directly truncate the last '{"bbox":' + needs_truncation = ( + len(text) > 50000 or + not text.strip().endswith(']') + ) + + if needs_truncation: + # Check how many dict objects there are + bbox_count = text.count('{"bbox":') + + # If there is only one dict object, do not truncate to avoid deleting the only object + if bbox_count <= 1: + print(f" ⚠️ Only {bbox_count} dict objects found, skipping truncation to avoid deleting all content") + return text, False + + # Find the position of the last '{"bbox":' + last_bbox_pos = text.rfind('{"bbox":') + + if last_bbox_pos > 0: + # Truncate before this position + truncated_text = text[:last_bbox_pos].rstrip() + + # Remove trailing comma + if truncated_text.endswith(','): + truncated_text = truncated_text[:-1] + + print(f" ✂️ Truncated the last incomplete element, length reduced from {len(text):,} to {len(truncated_text):,}") + return truncated_text, True + + return text, False + + def _remove_duplicate_complete_dicts_preserve_order(self, text: str) -> Tuple[str, int]: + """Removes duplicate complete dict objects, preserving original order""" + + # Extract all dict objects, preserving order + dict_matches = list(self.dict_pattern.finditer(text)) + + if not dict_matches: + return text, 0 + + print(f" 📊 Found {len(dict_matches)} dict objects") + + # Deduplication while preserving order: only keep the first occurrence of a dict + unique_dicts = [] + seen_dict_strings = set() + total_duplicates = 0 + + for match in dict_matches: + dict_str = match.group() + + if dict_str not in seen_dict_strings: + unique_dicts.append(dict_str) + seen_dict_strings.add(dict_str) + else: + total_duplicates += 1 + + if total_duplicates > 0: + # Reconstruct the JSON array, preserving the original order + new_text = '[' + ', '.join(unique_dicts) + ']' + print(f" ✅ Removed {total_duplicates} duplicate dicts, keeping {len(unique_dicts)} unique dicts (order preserved)") + return new_text, total_duplicates + else: + print(f" ✅ No duplicate dict objects found") + return text, 0 + + def _ensure_json_format(self, text: str) -> str: + """Ensures correct JSON format""" + + text = text.strip() + + if not text.startswith('['): + text = '[' + text + + if not text.endswith(']'): + # Remove trailing comma + text = text.rstrip(',').rstrip() + text += ']' + + return text + + def _parse_final_json(self, text: str) -> Optional[List[Dict]]: + """Tries to parse the final JSON""" + + try: + data = json.loads(text) + if isinstance(data, list): + return data + except json.JSONDecodeError as e: + print(f" ❌ JSON parsing failed: {e}") + + # fallback1: Extract valid dict objects + valid_dicts = [] + + for match in self.dict_pattern.finditer(text): + dict_str = match.group() + try: + dict_obj = json.loads(dict_str) + valid_dicts.append(dict_obj) + except: + continue + + if valid_dicts: + print(f" ✅ Extracted {len(valid_dicts)} valid dicts") + return valid_dicts + + # fallback2: Special handling for a single incomplete dict + return self._handle_single_incomplete_dict(text) + + return None + + def _handle_single_incomplete_dict(self, text: str) -> Optional[List[Dict]]: + """Handles the special case of a single incomplete dict""" + + # Check if it's a single incomplete dict case + if not text.strip().startswith('[{"bbox":'): + return None + + try: + # Try to extract bbox coordinates + bbox_match = re.search(r'"bbox"\s*:\s*\[([^\]]+)\]', text) + if not bbox_match: + return None + + bbox_str = bbox_match.group(1) + bbox_coords = [int(x.strip()) for x in bbox_str.split(',')] + + if len(bbox_coords) != 4: + return None + + # Try to extract category + category_match = re.search(r'"category"\s*:\s*"([^"]+)"', text) + category = category_match.group(1) if category_match else "Text" + + # Try to extract the beginning of the text (first 10000 characters) + text_match = re.search(r'"text"\s*:\s*"([^"]{0,10000})', text) + if text_match: + text_content = text_match.group(1) + else: + text_content = "" + + # Construct the fixed dict + fixed_dict = { + "bbox": bbox_coords, + "category": category + } + + if text_content: + fixed_dict["text"] = text_content + + print(f" 🔧 Special fix: single incomplete dict → {fixed_dict}") + return [fixed_dict] + + except Exception as e: + print(f" ❌ Special fix failed: {e}") + return None + + def remove_duplicate_category_text_pairs_and_bbox(self, data_list: List[dict], case_id: int) -> List[dict]: + """Removes duplicate category-text pairs and duplicate bboxes""" + + if not data_list or len(data_list) <= 1: + print(f" 📊 Data length {len(data_list)} <= 1, skipping deduplication check") + return data_list + + print(f" 📊 Original data length: {len(data_list)}") + + # 1. Count occurrences and positions of each category-text pair + category_text_pairs = {} + for i, item in enumerate(data_list): + if isinstance(item, dict) and 'category' in item and 'text' in item: + pair_key = (item.get('category', ''), item.get('text', '')) + if pair_key not in category_text_pairs: + category_text_pairs[pair_key] = [] + category_text_pairs[pair_key].append(i) + + # 2. Count occurrences and positions of each bbox + bbox_pairs = {} + for i, item in enumerate(data_list): + if isinstance(item, dict) and 'bbox' in item: + bbox = item.get('bbox') + if isinstance(bbox, list) and len(bbox) > 0: + bbox_key = tuple(bbox) # Convert to tuple to use as a dictionary key + if bbox_key not in bbox_pairs: + bbox_pairs[bbox_key] = [] + bbox_pairs[bbox_key].append(i) + + # 3. Identify items to be removed + duplicates_to_remove = set() + + # 3a. Process category-text pairs that appear 5 or more times + for pair_key, positions in category_text_pairs.items(): + if len(positions) >= 5: + category, text = pair_key + # Keep the first occurrence, remove subsequent duplicates + positions_to_remove = positions[1:] + duplicates_to_remove.update(positions_to_remove) + + print(f" 🔍 Found duplicate category-text pair: category='{category}', first 50 chars of text='{text[:50]}...'") + print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}") + + # 3b. Process bboxes that appear 2 or more times + for bbox_key, positions in bbox_pairs.items(): + if len(positions) >= 2: + # Keep the first occurrence, remove subsequent duplicates + positions_to_remove = positions[1:] + duplicates_to_remove.update(positions_to_remove) + + print(f" 🔍 Found duplicate bbox: {list(bbox_key)}") + print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}") + + if not duplicates_to_remove: + print(f" ✅ No category-text pairs or bboxes found exceeding the duplication threshold") + return data_list + + # 4. Remove duplicate items from the original data (preserving order) + cleaned_data = [] + removed_count = 0 + for i, item in enumerate(data_list): + if i not in duplicates_to_remove: + cleaned_data.append(item) + else: + removed_count += 1 + + print(f" ✅ Deduplication complete: Removed {removed_count} duplicate items") + print(f" 📊 Cleaned data length: {len(cleaned_data)}") + + return cleaned_data + + def clean_model_output(self, model_output: str): + try: + # Select cleaning method based on data type + if isinstance(model_output, list): + result = self.clean_list_data(model_output, case_id=0) + else: + result = self.clean_string_data(str(model_output), case_id=0) + + # Add deduplication step: remove duplicate category-text pairs and bboxes + if result and hasattr(result, 'success') and result.success and result.cleaned_data: + original_data = result.cleaned_data + deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id=0) + # Update the cleaned_data in the CleanedData object + result.cleaned_data = deduplicated_data + return result.cleaned_data + except Exception as e: + print(f"❌ Case cleaning failed: {e}") + return model_output + + def clean_all_data(self, jsonl_path: str) -> List[CleanedData]: + """Cleans all data from a JSONL file""" + + print(f"🚀 Starting to clean JSONL file: {jsonl_path}") + + with open(jsonl_path, 'r', encoding='utf-8') as f: + lines = f.readlines() + + datas = [] + for i, line in enumerate(lines): + if line.strip(): + try: + data = json.loads(line) + predict_field = data.get('predict') + case_id = i + 1 + + print(f"\n{'='*50}") + print(f"🎯 Cleaning Case {case_id}") + print(f"{'='*50}") + + # Select cleaning method based on data type + if isinstance(predict_field, list): + print("📊 Data type: List") + result = self.clean_list_data(predict_field, case_id) + else: + print("📊 Data type: String") + result = self.clean_string_data(str(predict_field), case_id) + + # Add deduplication step: remove duplicate category-text pairs and bboxes + if result and hasattr(result, 'success') and result.success and result.cleaned_data: + print("🔄 Checking for and removing duplicate category-text pairs and bboxes...") + original_data = result.cleaned_data + deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id) + # Update the cleaned_data in the CleanedData object + result.cleaned_data = deduplicated_data + data['predict_resized'] = result.cleaned_data + + datas.append(data) + self.cleaned_results.append(result) + + except Exception as e: + print(f"❌ Case {i+1} cleaning failed: {e}") + traceback.print_exc() + + save_path = jsonl_path.replace('.jsonl', '_filtered.jsonl') + with open(save_path, 'w') as w: + for data in datas: + w.write(json.dumps(data, ensure_ascii=False) + '\n') + print(f"✅ Saved cleaned data to: {save_path}") + + return self.cleaned_results + + def save_cleaned_data(self, output_dir: str): + """Saves the cleaned data""" + + print(f"\n💾 Saving cleaned data to: {output_dir}") + os.makedirs(output_dir, exist_ok=True) + + # 1. Save cleaned data for each case + for result in self.cleaned_results: + case_filename = f"cleaned_case_{result.case_id:02d}.json" + case_filepath = os.path.join(output_dir, case_filename) + + # Save the cleaned data + with open(case_filepath, 'w', encoding='utf-8') as f: + json.dump(result.cleaned_data, f, ensure_ascii=False, indent=2) + + print(f" ✅ Case {result.case_id}: {len(result.cleaned_data)} objects → {case_filename}") + + # 2. Save all cleaned data to a single file + all_cleaned_data = [] + for result in self.cleaned_results: + all_cleaned_data.append({ + 'case_id': result.case_id, + 'original_type': result.original_type, + 'original_length': result.original_length, + 'cleaned_objects_count': len(result.cleaned_data), + 'success': result.success, + 'cleaning_operations': result.cleaning_operations, + 'cleaned_data': result.cleaned_data + }) + + all_data_filepath = os.path.join(output_dir, "all_cleaned_data.json") + with open(all_data_filepath, 'w', encoding='utf-8') as f: + json.dump(all_cleaned_data, f, ensure_ascii=False, indent=2) + + print(f" 📁 All data: {len(all_cleaned_data)} cases → all_cleaned_data.json") + + # 3. Generate a cleaning report + self._generate_cleaning_report(output_dir) + + def _generate_cleaning_report(self, output_dir: str): + """Generates a cleaning report""" + + report = [] + report.append("📊 Data Cleaning Report") + report.append("=" * 60) + import datetime + report.append(f"Processing Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + report.append("") + + # Overall statistics + total_cases = len(self.cleaned_results) + successful_cases = sum(1 for r in self.cleaned_results if r.success) + total_objects = sum(len(r.cleaned_data) for r in self.cleaned_results) + + report.append("📈 Overall Statistics:") + report.append(f" Total Cases: {total_cases}") + report.append(f" Successfully Cleaned: {successful_cases}") + report.append(f" Success Rate: {successful_cases/total_cases*100:.1f}%") + report.append(f" Total Recovered Objects: {total_objects}") + report.append("") + + # Detailed statistics + list_results = [r for r in self.cleaned_results if r.original_type == 'list'] + str_results = [r for r in self.cleaned_results if r.original_type == 'str'] + + if list_results: + report.append("📋 List Type Cleaning Statistics:") + for r in list_results: + ops = r.cleaning_operations + report.append(f" Case {r.case_id}: {ops['original_count']} → {ops['final_count']} objects") + if ops['bbox_fixes'] > 0: + report.append(f" - bbox fixes: {ops['bbox_fixes']}") + if ops['removed_items'] > 0: + report.append(f" - invalid items removed: {ops['removed_items']}") + report.append("") + + if str_results: + report.append("📝 String Type Cleaning Statistics:") + for r in str_results: + ops = r.cleaning_operations + status = "✅" if r.success else "❌" + report.append(f" Case {r.case_id} {status}: {ops['original_length']:,} chars → {ops['final_objects']} objects") + details = [] + if ops['delimiter_fixes'] > 0: + details.append(f"Delimiter fixes: {ops['delimiter_fixes']}") + if ops['tail_truncated']: + reduction = ops['original_length'] - ops['truncated_length'] + details.append(f"Tail truncation: -{reduction:,} chars") + if ops['duplicate_dicts_removed'] > 0: + details.append(f"Duplicates removed: {ops['duplicate_dicts_removed']}") + if details: + report.append(f" - {', '.join(details)}") + report.append("") + + # Note on data order + report.append("🔄 Data Order Guarantee:") + report.append(" ✅ The relative order of all dict objects is preserved during cleaning.") + report.append(" ✅ When deduplicating, the first occurrence of a dict is kept, and subsequent duplicates are removed.") + report.append(" ✅ The order of items in List-type data is fully preserved.") + + # Save the report + report_filepath = os.path.join(output_dir, "cleaning_report.txt") + with open(report_filepath, 'w', encoding='utf-8') as f: + f.write('\n'.join(report)) + + print(f" 📋 Cleaning report: cleaning_report.txt") + + # Also print to console + print(f"\n{chr(10).join(report)}") + + +def main(): + """Main function""" + + # Create a data cleaner instance + cleaner = OutputCleaner() + + # Input file + jsonl_path = "output_with_failcase.jsonl" + + # Output directory + output_dir = "output_with_failcase_cleaned" + + # Clean all data + results = cleaner.clean_all_data(jsonl_path) + + # Save the cleaned data + cleaner.save_cleaned_data(output_dir) + + print(f"\n🎉 Data cleaning complete!") + print(f"📁 Cleaned data saved in: {output_dir}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/prompts.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/prompts.py new file mode 100644 index 0000000..87714c3 --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/prompts.py @@ -0,0 +1,34 @@ +dict_promptmode_to_prompt = { + # prompt_layout_all_en: parse all layout info in json format. + "prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox. + +1. Bbox format: [x1, y1, x2, y2] + +2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. + +3. Text Extraction & Formatting Rules: + - Picture: For the 'Picture' category, the text field should be omitted. + - Formula: Format its text as LaTeX. + - Table: Format its text as HTML. + - All Others (Text, Title, etc.): Format their text as Markdown. + +4. Constraints: + - The output text must be the original text from the image, with no translation. + - All layout elements must be sorted according to human reading order. + +5. Final Output: The entire output must be a single JSON object. +""", + + # prompt_layout_only_en: layout detection + "prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""", + + # prompt_layout_only_en: parse ocr text except the Page-header and Page-footer + "prompt_ocr": """Extract the text content from this image.""", + + # prompt_grounding_ocr: extract text content in the given bounding box + "prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""", + + # "prompt_table_html": """Convert the table in this image to HTML.""", + # "prompt_table_latex": """Convert the table in this image to LaTeX.""", + # "prompt_formula_latex": """Convert the formula in this image to LaTeX.""", +} diff --git a/rag_factory/parser/Parser_Dotsocr/fig_recognize.py b/rag_factory/parser/Parser_Dotsocr/fig_recognize.py new file mode 100644 index 0000000..175529e --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/fig_recognize.py @@ -0,0 +1,183 @@ +import os +import glob +import json +import re +import fitz +from PIL import Image +from tqdm import tqdm +from dashscope import MultiModalConversation +import argparse +from pathlib import Path + + +def fig_understand(fig_path): + # prompt = '请给出图像中具体内容信息,并用json格式输出,仅输出json格式数据,其中,图片类型请从["chart","knowladge_map","other"]中选择' + prompt = ''' +你是一个图像内容理解专家,任务是读取图像内容并生成结构化 JSON 数据。请遵循以下规则: + +1. **仅输出 JSON 数据**,不要添加任何解释、前缀或后缀文字。 +2. JSON 格式中必须包含两个字段: + - "type": 图像类型,只能从 ["chart", "knowladge_map", "other"] 中选择。 + - "content": 图像的具体结构化内容描述。 +3. 如果图像类型是: + - "chart": 请提取图表的标题、坐标轴标签、图例、系列等结构信息。 + - "knowladge_map": 输出树状结构,所有节点使用 {"name": xxx, "children": [...]} 格式。 + - "other": 尽可能准确描述图像的主要元素。 + +以下是几个示例,请模仿格式输出。 + +--- + +### 示例1(chart): +输入图像:柱状图,标题为“年度销售统计”,X轴为月份,Y轴为销售额,图例为“产品A”和“产品B”。 + +输出: +```json +{ + "type": "chart", + "content": { + + "title": "年度销售统计", + "x_axis": "月份", + "y_axis": "销售额", + "legend": ["产品A", "产品B"], + "series": [ + {"name": "产品A", "data": [100, 120, 130]}, + {"name": "产品B", "data": [80, 90, 100]} + ] + } +} +示例2(knowladge_map): +输入图像:知识图谱,核心为“机器学习”,子节点有“监督学习”和“无监督学习”,监督学习下有“回归”和“分类”。 + +输出: +{ + "type": "knowladge_map", + "content": { + "name": "机器学习", + "children": [ + { + "name": "监督学习", + "children": [ + {"name": "回归"}, + {"name": "分类"} + ] + }, + { + "name": "无监督学习" + } + ] + } +} +示例3(other): +输入图像:一张会议室内多个人开会的场景。 + +输出: +{ + "type": "other", + "content": "一个会议室中有5个人正在围绕会议桌讨论,桌上有笔记本电脑和文件。" +} + +请根据上面的示例输出格式,严格输出图像的内容识别结果,只返回符合格式的 JSON 数据。 + +''' + messages= [ + { + "role": "user", + "content": [ + {"image": f"file://{fig_path}"}, + {"text": prompt} + ] + } +] + + response = MultiModalConversation.call( + api_key=os.environ.get('DASHSCOPE_API_KEY'), + model="qwen-vl-plus", + messages=messages, + ) + + # print(response) + return response["output"]["choices"][0]["message"].content[0]["text"].replace("```json",'').replace("```",'').strip() + +def save_fig(file_path, page_no, index, bbox, scale): + + file_name, file_ext = os.path.splitext(os.path.basename(file_path)) + file_name = file_name.replace('_layout', '') + base_dir = os.path.dirname(file_path) + pdf_file = os.path.join(base_dir, f"{file_name}_original.pdf") + doc = fitz.open(pdf_file) + page = doc.load_page(page_no) + + pdf_width = page.rect.width + pdf_height = page.rect.height + + scale_x = scale[1] / pdf_width + scale_y = scale[0] / pdf_height + x1 = bbox[0] / scale_x + y1 = bbox[1] / scale_y + x2 = bbox[2] / scale_x + y2 = bbox[3] / scale_y + pdf_bbox = fitz.Rect(x1, y1, x2, y2) + zoom = 300 / 72 # 输出300 DPI + matrix = fitz.Matrix(zoom, zoom) + img = page.get_pixmap(matrix=matrix, clip=pdf_bbox, alpha=False) + + save_dir = os.path.join(base_dir,f"{file_name}/image") + if not os.path.exists(save_dir): + os.mkdir(save_dir) + + + text = '' + save_path = os.path.join(save_dir, f'page_{page_no}_{index}.png') + if img is not None: + img.save(save_path) + + text = fig_understand(save_path) + + return save_path, text + +def process_one_file(json_file): + file_name = os.path.basename(json_file) + base_dir = os.path.dirname(json_file) + output_path = str(json_file).replace("layout", "img_content") + data = [] + with open(json_file, 'r', encoding='utf-8') as f: + json_data = json.load(f) + print(f"Processing file: {file_name}") + for row in tqdm(json_data): + if row.get('category','') == 'Picture': + bbox = row['bbox'] + page_no = row['page_no'] + if (bbox[2]-bbox[0])*(bbox[3]-bbox[1]) < 52000: + row['text'] = "" + else: + fig_path, text = save_fig(json_file, page_no=page_no, index=row['index'], bbox=bbox, scale=row['scale']) + # print(text) + row['text'] = json.loads(text) + with open(output_path, 'w', encoding='utf-8') as f: + json.dump(json_data, f, ensure_ascii=False, indent=4) + return json_data + +def main(): + parser = argparse.ArgumentParser(description="Use vlm to get parsed figure content.") + parser.add_argument( + "--output", type=str, default="output", + help="Output parsed directory (default: output)" + ) + args = parser.parse_args() + + + if os.path.isdir(args.output): + for file in sorted(Path(args.output).glob('*_layout.json')): + data = process_one_file(file) + elif os.path.isfile(args.output): + data = process_one_file(args.output) + else: + print(f"'{args.output}' no exist") + +if __name__ == "__main__": + os.environ["DASHSCOPE_API_KEY"] = "your api key" + main() + + diff --git a/rag_factory/parser/Parser_Dotsocr/parser.py b/rag_factory/parser/Parser_Dotsocr/parser.py new file mode 100644 index 0000000..a143eec --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/parser.py @@ -0,0 +1,424 @@ +import os +import json +from tqdm import tqdm +from multiprocessing.pool import ThreadPool, Pool +import argparse + + +from dots_ocr.model.inference import inference_with_vllm +from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS +from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize +from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf +from dots_ocr.utils.prompts import dict_promptmode_to_prompt +from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes +from dots_ocr.utils.format_transformer import layoutjson2md + + +class DotsOCRParser: + """ + parse image or pdf file + """ + + def __init__(self, + ip='localhost', + port=8000, + model_name='model', + temperature=0.1, + top_p=1.0, + max_completion_tokens=16384, + num_thread=64, + dpi = 200, + output_dir="output", + min_pixels=None, + max_pixels=None, + use_hf=True, + ): + self.dpi = dpi + + # default args for vllm server + self.ip = ip + self.port = port + self.model_name = model_name + # default args for inference + self.temperature = temperature + self.top_p = top_p + self.max_completion_tokens = max_completion_tokens + self.num_thread = num_thread + self.output_dir = output_dir + self.min_pixels = min_pixels + self.max_pixels = max_pixels + + self.use_hf = use_hf + if self.use_hf: + self._load_hf_model() + print(f"use hf model, num_thread will be set to 1") + else: + print(f"use vllm model, num_thread will be set to {self.num_thread}") + assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS + assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS + + def _load_hf_model(self): + import torch + from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer + from qwen_vl_utils import process_vision_info + + model_path = "weights/DotsOCR" + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + attn_implementation="flash_attention_2", + torch_dtype=torch.bfloat16, + device_map="auto", + trust_remote_code=True + ) + self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True) + self.process_vision_info = process_vision_info + + def _inference_with_hf(self, image, prompt): + messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image + }, + {"type": "text", "text": prompt} + ] + } + ] + + # Preparation for inference + text = self.processor.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + image_inputs, video_inputs = self.process_vision_info(messages) + inputs = self.processor( + text=[text], + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + + inputs = inputs.to("cuda") + + # Inference: Generation of the output + generated_ids = self.model.generate(**inputs, max_new_tokens=24000) + generated_ids_trimmed = [ + out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) + ] + response = self.processor.batch_decode( + generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False + )[0] + return response + + def _inference_with_vllm(self, image, prompt): + response = inference_with_vllm( + image, + prompt, + model_name=self.model_name, + ip=self.ip, + port=self.port, + temperature=self.temperature, + top_p=self.top_p, + max_completion_tokens=self.max_completion_tokens, + ) + return response + + def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None): + prompt = dict_promptmode_to_prompt[prompt_mode] + if prompt_mode == 'prompt_grounding_ocr': + assert bbox is not None + bboxes = [bbox] + bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0] + prompt = prompt + str(bbox) + return prompt + + # def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels) + def _parse_single_image( + self, + origin_image, + prompt_mode, + save_dir, + save_name, + source="image", + page_idx=0, + bbox=None, + fitz_preprocess=False, + ): + min_pixels, max_pixels = self.min_pixels, self.max_pixels + if prompt_mode == "prompt_grounding_ocr": + min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input + max_pixels = max_pixels or MAX_PIXELS + if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}" + if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <+ {MAX_PIXELS}" + + if source == 'image' and fitz_preprocess: + image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi) + image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels) + else: + image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels) + input_height, input_width = smart_resize(image.height, image.width) + prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels) + if self.use_hf: + response = self._inference_with_hf(image, prompt) + else: + response = self._inference_with_vllm(image, prompt) + result = {'page_no': page_idx, + "input_height": input_height, + "input_width": input_width + } + if source == 'pdf': + save_name = f"{save_name}_page_{page_idx}" + if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']: + cells, filtered = post_process_output( + response, + prompt_mode, + origin_image, + image, + min_pixels=min_pixels, + max_pixels=max_pixels, + ) + if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process + json_file_path = os.path.join(save_dir, f"{save_name}.json") + with open(json_file_path, 'w', encoding="utf-8") as w: + json.dump(response, w, ensure_ascii=False, indent=4) + + image_layout_path = os.path.join(save_dir, f"{save_name}.jpg") + origin_image.save(image_layout_path) + result.update({ + 'layout_info_path': json_file_path, + 'layout_image_path': image_layout_path, + }) + + md_file_path = os.path.join(save_dir, f"{save_name}.md") + with open(md_file_path, "w", encoding="utf-8") as md_file: + md_file.write(cells) + result.update({ + 'md_content_path': md_file_path + }) + result.update({ + 'filtered': True + }) + else: + try: + image_with_layout = draw_layout_on_image(origin_image, cells) + except Exception as e: + print(f"Error drawing layout on image: {e}") + image_with_layout = origin_image + + json_file_path = os.path.join(save_dir, f"{save_name}.json") + with open(json_file_path, 'w', encoding="utf-8") as w: + json.dump(cells, w, ensure_ascii=False, indent=4) + + image_layout_path = os.path.join(save_dir, f"{save_name}.jpg") + image_with_layout.save(image_layout_path) + result.update({ + 'layout_info_path': json_file_path, + 'layout_image_path': image_layout_path, + }) + if prompt_mode != "prompt_layout_only_en": # no text md when detection only + md_content = layoutjson2md(origin_image, cells, text_key='text') + md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench + md_file_path = os.path.join(save_dir, f"{save_name}.md") + with open(md_file_path, "w", encoding="utf-8") as md_file: + md_file.write(md_content) + md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md") + with open(md_nohf_file_path, "w", encoding="utf-8") as md_file: + md_file.write(md_content_no_hf) + result.update({ + 'md_content_path': md_file_path, + 'md_content_nohf_path': md_nohf_file_path, + }) + else: + image_layout_path = os.path.join(save_dir, f"{save_name}.jpg") + origin_image.save(image_layout_path) + result.update({ + 'layout_image_path': image_layout_path, + }) + + md_content = response + md_file_path = os.path.join(save_dir, f"{save_name}.md") + with open(md_file_path, "w", encoding="utf-8") as md_file: + md_file.write(md_content) + result.update({ + 'md_content_path': md_file_path, + }) + + return result + + def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False): + origin_image = fetch_image(input_path) + result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess) + result['file_path'] = input_path + return [result] + + def parse_pdf(self, input_path, filename, prompt_mode, save_dir): + print(f"loading pdf: {input_path}") + images_origin = load_images_from_pdf(input_path, dpi=self.dpi) + total_pages = len(images_origin) + tasks = [ + { + "origin_image": image, + "prompt_mode": prompt_mode, + "save_dir": save_dir, + "save_name": filename, + "source":"pdf", + "page_idx": i, + } for i, image in enumerate(images_origin) + ] + + def _execute_task(task_args): + return self._parse_single_image(**task_args) + + if self.use_hf: + num_thread = 1 + else: + num_thread = min(total_pages, self.num_thread) + print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...") + + results = [] + with ThreadPool(num_thread) as pool: + with tqdm(total=total_pages, desc="Processing PDF pages") as pbar: + for result in pool.imap_unordered(_execute_task, tasks): + results.append(result) + pbar.update(1) + + results.sort(key=lambda x: x["page_no"]) + for i in range(len(results)): + results[i]['file_path'] = input_path + return results + + def parse_file(self, + input_path, + output_dir="", + prompt_mode="prompt_layout_all_en", + bbox=None, + fitz_preprocess=False + ): + output_dir = output_dir or self.output_dir + output_dir = os.path.abspath(output_dir) + filename, file_ext = os.path.splitext(os.path.basename(input_path)) + save_dir = os.path.join(output_dir, filename) + os.makedirs(save_dir, exist_ok=True) + + if file_ext == '.pdf': + results = self.parse_pdf(input_path, filename, prompt_mode, save_dir) + elif file_ext in image_extensions: + results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess) + else: + raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf") + + print(f"Parsing finished, results saving to {save_dir}") + with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w', encoding="utf-8") as w: + for result in results: + w.write(json.dumps(result, ensure_ascii=False) + '\n') + + return results + + + +def main(): + prompts = list(dict_promptmode_to_prompt.keys()) + parser = argparse.ArgumentParser( + description="dots.ocr Multilingual Document Layout Parser", + ) + + parser.add_argument( + "input_path", type=str, + help="Input PDF/image file path" + ) + + parser.add_argument( + "--output", type=str, default="./output", + help="Output directory (default: ./output)" + ) + + parser.add_argument( + "--prompt", choices=prompts, type=str, default="prompt_layout_all_en", + help="prompt to query the model, different prompts for different tasks" + ) + parser.add_argument( + '--bbox', + type=int, + nargs=4, + metavar=('x1', 'y1', 'x2', 'y2'), + help='should give this argument if you want to prompt_grounding_ocr' + ) + parser.add_argument( + "--ip", type=str, default="localhost", + help="" + ) + parser.add_argument( + "--port", type=int, default=8001, + help="vllm server port" + ) + parser.add_argument( + "--model_name", type=str, default="model", + help="" + ) + parser.add_argument( + "--temperature", type=float, default=0.1, + help="" + ) + parser.add_argument( + "--top_p", type=float, default=1.0, + help="" + ) + parser.add_argument( + "--dpi", type=int, default=200, + help="" + ) + parser.add_argument( + "--max_completion_tokens", type=int, default=16384, + help="" + ) + parser.add_argument( + "--num_thread", type=int, default=32, + help="" + ) + # parser.add_argument( + # "--fitz_preprocess", type=bool, default=False, + # help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs" + # ) + parser.add_argument( + "--min_pixels", type=int, default=None, + help="" + ) + parser.add_argument( + "--max_pixels", type=int, default=None, + help="" + ) + parser.add_argument( + "--use_hf", type=bool, default=False, + help="" + ) + args = parser.parse_args() + + dots_ocr_parser = DotsOCRParser( + ip=args.ip, + port=args.port, + model_name=args.model_name, + temperature=args.temperature, + top_p=args.top_p, + max_completion_tokens=args.max_completion_tokens, + num_thread=args.num_thread, + dpi=args.dpi, + output_dir=args.output, + min_pixels=args.min_pixels, + max_pixels=args.max_pixels, + use_hf=args.use_hf, + ) + + result = dots_ocr_parser.parse_file( + args.input_path, + prompt_mode=args.prompt, + bbox=args.bbox, + ) + + + +if __name__ == "__main__": + main() diff --git a/rag_factory/parser/Parser_Dotsocr/readme.md b/rag_factory/parser/Parser_Dotsocr/readme.md new file mode 100644 index 0000000..d83b416 --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/readme.md @@ -0,0 +1,47 @@ +# dots.ocr parser + +This parser is based on the [dots.ocr](https://github.com/rednote-hilab/dots.ocr) model. See [dots.ocr](https://github.com/rednote-hilab/dots.ocr) for details. + +## 1. Installation + +``` +python>=3.10 +pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu128 +pip install -r requirements.txt +# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2 + +``` + +``` +# model download +from huggingface_hub import snapshot_download +snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir="Parser_Dotsocr/weights/DotsOCR", local_dir_use_symlinks=False, resume_download=True) + +or + +from modelscope import snapshot_download +snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir=model_dir) +``` + +## 2. vLLM inference + +Using vLLM for faster speed ( based on vllm==0.9.1 ) + +``` +python vllm_launch.py --model_path weights/DotsOCR +``` + +## 3. Document parse + +``` +python parser.py pdf_path.pdf (or pdfs_dir) +``` + +If you want to parse document with transformers,add `--use_hf=True` + +## 4. Figure understand + +Use vl model to understand content in parsed picture. Please obtain pdf layout parsed result first. +``` +python fig_recognize.py --output output +``` diff --git a/rag_factory/parser/Parser_Dotsocr/vllm_launch.py b/rag_factory/parser/Parser_Dotsocr/vllm_launch.py new file mode 100644 index 0000000..95d6fbc --- /dev/null +++ b/rag_factory/parser/Parser_Dotsocr/vllm_launch.py @@ -0,0 +1,65 @@ +import os +import subprocess +import sys +from pathlib import Path +import argparse + +def launch_vllm_server(hf_model_path="weights/DotsOCR", num_gpus="0", gpu_memory_utilization=0.95, port=8001): + # 1. 检查模型路径 + model_path = Path(hf_model_path).resolve() + if not model_path.exists(): + print(f"error: 模型路径不存在: {model_path}") + sys.exit(1) + + # 2. 设置环境变量 + os.environ["hf_model_path"] = str(model_path) + os.environ["PYTHONPATH"] = f"{model_path.parent}:{os.environ.get('PYTHONPATH', '')}" # + os.environ["CUDA_VISIBLE_DEVICES"] = num_gpus + os.environ["OPENAI_API_BASE"] = f"http://localhost:{port}/v1" + os.environ["OPENAI_API_KEY"] = "EMPTY" + + # 3. 修改 vllm CLI 添加模型 + try: + vllm_path = subprocess.check_output(["which", "vllm"], text=True).strip() + with open(vllm_path, "r") as f: + vllm_content = f.read() + + inject_line = "from DotsOCR import modeling_dots_ocr_vllm" + if inject_line not in vllm_content: + print("修改 vllm CLI 引入 DotsOCR 模型...") + sed_cmd = f"sed -i '/^from vllm\\.entrypoints\\.cli\\.main import main$/a\\{inject_line}' {vllm_path}" + subprocess.run(sed_cmd, shell=True, check=True) + else: + print("vllm CLI 已包含 DotsOCR 模型") + + except subprocess.CalledProcessError as e: + print(f"error: 获取 vllm 路径失败: {e}") + sys.exit(1) + + # 4. 启动 vllm server + print(" 正在启动 vLLM 服务...") + cmd = [ + "vllm", "serve", str(model_path), + "--tensor-parallel-size", "1", + "--gpu-memory-utilization", str(gpu_memory_utilization), + "--chat-template-content-format", "string", + "--served-model-name", "model", + "--trust-remote-code", + "--port", str(port) + ] + + try: + subprocess.run(cmd, check=True) + except subprocess.CalledProcessError as e: + print(f"error: vLLM 启动失败: {e}") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Launch vLLM server with dots_ocr model.") + parser.add_argument("--model_path", type=str, required=True, help="Path to your downloaded model weights, Please use a directory name without periods") + parser.add_argument("--gpus", type=str, default="0", help="GPU device ID(s)") + parser.add_argument("--gpu_memory_utilization", type=float, default=0.95, help="Desired GPU memory utilization percentage") + parser.add_argument("--port", type=int, default=8001, help="Port to launch the vLLM server on") + + args = parser.parse_args() + + launch_vllm_server(args.model_path, args.gpus, args.gpu_memory_utilization, args.port) diff --git a/rag_factory/parser/__init__.py b/rag_factory/parser/__init__.py index 22bd7c9..5595ce8 100644 --- a/rag_factory/parser/__init__.py +++ b/rag_factory/parser/__init__.py @@ -1,4 +1,3 @@ -from .Parser_Docling import DoclingParser -from .Parser_MinerU import MineruParser +from Parser_MinerU import MinerUPdfParser -__all__ = ["DoclingParser", "MineruParser"] \ No newline at end of file +__all__ = ["MinerUPdfParser"] \ No newline at end of file diff --git a/rag_factory/parser/main.py b/rag_factory/parser/main.py deleted file mode 100644 index e0b6e02..0000000 --- a/rag_factory/parser/main.py +++ /dev/null @@ -1,135 +0,0 @@ -from .Parser_MinerU import MineruParser -from .Parser_Docling import DoclingParser -import argparse - - -def main(): - """ - Main function to run the document parser from command line - """ - parser = argparse.ArgumentParser( - description="Parse documents using MinerU 2.0 or Docling" - ) - parser.add_argument("file_path", help="Path to the document to parse") - parser.add_argument("--output", "-o", help="Output directory path") - parser.add_argument( - "--method", - "-m", - choices=["auto", "txt", "ocr"], - default="auto", - help="Parsing method (auto, txt, ocr)", - ) - parser.add_argument( - "--lang", - "-l", - help="Document language for OCR optimization (e.g., ch, en, ja)", - ) - parser.add_argument( - "--backend", - "-b", - choices=[ - "pipeline", - "vlm-transformers", - "vlm-sglang-engine", - "vlm-sglang-client", - ], - default="pipeline", - help="Parsing backend", - ) - parser.add_argument( - "--device", - "-d", - help="Inference device (e.g., cpu, cuda, cuda:0, npu, mps)", - ) - parser.add_argument( - "--source", - choices=["huggingface", "modelscope", "local"], - default="huggingface", - help="Model source", - ) - parser.add_argument( - "--no-formula", - action="store_true", - help="Disable formula parsing", - ) - parser.add_argument( - "--no-table", - action="store_true", - help="Disable table parsing", - ) - parser.add_argument( - "--stats", action="store_true", help="Display content statistics" - ) - parser.add_argument( - "--check", - action="store_true", - help="Check parser installation", - ) - parser.add_argument( - "--parser", - choices=["mineru", "docling"], - default="mineru", - help="Parser selection", - ) - parser.add_argument( - "--vlm_url", - help="When the backend is `vlm-sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`", - ) - - args = parser.parse_args() - - # Check installation if requested - if args.check: - doc_parser = DoclingParser() if args.parser == "docling" else MineruParser() - if doc_parser.check_installation(): - print(f"✅ {args.parser.title()} is properly installed") - return 0 - else: - print(f"❌ {args.parser.title()} installation check failed") - return 1 - - try: - # Parse the document - doc_parser = DoclingParser() if args.parser == "docling" else MineruParser() - content_list = doc_parser.parse_document( - file_path=args.file_path, - method=args.method, - output_dir=args.output, - lang=args.lang, - backend=args.backend, - device=args.device, - source=args.source, - formula=not args.no_formula, - table=not args.no_table, - vlm_url=args.vlm_url, - ) - - print(f"✅ Successfully parsed: {args.file_path}") - print(f"📊 Extracted {len(content_list)} content blocks") - - # Display statistics if requested - if args.stats: - print("\n📈 Document Statistics:") - print(f"Total content blocks: {len(content_list)}") - - # Count different types of content - content_types = {} - for item in content_list: - if isinstance(item, dict): - content_type = item.get("type", "unknown") - content_types[content_type] = content_types.get(content_type, 0) + 1 - - if content_types: - print("\n📋 Content Type Distribution:") - for content_type, count in sorted(content_types.items()): - print(f" • {content_type}: {count}") - - except Exception as e: - print(f"❌ Error: {str(e)}") - return 1 - - return 0 - - -if __name__ == "__main__": - exit(main()) \ No newline at end of file diff --git a/rag_factory/rerankers/Reranker_Base.py b/rag_factory/rerankers/Reranker_Base.py new file mode 100644 index 0000000..518b81e --- /dev/null +++ b/rag_factory/rerankers/Reranker_Base.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from ..Retrieval import Document +import warnings + +class RerankerBase(ABC): + """ + Reranker 基类,所有 Reranker 应该继承此类并实现 rerank 方法。 + 不建议直接实例化本类。 + + 使用方法: + class MyReranker(RerankerBase): + def rerank(self, query: str, documents: list[str], **kwargs) -> list[float]: + # 实现具体的重排序逻辑 + ... + """ + def __init__(self): + if type(self) is RerankerBase: + warnings.warn("RerankerBase 是抽象基类,不应直接实例化。请继承并实现 rerank 方法。", UserWarning) + + @abstractmethod + def rerank(self, query: str, documents: list[Document], **kwargs) -> list[Document]: + """ + Rerank the documents based on the query. + 需要子类实现。 + """ + warnings.warn("调用了未实现的 rerank 方法。请在子类中实现该方法。", UserWarning) + raise NotImplementedError("子类必须实现 rerank 方法。") diff --git a/rag_factory/rerankers/Reranker_Qwen3.py b/rag_factory/rerankers/Reranker_Qwen3.py new file mode 100644 index 0000000..b909c2b --- /dev/null +++ b/rag_factory/rerankers/Reranker_Qwen3.py @@ -0,0 +1,75 @@ +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from .Reranker_Base import RerankerBase +from ..Retrieval.RetrieverBase import Document + +class Qwen3Reranker(RerankerBase): + def __init__(self, model_name_or_path: str, max_length: int = 4096, instruction=None, attn_type='causal', device_id="cuda:0", **kwargs): + super().__init__() + device = torch.device(device_id) + self.max_length = max_length + self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side='left') + self.lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16) + self.lm = self.lm.to(device).eval() + self.token_false_id = self.tokenizer.convert_tokens_to_ids("no") + self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes") + self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n" + self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n" + self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False) + self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False) + self.instruction = instruction or "Given the user query, retrieval the relevant passages" + self.device = device + + def format_instruction(self, instruction, query, doc): + if instruction is None: + instruction = self.instruction + output = f": {instruction}\n: {query}\n: {doc}" + return output + + def process_inputs(self, pairs): + out = self.tokenizer( + pairs, padding=False, truncation='longest_first', + return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens) + ) + for i, ele in enumerate(out['input_ids']): + out['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens + out = self.tokenizer.pad(out, padding=True, return_tensors="pt", max_length=self.max_length) + for key in out: + out[key] = out[key].to(self.lm.device) + return out + + @torch.no_grad() + def compute_logits(self, inputs, **kwargs): + batch_scores = self.lm(**inputs).logits[:, -1, :] + true_vector = batch_scores[:, self.token_true_id] + false_vector = batch_scores[:, self.token_false_id] + batch_scores = torch.stack([false_vector, true_vector], dim=1) + batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1) + scores = batch_scores[:, 1].exp().tolist() + return scores + + def compute_scores(self, pairs, instruction=None, **kwargs): + pairs = [self.format_instruction(instruction, query, doc) for query, doc in pairs] + inputs = self.process_inputs(pairs) + scores = self.compute_logits(inputs) + return scores + + def rerank(self, query: str, documents: list[Document], k: int = None, batch_size: int = 8, **kwargs) -> list[Document]: + # 1. 组装 (query, doc.content) 对 + pairs = [(query, doc.content) for doc in documents] + + # 2. 计算分数 + all_scores = [] + for i in range(0, len(pairs), batch_size): + batch_pairs = pairs[i:i+batch_size] + batch_scores = self.compute_scores(batch_pairs) + all_scores.extend(batch_scores) + scores = all_scores + + # 3. 按分数排序 + doc_score_pairs = list(zip(documents, scores)) + doc_score_pairs.sort(key=lambda x: x[1], reverse=True) + reranked_docs = [doc for doc, score in doc_score_pairs] + if k is not None: + reranked_docs = reranked_docs[:k] + return reranked_docs \ No newline at end of file diff --git a/rag_factory/rerankers/__init__.py b/rag_factory/rerankers/__init__.py index e69de29..be8f542 100644 --- a/rag_factory/rerankers/__init__.py +++ b/rag_factory/rerankers/__init__.py @@ -0,0 +1,5 @@ +from .Reranker_Base import RerankerBase +from .Reranker_Qwen3 import Qwen3Reranker +from .registry import RerankerRegistry + +__all__ = ["RerankerBase", "Qwen3Reranker", "RerankerRegistry"] \ No newline at end of file diff --git a/rag_factory/rerankers/registry.py b/rag_factory/rerankers/registry.py new file mode 100644 index 0000000..d817c4a --- /dev/null +++ b/rag_factory/rerankers/registry.py @@ -0,0 +1,93 @@ +from typing import Dict, Type, Any, Optional, List +import logging +from .Reranker_Base import RerankerBase +from .Reranker_Qwen3 import Qwen3Reranker + +class RerankerRegistry: + _rerankers: Dict[str, Type[RerankerBase]] = {} + + @classmethod + def register(cls, name: str, reranker_class: Type[RerankerBase]): + """注册重排序器类到注册表中 + + Args: + name: 重排序器的名称 + reranker_class: 重排序器类,必须继承自RerankerBase + """ + if not issubclass(reranker_class, RerankerBase): + raise ValueError(f"重排序器类 {reranker_class} 必须继承自 RerankerBase") + + if name in cls._rerankers: + logging.warning(f"重排序器 '{name}' 已存在,将被覆盖") + + cls._rerankers[name] = reranker_class + logging.info(f"成功注册重排序器: {name}") + + @classmethod + def create(cls, name: str, **kwargs) -> RerankerBase: + """根据名称获取重排序器实例 + + Args: + name: 重排序器名称 + **kwargs: 传递给重排序器构造函数的参数 + + Returns: + RerankerBase: 重排序器实例 + + Raises: + ValueError: 当重排序器未注册时抛出 + """ + if name not in cls._rerankers: + available = list(cls._rerankers.keys()) + raise ValueError(f"未找到重排序器 '{name}'。可用的重排序器: {available}") + + reranker_class = cls._rerankers[name] + return reranker_class(**kwargs) + + @classmethod + def list_rerankers(cls) -> List[str]: + """获取所有已注册的重排序器名称列表 + + Returns: + List[str]: 重排序器名称列表 + """ + return list(cls._rerankers.keys()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """检查重排序器是否已注册 + + Args: + name: 重排序器名称 + + Returns: + bool: 如果已注册返回True,否则返回False + """ + return name in cls._rerankers + + @classmethod + def unregister(cls, name: str) -> bool: + """注销重排序器 + + Args: + name: 要注销的重排序器名称 + + Returns: + bool: 如果成功注销返回True,如果重排序器不存在返回False + """ + if name in cls._rerankers: + del cls._rerankers[name] + logging.info(f"成功注销重排序器: {name}") + return True + else: + logging.warning(f"尝试注销不存在的重排序器: {name}") + return False + + @classmethod + def clear_all(cls): + """清除所有已注册的重排序器""" + cls._rerankers.clear() + logging.info("已清除所有重排序器") + +# 注册默认的重排序器 +RerankerRegistry.register("qwen3", Qwen3Reranker) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index ac808e5..823e164 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,22 @@ neo4j aioboto3 llama-index llama-index-core -peewee \ No newline at end of file +peewee + +mineru[core] +rank_bm25 +faiss_gpu + + + +# streamlit +PyMuPDF +openai +qwen_vl_utils +transformers==4.51.3 +huggingface_hub +modelscope +flash-attn==2.8.0.post2 +# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2 +accelerate +dashscope