diff --git a/examples/TCL_rag/config.yaml b/examples/TCL_rag/config.yaml
new file mode 100644
index 0000000..231b3a8
--- /dev/null
+++ b/examples/TCL_rag/config.yaml
@@ -0,0 +1,35 @@
+llm:
+ name: openai
+ base_url: "https://api.gptsapi.net/v1"
+ api_key: "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2"
+ model: "gpt-4.1-mini"
+
+embedding:
+ name: huggingface
+ model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B"
+ model_kwargs:
+ device: "cuda:0"
+
+
+
+store:
+ name: faiss
+ folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store
+
+
+bm25:
+ name: bm25
+ k: 10
+ data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json
+
+retriever:
+ name: vectorstore
+
+reranker:
+ name: qwen3
+ model_name_or_path: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_reranker_0.6B"
+ device_id: "cuda:0"
+
+dataset:
+ name: TCL
+
diff --git a/examples/TCL_rag/rag_flow.py b/examples/TCL_rag/rag_flow.py
new file mode 100644
index 0000000..1e5f9c9
--- /dev/null
+++ b/examples/TCL_rag/rag_flow.py
@@ -0,0 +1,85 @@
+import sys
+import os
+
+# 添加 RAG-Factory 目录到 Python 路径
+rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
+sys.path.insert(0, rag_factory_path)
+
+from rag_factory.llms import LLMRegistry
+from rag_factory.Embed import EmbeddingRegistry
+from rag_factory.Store import VectorStoreRegistry
+from rag_factory.Retrieval import RetrieverRegistry
+from rag_factory.rerankers import RerankerRegistry
+from rag_factory.Retrieval import Document
+from typing import List
+import json
+
+
+class TCL_RAG:
+ def __init__(
+ self,
+ *,
+ llm_config=None,
+ embedding_config=None,
+ vector_store_config=None,
+ bm25_retriever_config=None,
+ retriever_config=None,
+ reranker_config=None,
+ ):
+ llm_config = llm_config or {}
+ embedding_config = embedding_config or {}
+ vector_store_config = vector_store_config or {}
+ bm25_retriever_config = bm25_retriever_config or {}
+ retriever_config = retriever_config or {}
+ reranker_config = reranker_config or {}
+ self.llm = LLMRegistry.create(**llm_config)
+ self.embedding = EmbeddingRegistry.create(**embedding_config)
+ self.vector_store = VectorStoreRegistry.load(**vector_store_config, embedding=self.embedding)
+ self.bm25_retriever = RetrieverRegistry.create(**bm25_retriever_config)
+ self.bm25_retriever = self.bm25_retriever.from_documents(documents=self._load_data(bm25_retriever_config["data_path"]), preprocess_func=self.chinese_preprocessing_func, k=bm25_retriever_config["k"])
+
+ self.retriever = RetrieverRegistry.create(**retriever_config, vectorstore=self.vector_store)
+ self.multi_path_retriever = RetrieverRegistry.create("multipath", retrievers=[self.bm25_retriever, self.retriever])
+ self.reranker = RerankerRegistry.create(**reranker_config)
+
+ def invoke(self, query: str, k: int = None):
+ return self.multi_path_retriever.invoke(query, top_k=k)
+
+ def rerank(self, query: str, documents: List[Document], k: int = None, batch_size: int = 8):
+ return self.reranker.rerank(query, documents, k, batch_size)
+
+ def _load_data(self, data_path: str):
+ with open(data_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ docs = []
+ for item in data:
+ content = item.get("full_content", "")
+ metadata = {"title": item.get("original_filename", "")}
+ docs.append(Document(content=content, metadata=metadata))
+ return docs
+
+ def chinese_preprocessing_func(self, text: str) -> str:
+ import jieba
+ return " ".join(jieba.cut(text))
+
+
+ def answer(self, query: str, documents: List[Document]):
+
+ template = (
+ "你是一位工业领域的专家。根据以下检索到的材料回答用户问题。"
+ "如果回答所需信息未在材料中出现,请说明无法找到相关信息。\n\n"
+ "{context}\n\n"
+ "用户问题:{question}\n"
+ "答复:"
+ )
+ context = "\n".join([doc.content for doc in documents])
+ prompt = template.format(question=query, context=context)
+ messages = [
+ {"role": "system", "content": "你是一位工业领域的专家。"},
+ {"role": "user", "content": prompt}
+ ]
+ return self.llm.chat(messages)
+
+
+
+
diff --git a/examples/TCL_rag/test.py b/examples/TCL_rag/test.py
new file mode 100644
index 0000000..9607f3f
--- /dev/null
+++ b/examples/TCL_rag/test.py
@@ -0,0 +1,32 @@
+from rag_flow import TCL_RAG
+import yaml
+
+# 加载配置文件
+with open('/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/TCL_rag/config.yaml', 'r', encoding='utf-8') as f:
+ config = yaml.safe_load(f)
+
+llm_config = config['llm']
+embedding_config = config['embedding']
+reranker_config = config['reranker']
+bm25_retriever_config = config['bm25']
+retriever_config = config['retriever']
+vector_store_config = config['store']
+
+
+
+
+if __name__ == "__main__":
+
+ rag = TCL_RAG(llm_config=llm_config,
+ embedding_config=embedding_config,
+ reranker_config=reranker_config,
+ retriever_config=retriever_config,
+ vector_store_config=vector_store_config,
+ bm25_retriever_config=bm25_retriever_config)
+
+ result = rag.invoke("毛细管设计规范按照什么标准",k=20)
+
+ answer = rag.answer("毛细管设计规范按照什么标准",result)
+
+
+ print(answer)
\ No newline at end of file
diff --git a/examples/bm25/config.yaml b/examples/bm25/config.yaml
new file mode 100644
index 0000000..f52f8af
--- /dev/null
+++ b/examples/bm25/config.yaml
@@ -0,0 +1,3 @@
+retriever:
+ name: bm25
+ k: 8
\ No newline at end of file
diff --git a/examples/bm25/main.py b/examples/bm25/main.py
new file mode 100644
index 0000000..03fa291
--- /dev/null
+++ b/examples/bm25/main.py
@@ -0,0 +1,36 @@
+import sys
+import os
+
+rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
+sys.path.insert(0, rag_factory_path)
+
+import json
+from rag_factory.Retrieval import Document
+from rag_factory.Retrieval import RetrieverRegistry
+
+import yaml
+
+
+def load_data(jsonl_path: str):
+ with open(jsonl_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+ docs = []
+ for item in data:
+ content = item.get("full_content", "")
+ metadata = {"title": item.get("original_title", "")}
+ docs.append(Document(content=content, metadata=metadata))
+ return docs
+
+def chinese_preprocessing_func(text: str) -> str:
+ import jieba
+ return " ".join(jieba.cut(text))
+
+if __name__ == "__main__":
+ docs = load_data("/data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json")
+ with open("/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/bm25/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+ bm25_retriever = RetrieverRegistry.create(**config["retriever"])
+ bm25_retriever = bm25_retriever.from_documents(documents=docs, preprocess_func=chinese_preprocessing_func, k=config["retriever"]["k"])
+
+ print(bm25_retriever.invoke("什么是TCL?"))
\ No newline at end of file
diff --git a/examples/faiss_construct/config.yaml b/examples/faiss_construct/config.yaml
new file mode 100644
index 0000000..d429d1a
--- /dev/null
+++ b/examples/faiss_construct/config.yaml
@@ -0,0 +1,14 @@
+store:
+ name: faiss # 数据库
+ folder_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store # 保存路径
+
+
+embedding:
+ name: huggingface # 嵌入模型
+ model_name: "/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B" # 模型路径
+ model_kwargs:
+ device: "cuda:1" # 设备
+
+dataset:
+ name: TCL
+ data_path: /data/FinAi_Mapping_Knowledge/chenmingzhen/tog3_backend/TCL/syn_table_data/data_all_clearn_short_chunk_with_caption_desc.json
\ No newline at end of file
diff --git a/examples/faiss_construct/faiss_constructor.py b/examples/faiss_construct/faiss_constructor.py
new file mode 100644
index 0000000..cb7c0c5
--- /dev/null
+++ b/examples/faiss_construct/faiss_constructor.py
@@ -0,0 +1,43 @@
+import sys
+import os
+
+# 添加 RAG-Factory 目录到 Python 路径
+rag_factory_path = os.path.join(os.path.dirname(__file__), "..", "..")
+sys.path.insert(0, rag_factory_path)
+
+from rag_factory.Store import VectorStoreRegistry
+from rag_factory.Embed import EmbeddingRegistry
+import yaml
+from rag_factory.Retrieval import Document
+import json
+
+
+with open("/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/examples/faiss_construct/config.yaml", "r", encoding="utf-8") as f:
+ config = yaml.safe_load(f)
+
+store_config = config["store"]
+embedding_config = config["embedding"]
+dataset_config = config["dataset"]["data_path"]
+embedding = EmbeddingRegistry.create(**embedding_config)
+store = VectorStoreRegistry.create(**store_config, embedding=embedding)
+
+
+if __name__ == "__main__":
+
+ # 读取数据
+ with open(dataset_config, "r", encoding="utf-8") as f:
+ docs = []
+ data = json.load(f)
+ for item in data:
+ full_content = item.get("full_content", "")
+ metadata = {
+ "title": item.get("original_filename"),
+ }
+
+ docs.append(Document(content=full_content, metadata=metadata))
+
+ # 创建向量库
+ vectorstore = store.from_documents(docs, embedding=embedding)
+
+ # 保存到本地
+ vectorstore.save_local(store_config["folder_path"])
\ No newline at end of file
diff --git a/rag_factory/Embed/Embedding_Base.py b/rag_factory/Embed/Embedding_Base.py
index 2af5bb2..c1527d3 100644
--- a/rag_factory/Embed/Embedding_Base.py
+++ b/rag_factory/Embed/Embedding_Base.py
@@ -2,12 +2,13 @@
from dataclasses import dataclass
import asyncio
from concurrent.futures import ThreadPoolExecutor
+from typing import List
class Embeddings(ABC):
"""嵌入接口"""
@abstractmethod
- def embed_documents(self, texts: list[str]) -> list[list[float]]:
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed search docs.
Args:
@@ -19,7 +20,7 @@ def embed_documents(self, texts: list[str]) -> list[list[float]]:
pass
@abstractmethod
- def embed_query(self, text: str) -> list[float]:
+ def embed_query(self, text: str) -> List[float]:
"""Embed query text.
Args:
@@ -30,7 +31,7 @@ def embed_query(self, text: str) -> list[float]:
"""
pass
- async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
+ async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronous Embed search docs.
Args:
@@ -43,7 +44,7 @@ async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
ThreadPoolExecutor(), self.embed_documents, texts
)
- async def aembed_query(self, text: str) -> list[float]:
+ async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text.
Args:
diff --git a/rag_factory/Embed/__init__.py b/rag_factory/Embed/__init__.py
index 209132e..dc3fe8f 100644
--- a/rag_factory/Embed/__init__.py
+++ b/rag_factory/Embed/__init__.py
@@ -1,4 +1,5 @@
from .Embedding_Base import Embeddings
from .Embedding_Huggingface import HuggingFaceEmbeddings
+from .registry import EmbeddingRegistry
-__all__ = ["Embeddings", "HuggingFaceEmbeddings"]
\ No newline at end of file
+__all__ = ["Embeddings", "HuggingFaceEmbeddings", "EmbeddingRegistry"]
\ No newline at end of file
diff --git a/rag_factory/Embed/registry.py b/rag_factory/Embed/registry.py
new file mode 100644
index 0000000..0899688
--- /dev/null
+++ b/rag_factory/Embed/registry.py
@@ -0,0 +1,79 @@
+from typing import Dict, Type, Any, Optional, List
+import logging
+from .Embedding_Huggingface import HuggingFaceEmbeddings
+from .Embedding_Base import Embeddings
+
+class EmbeddingRegistry:
+ """嵌入模型注册器,用于管理和创建不同类型的嵌入模型"""
+ _embeddings: Dict[str, Type[Embeddings]] = {}
+
+ @classmethod
+ def register(cls, name: str, embedding_class: Type[Embeddings]):
+ """注册嵌入模型类
+
+ Args:
+ name: 模型名称
+ embedding_class: 嵌入模型类
+ """
+ cls._embeddings[name] = embedding_class
+
+ @classmethod
+ def create(cls, name: str, **kwargs) -> Embeddings:
+ """获取嵌入模型实例
+
+ Args:
+ name: 模型名称
+ **kwargs: 模型初始化参数
+
+ Returns:
+ 嵌入模型实例
+
+ Raises:
+ ValueError: 当模型名称不存在时
+ """
+ if name not in cls._embeddings:
+ available_embeddings = list(cls._embeddings.keys())
+ raise ValueError(f"嵌入模型 '{name}' 未注册。可用的模型: {available_embeddings}")
+
+ embedding_class = cls._embeddings[name]
+ return embedding_class(**kwargs)
+
+ @classmethod
+ def list_embeddings(cls) -> List[str]:
+ """列出所有已注册的嵌入模型名称
+
+ Returns:
+ 已注册的模型名称列表
+ """
+ return list(cls._embeddings.keys())
+
+ @classmethod
+ def is_registered(cls, name: str) -> bool:
+ """检查模型是否已注册
+
+ Args:
+ name: 模型名称
+
+ Returns:
+ 如果已注册返回True,否则返回False
+ """
+ return name in cls._embeddings
+
+ @classmethod
+ def unregister(cls, name: str) -> bool:
+ """取消注册模型
+
+ Args:
+ name: 模型名称
+
+ Returns:
+ 成功取消注册返回True,模型不存在返回False
+ """
+ if name in cls._embeddings:
+ del cls._embeddings[name]
+ return True
+ return False
+
+
+# 注册默认的嵌入模型
+EmbeddingRegistry.register("huggingface", HuggingFaceEmbeddings)
\ No newline at end of file
diff --git a/rag_factory/Retrieval/Retriever/Retriever_BM25.py b/rag_factory/Retrieval/Retriever/Retriever_BM25.py
index 278b55e..98eef3e 100644
--- a/rag_factory/Retrieval/Retriever/Retriever_BM25.py
+++ b/rag_factory/Retrieval/Retriever/Retriever_BM25.py
@@ -9,12 +9,12 @@
from pydantic import ConfigDict, Field, model_validator
logger = logging.getLogger(__name__)
-
+import numpy as np
from rag_factory.Retrieval.RetrieverBase import BaseRetriever, Document
def default_preprocessing_func(text: str) -> List[str]:
- """默认的文本预处理函数
+ """默认的文本预处理函数,仅在英文文本上有效
Args:
text: 输入文本
@@ -25,33 +25,51 @@ def default_preprocessing_func(text: str) -> List[str]:
return text.split()
-def chinese_preprocessing_func(text: str) -> List[str]:
- """中文文本预处理函数
-
- Args:
- text: 输入的中文文本
-
- Returns:
- 分词后的词语列表
- """
- try:
- import jieba
- return list(jieba.cut(text))
- except ImportError:
- logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba")
- return text.split()
class BM25Retriever(BaseRetriever):
- """BM25 检索器实现
-
- 基于 BM25 算法的文档检索器。
- 使用 rank_bm25 库实现高效的 BM25 搜索。
-
- 注意:BM25 算法适用于相对静态的文档集合。虽然支持动态添加/删除文档,
- 但每次操作都会重建整个索引,在大型文档集合上可能有性能问题。
- 对于频繁更新的场景,建议使用 VectorStoreRetriever。
-
+ """
+ BM25Retriever 是一个基于 BM25 算法的文档检索器,适用于信息检索、问答系统、知识库等场景下的高效文本相关性排序。
+
+ 该类通过集成 rank_bm25 库,实现了对文档集合的 BM25 检索,支持文档的动态添加、删除、批量构建索引等操作。
+ 适合文档集合相对静态、检索速度要求较高的场景。对于频繁增删文档的场景,建议使用向量检索(如 VectorStoreRetriever)。
+
+ 主要特性:
+ - 支持从文本列表或 Document 对象列表快速构建 BM25 检索器。
+ - 支持自定义分词/预处理函数,适配不同语言和分词需求。
+ - 支持动态添加、删除文档(每次操作会重建索引,适合中小规模数据集)。
+ - 可获取检索分数、top-k 文档及分数、检索器配置信息等。
+ - 兼容异步文档添加/删除,便于大规模数据处理。
+ - 通过 Pydantic 校验参数,保证配置安全。
+
+ 主要参数:
+ vectorizer (Any): BM25 向量化器实例(通常为 BM25Okapi)。
+ docs (List[Document]): 当前检索器持有的文档对象列表。
+ k (int): 默认返回的相关文档数量。
+ preprocess_func (Callable): 文本分词/预处理函数,默认为空格分词。
+ bm25_params (Dict): 传递给 BM25Okapi 的参数(如 k1、b 等)。
+
+ 核心方法:
+ - from_texts/from_documents: 从原始文本或 Document 构建检索器。
+ - _get_relevant_documents: 检索与查询最相关的前 k 个文档。
+ - get_scores: 获取查询对所有文档的 BM25 分数。
+ - get_top_k_with_scores: 获取 top-k 文档及其分数。
+ - add_documents/delete_documents: 动态增删文档并重建索引。
+ - get_bm25_info: 获取检索器配置信息和统计。
+ - update_k: 动态调整返回文档数量。
+
+ 性能注意事项:
+ - 每次添加/删除文档都会重建 BM25 索引,适合文档量较小或更新不频繁的场景。
+ - 文档量较大或频繁更新时,建议使用向量检索方案。
+ - 支持异步操作,便于大规模数据处理。
+
+ 典型用法:
+ >>> retriever = BM25Retriever.from_texts(["文本1", "文本2"], k=3)
+ >>> results = retriever._get_relevant_documents("查询语句")
+ >>> retriever.add_documents([Document(content="新文档")])
+ >>> retriever.delete_documents(ids=["doc_id"])
+ >>> info = retriever.get_bm25_info()
+
Attributes:
vectorizer: BM25 向量化器实例
docs: 文档列表
@@ -95,7 +113,7 @@ def __init__(self, **kwargs):
# 设置属性
self.vectorizer = kwargs.get('vectorizer')
self.docs = kwargs.get('docs', [])
- self.k = kwargs.get('k', 4)
+ self.k = kwargs.get('k', 5)
self.preprocess_func = kwargs.get('preprocess_func', default_preprocessing_func)
self.bm25_params = kwargs.get('bm25_params', {})
@@ -125,7 +143,7 @@ def validate_params(cls, values: Dict[str, Any]) -> Dict[str, Any]:
Returns:
验证后的值
"""
- k = values.get("k", 4)
+ k = values.get("k", 5)
if k <= 0:
raise ValueError(f"k 必须大于 0,当前值: {k}")
@@ -139,7 +157,6 @@ def from_texts(
ids: Optional[Iterable[str]] = None,
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
- k: int = 4,
**kwargs: Any,
) -> "BM25Retriever":
"""从文本列表创建 BM25Retriever
@@ -150,7 +167,6 @@ def from_texts(
ids: ID列表,可选
bm25_params: BM25 算法参数,可选
preprocess_func: 预处理函数
- k: 返回文档数量
**kwargs: 其他参数
Returns:
@@ -213,7 +229,6 @@ def from_texts(
return cls(
vectorizer=vectorizer,
docs=docs,
- k=k,
preprocess_func=preprocess_func,
bm25_params=bm25_params,
**kwargs
@@ -225,7 +240,6 @@ def from_documents(
documents: Iterable[Document],
bm25_params: Optional[Dict[str, Any]] = None,
preprocess_func: Callable[[str], List[str]] = default_preprocessing_func,
- k: int = 4,
**kwargs: Any,
) -> "BM25Retriever":
"""从文档列表创建 BM25Retriever
@@ -234,7 +248,6 @@ def from_documents(
documents: 文档列表
bm25_params: BM25 算法参数,可选
preprocess_func: 预处理函数
- k: 返回文档数量
**kwargs: 其他参数
Returns:
@@ -254,51 +267,52 @@ def from_documents(
ids=ids,
bm25_params=bm25_params,
preprocess_func=preprocess_func,
- k=k,
**kwargs,
)
def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
- """获取与查询相关的文档
-
+ """获取与查询相关的前k个文档
+
Args:
query: 查询字符串
**kwargs: 其他参数,可能包含 'k' 来覆盖默认的返回数量
-
+
Returns:
相关文档列表
-
+
Raises:
ValueError: 如果向量化器未初始化
"""
if self.vectorizer is None:
raise ValueError("BM25 向量化器未初始化")
-
+
if not self.docs:
logger.warning("文档列表为空,返回空结果")
return []
-
+
# 获取返回文档数量
k = kwargs.get('k', self.k)
k = min(k, len(self.docs)) # 确保不超过总文档数
-
+
try:
# 预处理查询
processed_query = self.preprocess_func(query)
logger.debug(f"预处理后的查询: {processed_query}")
+
+ # 获取所有文档的分数
+ scores = self.vectorizer.get_scores(processed_query)
+ # 获取分数最高的前k个文档索引
- # 获取相关文档
- relevant_docs = self.vectorizer.get_top_n(
- processed_query, self.docs, n=k
- )
-
- logger.debug(f"找到 {len(relevant_docs)} 个相关文档")
- return relevant_docs
-
+ top_indices = np.argsort(scores)[::-1][:k]
+ # 返回前k个文档
+ top_docs = [self.docs[idx] for idx in top_indices]
+ logger.debug(f"找到 {len(top_docs)} 个相关文档")
+ return top_docs
+
except Exception as e:
logger.error(f"BM25 搜索时发生错误: {e}")
raise
-
+
def get_scores(self, query: str) -> List[float]:
"""获取查询对所有文档的 BM25 分数
diff --git a/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py b/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py
new file mode 100644
index 0000000..cae3923
--- /dev/null
+++ b/rag_factory/Retrieval/Retriever/Retriever_MultiPath.py
@@ -0,0 +1,164 @@
+from typing import List, Dict, Any, Optional, Union
+
+from ..RetrieverBase import Document
+from ..RetrieverBase import BaseRetriever
+from ..utils.Fusion import FusionMethod, RRFusion, RetrievalResult
+
+
+class MultiPathRetriever(BaseRetriever):
+ """
+ 多路检索器
+
+ 该类实现了多路检索功能,可以同时使用多个检索器进行文档检索,
+ 并通过指定的融合方法将多个检索器的结果进行合并和排序。
+
+ Attributes:
+ retrievers (List[BaseRetriever]): 检索器列表,每个检索器需要实现retrieve方法
+ fusion_method (FusionMethod): 融合方法,用于合并多个检索器的结果
+ top_k_per_retriever (int): 每个检索器返回的结果数量
+ """
+
+ def __init__(self,
+ retrievers: List[BaseRetriever],
+ fusion_method: Optional[FusionMethod] = None,
+ top_k_per_retriever: int = 50):
+ """
+ 初始化多路检索器
+
+ Args:
+ retrievers (List[BaseRetriever]): 检索器列表,每个检索器需要实现retrieve方法
+ fusion_method (Optional[FusionMethod]): 融合方法,默认为RRF (Reciprocal Rank Fusion)
+ top_k_per_retriever (int): 每个检索器返回的结果数量,默认为50
+ """
+ self.retrievers = retrievers
+ self.fusion_method = fusion_method or RRFusion()
+ self.top_k_per_retriever = top_k_per_retriever
+
+ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
+ """
+ 获取与查询相关的文档
+
+ 该方法会调用所有配置的检索器,获取每个检索器的检索结果,
+ 然后使用指定的融合方法将所有结果进行合并和排序。
+
+ Args:
+ query (str): 查询字符串
+ **kwargs (Any): 其他参数,包括top_k等
+
+ Returns:
+ List[Document]: 融合后的相关文档列表,按相关性排序
+
+ Note:
+ - 每个检索器的结果会被转换为RetrievalResult格式
+ - 支持多种输入格式:Document对象、字典格式、字符串等
+ - 融合后的结果会将score和rank信息保存在Document的metadata中
+ """
+ top_k = kwargs.get('top_k', 10)
+
+ # 从每个检索器获取结果
+ all_results = []
+ for retriever in self.retrievers:
+ try:
+ # 使用BaseRetriever的invoke方法
+ documents = retriever.invoke(query, **{**kwargs, 'k': self.top_k_per_retriever})
+
+ # 转换为RetrievalResult格式
+ formatted_results = []
+ for i, doc in enumerate(documents):
+ if isinstance(doc, Document):
+ # 如果是Document对象
+ retrieval_result = RetrievalResult(
+ document=doc,
+ score=getattr(doc, 'score', 1.0),
+ rank=i + 1
+ )
+ elif isinstance(doc, dict):
+ # 如果返回的是字典格式,需要转换为Document对象
+ content = doc.get('content', '')
+ metadata = doc.get('metadata', {})
+ doc_id = doc.get('id')
+
+ document = Document(
+ content=content,
+ metadata=metadata,
+ id=doc_id
+ )
+
+ retrieval_result = RetrievalResult(
+ document=document,
+ score=doc.get('score', 1.0),
+ rank=i + 1
+ )
+ else:
+ # 如果是字符串或其他格式,转换为Document对象
+ document = Document(
+ content=str(doc),
+ metadata={},
+ id=None
+ )
+
+ retrieval_result = RetrievalResult(
+ document=document,
+ score=1.0,
+ rank=i + 1
+ )
+ formatted_results.append(retrieval_result)
+
+ all_results.append(formatted_results)
+
+ except Exception as e:
+ print(f"检索器 {type(retriever).__name__} 执行失败: {e}")
+ all_results.append([])
+
+ # 使用融合方法合并结果
+ if not all_results or all(len(results) == 0 for results in all_results):
+ return []
+
+ fused_results = self.fusion_method.fuse(all_results, top_k)
+
+ # 转换回Document格式
+ documents = []
+ for result in fused_results:
+ doc = result.document
+ # 将score和rank添加到metadata中以便保留
+ if doc.metadata is None:
+ doc.metadata = {}
+ doc.metadata['score'] = result.score
+ doc.metadata['rank'] = result.rank
+ documents.append(doc)
+
+ return documents
+
+
+ def add_retriever(self, retriever: BaseRetriever):
+ """
+ 添加新的检索器到多路检索器中
+
+ Args:
+ retriever (BaseRetriever): 要添加的检索器实例
+ """
+ self.retrievers.append(retriever)
+
+ def remove_retriever(self, name: str):
+ """
+ 移除指定名称的检索器
+
+ Args:
+ name (str): 要移除的检索器的类名
+
+ Note:
+ 该方法通过比较检索器的类名来识别要移除的检索器
+ """
+ for i, retriever in enumerate(self.retrievers):
+ if hasattr(retriever, '__class__') and retriever.__class__.__name__ == name:
+ self.retrievers.pop(i)
+ break
+
+ def set_fusion_method(self, fusion_method: FusionMethod):
+ """
+ 设置融合方法
+
+ Args:
+ fusion_method (FusionMethod): 新的融合方法实例
+ """
+ self.fusion_method = fusion_method
diff --git a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py
index 734c7d8..2b7eafe 100644
--- a/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py
+++ b/rag_factory/Retrieval/Retriever/Retriever_VectorStore.py
@@ -10,8 +10,8 @@
import logging
from pydantic import ConfigDict, Field, model_validator
-from Retrieval.RetrieverBase import BaseRetriever, Document
-from Store.VectorStore.VectorStoreBase import VectorStore
+from ..RetrieverBase import BaseRetriever, Document
+from ...Store.VectorStore.VectorStoreBase import VectorStore
logger = logging.getLogger(__name__)
@@ -141,9 +141,15 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
# 合并搜索参数
search_params = {**self.search_kwargs, **kwargs}
+ # 获取返回文档数量,参考BM25Retriever的做法
+ k = search_params.get('k', getattr(self, 'k', 5))
+ search_params['k'] = k
+
try:
if self.search_type == "similarity":
docs = self.vectorstore.similarity_search(query, **search_params)
+ # 确保返回前k个文档
+ docs = docs[:k]
elif self.search_type == "similarity_score_threshold":
docs_and_similarities = (
@@ -152,11 +158,15 @@ def _get_relevant_documents(self, query: str, **kwargs: Any) -> List[Document]:
)
)
docs = [doc for doc, _ in docs_and_similarities]
+ # 确保返回前k个文档
+ docs = docs[:k]
elif self.search_type == "mmr":
docs = self.vectorstore.max_marginal_relevance_search(
query, **search_params
)
+ # 确保返回前k个文档
+ docs = docs[:k]
else:
msg = f"不支持的搜索类型: {self.search_type}"
diff --git a/rag_factory/Retrieval/Retriever/registry.py b/rag_factory/Retrieval/Retriever/registry.py
new file mode 100644
index 0000000..8558c76
--- /dev/null
+++ b/rag_factory/Retrieval/Retriever/registry.py
@@ -0,0 +1,118 @@
+from typing import Dict, Type, Any, Optional, List
+import logging
+from .Retriever_BM25 import BM25Retriever
+from .Retriever_MultiPath import MultiPathRetriever
+from .Retriever_VectorStore import VectorStoreRetriever
+from ..RetrieverBase import BaseRetriever
+
+logger = logging.getLogger(__name__)
+
+# TODO 统一所有检索器的调用方式
+
+class RetrieverRegistry:
+ """检索器注册表,用于管理和创建不同类型的检索器"""
+
+ _retrievers: Dict[str, Type[BaseRetriever]] = {}
+
+ @classmethod
+ def register(cls, name: str, retriever_class: Type[BaseRetriever]):
+ """
+ 注册检索器类
+
+ Args:
+ name: 检索器名称
+ retriever_class: 检索器类
+
+ Raises:
+ ValueError: 当检索器类不是BaseRetriever的子类时
+ TypeError: 当name不是字符串时
+ """
+ if not isinstance(name, str):
+ raise TypeError("检索器名称必须是字符串类型")
+
+ if not issubclass(retriever_class, BaseRetriever):
+ raise ValueError(f"检索器类 {retriever_class} 必须继承自 BaseRetriever")
+
+ if name in cls._retrievers:
+ logger.warning(f"检索器 '{name}' 已存在,将被覆盖")
+
+ cls._retrievers[name] = retriever_class
+ logger.info(f"检索器 '{name}' 注册成功")
+
+ @classmethod
+ def create(cls, name: str, **kwargs) -> BaseRetriever:
+ """
+ 创建检索器实例
+
+ Args:
+ name: 检索器名称
+ **kwargs: 传递给检索器构造函数的参数
+
+ Returns:
+ BaseRetriever: 检索器实例
+
+ Raises:
+ ValueError: 当检索器未注册时
+ Exception: 当检索器创建失败时
+ """
+ if name not in cls._retrievers:
+ available = ', '.join(cls.list_available())
+ raise ValueError(f"检索器 '{name}' 未注册。可用的检索器: {available}")
+
+ retriever_class = cls._retrievers[name]
+ try:
+ return retriever_class(**kwargs)
+ except Exception as e:
+ logger.error(f"创建检索器 '{name}' 失败: {str(e)}")
+ raise Exception(f"无法创建检索器 '{name}': {str(e)}") from e
+
+ @classmethod
+ def list_available(cls) -> List[str]:
+ """
+ 获取所有可用的检索器名称
+
+ Returns:
+ List[str]: 检索器名称列表
+ """
+ return list(cls._retrievers.keys())
+
+ @classmethod
+ def unregister(cls, name: str) -> bool:
+ """
+ 取消注册检索器
+
+ Args:
+ name: 检索器名称
+
+ Returns:
+ bool: 是否成功取消注册
+ """
+ if name in cls._retrievers:
+ del cls._retrievers[name]
+ logger.info(f"检索器 '{name}' 取消注册成功")
+ return True
+ return False
+
+ @classmethod
+ def is_registered(cls, name: str) -> bool:
+ """
+ 检查检索器是否已注册
+
+ Args:
+ name: 检索器名称
+
+ Returns:
+ bool: 是否已注册
+ """
+ return name in cls._retrievers
+
+ @classmethod
+ def clear_all(cls):
+ """清除所有已注册的检索器"""
+ cls._retrievers.clear()
+ logger.info("所有检索器注册已清除")
+
+# 预注册内置检索器
+RetrieverRegistry.register("bm25", BM25Retriever)
+RetrieverRegistry.register("multipath", MultiPathRetriever)
+RetrieverRegistry.register("vectorstore", VectorStoreRetriever)
diff --git a/rag_factory/Retrieval/Retriever/test_bm25_retriever.py b/rag_factory/Retrieval/Retriever/test_bm25_retriever.py
new file mode 100644
index 0000000..41b54d9
--- /dev/null
+++ b/rag_factory/Retrieval/Retriever/test_bm25_retriever.py
@@ -0,0 +1,64 @@
+from Retriever_BM25 import BM25Retriever
+from rag_factory.Retrieval.RetrieverBase import Document
+from typing import List
+import logging
+logger = logging.getLogger(__name__)
+
+def chinese_preprocessing_func(text: str) -> List[str]:
+ """中文文本预处理函数
+
+ Args:
+ text: 输入的中文文本
+
+ Returns:
+ 分词后的词语列表
+ """
+ try:
+ import jieba
+ return list(jieba.cut(text))
+ except ImportError:
+ logger.warning("jieba 未安装,使用默认分词方法。请安装: pip install jieba")
+ return text.split()
+
+
+if __name__ == "__main__":
+ # 构造测试数据
+ texts = [
+ "这是第一个测试文档",
+ "这是第二个测试文档,内容稍有不同",
+ "这是第三个文档,讨论不同的主题",
+ "第四个文档包含更多详细信息",
+ "最后一个文档作为总结"
+ ]
+
+ print(chinese_preprocessing_func("这是第一个测试文档"))
+ metadatas = [
+ {"source": "doc1", "type": "test"},
+ {"source": "doc2", "type": "test"},
+ {"source": "doc3", "type": "example"},
+ {"source": "doc4", "type": "detailed"},
+ {"source": "doc5", "type": "summary"},
+ ]
+ ids = [f"doc{i+1}" for i in range(len(texts))]
+
+ # 创建BM25Retriever
+ retriever = BM25Retriever.from_texts(
+ texts=texts,
+ metadatas=metadatas,
+ ids=ids,
+ preprocess_func=chinese_preprocessing_func,
+ k=3 # top3
+ )
+
+ # 查询
+ query = "第二个测试文档内容"
+ print(f"\n查询: {query}")
+
+ # retriever.update_k(4)
+ results = retriever.invoke(query, k=4)
+ print("\n召回结果:")
+ for i, (doc) in enumerate(results):
+ # print(f"{i+1}. 分数: {score:.4f}")
+ print(f" ID: {doc.id}")
+ print(f" 内容: {doc.content}")
+ print(f" 元数据: {doc.metadata}")
diff --git a/rag_factory/Retrieval/__init__.py b/rag_factory/Retrieval/__init__.py
index 85c684c..445771a 100644
--- a/rag_factory/Retrieval/__init__.py
+++ b/rag_factory/Retrieval/__init__.py
@@ -1,4 +1,5 @@
from .RetrieverBase import BaseRetriever, Document
from .Retriever.Retriever_VectorStore import VectorStoreRetriever
+from .Retriever.registry import RetrieverRegistry
-__all__ = ["BaseRetriever", "Document", "VectorStoreRetriever"]
\ No newline at end of file
+__all__ = ["BaseRetriever", "Document", "VectorStoreRetriever", "RetrieverRegistry"]
\ No newline at end of file
diff --git a/rag_factory/Retrieval/utils/Fusion.py b/rag_factory/Retrieval/utils/Fusion.py
new file mode 100644
index 0000000..c876590
--- /dev/null
+++ b/rag_factory/Retrieval/utils/Fusion.py
@@ -0,0 +1,76 @@
+from typing import List
+from dataclasses import dataclass
+from abc import ABC, abstractmethod
+from collections import defaultdict
+
+from ..RetrieverBase import Document
+
+
+@dataclass
+class RetrievalResult:
+ """检索结果的数据类"""
+ document: Document
+ score: float
+ rank: int = 0
+
+
+class FusionMethod(ABC):
+ """融合方法的抽象基类"""
+
+ @abstractmethod
+ def fuse(self, results: List[List[RetrievalResult]], top_k: int) -> List[RetrievalResult]:
+ """
+ 融合多个检索器的结果
+
+ Args:
+ results: 每个检索器的结果列表
+ top_k: 返回的最终结果数量
+
+ Returns:
+ 融合后的结果列表
+ """
+ pass
+
+
+class RRFusion(FusionMethod):
+ """Reciprocal Rank Fusion (RRF) 方法"""
+
+ def __init__(self, k: float = 60.0):
+ """
+ Args:
+ k: RRF中的常数,默认为60.0
+ """
+ self.k = k
+
+ def fuse(self, results: List[List[RetrievalResult]], top_k: int) -> List[RetrievalResult]:
+ # 为每个结果分配rank
+ for retriever_results in results:
+ for i, result in enumerate(retriever_results):
+ result.rank = i + 1
+
+ # 计算RRF分数
+ rrf_scores = defaultdict(float)
+ document_map = {}
+
+ for retriever_results in results:
+ for result in retriever_results:
+ rrf_score = 1.0 / (self.k + result.rank)
+ # 使用文档内容作为key来去重
+ content_key = result.document.content
+ rrf_scores[content_key] += rrf_score
+ document_map[content_key] = result.document
+
+ # 按RRF分数排序
+ sorted_items = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
+
+ # 构建最终结果
+ fused_results = []
+ for i, (content, rrf_score) in enumerate(sorted_items[:top_k]):
+ result = RetrievalResult(
+ document=document_map[content],
+ score=rrf_score,
+ rank=i + 1
+ )
+ fused_results.append(result)
+
+ return fused_results
diff --git a/rag_factory/Store/VectorStore/VectorStoreBase.py b/rag_factory/Store/VectorStore/VectorStoreBase.py
index f12931d..9027d40 100644
--- a/rag_factory/Store/VectorStore/VectorStoreBase.py
+++ b/rag_factory/Store/VectorStore/VectorStoreBase.py
@@ -15,6 +15,7 @@
Sequence,
Iterable,
Iterator,
+ Tuple,
)
from itertools import cycle
import asyncio
@@ -23,6 +24,7 @@
if TYPE_CHECKING:
from collections.abc import Collection
from rag_factory.Embed import Embeddings
+ from ...Retrieval.Retriever.Retriever_VectorStore import VectorStoreRetriever
logger = logging.getLogger(__name__)
@@ -293,7 +295,7 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
def similarity_search_with_score(
self, *args: Any, **kwargs: Any
- ) -> list[tuple[Document, float]]:
+ ) -> list[Tuple[Document, float]]:
"""使用距离运行相似性搜索
Args:
@@ -307,7 +309,7 @@ def similarity_search_with_score(
async def asimilarity_search_with_score(
self, *args: Any, **kwargs: Any
- ) -> list[tuple[Document, float]]:
+ ) -> list[Tuple[Document, float]]:
"""异步使用距离运行相似性搜索"""
return await asyncio.get_event_loop().run_in_executor(
ThreadPoolExecutor(), self.similarity_search_with_score, *args, **kwargs
@@ -318,7 +320,7 @@ def _similarity_search_with_relevance_scores(
query: str,
k: int = 4,
**kwargs: Any,
- ) -> list[tuple[Document, float]]:
+ ) -> list[Tuple[Document, float]]:
"""默认的带相关性分数的相似性搜索
必要时在子类中修改。
@@ -344,7 +346,7 @@ async def _asimilarity_search_with_relevance_scores(
query: str,
k: int = 4,
**kwargs: Any,
- ) -> list[tuple[Document, float]]:
+ ) -> list[Tuple[Document, float]]:
"""异步带相关性分数的相似性搜索"""
relevance_score_fn = self._select_relevance_score_fn()
docs_and_scores = await self.asimilarity_search_with_score(query, k, **kwargs)
@@ -355,7 +357,7 @@ def similarity_search_with_relevance_scores(
query: str,
k: int = 4,
**kwargs: Any,
- ) -> list[tuple[Document, float]]:
+ ) -> list[Tuple[Document, float]]:
"""返回[0, 1]范围内的文档和相关性分数
0表示不相似,1表示最相似。
@@ -402,7 +404,7 @@ async def asimilarity_search_with_relevance_scores(
query: str,
k: int = 4,
**kwargs: Any,
- ) -> list[tuple[Document, float]]:
+ ) -> list[Tuple[Document, float]]:
"""异步返回[0, 1]范围内的文档和相关性分数"""
score_threshold = kwargs.pop("score_threshold", None)
@@ -627,11 +629,11 @@ def _get_retriever_tags(self) -> list[str]:
def as_retriever(self, **kwargs: Any) -> "VectorStoreRetriever":
"""从此VectorStore返回初始化的VectorStoreRetriever"""
- from Retrieval import VectorStoreRetriever
+ # 延迟导入以避免循环依赖
+ from ...Retrieval.Retriever.Retriever_VectorStore import VectorStoreRetriever
tags = kwargs.pop("tags", None) or [] + self._get_retriever_tags()
return VectorStoreRetriever(vectorstore=self, tags=tags, **kwargs)
-
diff --git a/rag_factory/Store/VectorStore/VectorStore_Faiss.py b/rag_factory/Store/VectorStore/VectorStore_Faiss.py
index 0e38e37..2f2f55d 100644
--- a/rag_factory/Store/VectorStore/VectorStore_Faiss.py
+++ b/rag_factory/Store/VectorStore/VectorStore_Faiss.py
@@ -4,20 +4,21 @@
import os
import uuid
import numpy as np
-from typing import Any, Optional, Callable
+from typing import Any, Optional, Callable, List, Tuple
from .VectorStoreBase import VectorStore, Document
-from Embed import Embeddings
+from ...Embed.Embedding_Base import Embeddings
import asyncio
from concurrent.futures import ThreadPoolExecutor
+# TODO 需要支持GPU,提高速度
def _mmr_select(
- docs_and_scores: list[tuple[Document, float]],
- embeddings: list[list[float]],
- query_embedding: list[float],
+ docs_and_scores: List[Tuple[Document, float]],
+ embeddings: List[List[float]],
+ query_embedding: List[float],
k: int,
lambda_mult: float = 0.5,
-) -> list[Document]:
+) -> List[Document]:
"""最大边际相关性选择算法"""
if k >= len(docs_and_scores):
return [doc for doc, _ in docs_and_scores]
@@ -153,12 +154,12 @@ def _normalize_vectors(self, vectors: np.ndarray) -> np.ndarray:
def add_texts(
self,
- texts: list[str],
- metadatas: Optional[list[dict]] = None,
+ texts: List[str],
+ metadatas: Optional[List[dict]] = None,
*,
- ids: Optional[list[str]] = None,
+ ids: Optional[List[str]] = None,
**kwargs: Any,
- ) -> list[str]:
+ ) -> List[str]:
"""添加文本到向量存储"""
if not texts:
return []
@@ -209,12 +210,12 @@ def add_texts(
async def aadd_texts(
self,
- texts: list[str],
- metadatas: Optional[list[dict]] = None,
+ texts: List[str],
+ metadatas: Optional[List[dict]] = None,
*,
- ids: Optional[list[str]] = None,
+ ids: Optional[List[str]] = None,
**kwargs: Any,
- ) -> list[str]:
+ ) -> List[str]:
"""异步添加文本"""
return await asyncio.get_event_loop().run_in_executor(
ThreadPoolExecutor(), self.add_texts, texts, metadatas, ids, **kwargs
@@ -222,14 +223,14 @@ async def aadd_texts(
def similarity_search(
self, query: str, k: int = 4, **kwargs: Any
- ) -> list[Document]:
+ ) -> List[Document]:
"""相似性搜索"""
docs_and_scores = self.similarity_search_with_score(query, k, **kwargs)
return [doc for doc, _ in docs_and_scores]
def similarity_search_with_score(
self, query: str, k: int = 4, **kwargs: Any
- ) -> list[tuple[Document, float]]:
+ ) -> List[Tuple[Document, float]]:
"""带分数的相似性搜索"""
if self.index is None or self.index.ntotal == 0:
return []
@@ -239,15 +240,15 @@ def similarity_search_with_score(
return self.similarity_search_by_vector_with_score(query_embedding, k, **kwargs)
def similarity_search_by_vector(
- self, embedding: list[float], k: int = 4, **kwargs: Any
- ) -> list[Document]:
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Document]:
"""根据向量相似性搜索"""
docs_and_scores = self.similarity_search_by_vector_with_score(embedding, k, **kwargs)
return [doc for doc, _ in docs_and_scores]
def similarity_search_by_vector_with_score(
- self, embedding: list[float], k: int = 4, **kwargs: Any
- ) -> list[tuple[Document, float]]:
+ self, embedding: List[float], k: int = 4, **kwargs: Any
+ ) -> List[Tuple[Document, float]]:
"""根据向量带分数的相似性搜索"""
if self.index is None or self.index.ntotal == 0:
return []
@@ -278,7 +279,7 @@ def max_marginal_relevance_search(
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
- ) -> list[Document]:
+ ) -> List[Document]:
"""最大边际相关性搜索"""
if self.index is None or self.index.ntotal == 0:
return []
@@ -324,12 +325,12 @@ def max_marginal_relevance_search(
def max_marginal_relevance_search_by_vector(
self,
- embedding: list[float],
+ embedding: List[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
**kwargs: Any,
- ) -> list[Document]:
+ ) -> List[Document]:
"""根据向量的最大边际相关性搜索"""
if self.index is None or self.index.ntotal == 0:
return []
@@ -369,7 +370,7 @@ def max_marginal_relevance_search_by_vector(
return selected_docs
- def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[bool]:
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
"""删除文档(FAISS不支持直接删除,需要重建索引)"""
if ids is None:
# 删除所有
@@ -412,7 +413,7 @@ def delete(self, ids: Optional[list[str]] = None, **kwargs: Any) -> Optional[boo
return True
- def get_by_ids(self, ids: list[str]) -> list[Document]:
+ def get_by_ids(self, ids: List[str]) -> List[Document]:
"""根据ID获取文档"""
return [self.docstore[doc_id] for doc_id in ids if doc_id in self.docstore]
@@ -482,11 +483,11 @@ def load_local(
@classmethod
def from_texts(
cls,
- texts: list[str],
+ texts: List[str],
embedding: Embeddings,
- metadatas: Optional[list[dict]] = None,
+ metadatas: Optional[List[dict]] = None,
*,
- ids: Optional[list[str]] = None,
+ ids: Optional[List[str]] = None,
**kwargs: Any,
) -> "FaissVectorStore":
"""从文本创建FAISS向量存储"""
@@ -497,11 +498,11 @@ def from_texts(
@classmethod
async def afrom_texts(
cls,
- texts: list[str],
+ texts: List[str],
embedding: Embeddings,
- metadatas: Optional[list[dict]] = None,
+ metadatas: Optional[List[dict]] = None,
*,
- ids: Optional[list[str]] = None,
+ ids: Optional[List[str]] = None,
**kwargs: Any,
) -> "FaissVectorStore":
"""异步从文本创建FAISS向量存储"""
diff --git a/rag_factory/Store/VectorStore/registry.py b/rag_factory/Store/VectorStore/registry.py
index 611b690..ef7f41f 100644
--- a/rag_factory/Store/VectorStore/registry.py
+++ b/rag_factory/Store/VectorStore/registry.py
@@ -1,7 +1,7 @@
# VectorStore/registry.py
-from typing import Dict, Type, Any, Optional
+from typing import Dict, Type, Any, Optional, List
from .VectorStoreBase import VectorStore
-from Embed.Embedding_Base import Embeddings
+from ...Embed.Embedding_Base import Embeddings
from .VectorStore_Faiss import FaissVectorStore
@@ -22,9 +22,21 @@ def create(cls, name: str, embedding: Embeddings, **kwargs) -> VectorStore:
raise ValueError(f"未注册的向量存储类型: {name}")
return cls._stores[name](embedding=embedding, **kwargs)
+
+ @classmethod
+ def load(cls, name: str, folder_path: str, embedding: Embeddings, **kwargs) -> VectorStore:
+ """加载已保存的向量存储实例"""
+ if name not in cls._stores:
+ raise ValueError(f"未注册的向量存储类型: {name}")
+ store_class = cls._stores[name]
+ if hasattr(store_class, "load_local"):
+ return store_class.load_local(folder_path, embeddings=embedding, **kwargs)
+ else:
+ raise ValueError(f"{store_class.__name__} 不支持从本地加载")
+
@classmethod
- def list_available(cls) -> list[str]:
+ def list_available(cls) -> List[str]:
"""列出可用的向量存储类型"""
return list(cls._stores.keys())
diff --git a/rag_factory/Store/VectorStore/test.py b/rag_factory/Store/VectorStore/test.py
new file mode 100644
index 0000000..26a839b
--- /dev/null
+++ b/rag_factory/Store/VectorStore/test.py
@@ -0,0 +1,57 @@
+# test_faiss_vectorstore.py
+import numpy as np
+from rag_factory.Store.VectorStore.VectorStore_Faiss import FaissVectorStore, Document
+from rag_factory.Embed import EmbeddingRegistry
+from rag_factory.Store import VectorStoreRegistry
+from rag_factory.Retrieval import RetrieverRegistry
+
+if __name__ == "__main__":
+ # 初始化
+ embeddings = EmbeddingRegistry.create(name="huggingface", model_name="/finance_ML/dataarc_syn_database/model/Qwen/qwen_embedding_0.6B")
+ # vs = VectorStoreRegistry.load(name="faiss", folder_path="/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/faiss_test_store", embedding=embeddings)
+ # print(vs)
+ # vs = VectorStoreRegistry.create(name="faiss", embedding=embeddings)
+ vs = FaissVectorStore.load_local(folder_path="/data/FinAi_Mapping_Knowledge/chenmingzhen/test_faiss_store", embeddings=embeddings, index_name="index")
+ # # # 添加一些文本
+ # texts = ["苹果是一种水果", "香蕉是黄色的", "猫是一种动物", "狗喜欢跑步"]
+ # # ids = vs.add_texts(texts, metadatas=[{"type": "test"} for _ in texts])
+
+ # documents = [Document(content=text, metadata={"type": "test"}) for text in texts]
+ # vectorstore = vs.from_documents(documents, embedding=embeddings)
+ # vectorstore.save_local(folder_path="/data/FinAi_Mapping_Knowledge/chenmingzhen/RAG-Factory/faiss_test_store")
+ # print(f"添加的文档ID: {ids}")
+
+ # # 相似性搜索
+ # results = vs.similarity_search("苹果", k=2)
+ # print("\n=== 相似性搜索结果 ===")
+ # for doc in results:
+ # print(f"内容: {doc.content}, 元数据: {doc.metadata}")
+
+ # # 带分数的搜索
+ # results_with_score = vs.similarity_search_with_score("苹果", k=2)
+ # print("\n=== 带分数的搜索结果 ===")
+ # for doc, score in results_with_score:
+ # print(f"内容: {doc.content}, 分数: {score}")
+
+ # # 最大边际相关性搜索
+ # mmr_results = vs.max_marginal_relevance_search("苹果", k=2, fetch_k=3)
+ # print("\n=== MMR搜索结果 ===")
+ # for doc in mmr_results:
+ # print(f"内容: {doc.content}")
+
+ # # 保存到本地
+ # save_path = "./faiss_test_store"
+ # vs.save_local(save_path)
+ # print(f"\n索引已保存到: {save_path}")
+
+ # 从本地加载
+ # loaded_vs = FaissVectorStore.load_local(save_path, embeddings)
+ # load_results = loaded_vs.similarity_search("苹果", k=2)
+ # print("\n=== 从本地加载后的搜索结果 ===")
+ # for doc in load_results:
+ # print(f"内容: {doc.content}")
+
+ retriever = vs.as_retriever(search_kwargs={"k": 2})
+ # retriever = RetrieverRegistry.create(name="vectorstore", vectorstore=vs)
+ results = retriever.invoke("文件名称:GB$T 2828.1-2012 计数抽样检验程序 第1部分:按接收质量限(AQL)检索的逐批检验抽样计划.pdf\n4 不合格的表示 4. 1 总则\n不合格的程度以不合格品百分数(见 3.1.8和 3.1.9)或每百单位产品不合格数(见 3.1.10和3.1.11)表示。表7、表8和表10 是基于假定不合格的出现是随机且统计独立的。如果已知产品的某个不合格可能由某一条件引起,此条件还可能引起其他一些不合格,则应仅考虑该产品是否为合格品,而不管该产品有多少个不合格。")
+ print(results)
diff --git a/rag_factory/Store/__init__.py b/rag_factory/Store/__init__.py
index a2b1a3a..895b140 100644
--- a/rag_factory/Store/__init__.py
+++ b/rag_factory/Store/__init__.py
@@ -1,5 +1,7 @@
from .VectorStore.registry import VectorStoreRegistry
+from .VectorStore.VectorStore_Faiss import FaissVectorStore
__all__ = [
"VectorStoreRegistry",
+ "FaissVectorStore",
]
\ No newline at end of file
diff --git a/rag_factory/llms/__init__.py b/rag_factory/llms/__init__.py
index 33cf556..2476ec8 100644
--- a/rag_factory/llms/__init__.py
+++ b/rag_factory/llms/__init__.py
@@ -1,5 +1,9 @@
from .openai_compatible import OpenAICompatible
from .dashscope.base import DashScope, DashScopeGenerationModels
+from .openai_llm import OpenAILLM
+from .registry import LLMRegistry
__all__ = ['OpenAICompatible',
- "DashScope", "DashScopeGenerationModels"]
+ "DashScope", "DashScopeGenerationModels",
+ "OpenAILLM",
+ "LLMRegistry"]
diff --git a/rag_factory/llms/llm_base.py b/rag_factory/llms/llm_base.py
new file mode 100644
index 0000000..78e4475
--- /dev/null
+++ b/rag_factory/llms/llm_base.py
@@ -0,0 +1,142 @@
+from abc import ABC, abstractmethod
+from typing import Dict, Any, List, Optional, Union
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+class LLMBase(ABC):
+ """
+ 大语言模型基类,定义了所有LLM实现必须遵循的接口
+ """
+
+ def __init__(self, model_name: str, **kwargs):
+ """
+ 初始化LLM基类
+
+ Args:
+ model_name: 模型名称
+ **kwargs: 其他配置参数
+ """
+ self.model_name = model_name
+ self.config = kwargs
+ self._setup_logging()
+
+ def _setup_logging(self):
+ """设置日志配置"""
+ self.logger = logging.getLogger(f"{__name__}.{self.__class__.__name__}")
+
+
+ @abstractmethod
+ def chat(
+ self,
+ messages: List[Dict[str, str]],
+ max_tokens: Optional[int] = None,
+ temperature: Optional[float] = None,
+ **kwargs
+ ) -> str:
+ """
+ 对话式生成
+
+ Args:
+ messages: 对话消息列表,格式如[{"role": "user", "content": "问题"}]
+ max_tokens: 最大生成token数
+ temperature: 生成温度参数
+ **kwargs: 其他生成参数
+
+ Returns:
+ 生成的回复文本
+ """
+ pass
+
+ @abstractmethod
+ def stream_chat(
+ self,
+ messages: List[Dict[str, str]],
+ max_tokens: Optional[int] = None,
+ temperature: Optional[float] = None,
+ **kwargs
+ ):
+ """
+ 流式对话生成
+ """
+ pass
+
+ @abstractmethod
+ def embed(self, texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
+ """
+ 文本嵌入生成
+
+ Args:
+ texts: 单个文本或文本列表
+
+ Returns:
+ 嵌入向量或嵌入向量列表
+ """
+ pass
+
+ def get_model_info(self) -> Dict[str, Any]:
+ """
+ 获取模型信息
+
+ Returns:
+ 包含模型名称和配置的字典
+ """
+ return {
+ "model_name": self.model_name,
+ "config": self.config,
+ "class_name": self.__class__.__name__
+ }
+
+ def validate_input(self, input_text: str, max_length: Optional[int] = None) -> bool:
+ """
+ 验证输入文本
+
+ Args:
+ input_text: 输入文本
+ max_length: 最大长度限制
+
+ Returns:
+ 是否验证通过
+ """
+ if not isinstance(input_text, str):
+ self.logger.error("输入必须是字符串类型")
+ return False
+
+ if not input_text.strip():
+ self.logger.error("输入文本不能为空")
+ return False
+
+ if max_length and len(input_text) > max_length:
+ self.logger.error(f"输入文本长度超过限制: {len(input_text)} > {max_length}")
+ return False
+
+ return True
+
+ def format_messages(self, user_message: str, system_message: Optional[str] = None) -> List[Dict[str, str]]:
+ """
+ 格式化对话消息
+
+ Args:
+ user_message: 用户消息
+ system_message: 系统消息(可选)
+
+ Returns:
+ 格式化后的消息列表
+ """
+ messages = []
+
+ if system_message:
+ messages.append({"role": "system", "content": system_message})
+
+ messages.append({"role": "user", "content": user_message})
+
+ return messages
+
+ def __str__(self) -> str:
+ """字符串表示"""
+ return f"{self.__class__.__name__}(model_name='{self.model_name}')"
+
+ def __repr__(self) -> str:
+ """详细字符串表示"""
+ return f"{self.__class__.__name__}(model_name='{self.model_name}', config={self.config})"
diff --git a/rag_factory/llms/openai_llm.py b/rag_factory/llms/openai_llm.py
new file mode 100644
index 0000000..6069155
--- /dev/null
+++ b/rag_factory/llms/openai_llm.py
@@ -0,0 +1,264 @@
+import openai
+from typing import Dict, Any, List, Optional, Union, Tuple
+from .llm_base import LLMBase
+
+class OpenAILLM(LLMBase):
+ """
+ OpenAI LLM对话模型
+ """
+
+ def __init__(
+ self,
+ model_name: str = "gpt-4o-mini",
+ api_key: Optional[str] = None,
+ base_url: Optional[str] = None,
+ organization: Optional[str] = None,
+ max_retries: int = 3,
+ timeout: float = 60.0,
+ **kwargs
+ ):
+ """
+ 初始化OpenAI LLM
+
+ Args:
+ model_name: 模型名称,如 gpt-3.5-turbo, gpt-4 等
+ api_key: OpenAI API密钥
+ base_url: API基础URL
+ organization: 组织ID(可选)
+ max_retries: 最大重试次数
+ timeout: 请求超时时间
+ **kwargs: 其他配置参数
+ """
+ super().__init__(model_name, **kwargs)
+
+ # 初始化OpenAI客户端
+ self.client = openai.OpenAI(
+ api_key=api_key,
+ base_url=base_url,
+ organization=organization,
+ max_retries=max_retries,
+ timeout=timeout
+ )
+
+ # 默认参数
+ self.default_max_tokens = kwargs.get('max_tokens', 2000)
+ self.default_temperature = kwargs.get('temperature', 0.7)
+
+ self.logger.info(f"OpenAI LLM初始化完成,模型: {model_name}")
+
+
+ def chat(
+ self,
+ messages: List[Dict[str, str]],
+ max_tokens: Optional[int] = None,
+ temperature: Optional[float] = None,
+ return_token_count: bool = False,
+ **kwargs
+ ) -> Union[str, Tuple[str, Dict[str, int]]]:
+ """
+ 对话式生成
+
+ Args:
+ messages: 对话消息列表,格式如[{"role": "user", "content": "问题"}]
+ max_tokens: 最大生成token数
+ temperature: 生成温度参数
+ return_token_count: 是否返回token统计信息
+ **kwargs: 其他生成参数
+
+ Returns:
+ 生成的回复文本,如果return_token_count为True则返回(文本, token统计)
+ """
+ if not messages or not isinstance(messages, list):
+ raise ValueError("消息列表不能为空且必须是列表格式")
+
+ # 验证消息格式
+ for msg in messages:
+ if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
+ raise ValueError("消息格式错误,必须包含role和content字段")
+ if not self.validate_input(msg['content']):
+ raise ValueError(f"消息内容验证失败: {msg['content']}")
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=messages,
+ max_tokens=max_tokens or self.default_max_tokens,
+ temperature=temperature or self.default_temperature,
+ **kwargs
+ )
+
+ result = response.choices[0].message.content.strip()
+
+ if return_token_count:
+ # 获取输出token数
+ input_tokens = response.usage.prompt_tokens if response.usage else 0
+ output_tokens = response.usage.completion_tokens if response.usage else 0
+ total_tokens = response.usage.total_tokens if response.usage else 0
+
+ token_stats = {
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "total_tokens": total_tokens
+ }
+
+ self.logger.debug(f"对话生成成功,长度: {len(result)}, 输入tokens: {input_tokens}, 输出tokens: {output_tokens}")
+ return result, token_stats
+ else:
+ self.logger.debug(f"对话生成成功,长度: {len(result)}")
+ return result
+
+ except Exception as e:
+ self.logger.error(f"对话生成失败: {str(e)}")
+ raise
+
+
+ def stream_chat(
+ self,
+ messages: List[Dict[str, str]],
+ max_tokens: Optional[int] = None,
+ temperature: Optional[float] = None,
+ return_token_count: bool = False,
+ **kwargs
+ ):
+ """
+ 流式对话生成
+
+ Args:
+ messages: 对话消息列表
+ max_tokens: 最大生成token数
+ temperature: 生成温度参数
+ return_token_count: 是否返回token统计信息
+ **kwargs: 其他生成参数
+
+ Yields:
+ 生成的文本片段,如果return_token_count为True则在流式输出结束后yield token统计
+ """
+ if not messages or not isinstance(messages, list):
+ raise ValueError("消息列表不能为空且必须是列表格式")
+
+ # 验证消息格式
+ for msg in messages:
+ if not isinstance(msg, dict) or 'role' not in msg or 'content' not in msg:
+ raise ValueError("消息格式错误,必须包含role和content字段")
+ if not self.validate_input(msg['content']):
+ raise ValueError(f"消息内容验证失败: {msg['content']}")
+
+ try:
+ stream = self.client.chat.completions.create(
+ model=self.model_name,
+ messages=messages,
+ max_tokens=max_tokens or self.default_max_tokens,
+ temperature=temperature or self.default_temperature,
+ stream=True,
+ stream_options={"include_usage": True} if return_token_count else None,
+ **kwargs
+ )
+
+ full_response = ""
+
+ for chunk in stream:
+ # 检查choices是否存在以及是否有内容
+ if chunk.choices and len(chunk.choices) > 0:
+ delta = chunk.choices[0].delta
+ if hasattr(delta, 'content') and delta.content is not None:
+ content = delta.content
+ full_response += content
+ yield content
+
+ # 检查是否有usage信息(在流的最后一个chunk中)
+ if return_token_count and hasattr(chunk, 'usage') and chunk.usage is not None:
+ input_tokens = chunk.usage.prompt_tokens if chunk.usage else 0
+ output_tokens = chunk.usage.completion_tokens if chunk.usage else 0
+ total_tokens = chunk.usage.total_tokens if chunk.usage else 0
+
+ token_stats = {
+ "input_tokens": input_tokens,
+ "output_tokens": output_tokens,
+ "total_tokens": total_tokens
+ }
+
+ self.logger.debug(f"流式对话生成成功,长度: {len(full_response)}, 输入tokens: {input_tokens}, 输出tokens: {output_tokens}")
+ yield token_stats
+
+ except Exception as e:
+ self.logger.error(f"流式对话生成失败: {str(e)}")
+ raise
+
+
+ def embed(self, texts: Union[str, List[str]]) -> Union[List[float], List[List[float]]]:
+ """
+ 文本嵌入生成
+
+ Args:
+ texts: 单个文本或文本列表
+
+ Returns:
+ 嵌入向量或嵌入向量列表
+ """
+ # 检查当前模型是否为嵌入模型
+ embedding_models = ["text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large"]
+ if self.model_name not in embedding_models:
+ warning_msg = f"警告:当前模型'{self.model_name}'不是嵌入模型,建议使用嵌入专用模型如 text-embedding-ada-002"
+ self.logger.warning(warning_msg)
+ raise ValueError(f"当前模型'{self.model_name}'不支持嵌入生成,请使用嵌入专用模型")
+
+ # 统一处理为列表格式
+ is_single = isinstance(texts, str)
+ text_list = [texts] if is_single else texts
+
+ # 验证输入
+ for text in text_list:
+ if not self.validate_input(text):
+ raise ValueError(f"文本内容验证失败: {text}")
+
+ try:
+ # 使用当前嵌入模型生成嵌入
+ response = self.client.embeddings.create(
+ model=self.model_name,
+ input=text_list
+ )
+
+ embeddings = [data.embedding for data in response.data]
+
+ # 根据输入格式返回结果
+ result = embeddings[0] if is_single else embeddings
+ self.logger.debug(f"嵌入生成成功,文本数量: {len(text_list)}")
+ return result
+
+ except Exception as e:
+ self.logger.error(f"嵌入生成失败: {str(e)}")
+ raise
+
+ def get_available_models(self) -> List[str]:
+ """
+ 获取可用的模型列表
+
+ Returns:
+ 可用模型名称列表
+ """
+ try:
+ models = self.client.models.list()
+ model_names = [model.id for model in models.data]
+ self.logger.debug(f"获取到{len(model_names)}个可用模型")
+ return model_names
+ except Exception as e:
+ self.logger.error(f"获取模型列表失败: {str(e)}")
+ return []
+
+ def get_model_info(self) -> Dict[str, Any]:
+ """
+ 获取模型信息
+
+ Returns:
+ 包含模型名称和配置的字典
+ """
+ info = super().get_model_info()
+ info.update({
+ "api_base": getattr(self.client, 'base_url', None),
+ "organization": getattr(self.client, 'organization', None),
+ "max_retries": getattr(self.client, 'max_retries', None),
+ "timeout": getattr(self.client, 'timeout', None),
+ "default_max_tokens": self.default_max_tokens,
+ "default_temperature": self.default_temperature
+ })
+ return info
diff --git a/rag_factory/llms/registry.py b/rag_factory/llms/registry.py
new file mode 100644
index 0000000..249f03d
--- /dev/null
+++ b/rag_factory/llms/registry.py
@@ -0,0 +1,86 @@
+from .openai_llm import OpenAILLM
+from .llm_base import LLMBase
+from typing import Dict, Type, Any, Optional, List
+import logging
+
+logging.basicConfig(
+ level=logging.INFO, # 设置最低输出级别为 INFO
+ format="%(asctime)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+class LLMRegistry:
+ """LLM模型注册表,用于管理和创建不同类型的LLM模型"""
+ _llms: Dict[str, Type[LLMBase]] = {}
+
+ @classmethod
+ def register(cls, name: str, llm_class: Type[LLMBase]):
+ """注册LLM模型类
+
+ Args:
+ name: 模型名称
+ llm_class: LLM模型类
+ """
+ cls._llms[name] = llm_class
+
+ @classmethod
+ def create(cls, name: str, **kwargs) -> LLMBase:
+ """创建LLM实例
+
+ Args:
+ name: 模型名称
+ **kwargs: 模型初始化参数
+
+ Returns:
+ LLM实例
+
+ Raises:
+ ValueError: 当模型名称不存在时
+ """
+ if name not in cls._llms:
+ available_llms = list(cls._llms.keys())
+ raise ValueError(f"LLM模型 '{name}' 未注册。可用的模型: {available_llms}")
+
+ llm_class = cls._llms[name]
+
+ return llm_class(**kwargs)
+
+ @classmethod
+ def list_llms(cls) -> List[str]:
+ """列出所有已注册的LLM模型名称
+
+ Returns:
+ 已注册的模型名称列表
+ """
+ return list(cls._llms.keys())
+
+ @classmethod
+ def is_registered(cls, name: str) -> bool:
+ """检查模型是否已注册
+
+ Args:
+ name: 模型名称
+
+ Returns:
+ 如果已注册返回True,否则返回False
+ """
+ return name in cls._llms
+
+ @classmethod
+ def unregister(cls, name: str) -> bool:
+ """取消注册模型
+
+ Args:
+ name: 模型名称
+
+ Returns:
+ 成功取消注册返回True,模型不存在返回False
+ """
+ if name in cls._llms:
+ del cls._llms[name]
+ return True
+ return False
+
+
+# 注册默认的LLM模型
+LLMRegistry.register("openai", OpenAILLM)
\ No newline at end of file
diff --git a/rag_factory/llms/test.py b/rag_factory/llms/test.py
new file mode 100644
index 0000000..49332ba
--- /dev/null
+++ b/rag_factory/llms/test.py
@@ -0,0 +1,75 @@
+import os
+from pprint import pprint
+from .openai_llm import OpenAILLM # 你的类所在文件
+
+# ==== 配置 ====
+# API_KEY = os.getenv("OPENAI_API_KEY") # 或直接写成 "sk-xxxx"
+API_KEY = "sk-2T06b7c7f9c3870049fbf8fada596b0f8ef908d1e233KLY2"
+# BASE_URL = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
+BASE_URL = "https://api.gptsapi.net/v1"
+
+def test_openai_llm():
+ # 初始化普通对话模型
+ llm = OpenAILLM(
+ model_name="gpt-4o-mini",
+ api_key=API_KEY,
+ base_url=BASE_URL
+ )
+
+ # ==== 1. 测试 chat ====
+ messages = [{"role": "user", "content": "你好,请用一句话介绍你自己"}]
+ print("\n=== chat (普通) ===")
+ result = llm.chat(messages)
+ print("回复:", result)
+
+ # ==== 2. 测试 chat + token 统计 ====
+ print("\n=== chat (返回 token 数) ===")
+ result, token_stats = llm.chat(messages, return_token_count=True)
+ print("回复:", result)
+ print("token统计:", token_stats)
+
+ # ==== 3. 测试 stream_chat (普通流式) ====
+ print("\n=== stream_chat (普通流式) ===")
+ for chunk in llm.stream_chat(messages):
+ print(chunk, end="", flush=True)
+ print()
+
+ # ==== 4. 测试 stream_chat (返回 token 数) ====
+ print("\n=== stream_chat (返回 token 数) ===")
+ for item in llm.stream_chat(messages, return_token_count=True):
+ if isinstance(item, dict) and "input_tokens" in item:
+ # 这是最后的token统计信息
+ print("\nToken统计:", item)
+ else:
+ # 这是文本片段
+ print(item, end="", flush=True)
+
+
+ # ==== 7. 测试 get_model_info ====
+ print("\n=== get_model_info ===")
+ pprint(llm.get_model_info())
+
+ # ==== 8. 测试 get_available_models ====
+ print("\n=== get_available_models ===")
+ model_list = llm.get_available_models()
+ print("可用模型数量:", len(model_list))
+ print("前10个模型:", model_list[:10])
+
+ # ==== 9. 测试 embed ====
+ # 使用专门的嵌入模型
+ embed_llm = OpenAILLM(
+ model_name="text-embedding-3-small",
+ api_key=API_KEY,
+ base_url=BASE_URL
+ )
+ print("\n=== embed (单条) ===")
+ vec = embed_llm.embed("这是一个测试文本")
+ print("向量维度:", len(vec))
+
+ print("\n=== embed (多条) ===")
+ vecs = embed_llm.embed(["第一句", "第二句"])
+ print("向量数量:", len(vecs), "每个向量维度:", len(vecs[0]))
+
+
+if __name__ == "__main__":
+ test_openai_llm()
diff --git a/rag_factory/parser/Base.py b/rag_factory/parser/Base.py
deleted file mode 100644
index 9ce5ce9..0000000
--- a/rag_factory/parser/Base.py
+++ /dev/null
@@ -1,528 +0,0 @@
-# type: ignore
-"""
-Generic Document Parser Utility
-
-This module provides functionality for parsing PDF and image documents using MinerU 2.0 library,
-and converts the parsing results into markdown and JSON formats
-
-Note: MinerU 2.0 no longer includes LibreOffice document conversion module.
-For Office documents (.doc, .docx, .ppt, .pptx), please convert them to PDF format first.
-"""
-
-from __future__ import annotations
-
-
-import json
-import argparse
-import base64
-import subprocess
-import tempfile
-import logging
-from pathlib import Path
-from typing import (
- Dict,
- List,
- Optional,
- Union,
- Tuple,
- Any,
- TypeVar,
-)
-
-T = TypeVar("T")
-
-
-class Parser:
- """
- Base class for document parsing utilities.
-
- Defines common functionality and constants for parsing different document types.
- """
-
- # Define common file formats
- OFFICE_FORMATS = {".doc", ".docx", ".ppt", ".pptx", ".xls", ".xlsx"}
- IMAGE_FORMATS = {".png", ".jpeg", ".jpg", ".bmp", ".tiff", ".tif", ".gif", ".webp"}
- TEXT_FORMATS = {".txt", ".md"}
-
- # Class-level logger
- logger = logging.getLogger(__name__)
-
- def __init__(self) -> None:
- """Initialize the base parser."""
- pass
-
- @staticmethod
- def convert_office_to_pdf(
- doc_path: Union[str, Path], output_dir: Optional[str] = None
- ) -> Path:
- """
- Convert Office document (.doc, .docx, .ppt, .pptx, .xls, .xlsx) to PDF.
- Requires LibreOffice to be installed.
-
- Args:
- doc_path: Path to the Office document file
- output_dir: Output directory for the PDF file
-
- Returns:
- Path to the generated PDF file
- """
- try:
- # Convert to Path object for easier handling
- doc_path = Path(doc_path)
- if not doc_path.exists():
- raise FileNotFoundError(f"Office document does not exist: {doc_path}")
-
- name_without_suff = doc_path.stem
-
- # Prepare output directory
- if output_dir:
- base_output_dir = Path(output_dir)
- else:
- base_output_dir = doc_path.parent / "libreoffice_output"
-
- base_output_dir.mkdir(parents=True, exist_ok=True)
-
- # Create temporary directory for PDF conversion
- with tempfile.TemporaryDirectory() as temp_dir:
- temp_path = Path(temp_dir)
-
- # Convert to PDF using LibreOffice
- logging.info(f"Converting {doc_path.name} to PDF using LibreOffice...")
-
- # Prepare subprocess parameters to hide console window on Windows
- import platform
-
- # Try LibreOffice commands in order of preference
- commands_to_try = ["libreoffice", "soffice"]
-
- conversion_successful = False
- for cmd in commands_to_try:
- try:
- convert_cmd = [
- cmd,
- "--headless",
- "--convert-to",
- "pdf",
- "--outdir",
- str(temp_path),
- str(doc_path),
- ]
-
- # Prepare conversion subprocess parameters
- convert_subprocess_kwargs = {
- "capture_output": True,
- "text": True,
- "timeout": 60, # 60 second timeout
- "encoding": "utf-8",
- "errors": "ignore",
- }
-
- # Hide console window on Windows
- if platform.system() == "Windows":
- convert_subprocess_kwargs["creationflags"] = (
- subprocess.CREATE_NO_WINDOW
- )
-
- result = subprocess.run(
- convert_cmd, **convert_subprocess_kwargs
- )
-
- if result.returncode == 0:
- conversion_successful = True
- logging.info(
- f"Successfully converted {doc_path.name} to PDF using {cmd}"
- )
- break
- else:
- logging.warning(
- f"LibreOffice command '{cmd}' failed: {result.stderr}"
- )
- except FileNotFoundError:
- logging.warning(f"LibreOffice command '{cmd}' not found")
- except subprocess.TimeoutExpired:
- logging.warning(f"LibreOffice command '{cmd}' timed out")
- except Exception as e:
- logging.error(
- f"LibreOffice command '{cmd}' failed with exception: {e}"
- )
-
- if not conversion_successful:
- raise RuntimeError(
- f"LibreOffice conversion failed for {doc_path.name}. "
- f"Please ensure LibreOffice is installed:\n"
- "- Windows: Download from https://www.libreoffice.org/download/download/\n"
- "- macOS: brew install --cask libreoffice\n"
- "- Ubuntu/Debian: sudo apt-get install libreoffice\n"
- "- CentOS/RHEL: sudo yum install libreoffice\n"
- "Alternatively, convert the document to PDF manually."
- )
-
- # Find the generated PDF
- pdf_files = list(temp_path.glob("*.pdf"))
- if not pdf_files:
- raise RuntimeError(
- f"PDF conversion failed for {doc_path.name} - no PDF file generated. "
- f"Please check LibreOffice installation or try manual conversion."
- )
-
- pdf_path = pdf_files[0]
- logging.info(
- f"Generated PDF: {pdf_path.name} ({pdf_path.stat().st_size} bytes)"
- )
-
- # Validate the generated PDF
- if pdf_path.stat().st_size < 100: # Very small file, likely empty
- raise RuntimeError(
- "Generated PDF appears to be empty or corrupted. "
- "Original file may have issues or LibreOffice conversion failed."
- )
-
- # Copy PDF to final output directory
- final_pdf_path = base_output_dir / f"{name_without_suff}.pdf"
- import shutil
-
- shutil.copy2(pdf_path, final_pdf_path)
-
- return final_pdf_path
-
- except Exception as e:
- logging.error(f"Error in convert_office_to_pdf: {str(e)}")
- raise
-
- @staticmethod
- def convert_text_to_pdf(
- text_path: Union[str, Path], output_dir: Optional[str] = None
- ) -> Path:
- """
- Convert text file (.txt, .md) to PDF using ReportLab with full markdown support.
-
- Args:
- text_path: Path to the text file
- output_dir: Output directory for the PDF file
-
- Returns:
- Path to the generated PDF file
- """
- try:
- text_path = Path(text_path)
- if not text_path.exists():
- raise FileNotFoundError(f"Text file does not exist: {text_path}")
-
- # Supported text formats
- supported_text_formats = {".txt", ".md"}
- if text_path.suffix.lower() not in supported_text_formats:
- raise ValueError(f"Unsupported text format: {text_path.suffix}")
-
- # Read the text content
- try:
- with open(text_path, "r", encoding="utf-8") as f:
- text_content = f.read()
- except UnicodeDecodeError:
- # Try with different encodings
- for encoding in ["gbk", "latin-1", "cp1252"]:
- try:
- with open(text_path, "r", encoding=encoding) as f:
- text_content = f.read()
- logging.info(f"Successfully read file with {encoding} encoding")
- break
- except UnicodeDecodeError:
- continue
- else:
- raise RuntimeError(
- f"Could not decode text file {text_path.name} with any supported encoding"
- )
-
- # Prepare output directory
- if output_dir:
- base_output_dir = Path(output_dir)
- else:
- base_output_dir = text_path.parent / "reportlab_output"
-
- base_output_dir.mkdir(parents=True, exist_ok=True)
- pdf_path = base_output_dir / f"{text_path.stem}.pdf"
-
- # Convert text to PDF
- logging.info(f"Converting {text_path.name} to PDF...")
-
- try:
- from reportlab.lib.pagesizes import A4
- from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer
- from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
- from reportlab.lib.units import inch
- from reportlab.pdfbase import pdfmetrics
-
- # Create PDF document
- doc = SimpleDocTemplate(
- str(pdf_path),
- pagesize=A4,
- leftMargin=inch,
- rightMargin=inch,
- topMargin=inch,
- bottomMargin=inch,
- )
-
- # Get styles
- styles = getSampleStyleSheet()
- normal_style = styles["Normal"]
- heading_style = styles["Heading1"]
-
- # Try to register a font that supports Chinese characters
- try:
- # Try to use system fonts that support Chinese
- import platform
-
- system = platform.system()
- if system == "Windows":
- # Try common Windows fonts
- for font_name in ["SimSun", "SimHei", "Microsoft YaHei"]:
- try:
- from reportlab.pdfbase.cidfonts import (
- UnicodeCIDFont,
- )
-
- pdfmetrics.registerFont(UnicodeCIDFont(font_name))
- normal_style.fontName = font_name
- heading_style.fontName = font_name
- break
- except Exception:
- continue
- elif system == "Darwin": # macOS
- for font_name in ["STSong-Light", "STHeiti"]:
- try:
- from reportlab.pdfbase.cidfonts import (
- UnicodeCIDFont,
- )
-
- pdfmetrics.registerFont(UnicodeCIDFont(font_name))
- normal_style.fontName = font_name
- heading_style.fontName = font_name
- break
- except Exception:
- continue
- except Exception:
- pass # Use default fonts if Chinese font setup fails
-
- # Build content
- story = []
-
- # Handle markdown or plain text
- if text_path.suffix.lower() == ".md":
- # Handle markdown content - simplified implementation
- lines = text_content.split("\n")
- for line in lines:
- line = line.strip()
- if not line:
- story.append(Spacer(1, 12))
- continue
-
- # Headers
- if line.startswith("#"):
- level = len(line) - len(line.lstrip("#"))
- header_text = line.lstrip("#").strip()
- if header_text:
- header_style = ParagraphStyle(
- name=f"Heading{level}",
- parent=heading_style,
- fontSize=max(16 - level, 10),
- spaceAfter=8,
- spaceBefore=16 if level <= 2 else 12,
- )
- story.append(Paragraph(header_text, header_style))
- else:
- # Regular text
- story.append(Paragraph(line, normal_style))
- story.append(Spacer(1, 6))
- else:
- # Handle plain text files (.txt)
- logging.info(
- f"Processing plain text file with {len(text_content)} characters..."
- )
-
- # Split text into lines and process each line
- lines = text_content.split("\n")
- line_count = 0
-
- for line in lines:
- line = line.rstrip()
- line_count += 1
-
- # Empty lines
- if not line.strip():
- story.append(Spacer(1, 6))
- continue
-
- # Regular text lines
- # Escape special characters for ReportLab
- safe_line = (
- line.replace("&", "&")
- .replace("<", "<")
- .replace(">", ">")
- )
-
- # Create paragraph
- story.append(Paragraph(safe_line, normal_style))
- story.append(Spacer(1, 3))
-
- logging.info(f"Added {line_count} lines to PDF")
-
- # If no content was added, add a placeholder
- if not story:
- story.append(Paragraph("(Empty text file)", normal_style))
-
- # Build PDF
- doc.build(story)
- logging.info(
- f"Successfully converted {text_path.name} to PDF ({pdf_path.stat().st_size / 1024:.1f} KB)"
- )
-
- except ImportError:
- raise RuntimeError(
- "reportlab is required for text-to-PDF conversion. "
- "Please install it using: pip install reportlab"
- )
- except Exception as e:
- raise RuntimeError(
- f"Failed to convert text file {text_path.name} to PDF: {str(e)}"
- )
-
- # Validate the generated PDF
- if not pdf_path.exists() or pdf_path.stat().st_size < 100:
- raise RuntimeError(
- f"PDF conversion failed for {text_path.name} - generated PDF is empty or corrupted."
- )
-
- return pdf_path
-
- except Exception as e:
- logging.error(f"Error in convert_text_to_pdf: {str(e)}")
- raise
-
- @staticmethod
- def _process_inline_markdown(text: str) -> str:
- """
- Process inline markdown formatting (bold, italic, code, links)
-
- Args:
- text: Raw text with markdown formatting
-
- Returns:
- Text with ReportLab markup
- """
- import re
-
- # Escape special characters for ReportLab
- text = text.replace("&", "&").replace("<", "<").replace(">", ">")
-
- # Bold text: **text** or __text__
- text = re.sub(r"\*\*(.*?)\*\*", r"\1", text)
- text = re.sub(r"__(.*?)__", r"\1", text)
-
- # Italic text: *text* or _text_ (but not in the middle of words)
- text = re.sub(r"(?\1", text)
- text = re.sub(r"(?\1", text)
-
- # Inline code: `code`
- text = re.sub(
- r"`([^`]+?)`",
- r'\1',
- text,
- )
-
- # Links: [text](url) - convert to text with URL annotation
- def link_replacer(match):
- link_text = match.group(1)
- url = match.group(2)
- return f'{link_text}'
-
- text = re.sub(r"\[([^\]]+?)\]\(([^)]+?)\)", link_replacer, text)
-
- # Strikethrough: ~~text~~
- text = re.sub(r"~~(.*?)~~", r"\1", text)
-
- return text
-
- def parse_pdf(
- self,
- pdf_path: Union[str, Path],
- output_dir: Optional[str] = None,
- method: str = "auto",
- lang: Optional[str] = None,
- **kwargs,
- ) -> List[Dict[str, Any]]:
- """
- Abstract method to parse PDF document.
- Must be implemented by subclasses.
-
- Args:
- pdf_path: Path to the PDF file
- output_dir: Output directory path
- method: Parsing method (auto, txt, ocr)
- lang: Document language for OCR optimization
- **kwargs: Additional parameters for parser-specific command
-
- Returns:
- List[Dict[str, Any]]: List of content blocks
- """
- raise NotImplementedError("parse_pdf must be implemented by subclasses")
-
- def parse_image(
- self,
- image_path: Union[str, Path],
- output_dir: Optional[str] = None,
- lang: Optional[str] = None,
- **kwargs,
- ) -> List[Dict[str, Any]]:
- """
- Abstract method to parse image document.
- Must be implemented by subclasses.
-
- Note: Different parsers may support different image formats.
- Check the specific parser's documentation for supported formats.
-
- Args:
- image_path: Path to the image file
- output_dir: Output directory path
- lang: Document language for OCR optimization
- **kwargs: Additional parameters for parser-specific command
-
- Returns:
- List[Dict[str, Any]]: List of content blocks
- """
- raise NotImplementedError("parse_image must be implemented by subclasses")
-
- def parse_document(
- self,
- file_path: Union[str, Path],
- method: str = "auto",
- output_dir: Optional[str] = None,
- lang: Optional[str] = None,
- **kwargs,
- ) -> List[Dict[str, Any]]:
- """
- Abstract method to parse a document.
- Must be implemented by subclasses.
-
- Args:
- file_path: Path to the file to be parsed
- method: Parsing method (auto, txt, ocr)
- output_dir: Output directory path
- lang: Document language for OCR optimization
- **kwargs: Additional parameters for parser-specific command
-
- Returns:
- List[Dict[str, Any]]: List of content blocks
- """
- raise NotImplementedError("parse_document must be implemented by subclasses")
-
- def check_installation(self) -> bool:
- """
- Abstract method to check if the parser is properly installed.
- Must be implemented by subclasses.
-
- Returns:
- bool: True if installation is valid, False otherwise
- """
- raise NotImplementedError(
- "check_installation must be implemented by subclasses"
- )
-
diff --git a/rag_factory/parser/Parser_Docling.py b/rag_factory/parser/Parser_Docling.py
deleted file mode 100644
index f33886c..0000000
--- a/rag_factory/parser/Parser_Docling.py
+++ /dev/null
@@ -1,489 +0,0 @@
-from __future__ import annotations
-import base64
-import subprocess
-from pathlib import Path
-from typing import Union, List, Dict, Any, Optional, Tuple
-import json
-import logging
-from .Base import Parser
-
-
-class DoclingParser(Parser):
- """
- Docling document parsing utility class.
-
- Specialized in parsing Office documents and HTML files, converting the content
- into structured data and generating markdown and JSON output.
- """
-
- # Define Docling-specific formats
- HTML_FORMATS = {".html", ".htm", ".xhtml"}
-
- def __init__(self) -> None:
- """Initialize DoclingParser"""
- super().__init__()
-
- def parse_pdf(
- self,
- pdf_path: Union[str, Path],
- output_dir: Optional[str] = None,
- method: str = "auto",
- lang: Optional[str] = None,
- **kwargs,
- ) -> List[Dict[str, Any]]:
- """
- Parse PDF document using Docling
-
- Args:
- pdf_path: Path to the PDF file
- output_dir: Output directory path
- method: Parsing method (auto, txt, ocr)
- lang: Document language for OCR optimization
- **kwargs: Additional parameters for docling command
-
- Returns:
- List[Dict[str, Any]]: List of content blocks
- """
- try:
- # Convert to Path object for easier handling
- pdf_path = Path(pdf_path)
- if not pdf_path.exists():
- raise FileNotFoundError(f"PDF file does not exist: {pdf_path}")
-
- name_without_suff = pdf_path.stem
-
- # Prepare output directory
- if output_dir:
- base_output_dir = Path(output_dir)
- else:
- base_output_dir = pdf_path.parent / "docling_output"
-
- base_output_dir.mkdir(parents=True, exist_ok=True)
-
- # Run docling command
- self._run_docling_command(
- input_path=pdf_path,
- output_dir=base_output_dir,
- file_stem=name_without_suff,
- **kwargs,
- )
-
- # Read the generated output files
- content_list, _ = self._read_output_files(
- base_output_dir, name_without_suff
- )
- return content_list
-
- except Exception as e:
- logging.error(f"Error in parse_pdf: {str(e)}")
- raise
-
- def parse_document(
- self,
- file_path: Union[str, Path],
- method: str = "auto",
- output_dir: Optional[str] = None,
- lang: Optional[str] = None,
- **kwargs,
- ) -> List[Dict[str, Any]]:
- """
- Parse document using Docling based on file extension
-
- Args:
- file_path: Path to the file to be parsed
- method: Parsing method
- output_dir: Output directory path
- lang: Document language for optimization
- **kwargs: Additional parameters for docling command
-
- Returns:
- List[Dict[str, Any]]: List of content blocks
- """
- # Convert to Path object
- file_path = Path(file_path)
- if not file_path.exists():
- raise FileNotFoundError(f"File does not exist: {file_path}")
-
- # Get file extension
- ext = file_path.suffix.lower()
-
- # Choose appropriate parser based on file type
- if ext == ".pdf":
- return self.parse_pdf(file_path, output_dir, method, lang, **kwargs)
- elif ext in self.OFFICE_FORMATS:
- return self.parse_office_doc(file_path, output_dir, lang, **kwargs)
- elif ext in self.HTML_FORMATS:
- return self.parse_html(file_path, output_dir, lang, **kwargs)
- else:
- raise ValueError(
- f"Unsupported file format: {ext}. "
- f"Docling only supports PDF files, Office formats ({', '.join(self.OFFICE_FORMATS)}) "
- f"and HTML formats ({', '.join(self.HTML_FORMATS)})"
- )
-
- def _run_docling_command(
- self,
- input_path: Union[str, Path],
- output_dir: Union[str, Path],
- file_stem: str,
- **kwargs,
- ) -> None:
- """
- Run docling command line tool
-
- Args:
- input_path: Path to input file or directory
- output_dir: Output directory path
- file_stem: File stem for creating subdirectory
- **kwargs: Additional parameters for docling command
- """
- # Create subdirectory structure similar to MinerU
- file_output_dir = Path(output_dir) / file_stem / "docling"
- file_output_dir.mkdir(parents=True, exist_ok=True)
-
- cmd_json = [
- "docling",
- "--output",
- str(file_output_dir),
- "--to",
- "json",
- str(input_path),
- ]
- cmd_md = [
- "docling",
- "--output",
- str(file_output_dir),
- "--to",
- "md",
- str(input_path),
- ]
-
- try:
- # Prepare subprocess parameters to hide console window on Windows
- import platform
-
- docling_subprocess_kwargs = {
- "capture_output": True,
- "text": True,
- "check": True,
- "encoding": "utf-8",
- "errors": "ignore",
- }
-
- # Hide console window on Windows
- if platform.system() == "Windows":
- docling_subprocess_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW
-
- result_json = subprocess.run(cmd_json, **docling_subprocess_kwargs)
- result_md = subprocess.run(cmd_md, **docling_subprocess_kwargs)
- logging.info("Docling command executed successfully")
- if result_json.stdout:
- logging.debug(f"JSON cmd output: {result_json.stdout}")
- if result_md.stdout:
- logging.debug(f"Markdown cmd output: {result_md.stdout}")
- except subprocess.CalledProcessError as e:
- logging.error(f"Error running docling command: {e}")
- if e.stderr:
- logging.error(f"Error details: {e.stderr}")
- raise
- except FileNotFoundError:
- raise RuntimeError(
- "docling command not found. Please ensure Docling is properly installed."
- )
-
- def _read_output_files(
- self,
- output_dir: Path,
- file_stem: str,
- ) -> Tuple[List[Dict[str, Any]], str]:
- """
- Read the output files generated by docling and convert to MinerU format
-
- Args:
- output_dir: Output directory
- file_stem: File name without extension
-
- Returns:
- Tuple containing (content list JSON, Markdown text)
- """
- # Use subdirectory structure similar to MinerU
- file_subdir = output_dir / file_stem / "docling"
- md_file = file_subdir / f"{file_stem}.md"
- json_file = file_subdir / f"{file_stem}.json"
-
- # Read markdown content
- md_content = ""
- if md_file.exists():
- try:
- with open(md_file, "r", encoding="utf-8") as f:
- md_content = f.read()
- except Exception as e:
- logging.warning(f"Could not read markdown file {md_file}: {e}")
-
- # Read JSON content and convert format
- content_list = []
- if json_file.exists():
- try:
- with open(json_file, "r", encoding="utf-8") as f:
- docling_content = json.load(f)
- # Convert docling format to minerU format
- content_list = self.read_from_block_recursive(
- docling_content["body"],
- "body",
- file_subdir,
- 0,
- "0",
- docling_content,
- )
- except Exception as e:
- logging.warning(f"Could not read or convert JSON file {json_file}: {e}")
- return content_list, md_content
-
- def read_from_block_recursive(
- self,
- block,
- type: str,
- output_dir: Path,
- cnt: int,
- num: str,
- docling_content: Dict[str, Any],
- ) -> List[Dict[str, Any]]:
- content_list = []
- if not block.get("children"):
- cnt += 1
- content_list.append(self.read_from_block(block, type, output_dir, cnt, num))
- else:
- if type not in ["groups", "body"]:
- cnt += 1
- content_list.append(
- self.read_from_block(block, type, output_dir, cnt, num)
- )
- members = block["children"]
- for member in members:
- cnt += 1
- member_tag = member["$ref"]
- member_type = member_tag.split("/")[1]
- member_num = member_tag.split("/")[2]
- member_block = docling_content[member_type][int(member_num)]
- content_list.extend(
- self.read_from_block_recursive(
- member_block,
- member_type,
- output_dir,
- cnt,
- member_num,
- docling_content,
- )
- )
- return content_list
-
- def read_from_block(
- self, block, type: str, output_dir: Path, cnt: int, num: str
- ) -> Dict[str, Any]:
- if type == "texts":
- if block["label"] == "formula":
- return {
- "type": "equation",
- "img_path": "",
- "text": block["orig"],
- "text_format": "unkown",
- "page_idx": cnt // 10,
- }
- else:
- return {
- "type": "text",
- "text": block["orig"],
- "page_idx": cnt // 10,
- }
- elif type == "pictures":
- try:
- base64_uri = block["image"]["uri"]
- base64_str = base64_uri.split(",")[1]
- # Create images directory within the docling subdirectory
- image_dir = output_dir / "images"
- image_dir.mkdir(parents=True, exist_ok=True) # Ensure directory exists
- image_path = image_dir / f"image_{num}.png"
- with open(image_path, "wb") as f:
- f.write(base64.b64decode(base64_str))
- return {
- "type": "image",
- "img_path": str(image_path.resolve()), # Convert to absolute path
- "image_caption": block.get("caption", ""),
- "image_footnote": block.get("footnote", ""),
- "page_idx": cnt // 10,
- }
- except Exception as e:
- logging.warning(f"Failed to process image {num}: {e}")
- return {
- "type": "text",
- "text": f"[Image processing failed: {block.get('caption', '')}]",
- "page_idx": cnt // 10,
- }
- else:
- try:
- return {
- "type": "table",
- "img_path": "",
- "table_caption": block.get("caption", ""),
- "table_footnote": block.get("footnote", ""),
- "table_body": block.get("data", []),
- "page_idx": cnt // 10,
- }
- except Exception as e:
- logging.warning(f"Failed to process table {num}: {e}")
- return {
- "type": "text",
- "text": f"[Table processing failed: {block.get('caption', '')}]",
- "page_idx": cnt // 10,
- }
-
- def parse_office_doc(
- self,
- doc_path: Union[str, Path],
- output_dir: Optional[str] = None,
- lang: Optional[str] = None,
- **kwargs,
- ) -> List[Dict[str, Any]]:
- """
- Parse office document directly using Docling
-
- Supported formats: .doc, .docx, .ppt, .pptx, .xls, .xlsx
-
- Args:
- doc_path: Path to the document file
- output_dir: Output directory path
- lang: Document language for optimization
- **kwargs: Additional parameters for docling command
-
- Returns:
- List[Dict[str, Any]]: List of content blocks
- """
- try:
- # Convert to Path object
- doc_path = Path(doc_path)
- if not doc_path.exists():
- raise FileNotFoundError(f"Document file does not exist: {doc_path}")
-
- if doc_path.suffix.lower() not in self.OFFICE_FORMATS:
- raise ValueError(f"Unsupported office format: {doc_path.suffix}")
-
- name_without_suff = doc_path.stem
-
- # Prepare output directory
- if output_dir:
- base_output_dir = Path(output_dir)
- else:
- base_output_dir = doc_path.parent / "docling_output"
-
- base_output_dir.mkdir(parents=True, exist_ok=True)
-
- # Run docling command
- self._run_docling_command(
- input_path=doc_path,
- output_dir=base_output_dir,
- file_stem=name_without_suff,
- **kwargs,
- )
-
- # Read the generated output files
- content_list, _ = self._read_output_files(
- base_output_dir, name_without_suff
- )
- return content_list
-
- except Exception as e:
- logging.error(f"Error in parse_office_doc: {str(e)}")
- raise
-
- def parse_html(
- self,
- html_path: Union[str, Path],
- output_dir: Optional[str] = None,
- lang: Optional[str] = None,
- **kwargs,
- ) -> List[Dict[str, Any]]:
- """
- Parse HTML document using Docling
-
- Supported formats: .html, .htm, .xhtml
-
- Args:
- html_path: Path to the HTML file
- output_dir: Output directory path
- lang: Document language for optimization
- **kwargs: Additional parameters for docling command
-
- Returns:
- List[Dict[str, Any]]: List of content blocks
- """
- try:
- # Convert to Path object
- html_path = Path(html_path)
- if not html_path.exists():
- raise FileNotFoundError(f"HTML file does not exist: {html_path}")
-
- if html_path.suffix.lower() not in self.HTML_FORMATS:
- raise ValueError(f"Unsupported HTML format: {html_path.suffix}")
-
- name_without_suff = html_path.stem
-
- # Prepare output directory
- if output_dir:
- base_output_dir = Path(output_dir)
- else:
- base_output_dir = html_path.parent / "docling_output"
-
- base_output_dir.mkdir(parents=True, exist_ok=True)
-
- # Run docling command
- self._run_docling_command(
- input_path=html_path,
- output_dir=base_output_dir,
- file_stem=name_without_suff,
- **kwargs,
- )
-
- # Read the generated output files
- content_list, _ = self._read_output_files(
- base_output_dir, name_without_suff
- )
- return content_list
-
- except Exception as e:
- logging.error(f"Error in parse_html: {str(e)}")
- raise
-
- def check_installation(self) -> bool:
- """
- Check if Docling is properly installed
-
- Returns:
- bool: True if installation is valid, False otherwise
- """
- try:
- # Prepare subprocess parameters to hide console window on Windows
- import platform
-
- subprocess_kwargs = {
- "capture_output": True,
- "text": True,
- "check": True,
- "encoding": "utf-8",
- "errors": "ignore",
- }
-
- # Hide console window on Windows
- if platform.system() == "Windows":
- subprocess_kwargs["creationflags"] = subprocess.CREATE_NO_WINDOW
-
- result = subprocess.run(["docling", "--version"], **subprocess_kwargs)
- logging.debug(f"Docling version: {result.stdout.strip()}")
- return True
- except (subprocess.CalledProcessError, FileNotFoundError):
- logging.debug(
- "Docling is not properly installed. "
- "Please ensure it is installed correctly."
- )
- return False
-
-
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/download_model.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/download_model.py
new file mode 100644
index 0000000..32d7087
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/download_model.py
@@ -0,0 +1,24 @@
+from argparse import ArgumentParser
+import os
+
+
+if __name__ == '__main__':
+ parser = ArgumentParser()
+ parser.add_argument('--type', '-t', type=str, default="huggingface")
+ parser.add_argument('--name', '-n', type=str, default="rednote-hilab/dots.ocr")
+ args = parser.parse_args()
+ script_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ print(f"Attention: The model save dir dots.ocr should be replace by a name without `.` like DotsOCR, util we merge our code to transformers.")
+ model_dir = os.path.join(script_dir, "weights/DotsOCR")
+ if not os.path.exists(model_dir):
+ os.makedirs(model_dir)
+ if args.type == "huggingface":
+ from huggingface_hub import snapshot_download
+ snapshot_download(repo_id=args.name, local_dir=model_dir, local_dir_use_symlinks=False, resume_download=True)
+ elif args.type == "modelscope":
+ from modelscope import snapshot_download
+ snapshot_download(repo_id=args.name, local_dir=model_dir)
+ else:
+ raise ValueError(f"Invalid type: {args.type}")
+
+ print(f"model downloaded to {model_dir}")
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/inference.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/inference.py
new file mode 100644
index 0000000..80f3395
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/model/inference.py
@@ -0,0 +1,50 @@
+import json
+import io
+import base64
+import math
+from PIL import Image
+import requests
+from dots_ocr.utils.image_utils import PILimage_to_base64
+from openai import OpenAI
+import os
+
+
+def inference_with_vllm(
+ image,
+ prompt,
+ ip="localhost",
+ port=8000,
+ temperature=0.1,
+ top_p=0.9,
+ max_completion_tokens=32768,
+ model_name='model',
+ ):
+
+ addr = f"http://{ip}:{port}/v1"
+ client = OpenAI(api_key="{}".format(os.environ.get("API_KEY", "0")), base_url=addr)
+ messages = []
+ messages.append(
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image_url",
+ "image_url": {"url": PILimage_to_base64(image)},
+ },
+ {"type": "text", "text": f"<|img|><|imgpad|><|endofimg|>{prompt}"} # if no "<|img|><|imgpad|><|endofimg|>" here,vllm v1 will add "\n" here
+ ],
+ }
+ )
+ try:
+ response = client.chat.completions.create(
+ messages=messages,
+ model=model_name,
+ max_completion_tokens=max_completion_tokens,
+ temperature=temperature,
+ top_p=top_p)
+ response = response.choices[0].message.content
+ return response
+ except requests.exceptions.RequestException as e:
+ print(f"request error: {e}")
+ return None
+
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/consts.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/consts.py
new file mode 100644
index 0000000..3b9f71b
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/consts.py
@@ -0,0 +1,5 @@
+MIN_PIXELS=3136
+MAX_PIXELS=11289600
+IMAGE_FACTOR=28
+
+image_extensions = {'.jpg', '.jpeg', '.png'}
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/doc_utils.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/doc_utils.py
new file mode 100644
index 0000000..915c5c8
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/doc_utils.py
@@ -0,0 +1,60 @@
+import fitz
+import numpy as np
+import enum
+from pydantic import BaseModel, Field
+from PIL import Image
+
+
+class SupportedPdfParseMethod(enum.Enum):
+ OCR = 'ocr'
+ TXT = 'txt'
+
+
+class PageInfo(BaseModel):
+ """The width and height of page
+ """
+ w: float = Field(description='the width of page')
+ h: float = Field(description='the height of page')
+
+
+def fitz_doc_to_image(doc, target_dpi=200, origin_dpi=None) -> dict:
+ """Convert fitz.Document to image, Then convert the image to numpy array.
+
+ Args:
+ doc (_type_): pymudoc page
+ dpi (int, optional): reset the dpi of dpi. Defaults to 200.
+
+ Returns:
+ dict: {'img': numpy array, 'width': width, 'height': height }
+ """
+ from PIL import Image
+ mat = fitz.Matrix(target_dpi / 72, target_dpi / 72)
+ pm = doc.get_pixmap(matrix=mat, alpha=False)
+
+ if pm.width > 4500 or pm.height > 4500:
+ mat = fitz.Matrix(72 / 72, 72 / 72) # use fitz default dpi
+ pm = doc.get_pixmap(matrix=mat, alpha=False)
+
+ image = Image.frombytes('RGB', (pm.width, pm.height), pm.samples)
+ return image
+
+
+def load_images_from_pdf(pdf_file, dpi=200, start_page_id=0, end_page_id=None) -> list:
+ images = []
+ with fitz.open(pdf_file) as doc:
+ pdf_page_num = doc.page_count
+ end_page_id = (
+ end_page_id
+ if end_page_id is not None and end_page_id >= 0
+ else pdf_page_num - 1
+ )
+ if end_page_id > pdf_page_num - 1:
+ print('end_page_id is out of range, use images length')
+ end_page_id = pdf_page_num - 1
+
+ for index in range(0, doc.page_count):
+ if start_page_id <= index <= end_page_id:
+ page = doc[index]
+ img = fitz_doc_to_image(page, target_dpi=dpi)
+ images.append(img)
+ return images
\ No newline at end of file
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/format_transformer.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/format_transformer.py
new file mode 100644
index 0000000..b2a2123
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/format_transformer.py
@@ -0,0 +1,205 @@
+import os
+import sys
+import json
+import re
+
+from PIL import Image
+from dots_ocr.utils.image_utils import PILimage_to_base64
+
+
+def has_latex_markdown(text: str) -> bool:
+ """
+ Checks if a string contains LaTeX markdown patterns.
+
+ Args:
+ text (str): The string to check.
+
+ Returns:
+ bool: True if LaTeX markdown is found, otherwise False.
+ """
+ if not isinstance(text, str):
+ return False
+
+ # Define regular expression patterns for LaTeX markdown
+ latex_patterns = [
+ r'\$\$.*?\$\$', # Block-level math formula $$...$$
+ r'\$[^$\n]+?\$', # Inline math formula $...$
+ r'\\begin\{.*?\}.*?\\end\{.*?\}', # LaTeX environment \begin{...}...\end{...}
+ r'\\[a-zA-Z]+\{.*?\}', # LaTeX command \command{...}
+ r'\\[a-zA-Z]+', # Simple LaTeX command \command
+ r'\\\[.*?\\\]', # Display math formula \[...\]
+ r'\\\(.*?\\\)', # Inline math formula \(...\)
+ ]
+
+ # Check if any of the patterns match
+ for pattern in latex_patterns:
+ if re.search(pattern, text, re.DOTALL):
+ return True
+
+ return False
+
+
+def clean_latex_preamble(latex_text: str) -> str:
+ """
+ Removes LaTeX preamble commands like document class and package imports.
+
+ Args:
+ latex_text (str): The original LaTeX text.
+
+ Returns:
+ str: The cleaned LaTeX text without preamble commands.
+ """
+ # Define patterns to be removed
+ patterns = [
+ r'\\documentclass\{[^}]+\}', # \documentclass{...}
+ r'\\usepackage\{[^}]+\}', # \usepackage{...}
+ r'\\usepackage\[[^\]]*\]\{[^}]+\}', # \usepackage[options]{...}
+ r'\\begin\{document\}', # \begin{document}
+ r'\\end\{document\}', # \end{document}
+ ]
+
+ # Apply each pattern to clean the text
+ cleaned_text = latex_text
+ for pattern in patterns:
+ cleaned_text = re.sub(pattern, '', cleaned_text, flags=re.IGNORECASE)
+
+ return cleaned_text
+
+
+def get_formula_in_markdown(text: str) -> str:
+ """
+ Formats a string containing a formula into a standard Markdown block.
+
+ Args:
+ text (str): The input string, potentially containing a formula.
+
+ Returns:
+ str: The formatted string, ready for Markdown rendering.
+ """
+ # Remove leading/trailing whitespace
+ text = text.strip()
+
+ # Check if it's already enclosed in $$
+ if text.startswith('$$') and text.endswith('$$'):
+ text_new = text[2:-2].strip()
+ if not '$' in text_new:
+ return f"$$\n{text_new}\n$$"
+ else:
+ return text
+
+ # Handle \[...\] format, convert to $$...$$
+ if text.startswith('\\[') and text.endswith('\\]'):
+ inner_content = text[2:-2].strip()
+ return f"$$\n{inner_content}\n$$"
+
+ # Check if it's enclosed in \[ \]
+ if len(re.findall(r'.*\\\[.*\\\].*', text)) > 0:
+ return text
+
+ # Handle inline formulas ($...$)
+ pattern = r'\$([^$]+)\$'
+ matches = re.findall(pattern, text)
+ if len(matches) > 0:
+ # It's an inline formula, return it as is
+ return text
+
+ # If no LaTeX markdown syntax is present, return directly
+ if not has_latex_markdown(text):
+ return text
+
+ # Handle unnecessary LaTeX formatting like \usepackage
+ if 'usepackage' in text:
+ text = clean_latex_preamble(text)
+
+ if text[0] == '`' and text[-1] == '`':
+ text = text[1:-1]
+
+ # Enclose the final text in a $$ block with newlines
+ text = f"$$\n{text}\n$$"
+ return text
+
+
+def clean_text(text: str) -> str:
+ """
+ Cleans text by removing extra whitespace.
+
+ Args:
+ text: The original text.
+
+ Returns:
+ str: The cleaned text.
+ """
+ if not text:
+ return ""
+
+ # Remove leading and trailing whitespace
+ text = text.strip()
+
+ # Replace multiple consecutive whitespace characters with a single space
+ text = re.sub(r'\s+', ' ', text)
+
+ return text
+
+
+def layoutjson2md(image: Image.Image, cells: list, text_key: str = 'text', no_page_hf: bool = False) -> str:
+ """
+ Converts a layout JSON format to Markdown.
+
+ In the layout JSON, formulas are LaTeX, tables are HTML, and text is Markdown.
+
+ Args:
+ image: A PIL Image object.
+ cells: A list of dictionaries, each representing a layout cell.
+ text_key: The key for the text field in the cell dictionary.
+ no_page_header_footer: If True, skips page headers and footers.
+
+ Returns:
+ str: The text in Markdown format.
+ """
+ text_items = []
+
+ for i, cell in enumerate(cells):
+ x1, y1, x2, y2 = [int(coord) for coord in cell['bbox']]
+ text = cell.get(text_key, "")
+
+ if no_page_hf and cell['category'] in ['Page-header', 'Page-footer']:
+ continue
+
+ if cell['category'] == 'Picture':
+ image_crop = image.crop((x1, y1, x2, y2))
+ image_base64 = PILimage_to_base64(image_crop)
+ text_items.append(f"")
+ elif cell['category'] == 'Formula':
+ text_items.append(get_formula_in_markdown(text))
+ else:
+ text = clean_text(text)
+ text_items.append(f"{text}")
+
+ markdown_text = '\n\n'.join(text_items)
+ return markdown_text
+
+
+def fix_streamlit_formulas(md: str) -> str:
+ """
+ Fixes the format of formulas in Markdown to ensure they display correctly in Streamlit.
+ It adds a newline after the opening $$ and before the closing $$ if they don't already exist.
+
+ Args:
+ md_text (str): The Markdown text to fix.
+
+ Returns:
+ str: The fixed Markdown text.
+ """
+
+ # This inner function will be used by re.sub to perform the replacement
+ def replace_formula(match):
+ content = match.group(1)
+ # If the content already has surrounding newlines, don't add more.
+ if content.startswith('\n'):
+ content = content[1:]
+ if content.endswith('\n'):
+ content = content[:-1]
+ return f'$$\n{content}\n$$'
+
+ # Use regex to find all $$....$$ patterns and replace them using the helper function.
+ return re.sub(r'\$\$(.*?)\$\$', replace_formula, md, flags=re.DOTALL)
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/image_utils.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/image_utils.py
new file mode 100644
index 0000000..69479eb
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/image_utils.py
@@ -0,0 +1,196 @@
+import math
+import base64
+from PIL import Image
+from typing import Tuple
+import os
+from dots_ocr.utils.consts import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS
+from dots_ocr.utils.doc_utils import fitz_doc_to_image
+from io import BytesIO
+import fitz
+import requests
+import copy
+
+
+def round_by_factor(number: int, factor: int) -> int:
+ """Returns the closest integer to 'number' that is divisible by 'factor'."""
+ return round(number / factor) * factor
+
+
+def ceil_by_factor(number: int, factor: int) -> int:
+ """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
+ return math.ceil(number / factor) * factor
+
+
+def floor_by_factor(number: int, factor: int) -> int:
+ """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
+ return math.floor(number / factor) * factor
+
+
+def smart_resize(
+ height: int,
+ width: int,
+ factor: int = 28,
+ min_pixels: int = 3136,
+ max_pixels: int = 11289600,
+):
+ """Rescales the image so that the following conditions are met:
+
+ 1. Both dimensions (height and width) are divisible by 'factor'.
+
+ 2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
+
+ 3. The aspect ratio of the image is maintained as closely as possible.
+
+ """
+ if max(height, width) / min(height, width) > 200:
+ raise ValueError(
+ f"absolute aspect ratio must be smaller than 200, got {max(height, width) / min(height, width)}"
+ )
+ h_bar = max(factor, round_by_factor(height, factor))
+ w_bar = max(factor, round_by_factor(width, factor))
+ if h_bar * w_bar > max_pixels:
+ beta = math.sqrt((height * width) / max_pixels)
+ h_bar = max(factor, floor_by_factor(height / beta, factor))
+ w_bar = max(factor, floor_by_factor(width / beta, factor))
+ elif h_bar * w_bar < min_pixels:
+ beta = math.sqrt(min_pixels / (height * width))
+ h_bar = ceil_by_factor(height * beta, factor)
+ w_bar = ceil_by_factor(width * beta, factor)
+ if h_bar * w_bar > max_pixels: # max_pixels first to control the token length
+ beta = math.sqrt((h_bar * w_bar) / max_pixels)
+ h_bar = max(factor, floor_by_factor(h_bar / beta, factor))
+ w_bar = max(factor, floor_by_factor(w_bar / beta, factor))
+ return h_bar, w_bar
+
+
+
+def PILimage_to_base64(image, format='PNG'):
+ buffered = BytesIO()
+ image.save(buffered, format=format)
+ base64_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
+ return f"data:image;base64,{base64_str}"
+
+
+def to_rgb(pil_image: Image.Image) -> Image.Image:
+ if pil_image.mode == 'RGBA':
+ white_background = Image.new("RGB", pil_image.size, (255, 255, 255))
+ white_background.paste(pil_image, mask=pil_image.split()[3]) # Use alpha channel as mask
+ return white_background
+ else:
+ return pil_image.convert("RGB")
+
+
+# copy from https://github.com/QwenLM/Qwen2.5-VL/blob/main/qwen-vl-utils/src/qwen_vl_utils/vision_process.py
+def fetch_image(
+ image,
+ min_pixels=None,
+ max_pixels=None,
+ resized_height=None,
+ resized_width=None,
+ ) -> Image.Image:
+ assert image is not None, f"image not found, maybe input format error: {image}"
+ image_obj = None
+ if isinstance(image, Image.Image):
+ image_obj = image
+ elif image.startswith("http://") or image.startswith("https://"):
+ # fix memory leak issue while using BytesIO
+ with requests.get(image, stream=True) as response:
+ response.raise_for_status()
+ with BytesIO(response.content) as bio:
+ image_obj = copy.deepcopy(Image.open(bio))
+ elif image.startswith("file://"):
+ image_obj = Image.open(image[7:])
+ elif image.startswith("data:image"):
+ if "base64," in image:
+ _, base64_data = image.split("base64,", 1)
+ data = base64.b64decode(base64_data)
+ # fix memory leak issue while using BytesIO
+ with BytesIO(data) as bio:
+ image_obj = copy.deepcopy(Image.open(bio))
+ else:
+ image_obj = Image.open(image)
+ if image_obj is None:
+ raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
+ image = to_rgb(image_obj)
+ ## resize
+ if resized_height and resized_width:
+ resized_height, resized_width = smart_resize(
+ resized_height,
+ resized_width,
+ factor=IMAGE_FACTOR,
+ )
+ assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
+ image = image.resize((resized_width, resized_height))
+ elif min_pixels or max_pixels:
+ width, height = image.size
+ if not min_pixels:
+ min_pixels = MIN_PIXELS
+ if not max_pixels:
+ max_pixels = MAX_PIXELS
+ resized_height, resized_width = smart_resize(
+ height,
+ width,
+ factor=IMAGE_FACTOR,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ assert resized_height>0 and resized_width>0, f"resized_height: {resized_height}, resized_width: {resized_width}, min_pixels: {min_pixels}, max_pixels:{max_pixels}, width: {width}, height:{height}, "
+ image = image.resize((resized_width, resized_height))
+
+ return image
+
+def get_input_dimensions(
+ image: Image.Image,
+ min_pixels: int,
+ max_pixels: int,
+ factor: int = 28
+) -> Tuple[int, int]:
+ """
+ Gets the resized dimensions of the input image.
+
+ Args:
+ image: The original image.
+ min_pixels: The minimum number of pixels.
+ max_pixels: The maximum number of pixels.
+ factor: The resizing factor.
+
+ Returns:
+ The resized (width, height).
+ """
+ input_height, input_width = smart_resize(
+ image.height,
+ image.width,
+ factor=factor,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels
+ )
+ return input_width, input_height
+
+
+def get_image_by_fitz_doc(image, target_dpi=200):
+ # get image through fitz, to get target dpi image, mainly for higher image
+ if not isinstance(image, Image.Image):
+ assert isinstance(image, str)
+ _, file_ext = os.path.splitext(image)
+ assert file_ext in {'.jpg', '.jpeg', '.png'}
+
+ if image.startswith("http://") or image.startswith("https://"):
+ with requests.get(image, stream=True) as response:
+ response.raise_for_status()
+ data_bytes = response.content
+ else:
+ with open(image, 'rb') as f:
+ data_bytes = f.read()
+
+ image = Image.open(BytesIO(data_bytes))
+ else:
+ data_bytes = BytesIO()
+ image.save(data_bytes, format='PNG')
+
+ origin_dpi = image.info.get('dpi', None)
+ pdf_bytes = fitz.open(stream=data_bytes).convert_to_pdf()
+ doc = fitz.open('pdf', pdf_bytes)
+ page = doc[0]
+ image_fitz = fitz_doc_to_image(page, target_dpi=target_dpi, origin_dpi=origin_dpi)
+
+ return image_fitz
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/layout_utils.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/layout_utils.py
new file mode 100644
index 0000000..86fe1ed
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/layout_utils.py
@@ -0,0 +1,228 @@
+from PIL import Image
+from typing import Dict, List
+
+import fitz
+from io import BytesIO
+import json
+
+from dots_ocr.utils.image_utils import smart_resize
+from dots_ocr.utils.consts import MIN_PIXELS, MAX_PIXELS
+from dots_ocr.utils.output_cleaner import OutputCleaner
+
+
+# Define a color map (using RGBA format)
+dict_layout_type_to_color = {
+ "Text": (0, 128, 0, 256), # Green, translucent
+ "Picture": (255, 0, 255, 256), # Magenta, translucent
+ "Caption": (255, 165, 0, 256), # Orange, translucent
+ "Section-header": (0, 255, 255, 256), # Cyan, translucent
+ "Footnote": (0, 128, 0, 256), # Green, translucent
+ "Formula": (128, 128, 128, 256), # Gray, translucent
+ "Table": (255, 192, 203, 256), # Pink, translucent
+ "Title": (255, 0, 0, 256), # Red, translucent
+ "List-item": (0, 0, 255, 256), # Blue, translucent
+ "Page-header": (0, 128, 0, 256), # Green, translucent
+ "Page-footer": (128, 0, 128, 256), # Purple, translucent
+ "Other": (165, 42, 42, 256), # Brown, translucent
+ "Unknown": (0, 0, 0, 0),
+}
+
+
+def draw_layout_on_image(image, cells, resized_height=None, resized_width=None, fill_bbox=True, draw_bbox=True):
+ """
+ Draw transparent boxes on an image.
+
+ Args:
+ image: The source PIL Image.
+ cells: A list of cells containing bounding box information.
+ resized_height: The resized height.
+ resized_width: The resized width.
+ fill_bbox: Whether to fill the bounding box.
+ draw_bbox: Whether to draw the bounding box.
+
+ Returns:
+ PIL.Image: The image with drawings.
+ """
+ # origin_image = Image.open(image_path)
+ original_width, original_height = image.size
+
+ # Create a new PDF document
+ doc = fitz.open()
+
+ # Get image information
+ img_bytes = BytesIO()
+ image.save(img_bytes, format='PNG')
+ # pix = fitz.Pixmap(image_path)
+ pix = fitz.Pixmap(img_bytes)
+
+ # Create a page
+ page = doc.new_page(width=pix.width, height=pix.height)
+ page.insert_image(
+ fitz.Rect(0, 0, pix.width, pix.height),
+ # filename=image_path
+ pixmap=pix
+ )
+
+ for i, cell in enumerate(cells):
+ bbox = cell['bbox']
+ layout_type = cell['category']
+ order = i
+
+ top_left = (bbox[0], bbox[1])
+ down_right = (bbox[2], bbox[3])
+ if resized_height and resized_width:
+ scale_x = resized_width / original_width
+ scale_y = resized_height / original_height
+ top_left = (int(bbox[0] / scale_x), int(bbox[1] / scale_y))
+ down_right = (int(bbox[2] / scale_x), int(bbox[3] / scale_y))
+
+ color = dict_layout_type_to_color.get(layout_type, (0, 128, 0, 256))
+ color = [col/255 for col in color[:3]]
+
+ x0, y0, x1, y1 = top_left[0], top_left[1], down_right[0], down_right[1]
+ rect_coords = fitz.Rect(x0, y0, x1, y1)
+ if draw_bbox:
+ if fill_bbox:
+ page.draw_rect(
+ rect_coords,
+ color=None,
+ fill=color,
+ fill_opacity=0.3,
+ width=0.5,
+ overlay=True,
+ ) # Draw the rectangle
+ else:
+ page.draw_rect(
+ rect_coords,
+ color=color,
+ fill=None,
+ fill_opacity=1,
+ width=0.5,
+ overlay=True,
+ ) # Draw the rectangle
+ order_cate = f"{order}_{layout_type}"
+ page.insert_text(
+ (x1, y0 + 20), order_cate, fontsize=20, color=color
+ ) # Insert the index in the top left corner of the rectangle
+
+ # Convert to a Pixmap (maintaining original dimensions)
+ mat = fitz.Matrix(1.0, 1.0)
+ pix = page.get_pixmap(matrix=mat)
+
+ return Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
+
+
+def pre_process_bboxes(
+ origin_image,
+ bboxes,
+ input_width,
+ input_height,
+ factor: int = 28,
+ min_pixels: int = 3136,
+ max_pixels: int = 11289600
+):
+ assert isinstance(bboxes, list) and len(bboxes) > 0 and isinstance(bboxes[0], list)
+ min_pixels = min_pixels or MIN_PIXELS
+ max_pixels = max_pixels or MAX_PIXELS
+ original_width, original_height = origin_image.size
+
+ input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
+
+ scale_x = original_width / input_width
+ scale_y = original_height / input_height
+
+ bboxes_out = []
+ for bbox in bboxes:
+ bbox_resized = [
+ int(float(bbox[0]) / scale_x),
+ int(float(bbox[1]) / scale_y),
+ int(float(bbox[2]) / scale_x),
+ int(float(bbox[3]) / scale_y)
+ ]
+ bboxes_out.append(bbox_resized)
+
+ return bboxes_out
+
+def post_process_cells(
+ origin_image: Image.Image,
+ cells: List[Dict],
+ input_width, # server input width, also has smart_resize in server
+ input_height,
+ factor: int = 28,
+ min_pixels: int = 3136,
+ max_pixels: int = 11289600
+) -> List[Dict]:
+ """
+ Post-processes cell bounding boxes, converting coordinates from the resized dimensions back to the original dimensions.
+
+ Args:
+ origin_image: The original PIL Image.
+ cells: A list of cells containing bounding box information.
+ input_width: The width of the input image sent to the server.
+ input_height: The height of the input image sent to the server.
+ factor: Resizing factor.
+ min_pixels: Minimum number of pixels.
+ max_pixels: Maximum number of pixels.
+
+ Returns:
+ A list of post-processed cells.
+ """
+ assert isinstance(cells, list) and len(cells) > 0 and isinstance(cells[0], dict)
+ min_pixels = min_pixels or MIN_PIXELS
+ max_pixels = max_pixels or MAX_PIXELS
+ original_width, original_height = origin_image.size
+
+ input_height, input_width = smart_resize(input_height, input_width, min_pixels=min_pixels, max_pixels=max_pixels)
+
+ scale_x = input_width / original_width
+ scale_y = input_height / original_height
+
+ cells_out = []
+ for cell in cells:
+ bbox = cell['bbox']
+ bbox_resized = [
+ int(float(bbox[0]) / scale_x),
+ int(float(bbox[1]) / scale_y),
+ int(float(bbox[2]) / scale_x),
+ int(float(bbox[3]) / scale_y)
+ ]
+ cell_copy = cell.copy()
+ cell_copy['bbox'] = bbox_resized
+ cells_out.append(cell_copy)
+
+ return cells_out
+
+def is_legal_bbox(cells):
+ for cell in cells:
+ bbox = cell['bbox']
+ if bbox[2] <= bbox[0] or bbox[3] <= bbox[1]:
+ return False
+ return True
+
+def post_process_output(response, prompt_mode, origin_image, input_image, min_pixels=None, max_pixels=None):
+ if prompt_mode in ["prompt_ocr", "prompt_table_html", "prompt_table_latex", "prompt_formula_latex"]:
+ return response
+
+ json_load_failed = False
+ cells = response
+ try:
+ cells = json.loads(cells)
+ cells = post_process_cells(
+ origin_image,
+ cells,
+ input_image.width,
+ input_image.height,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels
+ )
+ return cells, False
+ except Exception as e:
+ print(f"cells post process error: {e}, when using {prompt_mode}")
+ json_load_failed = True
+
+ if json_load_failed:
+ cleaner = OutputCleaner()
+ response_clean = cleaner.clean_model_output(cells)
+ if isinstance(response_clean, list):
+ response_clean = "\n\n".join([cell['text'] for cell in response_clean if 'text' in cell])
+ return response_clean, True
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/output_cleaner.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/output_cleaner.py
new file mode 100644
index 0000000..3fbc3aa
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/output_cleaner.py
@@ -0,0 +1,623 @@
+#!/usr/bin/env python3
+"""
+Data Cleaning Script - Cleans all data using a simplified regex method and saves the results
+
+Features:
+1. Cleans all cases using a simplified regex method.
+2. Saves the cleaned data for each case.
+3. Ensures the relative order of dicts remains unchanged.
+4. Generates a before-and-after cleaning report.
+"""
+
+import json
+import re
+import os
+from typing import Dict, List, Tuple, Optional, Any
+from dataclasses import dataclass
+from collections import Counter
+import traceback
+
+
+@dataclass
+class CleanedData:
+ """Data structure for cleaned data"""
+ case_id: int
+ original_type: str # 'list' or 'str'
+ original_length: int
+ cleaned_data: List[Dict]
+ cleaning_operations: Dict[str, Any] # Records the cleaning operations performed
+ success: bool
+
+
+class OutputCleaner:
+ """Data Cleaner - Based on a simplified regex method"""
+
+ def __init__(self):
+ # Simplified regular expression patterns
+ self.dict_pattern = re.compile(r'\{[^{}]*?"bbox"\s*:\s*\[[^\]]*?\][^{}]*?\}', re.DOTALL)
+ self.bbox_pattern = re.compile(r'"bbox"\s*:\s*\[([^\]]+)\]')
+ self.missing_delimiter_pattern = re.compile(r'\}\s*\{(?!")')
+
+ self.cleaned_results: List[CleanedData] = []
+
+ def clean_list_data(self, data: List[Dict], case_id: int) -> CleanedData:
+ """Cleans list-type data"""
+
+ print(f"🔧 Cleaning List data - Case {case_id}")
+ print(f" Original items: {len(data)}")
+
+ cleaned_data = []
+ operations = {
+ 'type': 'list',
+ 'bbox_fixes': 0,
+ 'removed_items': 0,
+ 'original_count': len(data)
+ }
+
+ for i, item in enumerate(data):
+ if not isinstance(item, dict):
+ operations['removed_items'] += 1
+ continue
+
+ # Check the bbox field
+ if 'bbox' in item:
+ bbox = item['bbox']
+
+ # Check bbox length - core logic
+ if isinstance(bbox, list) and len(bbox) == 3:
+ print(f" ⚠️ Item {i}: bbox has only 3 coordinates. Removing bbox, keeping category and text.")
+ # Keep only category and text, ensuring order is preserved
+ new_item = {}
+ if 'category' in item:
+ new_item['category'] = item['category']
+ if 'text' in item:
+ new_item['text'] = item['text']
+ if new_item: # Add only if there is valid content
+ cleaned_data.append(new_item)
+ operations['bbox_fixes'] += 1
+ else:
+ operations['removed_items'] += 1
+ continue
+ elif isinstance(bbox, list) and len(bbox) == 4:
+ # bbox is normal, add directly, preserving original order
+ cleaned_data.append(item.copy())
+ continue
+ else:
+ print(f" ❌ Item {i}: Abnormal bbox format, skipping.")
+ operations['removed_items'] += 1
+ continue
+ else:
+ # No bbox field, keep if category exists
+ if 'category' in item:
+ cleaned_data.append(item.copy())
+ continue
+ else:
+ operations['removed_items'] += 1
+
+ operations['final_count'] = len(cleaned_data)
+ print(f" ✅ Cleaning complete: {len(cleaned_data)} items, {operations['bbox_fixes']} bbox fixes, {operations['removed_items']} items removed")
+
+ return CleanedData(
+ case_id=case_id,
+ original_type='list',
+ original_length=len(data),
+ cleaned_data=cleaned_data,
+ cleaning_operations=operations,
+ success=True
+ )
+
+ def clean_string_data(self, data_str: str, case_id: int) -> CleanedData:
+ """Cleans string-type data"""
+
+ print(f"🔧 Cleaning String data - Case {case_id}")
+ print(f" Original length: {len(data_str):,}")
+
+ operations = {
+ 'type': 'str',
+ 'original_length': len(data_str),
+ 'delimiter_fixes': 0,
+ 'tail_truncated': False,
+ 'truncated_length': 0,
+ 'duplicate_dicts_removed': 0,
+ 'final_objects': 0
+ }
+
+ try:
+ # Step 1: Detect and fix missing delimiters
+ data_str, delimiter_fixes = self._fix_missing_delimiters(data_str)
+ operations['delimiter_fixes'] = delimiter_fixes
+
+ # Step 2: Truncate the last incomplete element
+ data_str, tail_truncated = self._truncate_last_incomplete_element(data_str)
+ operations['tail_truncated'] = tail_truncated
+ operations['truncated_length'] = len(data_str)
+
+ # Step 3: Remove duplicate complete dict objects, preserving order
+ data_str, duplicate_removes = self._remove_duplicate_complete_dicts_preserve_order(data_str)
+ operations['duplicate_dicts_removed'] = duplicate_removes
+
+ # Step 4: Ensure correct JSON format
+ data_str = self._ensure_json_format(data_str)
+
+ # Step 5: Try to parse the final result
+ final_data = self._parse_final_json(data_str)
+
+ if final_data is not None:
+ operations['final_objects'] = len(final_data)
+ print(f" ✅ Cleaning complete: {len(final_data)} objects")
+
+ return CleanedData(
+ case_id=case_id,
+ original_type='str',
+ original_length=operations['original_length'],
+ cleaned_data=final_data,
+ cleaning_operations=operations,
+ success=True
+ )
+ else:
+ raise Exception("Could not parse the cleaned data")
+
+ except Exception as e:
+ print(f" ❌ Cleaning failed: {e}")
+ return CleanedData(
+ case_id=case_id,
+ original_type='str',
+ original_length=operations['original_length'],
+ cleaned_data=[],
+ cleaning_operations=operations,
+ success=False
+ )
+
+ def _fix_missing_delimiters(self, text: str) -> Tuple[str, int]:
+ """Fixes missing delimiters"""
+
+ fixes = 0
+
+ def replace_delimiter(match):
+ nonlocal fixes
+ fixes += 1
+ return '},{'
+
+ text = self.missing_delimiter_pattern.sub(replace_delimiter, text)
+
+ if fixes > 0:
+ print(f" ✅ Fixed {fixes} missing delimiters")
+
+ return text, fixes
+
+ def _truncate_last_incomplete_element(self, text: str) -> Tuple[str, bool]:
+ """Truncates the last incomplete element"""
+
+ # For very long text (>50k) or text not ending with ']', directly truncate the last '{"bbox":'
+ needs_truncation = (
+ len(text) > 50000 or
+ not text.strip().endswith(']')
+ )
+
+ if needs_truncation:
+ # Check how many dict objects there are
+ bbox_count = text.count('{"bbox":')
+
+ # If there is only one dict object, do not truncate to avoid deleting the only object
+ if bbox_count <= 1:
+ print(f" ⚠️ Only {bbox_count} dict objects found, skipping truncation to avoid deleting all content")
+ return text, False
+
+ # Find the position of the last '{"bbox":'
+ last_bbox_pos = text.rfind('{"bbox":')
+
+ if last_bbox_pos > 0:
+ # Truncate before this position
+ truncated_text = text[:last_bbox_pos].rstrip()
+
+ # Remove trailing comma
+ if truncated_text.endswith(','):
+ truncated_text = truncated_text[:-1]
+
+ print(f" ✂️ Truncated the last incomplete element, length reduced from {len(text):,} to {len(truncated_text):,}")
+ return truncated_text, True
+
+ return text, False
+
+ def _remove_duplicate_complete_dicts_preserve_order(self, text: str) -> Tuple[str, int]:
+ """Removes duplicate complete dict objects, preserving original order"""
+
+ # Extract all dict objects, preserving order
+ dict_matches = list(self.dict_pattern.finditer(text))
+
+ if not dict_matches:
+ return text, 0
+
+ print(f" 📊 Found {len(dict_matches)} dict objects")
+
+ # Deduplication while preserving order: only keep the first occurrence of a dict
+ unique_dicts = []
+ seen_dict_strings = set()
+ total_duplicates = 0
+
+ for match in dict_matches:
+ dict_str = match.group()
+
+ if dict_str not in seen_dict_strings:
+ unique_dicts.append(dict_str)
+ seen_dict_strings.add(dict_str)
+ else:
+ total_duplicates += 1
+
+ if total_duplicates > 0:
+ # Reconstruct the JSON array, preserving the original order
+ new_text = '[' + ', '.join(unique_dicts) + ']'
+ print(f" ✅ Removed {total_duplicates} duplicate dicts, keeping {len(unique_dicts)} unique dicts (order preserved)")
+ return new_text, total_duplicates
+ else:
+ print(f" ✅ No duplicate dict objects found")
+ return text, 0
+
+ def _ensure_json_format(self, text: str) -> str:
+ """Ensures correct JSON format"""
+
+ text = text.strip()
+
+ if not text.startswith('['):
+ text = '[' + text
+
+ if not text.endswith(']'):
+ # Remove trailing comma
+ text = text.rstrip(',').rstrip()
+ text += ']'
+
+ return text
+
+ def _parse_final_json(self, text: str) -> Optional[List[Dict]]:
+ """Tries to parse the final JSON"""
+
+ try:
+ data = json.loads(text)
+ if isinstance(data, list):
+ return data
+ except json.JSONDecodeError as e:
+ print(f" ❌ JSON parsing failed: {e}")
+
+ # fallback1: Extract valid dict objects
+ valid_dicts = []
+
+ for match in self.dict_pattern.finditer(text):
+ dict_str = match.group()
+ try:
+ dict_obj = json.loads(dict_str)
+ valid_dicts.append(dict_obj)
+ except:
+ continue
+
+ if valid_dicts:
+ print(f" ✅ Extracted {len(valid_dicts)} valid dicts")
+ return valid_dicts
+
+ # fallback2: Special handling for a single incomplete dict
+ return self._handle_single_incomplete_dict(text)
+
+ return None
+
+ def _handle_single_incomplete_dict(self, text: str) -> Optional[List[Dict]]:
+ """Handles the special case of a single incomplete dict"""
+
+ # Check if it's a single incomplete dict case
+ if not text.strip().startswith('[{"bbox":'):
+ return None
+
+ try:
+ # Try to extract bbox coordinates
+ bbox_match = re.search(r'"bbox"\s*:\s*\[([^\]]+)\]', text)
+ if not bbox_match:
+ return None
+
+ bbox_str = bbox_match.group(1)
+ bbox_coords = [int(x.strip()) for x in bbox_str.split(',')]
+
+ if len(bbox_coords) != 4:
+ return None
+
+ # Try to extract category
+ category_match = re.search(r'"category"\s*:\s*"([^"]+)"', text)
+ category = category_match.group(1) if category_match else "Text"
+
+ # Try to extract the beginning of the text (first 10000 characters)
+ text_match = re.search(r'"text"\s*:\s*"([^"]{0,10000})', text)
+ if text_match:
+ text_content = text_match.group(1)
+ else:
+ text_content = ""
+
+ # Construct the fixed dict
+ fixed_dict = {
+ "bbox": bbox_coords,
+ "category": category
+ }
+
+ if text_content:
+ fixed_dict["text"] = text_content
+
+ print(f" 🔧 Special fix: single incomplete dict → {fixed_dict}")
+ return [fixed_dict]
+
+ except Exception as e:
+ print(f" ❌ Special fix failed: {e}")
+ return None
+
+ def remove_duplicate_category_text_pairs_and_bbox(self, data_list: List[dict], case_id: int) -> List[dict]:
+ """Removes duplicate category-text pairs and duplicate bboxes"""
+
+ if not data_list or len(data_list) <= 1:
+ print(f" 📊 Data length {len(data_list)} <= 1, skipping deduplication check")
+ return data_list
+
+ print(f" 📊 Original data length: {len(data_list)}")
+
+ # 1. Count occurrences and positions of each category-text pair
+ category_text_pairs = {}
+ for i, item in enumerate(data_list):
+ if isinstance(item, dict) and 'category' in item and 'text' in item:
+ pair_key = (item.get('category', ''), item.get('text', ''))
+ if pair_key not in category_text_pairs:
+ category_text_pairs[pair_key] = []
+ category_text_pairs[pair_key].append(i)
+
+ # 2. Count occurrences and positions of each bbox
+ bbox_pairs = {}
+ for i, item in enumerate(data_list):
+ if isinstance(item, dict) and 'bbox' in item:
+ bbox = item.get('bbox')
+ if isinstance(bbox, list) and len(bbox) > 0:
+ bbox_key = tuple(bbox) # Convert to tuple to use as a dictionary key
+ if bbox_key not in bbox_pairs:
+ bbox_pairs[bbox_key] = []
+ bbox_pairs[bbox_key].append(i)
+
+ # 3. Identify items to be removed
+ duplicates_to_remove = set()
+
+ # 3a. Process category-text pairs that appear 5 or more times
+ for pair_key, positions in category_text_pairs.items():
+ if len(positions) >= 5:
+ category, text = pair_key
+ # Keep the first occurrence, remove subsequent duplicates
+ positions_to_remove = positions[1:]
+ duplicates_to_remove.update(positions_to_remove)
+
+ print(f" 🔍 Found duplicate category-text pair: category='{category}', first 50 chars of text='{text[:50]}...'")
+ print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
+
+ # 3b. Process bboxes that appear 2 or more times
+ for bbox_key, positions in bbox_pairs.items():
+ if len(positions) >= 2:
+ # Keep the first occurrence, remove subsequent duplicates
+ positions_to_remove = positions[1:]
+ duplicates_to_remove.update(positions_to_remove)
+
+ print(f" 🔍 Found duplicate bbox: {list(bbox_key)}")
+ print(f" Count: {len(positions)}, removing at positions: {positions_to_remove}")
+
+ if not duplicates_to_remove:
+ print(f" ✅ No category-text pairs or bboxes found exceeding the duplication threshold")
+ return data_list
+
+ # 4. Remove duplicate items from the original data (preserving order)
+ cleaned_data = []
+ removed_count = 0
+ for i, item in enumerate(data_list):
+ if i not in duplicates_to_remove:
+ cleaned_data.append(item)
+ else:
+ removed_count += 1
+
+ print(f" ✅ Deduplication complete: Removed {removed_count} duplicate items")
+ print(f" 📊 Cleaned data length: {len(cleaned_data)}")
+
+ return cleaned_data
+
+ def clean_model_output(self, model_output: str):
+ try:
+ # Select cleaning method based on data type
+ if isinstance(model_output, list):
+ result = self.clean_list_data(model_output, case_id=0)
+ else:
+ result = self.clean_string_data(str(model_output), case_id=0)
+
+ # Add deduplication step: remove duplicate category-text pairs and bboxes
+ if result and hasattr(result, 'success') and result.success and result.cleaned_data:
+ original_data = result.cleaned_data
+ deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id=0)
+ # Update the cleaned_data in the CleanedData object
+ result.cleaned_data = deduplicated_data
+ return result.cleaned_data
+ except Exception as e:
+ print(f"❌ Case cleaning failed: {e}")
+ return model_output
+
+ def clean_all_data(self, jsonl_path: str) -> List[CleanedData]:
+ """Cleans all data from a JSONL file"""
+
+ print(f"🚀 Starting to clean JSONL file: {jsonl_path}")
+
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
+ lines = f.readlines()
+
+ datas = []
+ for i, line in enumerate(lines):
+ if line.strip():
+ try:
+ data = json.loads(line)
+ predict_field = data.get('predict')
+ case_id = i + 1
+
+ print(f"\n{'='*50}")
+ print(f"🎯 Cleaning Case {case_id}")
+ print(f"{'='*50}")
+
+ # Select cleaning method based on data type
+ if isinstance(predict_field, list):
+ print("📊 Data type: List")
+ result = self.clean_list_data(predict_field, case_id)
+ else:
+ print("📊 Data type: String")
+ result = self.clean_string_data(str(predict_field), case_id)
+
+ # Add deduplication step: remove duplicate category-text pairs and bboxes
+ if result and hasattr(result, 'success') and result.success and result.cleaned_data:
+ print("🔄 Checking for and removing duplicate category-text pairs and bboxes...")
+ original_data = result.cleaned_data
+ deduplicated_data = self.remove_duplicate_category_text_pairs_and_bbox(original_data, case_id)
+ # Update the cleaned_data in the CleanedData object
+ result.cleaned_data = deduplicated_data
+ data['predict_resized'] = result.cleaned_data
+
+ datas.append(data)
+ self.cleaned_results.append(result)
+
+ except Exception as e:
+ print(f"❌ Case {i+1} cleaning failed: {e}")
+ traceback.print_exc()
+
+ save_path = jsonl_path.replace('.jsonl', '_filtered.jsonl')
+ with open(save_path, 'w') as w:
+ for data in datas:
+ w.write(json.dumps(data, ensure_ascii=False) + '\n')
+ print(f"✅ Saved cleaned data to: {save_path}")
+
+ return self.cleaned_results
+
+ def save_cleaned_data(self, output_dir: str):
+ """Saves the cleaned data"""
+
+ print(f"\n💾 Saving cleaned data to: {output_dir}")
+ os.makedirs(output_dir, exist_ok=True)
+
+ # 1. Save cleaned data for each case
+ for result in self.cleaned_results:
+ case_filename = f"cleaned_case_{result.case_id:02d}.json"
+ case_filepath = os.path.join(output_dir, case_filename)
+
+ # Save the cleaned data
+ with open(case_filepath, 'w', encoding='utf-8') as f:
+ json.dump(result.cleaned_data, f, ensure_ascii=False, indent=2)
+
+ print(f" ✅ Case {result.case_id}: {len(result.cleaned_data)} objects → {case_filename}")
+
+ # 2. Save all cleaned data to a single file
+ all_cleaned_data = []
+ for result in self.cleaned_results:
+ all_cleaned_data.append({
+ 'case_id': result.case_id,
+ 'original_type': result.original_type,
+ 'original_length': result.original_length,
+ 'cleaned_objects_count': len(result.cleaned_data),
+ 'success': result.success,
+ 'cleaning_operations': result.cleaning_operations,
+ 'cleaned_data': result.cleaned_data
+ })
+
+ all_data_filepath = os.path.join(output_dir, "all_cleaned_data.json")
+ with open(all_data_filepath, 'w', encoding='utf-8') as f:
+ json.dump(all_cleaned_data, f, ensure_ascii=False, indent=2)
+
+ print(f" 📁 All data: {len(all_cleaned_data)} cases → all_cleaned_data.json")
+
+ # 3. Generate a cleaning report
+ self._generate_cleaning_report(output_dir)
+
+ def _generate_cleaning_report(self, output_dir: str):
+ """Generates a cleaning report"""
+
+ report = []
+ report.append("📊 Data Cleaning Report")
+ report.append("=" * 60)
+ import datetime
+ report.append(f"Processing Time: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
+ report.append("")
+
+ # Overall statistics
+ total_cases = len(self.cleaned_results)
+ successful_cases = sum(1 for r in self.cleaned_results if r.success)
+ total_objects = sum(len(r.cleaned_data) for r in self.cleaned_results)
+
+ report.append("📈 Overall Statistics:")
+ report.append(f" Total Cases: {total_cases}")
+ report.append(f" Successfully Cleaned: {successful_cases}")
+ report.append(f" Success Rate: {successful_cases/total_cases*100:.1f}%")
+ report.append(f" Total Recovered Objects: {total_objects}")
+ report.append("")
+
+ # Detailed statistics
+ list_results = [r for r in self.cleaned_results if r.original_type == 'list']
+ str_results = [r for r in self.cleaned_results if r.original_type == 'str']
+
+ if list_results:
+ report.append("📋 List Type Cleaning Statistics:")
+ for r in list_results:
+ ops = r.cleaning_operations
+ report.append(f" Case {r.case_id}: {ops['original_count']} → {ops['final_count']} objects")
+ if ops['bbox_fixes'] > 0:
+ report.append(f" - bbox fixes: {ops['bbox_fixes']}")
+ if ops['removed_items'] > 0:
+ report.append(f" - invalid items removed: {ops['removed_items']}")
+ report.append("")
+
+ if str_results:
+ report.append("📝 String Type Cleaning Statistics:")
+ for r in str_results:
+ ops = r.cleaning_operations
+ status = "✅" if r.success else "❌"
+ report.append(f" Case {r.case_id} {status}: {ops['original_length']:,} chars → {ops['final_objects']} objects")
+ details = []
+ if ops['delimiter_fixes'] > 0:
+ details.append(f"Delimiter fixes: {ops['delimiter_fixes']}")
+ if ops['tail_truncated']:
+ reduction = ops['original_length'] - ops['truncated_length']
+ details.append(f"Tail truncation: -{reduction:,} chars")
+ if ops['duplicate_dicts_removed'] > 0:
+ details.append(f"Duplicates removed: {ops['duplicate_dicts_removed']}")
+ if details:
+ report.append(f" - {', '.join(details)}")
+ report.append("")
+
+ # Note on data order
+ report.append("🔄 Data Order Guarantee:")
+ report.append(" ✅ The relative order of all dict objects is preserved during cleaning.")
+ report.append(" ✅ When deduplicating, the first occurrence of a dict is kept, and subsequent duplicates are removed.")
+ report.append(" ✅ The order of items in List-type data is fully preserved.")
+
+ # Save the report
+ report_filepath = os.path.join(output_dir, "cleaning_report.txt")
+ with open(report_filepath, 'w', encoding='utf-8') as f:
+ f.write('\n'.join(report))
+
+ print(f" 📋 Cleaning report: cleaning_report.txt")
+
+ # Also print to console
+ print(f"\n{chr(10).join(report)}")
+
+
+def main():
+ """Main function"""
+
+ # Create a data cleaner instance
+ cleaner = OutputCleaner()
+
+ # Input file
+ jsonl_path = "output_with_failcase.jsonl"
+
+ # Output directory
+ output_dir = "output_with_failcase_cleaned"
+
+ # Clean all data
+ results = cleaner.clean_all_data(jsonl_path)
+
+ # Save the cleaned data
+ cleaner.save_cleaned_data(output_dir)
+
+ print(f"\n🎉 Data cleaning complete!")
+ print(f"📁 Cleaned data saved in: {output_dir}")
+
+
+if __name__ == "__main__":
+ main()
\ No newline at end of file
diff --git a/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/prompts.py b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/prompts.py
new file mode 100644
index 0000000..87714c3
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/dots_ocr/utils/prompts.py
@@ -0,0 +1,34 @@
+dict_promptmode_to_prompt = {
+ # prompt_layout_all_en: parse all layout info in json format.
+ "prompt_layout_all_en": """Please output the layout information from the PDF image, including each layout element's bbox, its category, and the corresponding text content within the bbox.
+
+1. Bbox format: [x1, y1, x2, y2]
+
+2. Layout Categories: The possible categories are ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title'].
+
+3. Text Extraction & Formatting Rules:
+ - Picture: For the 'Picture' category, the text field should be omitted.
+ - Formula: Format its text as LaTeX.
+ - Table: Format its text as HTML.
+ - All Others (Text, Title, etc.): Format their text as Markdown.
+
+4. Constraints:
+ - The output text must be the original text from the image, with no translation.
+ - All layout elements must be sorted according to human reading order.
+
+5. Final Output: The entire output must be a single JSON object.
+""",
+
+ # prompt_layout_only_en: layout detection
+ "prompt_layout_only_en": """Please output the layout information from this PDF image, including each layout's bbox and its category. The bbox should be in the format [x1, y1, x2, y2]. The layout categories for the PDF document include ['Caption', 'Footnote', 'Formula', 'List-item', 'Page-footer', 'Page-header', 'Picture', 'Section-header', 'Table', 'Text', 'Title']. Do not output the corresponding text. The layout result should be in JSON format.""",
+
+ # prompt_layout_only_en: parse ocr text except the Page-header and Page-footer
+ "prompt_ocr": """Extract the text content from this image.""",
+
+ # prompt_grounding_ocr: extract text content in the given bounding box
+ "prompt_grounding_ocr": """Extract text from the given bounding box on the image (format: [x1, y1, x2, y2]).\nBounding Box:\n""",
+
+ # "prompt_table_html": """Convert the table in this image to HTML.""",
+ # "prompt_table_latex": """Convert the table in this image to LaTeX.""",
+ # "prompt_formula_latex": """Convert the formula in this image to LaTeX.""",
+}
diff --git a/rag_factory/parser/Parser_Dotsocr/fig_recognize.py b/rag_factory/parser/Parser_Dotsocr/fig_recognize.py
new file mode 100644
index 0000000..175529e
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/fig_recognize.py
@@ -0,0 +1,183 @@
+import os
+import glob
+import json
+import re
+import fitz
+from PIL import Image
+from tqdm import tqdm
+from dashscope import MultiModalConversation
+import argparse
+from pathlib import Path
+
+
+def fig_understand(fig_path):
+ # prompt = '请给出图像中具体内容信息,并用json格式输出,仅输出json格式数据,其中,图片类型请从["chart","knowladge_map","other"]中选择'
+ prompt = '''
+你是一个图像内容理解专家,任务是读取图像内容并生成结构化 JSON 数据。请遵循以下规则:
+
+1. **仅输出 JSON 数据**,不要添加任何解释、前缀或后缀文字。
+2. JSON 格式中必须包含两个字段:
+ - "type": 图像类型,只能从 ["chart", "knowladge_map", "other"] 中选择。
+ - "content": 图像的具体结构化内容描述。
+3. 如果图像类型是:
+ - "chart": 请提取图表的标题、坐标轴标签、图例、系列等结构信息。
+ - "knowladge_map": 输出树状结构,所有节点使用 {"name": xxx, "children": [...]} 格式。
+ - "other": 尽可能准确描述图像的主要元素。
+
+以下是几个示例,请模仿格式输出。
+
+---
+
+### 示例1(chart):
+输入图像:柱状图,标题为“年度销售统计”,X轴为月份,Y轴为销售额,图例为“产品A”和“产品B”。
+
+输出:
+```json
+{
+ "type": "chart",
+ "content": {
+
+ "title": "年度销售统计",
+ "x_axis": "月份",
+ "y_axis": "销售额",
+ "legend": ["产品A", "产品B"],
+ "series": [
+ {"name": "产品A", "data": [100, 120, 130]},
+ {"name": "产品B", "data": [80, 90, 100]}
+ ]
+ }
+}
+示例2(knowladge_map):
+输入图像:知识图谱,核心为“机器学习”,子节点有“监督学习”和“无监督学习”,监督学习下有“回归”和“分类”。
+
+输出:
+{
+ "type": "knowladge_map",
+ "content": {
+ "name": "机器学习",
+ "children": [
+ {
+ "name": "监督学习",
+ "children": [
+ {"name": "回归"},
+ {"name": "分类"}
+ ]
+ },
+ {
+ "name": "无监督学习"
+ }
+ ]
+ }
+}
+示例3(other):
+输入图像:一张会议室内多个人开会的场景。
+
+输出:
+{
+ "type": "other",
+ "content": "一个会议室中有5个人正在围绕会议桌讨论,桌上有笔记本电脑和文件。"
+}
+
+请根据上面的示例输出格式,严格输出图像的内容识别结果,只返回符合格式的 JSON 数据。
+
+'''
+ messages= [
+ {
+ "role": "user",
+ "content": [
+ {"image": f"file://{fig_path}"},
+ {"text": prompt}
+ ]
+ }
+]
+
+ response = MultiModalConversation.call(
+ api_key=os.environ.get('DASHSCOPE_API_KEY'),
+ model="qwen-vl-plus",
+ messages=messages,
+ )
+
+ # print(response)
+ return response["output"]["choices"][0]["message"].content[0]["text"].replace("```json",'').replace("```",'').strip()
+
+def save_fig(file_path, page_no, index, bbox, scale):
+
+ file_name, file_ext = os.path.splitext(os.path.basename(file_path))
+ file_name = file_name.replace('_layout', '')
+ base_dir = os.path.dirname(file_path)
+ pdf_file = os.path.join(base_dir, f"{file_name}_original.pdf")
+ doc = fitz.open(pdf_file)
+ page = doc.load_page(page_no)
+
+ pdf_width = page.rect.width
+ pdf_height = page.rect.height
+
+ scale_x = scale[1] / pdf_width
+ scale_y = scale[0] / pdf_height
+ x1 = bbox[0] / scale_x
+ y1 = bbox[1] / scale_y
+ x2 = bbox[2] / scale_x
+ y2 = bbox[3] / scale_y
+ pdf_bbox = fitz.Rect(x1, y1, x2, y2)
+ zoom = 300 / 72 # 输出300 DPI
+ matrix = fitz.Matrix(zoom, zoom)
+ img = page.get_pixmap(matrix=matrix, clip=pdf_bbox, alpha=False)
+
+ save_dir = os.path.join(base_dir,f"{file_name}/image")
+ if not os.path.exists(save_dir):
+ os.mkdir(save_dir)
+
+
+ text = ''
+ save_path = os.path.join(save_dir, f'page_{page_no}_{index}.png')
+ if img is not None:
+ img.save(save_path)
+
+ text = fig_understand(save_path)
+
+ return save_path, text
+
+def process_one_file(json_file):
+ file_name = os.path.basename(json_file)
+ base_dir = os.path.dirname(json_file)
+ output_path = str(json_file).replace("layout", "img_content")
+ data = []
+ with open(json_file, 'r', encoding='utf-8') as f:
+ json_data = json.load(f)
+ print(f"Processing file: {file_name}")
+ for row in tqdm(json_data):
+ if row.get('category','') == 'Picture':
+ bbox = row['bbox']
+ page_no = row['page_no']
+ if (bbox[2]-bbox[0])*(bbox[3]-bbox[1]) < 52000:
+ row['text'] = ""
+ else:
+ fig_path, text = save_fig(json_file, page_no=page_no, index=row['index'], bbox=bbox, scale=row['scale'])
+ # print(text)
+ row['text'] = json.loads(text)
+ with open(output_path, 'w', encoding='utf-8') as f:
+ json.dump(json_data, f, ensure_ascii=False, indent=4)
+ return json_data
+
+def main():
+ parser = argparse.ArgumentParser(description="Use vlm to get parsed figure content.")
+ parser.add_argument(
+ "--output", type=str, default="output",
+ help="Output parsed directory (default: output)"
+ )
+ args = parser.parse_args()
+
+
+ if os.path.isdir(args.output):
+ for file in sorted(Path(args.output).glob('*_layout.json')):
+ data = process_one_file(file)
+ elif os.path.isfile(args.output):
+ data = process_one_file(args.output)
+ else:
+ print(f"'{args.output}' no exist")
+
+if __name__ == "__main__":
+ os.environ["DASHSCOPE_API_KEY"] = "your api key"
+ main()
+
+
diff --git a/rag_factory/parser/Parser_Dotsocr/parser.py b/rag_factory/parser/Parser_Dotsocr/parser.py
new file mode 100644
index 0000000..a143eec
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/parser.py
@@ -0,0 +1,424 @@
+import os
+import json
+from tqdm import tqdm
+from multiprocessing.pool import ThreadPool, Pool
+import argparse
+
+
+from dots_ocr.model.inference import inference_with_vllm
+from dots_ocr.utils.consts import image_extensions, MIN_PIXELS, MAX_PIXELS
+from dots_ocr.utils.image_utils import get_image_by_fitz_doc, fetch_image, smart_resize
+from dots_ocr.utils.doc_utils import fitz_doc_to_image, load_images_from_pdf
+from dots_ocr.utils.prompts import dict_promptmode_to_prompt
+from dots_ocr.utils.layout_utils import post_process_output, draw_layout_on_image, pre_process_bboxes
+from dots_ocr.utils.format_transformer import layoutjson2md
+
+
+class DotsOCRParser:
+ """
+ parse image or pdf file
+ """
+
+ def __init__(self,
+ ip='localhost',
+ port=8000,
+ model_name='model',
+ temperature=0.1,
+ top_p=1.0,
+ max_completion_tokens=16384,
+ num_thread=64,
+ dpi = 200,
+ output_dir="output",
+ min_pixels=None,
+ max_pixels=None,
+ use_hf=True,
+ ):
+ self.dpi = dpi
+
+ # default args for vllm server
+ self.ip = ip
+ self.port = port
+ self.model_name = model_name
+ # default args for inference
+ self.temperature = temperature
+ self.top_p = top_p
+ self.max_completion_tokens = max_completion_tokens
+ self.num_thread = num_thread
+ self.output_dir = output_dir
+ self.min_pixels = min_pixels
+ self.max_pixels = max_pixels
+
+ self.use_hf = use_hf
+ if self.use_hf:
+ self._load_hf_model()
+ print(f"use hf model, num_thread will be set to 1")
+ else:
+ print(f"use vllm model, num_thread will be set to {self.num_thread}")
+ assert self.min_pixels is None or self.min_pixels >= MIN_PIXELS
+ assert self.max_pixels is None or self.max_pixels <= MAX_PIXELS
+
+ def _load_hf_model(self):
+ import torch
+ from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
+ from qwen_vl_utils import process_vision_info
+
+ model_path = "weights/DotsOCR"
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ attn_implementation="flash_attention_2",
+ torch_dtype=torch.bfloat16,
+ device_map="auto",
+ trust_remote_code=True
+ )
+ self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True,use_fast=True)
+ self.process_vision_info = process_vision_info
+
+ def _inference_with_hf(self, image, prompt):
+ messages = [
+ {
+ "role": "user",
+ "content": [
+ {
+ "type": "image",
+ "image": image
+ },
+ {"type": "text", "text": prompt}
+ ]
+ }
+ ]
+
+ # Preparation for inference
+ text = self.processor.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+ image_inputs, video_inputs = self.process_vision_info(messages)
+ inputs = self.processor(
+ text=[text],
+ images=image_inputs,
+ videos=video_inputs,
+ padding=True,
+ return_tensors="pt",
+ )
+
+ inputs = inputs.to("cuda")
+
+ # Inference: Generation of the output
+ generated_ids = self.model.generate(**inputs, max_new_tokens=24000)
+ generated_ids_trimmed = [
+ out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
+ ]
+ response = self.processor.batch_decode(
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
+ )[0]
+ return response
+
+ def _inference_with_vllm(self, image, prompt):
+ response = inference_with_vllm(
+ image,
+ prompt,
+ model_name=self.model_name,
+ ip=self.ip,
+ port=self.port,
+ temperature=self.temperature,
+ top_p=self.top_p,
+ max_completion_tokens=self.max_completion_tokens,
+ )
+ return response
+
+ def get_prompt(self, prompt_mode, bbox=None, origin_image=None, image=None, min_pixels=None, max_pixels=None):
+ prompt = dict_promptmode_to_prompt[prompt_mode]
+ if prompt_mode == 'prompt_grounding_ocr':
+ assert bbox is not None
+ bboxes = [bbox]
+ bbox = pre_process_bboxes(origin_image, bboxes, input_width=image.width, input_height=image.height, min_pixels=min_pixels, max_pixels=max_pixels)[0]
+ prompt = prompt + str(bbox)
+ return prompt
+
+ # def post_process_results(self, response, prompt_mode, save_dir, save_name, origin_image, image, min_pixels, max_pixels)
+ def _parse_single_image(
+ self,
+ origin_image,
+ prompt_mode,
+ save_dir,
+ save_name,
+ source="image",
+ page_idx=0,
+ bbox=None,
+ fitz_preprocess=False,
+ ):
+ min_pixels, max_pixels = self.min_pixels, self.max_pixels
+ if prompt_mode == "prompt_grounding_ocr":
+ min_pixels = min_pixels or MIN_PIXELS # preprocess image to the final input
+ max_pixels = max_pixels or MAX_PIXELS
+ if min_pixels is not None: assert min_pixels >= MIN_PIXELS, f"min_pixels should >= {MIN_PIXELS}"
+ if max_pixels is not None: assert max_pixels <= MAX_PIXELS, f"max_pixels should <+ {MAX_PIXELS}"
+
+ if source == 'image' and fitz_preprocess:
+ image = get_image_by_fitz_doc(origin_image, target_dpi=self.dpi)
+ image = fetch_image(image, min_pixels=min_pixels, max_pixels=max_pixels)
+ else:
+ image = fetch_image(origin_image, min_pixels=min_pixels, max_pixels=max_pixels)
+ input_height, input_width = smart_resize(image.height, image.width)
+ prompt = self.get_prompt(prompt_mode, bbox, origin_image, image, min_pixels=min_pixels, max_pixels=max_pixels)
+ if self.use_hf:
+ response = self._inference_with_hf(image, prompt)
+ else:
+ response = self._inference_with_vllm(image, prompt)
+ result = {'page_no': page_idx,
+ "input_height": input_height,
+ "input_width": input_width
+ }
+ if source == 'pdf':
+ save_name = f"{save_name}_page_{page_idx}"
+ if prompt_mode in ['prompt_layout_all_en', 'prompt_layout_only_en', 'prompt_grounding_ocr']:
+ cells, filtered = post_process_output(
+ response,
+ prompt_mode,
+ origin_image,
+ image,
+ min_pixels=min_pixels,
+ max_pixels=max_pixels,
+ )
+ if filtered and prompt_mode != 'prompt_layout_only_en': # model output json failed, use filtered process
+ json_file_path = os.path.join(save_dir, f"{save_name}.json")
+ with open(json_file_path, 'w', encoding="utf-8") as w:
+ json.dump(response, w, ensure_ascii=False, indent=4)
+
+ image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
+ origin_image.save(image_layout_path)
+ result.update({
+ 'layout_info_path': json_file_path,
+ 'layout_image_path': image_layout_path,
+ })
+
+ md_file_path = os.path.join(save_dir, f"{save_name}.md")
+ with open(md_file_path, "w", encoding="utf-8") as md_file:
+ md_file.write(cells)
+ result.update({
+ 'md_content_path': md_file_path
+ })
+ result.update({
+ 'filtered': True
+ })
+ else:
+ try:
+ image_with_layout = draw_layout_on_image(origin_image, cells)
+ except Exception as e:
+ print(f"Error drawing layout on image: {e}")
+ image_with_layout = origin_image
+
+ json_file_path = os.path.join(save_dir, f"{save_name}.json")
+ with open(json_file_path, 'w', encoding="utf-8") as w:
+ json.dump(cells, w, ensure_ascii=False, indent=4)
+
+ image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
+ image_with_layout.save(image_layout_path)
+ result.update({
+ 'layout_info_path': json_file_path,
+ 'layout_image_path': image_layout_path,
+ })
+ if prompt_mode != "prompt_layout_only_en": # no text md when detection only
+ md_content = layoutjson2md(origin_image, cells, text_key='text')
+ md_content_no_hf = layoutjson2md(origin_image, cells, text_key='text', no_page_hf=True) # used for clean output or metric of omnidocbench、olmbench
+ md_file_path = os.path.join(save_dir, f"{save_name}.md")
+ with open(md_file_path, "w", encoding="utf-8") as md_file:
+ md_file.write(md_content)
+ md_nohf_file_path = os.path.join(save_dir, f"{save_name}_nohf.md")
+ with open(md_nohf_file_path, "w", encoding="utf-8") as md_file:
+ md_file.write(md_content_no_hf)
+ result.update({
+ 'md_content_path': md_file_path,
+ 'md_content_nohf_path': md_nohf_file_path,
+ })
+ else:
+ image_layout_path = os.path.join(save_dir, f"{save_name}.jpg")
+ origin_image.save(image_layout_path)
+ result.update({
+ 'layout_image_path': image_layout_path,
+ })
+
+ md_content = response
+ md_file_path = os.path.join(save_dir, f"{save_name}.md")
+ with open(md_file_path, "w", encoding="utf-8") as md_file:
+ md_file.write(md_content)
+ result.update({
+ 'md_content_path': md_file_path,
+ })
+
+ return result
+
+ def parse_image(self, input_path, filename, prompt_mode, save_dir, bbox=None, fitz_preprocess=False):
+ origin_image = fetch_image(input_path)
+ result = self._parse_single_image(origin_image, prompt_mode, save_dir, filename, source="image", bbox=bbox, fitz_preprocess=fitz_preprocess)
+ result['file_path'] = input_path
+ return [result]
+
+ def parse_pdf(self, input_path, filename, prompt_mode, save_dir):
+ print(f"loading pdf: {input_path}")
+ images_origin = load_images_from_pdf(input_path, dpi=self.dpi)
+ total_pages = len(images_origin)
+ tasks = [
+ {
+ "origin_image": image,
+ "prompt_mode": prompt_mode,
+ "save_dir": save_dir,
+ "save_name": filename,
+ "source":"pdf",
+ "page_idx": i,
+ } for i, image in enumerate(images_origin)
+ ]
+
+ def _execute_task(task_args):
+ return self._parse_single_image(**task_args)
+
+ if self.use_hf:
+ num_thread = 1
+ else:
+ num_thread = min(total_pages, self.num_thread)
+ print(f"Parsing PDF with {total_pages} pages using {num_thread} threads...")
+
+ results = []
+ with ThreadPool(num_thread) as pool:
+ with tqdm(total=total_pages, desc="Processing PDF pages") as pbar:
+ for result in pool.imap_unordered(_execute_task, tasks):
+ results.append(result)
+ pbar.update(1)
+
+ results.sort(key=lambda x: x["page_no"])
+ for i in range(len(results)):
+ results[i]['file_path'] = input_path
+ return results
+
+ def parse_file(self,
+ input_path,
+ output_dir="",
+ prompt_mode="prompt_layout_all_en",
+ bbox=None,
+ fitz_preprocess=False
+ ):
+ output_dir = output_dir or self.output_dir
+ output_dir = os.path.abspath(output_dir)
+ filename, file_ext = os.path.splitext(os.path.basename(input_path))
+ save_dir = os.path.join(output_dir, filename)
+ os.makedirs(save_dir, exist_ok=True)
+
+ if file_ext == '.pdf':
+ results = self.parse_pdf(input_path, filename, prompt_mode, save_dir)
+ elif file_ext in image_extensions:
+ results = self.parse_image(input_path, filename, prompt_mode, save_dir, bbox=bbox, fitz_preprocess=fitz_preprocess)
+ else:
+ raise ValueError(f"file extension {file_ext} not supported, supported extensions are {image_extensions} and pdf")
+
+ print(f"Parsing finished, results saving to {save_dir}")
+ with open(os.path.join(output_dir, os.path.basename(filename)+'.jsonl'), 'w', encoding="utf-8") as w:
+ for result in results:
+ w.write(json.dumps(result, ensure_ascii=False) + '\n')
+
+ return results
+
+
+
+def main():
+ prompts = list(dict_promptmode_to_prompt.keys())
+ parser = argparse.ArgumentParser(
+ description="dots.ocr Multilingual Document Layout Parser",
+ )
+
+ parser.add_argument(
+ "input_path", type=str,
+ help="Input PDF/image file path"
+ )
+
+ parser.add_argument(
+ "--output", type=str, default="./output",
+ help="Output directory (default: ./output)"
+ )
+
+ parser.add_argument(
+ "--prompt", choices=prompts, type=str, default="prompt_layout_all_en",
+ help="prompt to query the model, different prompts for different tasks"
+ )
+ parser.add_argument(
+ '--bbox',
+ type=int,
+ nargs=4,
+ metavar=('x1', 'y1', 'x2', 'y2'),
+ help='should give this argument if you want to prompt_grounding_ocr'
+ )
+ parser.add_argument(
+ "--ip", type=str, default="localhost",
+ help=""
+ )
+ parser.add_argument(
+ "--port", type=int, default=8001,
+ help="vllm server port"
+ )
+ parser.add_argument(
+ "--model_name", type=str, default="model",
+ help=""
+ )
+ parser.add_argument(
+ "--temperature", type=float, default=0.1,
+ help=""
+ )
+ parser.add_argument(
+ "--top_p", type=float, default=1.0,
+ help=""
+ )
+ parser.add_argument(
+ "--dpi", type=int, default=200,
+ help=""
+ )
+ parser.add_argument(
+ "--max_completion_tokens", type=int, default=16384,
+ help=""
+ )
+ parser.add_argument(
+ "--num_thread", type=int, default=32,
+ help=""
+ )
+ # parser.add_argument(
+ # "--fitz_preprocess", type=bool, default=False,
+ # help="False will use tikz dpi upsample pipeline, good for images which has been render with low dpi, but maybe result in higher computational costs"
+ # )
+ parser.add_argument(
+ "--min_pixels", type=int, default=None,
+ help=""
+ )
+ parser.add_argument(
+ "--max_pixels", type=int, default=None,
+ help=""
+ )
+ parser.add_argument(
+ "--use_hf", type=bool, default=False,
+ help=""
+ )
+ args = parser.parse_args()
+
+ dots_ocr_parser = DotsOCRParser(
+ ip=args.ip,
+ port=args.port,
+ model_name=args.model_name,
+ temperature=args.temperature,
+ top_p=args.top_p,
+ max_completion_tokens=args.max_completion_tokens,
+ num_thread=args.num_thread,
+ dpi=args.dpi,
+ output_dir=args.output,
+ min_pixels=args.min_pixels,
+ max_pixels=args.max_pixels,
+ use_hf=args.use_hf,
+ )
+
+ result = dots_ocr_parser.parse_file(
+ args.input_path,
+ prompt_mode=args.prompt,
+ bbox=args.bbox,
+ )
+
+
+
+if __name__ == "__main__":
+ main()
diff --git a/rag_factory/parser/Parser_Dotsocr/readme.md b/rag_factory/parser/Parser_Dotsocr/readme.md
new file mode 100644
index 0000000..d83b416
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/readme.md
@@ -0,0 +1,47 @@
+# dots.ocr parser
+
+This parser is based on the [dots.ocr](https://github.com/rednote-hilab/dots.ocr) model. See [dots.ocr](https://github.com/rednote-hilab/dots.ocr) for details.
+
+## 1. Installation
+
+```
+python>=3.10
+pip install torch==2.7.0 torchvision==0.22.0 torchaudio==2.7.0 --index-url https://download.pytorch.org/whl/cu128
+pip install -r requirements.txt
+# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2
+
+```
+
+```
+# model download
+from huggingface_hub import snapshot_download
+snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir="Parser_Dotsocr/weights/DotsOCR", local_dir_use_symlinks=False, resume_download=True)
+
+or
+
+from modelscope import snapshot_download
+snapshot_download(repo_id="rednote-hilab/dots.ocr", local_dir=model_dir)
+```
+
+## 2. vLLM inference
+
+Using vLLM for faster speed ( based on vllm==0.9.1 )
+
+```
+python vllm_launch.py --model_path weights/DotsOCR
+```
+
+## 3. Document parse
+
+```
+python parser.py pdf_path.pdf (or pdfs_dir)
+```
+
+If you want to parse document with transformers,add `--use_hf=True`
+
+## 4. Figure understand
+
+Use vl model to understand content in parsed picture. Please obtain pdf layout parsed result first.
+```
+python fig_recognize.py --output output
+```
diff --git a/rag_factory/parser/Parser_Dotsocr/vllm_launch.py b/rag_factory/parser/Parser_Dotsocr/vllm_launch.py
new file mode 100644
index 0000000..95d6fbc
--- /dev/null
+++ b/rag_factory/parser/Parser_Dotsocr/vllm_launch.py
@@ -0,0 +1,65 @@
+import os
+import subprocess
+import sys
+from pathlib import Path
+import argparse
+
+def launch_vllm_server(hf_model_path="weights/DotsOCR", num_gpus="0", gpu_memory_utilization=0.95, port=8001):
+ # 1. 检查模型路径
+ model_path = Path(hf_model_path).resolve()
+ if not model_path.exists():
+ print(f"error: 模型路径不存在: {model_path}")
+ sys.exit(1)
+
+ # 2. 设置环境变量
+ os.environ["hf_model_path"] = str(model_path)
+ os.environ["PYTHONPATH"] = f"{model_path.parent}:{os.environ.get('PYTHONPATH', '')}" #
+ os.environ["CUDA_VISIBLE_DEVICES"] = num_gpus
+ os.environ["OPENAI_API_BASE"] = f"http://localhost:{port}/v1"
+ os.environ["OPENAI_API_KEY"] = "EMPTY"
+
+ # 3. 修改 vllm CLI 添加模型
+ try:
+ vllm_path = subprocess.check_output(["which", "vllm"], text=True).strip()
+ with open(vllm_path, "r") as f:
+ vllm_content = f.read()
+
+ inject_line = "from DotsOCR import modeling_dots_ocr_vllm"
+ if inject_line not in vllm_content:
+ print("修改 vllm CLI 引入 DotsOCR 模型...")
+ sed_cmd = f"sed -i '/^from vllm\\.entrypoints\\.cli\\.main import main$/a\\{inject_line}' {vllm_path}"
+ subprocess.run(sed_cmd, shell=True, check=True)
+ else:
+ print("vllm CLI 已包含 DotsOCR 模型")
+
+ except subprocess.CalledProcessError as e:
+ print(f"error: 获取 vllm 路径失败: {e}")
+ sys.exit(1)
+
+ # 4. 启动 vllm server
+ print(" 正在启动 vLLM 服务...")
+ cmd = [
+ "vllm", "serve", str(model_path),
+ "--tensor-parallel-size", "1",
+ "--gpu-memory-utilization", str(gpu_memory_utilization),
+ "--chat-template-content-format", "string",
+ "--served-model-name", "model",
+ "--trust-remote-code",
+ "--port", str(port)
+ ]
+
+ try:
+ subprocess.run(cmd, check=True)
+ except subprocess.CalledProcessError as e:
+ print(f"error: vLLM 启动失败: {e}")
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Launch vLLM server with dots_ocr model.")
+ parser.add_argument("--model_path", type=str, required=True, help="Path to your downloaded model weights, Please use a directory name without periods")
+ parser.add_argument("--gpus", type=str, default="0", help="GPU device ID(s)")
+ parser.add_argument("--gpu_memory_utilization", type=float, default=0.95, help="Desired GPU memory utilization percentage")
+ parser.add_argument("--port", type=int, default=8001, help="Port to launch the vLLM server on")
+
+ args = parser.parse_args()
+
+ launch_vllm_server(args.model_path, args.gpus, args.gpu_memory_utilization, args.port)
diff --git a/rag_factory/parser/__init__.py b/rag_factory/parser/__init__.py
index 22bd7c9..5595ce8 100644
--- a/rag_factory/parser/__init__.py
+++ b/rag_factory/parser/__init__.py
@@ -1,4 +1,3 @@
-from .Parser_Docling import DoclingParser
-from .Parser_MinerU import MineruParser
+from Parser_MinerU import MinerUPdfParser
-__all__ = ["DoclingParser", "MineruParser"]
\ No newline at end of file
+__all__ = ["MinerUPdfParser"]
\ No newline at end of file
diff --git a/rag_factory/parser/main.py b/rag_factory/parser/main.py
deleted file mode 100644
index e0b6e02..0000000
--- a/rag_factory/parser/main.py
+++ /dev/null
@@ -1,135 +0,0 @@
-from .Parser_MinerU import MineruParser
-from .Parser_Docling import DoclingParser
-import argparse
-
-
-def main():
- """
- Main function to run the document parser from command line
- """
- parser = argparse.ArgumentParser(
- description="Parse documents using MinerU 2.0 or Docling"
- )
- parser.add_argument("file_path", help="Path to the document to parse")
- parser.add_argument("--output", "-o", help="Output directory path")
- parser.add_argument(
- "--method",
- "-m",
- choices=["auto", "txt", "ocr"],
- default="auto",
- help="Parsing method (auto, txt, ocr)",
- )
- parser.add_argument(
- "--lang",
- "-l",
- help="Document language for OCR optimization (e.g., ch, en, ja)",
- )
- parser.add_argument(
- "--backend",
- "-b",
- choices=[
- "pipeline",
- "vlm-transformers",
- "vlm-sglang-engine",
- "vlm-sglang-client",
- ],
- default="pipeline",
- help="Parsing backend",
- )
- parser.add_argument(
- "--device",
- "-d",
- help="Inference device (e.g., cpu, cuda, cuda:0, npu, mps)",
- )
- parser.add_argument(
- "--source",
- choices=["huggingface", "modelscope", "local"],
- default="huggingface",
- help="Model source",
- )
- parser.add_argument(
- "--no-formula",
- action="store_true",
- help="Disable formula parsing",
- )
- parser.add_argument(
- "--no-table",
- action="store_true",
- help="Disable table parsing",
- )
- parser.add_argument(
- "--stats", action="store_true", help="Display content statistics"
- )
- parser.add_argument(
- "--check",
- action="store_true",
- help="Check parser installation",
- )
- parser.add_argument(
- "--parser",
- choices=["mineru", "docling"],
- default="mineru",
- help="Parser selection",
- )
- parser.add_argument(
- "--vlm_url",
- help="When the backend is `vlm-sglang-client`, you need to specify the server_url, for example:`http://127.0.0.1:30000`",
- )
-
- args = parser.parse_args()
-
- # Check installation if requested
- if args.check:
- doc_parser = DoclingParser() if args.parser == "docling" else MineruParser()
- if doc_parser.check_installation():
- print(f"✅ {args.parser.title()} is properly installed")
- return 0
- else:
- print(f"❌ {args.parser.title()} installation check failed")
- return 1
-
- try:
- # Parse the document
- doc_parser = DoclingParser() if args.parser == "docling" else MineruParser()
- content_list = doc_parser.parse_document(
- file_path=args.file_path,
- method=args.method,
- output_dir=args.output,
- lang=args.lang,
- backend=args.backend,
- device=args.device,
- source=args.source,
- formula=not args.no_formula,
- table=not args.no_table,
- vlm_url=args.vlm_url,
- )
-
- print(f"✅ Successfully parsed: {args.file_path}")
- print(f"📊 Extracted {len(content_list)} content blocks")
-
- # Display statistics if requested
- if args.stats:
- print("\n📈 Document Statistics:")
- print(f"Total content blocks: {len(content_list)}")
-
- # Count different types of content
- content_types = {}
- for item in content_list:
- if isinstance(item, dict):
- content_type = item.get("type", "unknown")
- content_types[content_type] = content_types.get(content_type, 0) + 1
-
- if content_types:
- print("\n📋 Content Type Distribution:")
- for content_type, count in sorted(content_types.items()):
- print(f" • {content_type}: {count}")
-
- except Exception as e:
- print(f"❌ Error: {str(e)}")
- return 1
-
- return 0
-
-
-if __name__ == "__main__":
- exit(main())
\ No newline at end of file
diff --git a/rag_factory/rerankers/Reranker_Base.py b/rag_factory/rerankers/Reranker_Base.py
new file mode 100644
index 0000000..518b81e
--- /dev/null
+++ b/rag_factory/rerankers/Reranker_Base.py
@@ -0,0 +1,27 @@
+from abc import ABC, abstractmethod
+from ..Retrieval import Document
+import warnings
+
+class RerankerBase(ABC):
+ """
+ Reranker 基类,所有 Reranker 应该继承此类并实现 rerank 方法。
+ 不建议直接实例化本类。
+
+ 使用方法:
+ class MyReranker(RerankerBase):
+ def rerank(self, query: str, documents: list[str], **kwargs) -> list[float]:
+ # 实现具体的重排序逻辑
+ ...
+ """
+ def __init__(self):
+ if type(self) is RerankerBase:
+ warnings.warn("RerankerBase 是抽象基类,不应直接实例化。请继承并实现 rerank 方法。", UserWarning)
+
+ @abstractmethod
+ def rerank(self, query: str, documents: list[Document], **kwargs) -> list[Document]:
+ """
+ Rerank the documents based on the query.
+ 需要子类实现。
+ """
+ warnings.warn("调用了未实现的 rerank 方法。请在子类中实现该方法。", UserWarning)
+ raise NotImplementedError("子类必须实现 rerank 方法。")
diff --git a/rag_factory/rerankers/Reranker_Qwen3.py b/rag_factory/rerankers/Reranker_Qwen3.py
new file mode 100644
index 0000000..b909c2b
--- /dev/null
+++ b/rag_factory/rerankers/Reranker_Qwen3.py
@@ -0,0 +1,75 @@
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+from .Reranker_Base import RerankerBase
+from ..Retrieval.RetrieverBase import Document
+
+class Qwen3Reranker(RerankerBase):
+ def __init__(self, model_name_or_path: str, max_length: int = 4096, instruction=None, attn_type='causal', device_id="cuda:0", **kwargs):
+ super().__init__()
+ device = torch.device(device_id)
+ self.max_length = max_length
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True, padding_side='left')
+ self.lm = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.float16)
+ self.lm = self.lm.to(device).eval()
+ self.token_false_id = self.tokenizer.convert_tokens_to_ids("no")
+ self.token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
+ self.prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
+ self.suffix = "<|im_end|>\n<|im_start|>assistant\n\n\n\n\n"
+ self.prefix_tokens = self.tokenizer.encode(self.prefix, add_special_tokens=False)
+ self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
+ self.instruction = instruction or "Given the user query, retrieval the relevant passages"
+ self.device = device
+
+ def format_instruction(self, instruction, query, doc):
+ if instruction is None:
+ instruction = self.instruction
+ output = f": {instruction}\n: {query}\n: {doc}"
+ return output
+
+ def process_inputs(self, pairs):
+ out = self.tokenizer(
+ pairs, padding=False, truncation='longest_first',
+ return_attention_mask=False, max_length=self.max_length - len(self.prefix_tokens) - len(self.suffix_tokens)
+ )
+ for i, ele in enumerate(out['input_ids']):
+ out['input_ids'][i] = self.prefix_tokens + ele + self.suffix_tokens
+ out = self.tokenizer.pad(out, padding=True, return_tensors="pt", max_length=self.max_length)
+ for key in out:
+ out[key] = out[key].to(self.lm.device)
+ return out
+
+ @torch.no_grad()
+ def compute_logits(self, inputs, **kwargs):
+ batch_scores = self.lm(**inputs).logits[:, -1, :]
+ true_vector = batch_scores[:, self.token_true_id]
+ false_vector = batch_scores[:, self.token_false_id]
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
+ scores = batch_scores[:, 1].exp().tolist()
+ return scores
+
+ def compute_scores(self, pairs, instruction=None, **kwargs):
+ pairs = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
+ inputs = self.process_inputs(pairs)
+ scores = self.compute_logits(inputs)
+ return scores
+
+ def rerank(self, query: str, documents: list[Document], k: int = None, batch_size: int = 8, **kwargs) -> list[Document]:
+ # 1. 组装 (query, doc.content) 对
+ pairs = [(query, doc.content) for doc in documents]
+
+ # 2. 计算分数
+ all_scores = []
+ for i in range(0, len(pairs), batch_size):
+ batch_pairs = pairs[i:i+batch_size]
+ batch_scores = self.compute_scores(batch_pairs)
+ all_scores.extend(batch_scores)
+ scores = all_scores
+
+ # 3. 按分数排序
+ doc_score_pairs = list(zip(documents, scores))
+ doc_score_pairs.sort(key=lambda x: x[1], reverse=True)
+ reranked_docs = [doc for doc, score in doc_score_pairs]
+ if k is not None:
+ reranked_docs = reranked_docs[:k]
+ return reranked_docs
\ No newline at end of file
diff --git a/rag_factory/rerankers/__init__.py b/rag_factory/rerankers/__init__.py
index e69de29..be8f542 100644
--- a/rag_factory/rerankers/__init__.py
+++ b/rag_factory/rerankers/__init__.py
@@ -0,0 +1,5 @@
+from .Reranker_Base import RerankerBase
+from .Reranker_Qwen3 import Qwen3Reranker
+from .registry import RerankerRegistry
+
+__all__ = ["RerankerBase", "Qwen3Reranker", "RerankerRegistry"]
\ No newline at end of file
diff --git a/rag_factory/rerankers/registry.py b/rag_factory/rerankers/registry.py
new file mode 100644
index 0000000..d817c4a
--- /dev/null
+++ b/rag_factory/rerankers/registry.py
@@ -0,0 +1,93 @@
+from typing import Dict, Type, Any, Optional, List
+import logging
+from .Reranker_Base import RerankerBase
+from .Reranker_Qwen3 import Qwen3Reranker
+
+class RerankerRegistry:
+ _rerankers: Dict[str, Type[RerankerBase]] = {}
+
+ @classmethod
+ def register(cls, name: str, reranker_class: Type[RerankerBase]):
+ """注册重排序器类到注册表中
+
+ Args:
+ name: 重排序器的名称
+ reranker_class: 重排序器类,必须继承自RerankerBase
+ """
+ if not issubclass(reranker_class, RerankerBase):
+ raise ValueError(f"重排序器类 {reranker_class} 必须继承自 RerankerBase")
+
+ if name in cls._rerankers:
+ logging.warning(f"重排序器 '{name}' 已存在,将被覆盖")
+
+ cls._rerankers[name] = reranker_class
+ logging.info(f"成功注册重排序器: {name}")
+
+ @classmethod
+ def create(cls, name: str, **kwargs) -> RerankerBase:
+ """根据名称获取重排序器实例
+
+ Args:
+ name: 重排序器名称
+ **kwargs: 传递给重排序器构造函数的参数
+
+ Returns:
+ RerankerBase: 重排序器实例
+
+ Raises:
+ ValueError: 当重排序器未注册时抛出
+ """
+ if name not in cls._rerankers:
+ available = list(cls._rerankers.keys())
+ raise ValueError(f"未找到重排序器 '{name}'。可用的重排序器: {available}")
+
+ reranker_class = cls._rerankers[name]
+ return reranker_class(**kwargs)
+
+ @classmethod
+ def list_rerankers(cls) -> List[str]:
+ """获取所有已注册的重排序器名称列表
+
+ Returns:
+ List[str]: 重排序器名称列表
+ """
+ return list(cls._rerankers.keys())
+
+ @classmethod
+ def is_registered(cls, name: str) -> bool:
+ """检查重排序器是否已注册
+
+ Args:
+ name: 重排序器名称
+
+ Returns:
+ bool: 如果已注册返回True,否则返回False
+ """
+ return name in cls._rerankers
+
+ @classmethod
+ def unregister(cls, name: str) -> bool:
+ """注销重排序器
+
+ Args:
+ name: 要注销的重排序器名称
+
+ Returns:
+ bool: 如果成功注销返回True,如果重排序器不存在返回False
+ """
+ if name in cls._rerankers:
+ del cls._rerankers[name]
+ logging.info(f"成功注销重排序器: {name}")
+ return True
+ else:
+ logging.warning(f"尝试注销不存在的重排序器: {name}")
+ return False
+
+ @classmethod
+ def clear_all(cls):
+ """清除所有已注册的重排序器"""
+ cls._rerankers.clear()
+ logging.info("已清除所有重排序器")
+
+# 注册默认的重排序器
+RerankerRegistry.register("qwen3", Qwen3Reranker)
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
index ac808e5..823e164 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,4 +12,22 @@ neo4j
aioboto3
llama-index
llama-index-core
-peewee
\ No newline at end of file
+peewee
+
+mineru[core]
+rank_bm25
+faiss_gpu
+
+
+
+# streamlit
+PyMuPDF
+openai
+qwen_vl_utils
+transformers==4.51.3
+huggingface_hub
+modelscope
+flash-attn==2.8.0.post2
+# for GLIBC 2.31, please use flash-attn==2.7.4.post1 instead of flash-attn==2.8.0.post2
+accelerate
+dashscope