diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index c824622091..a79b8c966f 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -16,10 +16,12 @@ from .logging_plugin import LoggingPlugin from .plugin_manager import PluginManager from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin +from .tool_aware_context_filter_plugin import ToolAwareContextFilterPlugin __all__ = [ 'BasePlugin', 'LoggingPlugin', 'PluginManager', 'ReflectAndRetryToolPlugin', + 'ToolAwareContextFilterPlugin' ] diff --git a/src/google/adk/plugins/tool_aware_context_filter_plugin.py b/src/google/adk/plugins/tool_aware_context_filter_plugin.py new file mode 100644 index 0000000000..a0e3d147fc --- /dev/null +++ b/src/google/adk/plugins/tool_aware_context_filter_plugin.py @@ -0,0 +1,357 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool-aware context filter plugin for managing conversation history. + +This plugin extends the standard context filtering to properly handle function +call/response sequences, ensuring they remain atomic during history trimming. + +PROBLEM WITH STANDARD ContextFilterPlugin: +========================================== +The standard ContextFilterPlugin treats each model message as a separate +"invocation", but when a model makes a tool call, it creates MULTIPLE model +messages in sequence: + 1. Model message with function_call + 2. User message with function_response (tool result) + 3. Model message with final text response + +When filtering to keep N "invocations", the standard plugin can split these +related messages apart, creating orphaned function_responses without their +corresponding function_calls, which violates OpenAI API requirements. + +HOW THIS PLUGIN SOLVES IT: +=========================== +This plugin groups messages into LOGICAL invocations where a complete cycle is: + - User query (one or more messages) + - Model response (possibly with function_call) + - Function response(s) (if tool was called) + - Model final response (after tool execution) + +All messages in a tool call sequence are kept together as an atomic unit. +""" + +from __future__ import annotations + +import logging +from typing import Callable, List, Optional + +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin +from google.genai import types + +logger = logging.getLogger("google_adk." + __name__) + + +class ToolAwareContextFilterPlugin(BasePlugin): + """A plugin that filters LLM context while preserving tool call sequences. + + This plugin extends context filtering to handle function call/response pairs + correctly, ensuring they are never split during history trimming. + """ + + def __init__( + self, + num_invocations_to_keep: Optional[int] = None, + custom_filter: Optional[ + Callable[[List[types.Content]], List[types.Content]] + ] = None, + name: str = "tool_aware_context_filter_plugin", + ): + """Initializes the tool-aware context filter plugin. + + Args: + num_invocations_to_keep: The number of last invocations to keep. An + invocation is defined as a complete user-model interaction cycle, + including any tool calls and their responses. + custom_filter: A function to apply additional filtering to the context. + name: The name of the plugin instance. + """ + super().__init__(name) + self._num_invocations_to_keep = num_invocations_to_keep + self._custom_filter = custom_filter + + @staticmethod + def _has_function_call(content) -> bool: + """Check if a content has a function_call part.""" + if not content.parts: + return False + return any( + hasattr(part, "function_call") and part.function_call + for part in content.parts + ) + + @staticmethod + def _has_function_response(content) -> bool: + """Check if a content has a function_response part.""" + if not content.parts: + return False + return any( + hasattr(part, "function_response") and part.function_response + for part in content.parts + ) + + def _finalize_invocation_if_complete( + self, + current_invocation: List[int], + invocations: List[List[int]], + contents: List[types.Content], + ) -> List[int]: + """Finalize current invocation if it has a model response. + + Args: + current_invocation: Current invocation being built. + invocations: List of completed invocations. + contents: List of message contents. + + Returns: + Empty list if invocation was finalized, otherwise current_invocation. + """ + if current_invocation: + has_model = any( + contents[idx].role == "model" for idx in current_invocation + ) + if has_model: + invocations.append(current_invocation) + return [] + return current_invocation + + def _process_user_message( + self, + i: int, + contents: List[types.Content], + current_invocation: List[int], + invocations: List[List[int]], + ) -> tuple[int, List[int]]: + """Process a user message and update invocation tracking. + + Args: + i: Current index in contents. + contents: List of message contents. + current_invocation: Current invocation being built. + invocations: List of completed invocations. + + Returns: + Tuple of (next_index, updated_current_invocation). + """ + content = contents[i] + + # Check if this is a function_response (part of ongoing tool cycle) + if self._has_function_response(content): + # This is a tool response - must be part of current invocation + current_invocation.append(i) + return i + 1, current_invocation + + # Regular user message (not a function_response) + # Only start a NEW invocation if we've completed a previous one + current_invocation = self._finalize_invocation_if_complete( + current_invocation, invocations, contents + ) + + # Add this user message to current invocation + current_invocation.append(i) + return i + 1, current_invocation + + def _process_model_message_with_tool_call( + self, + i: int, + contents: List[types.Content], + current_invocation: List[int], + invocations: List[List[int]], + ) -> tuple[int, List[int]]: + """Process a model message with tool call and collect the full cycle. + + Args: + i: Current index in contents (at model message with function_call). + contents: List of message contents. + current_invocation: Current invocation being built. + invocations: List of completed invocations. + + Returns: + Tuple of (next_index, empty_invocation_list). + """ + # Model made a tool call - keep following messages together: + # 1. This model message (function_call) - already added + # 2. User message(s) with function_response - collect next + # 3. Model's final response - collect after tool responses + + i += 1 # Move to next message + + # Collect all function_response messages (usually 1, but could be multiple) + while ( + i < len(contents) + and contents[i].role == "user" + and self._has_function_response(contents[i]) + ): + current_invocation.append(i) + i += 1 + + # Now collect the model's final response after processing tool results + if i < len(contents) and contents[i].role == "model": + current_invocation.append(i) + i += 1 + + # Complete tool cycle collected - this is ONE complete invocation + invocations.append(current_invocation) + return i, [] + + def _process_model_message( + self, + i: int, + contents: List[types.Content], + current_invocation: List[int], + invocations: List[List[int]], + ) -> tuple[int, List[int]]: + """Process a model message and update invocation tracking. + + Args: + i: Current index in contents. + contents: List of message contents. + current_invocation: Current invocation being built. + invocations: List of completed invocations. + + Returns: + Tuple of (next_index, updated_current_invocation). + """ + content = contents[i] + current_invocation.append(i) + + # Check if model is making a tool call + if self._has_function_call(content): + return self._process_model_message_with_tool_call( + i, contents, current_invocation, invocations + ) + + # Model response WITHOUT function call - simple case + # The invocation is complete (user query → model answer) + invocations.append(current_invocation) + return i + 1, [] + + def _group_into_invocations( + self, contents: List[types.Content] + ) -> List[List[int]]: + """Group message indices into complete invocations. + + An invocation pattern: + 1. One or more user messages (including consecutive user messages) + 2. Model response (possibly with function_call) + 3. If function_call exists: user message(s) with function_response + 4. If function_call exists: model final response + + Example grouping: + Messages: [user, user, model, user, model+func_call, user+func_response, + model] Groups: [0,1,2] [3,4,5,6] + ^^^^^^^ ^^^^^^^^^^^ + Inv 1 Inv 2 (includes tool cycle) + + Args: + contents: List of message contents to group. + + Returns: + List of invocations, where each invocation is a list of message indices. + """ + invocations = [] + current_invocation = [] + i = 0 + + while i < len(contents): + content = contents[i] + + if content.role == "user": + i, current_invocation = self._process_user_message( + i, contents, current_invocation, invocations + ) + elif content.role == "model": + i, current_invocation = self._process_model_message( + i, contents, current_invocation, invocations + ) + else: + # Unknown role - just add to current invocation + current_invocation.append(i) + i += 1 + + # Add any remaining messages as final invocation + if current_invocation: + invocations.append(current_invocation) + + return invocations + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Filters the LLM request's context before it is sent to the model. + + This method groups messages into logical invocations and keeps only the + most recent N invocations, ensuring tool call sequences remain intact. + + Args: + callback_context: Context containing invocation and agent information. + llm_request: The LLM request to filter. + + Returns: + None - the request is modified in place. + """ + try: + contents = llm_request.contents + + if not contents: + return None + + # Apply invocation-based filtering if configured + if ( + self._num_invocations_to_keep is not None + and self._num_invocations_to_keep > 0 + ): + # Group messages into logical invocations + invocations = self._group_into_invocations(contents) + + logger.info( + "ToolAwareContextFilter: Total invocations=%d, keeping last %d", + len(invocations), + self._num_invocations_to_keep, + ) + + # Keep only the last N invocations + if len(invocations) > self._num_invocations_to_keep: + invocations_to_keep = invocations[-self._num_invocations_to_keep :] + + # Flatten the list of indices + indices_to_keep = [] + for invocation in invocations_to_keep: + indices_to_keep.extend(invocation) + + # Filter contents based on indices + filtered_contents = [contents[i] for i in indices_to_keep] + + logger.info( + "ToolAwareContextFilter: Reduced from %d messages to %d messages" + " (kept %d invocations)", + len(contents), + len(filtered_contents), + len(invocations_to_keep), + ) + + contents = filtered_contents + + # Apply custom filter if provided + if self._custom_filter: + contents = self._custom_filter(contents) + + llm_request.contents = contents + + except Exception as e: + logger.error("ToolAwareContextFilter: Failed to filter context: %s", e) + + return None \ No newline at end of file diff --git a/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py b/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py new file mode 100644 index 0000000000..a4c88710a1 --- /dev/null +++ b/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py @@ -0,0 +1,313 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ToolAwareContextFilterPlugin.""" + +import pytest +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.plugins.tool_aware_context_filter_plugin import ( + ToolAwareContextFilterPlugin, +) +from google.genai.types import Content, FunctionCall, FunctionResponse, Part + + +class TestToolAwareContextFilterPlugin: + """Tests for ToolAwareContextFilterPlugin.""" + + def test_init(self): + """Test plugin initialization.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + assert plugin._num_invocations_to_keep == 2 + assert plugin._custom_filter is None + + @pytest.mark.asyncio + async def test_no_filtering_when_disabled(self): + """Test that no filtering occurs when num_invocations_to_keep is None.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=None) + + contents = [ + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + Content(role="user", parts=[Part(text="How are you?")]), + Content(role="model", parts=[Part(text="I'm good")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + await plugin.before_model_callback( + callback_context=context, llm_request=request + ) + + # No filtering should occur + assert len(request.contents) == 4 + + @pytest.mark.asyncio + async def test_simple_invocations_no_tool_calls(self): + """Test filtering simple Q&A without tool calls.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + # Create 3 simple invocations + contents = [ + # Invocation 1 + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 + Content(role="user", parts=[Part(text="How are you?")]), + Content(role="model", parts=[Part(text="I'm good")]), + # Invocation 3 + Content(role="user", parts=[Part(text="What's your name?")]), + Content(role="model", parts=[Part(text="I'm Claude")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + await plugin.before_model_callback( + callback_context=context, llm_request=request + ) + + # Should keep last 2 invocations (indices 2-5) + assert len(request.contents) == 4 + assert request.contents[0].parts[0].text == "How are you?" + assert request.contents[-1].parts[0].text == "I'm Claude" + + @pytest.mark.asyncio + async def test_tool_call_sequence_kept_together(self): + """Test that function_call and function_response stay together.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) + + # Create invocations where the last one has a tool call + contents = [ + # Invocation 1 (should be removed) + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (should be kept - has tool call) + Content(role="user", parts=[Part(text="What's the weather?")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="get_weather", args={"location": "SF"} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="get_weather", response={"temp": 72} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="It's 72°F")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + await plugin.before_model_callback( + callback_context=context, llm_request=request + ) + + # Should keep entire tool call sequence (4 messages) + assert len(request.contents) == 4 + assert request.contents[0].parts[0].text == "What's the weather?" + assert hasattr(request.contents[1].parts[0], "function_call") + assert hasattr(request.contents[2].parts[0], "function_response") + assert request.contents[3].parts[0].text == "It's 72°F" + + @pytest.mark.asyncio + async def test_orphaned_function_response_prevented(self): + """Test that function_response is never orphaned without function_call.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + contents = [ + # Invocation 1 + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (with tool call) + Content(role="user", parts=[Part(text="Query 1")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="tool1", args={} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="tool1", response={} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="Response 1")]), + # Invocation 3 (with tool call) + Content(role="user", parts=[Part(text="Query 2")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="tool2", args={} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="tool2", response={} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="Response 2")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + await plugin.before_model_callback( + callback_context=context, llm_request=request + ) + + # Should keep invocations 2 and 3 (indices 2-9) + assert len(request.contents) == 8 + + # Verify no orphaned function_response + for i, content in enumerate(request.contents): + if content.role == "user" and any( + hasattr(p, "function_response") and p.function_response + for p in content.parts + ): + # There must be a preceding model message with function_call + assert i > 0 + prev_content = request.contents[i - 1] + assert prev_content.role == "model" + assert any( + hasattr(p, "function_call") and p.function_call + for p in prev_content.parts + ) + + @pytest.mark.asyncio + async def test_consecutive_user_messages_grouped(self): + """Test that consecutive user messages are grouped together.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) + + contents = [ + # Invocation 1 (should be removed) + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (should be kept) + Content(role="user", parts=[Part(text="For context:")]), + Content(role="user", parts=[Part(text="Tell me about X")]), + Content(role="model", parts=[Part(text="Here's about X")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + await plugin.before_model_callback( + callback_context=context, llm_request=request + ) + + # Should keep the last invocation with both user messages + assert len(request.contents) == 3 + assert request.contents[0].parts[0].text == "For context:" + assert request.contents[1].parts[0].text == "Tell me about X" + assert request.contents[2].parts[0].text == "Here's about X" + + @pytest.mark.asyncio + async def test_custom_filter_applied(self): + """Test that custom filter is applied after invocation filtering.""" + # Custom filter that removes all model messages + def custom_filter(contents): + return [c for c in contents if c.role != "model"] + + plugin = ToolAwareContextFilterPlugin( + num_invocations_to_keep=2, custom_filter=custom_filter + ) + + contents = [ + Content(role="user", parts=[Part(text="Query 1")]), + Content(role="model", parts=[Part(text="Response 1")]), + Content(role="user", parts=[Part(text="Query 2")]), + Content(role="model", parts=[Part(text="Response 2")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + await plugin.before_model_callback( + callback_context=context, llm_request=request + ) + + # Should have only user messages + assert len(request.contents) == 2 + assert all(c.role == "user" for c in request.contents) + + @pytest.mark.asyncio + async def test_empty_contents(self): + """Test handling of empty contents list.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + request = LlmRequest(model="test", contents=[]) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + await plugin.before_model_callback( + callback_context=context, llm_request=request + ) + + # Should handle empty contents gracefully + assert len(request.contents) == 0 + + @pytest.mark.asyncio + async def test_error_handling(self): + """Test that errors are caught and logged without crashing.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + # Create a valid request first, then set contents to None + request = LlmRequest(model="test", contents=[]) + context = CallbackContext(invocation_id="test", agent_name="test") + request.contents = None + + # Should not raise an exception + result = await plugin.before_model_callback( + callback_context=context, llm_request=request + ) + + # Should return None without crashing + assert result is None \ No newline at end of file