diff --git a/aworld/agents/llm_agent.py b/aworld/agents/llm_agent.py index 4109ca54f..d2eb236a8 100644 --- a/aworld/agents/llm_agent.py +++ b/aworld/agents/llm_agent.py @@ -596,7 +596,8 @@ async def async_policy(self, observation: Observation, info: Dict[str, Any] = {} output_message = Message( category=Constants.OUTPUT, payload=Output( - data=f"llm result error: {llm_response.error}" + data="", + metadata={"error": f"llm result error: {llm_response.error}"} ), sender=self.id(), session_id=message.context.session_id if message.context else "", @@ -827,7 +828,8 @@ async def invoke_model(self, await send_message(Message( category=Constants.OUTPUT, payload=Output( - data=f"Failed to call llm model: {e}" + data="", + metadata={"error": f"Failed to call llm model: {str(e)}"}, ), sender=self.id(), session_id=message.context.session_id if message.context else "", diff --git a/aworld/core/task.py b/aworld/core/task.py index e0ce924de..8525269db 100644 --- a/aworld/core/task.py +++ b/aworld/core/task.py @@ -60,6 +60,8 @@ class Task: task_status: TaskStatus = field(default=TaskStatusValue.INIT) # streaming support streaming_mode: StreamingMode = field(default=None) + # custom error formatter for error responses + error_formatter: Optional[Callable[[str], str]] = field(default=None, repr=False) def to_dict(self) -> Dict[str, Any]: """Serialize Task to dict while excluding parent_task to avoid recursion. @@ -112,6 +114,42 @@ class TaskResponse: # task final status, e.g. success/failed/cancelled status: TaskStatus | None = field(default=TaskStatusValue.SUCCESS) + @classmethod + def build_error_response(cls, task_id: str, msg: str, status: str = TaskStatusValue.FAILED, context: Context = None, + time_cost: float = 0.0, usage: Dict[str, Any] = None, + error_formatter: Optional[Callable[[str], str]] = None) -> 'TaskResponse': + """Build an error response with customizable error formatting. + + Args: + task_id: The task ID + msg: Error message + status: Task status, defaults to FAILED + context: Task context + time_cost: Time cost in seconds + usage: Usage statistics + error_formatter: Optional callable to format error message. + If None, uses default '__AWORLD_ERROR__: ' prefix. + Example: lambda msg: f"ERROR: {msg}" + + Returns: + TaskResponse: Error response object + """ + if error_formatter is None: + formatted_answer = '__AWORLD_ERROR__: ' + msg + else: + formatted_answer = error_formatter(msg) + + return cls( + answer=formatted_answer, + success=False, + context=context, + id=task_id, + time_cost=time_cost, + usage=usage, + msg=msg, + status=status + ) + class Runner(object): __metaclass__ = abc.ABCMeta diff --git a/aworld/dataset/trajectory_strategy.py b/aworld/dataset/trajectory_strategy.py index 654ccd2e3..4f2a84abf 100644 --- a/aworld/dataset/trajectory_strategy.py +++ b/aworld/dataset/trajectory_strategy.py @@ -9,6 +9,7 @@ import abc import json +import traceback from typing import Any, Dict, List, Optional, TYPE_CHECKING from aworld.core.agent.base import AgentFactory @@ -21,6 +22,7 @@ TrajectoryReward, ExpMeta, ) +from aworld.runners.state_manager import RunNodeStatus if TYPE_CHECKING: from aworld.core.agent.swarm import Swarm @@ -196,9 +198,15 @@ async def build_trajectory_action(self, source: Any, **kwargs) -> Optional[Traje node = state_manager._find_node(source.id) if state_manager else None agent_results = [] ext_info = {} + status = RunNodeStatus.SUCCESS.value + err_msg = "" if node and node.results: for handle_result in node.results: result = handle_result.result + result_status = handle_result.status + if result_status == RunNodeStatus.FAILED or result_status == RunNodeStatus.TIMEOUT: + status = result_status.value + err_msg += f"{node.msg_id}:{handle_result.result_msg}\n" if isinstance(result, Message) and isinstance(result.payload, list): agent_results.extend(result.payload) else: @@ -234,7 +242,8 @@ def _get_attr_from_action(obj, attr, default=None): if not agent_name: agent_name = source.receiver agent = AgentFactory.agent_instance(agent_name) - action = TrajectoryAction(content=action_content, tool_calls=tool_calls, is_agent_finished=agent.finished) + action = TrajectoryAction(content=action_content, tool_calls=tool_calls, is_agent_finished=agent.finished, + status=status, msg=err_msg) return action async def build_trajectory_reward(self, source: Any, **kwargs) -> Optional[TrajectoryReward]: diff --git a/aworld/dataset/types.py b/aworld/dataset/types.py index a65d6ec01..f1a24644a 100644 --- a/aworld/dataset/types.py +++ b/aworld/dataset/types.py @@ -42,6 +42,8 @@ class TrajectoryAction(BaseModel): content: Optional[str] = Field(default=None, description="Assistant message content") tool_calls: List[Dict[str, Any]] = Field(default_factory=list, description="Tool calls") is_agent_finished: bool = Field(default=False, description="Is agent finished") + status: Optional[str] = Field(default=None, description="Execution status") + msg: Optional[str] = Field(default=None, description="Execution error message") ext_info: Dict[str, Any] = Field(default_factory=dict, description="Extra information") class TrajectoryReward(BaseModel): diff --git a/aworld/runners/event_runner.py b/aworld/runners/event_runner.py index fc7baf120..d692b40a3 100644 --- a/aworld/runners/event_runner.py +++ b/aworld/runners/event_runner.py @@ -321,10 +321,16 @@ async def _do_run(self): context=message.context, success=True if not msg else False, id=self.task.id, - time_cost=( - time.time() - start), + time_cost=(time.time() - start), usage=self.context.token_usage, - status=TaskStatusValue.SUCCESS if not msg else TaskStatusValue.FAILED) + status=TaskStatusValue.SUCCESS) + if msg: + self._task_response = TaskResponse.build_error_response( + task_id=self.task.id, msg=msg, + context=message.context if message else self.context, + time_cost=(time.time() - start), + usage=self.context.token_usage, + error_formatter=self.task.error_formatter) break logger.debug(f"{task_flag} task {self.task.id} next message snap") # consume message @@ -395,9 +401,11 @@ def _response(self): if self.context.get_task().conf and self.context.get_task().conf.resp_carry_context == False: self._task_response.context = None if self._task_response is None: - self._task_response = TaskResponse(id=self.context.task_id if self.context else "", - success=False, - msg="Task return None.") + self._task_response = TaskResponse.build_error_response( + task_id=self.context.task_id if self.context else "", + msg="Task return None.", + error_formatter=self.task.error_formatter + ) if self.context.get_task().conf and self.context.get_task().conf.resp_carry_raw_llm_resp == True: self._task_response.raw_llm_resp = self.context.context_info.get('llm_output') self._task_response.trace_id = get_trace_id() @@ -410,7 +418,7 @@ async def _save_trajectories(self): logger.debug(f"{self.task.id}|{self.task.is_sub_task}#task_graph from context: {self.context._task_graph}") if traj: self._task_response.trajectory = [step.to_dict() for step in traj] - logger.debug(f"{self.task.id}|{self.task.is_sub_task}#_task_response.trajectory: {json.dumps(self._task_response.trajectory, ensure_ascii=False)}") + logger.warn(f"{self.task.id}|{self.task.is_sub_task}#_task_response.trajectory: {json.dumps(self._task_response.trajectory, ensure_ascii=False)}") except Exception as e: logger.error(f"Failed to get trajectories: {str(e)}.{traceback.format_exc()}") @@ -423,15 +431,14 @@ async def should_stop_task(self, message: Message): if 0 < self.task.timeout < time_cost: logger.warn( f"{task_flag} task {self.task.id} timeout after {time_cost} seconds.") - self._task_response = TaskResponse( - answer='', - success=False, + self._task_response = TaskResponse.build_error_response( + task_id=self.task.id, + msg=f'Task timeout after {time_cost} seconds.', context=message.context if message else self.context, - id=self.task.id, time_cost=(time.time() - self.start_time), usage=self.context.token_usage, - msg=f'Task timeout after {time_cost} seconds.', - status=TaskStatusValue.TIMEOUT + status=TaskStatusValue.TIMEOUT, + error_formatter=self.task.error_formatter ) await self.context.update_task_status(self.task.id, TaskStatusValue.TIMEOUT) return True @@ -440,15 +447,14 @@ async def should_stop_task(self, message: Message): task_status = await self.context.get_task_status() if task_status == TaskStatusValue.INTERRUPTED or task_status == TaskStatusValue.CANCELLED: logger.warn(f"{task_flag} task {self.task.id} is {task_status}.") - self._task_response = TaskResponse( - answer='', - success=False, + self._task_response = TaskResponse.build_error_response( + task_id=self.task.id, + msg=f'Task is {task_status}.', context=message.context if message else self.context, - id=self.task.id, time_cost=time_cost, usage=self.context.token_usage, - msg=f'Task is {task_status}.', - status=task_status + status=task_status, + error_formatter=self.task.error_formatter ) return True return False diff --git a/aworld/runners/handler/task.py b/aworld/runners/handler/task.py index f58bb5726..2fbac9a0f 100644 --- a/aworld/runners/handler/task.py +++ b/aworld/runners/handler/task.py @@ -73,13 +73,14 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: yield event logger.warning(f"{task_flag} task {self.runner.task.id} stop, cause: {task_item.msg}") - self.runner._task_response = TaskResponse(msg=task_item.msg, - answer='', - context=message.context, - success=False, - id=self.runner.task.id, - time_cost=(time.time() - self.runner.start_time), - usage=self.runner.context.token_usage) + self.runner._task_response = TaskResponse.build_error_response( + task_id=self.runner.task.id, msg=task_item.msg, + context=message.context, + time_cost=(time.time() - self.runner.start_time), + usage=self.runner.context.token_usage, + error_formatter=self.runner.task.error_formatter + ) + await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE) elif topic == TopicType.FINISHED: @@ -130,14 +131,15 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: yield Message(session_id=self.runner.context.session_id, sender=self.name(), category='mock', headers={"context": message.context}) # mark task response as cancelled - self.runner._task_response = TaskResponse(answer='', - success=False, - context=message.context, - id=self.runner.task.id, - time_cost=(time.time() - self.runner.start_time), - usage=self.runner.context.token_usage, - msg=f'cancellation message received: {task_item.msg}', - status=TaskStatusValue.CANCELLED) + self.runner._task_response = TaskResponse.build_error_response( + task_id=self.runner.task.id, msg=f'cancellation message received: {task_item.msg}', + status=TaskStatusValue.CANCELLED, + context=message.context, + time_cost=(time.time() - self.runner.start_time), + usage=self.runner.context.token_usage, + error_formatter=self.runner.task.error_formatter + ) + await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE) elif topic == TopicType.INTERRUPT: @@ -145,14 +147,15 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]: yield Message(session_id=self.runner.context.session_id, sender=self.name(), category='mock', headers={"context": message.context}) # mark task response as interrupted - self.runner._task_response = TaskResponse(answer='', - success=False, - context=message.context, - id=self.runner.task.id, - time_cost=(time.time() - self.runner.start_time), - usage=self.runner.context.token_usage, - msg=f'interruption message received: {task_item.msg}', - status=TaskStatusValue.INTERRUPTED) + self.runner._task_response = TaskResponse.build_error_response( + task_id=self.runner.task.id, msg=f'interruption message received: {task_item.msg}', + status=TaskStatusValue.INTERRUPTED, + context=message.context, + time_cost=(time.time() - self.runner.start_time), + usage=self.runner.context.token_usage, + error_formatter=self.runner.task.error_formatter + ) + await self.runner.stop() yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE) diff --git a/aworld/runners/state_manager.py b/aworld/runners/state_manager.py index 8847b6df4..286b9286f 100644 --- a/aworld/runners/state_manager.py +++ b/aworld/runners/state_manager.py @@ -806,6 +806,7 @@ def save_message_handle_result(self, name: str, message: Message, result: Messag handle_result = HandleResult( name=name, status=RunNodeStatus.FAILED, + result_msg=getattr(result.payload, "msg"), result=result) else: handle_result = HandleResult(