-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
233 lines (188 loc) · 8.61 KB
/
main.py
File metadata and controls
233 lines (188 loc) · 8.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
"""大学公众号文章审核 Agent - 入口文件"""
import json
import mimetypes
import sys
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from pathlib import Path
from urllib.parse import urlparse
import yaml
from agent.client import create_llm_client
from agent.workflow import ArticleReviewWorkflow
WORKSPACE_DIR = Path(__file__).parent
# 默认首页指向前端展示页,便于服务模式直接打开使用。
DEFAULT_INDEX = WORKSPACE_DIR / "showcase" / "article-review-agent.html"
DEFAULT_MAX_ITERATIONS = 5
MIN_MAX_ITERATIONS = 1
MAX_MAX_ITERATIONS = 10
class ArticleReviewAgent:
"""文章审核 Agent"""
def __init__(self, config_dir: str = "config"):
self.config_dir = Path(config_dir)
self._load_configs()
def _load_configs(self):
"""加载配置文件"""
# 加载 prompts
prompts_path = self.config_dir / "prompts.yaml"
with open(prompts_path, "r", encoding="utf-8") as f:
prompts_config = yaml.safe_load(f)
self.review_prompts = prompts_config["prompts"]["review"]
self.adjust_prompts = prompts_config["prompts"]["adjust"]
# 加载模型配置
models_path = self.config_dir / "models.yaml"
with open(models_path, "r", encoding="utf-8") as f:
models_config = yaml.safe_load(f)
self.review_llm = create_llm_client(models_config["models"]["review"])
self.adjust_llm = create_llm_client(models_config["models"]["adjust"])
# 加载 Markdown 规则并动态注入 prompt。
self.rules_markdown, self.review_rules = self._load_rules_markdown()
review_system = self.review_prompts.get("system", "")
if "{{RULES_MD}}" in review_system:
self.review_prompts["system"] = review_system.replace("{{RULES_MD}}", self.rules_markdown)
else:
self.review_prompts["system"] = f"{review_system}\n\n{self.rules_markdown}".strip()
def _load_rules_markdown(self) -> tuple[str, dict]:
"""读取 rules.md,并作为提示词文本注入。"""
rules_md_path = self.config_dir / "rules.md"
if not rules_md_path.exists():
return "", {}
with open(rules_md_path, "r", encoding="utf-8") as f:
rules_markdown = f.read()
# 规则改为纯 Markdown 条目,当前不从文件内做结构化解析。
return rules_markdown, {}
def review(self, article: str, title: str = "", max_iterations: int = DEFAULT_MAX_ITERATIONS) -> dict:
"""审核文章
Args:
article: 文章内容
title: 文章标题(可选)
max_iterations: 最大调整迭代次数
Returns:
dict: 包含审核结果的字典
"""
workflow = ArticleReviewWorkflow(
review_llm=self.review_llm.generate,
adjust_llm=self.adjust_llm.generate,
review_prompts=self.review_prompts,
adjust_prompts=self.adjust_prompts,
review_rules=self.review_rules,
max_iterations=max_iterations,
)
result = workflow.review(article, title)
return {
"status": result.status,
"approved": result.status == "approved",
"iteration": result.iteration,
"final_article": result.adjusted_article,
"review_result": result.review_result.model_dump() if result.review_result else None,
}
def create_http_handler(agent: ArticleReviewAgent):
# 通过闭包注入 agent,避免在每个请求里重复初始化模型客户端。
class ReviewHandler(BaseHTTPRequestHandler):
def _send_json(self, payload: dict, code: int = 200):
data = json.dumps(payload, ensure_ascii=False).encode("utf-8")
self.send_response(code)
self.send_header("Content-Type", "application/json; charset=utf-8")
self.send_header("Content-Length", str(len(data)))
self.end_headers()
self.wfile.write(data)
def _send_file(self, file_path: Path):
if not file_path.exists() or not file_path.is_file():
self._send_json({"error": "Not Found"}, 404)
return
content = file_path.read_bytes()
mime_type, _ = mimetypes.guess_type(str(file_path))
self.send_response(200)
self.send_header("Content-Type", f"{mime_type or 'application/octet-stream'}")
self.send_header("Content-Length", str(len(content)))
self.end_headers()
self.wfile.write(content)
def do_GET(self):
parsed = urlparse(self.path)
path = parsed.path
if path == "/health":
self._send_json({"ok": True})
return
if path == "/api/settings":
self._send_json(
{
"default_max_iterations": DEFAULT_MAX_ITERATIONS,
"min_max_iterations": MIN_MAX_ITERATIONS,
"max_max_iterations": MAX_MAX_ITERATIONS,
}
)
return
if path == "/":
self._send_file(DEFAULT_INDEX)
return
if path.startswith("/showcase/"):
target = (WORKSPACE_DIR / path.lstrip("/")).resolve()
showcase_root = (WORKSPACE_DIR / "showcase").resolve()
# 限制只读取 showcase 目录,避免目录穿越访问任意文件。
if showcase_root in target.parents or target == showcase_root:
self._send_file(target)
return
self._send_json({"error": "Forbidden"}, 403)
return
self._send_json({"error": "Not Found"}, 404)
def do_POST(self):
parsed = urlparse(self.path)
if parsed.path != "/api/review":
self._send_json({"error": "Not Found"}, 404)
return
try:
content_length = int(self.headers.get("Content-Length", "0"))
raw_body = self.rfile.read(content_length) if content_length > 0 else b"{}"
payload = json.loads(raw_body.decode("utf-8"))
article = (payload.get("article") or "").strip()
title = (payload.get("title") or "").strip()
max_iterations = int(payload.get("max_iterations", DEFAULT_MAX_ITERATIONS))
if not article:
self._send_json({"error": "article 不能为空"}, 400)
return
result = agent.review(
article=article,
title=title,
# 迭代次数做边界收敛,防止异常参数导致超长任务。
max_iterations=max(MIN_MAX_ITERATIONS, min(max_iterations, MAX_MAX_ITERATIONS)),
)
self._send_json(result, 200)
except json.JSONDecodeError:
self._send_json({"error": "请求体必须是合法 JSON"}, 400)
except Exception as exc:
self._send_json({"error": f"服务内部错误: {exc}"}, 500)
return ReviewHandler
def run_server(host: str = "127.0.0.1", port: int = 8000):
"""启动 HTTP 服务。"""
# 服务启动时初始化一次 agent,后续请求复用。
agent = ArticleReviewAgent(config_dir=str(WORKSPACE_DIR / "config"))
handler_cls = create_http_handler(agent)
server = ThreadingHTTPServer((host, port), handler_cls)
print(f"服务已启动: http://{host}:{port}")
print("接口: POST /api/review")
server.serve_forever()
def main():
"""命令行入口"""
if len(sys.argv) >= 2 and sys.argv[1] == "serve":
port = int(sys.argv[2]) if len(sys.argv) >= 3 else 8000
run_server(port=port)
return
if len(sys.argv) < 2:
print("用法: python main.py serve [端口]")
print("或: python main.py <文章内容> [标题]")
print("或通过 stdin 传入文章内容: python main.py - [标题]")
sys.exit(1)
if sys.argv[1] == "-":
# 从 stdin 读取
article = sys.stdin.read()
title = sys.argv[2] if len(sys.argv) > 2 else ""
else:
article = sys.argv[1]
title = sys.argv[2] if len(sys.argv) > 2 else ""
agent = ArticleReviewAgent()
result = agent.review(article, title)
print(f"审核状态: {result['status']}")
print(f"是否通过: {result['approved']}")
print(f"迭代次数: {result['iteration']}")
if result["review_result"]:
print(f"审核评语: {result['review_result'].get('comment', 'N/A')}")
if __name__ == "__main__":
main()