Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions backend/app/api/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def build_graph():

# 检查配置
errors = []
if not Config.ZEP_API_KEY:
if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY:
errors.append(t('api.zepApiKeyMissing'))
if errors:
logger.error(f"配置错误: {errors}")
Expand Down Expand Up @@ -387,7 +387,7 @@ def build_task():
)

# 创建图谱构建服务
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
builder = GraphBuilderService()

# 分块
task_manager.update_task(
Expand Down Expand Up @@ -572,13 +572,13 @@ def get_graph_data(graph_id: str):
获取图谱数据(节点和边)
"""
try:
if not Config.ZEP_API_KEY:
if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
}), 500

builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
builder = GraphBuilderService()
graph_data = builder.get_graph_data(graph_id)

return jsonify({
Expand All @@ -600,13 +600,13 @@ def delete_graph(graph_id: str):
删除Zep图谱
"""
try:
if not Config.ZEP_API_KEY:
if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
}), 500

builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
builder = GraphBuilderService()
builder.delete_graph(graph_id)

return jsonify({
Expand Down
6 changes: 3 additions & 3 deletions backend/app/api/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def get_graph_entities(graph_id: str):
enrich: 是否获取相关边信息(默认true)
"""
try:
if not Config.ZEP_API_KEY:
if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_graph_entities(graph_id: str):
def get_entity_detail(graph_id: str, entity_uuid: str):
"""获取单个实体的详细信息"""
try:
if not Config.ZEP_API_KEY:
if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
Expand Down Expand Up @@ -127,7 +127,7 @@ def get_entity_detail(graph_id: str, entity_uuid: str):
def get_entities_by_type(graph_id: str, entity_type: str):
"""获取指定类型的所有实体"""
try:
if not Config.ZEP_API_KEY:
if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
Expand Down
16 changes: 12 additions & 4 deletions backend/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,18 @@ class Config:
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')

# Zep配置
# 图谱存储配置
# 默认优先使用 SQLite 本地存储;如果显式提供 ZEP_API_KEY,也可以继续走远端 Zep。
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')

GRAPH_STORAGE_BACKEND = os.environ.get(
'MIROFISH_GRAPH_STORAGE',
'sqlite' if not os.environ.get('ZEP_API_KEY') else 'zep'
).strip().lower()
LOCAL_GRAPH_DB_PATH = os.environ.get(
'MIROFISH_GRAPH_DB_PATH',
os.path.join(os.path.dirname(__file__), '../uploads/local_graphs.sqlite3')
)

# 文件上传配置
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads')
Expand Down Expand Up @@ -69,7 +78,6 @@ def validate(cls) -> list[str]:
errors: list[str] = []
if not cls.LLM_API_KEY:
errors.append("LLM_API_KEY 未配置")
if not cls.ZEP_API_KEY:
errors.append("ZEP_API_KEY 未配置")
# 图谱存储现在默认可以使用 SQLite,本地模式下不再强制要求 ZEP_API_KEY。
return errors

108 changes: 96 additions & 12 deletions backend/app/services/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

from ..config import Config
from ..models.task import TaskManager, TaskStatus
from ..utils.llm_client import LLMClient
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from .local_graph_store import LocalGraphStore
from .text_processor import TextProcessor
from ..utils.locale import t, get_locale, set_locale

Expand Down Expand Up @@ -45,11 +47,23 @@ class GraphBuilderService:

def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")

self.client = Zep(api_key=self.api_key)
self.use_local_storage = Config.GRAPH_STORAGE_BACKEND == 'sqlite' or not self.api_key
self.task_manager = TaskManager()
self.local_store = LocalGraphStore() if self.use_local_storage else None
self.llm_client = None

if self.use_local_storage:
# Local SQLite mode: no Zep client needed.
if Config.LLM_API_KEY:
try:
self.llm_client = LLMClient()
except Exception:
self.llm_client = None
self.client = None
else:
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.client = Zep(api_key=self.api_key)

def build_graph_async(
self,
Expand Down Expand Up @@ -191,19 +205,31 @@ def _build_graph_worker(
self.task_manager.fail_task(task_id, error_msg)

def create_graph(self, name: str) -> str:
"""创建Zep图谱(公开方法)"""
"""创建图谱(公开方法)"""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"


if self.use_local_storage:
self.local_store.create_graph(
graph_id=graph_id,
name=name,
description="MiroFish Social Simulation Graph",
)
return graph_id

self.client.graph.create(
graph_id=graph_id,
name=name,
description="MiroFish Social Simulation Graph"
)

return graph_id

def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
"""设置图谱本体(公开方法)"""
if self.use_local_storage:
self.local_store.set_ontology(graph_id, ontology)
return

import warnings
from typing import Optional
from pydantic import Field
Expand Down Expand Up @@ -290,7 +316,6 @@ def safe_attr_name(attr_name: str) -> str:
entities=entity_types if entity_types else None,
edges=edge_definitions if edge_definitions else None,
)

def add_text_batches(
self,
graph_id: str,
Expand All @@ -299,6 +324,25 @@ def add_text_batches(
progress_callback: Optional[Callable] = None
) -> List[str]:
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
if self.use_local_storage:
total = len(chunks)
if total == 0:
return []
if progress_callback:
progress_callback(t('progress.sendingBatch', current=1, total=1, chunks=total), 0.0)

episode_uuids = self.local_store.extract_and_store_chunks(
graph_id=graph_id,
chunks=chunks,
ontology=(self.local_store.get_graph(graph_id) or {}).get('ontology', {}),
llm_client=self.llm_client,
progress_callback=None,
batch_size=batch_size,
)
if progress_callback:
progress_callback(t('progress.processingComplete', completed=len(episode_uuids), total=len(episode_uuids)), 1.0)
return episode_uuids

episode_uuids = []
total_chunks = len(chunks)

Expand Down Expand Up @@ -351,6 +395,11 @@ def _wait_for_episodes(
timeout: int = 600
):
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
if self.use_local_storage:
if progress_callback:
progress_callback(t('progress.processingComplete', completed=len(episode_uuids), total=len(episode_uuids)), 1.0)
return

if not episode_uuids:
if progress_callback:
progress_callback(t('progress.noEpisodesWait'), 1.0)
Expand Down Expand Up @@ -402,11 +451,29 @@ def _wait_for_episodes(

def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""获取图谱信息"""
if self.use_local_storage:
graph = self.local_store.get_graph(graph_id) or {}
nodes = graph.get("nodes", [])
edges = graph.get("edges", [])
entity_types = set()
for node in nodes:
for label in node.get("labels", []):
if label not in ["Entity", "Node"]:
entity_types.add(label)
return GraphInfo(
graph_id=graph_id,
node_count=len(nodes),
edge_count=len(edges),
entity_types=list(entity_types),
)

# 获取节点(分页)
nodes = fetch_all_nodes(self.client, graph_id)
client = self.client
assert client is not None
nodes = fetch_all_nodes(client, graph_id)

# 获取边(分页)
edges = fetch_all_edges(self.client, graph_id)
edges = fetch_all_edges(client, graph_id)

# 统计实体类型
entity_types = set()
Expand All @@ -433,8 +500,22 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
Returns:
包含nodes和edges的字典,包括时间信息、属性等详细数据
"""
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
if self.use_local_storage:
graph = self.local_store.get_graph(graph_id) or {}
nodes_data = graph.get("nodes", [])
edges_data = graph.get("edges", [])
return {
"graph_id": graph_id,
"nodes": nodes_data,
"edges": edges_data,
"node_count": len(nodes_data),
"edge_count": len(edges_data),
}

client = self.client
assert client is not None
nodes = fetch_all_nodes(client, graph_id)
edges = fetch_all_edges(client, graph_id)

# 创建节点映射用于获取节点名称
node_map = {}
Expand Down Expand Up @@ -502,5 +583,8 @@ def get_graph_data(self, graph_id: str) -> Dict[str, Any]:

def delete_graph(self, graph_id: str):
"""删除图谱"""
if self.use_local_storage:
self.local_store.delete_graph(graph_id)
return
self.client.graph.delete(graph_id=graph_id)

Loading