diff --git a/docs/agents/callbacks.md b/docs/agents/callbacks.md index cf39beaf..fdf9d23a 100644 --- a/docs/agents/callbacks.md +++ b/docs/agents/callbacks.md @@ -54,6 +54,34 @@ property so that callbacks can access information in a framework-agnostic way. You can see what attributes are available for LLM Calls and Tool Executions by examining the [`GenAI`][any_agent.tracing.attributes.GenAI] class. +### Framework State + +In addition to the span attributes, callbacks can access and modify framework-specific objects through [`Context.framework_state`][any_agent.callbacks.context.Context.framework_state]. + +This allows callbacks to directly manipulate the agent's execution, such as: + +- Modifying messages before they're sent to the LLM +- Modifying the LLM's response after generation +- Injecting prompts mid-execution +- Changing user queries dynamically + +#### Helper Methods + +The `framework_state` provides helper methods to work with messages in a normalized format: + +**`get_messages()`**: Get messages as a list of dicts with `role` and `content` keys + +**`set_messages()`**: Set messages from a list of dicts with `role` and `content` keys + +These methods handle framework-specific message formats internally, providing a consistent API across frameworks. + +!!! note "Availability" + + The `get_messages()` and `set_messages()` methods are **only available in `before_llm_call` callbacks**. + + - In `before_llm_call`: You can read and modify the messages that will be sent to the LLM + - In other callbacks (`after_llm_call`, `before_tool_execution`, `after_tool_execution`, etc.): These methods will raise `NotImplementedError` + ## Implementing Callbacks All callbacks must inherit from the base [`Callback`][any_agent.callbacks.base.Callback] class and can choose to implement any subset of the available callback methods. These methods include: @@ -61,7 +89,7 @@ All callbacks must inherit from the base [`Callback`][any_agent.callbacks.base.C | Callback Method | Description | |:----------------:|:------------:| | before_agent_invocation | Should be used to check the Context before the agent is invoked. | -| befor_llm_call | Should be used before the chat history hits the LLM. | +| before_llm_call | Should be used before the chat history hits the LLM. | | after_llm_call | Should be used once LLM output is generated, before it appends to the chat history. | | before_tool_execution | Should be used to check the Context before tool execution. | | after_tool_execution | Should be used once tool has been executed, before the output is appended to chat history. | @@ -136,7 +164,7 @@ Callbacks are provided to the agent using the [`AgentConfig.callbacks`][any_agen agent = AnyAgent.create( "tinyagent", AgentConfig( - model_id="gpt-4.1-nano", + model_id="openai:gpt-4.1-nano", instructions="Use the tools to find an answer", tools=[search_web, visit_webpage], callbacks=[ @@ -157,7 +185,7 @@ Callbacks are provided to the agent using the [`AgentConfig.callbacks`][any_agen agent = AnyAgent.create( "tinyagent", AgentConfig( - model_id="gpt-4.1-nano", + model_id="openai:gpt-4.1-nano", instructions="Use the tools to find an answer", tools=[search_web, visit_webpage], callbacks=[ @@ -272,3 +300,50 @@ class LimitToolExecutions(Callback): return context ``` + +## Example: Modifying prompts dynamically + +You can use callbacks to modify the prompt being sent to the LLM. This is useful for injecting instructions or reminders mid-execution: + +```python +from any_agent.callbacks.base import Callback +from any_agent.callbacks.context import Context + +class InjectReminderCallback(Callback): + def __init__(self, reminder: str, every_n_calls: int = 5): + self.reminder = reminder + self.every_n_calls = every_n_calls + self.call_count = 0 + + def before_llm_call(self, context: Context, *args, **kwargs) -> Context: + self.call_count += 1 + + if self.call_count % self.every_n_calls == 0: + try: + messages = context.framework_state.get_messages() + if messages: + messages[-1]["content"] += f"\n\n{self.reminder}" + context.framework_state.set_messages(messages) + except NotImplementedError: + pass + + return context +``` + +Example usage: + +```python +from any_agent import AgentConfig, AnyAgent + +callback = InjectReminderCallback( + reminder="Remember to use the Todo tool to track your tasks!", + every_n_calls=5 +) + +config = AgentConfig( + model_id="openai:gpt-4o-mini", + instructions="You are a helpful assistant.", + callbacks=[callback], +) +# ... Continue to create and run agent +``` diff --git a/docs/api/callbacks.md b/docs/api/callbacks.md index ebca1863..ed827a0e 100644 --- a/docs/api/callbacks.md +++ b/docs/api/callbacks.md @@ -4,6 +4,8 @@ ::: any_agent.callbacks.context.Context +::: any_agent.callbacks.context.FrameworkState + ::: any_agent.callbacks.span_print.ConsolePrintSpan ::: any_agent.callbacks.get_default_callbacks diff --git a/src/any_agent/callbacks/__init__.py b/src/any_agent/callbacks/__init__.py index 7e26da88..836625a4 100644 --- a/src/any_agent/callbacks/__init__.py +++ b/src/any_agent/callbacks/__init__.py @@ -1,8 +1,8 @@ from .base import Callback -from .context import Context +from .context import Context, FrameworkState from .span_print import ConsolePrintSpan -__all__ = ["Callback", "ConsolePrintSpan", "Context"] +__all__ = ["Callback", "ConsolePrintSpan", "Context", "FrameworkState"] def get_default_callbacks() -> list[Callback]: diff --git a/src/any_agent/callbacks/context.py b/src/any_agent/callbacks/context.py index e3479991..7574bdb2 100644 --- a/src/any_agent/callbacks/context.py +++ b/src/any_agent/callbacks/context.py @@ -1,14 +1,87 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from collections.abc import Callable + from opentelemetry.trace import Span, Tracer from any_agent.tracing.agent_trace import AgentTrace +@dataclass +class FrameworkState: + """Framework-specific state that can be accessed and modified by callbacks. + + This object provides a consistent interface for accessing framework state across + different agent frameworks, while the actual content is framework-specific. + """ + + messages: list[dict[str, Any]] = field(default_factory=list) + """Internal storage for messages. Use get_messages() and set_messages() instead.""" + + _message_getter: Callable[[], list[dict[str, Any]]] | None = field( + default=None, repr=False + ) + """Framework-specific message getter function.""" + + _message_setter: Callable[[list[dict[str, Any]]], None] | None = field( + default=None, repr=False + ) + """Framework-specific message setter function.""" + + def get_messages(self) -> list[dict[str, Any]]: + """Get messages in a normalized dict format. + + Returns a list of message dicts with 'role' and 'content' keys. + Works consistently across all frameworks. + + Returns: + List of message dicts with 'role' and 'content' keys. + + Raises: + NotImplementedError: If the framework doesn't support message access yet. + + Example: + ```python + messages = context.framework_state.get_messages() + # [{"role": "user", "content": "Hello"}] + ``` + + """ + if self._message_getter is None: + msg = "get_messages() is not implemented for this framework yet" + raise NotImplementedError(msg) + return self._message_getter() + + def set_messages(self, messages: list[dict[str, Any]]) -> None: + """Set messages from a normalized dict format. + + Accepts a list of message dicts with 'role' and 'content' keys and + converts them to the framework-specific format. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Raises: + NotImplementedError: If the framework doesn't support message modification yet. + + Example: + ```python + messages = context.framework_state.get_messages() + messages[-1]["content"] = "Say hello" + context.framework_state.set_messages(messages) + ``` + + """ + if self._message_setter is None: + msg = "set_messages() is not implemented for this framework yet" + raise NotImplementedError(msg) + self._message_setter(messages) + + @dataclass class Context: """Object that will be shared across callbacks. @@ -31,3 +104,20 @@ class Context: shared: dict[str, Any] """Can be used to store arbitrary information for sharing across callbacks.""" + + framework_state: FrameworkState + """Framework-specific state that can be accessed and modified by callbacks. + + Provides consistent access to framework state across different agent frameworks. + See [`FrameworkState`][any_agent.callbacks.context.FrameworkState] for available attributes. + + Example: + ```python + class ModifyPromptCallback(Callback): + def before_llm_call(self, context: Context, *args, **kwargs) -> Context: + # Modify the last message content + if context.framework_state.messages: + context.framework_state.messages[-1]["content"] = "Say hello" + return context + ``` + """ diff --git a/src/any_agent/callbacks/span_generation/agno.py b/src/any_agent/callbacks/span_generation/agno.py index d4e44582..dd3c0cf7 100644 --- a/src/any_agent/callbacks/span_generation/agno.py +++ b/src/any_agent/callbacks/span_generation/agno.py @@ -23,26 +23,62 @@ def before_llm_call(self, context: Context, *args, **kwargs): def after_llm_call(self, context: Context, *args, **kwargs) -> Context: output: str | list[dict[str, Any]] = "" + input_tokens: int = 0 + output_tokens: int = 0 + if assistant_message := kwargs.get("assistant_message"): - if content := getattr(assistant_message, "content", None): - output = str(content) - if tool_calls := getattr(assistant_message, "tool_calls", None): - output = [ - { - "tool.name": tool.get("function", {}).get("name", "No name"), - "tool.args": tool.get("function", {}).get( - "arguments", "No args" - ), - } - for tool in tool_calls - ] + if hasattr(assistant_message, "choices"): + choices = getattr(assistant_message, "choices", []) + if choices and len(choices) > 0: + choice = choices[0] + message = getattr(choice, "message", None) + if message: + if content := getattr(message, "content", None): + output = str(content) + if tool_calls := getattr(message, "tool_calls", None): + output = [ + { + "tool.name": getattr( + getattr(tool, "function", None), + "name", + "No name", + ), + "tool.args": getattr( + getattr(tool, "function", None), + "arguments", + "No args", + ), + } + for tool in tool_calls + ] + + if usage := getattr(assistant_message, "usage", None): + input_tokens = getattr(usage, "input_tokens", 0) or getattr( + usage, "prompt_tokens", 0 + ) + output_tokens = getattr(usage, "output_tokens", 0) or getattr( + usage, "completion_tokens", 0 + ) + else: + if content := getattr(assistant_message, "content", None): + output = str(content) + if tool_calls := getattr(assistant_message, "tool_calls", None): + output = [ + { + "tool.name": tool.get("function", {}).get( + "name", "No name" + ), + "tool.args": tool.get("function", {}).get( + "arguments", "No args" + ), + } + for tool in tool_calls + ] - metrics: MessageMetrics | None - input_tokens: int = 0 - output_tokens: int = 0 - if metrics := getattr(assistant_message, "metrics", None): - input_tokens = metrics.input_tokens - output_tokens = metrics.output_tokens + metrics: MessageMetrics | None + if metrics := getattr(assistant_message, "metrics", None): + input_tokens = metrics.input_tokens + output_tokens = metrics.output_tokens context = self._set_llm_output(context, output, input_tokens, output_tokens) diff --git a/src/any_agent/callbacks/wrappers/agno.py b/src/any_agent/callbacks/wrappers/agno.py index b8f20d22..3cbedb2e 100644 --- a/src/any_agent/callbacks/wrappers/agno.py +++ b/src/any_agent/callbacks/wrappers/agno.py @@ -9,33 +9,77 @@ from any_agent.callbacks.context import Context from any_agent.frameworks.agno import AgnoAgent +try: + from agno.models.message import Message + + agno_available = True +except ImportError: + agno_available = False + Message = None # type: ignore[assignment,misc] + class _AgnoWrapper: def __init__(self) -> None: self.callback_context: dict[int, Context] = {} - self._original_aprocess_model: Any = None + self._original_ainvoke: Any = None self._original_arun_function_call: Any = None async def wrap(self, agent: AgnoAgent) -> None: - self._original_aprocess_model = agent._agent.model._aprocess_model_response + if not agno_available: + msg = "Agno is not installed" + raise ImportError(msg) - async def wrapped_llm_call(*args, **kwargs): + self._original_ainvoke = agent._agent.model.ainvoke + + async def wrapped_ainvoke(messages, *args, **kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] context.shared["model_id"] = agent._agent.model.id + def get_messages(): + return [ + { + "role": msg.role, + "content": msg.content if msg.content else "", + } + for msg in messages + ] + + def set_messages(new_messages): + messages.clear() + for msg_dict in new_messages: + msg = Message( + role=msg_dict.get("role", "user"), + content=msg_dict.get("content", ""), + ) + if "tool_calls" in msg_dict: + msg.tool_calls = msg_dict["tool_calls"] + if "tool_call_id" in msg_dict: + msg.tool_call_id = msg_dict["tool_call_id"] + if "name" in msg_dict: + msg.name = msg_dict["name"] + messages.append(msg) + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + + callback_kwargs = {**kwargs, "messages": messages} for callback in agent.config.callbacks: - context = callback.before_llm_call(context, *args, **kwargs) + context = callback.before_llm_call(context, *args, **callback_kwargs) - result = await self._original_aprocess_model(*args, **kwargs) + result = await self._original_ainvoke(messages, *args, **kwargs) + callback_kwargs = {**kwargs, "assistant_message": result} for callback in agent.config.callbacks: - context = callback.after_llm_call(context, result, *args, **kwargs) + context = callback.after_llm_call(context, *args, **callback_kwargs) + + context.framework_state._message_getter = None + context.framework_state._message_setter = None return result - agent._agent.model._aprocess_model_response = wrapped_llm_call + agent._agent.model.ainvoke = wrapped_ainvoke self._original_arun_function_call = agent._agent.model.arun_function_call @@ -62,7 +106,7 @@ async def wrapped_tool_execution( agent._agent.model.arun_function_call = wrapped_tool_execution async def unwrap(self, agent: AgnoAgent): - if self._original_aprocess_model is not None: - agent._agent.model._aprocess_model_response = self._original_aprocess_model + if self._original_ainvoke is not None: + agent._agent.model.ainvoke = self._original_ainvoke if self._original_arun_function_call is not None: - agent._agent.model.arun_function_calls = self._original_arun_function_call + agent._agent.model.arun_function_call = self._original_arun_function_call diff --git a/src/any_agent/callbacks/wrappers/google.py b/src/any_agent/callbacks/wrappers/google.py index 88582af2..c6a9eabd 100644 --- a/src/any_agent/callbacks/wrappers/google.py +++ b/src/any_agent/callbacks/wrappers/google.py @@ -10,6 +10,93 @@ from any_agent.frameworks.google import GoogleAgent +def _import_google_converters() -> tuple[Any, Any]: + """Import conversion functions from google framework module.""" + from any_agent.frameworks.google import ( + _messages_from_content, + _messages_to_contents, + ) + + return _messages_from_content, _messages_to_contents + + +def _llm_response_to_message(llm_response) -> dict[str, Any]: + """Convert Google ADK LlmResponse to a normalized message dict.""" + from any_agent.frameworks.google import ADK_TO_ANY_LLM_ROLE, _safe_json_serialize + + if not llm_response or not llm_response.content: + return {"role": "assistant", "content": None} + + content = llm_response.content + role = ADK_TO_ANY_LLM_ROLE.get(str(content.role), "assistant") + + message_content: list[Any] = [] + tool_calls: list[Any] = [] + + if content.parts: + for part in content.parts: + if part.text: + message_content.append({"type": "text", "text": part.text}) + elif part.function_call: + tool_calls.append( + { + "type": "function", + "id": part.function_call.id, + "function": { + "name": part.function_call.name, + "arguments": _safe_json_serialize(part.function_call.args), + }, + } + ) + + return { + "role": role, + "content": message_content or None, + "tool_calls": tool_calls or None, + } + + +def _message_to_llm_response(message: dict[str, Any]): + """Convert a normalized message dict back to Google ADK LlmResponse.""" + import json + + from google.adk.models.llm_response import LlmResponse + from google.genai import types + + parts = [] + + content = message.get("content") + if content: + if isinstance(content, str): + parts.append(types.Part.from_text(text=content)) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(types.Part.from_text(text=part.get("text", ""))) + + tool_calls = message.get("tool_calls") + if tool_calls: + for tool_call in tool_calls: + if tool_call.get("type") == "function": + function_data = tool_call.get("function", {}) + args_str = function_data.get("arguments", "{}") + try: + args = ( + json.loads(args_str) if isinstance(args_str, str) else args_str + ) + except json.JSONDecodeError: + args = {} + + part = types.Part.from_function_call( + name=function_data.get("name", ""), + args=args, + ) + part.function_call.id = tool_call.get("id", "") + parts.append(part) + + return LlmResponse(content=types.Content(role="model", parts=parts), partial=False) + + class _GoogleADKWrapper: def __init__(self) -> None: self.callback_context: dict[int, Context] = {} @@ -23,6 +110,53 @@ def before_model_callback(*args, **kwargs) -> Any | None: get_current_span().get_span_context().trace_id ] + llm_request = kwargs.get("llm_request") + if llm_request is not None: + _messages_from_content, _messages_to_contents = ( + _import_google_converters() + ) + + messages = _messages_from_content(llm_request) + + for message in messages: + content = message.get("content") + if isinstance(content, list): + text_parts = [ + part.get("text", "") + for part in content + if isinstance(part, dict) and part.get("type") == "text" + ] + if text_parts and len(text_parts) == len(content): + message["content"] = ( + " ".join(text_parts) + if len(text_parts) > 1 + else text_parts[0] + ) + + if llm_request.config and llm_request.config.system_instruction: + messages.insert( + 0, + { + "role": "system", + "content": llm_request.config.system_instruction, + }, + ) + + context.framework_state.messages = messages + + def get_messages(): + return context.framework_state.messages + + def set_messages(new_messages): + context.framework_state.messages = new_messages + system_instruction, contents = _messages_to_contents(new_messages) + llm_request.contents = contents + if llm_request.config: + llm_request.config.system_instruction = system_instruction + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) @@ -40,9 +174,49 @@ def after_model_callback(*args, **kwargs) -> Any | None: get_current_span().get_span_context().trace_id ] + llm_response = kwargs.get("llm_response") + if llm_response is not None: + response_message = _llm_response_to_message(llm_response) + existing_messages = context.framework_state.messages.copy() + if isinstance(response_message.get("content"), list): + text_parts = [ + part.get("text", "") + for part in response_message["content"] + if isinstance(part, dict) and part.get("type") == "text" + ] + if text_parts: + response_message["content"] = " ".join(text_parts) + + all_messages = [*existing_messages, response_message] + context.framework_state.messages = all_messages + + original_response_content = response_message.get("content") + + def get_messages(): + return context.framework_state.messages + + def set_messages(new_messages): + context.framework_state.messages = new_messages + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.after_llm_call(context, *args, **kwargs) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + + if llm_response is not None: + final_messages = context.framework_state.messages + if final_messages: + final_response_message = final_messages[-1] + if ( + final_response_message.get("content") + != original_response_content + ): + return _message_to_llm_response(final_response_message) + if callable(self._original["after_model"]): return self._original["after_model"](*args, **kwargs) diff --git a/src/any_agent/callbacks/wrappers/langchain.py b/src/any_agent/callbacks/wrappers/langchain.py index 207490c4..81b97d57 100644 --- a/src/any_agent/callbacks/wrappers/langchain.py +++ b/src/any_agent/callbacks/wrappers/langchain.py @@ -5,6 +5,8 @@ from opentelemetry.trace import get_current_span +from any_agent.logging import logger + if TYPE_CHECKING: from collections.abc import Callable from uuid import UUID @@ -16,11 +18,22 @@ from any_agent.frameworks.langchain import LangchainAgent +def _import_langchain_converters() -> tuple[Any, Any]: + """Import conversion functions from langchain vendor module.""" + from any_agent.vendor.langchain_any_llm import ( + _convert_dict_to_message, + _convert_message_to_dict, + ) + + return _convert_dict_to_message, _convert_message_to_dict + + class _LangChainWrapper: def __init__(self) -> None: self.callback_context: dict[int, Context] = {} self._original_ainvoke: Any | None = None self._original_llm_call: Callable[..., Any] | None = None + self._original_agenerate: Callable[..., Any] | None = None async def wrap(self, agent: LangchainAgent) -> None: from langchain_core.callbacks.base import BaseCallbackHandler @@ -30,6 +43,8 @@ def before_llm_call(*args, **kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] + # Note: Message getters/setters are set up in wrap_agenerate + # This callback is for span generation via on_chat_model_start for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) @@ -126,7 +141,64 @@ async def wrap_ainvoke(*args, **kwargs): # type: ignore[no-untyped-def] agent._agent.ainvoke = wrap_ainvoke - # Wrap call_model to capture any-llm calls during structured output processing + if agent._model is not None and hasattr(agent._model, "_agenerate"): + self._original_agenerate = agent._model._agenerate + + async def wrap_agenerate(messages, *args, **kwargs): + messages_list = list(messages) + + try: + context = self.callback_context[ + get_current_span().get_span_context().trace_id + ] + _convert_dict_to_message, _convert_message_to_dict = ( + _import_langchain_converters() + ) + + normalized_messages = [ + _convert_message_to_dict(msg) for msg in messages_list + ] + context.framework_state.messages = normalized_messages + + def get_messages(): + return context.framework_state.messages + + def set_messages(new_messages): + nonlocal messages_list + context.framework_state.messages = new_messages + messages_list[:] = [ + _convert_dict_to_message(msg) for msg in new_messages + ] + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + + # Call user callbacks (but not span generation, that will happen in on_chat_model_start) + for callback in agent.config.callbacks: + if not hasattr(callback, "_set_llm_input"): + context = callback.before_llm_call( + context, None, [messages_list], **kwargs + ) + + except Exception: + # If we can't get context, just proceed without modification + logger.warning( + "Could not get context, proceeding without modification" + ) + + result = await self._original_agenerate(messages_list, *args, **kwargs) + + context = self.callback_context[ + get_current_span().get_span_context().trace_id + ] + context.framework_state._message_getter = None + context.framework_state._message_setter = None + + return result + + agent._model._agenerate = wrap_agenerate + else: + logger.warning("Could not wrap _agenerate, proceeding without modification") self._original_llm_call = agent.call_model async def wrap_call_model(**kwargs): @@ -150,5 +222,12 @@ async def unwrap(self, agent: LangchainAgent) -> None: if self._original_ainvoke is not None: agent._agent.ainvoke = self._original_ainvoke + if ( + self._original_agenerate is not None + and agent._model is not None + and hasattr(agent._model, "_agenerate") + ): + agent._model._agenerate = self._original_agenerate + if self._original_llm_call is not None: agent.call_model = self._original_llm_call diff --git a/src/any_agent/callbacks/wrappers/llama_index.py b/src/any_agent/callbacks/wrappers/llama_index.py index b0d034f8..a24fc3cf 100644 --- a/src/any_agent/callbacks/wrappers/llama_index.py +++ b/src/any_agent/callbacks/wrappers/llama_index.py @@ -10,6 +10,16 @@ from any_agent.frameworks.llama_index import LlamaIndexAgent +def _import_llama_index_converters() -> tuple[Any, Any]: + """Import conversion functions from llama_index vendor module.""" + from any_agent.vendor.llama_index_utils import ( + from_openai_message_dict, + to_openai_message_dicts, + ) + + return to_openai_message_dicts, from_openai_message_dict + + class _LlamaIndexWrapper: def __init__(self) -> None: self.callback_context: dict[int, Context] = {} @@ -26,6 +36,27 @@ async def wrap_take_step(*args, **kwargs): ] context.shared["model_id"] = getattr(agent._agent.llm, "model", "No model") + if len(args) > 1 and isinstance(args[1], list): + to_openai_message_dicts, from_openai_message_dict = ( + _import_llama_index_converters() + ) + + normalized_messages = to_openai_message_dicts(args[1]) + context.framework_state.messages = normalized_messages + + def get_messages(): + return context.framework_state.messages + + def set_messages(new_messages): + context.framework_state.messages = new_messages + args[1].clear() + args[1].extend( + [from_openai_message_dict(msg) for msg in new_messages] + ) + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) @@ -36,6 +67,9 @@ async def wrap_take_step(*args, **kwargs): for callback in agent.config.callbacks: context = callback.after_llm_call(context, output) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + return output # bypass Pydantic validation because _agent is a BaseModel @@ -73,7 +107,6 @@ async def acall(self, *args, **kwargs): wrapped = WrappedAcall(tool.metadata, tool.acall) tool.acall = wrapped.acall - # Wrap call_model to capture any-llm calls during structured output processing self._original_llm_call = agent.call_model async def wrap_call_model(**kwargs): diff --git a/src/any_agent/callbacks/wrappers/openai.py b/src/any_agent/callbacks/wrappers/openai.py index 72413dad..a3ac81e7 100644 --- a/src/any_agent/callbacks/wrappers/openai.py +++ b/src/any_agent/callbacks/wrappers/openai.py @@ -27,6 +27,33 @@ async def wrapped_llm_call(*args, **kwargs): get_current_span().get_span_context().trace_id ] context.shared["model_id"] = getattr(agent._agent.model, "model", None) + # import inside the wrap to avoid cases where the user hasn't installed the agents sdk + from any_agent.frameworks.openai import Converter + + system_instructions = kwargs.get("system_instructions") + input_data = kwargs.get("input") + + if input_data is not None: + converted_messages = Converter.params_to_messages( + system_instructions, input_data + ) + + context.framework_state.messages = converted_messages + + def get_messages(): + return context.framework_state.messages + + def set_messages(messages): + context.framework_state.messages = messages + + new_system_instructions, new_input = Converter.messages_to_params( + messages + ) + kwargs["system_instructions"] = new_system_instructions + kwargs["input"] = new_input + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) @@ -39,6 +66,9 @@ async def wrapped_llm_call(*args, **kwargs): output, ) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + return output agent._agent.model.get_response = wrapped_llm_call diff --git a/src/any_agent/callbacks/wrappers/smolagents.py b/src/any_agent/callbacks/wrappers/smolagents.py index 4b3cb9ff..cfb1a261 100644 --- a/src/any_agent/callbacks/wrappers/smolagents.py +++ b/src/any_agent/callbacks/wrappers/smolagents.py @@ -20,30 +20,105 @@ def __init__(self) -> None: self._original_tools: Any | None = None async def wrap(self, agent: SmolagentsAgent) -> None: + try: + from smolagents.memory import TaskStep + from smolagents.models import ChatMessage + + smolagents_available = True + except ImportError: + smolagents_available = False + + if not smolagents_available: + msg = "Smolagents is not installed" + raise ImportError(msg) + self._original_llm_call = agent._agent.model.generate - def wrap_generate(*args, **kwargs): + def wrap_generate(messages: list[ChatMessage], **kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] context.shared["model_id"] = str(agent._agent.model.model_id) - for callback in agent.config.callbacks: - context = callback.before_llm_call(context, *args, **kwargs) - - output = self._original_llm_call(*args, **kwargs) - - for callback in agent.config.callbacks: - context = callback.after_llm_call(context, output) + def get_messages(): + normalized_messages = [] + for msg in messages: + msg_dict = msg.dict() + + # Handle content that might be a list + content = msg_dict.get("content") + if isinstance(content, list): + text_parts = [ + part.get("text", "") + for part in content + if isinstance(part, dict) and part.get("type") == "text" + ] + if text_parts and len(text_parts) == len(content): + msg_dict["content"] = ( + " ".join(text_parts) + if len(text_parts) > 1 + else text_parts[0] + ) + + normalized_messages.append(msg_dict) + return normalized_messages + + def set_messages(new_messages): + messages.clear() + for msg_dict in new_messages: + content = msg_dict["content"] + if isinstance(content, str): + content = [{"type": "text", "text": content}] + + new_msg_dict = {**msg_dict, "content": content} + messages.append(ChatMessage.from_dict(new_msg_dict)) + + # Update TaskStep in memory so modifications persist through write_memory_to_messages() + # This is necessary because smolagents rebuilds messages from memory on every LLM call + if len(new_messages) >= 2: + for step in agent._agent.memory.steps: + if isinstance(step, TaskStep): + # Extract the task text (remove "New task:\n" prefix if present) + new_task = new_messages[1]["content"] + if isinstance(new_task, str): + new_task = new_task.removeprefix("New task:\n") + step.task = new_task + break + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + + # Only invoke callbacks on the first LLM call, not on retry attempts + # Retries can be detected by checking if there are error messages in the history + is_retry = any( + "Error:" in str(msg.content) and "Now let's retry" in str(msg.content) + for msg in messages + ) + + if not is_retry: + for callback in agent.config.callbacks: + context = callback.before_llm_call(context, messages, **kwargs) + + output = self._original_llm_call(messages, **kwargs) + + if not is_retry: + for callback in agent.config.callbacks: + context = callback.after_llm_call(context, output) + + context.framework_state._message_getter = None + context.framework_state._message_setter = None return output agent._agent.model.generate = wrap_generate def wrapped_tool_execution(original_tool, original_call, *args, **kwargs): - context = self.callback_context[ - get_current_span().get_span_context().trace_id - ] + trace_id = get_current_span().get_span_context().trace_id + + if trace_id == 0 or trace_id not in self.callback_context: + return original_call(**kwargs) + + context = self.callback_context[trace_id] context.shared["original_tool"] = original_tool for callback in agent.config.callbacks: diff --git a/src/any_agent/callbacks/wrappers/tinyagent.py b/src/any_agent/callbacks/wrappers/tinyagent.py index 47183706..3c8cfaee 100644 --- a/src/any_agent/callbacks/wrappers/tinyagent.py +++ b/src/any_agent/callbacks/wrappers/tinyagent.py @@ -26,6 +26,20 @@ async def wrap_call_model(**kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] + + if "messages" in kwargs: + context.framework_state.messages = kwargs["messages"] + + def get_messages(): + return context.framework_state.messages + + def set_messages(messages): + context.framework_state.messages = messages + kwargs["messages"] = messages + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.before_llm_call(context, **kwargs) @@ -34,6 +48,9 @@ async def wrap_call_model(**kwargs): for callback in agent.config.callbacks: context = callback.after_llm_call(context, output) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + return output agent.call_model = wrap_call_model @@ -42,6 +59,7 @@ async def wrapped_tool_execution(original_call, request): context = self.callback_context[ get_current_span().get_span_context().trace_id ] + for callback in agent.config.callbacks: context = callback.before_tool_execution(context, request) diff --git a/src/any_agent/frameworks/agno.py b/src/any_agent/frameworks/agno.py index 027b0284..3679c73c 100644 --- a/src/any_agent/frameworks/agno.py +++ b/src/any_agent/frameworks/agno.py @@ -139,7 +139,7 @@ async def ainvoke_stream( completion_kwargs = self.get_request_params(tools=tools) completion_kwargs["messages"] = self._format_messages(messages) completion_kwargs["stream"] = True - return acompletion(**completion_kwargs) + return await acompletion(**completion_kwargs) def parse_provider_response(self, response: Any, **kwargs) -> ModelResponse: # type: ignore[no-untyped-def] """Parse the provider response.""" diff --git a/src/any_agent/frameworks/any_agent.py b/src/any_agent/frameworks/any_agent.py index b1a1da41..34971fe2 100644 --- a/src/any_agent/frameworks/any_agent.py +++ b/src/any_agent/frameworks/any_agent.py @@ -8,7 +8,7 @@ from any_llm.utils.aio import run_async_in_sync from opentelemetry import trace as otel_trace -from any_agent.callbacks.context import Context +from any_agent.callbacks.context import Context, FrameworkState from any_agent.callbacks.wrappers import ( _get_wrapper_by_framework, ) @@ -217,6 +217,7 @@ async def run_async(self, prompt: str, **kwargs: Any) -> AgentTrace: trace=AgentTrace(), tracer=self._tracer, shared={}, + framework_state=FrameworkState(), ) if len(self._wrapper.callback_context) == 1: diff --git a/src/any_agent/frameworks/google.py b/src/any_agent/frameworks/google.py index 1a1c3315..e4de0730 100644 --- a/src/any_agent/frameworks/google.py +++ b/src/any_agent/frameworks/google.py @@ -36,6 +36,187 @@ "model": "assistant", } +ANY_LLM_TO_ADK_ROLE: dict[str, str] = { + "user": "user", + "assistant": "model", +} + + +def _messages_from_content(llm_request: LlmRequest) -> list[dict[str, Any]]: + """Convert Google ADK LlmRequest to normalized message format. + + Args: + llm_request: The LlmRequest to convert. + + Returns: + List of message dicts with 'role' and 'content' keys. + + """ + messages: list[dict[str, Any]] = [] + for content in llm_request.contents: + message: dict[str, Any] = { + "role": ADK_TO_ANY_LLM_ROLE[str(content.role)], + } + message_content: list[Any] = [] + tool_calls: list[Any] = [] + if parts := content.parts: + for part in parts: + if part.function_response: + messages.append( + { + "role": "tool", + "tool_call_id": part.function_response.id, + "content": _safe_json_serialize( + part.function_response.response + ), + } + ) + elif part.text: + message_content.append({"type": "text", "text": part.text}) + elif ( + part.inline_data + and part.inline_data.data + and part.inline_data.mime_type + ): + # TODO Handle multimodal input + msg = f"Part of type {part.inline_data.mime_type} is not supported." + raise NotImplementedError(msg) + elif part.function_call: + tool_calls.append( + { + "type": "function", + "id": part.function_call.id, + "function": { + "name": part.function_call.name, + "arguments": _safe_json_serialize( + part.function_call.args + ), + }, + } + ) + + message["content"] = message_content or None + message["tool_calls"] = tool_calls or None + # messages from function_response were directly appended before + if message["content"] or message["tool_calls"]: + messages.append(message) + + return messages + + +def _messages_to_contents( + messages: list[dict[str, Any]], +) -> tuple[str | None, list[types.Content]]: + """Convert normalized messages back to Google ADK format. + + This function performs a round-trip conversion from normalized message dicts + back to Google ADK Content objects and system instruction. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Returns: + Tuple of (system_instruction, contents_list). + + """ + system_instruction: str | None = None + system_parts: list[str] = [] + contents: list[types.Content] = [] + + # Track pending tool responses to group them with the next user message + pending_tool_responses: list[types.Part] = [] + + for message in messages: + role = message.get("role", "user") + + # Extract system messages + if role == "system": + content = message.get("content", "") + if isinstance(content, str): + system_parts.append(content) + elif isinstance(content, list): + # Handle content as list of parts + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + system_parts.append(part.get("text", "")) + continue + + # Handle tool responses - accumulate them + if role == "tool": + tool_call_id = message.get("tool_call_id", "") + content = message.get("content", "") + # Parse the content as JSON if it's a string + try: + response_data = ( + json.loads(content) if isinstance(content, str) else content + ) + except json.JSONDecodeError: + response_data = content + + part = types.Part( + function_response=types.FunctionResponse( + name="", # Name is not stored in tool messages + id=tool_call_id, + response=response_data, + ) + ) + pending_tool_responses.append(part) + continue + + # Process user/assistant messages + parts: list[types.Part] = [] + + # If there are pending tool responses and this is a user message, add them first + if pending_tool_responses and role == "user": + parts.extend(pending_tool_responses) + pending_tool_responses = [] + + # Handle content + content = message.get("content") + if content: + if isinstance(content, str): + parts.append(types.Part.from_text(text=content)) + elif isinstance(content, list): + # Handle content as list of parts + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(types.Part.from_text(text=part.get("text", ""))) + + # Handle tool calls + tool_calls = message.get("tool_calls") + if tool_calls: + for tool_call in tool_calls: + if tool_call.get("type") == "function": + function_data = tool_call.get("function", {}) + args_str = function_data.get("arguments", "{}") + # Parse arguments if they're a string + try: + args = ( + json.loads(args_str) + if isinstance(args_str, str) + else args_str + ) + except json.JSONDecodeError: + args = {} + + part = types.Part.from_function_call( + name=function_data.get("name", ""), + args=args, + ) + part.function_call.id = tool_call.get("id", "") + parts.append(part) + + # Only add content if there are parts + if parts: + adk_role = ANY_LLM_TO_ADK_ROLE.get(role, "user") + contents.append(types.Content(role=adk_role, parts=parts)) + + # Combine system messages into a single instruction + if system_parts: + system_instruction = "\n".join(system_parts) + + return system_instruction, contents + def _safe_json_serialize(obj: Any) -> str: """Convert any Python object to a JSON-serializable type or string. @@ -63,59 +244,6 @@ def __init__(self, model: str, **kwargs: Any) -> None: super().__init__(model=model) self._kwargs = kwargs or {} - @staticmethod - def _messages_from_content(llm_request: LlmRequest) -> list[dict[str, Any]]: - messages: list[dict[str, Any]] = [] - for content in llm_request.contents: - message: dict[str, Any] = { - "role": ADK_TO_ANY_LLM_ROLE[str(content.role)], - } - message_content: list[Any] = [] - tool_calls: list[Any] = [] - if parts := content.parts: - for part in parts: - if part.function_response: - messages.append( - { - "role": "tool", - "tool_call_id": part.function_response.id, - "content": _safe_json_serialize( - part.function_response.response - ), - } - ) - elif part.text: - message_content.append({"type": "text", "text": part.text}) - elif ( - part.inline_data - and part.inline_data.data - and part.inline_data.mime_type - ): - # TODO Handle multimodal input - msg = f"Part of type {part.inline_data.mime_type} is not supported." - raise NotImplementedError(msg) - elif part.function_call: - tool_calls.append( - { - "type": "function", - "id": part.function_call.id, - "function": { - "name": part.function_call.name, - "arguments": _safe_json_serialize( - part.function_call.args - ), - }, - } - ) - - message["content"] = message_content or None - message["tool_calls"] = tool_calls or None - # messages from function_response were directly appended before - if message["content"] or message["tool_calls"]: - messages.append(message) - - return messages - def _schema_to_dict(self, schema: types.Schema) -> dict[str, Any]: """Recursively converts a types.Schema to a pure-python dict. @@ -204,7 +332,7 @@ def _function_declaration_to_tool_param( def _llm_request_to_completion_args( self, llm_request: LlmRequest ) -> dict[str, Any]: - messages = self._messages_from_content(llm_request) + messages = _messages_from_content(llm_request) if llm_request.config.system_instruction: messages.insert( 0, {"role": "system", "content": llm_request.config.system_instruction} diff --git a/src/any_agent/frameworks/langchain.py b/src/any_agent/frameworks/langchain.py index 34d52494..a910d981 100644 --- a/src/any_agent/frameworks/langchain.py +++ b/src/any_agent/frameworks/langchain.py @@ -250,6 +250,7 @@ class LangchainAgent(AnyAgent): def __init__(self, config: AgentConfig): super().__init__(config) self._agent: CompiledStateGraph[Any] | None = None + self._model: LanguageModelLike | None = None @property def framework(self) -> AgentFramework: @@ -280,9 +281,10 @@ async def _load_agent(self) -> None: self._tools = imported_tools agent_type = self.config.agent_type or DEFAULT_AGENT_TYPE agent_args = self.config.agent_args or {} + self._model = self._get_model(self.config) self._agent = agent_type( name=self.config.name, - model=self._get_model(self.config), + model=self._model, tools=imported_tools, prompt=self.config.instructions, **agent_args, diff --git a/src/any_agent/frameworks/openai.py b/src/any_agent/frameworks/openai.py index 3aefcc18..7a752f32 100644 --- a/src/any_agent/frameworks/openai.py +++ b/src/any_agent/frameworks/openai.py @@ -22,6 +22,7 @@ from any_llm import AnyLLM from openai import NOT_GIVEN, NotGiven, Omit from openai.types.responses import Response + from openai.types.responses.response_input_item_param import Message from openai.types.responses.response_usage import ( InputTokensDetails, OutputTokensDetails, @@ -67,6 +68,62 @@ def tool_to_openai(cls, tool: Tool) -> ChatCompletionToolParam: ) raise UserError(msg) + @classmethod + def params_to_messages( + cls, + system_instructions: str | None, + input_data: str | list[TResponseInputItem], + ) -> list[dict[str, Any]]: + """Convert system_instructions and input parameters to unified message format. + + Args: + system_instructions: Optional system instructions to prepend + input_data: Either a string or list of response input items + + Returns: + List of message dicts with 'role' and 'content' keys + + """ + converted_messages: list[dict[str, Any]] + if isinstance(input_data, str): + converted_messages = [{"role": "user", "content": input_data}] + else: + converted_messages = [ + dict(msg) for msg in cls.items_to_messages(input_data) + ] + + if system_instructions: + converted_messages.insert( + 0, {"content": system_instructions, "role": "system"} + ) + + return converted_messages + + @classmethod + def messages_to_params( + cls, messages: list[dict[str, Any]] + ) -> tuple[str | None, str | list[TResponseInputItem]]: + """Convert unified message format back to system_instructions and input parameters. + + Args: + messages: List of message dicts with 'role' and 'content' keys + + Returns: + Tuple of (system_instructions, input) where input is either a string or list of items + + """ + if messages and messages[0].get("role") == "system": + system_instructions = messages[0]["content"] + remaining_messages = messages[1:] + else: + system_instructions = None + remaining_messages = messages + + return system_instructions, [ + Message(role=msg["role"], content=msg["content"]) + for msg in remaining_messages + ] + class AnyllmModel(Model): """Enables using any model via AnyLLM. @@ -265,16 +322,7 @@ async def _fetch_response( any_llm.types.completion.ChatCompletion | tuple[Response, AsyncIterator[any_llm.types.completion.ChatCompletionChunk]] ): - converted_messages = Converter.items_to_messages(input) - - if system_instructions: - converted_messages.insert( - 0, - { - "content": system_instructions, - "role": "system", - }, - ) + converted_messages = Converter.params_to_messages(system_instructions, input) converted_messages = _to_dump_compatible(converted_messages) if tracing.include_data(): diff --git a/src/any_agent/frameworks/smolagents.py b/src/any_agent/frameworks/smolagents.py index 19208a9b..380bf28c 100644 --- a/src/any_agent/frameworks/smolagents.py +++ b/src/any_agent/frameworks/smolagents.py @@ -66,6 +66,7 @@ def __init__( "model": model_id, "api_key": api_key, "api_base": api_base, + "allow_running_loop": True, # Because smolagents uses sync api **kwargs, } diff --git a/tests/integration/callbacks/__init__.py b/tests/integration/callbacks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/callbacks/test_framework_state.py b/tests/integration/callbacks/test_framework_state.py new file mode 100644 index 00000000..4e6a60f4 --- /dev/null +++ b/tests/integration/callbacks/test_framework_state.py @@ -0,0 +1,107 @@ +import pytest + +from any_agent import AgentConfig, AnyAgent +from any_agent.callbacks import Callback, Context +from any_agent.config import AgentFramework +from any_agent.testing.helpers import DEFAULT_SMALL_MODEL_ID +from typing import Any + + +class LLMInputModifier(Callback): + """Callback that modifies LLM input messages.""" + + def __init__(self) -> None: + self.original_messages: list[dict[str, Any]] = [] + self.modified_messages: list[dict[str, Any]] = [] + + def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: + # Capture original messages for verification + messages = context.framework_state.get_messages() + self.original_messages = [msg.copy() for msg in messages] + + # Verify message structure before modification + assert len(messages) > 0, "Expected at least one message" + assert "role" in messages[-1], "Expected 'role' key in message" + assert "content" in messages[-1], "Expected 'content' key in message" + + # Modify the last message + messages[-1]["content"] = "Say hello" + context.framework_state.set_messages(messages) + + # Capture modified messages for verification + self.modified_messages = context.framework_state.get_messages() + + return context + + +class SecondModifier(Callback): + """Second callback to test sequential modifications.""" + + def __init__(self) -> None: + self.saw_first_modification = False + + def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: + messages = context.framework_state.get_messages() + + # Verify the first callback's modification is present + if len(messages) > 0: + self.saw_first_modification = "Say hello" in messages[-1].get("content", "") + + # Add additional modification + if len(messages) > 0: + messages[-1]["content"] = "Say hello and goodbye" + context.framework_state.set_messages(messages) + + return context + + +async def test_modify_llm_input(agent_framework: AgentFramework) -> None: + """Test that framework_state message modification works via helper methods.""" + modifier = LLMInputModifier() + second_modifier = SecondModifier() + + config = AgentConfig( + model_id="openai:gpt-4.1-mini", + instructions="You are a helpful assistant.", + callbacks=[modifier, second_modifier], + ) + + agent = await AnyAgent.create_async(agent_framework, config) + + try: + # First run: Test modification and sequential callback behavior + result = await agent.run_async("Say goodbye") + assert result.final_output is not None + assert isinstance(result.final_output, str) + + # Verify the modification took effect (should say hello AND goodbye) + assert "hello" in result.final_output.lower(), ( + "Expected 'hello' in the final output from first modification" + ) + assert "goodbye" in result.final_output.lower(), ( + "Expected 'goodbye' in the final output from second modification" + ) + + # Verify we captured the original message before modification + assert len(modifier.original_messages) > 0, ( + "Should have captured original messages" + ) + assert "Say goodbye" in modifier.original_messages[-1]["content"], ( + "Original message should contain 'Say goodbye'" + ) + + # Verify message structure was preserved + assert "role" in modifier.modified_messages[-1], ( + "Modified message should have 'role' key" + ) + assert "content" in modifier.modified_messages[-1], ( + "Modified message should have 'content' key" + ) + + # Verify sequential callback execution + assert second_modifier.saw_first_modification, ( + "Second callback should see first callback's modification" + ) + + finally: + await agent.cleanup_async()