From b029d93664de64a894a61bde359d33374fafac4f Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Tue, 21 Oct 2025 12:24:01 -0400 Subject: [PATCH 01/15] initial design --- docs/agents/callbacks.md | 73 +++++++++++++++ src/any_agent/callbacks/__init__.py | 4 +- src/any_agent/callbacks/context.py | 92 ++++++++++++++++++- src/any_agent/callbacks/wrappers/tinyagent.py | 15 +++ src/any_agent/frameworks/any_agent.py | 3 +- tests/integration/callbacks/__init__.py | 0 .../callbacks/test_framework_state.py | 39 ++++++++ 7 files changed, 222 insertions(+), 4 deletions(-) create mode 100644 tests/integration/callbacks/__init__.py create mode 100644 tests/integration/callbacks/test_framework_state.py diff --git a/docs/agents/callbacks.md b/docs/agents/callbacks.md index cf39beafd..1dda71e4e 100644 --- a/docs/agents/callbacks.md +++ b/docs/agents/callbacks.md @@ -54,6 +54,26 @@ property so that callbacks can access information in a framework-agnostic way. You can see what attributes are available for LLM Calls and Tool Executions by examining the [`GenAI`][any_agent.tracing.attributes.GenAI] class. +### Framework State + +In addition to the span attributes, callbacks can access and modify framework-specific objects through [`Context.framework_state`][any_agent.callbacks.context.Context.framework_state]. + +This allows callbacks to directly manipulate the agent's execution, such as: + +- Modifying messages before they're sent to the LLM +- Injecting prompts mid-execution +- Changing user queries dynamically + +#### Helper Methods + +The `framework_state` provides helper methods to work with messages in a normalized format: + +**`get_messages()`**: Get messages as a list of dicts with `role` and `content` keys + +**`set_messages()`**: Set messages from a list of dicts with `role` and `content` keys + +These methods handle framework-specific message formats internally, providing a consistent API across frameworks. + ## Implementing Callbacks All callbacks must inherit from the base [`Callback`][any_agent.callbacks.base.Callback] class and can choose to implement any subset of the available callback methods. These methods include: @@ -272,3 +292,56 @@ class LimitToolExecutions(Callback): return context ``` + +## Example: Modifying prompts dynamically + +You can use callbacks to modify the prompt being sent to the LLM. This is useful for injecting instructions or reminders mid-execution: + +```python +from any_agent.callbacks.base import Callback +from any_agent.callbacks.context import Context + +class InjectReminderCallback(Callback): + def __init__(self, reminder: str, every_n_calls: int = 5): + self.reminder = reminder + self.every_n_calls = every_n_calls + self.call_count = 0 + + def before_llm_call(self, context: Context, *args, **kwargs) -> Context: + self.call_count += 1 + + if self.call_count % self.every_n_calls == 0: + try: + messages = context.framework_state.get_messages() + if messages: + messages[-1]["content"] += f"\n\n{self.reminder}" + context.framework_state.set_messages(messages) + except NotImplementedError: + pass + + return context +``` + +Example usage: + +```python +from any_agent import AgentConfig, AnyAgent + +callback = InjectReminderCallback( + reminder="Remember to use the Todo tool to track your tasks!", + every_n_calls=5 +) + +config = AgentConfig( + model_id="gpt-4o-mini", + instructions="You are a helpful assistant.", + callbacks=[callback], +) + +agent = await AnyAgent.create_async("tinyagent", config) +``` + +!!! tip + + Use try/except to gracefully handle frameworks that don't support message modification yet. The callback will simply skip modification for unsupported frameworks. +``` diff --git a/src/any_agent/callbacks/__init__.py b/src/any_agent/callbacks/__init__.py index 7e26da881..836625a46 100644 --- a/src/any_agent/callbacks/__init__.py +++ b/src/any_agent/callbacks/__init__.py @@ -1,8 +1,8 @@ from .base import Callback -from .context import Context +from .context import Context, FrameworkState from .span_print import ConsolePrintSpan -__all__ = ["Callback", "ConsolePrintSpan", "Context"] +__all__ = ["Callback", "ConsolePrintSpan", "Context", "FrameworkState"] def get_default_callbacks() -> list[Callback]: diff --git a/src/any_agent/callbacks/context.py b/src/any_agent/callbacks/context.py index e3479991d..7574bdb22 100644 --- a/src/any_agent/callbacks/context.py +++ b/src/any_agent/callbacks/context.py @@ -1,14 +1,87 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from collections.abc import Callable + from opentelemetry.trace import Span, Tracer from any_agent.tracing.agent_trace import AgentTrace +@dataclass +class FrameworkState: + """Framework-specific state that can be accessed and modified by callbacks. + + This object provides a consistent interface for accessing framework state across + different agent frameworks, while the actual content is framework-specific. + """ + + messages: list[dict[str, Any]] = field(default_factory=list) + """Internal storage for messages. Use get_messages() and set_messages() instead.""" + + _message_getter: Callable[[], list[dict[str, Any]]] | None = field( + default=None, repr=False + ) + """Framework-specific message getter function.""" + + _message_setter: Callable[[list[dict[str, Any]]], None] | None = field( + default=None, repr=False + ) + """Framework-specific message setter function.""" + + def get_messages(self) -> list[dict[str, Any]]: + """Get messages in a normalized dict format. + + Returns a list of message dicts with 'role' and 'content' keys. + Works consistently across all frameworks. + + Returns: + List of message dicts with 'role' and 'content' keys. + + Raises: + NotImplementedError: If the framework doesn't support message access yet. + + Example: + ```python + messages = context.framework_state.get_messages() + # [{"role": "user", "content": "Hello"}] + ``` + + """ + if self._message_getter is None: + msg = "get_messages() is not implemented for this framework yet" + raise NotImplementedError(msg) + return self._message_getter() + + def set_messages(self, messages: list[dict[str, Any]]) -> None: + """Set messages from a normalized dict format. + + Accepts a list of message dicts with 'role' and 'content' keys and + converts them to the framework-specific format. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Raises: + NotImplementedError: If the framework doesn't support message modification yet. + + Example: + ```python + messages = context.framework_state.get_messages() + messages[-1]["content"] = "Say hello" + context.framework_state.set_messages(messages) + ``` + + """ + if self._message_setter is None: + msg = "set_messages() is not implemented for this framework yet" + raise NotImplementedError(msg) + self._message_setter(messages) + + @dataclass class Context: """Object that will be shared across callbacks. @@ -31,3 +104,20 @@ class Context: shared: dict[str, Any] """Can be used to store arbitrary information for sharing across callbacks.""" + + framework_state: FrameworkState + """Framework-specific state that can be accessed and modified by callbacks. + + Provides consistent access to framework state across different agent frameworks. + See [`FrameworkState`][any_agent.callbacks.context.FrameworkState] for available attributes. + + Example: + ```python + class ModifyPromptCallback(Callback): + def before_llm_call(self, context: Context, *args, **kwargs) -> Context: + # Modify the last message content + if context.framework_state.messages: + context.framework_state.messages[-1]["content"] = "Say hello" + return context + ``` + """ diff --git a/src/any_agent/callbacks/wrappers/tinyagent.py b/src/any_agent/callbacks/wrappers/tinyagent.py index 471837065..a473cf5ee 100644 --- a/src/any_agent/callbacks/wrappers/tinyagent.py +++ b/src/any_agent/callbacks/wrappers/tinyagent.py @@ -26,6 +26,20 @@ async def wrap_call_model(**kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] + + if "messages" in kwargs: + context.framework_state.messages = kwargs["messages"] + + def get_messages(): + return context.framework_state.messages + + def set_messages(messages): + context.framework_state.messages = messages + kwargs["messages"] = messages + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.before_llm_call(context, **kwargs) @@ -42,6 +56,7 @@ async def wrapped_tool_execution(original_call, request): context = self.callback_context[ get_current_span().get_span_context().trace_id ] + for callback in agent.config.callbacks: context = callback.before_tool_execution(context, request) diff --git a/src/any_agent/frameworks/any_agent.py b/src/any_agent/frameworks/any_agent.py index b1a1da413..34971fe23 100644 --- a/src/any_agent/frameworks/any_agent.py +++ b/src/any_agent/frameworks/any_agent.py @@ -8,7 +8,7 @@ from any_llm.utils.aio import run_async_in_sync from opentelemetry import trace as otel_trace -from any_agent.callbacks.context import Context +from any_agent.callbacks.context import Context, FrameworkState from any_agent.callbacks.wrappers import ( _get_wrapper_by_framework, ) @@ -217,6 +217,7 @@ async def run_async(self, prompt: str, **kwargs: Any) -> AgentTrace: trace=AgentTrace(), tracer=self._tracer, shared={}, + framework_state=FrameworkState(), ) if len(self._wrapper.callback_context) == 1: diff --git a/tests/integration/callbacks/__init__.py b/tests/integration/callbacks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/callbacks/test_framework_state.py b/tests/integration/callbacks/test_framework_state.py new file mode 100644 index 000000000..971c756d2 --- /dev/null +++ b/tests/integration/callbacks/test_framework_state.py @@ -0,0 +1,39 @@ +from any_agent import AgentConfig, AnyAgent +from any_agent.callbacks import Callback, Context +from any_agent.config import AgentFramework +from any_agent.testing.helpers import DEFAULT_SMALL_MODEL_ID +from typing import Any + + +class LLMInputModifier(Callback): + """Callback that modifies LLM input messages.""" + + def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: + messages = context.framework_state.get_messages() + messages[-1]["content"] = "Say hello" + context.framework_state.set_messages(messages) + return context + + +async def test_modify_llm_input(agent_framework: AgentFramework) -> None: + """Test that framework_state message modification works via helper methods.""" + modifier = LLMInputModifier() + config = AgentConfig( + model_id=DEFAULT_SMALL_MODEL_ID, + instructions="You are a helpful assistant.", + callbacks=[modifier], + ) + + agent = await AnyAgent.create_async(agent_framework, config) + + try: + result = await agent.run_async("Say goodbye") + assert result.final_output is not None + assert isinstance(result.final_output, str) + + assert "hello" in result.final_output.lower(), ( + "Expected 'hello' in the final output" + ) + + finally: + await agent.cleanup_async() From 85fbc09397a8680bd88a71f2e918bde90e89c99b Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Tue, 21 Oct 2025 12:26:19 -0400 Subject: [PATCH 02/15] fix provider naming --- docs/agents/callbacks.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/agents/callbacks.md b/docs/agents/callbacks.md index 1dda71e4e..c3e86f299 100644 --- a/docs/agents/callbacks.md +++ b/docs/agents/callbacks.md @@ -156,7 +156,7 @@ Callbacks are provided to the agent using the [`AgentConfig.callbacks`][any_agen agent = AnyAgent.create( "tinyagent", AgentConfig( - model_id="gpt-4.1-nano", + model_id="openai:gpt-4.1-nano", instructions="Use the tools to find an answer", tools=[search_web, visit_webpage], callbacks=[ @@ -177,7 +177,7 @@ Callbacks are provided to the agent using the [`AgentConfig.callbacks`][any_agen agent = AnyAgent.create( "tinyagent", AgentConfig( - model_id="gpt-4.1-nano", + model_id="openai:gpt-4.1-nano", instructions="Use the tools to find an answer", tools=[search_web, visit_webpage], callbacks=[ @@ -333,7 +333,7 @@ callback = InjectReminderCallback( ) config = AgentConfig( - model_id="gpt-4o-mini", + model_id="openai:gpt-4o-mini", instructions="You are a helpful assistant.", callbacks=[callback], ) From 35b6657c92bd6b3d7067746384ad6e00006f0cb9 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Tue, 21 Oct 2025 12:27:30 -0400 Subject: [PATCH 03/15] docs --- docs/agents/callbacks.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/agents/callbacks.md b/docs/agents/callbacks.md index c3e86f299..c1eb63b95 100644 --- a/docs/agents/callbacks.md +++ b/docs/agents/callbacks.md @@ -337,8 +337,7 @@ config = AgentConfig( instructions="You are a helpful assistant.", callbacks=[callback], ) - -agent = await AnyAgent.create_async("tinyagent", config) +# ... Continue to create and run agent ``` !!! tip From 0a9784a8fecd107ba4a414eaddd8772fbd940eb7 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Tue, 21 Oct 2025 12:33:54 -0400 Subject: [PATCH 04/15] add to api --- docs/api/callbacks.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api/callbacks.md b/docs/api/callbacks.md index ebca1863b..ed827a0e6 100644 --- a/docs/api/callbacks.md +++ b/docs/api/callbacks.md @@ -4,6 +4,8 @@ ::: any_agent.callbacks.context.Context +::: any_agent.callbacks.context.FrameworkState + ::: any_agent.callbacks.span_print.ConsolePrintSpan ::: any_agent.callbacks.get_default_callbacks From d955fa8d995c2a9a2058e5ffaf1b7e4b54cd1e04 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 27 Oct 2025 09:22:10 -0400 Subject: [PATCH 05/15] openai impl --- docs/agents/callbacks.md | 53 +++++++ src/any_agent/callbacks/wrappers/openai.py | 27 ++++ src/any_agent/frameworks/openai.py | 68 +++++++-- .../callbacks/test_framework_state.py | 135 +++++++++++++++++- 4 files changed, 270 insertions(+), 13 deletions(-) diff --git a/docs/agents/callbacks.md b/docs/agents/callbacks.md index c1eb63b95..1d45aa247 100644 --- a/docs/agents/callbacks.md +++ b/docs/agents/callbacks.md @@ -61,6 +61,7 @@ In addition to the span attributes, callbacks can access and modify framework-sp This allows callbacks to directly manipulate the agent's execution, such as: - Modifying messages before they're sent to the LLM +- Modifying the LLM's response after generation - Injecting prompts mid-execution - Changing user queries dynamically @@ -74,6 +75,14 @@ The `framework_state` provides helper methods to work with messages in a normali These methods handle framework-specific message formats internally, providing a consistent API across frameworks. +!!! note "Availability" + + The `get_messages()` and `set_messages()` methods are **only available in `before_llm_call` and `after_llm_call` callbacks**. + + - In `before_llm_call`: You can read and modify the messages that will be sent to the LLM + - In `after_llm_call`: You can read and modify the messages including the LLM's response + - In other callbacks (`before_tool_execution`, `after_tool_execution`, etc.): These methods will raise `NotImplementedError` + ## Implementing Callbacks All callbacks must inherit from the base [`Callback`][any_agent.callbacks.base.Callback] class and can choose to implement any subset of the available callback methods. These methods include: @@ -343,4 +352,48 @@ config = AgentConfig( !!! tip Use try/except to gracefully handle frameworks that don't support message modification yet. The callback will simply skip modification for unsupported frameworks. + +## Example: Modifying LLM responses + +You can also modify the LLM's response after it's generated using the `after_llm_call` callback. This is useful for post-processing responses, adding disclaimers, or formatting output: + +```python +from any_agent.callbacks.base import Callback +from any_agent.callbacks.context import Context + +class AddDisclaimerCallback(Callback): + def __init__(self, disclaimer: str): + self.disclaimer = disclaimer + + def after_llm_call(self, context: Context, *args, **kwargs) -> Context: + try: + messages = context.framework_state.get_messages() + if messages and len(messages) > 0: + # The last message is the LLM's response + last_message = messages[-1] + if last_message.get("role") == "assistant": + # Append disclaimer to the LLM's response + last_message["content"] += f"\n\n{self.disclaimer}" + context.framework_state.set_messages(messages) + except NotImplementedError: + pass + + return context +``` + +Example usage: + +```python +from any_agent import AgentConfig, AnyAgent + +callback = AddDisclaimerCallback( + disclaimer="*This response was generated by an AI assistant.*" +) + +config = AgentConfig( + model_id="openai:gpt-4o-mini", + instructions="You are a helpful assistant.", + callbacks=[callback], +) +# ... Continue to create and run agent ``` diff --git a/src/any_agent/callbacks/wrappers/openai.py b/src/any_agent/callbacks/wrappers/openai.py index 72413dad0..b6bc4c733 100644 --- a/src/any_agent/callbacks/wrappers/openai.py +++ b/src/any_agent/callbacks/wrappers/openai.py @@ -27,6 +27,33 @@ async def wrapped_llm_call(*args, **kwargs): get_current_span().get_span_context().trace_id ] context.shared["model_id"] = getattr(agent._agent.model, "model", None) + # import inside the wrap to avoid cases where the user hasn't installed the agents sdk + from any_agent.frameworks.openai import Converter + + system_instructions = kwargs.get("system_instructions") + input_data = kwargs.get("input") + + if input_data is not None: + converted_messages = Converter.params_to_messages( + system_instructions, input_data + ) + + context.framework_state.messages = converted_messages + + def get_messages(): + return context.framework_state.messages + + def set_messages(messages): + context.framework_state.messages = messages + + new_system_instructions, new_input = Converter.messages_to_params( + messages + ) + kwargs["system_instructions"] = new_system_instructions + kwargs["input"] = new_input + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) diff --git a/src/any_agent/frameworks/openai.py b/src/any_agent/frameworks/openai.py index 3aefcc183..7a752f326 100644 --- a/src/any_agent/frameworks/openai.py +++ b/src/any_agent/frameworks/openai.py @@ -22,6 +22,7 @@ from any_llm import AnyLLM from openai import NOT_GIVEN, NotGiven, Omit from openai.types.responses import Response + from openai.types.responses.response_input_item_param import Message from openai.types.responses.response_usage import ( InputTokensDetails, OutputTokensDetails, @@ -67,6 +68,62 @@ def tool_to_openai(cls, tool: Tool) -> ChatCompletionToolParam: ) raise UserError(msg) + @classmethod + def params_to_messages( + cls, + system_instructions: str | None, + input_data: str | list[TResponseInputItem], + ) -> list[dict[str, Any]]: + """Convert system_instructions and input parameters to unified message format. + + Args: + system_instructions: Optional system instructions to prepend + input_data: Either a string or list of response input items + + Returns: + List of message dicts with 'role' and 'content' keys + + """ + converted_messages: list[dict[str, Any]] + if isinstance(input_data, str): + converted_messages = [{"role": "user", "content": input_data}] + else: + converted_messages = [ + dict(msg) for msg in cls.items_to_messages(input_data) + ] + + if system_instructions: + converted_messages.insert( + 0, {"content": system_instructions, "role": "system"} + ) + + return converted_messages + + @classmethod + def messages_to_params( + cls, messages: list[dict[str, Any]] + ) -> tuple[str | None, str | list[TResponseInputItem]]: + """Convert unified message format back to system_instructions and input parameters. + + Args: + messages: List of message dicts with 'role' and 'content' keys + + Returns: + Tuple of (system_instructions, input) where input is either a string or list of items + + """ + if messages and messages[0].get("role") == "system": + system_instructions = messages[0]["content"] + remaining_messages = messages[1:] + else: + system_instructions = None + remaining_messages = messages + + return system_instructions, [ + Message(role=msg["role"], content=msg["content"]) + for msg in remaining_messages + ] + class AnyllmModel(Model): """Enables using any model via AnyLLM. @@ -265,16 +322,7 @@ async def _fetch_response( any_llm.types.completion.ChatCompletion | tuple[Response, AsyncIterator[any_llm.types.completion.ChatCompletionChunk]] ): - converted_messages = Converter.items_to_messages(input) - - if system_instructions: - converted_messages.insert( - 0, - { - "content": system_instructions, - "role": "system", - }, - ) + converted_messages = Converter.params_to_messages(system_instructions, input) converted_messages = _to_dump_compatible(converted_messages) if tracing.include_data(): diff --git a/tests/integration/callbacks/test_framework_state.py b/tests/integration/callbacks/test_framework_state.py index 971c756d2..df8374552 100644 --- a/tests/integration/callbacks/test_framework_state.py +++ b/tests/integration/callbacks/test_framework_state.py @@ -8,31 +8,160 @@ class LLMInputModifier(Callback): """Callback that modifies LLM input messages.""" + def __init__(self): + self.original_messages: list[dict[str, Any]] = [] + self.modified_messages: list[dict[str, Any]] = [] + def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: + # Capture original messages for verification messages = context.framework_state.get_messages() + self.original_messages = [msg.copy() for msg in messages] + + # Verify message structure before modification + assert len(messages) > 0, "Expected at least one message" + assert "role" in messages[-1], "Expected 'role' key in message" + assert "content" in messages[-1], "Expected 'content' key in message" + + # Modify the last message messages[-1]["content"] = "Say hello" context.framework_state.set_messages(messages) + + # Capture modified messages for verification + self.modified_messages = context.framework_state.get_messages() + + return context + + +class SecondModifier(Callback): + """Second callback to test sequential modifications.""" + + def __init__(self): + self.saw_first_modification = False + + def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: + messages = context.framework_state.get_messages() + + # Verify the first callback's modification is present + if len(messages) > 0: + self.saw_first_modification = "Say hello" in messages[-1].get("content", "") + + # Add additional modification + if len(messages) > 0: + messages[-1]["content"] = "Say hello and goodbye" + context.framework_state.set_messages(messages) + + return context + + +class AfterLLMModifier(Callback): + """Callback that modifies messages after LLM response.""" + + def __init__(self): + self.llm_response_messages: list[dict[str, Any]] = [] + self.saw_llm_response = False + + def after_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: + # Get messages including the LLM's response + messages = context.framework_state.get_messages() + self.llm_response_messages = [msg.copy() for msg in messages] + + # Verify we can see the LLM's response + if len(messages) > 0: + last_msg = messages[-1] + # The LLM should have responded with hello and goodbye + self.saw_llm_response = ( + "hello" in last_msg.get("content", "").lower() and + "goodbye" in last_msg.get("content", "").lower() + ) + + # Modify the LLM's response by appending text + last_msg["content"] = last_msg["content"] + " Also, have a great day!" + context.framework_state.set_messages(messages) + return context async def test_modify_llm_input(agent_framework: AgentFramework) -> None: - """Test that framework_state message modification works via helper methods.""" + """Test that framework_state message modification works via helper methods. + + This test verifies: + 1. Original messages can be read before modification (before_llm_call) + 2. Message structure (role, content) is preserved after modifications + 3. Multiple callbacks can modify messages sequentially (before_llm_call) + 4. LLM response can be read and modified in after_llm_call + 5. Modifications don't leak between separate runs (isolation) + """ modifier = LLMInputModifier() + second_modifier = SecondModifier() + after_llm_modifier = AfterLLMModifier() + config = AgentConfig( model_id=DEFAULT_SMALL_MODEL_ID, instructions="You are a helpful assistant.", - callbacks=[modifier], + callbacks=[modifier, second_modifier, after_llm_modifier], ) agent = await AnyAgent.create_async(agent_framework, config) try: + # First run: Test modification and sequential callback behavior result = await agent.run_async("Say goodbye") assert result.final_output is not None assert isinstance(result.final_output, str) + # Verify the modification took effect (should say hello AND goodbye) assert "hello" in result.final_output.lower(), ( - "Expected 'hello' in the final output" + "Expected 'hello' in the final output from first modification" + ) + assert "goodbye" in result.final_output.lower(), ( + "Expected 'goodbye' in the final output from second modification" + ) + + # Verify after_llm_call modification took effect + assert "have a great day" in result.final_output.lower(), ( + "Expected 'have a great day' from after_llm_call modification" + ) + + # Verify we captured the original message before modification + assert len(modifier.original_messages) > 0, "Should have captured original messages" + assert "Say goodbye" in modifier.original_messages[-1]["content"], ( + "Original message should contain 'Say goodbye'" + ) + + # Verify message structure was preserved + assert "role" in modifier.modified_messages[-1], "Modified message should have 'role' key" + assert "content" in modifier.modified_messages[-1], "Modified message should have 'content' key" + + # Verify sequential callback execution + assert second_modifier.saw_first_modification, ( + "Second callback should see first callback's modification" + ) + + # Verify after_llm_call could read the LLM's response + assert after_llm_modifier.saw_llm_response, ( + "after_llm_call callback should be able to read LLM response" + ) + assert len(after_llm_modifier.llm_response_messages) > 0, ( + "after_llm_call should have captured messages including LLM response" + ) + + # Second run: Test that modifications don't leak between runs + result2 = await agent.run_async("Tell me a joke") + assert result2.final_output is not None + assert isinstance(result2.final_output, str) + + # Verify the second run also got modified correctly + # (should still say hello and goodbye, not "tell me a joke") + assert "hello" in result2.final_output.lower(), ( + "Expected 'hello' in second run output" + ) + + # Verify original message was different for second run + assert "Tell me a joke" in modifier.original_messages[-1]["content"], ( + "Second run should have different original message" + ) + assert "Tell me a joke" not in modifier.modified_messages[-1]["content"], ( + "Second run's original input should have been modified" ) finally: From 6c6a383401cb498d9799bbbf68568421cb1b9b0a Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 27 Oct 2025 10:13:53 -0400 Subject: [PATCH 06/15] lint and google adk impl --- src/any_agent/callbacks/wrappers/google.py | 188 ++++++++++++++ src/any_agent/frameworks/google.py | 236 ++++++++++++++---- .../callbacks/test_framework_state.py | 22 +- 3 files changed, 384 insertions(+), 62 deletions(-) diff --git a/src/any_agent/callbacks/wrappers/google.py b/src/any_agent/callbacks/wrappers/google.py index 88582af27..cbb288b1c 100644 --- a/src/any_agent/callbacks/wrappers/google.py +++ b/src/any_agent/callbacks/wrappers/google.py @@ -10,6 +10,95 @@ from any_agent.frameworks.google import GoogleAgent +def _import_google_converters() -> tuple[Any, Any]: + """Import conversion functions from google framework module.""" + from any_agent.frameworks.google import ( + _messages_from_content, + _messages_to_contents, + ) + + return _messages_from_content, _messages_to_contents + + +def _llm_response_to_message(llm_response) -> dict[str, Any]: + """Convert Google ADK LlmResponse to a normalized message dict.""" + from any_agent.frameworks.google import ADK_TO_ANY_LLM_ROLE, _safe_json_serialize + + if not llm_response or not llm_response.content: + return {"role": "assistant", "content": None} + + content = llm_response.content + role = ADK_TO_ANY_LLM_ROLE.get(str(content.role), "assistant") + + message_content: list[Any] = [] + tool_calls: list[Any] = [] + + if content.parts: + for part in content.parts: + if part.text: + message_content.append({"type": "text", "text": part.text}) + elif part.function_call: + tool_calls.append( + { + "type": "function", + "id": part.function_call.id, + "function": { + "name": part.function_call.name, + "arguments": _safe_json_serialize(part.function_call.args), + }, + } + ) + + return { + "role": role, + "content": message_content or None, + "tool_calls": tool_calls or None, + } + + +def _message_to_llm_response(message: dict[str, Any]): + """Convert a normalized message dict back to Google ADK LlmResponse.""" + import json + + from google.adk.models.llm_response import LlmResponse + from google.genai import types + + parts = [] + + # Handle content + content = message.get("content") + if content: + if isinstance(content, str): + parts.append(types.Part.from_text(text=content)) + elif isinstance(content, list): + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(types.Part.from_text(text=part.get("text", ""))) + + # Handle tool calls + tool_calls = message.get("tool_calls") + if tool_calls: + for tool_call in tool_calls: + if tool_call.get("type") == "function": + function_data = tool_call.get("function", {}) + args_str = function_data.get("arguments", "{}") + try: + args = ( + json.loads(args_str) if isinstance(args_str, str) else args_str + ) + except json.JSONDecodeError: + args = {} + + part = types.Part.from_function_call( + name=function_data.get("name", ""), + args=args, + ) + part.function_call.id = tool_call.get("id", "") + parts.append(part) + + return LlmResponse(content=types.Content(role="model", parts=parts), partial=False) + + class _GoogleADKWrapper: def __init__(self) -> None: self.callback_context: dict[int, Context] = {} @@ -23,6 +112,59 @@ def before_model_callback(*args, **kwargs) -> Any | None: get_current_span().get_span_context().trace_id ] + # Set up message getters/setters if llm_request is present + llm_request = kwargs.get("llm_request") + if llm_request is not None: + _messages_from_content, _messages_to_contents = ( + _import_google_converters() + ) + + # Convert to normalized format + messages = _messages_from_content(llm_request) + + # Normalize content to strings when it's just text parts + for message in messages: + content = message.get("content") + if isinstance(content, list): + text_parts = [ + part.get("text", "") + for part in content + if isinstance(part, dict) and part.get("type") == "text" + ] + if text_parts and len(text_parts) == len(content): + # All parts are text, so join them + message["content"] = ( + " ".join(text_parts) + if len(text_parts) > 1 + else text_parts[0] + ) + + # Add system instruction as a message if present + if llm_request.config and llm_request.config.system_instruction: + messages.insert( + 0, + { + "role": "system", + "content": llm_request.config.system_instruction, + }, + ) + + context.framework_state.messages = messages + + def get_messages(): + return context.framework_state.messages + + def set_messages(new_messages): + context.framework_state.messages = new_messages + # Convert back to Google ADK format + system_instruction, contents = _messages_to_contents(new_messages) + llm_request.contents = contents + if llm_request.config: + llm_request.config.system_instruction = system_instruction + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) @@ -40,9 +182,55 @@ def after_model_callback(*args, **kwargs) -> Any | None: get_current_span().get_span_context().trace_id ] + # Set up message getters/setters to include the LLM response + llm_response = kwargs.get("llm_response") + if llm_response is not None: + # Convert response to normalized message format + response_message = _llm_response_to_message(llm_response) + + # Get existing messages and append the response + existing_messages = context.framework_state.messages.copy() + + # If the response has content as a list, convert to string for easier modification + if isinstance(response_message.get("content"), list): + text_parts = [ + part.get("text", "") + for part in response_message["content"] + if isinstance(part, dict) and part.get("type") == "text" + ] + if text_parts: + response_message["content"] = " ".join(text_parts) + + all_messages = [*existing_messages, response_message] + context.framework_state.messages = all_messages + + # Track the original response message for comparison + original_response_content = response_message.get("content") + + def get_messages(): + return context.framework_state.messages + + def set_messages(new_messages): + context.framework_state.messages = new_messages + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.after_llm_call(context, *args, **kwargs) + # Check if the response was modified and return the new response if so + if llm_response is not None: + final_messages = context.framework_state.messages + if final_messages: + final_response_message = final_messages[-1] + # If the content changed, convert back and return new response + if ( + final_response_message.get("content") + != original_response_content + ): + return _message_to_llm_response(final_response_message) + if callable(self._original["after_model"]): return self._original["after_model"](*args, **kwargs) diff --git a/src/any_agent/frameworks/google.py b/src/any_agent/frameworks/google.py index 1a1c33157..e4de0730a 100644 --- a/src/any_agent/frameworks/google.py +++ b/src/any_agent/frameworks/google.py @@ -36,6 +36,187 @@ "model": "assistant", } +ANY_LLM_TO_ADK_ROLE: dict[str, str] = { + "user": "user", + "assistant": "model", +} + + +def _messages_from_content(llm_request: LlmRequest) -> list[dict[str, Any]]: + """Convert Google ADK LlmRequest to normalized message format. + + Args: + llm_request: The LlmRequest to convert. + + Returns: + List of message dicts with 'role' and 'content' keys. + + """ + messages: list[dict[str, Any]] = [] + for content in llm_request.contents: + message: dict[str, Any] = { + "role": ADK_TO_ANY_LLM_ROLE[str(content.role)], + } + message_content: list[Any] = [] + tool_calls: list[Any] = [] + if parts := content.parts: + for part in parts: + if part.function_response: + messages.append( + { + "role": "tool", + "tool_call_id": part.function_response.id, + "content": _safe_json_serialize( + part.function_response.response + ), + } + ) + elif part.text: + message_content.append({"type": "text", "text": part.text}) + elif ( + part.inline_data + and part.inline_data.data + and part.inline_data.mime_type + ): + # TODO Handle multimodal input + msg = f"Part of type {part.inline_data.mime_type} is not supported." + raise NotImplementedError(msg) + elif part.function_call: + tool_calls.append( + { + "type": "function", + "id": part.function_call.id, + "function": { + "name": part.function_call.name, + "arguments": _safe_json_serialize( + part.function_call.args + ), + }, + } + ) + + message["content"] = message_content or None + message["tool_calls"] = tool_calls or None + # messages from function_response were directly appended before + if message["content"] or message["tool_calls"]: + messages.append(message) + + return messages + + +def _messages_to_contents( + messages: list[dict[str, Any]], +) -> tuple[str | None, list[types.Content]]: + """Convert normalized messages back to Google ADK format. + + This function performs a round-trip conversion from normalized message dicts + back to Google ADK Content objects and system instruction. + + Args: + messages: List of message dicts with 'role' and 'content' keys. + + Returns: + Tuple of (system_instruction, contents_list). + + """ + system_instruction: str | None = None + system_parts: list[str] = [] + contents: list[types.Content] = [] + + # Track pending tool responses to group them with the next user message + pending_tool_responses: list[types.Part] = [] + + for message in messages: + role = message.get("role", "user") + + # Extract system messages + if role == "system": + content = message.get("content", "") + if isinstance(content, str): + system_parts.append(content) + elif isinstance(content, list): + # Handle content as list of parts + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + system_parts.append(part.get("text", "")) + continue + + # Handle tool responses - accumulate them + if role == "tool": + tool_call_id = message.get("tool_call_id", "") + content = message.get("content", "") + # Parse the content as JSON if it's a string + try: + response_data = ( + json.loads(content) if isinstance(content, str) else content + ) + except json.JSONDecodeError: + response_data = content + + part = types.Part( + function_response=types.FunctionResponse( + name="", # Name is not stored in tool messages + id=tool_call_id, + response=response_data, + ) + ) + pending_tool_responses.append(part) + continue + + # Process user/assistant messages + parts: list[types.Part] = [] + + # If there are pending tool responses and this is a user message, add them first + if pending_tool_responses and role == "user": + parts.extend(pending_tool_responses) + pending_tool_responses = [] + + # Handle content + content = message.get("content") + if content: + if isinstance(content, str): + parts.append(types.Part.from_text(text=content)) + elif isinstance(content, list): + # Handle content as list of parts + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + parts.append(types.Part.from_text(text=part.get("text", ""))) + + # Handle tool calls + tool_calls = message.get("tool_calls") + if tool_calls: + for tool_call in tool_calls: + if tool_call.get("type") == "function": + function_data = tool_call.get("function", {}) + args_str = function_data.get("arguments", "{}") + # Parse arguments if they're a string + try: + args = ( + json.loads(args_str) + if isinstance(args_str, str) + else args_str + ) + except json.JSONDecodeError: + args = {} + + part = types.Part.from_function_call( + name=function_data.get("name", ""), + args=args, + ) + part.function_call.id = tool_call.get("id", "") + parts.append(part) + + # Only add content if there are parts + if parts: + adk_role = ANY_LLM_TO_ADK_ROLE.get(role, "user") + contents.append(types.Content(role=adk_role, parts=parts)) + + # Combine system messages into a single instruction + if system_parts: + system_instruction = "\n".join(system_parts) + + return system_instruction, contents + def _safe_json_serialize(obj: Any) -> str: """Convert any Python object to a JSON-serializable type or string. @@ -63,59 +244,6 @@ def __init__(self, model: str, **kwargs: Any) -> None: super().__init__(model=model) self._kwargs = kwargs or {} - @staticmethod - def _messages_from_content(llm_request: LlmRequest) -> list[dict[str, Any]]: - messages: list[dict[str, Any]] = [] - for content in llm_request.contents: - message: dict[str, Any] = { - "role": ADK_TO_ANY_LLM_ROLE[str(content.role)], - } - message_content: list[Any] = [] - tool_calls: list[Any] = [] - if parts := content.parts: - for part in parts: - if part.function_response: - messages.append( - { - "role": "tool", - "tool_call_id": part.function_response.id, - "content": _safe_json_serialize( - part.function_response.response - ), - } - ) - elif part.text: - message_content.append({"type": "text", "text": part.text}) - elif ( - part.inline_data - and part.inline_data.data - and part.inline_data.mime_type - ): - # TODO Handle multimodal input - msg = f"Part of type {part.inline_data.mime_type} is not supported." - raise NotImplementedError(msg) - elif part.function_call: - tool_calls.append( - { - "type": "function", - "id": part.function_call.id, - "function": { - "name": part.function_call.name, - "arguments": _safe_json_serialize( - part.function_call.args - ), - }, - } - ) - - message["content"] = message_content or None - message["tool_calls"] = tool_calls or None - # messages from function_response were directly appended before - if message["content"] or message["tool_calls"]: - messages.append(message) - - return messages - def _schema_to_dict(self, schema: types.Schema) -> dict[str, Any]: """Recursively converts a types.Schema to a pure-python dict. @@ -204,7 +332,7 @@ def _function_declaration_to_tool_param( def _llm_request_to_completion_args( self, llm_request: LlmRequest ) -> dict[str, Any]: - messages = self._messages_from_content(llm_request) + messages = _messages_from_content(llm_request) if llm_request.config.system_instruction: messages.insert( 0, {"role": "system", "content": llm_request.config.system_instruction} diff --git a/tests/integration/callbacks/test_framework_state.py b/tests/integration/callbacks/test_framework_state.py index df8374552..01f8f0c74 100644 --- a/tests/integration/callbacks/test_framework_state.py +++ b/tests/integration/callbacks/test_framework_state.py @@ -8,7 +8,7 @@ class LLMInputModifier(Callback): """Callback that modifies LLM input messages.""" - def __init__(self): + def __init__(self) -> None: self.original_messages: list[dict[str, Any]] = [] self.modified_messages: list[dict[str, Any]] = [] @@ -35,7 +35,7 @@ def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Contex class SecondModifier(Callback): """Second callback to test sequential modifications.""" - def __init__(self): + def __init__(self) -> None: self.saw_first_modification = False def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: @@ -56,7 +56,7 @@ def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Contex class AfterLLMModifier(Callback): """Callback that modifies messages after LLM response.""" - def __init__(self): + def __init__(self) -> None: self.llm_response_messages: list[dict[str, Any]] = [] self.saw_llm_response = False @@ -70,8 +70,8 @@ def after_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context last_msg = messages[-1] # The LLM should have responded with hello and goodbye self.saw_llm_response = ( - "hello" in last_msg.get("content", "").lower() and - "goodbye" in last_msg.get("content", "").lower() + "hello" in last_msg.get("content", "").lower() + and "goodbye" in last_msg.get("content", "").lower() ) # Modify the LLM's response by appending text @@ -123,14 +123,20 @@ async def test_modify_llm_input(agent_framework: AgentFramework) -> None: ) # Verify we captured the original message before modification - assert len(modifier.original_messages) > 0, "Should have captured original messages" + assert len(modifier.original_messages) > 0, ( + "Should have captured original messages" + ) assert "Say goodbye" in modifier.original_messages[-1]["content"], ( "Original message should contain 'Say goodbye'" ) # Verify message structure was preserved - assert "role" in modifier.modified_messages[-1], "Modified message should have 'role' key" - assert "content" in modifier.modified_messages[-1], "Modified message should have 'content' key" + assert "role" in modifier.modified_messages[-1], ( + "Modified message should have 'role' key" + ) + assert "content" in modifier.modified_messages[-1], ( + "Modified message should have 'content' key" + ) # Verify sequential callback execution assert second_modifier.saw_first_modification, ( From f979bfb0122eed9669d2d2077fdfb4c00b6594a2 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 27 Oct 2025 12:55:34 -0400 Subject: [PATCH 07/15] cleanup --- docs/agents/callbacks.md | 52 +------------ src/any_agent/callbacks/wrappers/agno.py | 55 ++++++++++++-- src/any_agent/callbacks/wrappers/google.py | 17 ----- src/any_agent/callbacks/wrappers/langchain.py | 1 - .../callbacks/wrappers/llama_index.py | 1 - src/any_agent/frameworks/agno.py | 2 +- .../callbacks/test_framework_state.py | 73 +------------------ 7 files changed, 53 insertions(+), 148 deletions(-) diff --git a/docs/agents/callbacks.md b/docs/agents/callbacks.md index 1d45aa247..934ac007d 100644 --- a/docs/agents/callbacks.md +++ b/docs/agents/callbacks.md @@ -77,11 +77,10 @@ These methods handle framework-specific message formats internally, providing a !!! note "Availability" - The `get_messages()` and `set_messages()` methods are **only available in `before_llm_call` and `after_llm_call` callbacks**. + The `get_messages()` and `set_messages()` methods are **only available in `before_llm_call` callbacks**. - In `before_llm_call`: You can read and modify the messages that will be sent to the LLM - - In `after_llm_call`: You can read and modify the messages including the LLM's response - - In other callbacks (`before_tool_execution`, `after_tool_execution`, etc.): These methods will raise `NotImplementedError` + - In other callbacks (`after_llm_call`, `before_tool_execution`, `after_tool_execution`, etc.): These methods will raise `NotImplementedError` ## Implementing Callbacks @@ -90,7 +89,7 @@ All callbacks must inherit from the base [`Callback`][any_agent.callbacks.base.C | Callback Method | Description | |:----------------:|:------------:| | before_agent_invocation | Should be used to check the Context before the agent is invoked. | -| befor_llm_call | Should be used before the chat history hits the LLM. | +| before_llm_call | Should be used before the chat history hits the LLM. | | after_llm_call | Should be used once LLM output is generated, before it appends to the chat history. | | before_tool_execution | Should be used to check the Context before tool execution. | | after_tool_execution | Should be used once tool has been executed, before the output is appended to chat history. | @@ -352,48 +351,3 @@ config = AgentConfig( !!! tip Use try/except to gracefully handle frameworks that don't support message modification yet. The callback will simply skip modification for unsupported frameworks. - -## Example: Modifying LLM responses - -You can also modify the LLM's response after it's generated using the `after_llm_call` callback. This is useful for post-processing responses, adding disclaimers, or formatting output: - -```python -from any_agent.callbacks.base import Callback -from any_agent.callbacks.context import Context - -class AddDisclaimerCallback(Callback): - def __init__(self, disclaimer: str): - self.disclaimer = disclaimer - - def after_llm_call(self, context: Context, *args, **kwargs) -> Context: - try: - messages = context.framework_state.get_messages() - if messages and len(messages) > 0: - # The last message is the LLM's response - last_message = messages[-1] - if last_message.get("role") == "assistant": - # Append disclaimer to the LLM's response - last_message["content"] += f"\n\n{self.disclaimer}" - context.framework_state.set_messages(messages) - except NotImplementedError: - pass - - return context -``` - -Example usage: - -```python -from any_agent import AgentConfig, AnyAgent - -callback = AddDisclaimerCallback( - disclaimer="*This response was generated by an AI assistant.*" -) - -config = AgentConfig( - model_id="openai:gpt-4o-mini", - instructions="You are a helpful assistant.", - callbacks=[callback], -) -# ... Continue to create and run agent -``` diff --git a/src/any_agent/callbacks/wrappers/agno.py b/src/any_agent/callbacks/wrappers/agno.py index b8f20d22c..31d15ea0a 100644 --- a/src/any_agent/callbacks/wrappers/agno.py +++ b/src/any_agent/callbacks/wrappers/agno.py @@ -9,33 +9,72 @@ from any_agent.callbacks.context import Context from any_agent.frameworks.agno import AgnoAgent +try: + from agno.models.message import Message + + agno_available = True +except ImportError: + agno_available = False + Message = None # type: ignore[assignment,misc] + class _AgnoWrapper: def __init__(self) -> None: self.callback_context: dict[int, Context] = {} - self._original_aprocess_model: Any = None + self._original_ainvoke: Any = None self._original_arun_function_call: Any = None async def wrap(self, agent: AgnoAgent) -> None: - self._original_aprocess_model = agent._agent.model._aprocess_model_response + if not agno_available: + msg = "Agno is not installed" + raise ImportError(msg) - async def wrapped_llm_call(*args, **kwargs): + self._original_ainvoke = agent._agent.model.ainvoke + + async def wrapped_ainvoke(messages, *args, **kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] context.shared["model_id"] = agent._agent.model.id + def get_messages(): + return [ + { + "role": msg.role, + "content": msg.content if msg.content else "", + } + for msg in messages + ] + + def set_messages(new_messages): + messages.clear() + for msg_dict in new_messages: + msg = Message( + role=msg_dict.get("role", "user"), + content=msg_dict.get("content", ""), + ) + if "tool_calls" in msg_dict: + msg.tool_calls = msg_dict["tool_calls"] + if "tool_call_id" in msg_dict: + msg.tool_call_id = msg_dict["tool_call_id"] + if "name" in msg_dict: + msg.name = msg_dict["name"] + messages.append(msg) + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) - result = await self._original_aprocess_model(*args, **kwargs) + result = await self._original_ainvoke(messages, *args, **kwargs) for callback in agent.config.callbacks: context = callback.after_llm_call(context, result, *args, **kwargs) return result - agent._agent.model._aprocess_model_response = wrapped_llm_call + agent._agent.model.ainvoke = wrapped_ainvoke self._original_arun_function_call = agent._agent.model.arun_function_call @@ -62,7 +101,7 @@ async def wrapped_tool_execution( agent._agent.model.arun_function_call = wrapped_tool_execution async def unwrap(self, agent: AgnoAgent): - if self._original_aprocess_model is not None: - agent._agent.model._aprocess_model_response = self._original_aprocess_model + if self._original_ainvoke is not None: + agent._agent.model.ainvoke = self._original_ainvoke if self._original_arun_function_call is not None: - agent._agent.model.arun_function_calls = self._original_arun_function_call + agent._agent.model.arun_function_call = self._original_arun_function_call diff --git a/src/any_agent/callbacks/wrappers/google.py b/src/any_agent/callbacks/wrappers/google.py index cbb288b1c..7f0b3acf0 100644 --- a/src/any_agent/callbacks/wrappers/google.py +++ b/src/any_agent/callbacks/wrappers/google.py @@ -65,7 +65,6 @@ def _message_to_llm_response(message: dict[str, Any]): parts = [] - # Handle content content = message.get("content") if content: if isinstance(content, str): @@ -75,7 +74,6 @@ def _message_to_llm_response(message: dict[str, Any]): if isinstance(part, dict) and part.get("type") == "text": parts.append(types.Part.from_text(text=part.get("text", ""))) - # Handle tool calls tool_calls = message.get("tool_calls") if tool_calls: for tool_call in tool_calls: @@ -112,17 +110,14 @@ def before_model_callback(*args, **kwargs) -> Any | None: get_current_span().get_span_context().trace_id ] - # Set up message getters/setters if llm_request is present llm_request = kwargs.get("llm_request") if llm_request is not None: _messages_from_content, _messages_to_contents = ( _import_google_converters() ) - # Convert to normalized format messages = _messages_from_content(llm_request) - # Normalize content to strings when it's just text parts for message in messages: content = message.get("content") if isinstance(content, list): @@ -132,14 +127,12 @@ def before_model_callback(*args, **kwargs) -> Any | None: if isinstance(part, dict) and part.get("type") == "text" ] if text_parts and len(text_parts) == len(content): - # All parts are text, so join them message["content"] = ( " ".join(text_parts) if len(text_parts) > 1 else text_parts[0] ) - # Add system instruction as a message if present if llm_request.config and llm_request.config.system_instruction: messages.insert( 0, @@ -156,7 +149,6 @@ def get_messages(): def set_messages(new_messages): context.framework_state.messages = new_messages - # Convert back to Google ADK format system_instruction, contents = _messages_to_contents(new_messages) llm_request.contents = contents if llm_request.config: @@ -182,16 +174,10 @@ def after_model_callback(*args, **kwargs) -> Any | None: get_current_span().get_span_context().trace_id ] - # Set up message getters/setters to include the LLM response llm_response = kwargs.get("llm_response") if llm_response is not None: - # Convert response to normalized message format response_message = _llm_response_to_message(llm_response) - - # Get existing messages and append the response existing_messages = context.framework_state.messages.copy() - - # If the response has content as a list, convert to string for easier modification if isinstance(response_message.get("content"), list): text_parts = [ part.get("text", "") @@ -204,7 +190,6 @@ def after_model_callback(*args, **kwargs) -> Any | None: all_messages = [*existing_messages, response_message] context.framework_state.messages = all_messages - # Track the original response message for comparison original_response_content = response_message.get("content") def get_messages(): @@ -219,12 +204,10 @@ def set_messages(new_messages): for callback in agent.config.callbacks: context = callback.after_llm_call(context, *args, **kwargs) - # Check if the response was modified and return the new response if so if llm_response is not None: final_messages = context.framework_state.messages if final_messages: final_response_message = final_messages[-1] - # If the content changed, convert back and return new response if ( final_response_message.get("content") != original_response_content diff --git a/src/any_agent/callbacks/wrappers/langchain.py b/src/any_agent/callbacks/wrappers/langchain.py index 207490c47..e78d0b4b6 100644 --- a/src/any_agent/callbacks/wrappers/langchain.py +++ b/src/any_agent/callbacks/wrappers/langchain.py @@ -126,7 +126,6 @@ async def wrap_ainvoke(*args, **kwargs): # type: ignore[no-untyped-def] agent._agent.ainvoke = wrap_ainvoke - # Wrap call_model to capture any-llm calls during structured output processing self._original_llm_call = agent.call_model async def wrap_call_model(**kwargs): diff --git a/src/any_agent/callbacks/wrappers/llama_index.py b/src/any_agent/callbacks/wrappers/llama_index.py index b0d034f87..62fc5a54f 100644 --- a/src/any_agent/callbacks/wrappers/llama_index.py +++ b/src/any_agent/callbacks/wrappers/llama_index.py @@ -73,7 +73,6 @@ async def acall(self, *args, **kwargs): wrapped = WrappedAcall(tool.metadata, tool.acall) tool.acall = wrapped.acall - # Wrap call_model to capture any-llm calls during structured output processing self._original_llm_call = agent.call_model async def wrap_call_model(**kwargs): diff --git a/src/any_agent/frameworks/agno.py b/src/any_agent/frameworks/agno.py index 027b02849..3679c73c4 100644 --- a/src/any_agent/frameworks/agno.py +++ b/src/any_agent/frameworks/agno.py @@ -139,7 +139,7 @@ async def ainvoke_stream( completion_kwargs = self.get_request_params(tools=tools) completion_kwargs["messages"] = self._format_messages(messages) completion_kwargs["stream"] = True - return acompletion(**completion_kwargs) + return await acompletion(**completion_kwargs) def parse_provider_response(self, response: Any, **kwargs) -> ModelResponse: # type: ignore[no-untyped-def] """Parse the provider response.""" diff --git a/tests/integration/callbacks/test_framework_state.py b/tests/integration/callbacks/test_framework_state.py index 01f8f0c74..f0d6e512f 100644 --- a/tests/integration/callbacks/test_framework_state.py +++ b/tests/integration/callbacks/test_framework_state.py @@ -53,52 +53,15 @@ def before_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Contex return context -class AfterLLMModifier(Callback): - """Callback that modifies messages after LLM response.""" - - def __init__(self) -> None: - self.llm_response_messages: list[dict[str, Any]] = [] - self.saw_llm_response = False - - def after_llm_call(self, context: Context, *args: Any, **kwargs: Any) -> Context: - # Get messages including the LLM's response - messages = context.framework_state.get_messages() - self.llm_response_messages = [msg.copy() for msg in messages] - - # Verify we can see the LLM's response - if len(messages) > 0: - last_msg = messages[-1] - # The LLM should have responded with hello and goodbye - self.saw_llm_response = ( - "hello" in last_msg.get("content", "").lower() - and "goodbye" in last_msg.get("content", "").lower() - ) - - # Modify the LLM's response by appending text - last_msg["content"] = last_msg["content"] + " Also, have a great day!" - context.framework_state.set_messages(messages) - - return context - - async def test_modify_llm_input(agent_framework: AgentFramework) -> None: - """Test that framework_state message modification works via helper methods. - - This test verifies: - 1. Original messages can be read before modification (before_llm_call) - 2. Message structure (role, content) is preserved after modifications - 3. Multiple callbacks can modify messages sequentially (before_llm_call) - 4. LLM response can be read and modified in after_llm_call - 5. Modifications don't leak between separate runs (isolation) - """ + """Test that framework_state message modification works via helper methods.""" modifier = LLMInputModifier() second_modifier = SecondModifier() - after_llm_modifier = AfterLLMModifier() config = AgentConfig( model_id=DEFAULT_SMALL_MODEL_ID, instructions="You are a helpful assistant.", - callbacks=[modifier, second_modifier, after_llm_modifier], + callbacks=[modifier, second_modifier], ) agent = await AnyAgent.create_async(agent_framework, config) @@ -117,11 +80,6 @@ async def test_modify_llm_input(agent_framework: AgentFramework) -> None: "Expected 'goodbye' in the final output from second modification" ) - # Verify after_llm_call modification took effect - assert "have a great day" in result.final_output.lower(), ( - "Expected 'have a great day' from after_llm_call modification" - ) - # Verify we captured the original message before modification assert len(modifier.original_messages) > 0, ( "Should have captured original messages" @@ -143,32 +101,5 @@ async def test_modify_llm_input(agent_framework: AgentFramework) -> None: "Second callback should see first callback's modification" ) - # Verify after_llm_call could read the LLM's response - assert after_llm_modifier.saw_llm_response, ( - "after_llm_call callback should be able to read LLM response" - ) - assert len(after_llm_modifier.llm_response_messages) > 0, ( - "after_llm_call should have captured messages including LLM response" - ) - - # Second run: Test that modifications don't leak between runs - result2 = await agent.run_async("Tell me a joke") - assert result2.final_output is not None - assert isinstance(result2.final_output, str) - - # Verify the second run also got modified correctly - # (should still say hello and goodbye, not "tell me a joke") - assert "hello" in result2.final_output.lower(), ( - "Expected 'hello' in second run output" - ) - - # Verify original message was different for second run - assert "Tell me a joke" in modifier.original_messages[-1]["content"], ( - "Second run should have different original message" - ) - assert "Tell me a joke" not in modifier.modified_messages[-1]["content"], ( - "Second run's original input should have been modified" - ) - finally: await agent.cleanup_async() From f2475d7507c5a78e241ff5f0883b546d796d29fc Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 27 Oct 2025 13:24:13 -0400 Subject: [PATCH 08/15] smolagents --- .../callbacks/wrappers/smolagents.py | 58 ++++++++++++++++++- src/any_agent/frameworks/smolagents.py | 1 + 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/any_agent/callbacks/wrappers/smolagents.py b/src/any_agent/callbacks/wrappers/smolagents.py index 4b3cb9ffd..b1b531825 100644 --- a/src/any_agent/callbacks/wrappers/smolagents.py +++ b/src/any_agent/callbacks/wrappers/smolagents.py @@ -20,18 +20,70 @@ def __init__(self) -> None: self._original_tools: Any | None = None async def wrap(self, agent: SmolagentsAgent) -> None: + try: + from smolagents.models import ChatMessage, MessageRole + + smolagents_available = True + except ImportError: + smolagents_available = False + + if not smolagents_available: + msg = "Smolagents is not installed" + raise ImportError(msg) + self._original_llm_call = agent._agent.model.generate - def wrap_generate(*args, **kwargs): + def wrap_generate(messages: list[ChatMessage], **kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] context.shared["model_id"] = str(agent._agent.model.model_id) + def get_messages(): + normalized_messages = [] + for msg in messages: + msg_dict = msg.dict() + + # Handle content that might be a list + content = msg_dict.get("content") + if isinstance(content, list): + text_parts = [ + part.get("text", "") + for part in content + if isinstance(part, dict) and part.get("type") == "text" + ] + if text_parts and len(text_parts) == len(content): + msg_dict["content"] = ( + " ".join(text_parts) + if len(text_parts) > 1 + else text_parts[0] + ) + + normalized_messages.append(msg_dict) + return normalized_messages + + def set_messages(new_messages): + if len(new_messages) != len(messages): + raise ValueError( + "Number of messages must match, Smolagents only allows for modification of message content, not the number of messages" + ) + for i, msg_dict in enumerate(new_messages): + text = msg_dict["content"] + if i == 1: + # Because of https://github.com/huggingface/smolagents/blob/317b57336c955e4e7518c42cc4ba53d880dd621a/src/smolagents/memory.py#L192 + # Smolagents expects the first user message to start with the hardcoded string "New task:" + text = f"New task:\n{text}" + messages[i].content = [ + {"type": "text", "text": text} + ] + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: - context = callback.before_llm_call(context, *args, **kwargs) + context = callback.before_llm_call(context, messages, **kwargs) - output = self._original_llm_call(*args, **kwargs) + output = self._original_llm_call(messages, **kwargs) for callback in agent.config.callbacks: context = callback.after_llm_call(context, output) diff --git a/src/any_agent/frameworks/smolagents.py b/src/any_agent/frameworks/smolagents.py index 19208a9b7..23dcac8f0 100644 --- a/src/any_agent/frameworks/smolagents.py +++ b/src/any_agent/frameworks/smolagents.py @@ -66,6 +66,7 @@ def __init__( "model": model_id, "api_key": api_key, "api_base": api_base, + "allow_running_loop": True, # Because smolagents uses sync api **kwargs, } From bf05d9d15a48760977886026188d9a32097bc2b1 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 27 Oct 2025 13:29:59 -0400 Subject: [PATCH 09/15] lint --- src/any_agent/callbacks/wrappers/smolagents.py | 13 +++++-------- src/any_agent/frameworks/smolagents.py | 2 +- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/any_agent/callbacks/wrappers/smolagents.py b/src/any_agent/callbacks/wrappers/smolagents.py index b1b531825..4e6e92f77 100644 --- a/src/any_agent/callbacks/wrappers/smolagents.py +++ b/src/any_agent/callbacks/wrappers/smolagents.py @@ -9,6 +9,8 @@ if TYPE_CHECKING: from collections.abc import Callable + from smolagents.models import ChatMessage + from any_agent.callbacks.context import Context from any_agent.frameworks.smolagents import SmolagentsAgent @@ -21,8 +23,6 @@ def __init__(self) -> None: async def wrap(self, agent: SmolagentsAgent) -> None: try: - from smolagents.models import ChatMessage, MessageRole - smolagents_available = True except ImportError: smolagents_available = False @@ -64,18 +64,15 @@ def get_messages(): def set_messages(new_messages): if len(new_messages) != len(messages): - raise ValueError( - "Number of messages must match, Smolagents only allows for modification of message content, not the number of messages" - ) + msg = f"Number of messages must match, Smolagents only allows for modification of message content, not the number of messages. Expected {len(messages)} messages, got {len(new_messages)} messages." + raise ValueError(msg) for i, msg_dict in enumerate(new_messages): text = msg_dict["content"] if i == 1: # Because of https://github.com/huggingface/smolagents/blob/317b57336c955e4e7518c42cc4ba53d880dd621a/src/smolagents/memory.py#L192 # Smolagents expects the first user message to start with the hardcoded string "New task:" text = f"New task:\n{text}" - messages[i].content = [ - {"type": "text", "text": text} - ] + messages[i].content = [{"type": "text", "text": text}] context.framework_state._message_getter = get_messages context.framework_state._message_setter = set_messages diff --git a/src/any_agent/frameworks/smolagents.py b/src/any_agent/frameworks/smolagents.py index 23dcac8f0..380bf28cc 100644 --- a/src/any_agent/frameworks/smolagents.py +++ b/src/any_agent/frameworks/smolagents.py @@ -66,7 +66,7 @@ def __init__( "model": model_id, "api_key": api_key, "api_base": api_base, - "allow_running_loop": True, # Because smolagents uses sync api + "allow_running_loop": True, # Because smolagents uses sync api **kwargs, } From cc7100fe997009f0cb70ee338e2cbba44076bfeb Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 27 Oct 2025 15:03:11 -0400 Subject: [PATCH 10/15] everybody works --- src/any_agent/callbacks/wrappers/langchain.py | 72 +++++++++++++++++++ .../callbacks/wrappers/llama_index.py | 31 ++++++++ .../callbacks/wrappers/smolagents.py | 61 +++++++++++----- src/any_agent/frameworks/langchain.py | 4 +- .../callbacks/test_framework_state.py | 4 +- 5 files changed, 151 insertions(+), 21 deletions(-) diff --git a/src/any_agent/callbacks/wrappers/langchain.py b/src/any_agent/callbacks/wrappers/langchain.py index e78d0b4b6..11d8f5428 100644 --- a/src/any_agent/callbacks/wrappers/langchain.py +++ b/src/any_agent/callbacks/wrappers/langchain.py @@ -5,6 +5,8 @@ from opentelemetry.trace import get_current_span +from any_agent.logging import logger + if TYPE_CHECKING: from collections.abc import Callable from uuid import UUID @@ -16,11 +18,22 @@ from any_agent.frameworks.langchain import LangchainAgent +def _import_langchain_converters() -> tuple[Any, Any]: + """Import conversion functions from langchain vendor module.""" + from any_agent.vendor.langchain_any_llm import ( + _convert_dict_to_message, + _convert_message_to_dict, + ) + + return _convert_dict_to_message, _convert_message_to_dict + + class _LangChainWrapper: def __init__(self) -> None: self.callback_context: dict[int, Context] = {} self._original_ainvoke: Any | None = None self._original_llm_call: Callable[..., Any] | None = None + self._original_agenerate: Callable[..., Any] | None = None async def wrap(self, agent: LangchainAgent) -> None: from langchain_core.callbacks.base import BaseCallbackHandler @@ -30,6 +43,8 @@ def before_llm_call(*args, **kwargs): context = self.callback_context[ get_current_span().get_span_context().trace_id ] + # Note: Message getters/setters are set up in wrap_agenerate + # This callback is for span generation via on_chat_model_start for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) @@ -126,6 +141,56 @@ async def wrap_ainvoke(*args, **kwargs): # type: ignore[no-untyped-def] agent._agent.ainvoke = wrap_ainvoke + if agent._model is not None and hasattr(agent._model, "_agenerate"): + self._original_agenerate = agent._model._agenerate + + async def wrap_agenerate(messages, *args, **kwargs): + messages_list = list(messages) + + try: + context = self.callback_context[ + get_current_span().get_span_context().trace_id + ] + _convert_dict_to_message, _convert_message_to_dict = ( + _import_langchain_converters() + ) + + normalized_messages = [ + _convert_message_to_dict(msg) for msg in messages_list + ] + context.framework_state.messages = normalized_messages + + def get_messages(): + return context.framework_state.messages + + def set_messages(new_messages): + nonlocal messages_list + context.framework_state.messages = new_messages + messages_list[:] = [ + _convert_dict_to_message(msg) for msg in new_messages + ] + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + + # Call user callbacks (but not span generation, that will happen in on_chat_model_start) + for callback in agent.config.callbacks: + if not hasattr(callback, "_set_llm_input"): + context = callback.before_llm_call( + context, None, [messages_list], **kwargs + ) + + except Exception: + # If we can't get context, just proceed without modification + logger.warning( + "Could not get context, proceeding without modification" + ) + + return await self._original_agenerate(messages_list, *args, **kwargs) + + agent._model._agenerate = wrap_agenerate + else: + logger.warning("Could not wrap _agenerate, proceeding without modification") self._original_llm_call = agent.call_model async def wrap_call_model(**kwargs): @@ -149,5 +214,12 @@ async def unwrap(self, agent: LangchainAgent) -> None: if self._original_ainvoke is not None: agent._agent.ainvoke = self._original_ainvoke + if ( + self._original_agenerate is not None + and agent._model is not None + and hasattr(agent._model, "_agenerate") + ): + agent._model._agenerate = self._original_agenerate + if self._original_llm_call is not None: agent.call_model = self._original_llm_call diff --git a/src/any_agent/callbacks/wrappers/llama_index.py b/src/any_agent/callbacks/wrappers/llama_index.py index 62fc5a54f..a325cbb91 100644 --- a/src/any_agent/callbacks/wrappers/llama_index.py +++ b/src/any_agent/callbacks/wrappers/llama_index.py @@ -10,6 +10,16 @@ from any_agent.frameworks.llama_index import LlamaIndexAgent +def _import_llama_index_converters() -> tuple[Any, Any]: + """Import conversion functions from llama_index vendor module.""" + from any_agent.vendor.llama_index_utils import ( + from_openai_message_dict, + to_openai_message_dicts, + ) + + return to_openai_message_dicts, from_openai_message_dict + + class _LlamaIndexWrapper: def __init__(self) -> None: self.callback_context: dict[int, Context] = {} @@ -26,6 +36,27 @@ async def wrap_take_step(*args, **kwargs): ] context.shared["model_id"] = getattr(agent._agent.llm, "model", "No model") + if len(args) > 1 and isinstance(args[1], list): + to_openai_message_dicts, from_openai_message_dict = ( + _import_llama_index_converters() + ) + + normalized_messages = to_openai_message_dicts(args[1]) + context.framework_state.messages = normalized_messages + + def get_messages(): + return context.framework_state.messages + + def set_messages(new_messages): + context.framework_state.messages = new_messages + args[1].clear() + args[1].extend( + [from_openai_message_dict(msg) for msg in new_messages] + ) + + context.framework_state._message_getter = get_messages + context.framework_state._message_setter = set_messages + for callback in agent.config.callbacks: context = callback.before_llm_call(context, *args, **kwargs) diff --git a/src/any_agent/callbacks/wrappers/smolagents.py b/src/any_agent/callbacks/wrappers/smolagents.py index 4e6e92f77..10c4a4dfd 100644 --- a/src/any_agent/callbacks/wrappers/smolagents.py +++ b/src/any_agent/callbacks/wrappers/smolagents.py @@ -9,8 +9,6 @@ if TYPE_CHECKING: from collections.abc import Callable - from smolagents.models import ChatMessage - from any_agent.callbacks.context import Context from any_agent.frameworks.smolagents import SmolagentsAgent @@ -23,6 +21,9 @@ def __init__(self) -> None: async def wrap(self, agent: SmolagentsAgent) -> None: try: + from smolagents.memory import TaskStep + from smolagents.models import ChatMessage + smolagents_available = True except ImportError: smolagents_available = False @@ -63,36 +64,58 @@ def get_messages(): return normalized_messages def set_messages(new_messages): - if len(new_messages) != len(messages): - msg = f"Number of messages must match, Smolagents only allows for modification of message content, not the number of messages. Expected {len(messages)} messages, got {len(new_messages)} messages." - raise ValueError(msg) - for i, msg_dict in enumerate(new_messages): - text = msg_dict["content"] - if i == 1: - # Because of https://github.com/huggingface/smolagents/blob/317b57336c955e4e7518c42cc4ba53d880dd621a/src/smolagents/memory.py#L192 - # Smolagents expects the first user message to start with the hardcoded string "New task:" - text = f"New task:\n{text}" - messages[i].content = [{"type": "text", "text": text}] + messages.clear() + for msg_dict in new_messages: + content = msg_dict["content"] + if isinstance(content, str): + content = [{"type": "text", "text": content}] + + new_msg_dict = {**msg_dict, "content": content} + messages.append(ChatMessage.from_dict(new_msg_dict)) + + # Update TaskStep in memory so modifications persist through write_memory_to_messages() + # This is necessary because smolagents rebuilds messages from memory on every LLM call + if len(new_messages) >= 2: + for step in agent._agent.memory.steps: + if isinstance(step, TaskStep): + # Extract the task text (remove "New task:\n" prefix if present) + new_task = new_messages[1]["content"] + if isinstance(new_task, str): + new_task = new_task.removeprefix("New task:\n") + step.task = new_task + break context.framework_state._message_getter = get_messages context.framework_state._message_setter = set_messages - for callback in agent.config.callbacks: - context = callback.before_llm_call(context, messages, **kwargs) + # Only invoke callbacks on the first LLM call, not on retry attempts + # Retries can be detected by checking if there are error messages in the history + is_retry = any( + "Error:" in str(msg.content) and "Now let's retry" in str(msg.content) + for msg in messages + ) + + if not is_retry: + for callback in agent.config.callbacks: + context = callback.before_llm_call(context, messages, **kwargs) output = self._original_llm_call(messages, **kwargs) - for callback in agent.config.callbacks: - context = callback.after_llm_call(context, output) + if not is_retry: + for callback in agent.config.callbacks: + context = callback.after_llm_call(context, output) return output agent._agent.model.generate = wrap_generate def wrapped_tool_execution(original_tool, original_call, *args, **kwargs): - context = self.callback_context[ - get_current_span().get_span_context().trace_id - ] + trace_id = get_current_span().get_span_context().trace_id + + if trace_id == 0 or trace_id not in self.callback_context: + return original_call(**kwargs) + + context = self.callback_context[trace_id] context.shared["original_tool"] = original_tool for callback in agent.config.callbacks: diff --git a/src/any_agent/frameworks/langchain.py b/src/any_agent/frameworks/langchain.py index 34d524945..a910d9814 100644 --- a/src/any_agent/frameworks/langchain.py +++ b/src/any_agent/frameworks/langchain.py @@ -250,6 +250,7 @@ class LangchainAgent(AnyAgent): def __init__(self, config: AgentConfig): super().__init__(config) self._agent: CompiledStateGraph[Any] | None = None + self._model: LanguageModelLike | None = None @property def framework(self) -> AgentFramework: @@ -280,9 +281,10 @@ async def _load_agent(self) -> None: self._tools = imported_tools agent_type = self.config.agent_type or DEFAULT_AGENT_TYPE agent_args = self.config.agent_args or {} + self._model = self._get_model(self.config) self._agent = agent_type( name=self.config.name, - model=self._get_model(self.config), + model=self._model, tools=imported_tools, prompt=self.config.instructions, **agent_args, diff --git a/tests/integration/callbacks/test_framework_state.py b/tests/integration/callbacks/test_framework_state.py index f0d6e512f..4e6a60f4f 100644 --- a/tests/integration/callbacks/test_framework_state.py +++ b/tests/integration/callbacks/test_framework_state.py @@ -1,3 +1,5 @@ +import pytest + from any_agent import AgentConfig, AnyAgent from any_agent.callbacks import Callback, Context from any_agent.config import AgentFramework @@ -59,7 +61,7 @@ async def test_modify_llm_input(agent_framework: AgentFramework) -> None: second_modifier = SecondModifier() config = AgentConfig( - model_id=DEFAULT_SMALL_MODEL_ID, + model_id="openai:gpt-4.1-mini", instructions="You are a helpful assistant.", callbacks=[modifier, second_modifier], ) From 6c5c0e00dc3dd2b52edf0473340fc8bffddd765d Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 27 Oct 2025 15:17:39 -0400 Subject: [PATCH 11/15] lint --- docs/agents/callbacks.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/agents/callbacks.md b/docs/agents/callbacks.md index 934ac007d..fdf9d23a0 100644 --- a/docs/agents/callbacks.md +++ b/docs/agents/callbacks.md @@ -347,7 +347,3 @@ config = AgentConfig( ) # ... Continue to create and run agent ``` - -!!! tip - - Use try/except to gracefully handle frameworks that don't support message modification yet. The callback will simply skip modification for unsupported frameworks. From 4b1779f3b1d9a4cee61ba225e63ab43961a03c96 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 3 Nov 2025 20:33:17 -0500 Subject: [PATCH 12/15] fix callbacks wrapping order --- src/any_agent/callbacks/wrappers/agno.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/any_agent/callbacks/wrappers/agno.py b/src/any_agent/callbacks/wrappers/agno.py index 31d15ea0a..380111cc6 100644 --- a/src/any_agent/callbacks/wrappers/agno.py +++ b/src/any_agent/callbacks/wrappers/agno.py @@ -64,13 +64,15 @@ def set_messages(new_messages): context.framework_state._message_getter = get_messages context.framework_state._message_setter = set_messages + callback_kwargs = {**kwargs, "messages": messages} for callback in agent.config.callbacks: - context = callback.before_llm_call(context, *args, **kwargs) + context = callback.before_llm_call(context, *args, **callback_kwargs) result = await self._original_ainvoke(messages, *args, **kwargs) + callback_kwargs = {**kwargs, "assistant_message": result} for callback in agent.config.callbacks: - context = callback.after_llm_call(context, result, *args, **kwargs) + context = callback.after_llm_call(context, *args, **callback_kwargs) return result From 116dba5d75e2f60ed92dadad33189bb426e31146 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 3 Nov 2025 20:50:23 -0500 Subject: [PATCH 13/15] logic swap to resolve tests --- .../callbacks/span_generation/agno.py | 68 ++++++++++++++----- 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/src/any_agent/callbacks/span_generation/agno.py b/src/any_agent/callbacks/span_generation/agno.py index d4e445826..d73507c20 100644 --- a/src/any_agent/callbacks/span_generation/agno.py +++ b/src/any_agent/callbacks/span_generation/agno.py @@ -23,26 +23,58 @@ def before_llm_call(self, context: Context, *args, **kwargs): def after_llm_call(self, context: Context, *args, **kwargs) -> Context: output: str | list[dict[str, Any]] = "" + input_tokens: int = 0 + output_tokens: int = 0 + if assistant_message := kwargs.get("assistant_message"): - if content := getattr(assistant_message, "content", None): - output = str(content) - if tool_calls := getattr(assistant_message, "tool_calls", None): - output = [ - { - "tool.name": tool.get("function", {}).get("name", "No name"), - "tool.args": tool.get("function", {}).get( - "arguments", "No args" - ), - } - for tool in tool_calls - ] + if hasattr(assistant_message, "choices"): + choices = getattr(assistant_message, "choices", []) + if choices and len(choices) > 0: + choice = choices[0] + message = getattr(choice, "message", None) + if message: + if content := getattr(message, "content", None): + output = str(content) + if tool_calls := getattr(message, "tool_calls", None): + output = [ + { + "tool.name": getattr( + getattr(tool, "function", None), "name", "No name" + ), + "tool.args": getattr( + getattr(tool, "function", None), + "arguments", + "No args", + ), + } + for tool in tool_calls + ] + + if usage := getattr(assistant_message, "usage", None): + input_tokens = getattr(usage, "input_tokens", 0) or getattr( + usage, "prompt_tokens", 0 + ) + output_tokens = getattr(usage, "output_tokens", 0) or getattr( + usage, "completion_tokens", 0 + ) + else: + if content := getattr(assistant_message, "content", None): + output = str(content) + if tool_calls := getattr(assistant_message, "tool_calls", None): + output = [ + { + "tool.name": tool.get("function", {}).get("name", "No name"), + "tool.args": tool.get("function", {}).get( + "arguments", "No args" + ), + } + for tool in tool_calls + ] - metrics: MessageMetrics | None - input_tokens: int = 0 - output_tokens: int = 0 - if metrics := getattr(assistant_message, "metrics", None): - input_tokens = metrics.input_tokens - output_tokens = metrics.output_tokens + metrics: MessageMetrics | None + if metrics := getattr(assistant_message, "metrics", None): + input_tokens = metrics.input_tokens + output_tokens = metrics.output_tokens context = self._set_llm_output(context, output, input_tokens, output_tokens) From b1330f889201c9b4fc6762f43040497efadb5db2 Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 3 Nov 2025 21:13:16 -0500 Subject: [PATCH 14/15] remove getters and setters when you're done with them --- src/any_agent/callbacks/wrappers/agno.py | 3 +++ src/any_agent/callbacks/wrappers/google.py | 3 +++ src/any_agent/callbacks/wrappers/langchain.py | 13 ++++++++++++- src/any_agent/callbacks/wrappers/llama_index.py | 3 +++ src/any_agent/callbacks/wrappers/openai.py | 3 +++ src/any_agent/callbacks/wrappers/smolagents.py | 3 +++ src/any_agent/callbacks/wrappers/tinyagent.py | 3 +++ 7 files changed, 30 insertions(+), 1 deletion(-) diff --git a/src/any_agent/callbacks/wrappers/agno.py b/src/any_agent/callbacks/wrappers/agno.py index 380111cc6..3cbedb2e6 100644 --- a/src/any_agent/callbacks/wrappers/agno.py +++ b/src/any_agent/callbacks/wrappers/agno.py @@ -74,6 +74,9 @@ def set_messages(new_messages): for callback in agent.config.callbacks: context = callback.after_llm_call(context, *args, **callback_kwargs) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + return result agent._agent.model.ainvoke = wrapped_ainvoke diff --git a/src/any_agent/callbacks/wrappers/google.py b/src/any_agent/callbacks/wrappers/google.py index 7f0b3acf0..c6a9eabd1 100644 --- a/src/any_agent/callbacks/wrappers/google.py +++ b/src/any_agent/callbacks/wrappers/google.py @@ -204,6 +204,9 @@ def set_messages(new_messages): for callback in agent.config.callbacks: context = callback.after_llm_call(context, *args, **kwargs) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + if llm_response is not None: final_messages = context.framework_state.messages if final_messages: diff --git a/src/any_agent/callbacks/wrappers/langchain.py b/src/any_agent/callbacks/wrappers/langchain.py index 11d8f5428..241f7d569 100644 --- a/src/any_agent/callbacks/wrappers/langchain.py +++ b/src/any_agent/callbacks/wrappers/langchain.py @@ -186,7 +186,18 @@ def set_messages(new_messages): "Could not get context, proceeding without modification" ) - return await self._original_agenerate(messages_list, *args, **kwargs) + result = await self._original_agenerate(messages_list, *args, **kwargs) + + try: + context = self.callback_context[ + get_current_span().get_span_context().trace_id + ] + context.framework_state._message_getter = None + context.framework_state._message_setter = None + except Exception: + pass + + return result agent._model._agenerate = wrap_agenerate else: diff --git a/src/any_agent/callbacks/wrappers/llama_index.py b/src/any_agent/callbacks/wrappers/llama_index.py index a325cbb91..a24fc3cf0 100644 --- a/src/any_agent/callbacks/wrappers/llama_index.py +++ b/src/any_agent/callbacks/wrappers/llama_index.py @@ -67,6 +67,9 @@ def set_messages(new_messages): for callback in agent.config.callbacks: context = callback.after_llm_call(context, output) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + return output # bypass Pydantic validation because _agent is a BaseModel diff --git a/src/any_agent/callbacks/wrappers/openai.py b/src/any_agent/callbacks/wrappers/openai.py index b6bc4c733..a3ac81e73 100644 --- a/src/any_agent/callbacks/wrappers/openai.py +++ b/src/any_agent/callbacks/wrappers/openai.py @@ -66,6 +66,9 @@ def set_messages(messages): output, ) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + return output agent._agent.model.get_response = wrapped_llm_call diff --git a/src/any_agent/callbacks/wrappers/smolagents.py b/src/any_agent/callbacks/wrappers/smolagents.py index 10c4a4dfd..cfb1a2614 100644 --- a/src/any_agent/callbacks/wrappers/smolagents.py +++ b/src/any_agent/callbacks/wrappers/smolagents.py @@ -105,6 +105,9 @@ def set_messages(new_messages): for callback in agent.config.callbacks: context = callback.after_llm_call(context, output) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + return output agent._agent.model.generate = wrap_generate diff --git a/src/any_agent/callbacks/wrappers/tinyagent.py b/src/any_agent/callbacks/wrappers/tinyagent.py index a473cf5ee..3c8cfaee2 100644 --- a/src/any_agent/callbacks/wrappers/tinyagent.py +++ b/src/any_agent/callbacks/wrappers/tinyagent.py @@ -48,6 +48,9 @@ def set_messages(messages): for callback in agent.config.callbacks: context = callback.after_llm_call(context, output) + context.framework_state._message_getter = None + context.framework_state._message_setter = None + return output agent.call_model = wrap_call_model From 829f539f58193b301b0be0a7d386bee31dd203cd Mon Sep 17 00:00:00 2001 From: Nathan Brake Date: Mon, 3 Nov 2025 21:17:18 -0500 Subject: [PATCH 15/15] lint --- src/any_agent/callbacks/span_generation/agno.py | 8 ++++++-- src/any_agent/callbacks/wrappers/langchain.py | 13 +++++-------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/any_agent/callbacks/span_generation/agno.py b/src/any_agent/callbacks/span_generation/agno.py index d73507c20..dd3c0cf71 100644 --- a/src/any_agent/callbacks/span_generation/agno.py +++ b/src/any_agent/callbacks/span_generation/agno.py @@ -39,7 +39,9 @@ def after_llm_call(self, context: Context, *args, **kwargs) -> Context: output = [ { "tool.name": getattr( - getattr(tool, "function", None), "name", "No name" + getattr(tool, "function", None), + "name", + "No name", ), "tool.args": getattr( getattr(tool, "function", None), @@ -63,7 +65,9 @@ def after_llm_call(self, context: Context, *args, **kwargs) -> Context: if tool_calls := getattr(assistant_message, "tool_calls", None): output = [ { - "tool.name": tool.get("function", {}).get("name", "No name"), + "tool.name": tool.get("function", {}).get( + "name", "No name" + ), "tool.args": tool.get("function", {}).get( "arguments", "No args" ), diff --git a/src/any_agent/callbacks/wrappers/langchain.py b/src/any_agent/callbacks/wrappers/langchain.py index 241f7d569..81b97d57e 100644 --- a/src/any_agent/callbacks/wrappers/langchain.py +++ b/src/any_agent/callbacks/wrappers/langchain.py @@ -188,14 +188,11 @@ def set_messages(new_messages): result = await self._original_agenerate(messages_list, *args, **kwargs) - try: - context = self.callback_context[ - get_current_span().get_span_context().trace_id - ] - context.framework_state._message_getter = None - context.framework_state._message_setter = None - except Exception: - pass + context = self.callback_context[ + get_current_span().get_span_context().trace_id + ] + context.framework_state._message_getter = None + context.framework_state._message_setter = None return result