Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions aworld/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,8 @@ async def async_policy(self, observation: Observation, info: Dict[str, Any] = {}
output_message = Message(
category=Constants.OUTPUT,
payload=Output(
data=f"llm result error: {llm_response.error}"
data="",
metadata={"error": f"llm result error: {llm_response.error}"}
),
sender=self.id(),
session_id=message.context.session_id if message.context else "",
Expand Down Expand Up @@ -827,7 +828,8 @@ async def invoke_model(self,
await send_message(Message(
category=Constants.OUTPUT,
payload=Output(
data=f"Failed to call llm model: {e}"
data="",
metadata={"error": f"Failed to call llm model: {str(e)}"},
),
sender=self.id(),
session_id=message.context.session_id if message.context else "",
Expand Down
38 changes: 38 additions & 0 deletions aworld/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class Task:
task_status: TaskStatus = field(default=TaskStatusValue.INIT)
# streaming support
streaming_mode: StreamingMode = field(default=None)
# custom error formatter for error responses
error_formatter: Optional[Callable[[str], str]] = field(default=None, repr=False)

def to_dict(self) -> Dict[str, Any]:
"""Serialize Task to dict while excluding parent_task to avoid recursion.
Expand Down Expand Up @@ -112,6 +114,42 @@ class TaskResponse:
# task final status, e.g. success/failed/cancelled
status: TaskStatus | None = field(default=TaskStatusValue.SUCCESS)

@classmethod
def build_error_response(cls, task_id: str, msg: str, status: str = TaskStatusValue.FAILED, context: Context = None,
time_cost: float = 0.0, usage: Dict[str, Any] = None,
error_formatter: Optional[Callable[[str], str]] = None) -> 'TaskResponse':
"""Build an error response with customizable error formatting.

Args:
task_id: The task ID
msg: Error message
status: Task status, defaults to FAILED
context: Task context
time_cost: Time cost in seconds
usage: Usage statistics
error_formatter: Optional callable to format error message.
If None, uses default '__AWORLD_ERROR__: ' prefix.
Example: lambda msg: f"ERROR: {msg}"

Returns:
TaskResponse: Error response object
"""
if error_formatter is None:
formatted_answer = '__AWORLD_ERROR__: ' + msg
else:
formatted_answer = error_formatter(msg)

return cls(
answer=formatted_answer,
success=False,
context=context,
id=task_id,
time_cost=time_cost,
usage=usage,
msg=msg,
status=status
)


class Runner(object):
__metaclass__ = abc.ABCMeta
Expand Down
11 changes: 10 additions & 1 deletion aworld/dataset/trajectory_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import abc
import json
import traceback
from typing import Any, Dict, List, Optional, TYPE_CHECKING

from aworld.core.agent.base import AgentFactory
Expand All @@ -21,6 +22,7 @@
TrajectoryReward,
ExpMeta,
)
from aworld.runners.state_manager import RunNodeStatus

if TYPE_CHECKING:
from aworld.core.agent.swarm import Swarm
Expand Down Expand Up @@ -196,9 +198,15 @@ async def build_trajectory_action(self, source: Any, **kwargs) -> Optional[Traje
node = state_manager._find_node(source.id) if state_manager else None
agent_results = []
ext_info = {}
status = RunNodeStatus.SUCCESS.value
err_msg = ""
if node and node.results:
for handle_result in node.results:
result = handle_result.result
result_status = handle_result.status
if result_status == RunNodeStatus.FAILED or result_status == RunNodeStatus.TIMEOUT:
status = result_status.value
err_msg += f"{node.msg_id}:{handle_result.result_msg}\n"
if isinstance(result, Message) and isinstance(result.payload, list):
agent_results.extend(result.payload)
else:
Expand Down Expand Up @@ -234,7 +242,8 @@ def _get_attr_from_action(obj, attr, default=None):
if not agent_name:
agent_name = source.receiver
agent = AgentFactory.agent_instance(agent_name)
action = TrajectoryAction(content=action_content, tool_calls=tool_calls, is_agent_finished=agent.finished)
action = TrajectoryAction(content=action_content, tool_calls=tool_calls, is_agent_finished=agent.finished,
status=status, msg=err_msg)
return action

async def build_trajectory_reward(self, source: Any, **kwargs) -> Optional[TrajectoryReward]:
Expand Down
2 changes: 2 additions & 0 deletions aworld/dataset/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class TrajectoryAction(BaseModel):
content: Optional[str] = Field(default=None, description="Assistant message content")
tool_calls: List[Dict[str, Any]] = Field(default_factory=list, description="Tool calls")
is_agent_finished: bool = Field(default=False, description="Is agent finished")
status: Optional[str] = Field(default=None, description="Execution status")
msg: Optional[str] = Field(default=None, description="Execution error message")
ext_info: Dict[str, Any] = Field(default_factory=dict, description="Extra information")

class TrajectoryReward(BaseModel):
Expand Down
44 changes: 25 additions & 19 deletions aworld/runners/event_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,10 +321,16 @@ async def _do_run(self):
context=message.context,
success=True if not msg else False,
id=self.task.id,
time_cost=(
time.time() - start),
time_cost=(time.time() - start),
usage=self.context.token_usage,
status=TaskStatusValue.SUCCESS if not msg else TaskStatusValue.FAILED)
status=TaskStatusValue.SUCCESS)
if msg:
self._task_response = TaskResponse.build_error_response(
task_id=self.task.id, msg=msg,
context=message.context if message else self.context,
time_cost=(time.time() - start),
usage=self.context.token_usage,
error_formatter=self.task.error_formatter)
break
logger.debug(f"{task_flag} task {self.task.id} next message snap")
# consume message
Expand Down Expand Up @@ -395,9 +401,11 @@ def _response(self):
if self.context.get_task().conf and self.context.get_task().conf.resp_carry_context == False:
self._task_response.context = None
if self._task_response is None:
self._task_response = TaskResponse(id=self.context.task_id if self.context else "",
success=False,
msg="Task return None.")
self._task_response = TaskResponse.build_error_response(
task_id=self.context.task_id if self.context else "",
msg="Task return None.",
error_formatter=self.task.error_formatter
)
if self.context.get_task().conf and self.context.get_task().conf.resp_carry_raw_llm_resp == True:
self._task_response.raw_llm_resp = self.context.context_info.get('llm_output')
self._task_response.trace_id = get_trace_id()
Expand All @@ -410,7 +418,7 @@ async def _save_trajectories(self):
logger.debug(f"{self.task.id}|{self.task.is_sub_task}#task_graph from context: {self.context._task_graph}")
if traj:
self._task_response.trajectory = [step.to_dict() for step in traj]
logger.debug(f"{self.task.id}|{self.task.is_sub_task}#_task_response.trajectory: {json.dumps(self._task_response.trajectory, ensure_ascii=False)}")
logger.warn(f"{self.task.id}|{self.task.is_sub_task}#_task_response.trajectory: {json.dumps(self._task_response.trajectory, ensure_ascii=False)}")

except Exception as e:
logger.error(f"Failed to get trajectories: {str(e)}.{traceback.format_exc()}")
Expand All @@ -423,15 +431,14 @@ async def should_stop_task(self, message: Message):
if 0 < self.task.timeout < time_cost:
logger.warn(
f"{task_flag} task {self.task.id} timeout after {time_cost} seconds.")
self._task_response = TaskResponse(
answer='',
success=False,
self._task_response = TaskResponse.build_error_response(
task_id=self.task.id,
msg=f'Task timeout after {time_cost} seconds.',
context=message.context if message else self.context,
id=self.task.id,
time_cost=(time.time() - self.start_time),
usage=self.context.token_usage,
msg=f'Task timeout after {time_cost} seconds.',
status=TaskStatusValue.TIMEOUT
status=TaskStatusValue.TIMEOUT,
error_formatter=self.task.error_formatter
)
await self.context.update_task_status(self.task.id, TaskStatusValue.TIMEOUT)
return True
Expand All @@ -440,15 +447,14 @@ async def should_stop_task(self, message: Message):
task_status = await self.context.get_task_status()
if task_status == TaskStatusValue.INTERRUPTED or task_status == TaskStatusValue.CANCELLED:
logger.warn(f"{task_flag} task {self.task.id} is {task_status}.")
self._task_response = TaskResponse(
answer='',
success=False,
self._task_response = TaskResponse.build_error_response(
task_id=self.task.id,
msg=f'Task is {task_status}.',
context=message.context if message else self.context,
id=self.task.id,
time_cost=time_cost,
usage=self.context.token_usage,
msg=f'Task is {task_status}.',
status=task_status
status=task_status,
error_formatter=self.task.error_formatter
)
return True
return False
49 changes: 26 additions & 23 deletions aworld/runners/handler/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,14 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]:
yield event

logger.warning(f"{task_flag} task {self.runner.task.id} stop, cause: {task_item.msg}")
self.runner._task_response = TaskResponse(msg=task_item.msg,
answer='',
context=message.context,
success=False,
id=self.runner.task.id,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage)
self.runner._task_response = TaskResponse.build_error_response(
task_id=self.runner.task.id, msg=task_item.msg,
context=message.context,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage,
error_formatter=self.runner.task.error_formatter
)

await self.runner.stop()
yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE)
elif topic == TopicType.FINISHED:
Expand Down Expand Up @@ -130,29 +131,31 @@ async def _do_handle(self, message: Message) -> AsyncGenerator[Message, None]:
yield Message(session_id=self.runner.context.session_id, sender=self.name(), category='mock',
headers={"context": message.context})
# mark task response as cancelled
self.runner._task_response = TaskResponse(answer='',
success=False,
context=message.context,
id=self.runner.task.id,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage,
msg=f'cancellation message received: {task_item.msg}',
status=TaskStatusValue.CANCELLED)
self.runner._task_response = TaskResponse.build_error_response(
task_id=self.runner.task.id, msg=f'cancellation message received: {task_item.msg}',
status=TaskStatusValue.CANCELLED,
context=message.context,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage,
error_formatter=self.runner.task.error_formatter
)

await self.runner.stop()
yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers, topic=TopicType.TASK_RESPONSE)
elif topic == TopicType.INTERRUPT:
# Avoid waiting to receive events and send a mock event for quick interrupt
yield Message(session_id=self.runner.context.session_id, sender=self.name(), category='mock',
headers={"context": message.context})
# mark task response as interrupted
self.runner._task_response = TaskResponse(answer='',
success=False,
context=message.context,
id=self.runner.task.id,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage,
msg=f'interruption message received: {task_item.msg}',
status=TaskStatusValue.INTERRUPTED)
self.runner._task_response = TaskResponse.build_error_response(
task_id=self.runner.task.id, msg=f'interruption message received: {task_item.msg}',
status=TaskStatusValue.INTERRUPTED,
context=message.context,
time_cost=(time.time() - self.runner.start_time),
usage=self.runner.context.token_usage,
error_formatter=self.runner.task.error_formatter
)

await self.runner.stop()
yield Message(payload=self.runner._task_response, session_id=message.session_id, headers=message.headers,
topic=TopicType.TASK_RESPONSE)
Expand Down
1 change: 1 addition & 0 deletions aworld/runners/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def save_message_handle_result(self, name: str, message: Message, result: Messag
handle_result = HandleResult(
name=name,
status=RunNodeStatus.FAILED,
result_msg=getattr(result.payload, "msg"),
result=result)
else:
handle_result = HandleResult(
Expand Down