diff --git a/backend/app/agent/listen_chat_agent.py b/backend/app/agent/listen_chat_agent.py index 7fbaddc6a..c5ff88a09 100644 --- a/backend/app/agent/listen_chat_agent.py +++ b/backend/app/agent/listen_chat_agent.py @@ -13,14 +13,13 @@ # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import asyncio -import json import logging import threading from collections.abc import Callable from threading import Event from typing import Any -from camel.agents import ChatAgent +from camel.agents import ChatAgent, CloneContext from camel.agents._types import ToolCallRequest from camel.agents.chat_agent import ( AsyncStreamingChatAgentResponse, @@ -28,7 +27,7 @@ ) from camel.memories import AgentMemory from camel.messages import BaseMessage -from camel.models import BaseModelBackend, ModelManager, ModelProcessingError +from camel.models import BaseModelBackend, ModelManager from camel.responses import ChatAgentResponse from camel.terminators import ResponseTerminator from camel.toolkits import FunctionTool, RegisteredAgentToolkit @@ -36,17 +35,10 @@ from camel.types.agents import ToolCallingRecord from pydantic import BaseModel -from app.service.task import ( - Action, - ActionActivateAgentData, - ActionActivateToolkitData, - ActionBudgetNotEnough, - ActionDeactivateAgentData, - ActionDeactivateToolkitData, - get_task_lock, - set_process_task, +from app.agent.listen_chat_agent_callback import ( + ListenChatAgentCallback, ) -from app.utils.event_loop_utils import _schedule_async_task +from app.service.task import get_task_lock, set_process_task # Logger for agent tracking logger = logging.getLogger("agent") @@ -98,6 +90,19 @@ def __init__( step_timeout: float | None = 1800, # 30 minutes **kwargs: Any, ) -> None: + self.api_task_id = api_task_id + self.agent_name = agent_name + self.process_task_id = "" + + self._user_callbacks = list(kwargs.pop("callbacks", []) or []) + self._user_execution_context = dict( + kwargs.pop("execution_context", {}) or {} + ) + self._user_execution_context_provider = kwargs.pop( + "execution_context_provider", None + ) + self._listen_callback = ListenChatAgentCallback(self) + super().__init__( system_message=system_message, model=model, @@ -119,588 +124,112 @@ def __init__( prune_tool_calls_from_memory=prune_tool_calls_from_memory, enable_snapshot_clean=enable_snapshot_clean, step_timeout=step_timeout, + callbacks=[ + *self._user_callbacks, + self._listen_callback, + ], + execution_context={ + **self._build_static_execution_context(), + **self._user_execution_context, + }, + execution_context_provider=self._build_dynamic_execution_context, **kwargs, ) - self.api_task_id = api_task_id - self.agent_name = agent_name - - process_task_id: str = "" - - def _send_agent_deactivate(self, message: str, tokens: int) -> None: - """Send agent deactivation event to the frontend. - - Args: - message: The accumulated message content - tokens: The total token count used - """ - task_lock = get_task_lock(self.api_task_id) - _schedule_async_task( - task_lock.put_queue( - ActionDeactivateAgentData( - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "agent_id": self.agent_id, - "message": message, - "tokens": tokens, - }, - ) - ) - ) - - @staticmethod - def _extract_tokens(response) -> int: - """Extract total token count from a response chunk. - - Args: - response: The response chunk (ChatAgentResponse or similar) - - Returns: - Total token count or 0 if not available - """ - if response is None: - return 0 - usage_info = ( - response.info.get("usage") - or response.info.get("token_usage") - or {} - ) - return usage_info.get("total_tokens", 0) - - def _stream_chunks(self, response_gen): - """Generator that wraps a streaming response. - - Sends chunks to frontend. - - Args: - response_gen: The original streaming response generator - Yields: - Each chunk from the original generator - - Returns: - Tuple of (accumulated_content, total_tokens) via - StopIteration value - """ - accumulated_content = "" - last_chunk = None - - try: - for chunk in response_gen: - last_chunk = chunk - if chunk.msg and chunk.msg.content: - accumulated_content += chunk.msg.content + def _build_static_execution_context(self) -> dict[str, Any]: + return { + "api_task_id": self.api_task_id, + "agent_name": self.agent_name, + } + + def _build_dynamic_execution_context(self) -> dict[str, Any] | None: + execution_context: dict[str, Any] = {} + if self._user_execution_context_provider is not None: + provided_context = self._user_execution_context_provider() + if provided_context: + execution_context.update(provided_context) + if self.process_task_id: + execution_context["process_task_id"] = self.process_task_id + if getattr(self, "agent_id", None): + execution_context["agent_id"] = self.agent_id + return execution_context or None + + def _ensure_task_lock(self): + return get_task_lock(self.api_task_id) + + def _stream_with_process_context( + self, response_gen: StreamingChatAgentResponse + ): + with set_process_task(self.process_task_id): + yield from response_gen + + async def _astream_with_process_context( + self, + response_gen: AsyncStreamingChatAgentResponse, + ): + with set_process_task(self.process_task_id): + async for chunk in response_gen: yield chunk - finally: - total_tokens = self._extract_tokens(last_chunk) - self._send_agent_deactivate(accumulated_content, total_tokens) - - async def _astream_chunks(self, response_gen): - """Async generator that wraps a streaming response. - - Sends chunks to frontend. - - Args: - response_gen: The original async streaming response generator - Yields: - Each chunk from the original generator - """ - accumulated_content = "" - last_chunk = None + def _execute_tool( + self, + tool_call_request: ToolCallRequest, + ) -> ToolCallingRecord: + with set_process_task(self.process_task_id): + return super()._execute_tool(tool_call_request) - try: - async for chunk in response_gen: - last_chunk = chunk - if chunk.msg and chunk.msg.content: - delta_content = chunk.msg.content - accumulated_content += delta_content - yield chunk - finally: - total_tokens = self._extract_tokens(last_chunk) - self._send_agent_deactivate(accumulated_content, total_tokens) + async def _aexecute_tool( + self, + tool_call_request: ToolCallRequest, + ) -> ToolCallingRecord: + with set_process_task(self.process_task_id): + return await super()._aexecute_tool(tool_call_request) def step( self, input_message: BaseMessage | str, response_format: type[BaseModel] | None = None, ) -> ChatAgentResponse | StreamingChatAgentResponse: - task_lock = get_task_lock(self.api_task_id) - _schedule_async_task( - task_lock.put_queue( - ActionActivateAgentData( - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "agent_id": self.agent_id, - "message": ( - input_message.content - if isinstance(input_message, BaseMessage) - else input_message - ), - }, - ) - ) - ) - error_info = None - message = None - res = None - msg = ( - input_message.content - if isinstance(input_message, BaseMessage) - else input_message - ) - logger.info( - f"Agent {self.agent_name} starting step with message: {msg}" - ) - try: - res = super().step(input_message, response_format) - except ModelProcessingError as e: - res = None - error_info = e - if "Budget has been exceeded" in str(e): - message = "Budget has been exceeded" - logger.warning(f"Agent {self.agent_name} budget exceeded") - _schedule_async_task( - task_lock.put_queue(ActionBudgetNotEnough()) - ) - else: - message = str(e) - logger.error( - f"Agent {self.agent_name} model processing error: {e}" - ) - total_tokens = 0 - except Exception as e: - res = None - error_info = e - logger.error( - f"Agent {self.agent_name} unexpected error in step: {e}", - exc_info=True, + self._ensure_task_lock() + with set_process_task(self.process_task_id): + response = super().step(input_message, response_format) + if isinstance(response, StreamingChatAgentResponse): + return StreamingChatAgentResponse( + self._stream_with_process_context(response) ) - message = f"Error processing message: {e!s}" - total_tokens = 0 - - if res is not None: - if isinstance(res, StreamingChatAgentResponse): - # Use reusable stream wrapper to send chunks to frontend - return StreamingChatAgentResponse(self._stream_chunks(res)) - - message = res.msg.content if res.msg else "" - usage_info = ( - res.info.get("usage") or res.info.get("token_usage") or {} - ) - total_tokens = ( - usage_info.get("total_tokens", 0) if usage_info else 0 - ) - logger.info( - f"Agent {self.agent_name} completed step, " - f"tokens used: {total_tokens}" - ) - - assert message is not None - - _schedule_async_task( - task_lock.put_queue( - ActionDeactivateAgentData( - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "agent_id": self.agent_id, - "message": message, - "tokens": total_tokens, - }, - ) - ) - ) - - if error_info is not None: - raise error_info - assert res is not None - return res + return response async def astep( self, input_message: BaseMessage | str, response_format: type[BaseModel] | None = None, ) -> ChatAgentResponse | AsyncStreamingChatAgentResponse: - task_lock = get_task_lock(self.api_task_id) - await task_lock.put_queue( - ActionActivateAgentData( - action=Action.activate_agent, - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "agent_id": self.agent_id, - "message": ( - input_message.content - if isinstance(input_message, BaseMessage) - else input_message - ), - }, - ) - ) - - error_info = None - message = None - res = None - msg = ( - input_message.content - if isinstance(input_message, BaseMessage) - else input_message - ) - logger.debug( - f"Agent {self.agent_name} starting async step with message: {msg}" - ) - - try: - res = await super().astep(input_message, response_format) - if isinstance(res, AsyncStreamingChatAgentResponse): - # Use reusable async stream wrapper to send chunks to frontend - return AsyncStreamingChatAgentResponse( - self._astream_chunks(res) - ) - except ModelProcessingError as e: - res = None - error_info = e - if "Budget has been exceeded" in str(e): - message = "Budget has been exceeded" - logger.warning(f"Agent {self.agent_name} budget exceeded") - asyncio.create_task( - task_lock.put_queue(ActionBudgetNotEnough()) - ) - else: - message = str(e) - logger.error( - f"Agent {self.agent_name} model processing error: {e}" - ) - total_tokens = 0 - except Exception as e: - res = None - error_info = e - logger.error( - f"Agent {self.agent_name} unexpected error in async step: {e}", - exc_info=True, - ) - message = f"Error processing message: {e!s}" - total_tokens = 0 - - # For non-streaming responses, extract message and tokens from response - if res is not None and not isinstance( - res, AsyncStreamingChatAgentResponse - ): - message = res.msg.content if res.msg else "" - usage_info = ( - res.info.get("usage") or res.info.get("token_usage") or {} - ) - total_tokens = ( - usage_info.get("total_tokens", 0) if usage_info else 0 - ) - logger.info( - f"Agent {self.agent_name} completed step, " - f"tokens used: {total_tokens}" - ) - - # Send deactivation for all non-streaming cases (success or error) - # Streaming responses handle deactivation in _astream_chunks - assert message is not None - - asyncio.create_task( - task_lock.put_queue( - ActionDeactivateAgentData( - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "agent_id": self.agent_id, - "message": message, - "tokens": total_tokens, - }, - ) - ) - ) - - if error_info is not None: - raise error_info - assert res is not None - return res - - def _execute_tool( - self, tool_call_request: ToolCallRequest - ) -> ToolCallingRecord: - func_name = tool_call_request.tool_name - tool: FunctionTool = self._internal_tools[func_name] - # Route async functions to async execution - # even if they have __wrapped__ - if asyncio.iscoroutinefunction(tool.func): - # For async functions, we need to use the async execution path - return asyncio.run(self._aexecute_tool(tool_call_request)) - - # Handle all sync tools ourselves to maintain ContextVar context - args = tool_call_request.args - tool_call_id = tool_call_request.tool_call_id - - # Check if tool is wrapped by @listen_toolkit decorator - # If so, the decorator will handle activate/deactivate events - # TODO: Refactor - current marker detection is a workaround. - # The proper fix is to unify event sending: - # remove activate/deactivate from @listen_toolkit, only send here - has_listen_decorator = getattr(tool.func, "__listen_toolkit__", False) - - try: - task_lock = get_task_lock(self.api_task_id) - - toolkit_name = ( - tool._toolkit_name - if hasattr(tool, "_toolkit_name") - else "mcp_toolkit" - ) - logger.debug( - f"Agent {self.agent_name} executing tool: " - f"{func_name} from toolkit: {toolkit_name} " - f"with args: {json.dumps(args, ensure_ascii=False)}" + self._ensure_task_lock() + with set_process_task(self.process_task_id): + response = await super().astep(input_message, response_format) + if isinstance(response, AsyncStreamingChatAgentResponse): + return AsyncStreamingChatAgentResponse( + self._astream_with_process_context(response) ) + return response - # Only send activate event if tool is - # NOT wrapped by @listen_toolkit - if not has_listen_decorator: - _schedule_async_task( - task_lock.put_queue( - ActionActivateToolkitData( - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "toolkit_name": toolkit_name, - "method_name": func_name, - "message": json.dumps( - args, ensure_ascii=False - ), - }, - ) - ) - ) - # Set process_task context for all tool executions - with set_process_task(self.process_task_id): - raw_result = tool(**args) - logger.debug(f"Tool {func_name} executed successfully") - if self.mask_tool_output: - self._secure_result_store[tool_call_id] = raw_result - result = ( - "[The tool has been executed successfully, but the output" - " from the tool is masked. You can move forward]" - ) - mask_flag = True - else: - result = raw_result - mask_flag = False - # Prepare result message with truncation - if isinstance(result, str): - result_msg = result - else: - result_str = repr(result) - MAX_RESULT_LENGTH = 500 - if len(result_str) > MAX_RESULT_LENGTH: - result_msg = result_str[:MAX_RESULT_LENGTH] + ( - f"... (truncated, total length: " - f"{len(result_str)} chars)" - ) - else: - result_msg = result_str - - # Only send deactivate event if tool is - # NOT wrapped by @listen_toolkit - if not has_listen_decorator: - _schedule_async_task( - task_lock.put_queue( - ActionDeactivateToolkitData( - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "toolkit_name": toolkit_name, - "method_name": func_name, - "message": result_msg, - }, - ) - ) - ) - except Exception as e: - # Capture the error message to prevent framework crash - error_msg = f"Error executing tool '{func_name}': {e!s}" - result = f"Tool execution failed: {error_msg}" - mask_flag = False - logger.error( - f"Tool execution failed for {func_name}: {e}", exc_info=True - ) - - return self._record_tool_calling( - func_name, - args, - result, - tool_call_id, - mask_output=mask_flag, - extra_content=tool_call_request.extra_content, - ) - - async def _aexecute_tool( - self, tool_call_request: ToolCallRequest - ) -> ToolCallingRecord: - func_name = tool_call_request.tool_name - tool: FunctionTool = self._internal_tools[func_name] - - # Always handle tool execution ourselves to maintain ContextVar context - args = tool_call_request.args - tool_call_id = tool_call_request.tool_call_id - task_lock = get_task_lock(self.api_task_id) - - # Try to get the real toolkit name - toolkit_name = None - - # Method 1: Check _toolkit_name attribute - if hasattr(tool, "_toolkit_name"): - toolkit_name = tool._toolkit_name - - # Method 2: For MCP tools, check if func has __self__ - # (the toolkit instance) - if ( - not toolkit_name - and hasattr(tool, "func") - and hasattr(tool.func, "__self__") - ): - toolkit_instance = tool.func.__self__ - if hasattr(toolkit_instance, "toolkit_name") and callable( - toolkit_instance.toolkit_name - ): - toolkit_name = toolkit_instance.toolkit_name() - - # Method 3: Check if tool.func is a bound method with toolkit - if not toolkit_name and hasattr(tool, "func"): - if hasattr(tool.func, "func") and hasattr( - tool.func.func, "__self__" - ): - toolkit_instance = tool.func.func.__self__ - if hasattr(toolkit_instance, "toolkit_name") and callable( - toolkit_instance.toolkit_name - ): - toolkit_name = toolkit_instance.toolkit_name() - - # Default fallback - if not toolkit_name: - toolkit_name = "mcp_toolkit" - - logger.info( - f"Agent {self.agent_name} executing async tool: {func_name} " - f"from toolkit: {toolkit_name} " - f"with args: {json.dumps(args, ensure_ascii=False)}" - ) - - # Check if tool is wrapped by @listen_toolkit decorator - # If so, the decorator will handle activate/deactivate events - has_listen_decorator = getattr(tool.func, "__listen_toolkit__", False) - - # Only send activate event if tool is NOT wrapped by @listen_toolkit - if not has_listen_decorator: - await task_lock.put_queue( - ActionActivateToolkitData( - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "toolkit_name": toolkit_name, - "method_name": func_name, - "message": json.dumps(args, ensure_ascii=False), - }, - ) - ) - try: - # Set process_task context for all tool executions - with set_process_task(self.process_task_id): - # Try different invocation paths in order of preference - if hasattr(tool, "func") and hasattr(tool.func, "async_call"): - # MCP FunctionTool: always use async_call (sync wrapper can timeout) - result = await tool.func.async_call(**args) - - elif hasattr(tool, "async_call") and callable(tool.async_call): - # Case: tool itself has async_call - # Check if this is a sync tool to avoid run_in_executor - # (which breaks ContextVar) - if hasattr(tool, "is_async") and not tool.is_async: - # Sync tool: call directly to preserve ContextVar - # in same thread - result = tool(**args) - # Handle case where sync call returns a coroutine - if asyncio.iscoroutine(result): - result = await result - else: - # Async tool: use async_call - result = await tool.async_call(**args) - - elif hasattr(tool, "func") and asyncio.iscoroutinefunction( - tool.func - ): - # Case: tool wraps a direct async function - result = await tool.func(**args) - - elif asyncio.iscoroutinefunction(tool): - # Case: tool is itself a coroutine function - result = await tool(**args) - - else: - # Fallback: sync call - call directly in current context - # DO NOT use run_in_executor to preserve ContextVar - result = tool(**args) - # Handle case where synchronous call returns a coroutine - if asyncio.iscoroutine(result): - result = await result - - except Exception as e: - # Capture the error message to prevent framework crash - error_msg = f"Error executing async tool '{func_name}': {e!s}" - result = {"error": error_msg} - logger.error( - f"Async tool execution failed for {func_name}: {e}", - exc_info=True, - ) - - # Prepare result message with truncation - if isinstance(result, str): - result_msg = result - else: - result_str = repr(result) - MAX_RESULT_LENGTH = 500 - if len(result_str) > MAX_RESULT_LENGTH: - result_msg = ( - result_str[:MAX_RESULT_LENGTH] - + f"... (truncated, total length: {len(result_str)} chars)" - ) - else: - result_msg = result_str - - # Only send deactivate event if tool is NOT wrapped by @listen_toolkit - if not has_listen_decorator: - await task_lock.put_queue( - ActionDeactivateToolkitData( - data={ - "agent_name": self.agent_name, - "process_task_id": self.process_task_id, - "toolkit_name": toolkit_name, - "method_name": func_name, - "message": result_msg, - }, - ) - ) - return self._record_tool_calling( - func_name, - args, - result, - tool_call_id, - extra_content=tool_call_request.extra_content, - ) - - def clone(self, with_memory: bool = False) -> ChatAgent: + def clone( + self, + with_memory: bool = False, + clone_context: CloneContext | None = None, + ) -> ChatAgent: """Please see super.clone()""" system_message = None if with_memory else self._original_system_message + effective_clone_context = ( + clone_context.model_copy(deep=True) + if clone_context is not None + else CloneContext() + ) # If this agent has CDP acquire callback, acquire CDP BEFORE cloning # tools so that HybridBrowserToolkit clones with the correct CDP port new_cdp_port = None - new_cdp_session = None + new_cdp_session = effective_clone_context.session_id has_cdp = hasattr(self, "_cdp_acquire_callback") and callable( getattr(self, "_cdp_acquire_callback", None) ) @@ -711,11 +240,13 @@ def clone(self, with_memory: bool = False) -> ChatAgent: cdp_browsers = getattr(options, "cdp_browsers", []) if cdp_browsers and hasattr(self, "_browser_toolkit"): need_cdp_clone = True - import uuid as _uuid + from uuid import uuid4 from app.agent.factory.browser import _cdp_pool_manager - new_cdp_session = str(_uuid.uuid4())[:8] + if not new_cdp_session: + new_cdp_session = str(uuid4())[:8] + effective_clone_context.session_id = new_cdp_session selected = _cdp_pool_manager.acquire_browser( cdp_browsers, new_cdp_session, @@ -741,7 +272,9 @@ def clone(self, with_memory: bool = False) -> ChatAgent: f"http://localhost:{new_cdp_port}" ) try: - cloned_tools, toolkits_to_register = self._clone_tools() + cloned_tools, toolkits_to_register = self._clone_tools( + effective_clone_context + ) except Exception: _cdp_pool_manager.release_browser( new_cdp_port, new_cdp_session @@ -752,7 +285,15 @@ def clone(self, with_memory: bool = False) -> ChatAgent: original_cdp_url ) else: - cloned_tools, toolkits_to_register = self._clone_tools() + cloned_tools, toolkits_to_register = self._clone_tools( + effective_clone_context + ) + + clone_execution_context = dict(self._user_execution_context) + if effective_clone_context.execution_context: + clone_execution_context.update( + effective_clone_context.execution_context + ) new_agent = ListenChatAgent( api_task_id=self.api_task_id, @@ -779,8 +320,14 @@ def clone(self, with_memory: bool = False) -> ChatAgent: pause_event=self.pause_event, prune_tool_calls_from_memory=self.prune_tool_calls_from_memory, enable_snapshot_clean=self._enable_snapshot_clean, + retry_attempts=self.retry_attempts, + retry_delay=self.retry_delay, step_timeout=self.step_timeout, + callbacks=self._user_callbacks, + execution_context=clone_execution_context, + execution_context_provider=self._user_execution_context_provider, stream_accumulate=self.stream_accumulate, + summary_window_ratio=self.summary_window_ratio, ) new_agent.process_task_id = self.process_task_id diff --git a/backend/app/agent/listen_chat_agent_callback.py b/backend/app/agent/listen_chat_agent_callback.py new file mode 100644 index 000000000..38f5cac98 --- /dev/null +++ b/backend/app/agent/listen_chat_agent_callback.py @@ -0,0 +1,227 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from app.service.task import ( + ActionActivateAgentData, + ActionActivateToolkitData, + ActionBudgetNotEnough, + ActionDeactivateAgentData, + ActionDeactivateToolkitData, +) +from app.utils.event_loop_utils import _schedule_async_task +from camel.agents import ( + AgentCallback, + AgentEvent, + StepCompletedEvent, + StepFailedEvent, + StepStartedEvent, + ToolCompletedEvent, + ToolFailedEvent, + ToolStartedEvent, +) + +if TYPE_CHECKING: + from app.agent.listen_chat_agent import ListenChatAgent + from camel.toolkits import FunctionTool + +logger = logging.getLogger("listen_chat_agent_callback") + + +def _get_total_tokens(event: StepCompletedEvent) -> int: + usage = event.usage or {} + return usage.get("total_tokens", 0) + + +class ListenChatAgentCallback(AgentCallback): + """Bridge CAMEL agent lifecycle events into Eigent task actions.""" + + def __init__(self, agent: ListenChatAgent) -> None: + self._agent = agent + + def handle_event(self, event: AgentEvent) -> None: + if isinstance(event, StepStartedEvent): + self._handle_step_started(event) + elif isinstance(event, StepCompletedEvent): + self._handle_step_completed(event) + elif isinstance(event, StepFailedEvent): + self._handle_step_failed(event) + elif isinstance(event, ToolStartedEvent): + self._handle_tool_started(event) + elif isinstance(event, ToolCompletedEvent): + self._handle_tool_completed(event) + elif isinstance(event, ToolFailedEvent): + self._handle_tool_failed(event) + + def _queue_action(self, action) -> None: + task_lock = self._agent._ensure_task_lock() + _schedule_async_task(task_lock.put_queue(action)) + + def _resolve_tool(self, tool_name: str) -> FunctionTool | None: + return self._agent._internal_tools.get(tool_name) + + def _should_skip_toolkit_event(self, tool_name: str) -> bool: + tool = self._resolve_tool(tool_name) + if tool is None: + return False + return bool(getattr(tool.func, "__listen_toolkit__", False)) + + def _resolve_toolkit_name( + self, event_toolkit_name: str | None, tool_name: str + ) -> str: + if event_toolkit_name: + return event_toolkit_name + + tool = self._resolve_tool(tool_name) + if tool is None: + return "mcp_toolkit" + + if hasattr(tool, "_toolkit_name"): + return tool._toolkit_name + + if hasattr(tool, "func") and hasattr(tool.func, "__self__"): + toolkit_instance = tool.func.__self__ + if hasattr(toolkit_instance, "toolkit_name") and callable( + toolkit_instance.toolkit_name + ): + return toolkit_instance.toolkit_name() + + if ( + hasattr(tool, "func") + and hasattr(tool.func, "func") + and hasattr(tool.func.func, "__self__") + ): + toolkit_instance = tool.func.func.__self__ + if hasattr(toolkit_instance, "toolkit_name") and callable( + toolkit_instance.toolkit_name + ): + return toolkit_instance.toolkit_name() + + return "mcp_toolkit" + + def _activate_agent_payload(self, message: str) -> dict: + return { + "agent_name": self._agent.agent_name, + "process_task_id": self._agent.process_task_id, + "agent_id": self._agent.agent_id, + "message": message, + } + + def _deactivate_agent_payload(self, message: str, tokens: int = 0) -> dict: + return { + "agent_name": self._agent.agent_name, + "process_task_id": self._agent.process_task_id, + "agent_id": self._agent.agent_id, + "message": message, + "tokens": tokens, + } + + def _tool_payload( + self, + *, + tool_name: str, + toolkit_name: str, + message: str, + ) -> dict: + return { + "agent_name": self._agent.agent_name, + "process_task_id": self._agent.process_task_id, + "toolkit_name": toolkit_name, + "method_name": tool_name, + "message": message, + } + + def _handle_step_started(self, event: StepStartedEvent) -> None: + self._queue_action( + ActionActivateAgentData( + data=self._activate_agent_payload(event.input_summary or "") + ) + ) + + def _handle_step_completed(self, event: StepCompletedEvent) -> None: + self._queue_action( + ActionDeactivateAgentData( + data=self._deactivate_agent_payload( + event.output_summary or "", + _get_total_tokens(event), + ) + ) + ) + + def _handle_step_failed(self, event: StepFailedEvent) -> None: + message = event.error_message + if "Budget has been exceeded" in message: + self._queue_action(ActionBudgetNotEnough()) + message = "Budget has been exceeded" + + self._queue_action( + ActionDeactivateAgentData( + data=self._deactivate_agent_payload(message, 0) + ) + ) + + def _handle_tool_started(self, event: ToolStartedEvent) -> None: + if self._should_skip_toolkit_event(event.tool_name): + return + + toolkit_name = self._resolve_toolkit_name( + event.toolkit_name, event.tool_name + ) + self._queue_action( + ActionActivateToolkitData( + data=self._tool_payload( + tool_name=event.tool_name, + toolkit_name=toolkit_name, + message=event.input_summary or "", + ) + ) + ) + + def _handle_tool_completed(self, event: ToolCompletedEvent) -> None: + if self._should_skip_toolkit_event(event.tool_name): + return + + toolkit_name = self._resolve_toolkit_name( + event.toolkit_name, event.tool_name + ) + self._queue_action( + ActionDeactivateToolkitData( + data=self._tool_payload( + tool_name=event.tool_name, + toolkit_name=toolkit_name, + message=event.output_summary or "", + ) + ) + ) + + def _handle_tool_failed(self, event: ToolFailedEvent) -> None: + if self._should_skip_toolkit_event(event.tool_name): + return + + toolkit_name = self._resolve_toolkit_name( + event.toolkit_name, event.tool_name + ) + self._queue_action( + ActionDeactivateToolkitData( + data=self._tool_payload( + tool_name=event.tool_name, + toolkit_name=toolkit_name, + message=event.error_message, + ) + ) + ) diff --git a/backend/tests/app/agent/test_listen_chat_agent.py b/backend/tests/app/agent/test_listen_chat_agent.py index 3b576fcc0..9ca4e81b1 100644 --- a/backend/tests/app/agent/test_listen_chat_agent.py +++ b/backend/tests/app/agent/test_listen_chat_agent.py @@ -13,19 +13,20 @@ # ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= import asyncio +from contextlib import nullcontext from unittest.mock import AsyncMock, MagicMock, patch import pytest -from camel.agents import ChatAgent from camel.agents._types import ToolCallRequest + +from app.agent.listen_chat_agent import CloneContext, ListenChatAgent +from app.model.chat import Chat +from camel.agents import ChatAgent from camel.messages import BaseMessage from camel.responses import ChatAgentResponse from camel.toolkits import FunctionTool from camel.types.agents import ToolCallingRecord -from app.agent.listen_chat_agent import ListenChatAgent -from app.model.chat import Chat - _LCA = "app.agent.listen_chat_agent" pytestmark = pytest.mark.unit @@ -74,7 +75,9 @@ def test_listen_chat_agent_step_with_string_input(self, mock_task_lock): with ( patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock), patch("camel.models.ModelFactory.create") as mock_create_model, - patch("asyncio.create_task"), + patch( + f"{_LCA}.set_process_task", return_value=nullcontext() + ) as mock_set_process_task, ): # Mock the model backend creation mock_backend = MagicMock() @@ -106,8 +109,9 @@ def test_listen_chat_agent_step_with_string_input(self, mock_task_lock): mock_parent_step.assert_called_once() args, kwargs = mock_parent_step.call_args assert args[0] == "Test input message" - # Should queue activation notification - mock_task_lock.put_queue.assert_called() + mock_set_process_task.assert_called_once_with( + "test_process_task" + ) def test_listen_chat_agent_step_with_base_message_input( self, mock_task_lock @@ -119,7 +123,9 @@ def test_listen_chat_agent_step_with_base_message_input( with ( patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock), patch("camel.models.ModelFactory.create") as mock_create_model, - patch("asyncio.create_task"), + patch( + f"{_LCA}.set_process_task", return_value=nullcontext() + ) as mock_set_process_task, ): # Mock the model backend creation mock_backend = MagicMock() @@ -156,12 +162,9 @@ def test_listen_chat_agent_step_with_base_message_input( mock_parent_step.assert_called_once() args, kwargs = mock_parent_step.call_args assert args[0] is mock_message - - # Should queue activation with message content - mock_task_lock.put_queue.assert_called() - # Just verify put_queue was called - - # don't check internal data - # structure details + mock_set_process_task.assert_called_once_with( + "test_process_task" + ) @pytest.mark.asyncio async def test_listen_chat_agent_astep(self, mock_task_lock): @@ -172,7 +175,9 @@ async def test_listen_chat_agent_astep(self, mock_task_lock): with ( patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock), patch("camel.models.ModelFactory.create") as mock_create_model, - patch("asyncio.create_task"), + patch( + f"{_LCA}.set_process_task", return_value=nullcontext() + ) as mock_set_process_task, ): # Mock the model backend creation mock_backend = MagicMock() @@ -204,9 +209,9 @@ async def test_listen_chat_agent_astep(self, mock_task_lock): mock_parent_astep.assert_called_once() args, kwargs = mock_parent_astep.call_args assert args[0] == "Test async input" - - # Verify that task lock put_queue was called - mock_task_lock.put_queue.assert_called() + mock_set_process_task.assert_called_once_with( + "test_process_task" + ) def test_listen_chat_agent_execute_tool(self, mock_task_lock): """Test ListenChatAgent _execute_tool method.""" @@ -217,6 +222,7 @@ def test_listen_chat_agent_execute_tool(self, mock_task_lock): patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock), patch("camel.models.ModelFactory.create") as mock_create_model, patch("asyncio.create_task"), + patch("app.agent.listen_chat_agent_callback._schedule_async_task"), ): # Mock the model backend creation mock_backend = MagicMock() @@ -267,6 +273,7 @@ async def test_listen_chat_agent_aexecute_tool(self, mock_task_lock): with ( patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock), patch("camel.models.ModelFactory.create") as mock_create_model, + patch("app.agent.listen_chat_agent_callback._schedule_async_task"), ): # Mock the model backend creation mock_backend = MagicMock() @@ -361,7 +368,13 @@ def test_listen_chat_agent_clone(self, mock_task_lock): ) as mock_clone_constructor, patch.object(agent, "_clone_tools", return_value=([], [])), ): - result = agent.clone(with_memory=True) + result = agent.clone( + with_memory=True, + clone_context=CloneContext( + session_id="clone-session", + execution_context={"clone_mode": "test"}, + ), + ) assert result is cloned_agent mock_clone_constructor.assert_called_once() diff --git a/backend/tests/app/agent/test_listen_chat_agent_callback.py b/backend/tests/app/agent/test_listen_chat_agent_callback.py new file mode 100644 index 000000000..29db251a4 --- /dev/null +++ b/backend/tests/app/agent/test_listen_chat_agent_callback.py @@ -0,0 +1,96 @@ +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2025-2026 @ Eigent.ai All Rights Reserved. ========= + +from unittest.mock import MagicMock, patch + +import pytest + +from app.agent.listen_chat_agent import ListenChatAgent +from app.agent.listen_chat_agent_callback import ( + StepCompletedEvent, + StepFailedEvent, + StepStartedEvent, +) + +_LCA = "app.agent.listen_chat_agent" + + +@pytest.fixture +def mock_task_lock(): + mock_lock = MagicMock() + mock_lock.put_queue = MagicMock() + return mock_lock + + +def _make_agent(mock_task_lock, mock_create_model): + mock_backend = MagicMock() + mock_backend.model_type = "gpt-4" + mock_backend.current_model = MagicMock() + mock_backend.current_model.model_type = "gpt-4" + mock_create_model.return_value = mock_backend + + agent = ListenChatAgent( + api_task_id="test_api_task_123", + agent_name="TestAgent", + model="gpt-4", + ) + agent.process_task_id = "test_process_task" + agent.agent_id = "test_agent_123" + return agent + + +class TestListenChatAgentCallback: + def test_handles_step_events(self, mock_task_lock): + with ( + patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock), + patch("camel.models.ModelFactory.create") as mock_create_model, + patch("app.agent.listen_chat_agent_callback._schedule_async_task"), + ): + agent = _make_agent(mock_task_lock, mock_create_model) + + agent._listen_callback.handle_event( + StepStartedEvent( + agent_id=agent.agent_id, + role_name=agent.role_name, + input_summary="Test input message", + ) + ) + agent._listen_callback.handle_event( + StepCompletedEvent( + agent_id=agent.agent_id, + role_name=agent.role_name, + output_summary="Test response content", + usage={"total_tokens": 100}, + ) + ) + + assert mock_task_lock.put_queue.call_count == 2 + + def test_handles_budget_failure(self, mock_task_lock): + with ( + patch(f"{_LCA}.get_task_lock", return_value=mock_task_lock), + patch("camel.models.ModelFactory.create") as mock_create_model, + patch("app.agent.listen_chat_agent_callback._schedule_async_task"), + ): + agent = _make_agent(mock_task_lock, mock_create_model) + + agent._listen_callback.handle_event( + StepFailedEvent( + agent_id=agent.agent_id, + role_name=agent.role_name, + error_message="Budget has been exceeded", + ) + ) + + assert mock_task_lock.put_queue.call_count == 2