diff --git a/loopai/agents/Judger/judger_agent.py b/loopai/agents/Judger/judger_agent.py
index 09a3863..a26ec59 100644
--- a/loopai/agents/Judger/judger_agent.py
+++ b/loopai/agents/Judger/judger_agent.py
@@ -26,8 +26,13 @@
from loopai.schema.events import StreamEvent
from loopai.logger import get_logger
+from loopai.common.prompts import PromptLoader
+
logger = get_logger()
+On = False
+Vllm_Start_Error = False
+
def _isNotNone(value):
return value != "" and value is not None
@@ -51,7 +56,7 @@ def extract_num(cp: str) -> int:
# 使用 min 选择:先按距离排序,再按数值本身排序
best_cp = min(checkpoints, key=lambda cp: (abs(extract_num(cp) - best_step), extract_num(cp)))
- return best_cp
+ return best_cp
class JudgerAgent(BaseAgent):
@property
@@ -74,7 +79,7 @@ def get_check_required_fields_node(self):
def check_required_fields(state: LoopAIState, runtime: Runtime[RuntimeContext]):
writer = get_stream_writer()
required_fields = {
- 'judger':["eval_api_key", "eval_temperature",
+ 'judger':["eval_temperature",
"eval_top_p", "eval_problem_path",
"eval_case_num", "eval_task_type"
],
@@ -96,8 +101,7 @@ def check_required_fields(state: LoopAIState, runtime: Runtime[RuntimeContext]):
missing_fields = get_missing_fields({'judger':["eval_model_path"]}, state)
"""vllm启动检查"""
- # base_url = state.get("judger", {}).get("eval_base_url", None)
- base_url = None
+ base_url = state.get("judger", {}).get("eval_base_url", None)
task_type = state.get("judger", {}).get("eval_task_type", "")
logger.info(f"base_url:{base_url}")
if(_isNotNone(base_url) and task_type!="general_text"):
@@ -178,18 +182,17 @@ def check_required_fields(state: LoopAIState, runtime: Runtime[RuntimeContext]):
check_file_fields = True
if task_type == "code" or task_type =="text2sql":
- check_file_fields = check_jsonl_fields(state.get("judger", {}).get("eval_problem_path", ""), required_fields)
+ check_file_fields,error_details = check_jsonl_fields(state.get("judger", {}).get("eval_problem_path", ""), required_fields)
if check_file_fields is not True:
logger.info("$"*50)
- logger.info(["eval_problem_path"])
state['exception'] = 'ConfigerError'
state['next_to'] = 'config_node'
state['automated_query'] = self.prompt_loader(
"automated_query", "judger_missing_fields_prompt")
- state.setdefault('configer',{})['configer_error'] = f'Wrong required fields: {json.dumps({"wrong_fields": "eval_problem_path"}, ensure_ascii=False)}'
+ state.setdefault('configer',{})['configer_error'] = json.dumps(error_details, ensure_ascii=False,indent=4)
goto_node = runtime.context['exception_navigate']
- logger.info(f'found wrong fields, required fields missing in the file, goto {goto_node}')
+ logger.warning(f'found wrong fields, required fields missing in the file, goto {goto_node}')
return Command(
update=state,
goto=goto_node,
@@ -240,6 +243,7 @@ def vllm_kill_node(state: LoopAIState) -> LoopAIState:
# vllm_port = state.get("judger", {}).get("eval_vllm_port", DEFAULT_VLLM_PORT)
# base_url = state.get("judger", {}).get("eval_base_url", None)
# base_url = None
+ global On
writer = get_stream_writer()
# 未设置base_url才会进入本地开启程序才会先关闭本地的vllm服务
logger.info("=== 准备关闭 vllm ===")
@@ -280,11 +284,13 @@ def vllm_kill_node(state: LoopAIState) -> LoopAIState:
# message="vllm开启结束",
# data={"msg": f"因已开启自定义vllm服务而跳过该过程"}
# ).json())
+ state["judger"]["eval_base_url"] = None
+ On = not On
return state
@staticmethod
@BaseAgent.set_current
- def vllm_start_node(state: LoopAIState) -> LoopAIState:
+ def vllm_start_node(state: LoopAIState) -> Union[LoopAIState, Command]:
# 设置 cuda_visible_devices
set_gpu(state)
@@ -325,8 +331,17 @@ def vllm_start_node(state: LoopAIState) -> LoopAIState:
message="vllm 启动异常",
data={"msg": f"vllm 启动异常 ,请解决后重新评测:{e}"}
).json())
- # 直接结束 Judger 当前节点,不再跳转父图异常路由
- return state
+ state['exception'] = 'ConfigerError'
+ state['next_to'] = 'config_node'
+ state['automated_query'] = "在上一轮执行中,`judger_agent` 因参数缺失而中断并跳转至 `configer_agent`。现在用户已经补全了缺失的参数。请询问用户是否需要继续执行 `judge` 操作。"
+ state.setdefault('configer',{})['configer_error'] = f'vllm 启动异常,请检查配置是否正确'
+ goto_node = END
+ logger.warning(f'vllm start error, goto {goto_node}')
+ return Command(
+ update=state,
+ goto=goto_node,
+ graph=Command.PARENT
+ )
else:
if writer:
writer(StreamEvent(
@@ -337,6 +352,26 @@ def vllm_start_node(state: LoopAIState) -> LoopAIState:
).json())
return state
+ #@staticmethod
+ #@BaseAgent.set_current
+ #def vllm_start_next(state: LoopAIState) -> str:
+ # """
+ # 如果启动 vllm 发生异常直接 to_end
+ # """
+ # global Vllm_Start_Error
+ # # Vllm启动异常
+ # if not Vllm_Start_Error:
+ # # 正常启动
+ # return "to_data_format"
+ # else:
+ # # 异常则直接结束
+ # Vllm_Start_Error = False
+ # state['judger']['eval_base_url'] = None
+ # logger.warning("==== 发生异常结束 ====")
+ # logger.info("==== 结束 ====")
+ # logger.info(f"=== eval_base_url: {state['judger']['eval_base_url']}")
+ # return "to_end"
+
@staticmethod
@BaseAgent.set_current
def data_format_node(state: LoopAIState) -> LoopAIState:
@@ -469,7 +504,7 @@ def vllm_kill_next(state: LoopAIState) -> str:
"""
base_url = state.get("judger", {}).get("eval_base_url", None)
# 判断是否为空,为空则开启本地
- if not _isNotNone(base_url):
+ if not _isNotNone(base_url) and On:
return "to_vllm_start"
else:
state['judger']['eval_base_url'] = None
@@ -496,6 +531,15 @@ def init_graph(self, **kwargs):
builder.set_entry_point("check_required_fields")
builder.set_finish_point("eval_general_text")
+ #builder.add_conditional_edges(
+ # source="vllm_start",
+ # path=self.vllm_start_next,
+ # path_map={
+ # "to_data_format":"data_format",
+ # "to_end": END,
+ # }
+ #)
+
builder.add_conditional_edges(
source="check_param_type", # 来源节点(从哪个节点跳转出来)
path=self.check_param_type_next, # 路由函数(执行条件判断,返回路由键)
diff --git a/loopai/agents/Judger/nodes/eval_general_text_node.py b/loopai/agents/Judger/nodes/eval_general_text_node.py
index 91798c0..03218fb 100644
--- a/loopai/agents/Judger/nodes/eval_general_text_node.py
+++ b/loopai/agents/Judger/nodes/eval_general_text_node.py
@@ -124,7 +124,7 @@ def _build_model_config(cfg: Dict[str, Any]) -> ModelConfig:
model_name_or_path=model_name_or_path,
is_api=is_api,
api_url=api_url,
- api_key=cfg.get("eval_api_key",""),
+ api_key=cfg.get("eval_api_key","EMPTY"),
temperature=float(cfg.get("eval_temperature", 0.0)),
top_p=float(cfg.get("eval_top_p", 1.0)),
tensor_parallel_size=int(cfg.get("eval_vllm_tensor_parallel_size", 1)),
diff --git a/loopai/agents/Judger/utils/oj/data.py b/loopai/agents/Judger/utils/oj/data.py
index e9bc8de..ffd3389 100644
--- a/loopai/agents/Judger/utils/oj/data.py
+++ b/loopai/agents/Judger/utils/oj/data.py
@@ -7,10 +7,7 @@
def read_problems(evalset_file) -> Dict[str, Dict]:
return {task["task_id"]: task for task in stream_jsonl(evalset_file)}
-import json
-from typing import List
-
-def check_jsonl_fields(filepath: str, require_fields: List[str]) -> bool:
+def check_jsonl_fields(filepath: str, require_fields: List[str]) -> Tuple[bool, List[str]]:
"""
检查 JSONL 文件中的每一行是否都包含指定的字段。
@@ -20,31 +17,43 @@ def check_jsonl_fields(filepath: str, require_fields: List[str]) -> bool:
返回:
如果所有行都包含所有字段,返回 True;否则返回 False
+ 是否全部通过校验, 缺失字段或错误的详细信息列表
"""
+ error_details = []
try:
- with open(filepath, 'r', encoding='utf-8') as f:
- for line_num, line in enumerate(f, start=1):
- line = line.strip()
+ with open(filepath, 'rb') as f:
+ for line_num, line_bytes in enumerate(f, start=1):
+ try:
+ # 去除首尾空白并解码
+ line = line_bytes.decode('utf-8').strip()
+ except UnicodeDecodeError:
+ error_details.append(f"第 {line_num} 行: 无法解码的字符编码")
+ continue
+
if not line: # 跳过空行
continue
+
try:
data = json.loads(line)
except json.JSONDecodeError as e:
- print(f"第 {line_num} 行 JSON 解析错误: {e}")
- return False
-
+ error_details.append(f"第 {line_num} 行: JSON 解析失败 ({e})")
+ continue
+
# 检查每个必需字段
for field in require_fields:
if field not in data:
- print(f"第 {line_num} 行缺少字段 '{field}'")
- return False
- return True
+ error_details.append(f"第 {line_num} 行: 缺少必填字段 '{field}'")
+
+
except FileNotFoundError:
- print(f"文件不存在: {filepath}")
- return False
+ return False, [f"文件未找到: {filepath}"]
+
except Exception as e:
- print(f"读取文件时发生错误: {e}")
- return False
+ return False, [f"读取文件时发生未知错误: {e}"]
+
+ # 如果错误列表为空,说明校验通过
+ is_valid = len(error_details) == 0
+ return is_valid, error_details
def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
diff --git a/loopai/agents/Judger/utils/oj/generate.py b/loopai/agents/Judger/utils/oj/generate.py
index fa548fc..8ca5a43 100644
--- a/loopai/agents/Judger/utils/oj/generate.py
+++ b/loopai/agents/Judger/utils/oj/generate.py
@@ -38,11 +38,11 @@ def generate_sample_code(state):
model = init_model(
model_path=judger_state['eval_model_path'],
base_url=judger_state['eval_base_url'],
- api_key=judger_state['eval_api_key'],
+ api_key=judger_state.get("eval_api_key","EMPTY"),
temperature=judger_state['eval_temperature'],
top_p=judger_state['eval_top_p']
)
- logger.info(f"模型路径:-> base_url: {judger_state['eval_base_url']} ->api_key: {judger_state['eval_api_key']}")
+ logger.info(f"模型路径:-> base_url: {judger_state['eval_base_url']} ->api_key: {judger_state.get("eval_api_key","EMPTY")}")
output_dir = Path(state.get("output_dir"))
problem_path = judger_state['eval_problem_path']
@@ -124,7 +124,7 @@ def generate_sample_text2sql(state):
model = init_model(
model_path=judger_state['eval_model_path'],
base_url=judger_state['eval_base_url'],
- api_key=judger_state['eval_api_key'],
+ api_key='EMPTY',
temperature=judger_state['eval_temperature'],
top_p=judger_state['eval_top_p']
)
diff --git a/loopai/agents/Judger/utils/oj/vllm_killer.py b/loopai/agents/Judger/utils/oj/vllm_killer.py
index 9d1ae36..dd70827 100644
--- a/loopai/agents/Judger/utils/oj/vllm_killer.py
+++ b/loopai/agents/Judger/utils/oj/vllm_killer.py
@@ -103,9 +103,10 @@ def kill_vllm_openai_api_server(
logger.info("强制终止vllm进程成功")
break
- # 超出等待时间,抛出异常
+ # 超出等待时间,退出循环
if time.time() - start_kill_time > process_kill_wait:
- raise Exception(f"旧vllm进程终止超时,超过{process_kill_wait}秒仍未终止")
+ logger.warning(f"旧vllm进程终止超时,超过{process_kill_wait}秒仍未终止")
+ break
# 等待一段时间后重试校验
time.sleep(1.0)
diff --git a/loopai/agents/Judger/utils/oj/vllm_starter.py b/loopai/agents/Judger/utils/oj/vllm_starter.py
index 22d0af5..f199ff9 100644
--- a/loopai/agents/Judger/utils/oj/vllm_starter.py
+++ b/loopai/agents/Judger/utils/oj/vllm_starter.py
@@ -198,13 +198,13 @@ def stop_vllm_server(proc: subprocess.Popen, stop_event: threading.Event):
except Exception as e:
logger.error(f"停止vllm失败: {e}")
-# ========== 调用示例(和你原有调用方式几乎一致) ==========
+# ========== 调用示例 ==========
# if __name__ == "__main__":
# try:
# # 替换为你的环境路径
# ENV_PYTHON_PATH = "/root/miniconda3/envs/brjl/bin/python"
-# # 调用函数(仅返回值多了stop_event,其他参数完全不变)
+# # 调用函数
# vllm_proc, stop_event = start_vllm_openai_api_server(
# env_configs='{"CUDA_VISIBLE_DEVICES": "0","NCCL_P2P_DISABLE": "1","NCCL_IB_DISABLE": "1","NCCL_DEBUG": "INFO","NCCL_SOCKET_IFNAME": "lo","NCCL_BLOCKING_WAIT": "1"}',
# vllm_env_path=ENV_PYTHON_PATH,
diff --git a/loopai/schema/states.py b/loopai/schema/states.py
index d2eea95..a10ec95 100644
--- a/loopai/schema/states.py
+++ b/loopai/schema/states.py
@@ -1010,12 +1010,12 @@ class JudgerState(BaseModel):
# description="评估模型 Base URL,未设置或为空的时候,将会尝试通过本地开启vllm",
# json_schema_extra={"ui_type": "text", "ui_group": "评估模型"}
#)
- eval_api_key: str = Field(
- default="EMPTY",
- title="评估模型 API Key",
- description="评估模型 API Key",
- json_schema_extra={"ui_type": "password", "ui_group": "评估模型"}
- )
+ #eval_api_key: str = Field(
+ # default="EMPTY",
+ # title="评估模型 API Key",
+ # description="评估模型 API Key",
+ # json_schema_extra={"ui_type": "password", "ui_group": "评估模型"}
+ #)
eval_temperature: float = Field(
default=0,
title="评估模型温度",
@@ -1127,13 +1127,14 @@ class JudgerState(BaseModel):
description="通用文本评测使用的评测集名称",
json_schema_extra={"ui_type": "text", "ui_group": "评估模型"}
)
- # ===== 通用文本 / DataFlow Eval =====
cuda_visible_devices: str = Field(
default="0",
- title="通用文本可见GPU编号",
- description="通用文本任务指定运行GPU",
+ title="可见GPU编号",
+ description="评测任务指定运行GPU",
json_schema_extra={"ui_type": "text", "ui_group": "评估模型"}
)
+ # ===== 通用文本 / DataFlow Eval =====
+
# is_api: bool = Field(
# default=False,
# title="是否 API 模式",
@@ -1159,18 +1160,18 @@ class JudgerState(BaseModel):
description="DataFlow 评测字段映射,如 input_question_key / input_target_key / input_pred_key",
json_schema_extra={"ui_type": "json_viewer", "ui_group": "评估模型"}
)
- skip_dataflow_eval: bool = Field(
- default=False,
- title="跳过 DataFlow 正式评测",
- description="为 True 时仅准备 bench / records,不调用 DataFlowEvalTool.run_eval",
- json_schema_extra={"ui_type": "toggle_switch", "ui_group": "评估模型"}
- )
- output_dir: str = Field(
- default="",
- title="通用文本输出路径",
- description="通用文本任务结束后输出路径",
- json_schema_ectra={"ui_type": "text", "ui_group": "评估模型"}
- )
+ #skip_dataflow_eval: bool = Field(
+ # default=False,
+ # title="跳过 DataFlow 正式评测",
+ # description="为 True 时仅准备 bench / records,不调用 DataFlowEvalTool.run_eval",
+ # json_schema_extra={"ui_type": "toggle_switch", "ui_group": "评估模型"}
+ #)
+ #output_dir: str = Field(
+ # default="",
+ # title="通用文本输出路径",
+ # description="通用文本任务结束后输出路径",
+ # json_schema_ectra={"ui_type": "text", "ui_group": "评估模型"}
+ #)
class AnalyzerState(BaseModel):
diff --git a/tutorial/docs/guide/details/judger-agent.md b/tutorial/docs/guide/details/judger-agent.md
index 74ef060..3bc732b 100644
--- a/tutorial/docs/guide/details/judger-agent.md
+++ b/tutorial/docs/guide/details/judger-agent.md
@@ -10,12 +10,61 @@
## 进入它之前通常要准备什么
-常见前置条件包括:
+通用前置条件包括:
+| 字段名 | 类型 | 说明 |
+| --- | :---: | --- |
+| `eval_task_type` | str | 评测任务类型。目前支持的任务类型有代码生成(code), Text2sql(text2sql), 通用领域文本评估(general_text),配置时需从`code`、`text2sql`、`general_text`中选择,默认为`code` |
+| `eval_model_path` | str | 被评测的模型路径 |
+| `eval_temperature` | double | 模型温度参数 |
+| `eval_top_p` | double | 顶部概率质量采样的累积概率阈值 |
+| `eval_problem_path` | str | 格式为`jsonl`的问题集文件路径(相关字段要求请查阅"评测数据集相关字段"部分) |
+| `eval_vllm_tensor_parallel_size` | int | 张量并行大小,默认为1 |
+| `eval_vllm_gpu_memory_utilization` | double | GPU显存利用率。 |
+| `cuda_visible_devices` | str | 评测任务指定的GPU编号,比如`0,1`。该参数默认为`0` |
+
+code任务前置条件:
+| 字段名 | 类型 | 说明 |
+| --- | :---: | --- |
+| `eval_batch_size` | int | 生成样例批处理大小 |
+| `eval_case_num` | int | 问题集中每条问题生成样例的数量 |
+
+
+text2sql任务前置条件:
+| 字段名 | 类型 | 说明 |
+| --- | :---: | --- |
+| `eval_batch_size` | int | 生成样例批处理大小 |
+| `eval_case_num` | int | 问题集中每条问题生成样例的数量 |
+| `eval_text2sql_dir` | int | 数据库文件夹路径,如`path/to/your/database/` |
+
+
+general_text任务前置条件:
+| 字段名 | 类型 | 说明 |
+| --- | :---: | --- |
+| `bench_dataflow_eval_type` | str | 通用文本任务类型,例如 key2_qa / key1_text_score |
+| `key_mapping` | str | 问题集的json格式的字段映射,若用户未设置则从问题集中自动识别,如 {"input_question_key": "question","input_target_key": "answer","input_pred_key": "generated_ans"} |
+
+-----------------
+评测数据集相关字段:
+
+- code任务:
+
+ | 字段名 |
含义
| 说明 |
+ | --- | :---: | --- |
+ | task_id | 题目编号 | `问题集名/序号`或者`序号` |
+ | prompt | 问题提示词 | 函数定义+问题描述提示(以多行注释形式写在函数定义下),为了减少处理过程,需模型生成的结果为完整函数。如`def return1():\n \"\"\"This function has no input parameters, and your task is to make it return the integer 1.\n \"\"\"` |
+ | entry_point | 评测函数入口点 | 如`return1` |
+ | canonical_solution | 标准程序 | 如`def return1():\n return 1`。需要完整的代码(包含函数定义部分) |
+ | test_list | 测试用例列表 | 如`["assert return1() == 1"]`。需要为测试用例列表,其中的函数名需要和`entry_point`一致 |
+
+- text2sql任务:
+ | 字段名 | 含义
| 说明 |
+ | --- | :---: | --- |
+ | task_id | 题目编号 | `问题集名/序号`或者`序号` |
+ | prompt | 问题提示词 | |
+ | db_id | 数据库名称 | 如值为`dbName`,则`dbName.sqlite`数据库文件应在`{judger.eval_text2sql_dir}\dbName`目录中 |
+ | question | 问题内容 | 如`What is the highest eligible free rate for K-12 students in the schools in Alameda County?` |
+ | ground_truth | 标准答案 | 如`SELECT 'Free Meal Count (K-12)' / 'Enrollment (K-12)' FROM frpm WHERE 'County Name' = 'Alameda' ORDER BY (CAST('Free Meal Count (K-12)' AS REAL) / 'Enrollment (K-12)') DESC LIMIT 1` |
-- `judger.eval_task_type`
-- `judger.eval_model_path`
-- `judger.eval_base_url` 或本地 vLLM 相关配置
-- 评测数据集、任务配置或评测脚本相关字段
## 它的输入和输出可以怎么理解
@@ -27,9 +76,9 @@
输出通常是:
-- 样例输出
-- 评测结果
-- 可供 Analyzer 继续使用的结果路径
+- 样例输出:`code`和`text2sql`任务输出路径为`judger.output_case_path`;`general_text`任务输出路径为`judger.output_pred_path`。
+- 评测结果:输出路径为`judger.output_result_path`
+- 上述输出路径都是可供 Analyzer 继续使用的结果路径
## 在闭环中的位置