Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 56 additions & 12 deletions loopai/agents/Judger/judger_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@
from loopai.schema.events import StreamEvent

from loopai.logger import get_logger
from loopai.common.prompts import PromptLoader

logger = get_logger()

On = False
Vllm_Start_Error = False

def _isNotNone(value):
return value != "" and value is not None

Expand All @@ -51,7 +56,7 @@ def extract_num(cp: str) -> int:
# 使用 min 选择:先按距离排序,再按数值本身排序
best_cp = min(checkpoints, key=lambda cp: (abs(extract_num(cp) - best_step), extract_num(cp)))

return best_cp
return best_cp

class JudgerAgent(BaseAgent):
@property
Expand All @@ -74,7 +79,7 @@ def get_check_required_fields_node(self):
def check_required_fields(state: LoopAIState, runtime: Runtime[RuntimeContext]):
writer = get_stream_writer()
required_fields = {
'judger':["eval_api_key", "eval_temperature",
'judger':["eval_temperature",
"eval_top_p", "eval_problem_path",
"eval_case_num", "eval_task_type"
],
Expand All @@ -96,8 +101,7 @@ def check_required_fields(state: LoopAIState, runtime: Runtime[RuntimeContext]):
missing_fields = get_missing_fields({'judger':["eval_model_path"]}, state)

"""vllm启动检查"""
# base_url = state.get("judger", {}).get("eval_base_url", None)
base_url = None
base_url = state.get("judger", {}).get("eval_base_url", None)
task_type = state.get("judger", {}).get("eval_task_type", "")
logger.info(f"base_url:{base_url}")
if(_isNotNone(base_url) and task_type!="general_text"):
Expand Down Expand Up @@ -178,18 +182,17 @@ def check_required_fields(state: LoopAIState, runtime: Runtime[RuntimeContext]):
check_file_fields = True

if task_type == "code" or task_type =="text2sql":
check_file_fields = check_jsonl_fields(state.get("judger", {}).get("eval_problem_path", ""), required_fields)
check_file_fields,error_details = check_jsonl_fields(state.get("judger", {}).get("eval_problem_path", ""), required_fields)

if check_file_fields is not True:
logger.info("$"*50)
logger.info(["eval_problem_path"])
state['exception'] = 'ConfigerError'
state['next_to'] = 'config_node'
state['automated_query'] = self.prompt_loader(
"automated_query", "judger_missing_fields_prompt")
state.setdefault('configer',{})['configer_error'] = f'Wrong required fields: {json.dumps({"wrong_fields": "eval_problem_path"}, ensure_ascii=False)}'
state.setdefault('configer',{})['configer_error'] = json.dumps(error_details, ensure_ascii=False,indent=4)
goto_node = runtime.context['exception_navigate']
logger.info(f'found wrong fields, required fields missing in the file, goto {goto_node}')
logger.warning(f'found wrong fields, required fields missing in the file, goto {goto_node}')
return Command(
update=state,
goto=goto_node,
Expand Down Expand Up @@ -240,6 +243,7 @@ def vllm_kill_node(state: LoopAIState) -> LoopAIState:
# vllm_port = state.get("judger", {}).get("eval_vllm_port", DEFAULT_VLLM_PORT)
# base_url = state.get("judger", {}).get("eval_base_url", None)
# base_url = None
global On
writer = get_stream_writer()
# 未设置base_url才会进入本地开启程序才会先关闭本地的vllm服务
logger.info("=== 准备关闭 vllm ===")
Expand Down Expand Up @@ -280,11 +284,13 @@ def vllm_kill_node(state: LoopAIState) -> LoopAIState:
# message="vllm开启结束",
# data={"msg": f"因已开启自定义vllm服务而跳过该过程"}
# ).json())
state["judger"]["eval_base_url"] = None
On = not On
return state

@staticmethod
@BaseAgent.set_current
def vllm_start_node(state: LoopAIState) -> LoopAIState:
def vllm_start_node(state: LoopAIState) -> Union[LoopAIState, Command]:
# 设置 cuda_visible_devices
set_gpu(state)

Expand Down Expand Up @@ -325,8 +331,17 @@ def vllm_start_node(state: LoopAIState) -> LoopAIState:
message="vllm 启动异常",
data={"msg": f"vllm 启动异常 ,请解决后重新评测:{e}"}
).json())
# 直接结束 Judger 当前节点,不再跳转父图异常路由
return state
state['exception'] = 'ConfigerError'
state['next_to'] = 'config_node'
state['automated_query'] = "<automated_query>在上一轮执行中,`judger_agent` 因参数缺失而中断并跳转至 `configer_agent`。现在用户已经补全了缺失的参数。请询问用户是否需要继续执行 `judge` 操作。</automated_query>"
state.setdefault('configer',{})['configer_error'] = f'vllm 启动异常,请检查配置是否正确'
goto_node = END
logger.warning(f'vllm start error, goto {goto_node}')
return Command(
update=state,
goto=goto_node,
graph=Command.PARENT
)
else:
if writer:
writer(StreamEvent(
Expand All @@ -337,6 +352,26 @@ def vllm_start_node(state: LoopAIState) -> LoopAIState:
).json())
return state

#@staticmethod
#@BaseAgent.set_current
#def vllm_start_next(state: LoopAIState) -> str:
# """
# 如果启动 vllm 发生异常直接 to_end
# """
# global Vllm_Start_Error
# # Vllm启动异常
# if not Vllm_Start_Error:
# # 正常启动
# return "to_data_format"
# else:
# # 异常则直接结束
# Vllm_Start_Error = False
# state['judger']['eval_base_url'] = None
# logger.warning("==== 发生异常结束 ====")
# logger.info("==== 结束 ====")
# logger.info(f"=== eval_base_url: {state['judger']['eval_base_url']}")
# return "to_end"

@staticmethod
@BaseAgent.set_current
def data_format_node(state: LoopAIState) -> LoopAIState:
Expand Down Expand Up @@ -469,7 +504,7 @@ def vllm_kill_next(state: LoopAIState) -> str:
"""
base_url = state.get("judger", {}).get("eval_base_url", None)
# 判断是否为空,为空则开启本地
if not _isNotNone(base_url):
if not _isNotNone(base_url) and On:
return "to_vllm_start"
else:
state['judger']['eval_base_url'] = None
Expand All @@ -496,6 +531,15 @@ def init_graph(self, **kwargs):
builder.set_entry_point("check_required_fields")
builder.set_finish_point("eval_general_text")

#builder.add_conditional_edges(
# source="vllm_start",
# path=self.vllm_start_next,
# path_map={
# "to_data_format":"data_format",
# "to_end": END,
# }
#)

builder.add_conditional_edges(
source="check_param_type", # 来源节点(从哪个节点跳转出来)
path=self.check_param_type_next, # 路由函数(执行条件判断,返回路由键)
Expand Down
2 changes: 1 addition & 1 deletion loopai/agents/Judger/nodes/eval_general_text_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _build_model_config(cfg: Dict[str, Any]) -> ModelConfig:
model_name_or_path=model_name_or_path,
is_api=is_api,
api_url=api_url,
api_key=cfg.get("eval_api_key",""),
api_key=cfg.get("eval_api_key","EMPTY"),
temperature=float(cfg.get("eval_temperature", 0.0)),
top_p=float(cfg.get("eval_top_p", 1.0)),
tensor_parallel_size=int(cfg.get("eval_vllm_tensor_parallel_size", 1)),
Expand Down
43 changes: 26 additions & 17 deletions loopai/agents/Judger/utils/oj/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
def read_problems(evalset_file) -> Dict[str, Dict]:
return {task["task_id"]: task for task in stream_jsonl(evalset_file)}

import json
from typing import List

def check_jsonl_fields(filepath: str, require_fields: List[str]) -> bool:
def check_jsonl_fields(filepath: str, require_fields: List[str]) -> Tuple[bool, List[str]]:
"""
检查 JSONL 文件中的每一行是否都包含指定的字段。

Expand All @@ -20,31 +17,43 @@ def check_jsonl_fields(filepath: str, require_fields: List[str]) -> bool:

返回:
如果所有行都包含所有字段,返回 True;否则返回 False
是否全部通过校验, 缺失字段或错误的详细信息列表
"""
error_details = []
try:
with open(filepath, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, start=1):
line = line.strip()
with open(filepath, 'rb') as f:
for line_num, line_bytes in enumerate(f, start=1):
try:
# 去除首尾空白并解码
line = line_bytes.decode('utf-8').strip()
except UnicodeDecodeError:
error_details.append(f"第 {line_num} 行: 无法解码的字符编码")
continue

if not line: # 跳过空行
continue

try:
data = json.loads(line)
except json.JSONDecodeError as e:
print(f"第 {line_num} 行 JSON 解析错误: {e}")
return False
error_details.append(f"第 {line_num} 行: JSON 解析失败 ({e})")
continue

# 检查每个必需字段
for field in require_fields:
if field not in data:
print(f"第 {line_num} 行缺少字段 '{field}'")
return False
return True
error_details.append(f"第 {line_num} 行: 缺少必填字段 '{field}'")


except FileNotFoundError:
print(f"文件不存在: {filepath}")
return False
return False, [f"文件未找到: {filepath}"]

except Exception as e:
print(f"读取文件时发生错误: {e}")
return False
return False, [f"读取文件时发生未知错误: {e}"]

# 如果错误列表为空,说明校验通过
is_valid = len(error_details) == 0
return is_valid, error_details

def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
Expand Down
6 changes: 3 additions & 3 deletions loopai/agents/Judger/utils/oj/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ def generate_sample_code(state):
model = init_model(
model_path=judger_state['eval_model_path'],
base_url=judger_state['eval_base_url'],
api_key=judger_state['eval_api_key'],
api_key=judger_state.get("eval_api_key","EMPTY"),
temperature=judger_state['eval_temperature'],
top_p=judger_state['eval_top_p']
)
logger.info(f"模型路径:-> base_url: {judger_state['eval_base_url']} ->api_key: {judger_state['eval_api_key']}")
logger.info(f"模型路径:-> base_url: {judger_state['eval_base_url']} ->api_key: {judger_state.get("eval_api_key","EMPTY")}")

output_dir = Path(state.get("output_dir"))
problem_path = judger_state['eval_problem_path']
Expand Down Expand Up @@ -124,7 +124,7 @@ def generate_sample_text2sql(state):
model = init_model(
model_path=judger_state['eval_model_path'],
base_url=judger_state['eval_base_url'],
api_key=judger_state['eval_api_key'],
api_key='EMPTY',
temperature=judger_state['eval_temperature'],
top_p=judger_state['eval_top_p']
)
Expand Down
5 changes: 3 additions & 2 deletions loopai/agents/Judger/utils/oj/vllm_killer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def kill_vllm_openai_api_server(
logger.info("强制终止vllm进程成功")
break

# 超出等待时间,抛出异常
# 超出等待时间,退出循环
if time.time() - start_kill_time > process_kill_wait:
raise Exception(f"旧vllm进程终止超时,超过{process_kill_wait}秒仍未终止")
logger.warning(f"旧vllm进程终止超时,超过{process_kill_wait}秒仍未终止")
break

# 等待一段时间后重试校验
time.sleep(1.0)
Expand Down
4 changes: 2 additions & 2 deletions loopai/agents/Judger/utils/oj/vllm_starter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,13 @@ def stop_vllm_server(proc: subprocess.Popen, stop_event: threading.Event):
except Exception as e:
logger.error(f"停止vllm失败: {e}")

# ========== 调用示例(和你原有调用方式几乎一致) ==========
# ========== 调用示例 ==========
# if __name__ == "__main__":
# try:
# # 替换为你的环境路径
# ENV_PYTHON_PATH = "/root/miniconda3/envs/brjl/bin/python"

# # 调用函数(仅返回值多了stop_event,其他参数完全不变)
# # 调用函数
# vllm_proc, stop_event = start_vllm_openai_api_server(
# env_configs='{"CUDA_VISIBLE_DEVICES": "0","NCCL_P2P_DISABLE": "1","NCCL_IB_DISABLE": "1","NCCL_DEBUG": "INFO","NCCL_SOCKET_IFNAME": "lo","NCCL_BLOCKING_WAIT": "1"}',
# vllm_env_path=ENV_PYTHON_PATH,
Expand Down
43 changes: 22 additions & 21 deletions loopai/schema/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,12 +1010,12 @@ class JudgerState(BaseModel):
# description="评估模型 Base URL,未设置或为空的时候,将会尝试通过本地开启vllm",
# json_schema_extra={"ui_type": "text", "ui_group": "评估模型"}
#)
eval_api_key: str = Field(
default="EMPTY",
title="评估模型 API Key",
description="评估模型 API Key",
json_schema_extra={"ui_type": "password", "ui_group": "评估模型"}
)
#eval_api_key: str = Field(
# default="EMPTY",
# title="评估模型 API Key",
# description="评估模型 API Key",
# json_schema_extra={"ui_type": "password", "ui_group": "评估模型"}
#)
eval_temperature: float = Field(
default=0,
title="评估模型温度",
Expand Down Expand Up @@ -1127,13 +1127,14 @@ class JudgerState(BaseModel):
description="通用文本评测使用的评测集名称",
json_schema_extra={"ui_type": "text", "ui_group": "评估模型"}
)
# ===== 通用文本 / DataFlow Eval =====
cuda_visible_devices: str = Field(
default="0",
title="通用文本可见GPU编号",
description="通用文本任务指定运行GPU",
title="可见GPU编号",
description="评测任务指定运行GPU",
json_schema_extra={"ui_type": "text", "ui_group": "评估模型"}
)
# ===== 通用文本 / DataFlow Eval =====

# is_api: bool = Field(
# default=False,
# title="是否 API 模式",
Expand All @@ -1159,18 +1160,18 @@ class JudgerState(BaseModel):
description="DataFlow 评测字段映射,如 input_question_key / input_target_key / input_pred_key",
json_schema_extra={"ui_type": "json_viewer", "ui_group": "评估模型"}
)
skip_dataflow_eval: bool = Field(
default=False,
title="跳过 DataFlow 正式评测",
description="为 True 时仅准备 bench / records,不调用 DataFlowEvalTool.run_eval",
json_schema_extra={"ui_type": "toggle_switch", "ui_group": "评估模型"}
)
output_dir: str = Field(
default="",
title="通用文本输出路径",
description="通用文本任务结束后输出路径",
json_schema_ectra={"ui_type": "text", "ui_group": "评估模型"}
)
#skip_dataflow_eval: bool = Field(
# default=False,
# title="跳过 DataFlow 正式评测",
# description="为 True 时仅准备 bench / records,不调用 DataFlowEvalTool.run_eval",
# json_schema_extra={"ui_type": "toggle_switch", "ui_group": "评估模型"}
#)
#output_dir: str = Field(
# default="",
# title="通用文本输出路径",
# description="通用文本任务结束后输出路径",
# json_schema_ectra={"ui_type": "text", "ui_group": "评估模型"}
#)


class AnalyzerState(BaseModel):
Expand Down
Loading