From f755dca13c54f37fc39920a9f6378774de2fd97a Mon Sep 17 00:00:00 2001 From: shekhar Date: Mon, 5 Jan 2026 19:18:52 +0530 Subject: [PATCH 1/4] Add ToolAwareContextFilterPlugin to preserve tool call sequences - Fixes issue where ContextFilterPlugin splits function_call/function_response pairs - Groups tool call sequences as atomic invocations - Prevents OpenAI API errors when filtering conversation history - Adds comprehensive unit tests Fixes #4027 --- src/google/adk/plugins/__init__.py | 2 + .../tool_aware_context_filter_plugin.py | 267 +++++++++++++++ test_tool_aware_context_filter_plugin.py | 320 ++++++++++++++++++ .../test_tool_aware_context_filter_plugin.py | 320 ++++++++++++++++++ tool_aware_context_filter_plugin.py | 267 +++++++++++++++ 5 files changed, 1176 insertions(+) create mode 100644 src/google/adk/plugins/tool_aware_context_filter_plugin.py create mode 100644 test_tool_aware_context_filter_plugin.py create mode 100644 tests/unittests/plugins/test_tool_aware_context_filter_plugin.py create mode 100644 tool_aware_context_filter_plugin.py diff --git a/src/google/adk/plugins/__init__.py b/src/google/adk/plugins/__init__.py index c824622091..a79b8c966f 100644 --- a/src/google/adk/plugins/__init__.py +++ b/src/google/adk/plugins/__init__.py @@ -16,10 +16,12 @@ from .logging_plugin import LoggingPlugin from .plugin_manager import PluginManager from .reflect_retry_tool_plugin import ReflectAndRetryToolPlugin +from .tool_aware_context_filter_plugin import ToolAwareContextFilterPlugin __all__ = [ 'BasePlugin', 'LoggingPlugin', 'PluginManager', 'ReflectAndRetryToolPlugin', + 'ToolAwareContextFilterPlugin' ] diff --git a/src/google/adk/plugins/tool_aware_context_filter_plugin.py b/src/google/adk/plugins/tool_aware_context_filter_plugin.py new file mode 100644 index 0000000000..0a3b53e0ab --- /dev/null +++ b/src/google/adk/plugins/tool_aware_context_filter_plugin.py @@ -0,0 +1,267 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool-aware context filter plugin for managing conversation history. + +This plugin extends the standard context filtering to properly handle function +call/response sequences, ensuring they remain atomic during history trimming. + +PROBLEM WITH STANDARD ContextFilterPlugin: +========================================== +The standard ContextFilterPlugin treats each model message as a separate +"invocation", but when a model makes a tool call, it creates MULTIPLE model +messages in sequence: + 1. Model message with function_call + 2. User message with function_response (tool result) + 3. Model message with final text response + +When filtering to keep N "invocations", the standard plugin can split these +related messages apart, creating orphaned function_responses without their +corresponding function_calls, which violates OpenAI API requirements. + +HOW THIS PLUGIN SOLVES IT: +=========================== +This plugin groups messages into LOGICAL invocations where a complete cycle is: + - User query (one or more messages) + - Model response (possibly with function_call) + - Function response(s) (if tool was called) + - Model final response (after tool execution) + +All messages in a tool call sequence are kept together as an atomic unit. +""" + +from __future__ import annotations + +import logging +from typing import Callable, List, Optional + +from google.adk.agents.callback_context import CallbackContext +from google.adk.events.event import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin + +logger = logging.getLogger("google_adk." + __name__) + + +class ToolAwareContextFilterPlugin(BasePlugin): + """A plugin that filters LLM context while preserving tool call sequences. + + This plugin extends context filtering to handle function call/response pairs + correctly, ensuring they are never split during history trimming. + """ + + def __init__( + self, + num_invocations_to_keep: Optional[int] = None, + custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, + name: str = "tool_aware_context_filter_plugin", + ): + """Initializes the tool-aware context filter plugin. + + Args: + num_invocations_to_keep: The number of last invocations to keep. An + invocation is defined as a complete user-model interaction cycle, + including any tool calls and their responses. + custom_filter: A function to apply additional filtering to the context. + name: The name of the plugin instance. + """ + super().__init__(name) + self._num_invocations_to_keep = num_invocations_to_keep + self._custom_filter = custom_filter + + def _has_function_call(self, content) -> bool: + """Check if a content has a function_call part.""" + if not content.parts: + return False + return any( + hasattr(part, "function_call") and part.function_call + for part in content.parts + ) + + def _has_function_response(self, content) -> bool: + """Check if a content has a function_response part.""" + if not content.parts: + return False + return any( + hasattr(part, "function_response") and part.function_response + for part in content.parts + ) + + def _group_into_invocations(self, contents: List) -> List[List[int]]: + """Group message indices into complete invocations. + + An invocation pattern: + 1. One or more user messages (including consecutive user messages) + 2. Model response (possibly with function_call) + 3. If function_call exists: user message(s) with function_response + 4. If function_call exists: model final response + + Example grouping: + Messages: [user, user, model, user, model+func_call, user+func_response, + model] Groups: [0,1,2] [3,4,5,6] + ^^^^^^^ ^^^^^^^^^^^ + Inv 1 Inv 2 (includes tool cycle) + + Args: + contents: List of message contents to group. + + Returns: + List of invocations, where each invocation is a list of message indices. + """ + invocations = [] + current_invocation = [] + i = 0 + + while i < len(contents): + content = contents[i] + + # CASE 1: User message + if content.role == "user": + # Check if this is a function_response (part of ongoing tool cycle) + if self._has_function_response(content): + # This is a tool response - must be part of current invocation + current_invocation.append(i) + i += 1 + else: + # Regular user message (not a function_response) + # Only start a NEW invocation if we've completed a previous one + if current_invocation: + # Check if previous invocation has a model response + has_model = any( + contents[idx].role == "model" for idx in current_invocation + ) + if has_model: + invocations.append(current_invocation) + current_invocation = [] + + # Add this user message to current invocation + current_invocation.append(i) + i += 1 + + # CASE 2: Model message + elif content.role == "model": + current_invocation.append(i) + + # Check if model is making a tool call + if self._has_function_call(content): + # Model made a tool call - keep following messages together: + # 1. This model message (function_call) - already added + # 2. User message(s) with function_response - collect next + # 3. Model's final response - collect after tool responses + + i += 1 # Move to next message + + # Collect all function_response messages (usually 1, but could be + # multiple) + while ( + i < len(contents) + and contents[i].role == "user" + and self._has_function_response(contents[i]) + ): + current_invocation.append(i) + i += 1 + + # Now collect the model's final response after processing tool results + if i < len(contents) and contents[i].role == "model": + current_invocation.append(i) + i += 1 + + # Complete tool cycle collected - this is ONE complete invocation + invocations.append(current_invocation) + current_invocation = [] + else: + # Model response WITHOUT function call - simple case + # The invocation is complete (user query → model answer) + i += 1 + invocations.append(current_invocation) + current_invocation = [] + else: + # Unknown role - just add to current invocation + current_invocation.append(i) + i += 1 + + # Add any remaining messages as final invocation + if current_invocation: + invocations.append(current_invocation) + + return invocations + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Filters the LLM request's context before it is sent to the model. + + This method groups messages into logical invocations and keeps only the + most recent N invocations, ensuring tool call sequences remain intact. + + Args: + callback_context: Context containing invocation and agent information. + llm_request: The LLM request to filter. + + Returns: + None - the request is modified in place. + """ + try: + contents = llm_request.contents + + if not contents: + return None + + # Apply invocation-based filtering if configured + if ( + self._num_invocations_to_keep is not None + and self._num_invocations_to_keep > 0 + ): + # Group messages into logical invocations + invocations = self._group_into_invocations(contents) + + logger.info( + "ToolAwareContextFilter: Total invocations=%d, keeping last %d", + len(invocations), + self._num_invocations_to_keep, + ) + + # Keep only the last N invocations + if len(invocations) > self._num_invocations_to_keep: + invocations_to_keep = invocations[-self._num_invocations_to_keep :] + + # Flatten the list of indices + indices_to_keep = [] + for invocation in invocations_to_keep: + indices_to_keep.extend(invocation) + + # Filter contents based on indices + filtered_contents = [contents[i] for i in sorted(indices_to_keep)] + + logger.info( + "ToolAwareContextFilter: Reduced from %d messages to %d messages" + " (kept %d invocations)", + len(contents), + len(filtered_contents), + len(invocations_to_keep), + ) + + contents = filtered_contents + + # Apply custom filter if provided + if self._custom_filter: + contents = self._custom_filter(contents) + + llm_request.contents = contents + + except Exception as e: + logger.error("ToolAwareContextFilter: Failed to filter context: %s", e) + + return None \ No newline at end of file diff --git a/test_tool_aware_context_filter_plugin.py b/test_tool_aware_context_filter_plugin.py new file mode 100644 index 0000000000..bfa0b5bbdb --- /dev/null +++ b/test_tool_aware_context_filter_plugin.py @@ -0,0 +1,320 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ToolAwareContextFilterPlugin.""" + +import pytest +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.plugins.tool_aware_context_filter_plugin import ( + ToolAwareContextFilterPlugin, +) +from google.genai.types import Content, FunctionCall, FunctionResponse, Part + + +class TestToolAwareContextFilterPlugin: + """Tests for ToolAwareContextFilterPlugin.""" + + def test_init(self): + """Test plugin initialization.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + assert plugin._num_invocations_to_keep == 2 + assert plugin._custom_filter is None + + def test_no_filtering_when_disabled(self): + """Test that no filtering occurs when num_invocations_to_keep is None.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=None) + + contents = [ + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + Content(role="user", parts=[Part(text="How are you?")]), + Content(role="model", parts=[Part(text="I'm good")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # No filtering should occur + assert len(request.contents) == 4 + + def test_simple_invocations_no_tool_calls(self): + """Test filtering simple Q&A without tool calls.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + # Create 3 simple invocations + contents = [ + # Invocation 1 + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 + Content(role="user", parts=[Part(text="How are you?")]), + Content(role="model", parts=[Part(text="I'm good")]), + # Invocation 3 + Content(role="user", parts=[Part(text="What's your name?")]), + Content(role="model", parts=[Part(text="I'm Claude")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should keep last 2 invocations (indices 2-5) + assert len(request.contents) == 4 + assert request.contents[0].parts[0].text == "How are you?" + assert request.contents[-1].parts[0].text == "I'm Claude" + + def test_tool_call_sequence_kept_together(self): + """Test that function_call and function_response stay together.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) + + # Create invocations where the last one has a tool call + contents = [ + # Invocation 1 (should be removed) + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (should be kept - has tool call) + Content(role="user", parts=[Part(text="What's the weather?")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="get_weather", args={"location": "SF"} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="get_weather", response={"temp": 72} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="It's 72°F")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should keep entire tool call sequence (4 messages) + assert len(request.contents) == 4 + assert request.contents[0].parts[0].text == "What's the weather?" + assert hasattr(request.contents[1].parts[0], "function_call") + assert hasattr(request.contents[2].parts[0], "function_response") + assert request.contents[3].parts[0].text == "It's 72°F" + + def test_orphaned_function_response_prevented(self): + """Test that function_response is never orphaned without function_call.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + contents = [ + # Invocation 1 + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (with tool call) + Content(role="user", parts=[Part(text="Query 1")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="tool1", args={} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="tool1", response={} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="Response 1")]), + # Invocation 3 (with tool call) + Content(role="user", parts=[Part(text="Query 2")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="tool2", args={} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="tool2", response={} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="Response 2")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should keep invocations 2 and 3 (indices 2-9) + assert len(request.contents) == 8 + + # Verify no orphaned function_response + for i, content in enumerate(request.contents): + if content.role == "user" and any( + hasattr(p, "function_response") and p.function_response + for p in content.parts + ): + # There must be a preceding model message with function_call + assert i > 0 + prev_content = request.contents[i - 1] + assert prev_content.role == "model" + assert any( + hasattr(p, "function_call") and p.function_call + for p in prev_content.parts + ) + + def test_consecutive_user_messages_grouped(self): + """Test that consecutive user messages are grouped together.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) + + contents = [ + # Invocation 1 (should be removed) + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (should be kept) + Content(role="user", parts=[Part(text="For context:")]), + Content(role="user", parts=[Part(text="Tell me about X")]), + Content(role="model", parts=[Part(text="Here's about X")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should keep the last invocation with both user messages + assert len(request.contents) == 3 + assert request.contents[0].parts[0].text == "For context:" + assert request.contents[1].parts[0].text == "Tell me about X" + assert request.contents[2].parts[0].text == "Here's about X" + + def test_custom_filter_applied(self): + """Test that custom filter is applied after invocation filtering.""" + # Custom filter that removes all model messages + def custom_filter(contents): + return [c for c in contents if c.role != "model"] + + plugin = ToolAwareContextFilterPlugin( + num_invocations_to_keep=2, custom_filter=custom_filter + ) + + contents = [ + Content(role="user", parts=[Part(text="Query 1")]), + Content(role="model", parts=[Part(text="Response 1")]), + Content(role="user", parts=[Part(text="Query 2")]), + Content(role="model", parts=[Part(text="Response 2")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should have only user messages + assert len(request.contents) == 2 + assert all(c.role == "user" for c in request.contents) + + def test_empty_contents(self): + """Test handling of empty contents list.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + request = LlmRequest(model="test", contents=[]) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should handle empty contents gracefully + assert len(request.contents) == 0 + + def test_error_handling(self): + """Test that errors are caught and logged without crashing.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + # Create a malformed request that might cause errors + request = LlmRequest(model="test", contents=None) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Should not raise an exception + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should return None without crashing + assert result is None \ No newline at end of file diff --git a/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py b/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py new file mode 100644 index 0000000000..bfa0b5bbdb --- /dev/null +++ b/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py @@ -0,0 +1,320 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for ToolAwareContextFilterPlugin.""" + +import pytest +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.plugins.tool_aware_context_filter_plugin import ( + ToolAwareContextFilterPlugin, +) +from google.genai.types import Content, FunctionCall, FunctionResponse, Part + + +class TestToolAwareContextFilterPlugin: + """Tests for ToolAwareContextFilterPlugin.""" + + def test_init(self): + """Test plugin initialization.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + assert plugin._num_invocations_to_keep == 2 + assert plugin._custom_filter is None + + def test_no_filtering_when_disabled(self): + """Test that no filtering occurs when num_invocations_to_keep is None.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=None) + + contents = [ + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + Content(role="user", parts=[Part(text="How are you?")]), + Content(role="model", parts=[Part(text="I'm good")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # No filtering should occur + assert len(request.contents) == 4 + + def test_simple_invocations_no_tool_calls(self): + """Test filtering simple Q&A without tool calls.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + # Create 3 simple invocations + contents = [ + # Invocation 1 + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 + Content(role="user", parts=[Part(text="How are you?")]), + Content(role="model", parts=[Part(text="I'm good")]), + # Invocation 3 + Content(role="user", parts=[Part(text="What's your name?")]), + Content(role="model", parts=[Part(text="I'm Claude")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should keep last 2 invocations (indices 2-5) + assert len(request.contents) == 4 + assert request.contents[0].parts[0].text == "How are you?" + assert request.contents[-1].parts[0].text == "I'm Claude" + + def test_tool_call_sequence_kept_together(self): + """Test that function_call and function_response stay together.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) + + # Create invocations where the last one has a tool call + contents = [ + # Invocation 1 (should be removed) + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (should be kept - has tool call) + Content(role="user", parts=[Part(text="What's the weather?")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="get_weather", args={"location": "SF"} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="get_weather", response={"temp": 72} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="It's 72°F")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should keep entire tool call sequence (4 messages) + assert len(request.contents) == 4 + assert request.contents[0].parts[0].text == "What's the weather?" + assert hasattr(request.contents[1].parts[0], "function_call") + assert hasattr(request.contents[2].parts[0], "function_response") + assert request.contents[3].parts[0].text == "It's 72°F" + + def test_orphaned_function_response_prevented(self): + """Test that function_response is never orphaned without function_call.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + contents = [ + # Invocation 1 + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (with tool call) + Content(role="user", parts=[Part(text="Query 1")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="tool1", args={} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="tool1", response={} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="Response 1")]), + # Invocation 3 (with tool call) + Content(role="user", parts=[Part(text="Query 2")]), + Content( + role="model", + parts=[ + Part( + function_call=FunctionCall( + name="tool2", args={} + ) + ) + ], + ), + Content( + role="user", + parts=[ + Part( + function_response=FunctionResponse( + name="tool2", response={} + ) + ) + ], + ), + Content(role="model", parts=[Part(text="Response 2")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should keep invocations 2 and 3 (indices 2-9) + assert len(request.contents) == 8 + + # Verify no orphaned function_response + for i, content in enumerate(request.contents): + if content.role == "user" and any( + hasattr(p, "function_response") and p.function_response + for p in content.parts + ): + # There must be a preceding model message with function_call + assert i > 0 + prev_content = request.contents[i - 1] + assert prev_content.role == "model" + assert any( + hasattr(p, "function_call") and p.function_call + for p in prev_content.parts + ) + + def test_consecutive_user_messages_grouped(self): + """Test that consecutive user messages are grouped together.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) + + contents = [ + # Invocation 1 (should be removed) + Content(role="user", parts=[Part(text="Hello")]), + Content(role="model", parts=[Part(text="Hi")]), + # Invocation 2 (should be kept) + Content(role="user", parts=[Part(text="For context:")]), + Content(role="user", parts=[Part(text="Tell me about X")]), + Content(role="model", parts=[Part(text="Here's about X")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should keep the last invocation with both user messages + assert len(request.contents) == 3 + assert request.contents[0].parts[0].text == "For context:" + assert request.contents[1].parts[0].text == "Tell me about X" + assert request.contents[2].parts[0].text == "Here's about X" + + def test_custom_filter_applied(self): + """Test that custom filter is applied after invocation filtering.""" + # Custom filter that removes all model messages + def custom_filter(contents): + return [c for c in contents if c.role != "model"] + + plugin = ToolAwareContextFilterPlugin( + num_invocations_to_keep=2, custom_filter=custom_filter + ) + + contents = [ + Content(role="user", parts=[Part(text="Query 1")]), + Content(role="model", parts=[Part(text="Response 1")]), + Content(role="user", parts=[Part(text="Query 2")]), + Content(role="model", parts=[Part(text="Response 2")]), + ] + + request = LlmRequest(model="test", contents=contents) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should have only user messages + assert len(request.contents) == 2 + assert all(c.role == "user" for c in request.contents) + + def test_empty_contents(self): + """Test handling of empty contents list.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + request = LlmRequest(model="test", contents=[]) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Run the plugin + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should handle empty contents gracefully + assert len(request.contents) == 0 + + def test_error_handling(self): + """Test that errors are caught and logged without crashing.""" + plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) + + # Create a malformed request that might cause errors + request = LlmRequest(model="test", contents=None) + context = CallbackContext(invocation_id="test", agent_name="test") + + # Should not raise an exception + result = pytest.mark.asyncio( + plugin.before_model_callback( + callback_context=context, llm_request=request + ) + ) + + # Should return None without crashing + assert result is None \ No newline at end of file diff --git a/tool_aware_context_filter_plugin.py b/tool_aware_context_filter_plugin.py new file mode 100644 index 0000000000..0a3b53e0ab --- /dev/null +++ b/tool_aware_context_filter_plugin.py @@ -0,0 +1,267 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tool-aware context filter plugin for managing conversation history. + +This plugin extends the standard context filtering to properly handle function +call/response sequences, ensuring they remain atomic during history trimming. + +PROBLEM WITH STANDARD ContextFilterPlugin: +========================================== +The standard ContextFilterPlugin treats each model message as a separate +"invocation", but when a model makes a tool call, it creates MULTIPLE model +messages in sequence: + 1. Model message with function_call + 2. User message with function_response (tool result) + 3. Model message with final text response + +When filtering to keep N "invocations", the standard plugin can split these +related messages apart, creating orphaned function_responses without their +corresponding function_calls, which violates OpenAI API requirements. + +HOW THIS PLUGIN SOLVES IT: +=========================== +This plugin groups messages into LOGICAL invocations where a complete cycle is: + - User query (one or more messages) + - Model response (possibly with function_call) + - Function response(s) (if tool was called) + - Model final response (after tool execution) + +All messages in a tool call sequence are kept together as an atomic unit. +""" + +from __future__ import annotations + +import logging +from typing import Callable, List, Optional + +from google.adk.agents.callback_context import CallbackContext +from google.adk.events.event import Event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.plugins.base_plugin import BasePlugin + +logger = logging.getLogger("google_adk." + __name__) + + +class ToolAwareContextFilterPlugin(BasePlugin): + """A plugin that filters LLM context while preserving tool call sequences. + + This plugin extends context filtering to handle function call/response pairs + correctly, ensuring they are never split during history trimming. + """ + + def __init__( + self, + num_invocations_to_keep: Optional[int] = None, + custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, + name: str = "tool_aware_context_filter_plugin", + ): + """Initializes the tool-aware context filter plugin. + + Args: + num_invocations_to_keep: The number of last invocations to keep. An + invocation is defined as a complete user-model interaction cycle, + including any tool calls and their responses. + custom_filter: A function to apply additional filtering to the context. + name: The name of the plugin instance. + """ + super().__init__(name) + self._num_invocations_to_keep = num_invocations_to_keep + self._custom_filter = custom_filter + + def _has_function_call(self, content) -> bool: + """Check if a content has a function_call part.""" + if not content.parts: + return False + return any( + hasattr(part, "function_call") and part.function_call + for part in content.parts + ) + + def _has_function_response(self, content) -> bool: + """Check if a content has a function_response part.""" + if not content.parts: + return False + return any( + hasattr(part, "function_response") and part.function_response + for part in content.parts + ) + + def _group_into_invocations(self, contents: List) -> List[List[int]]: + """Group message indices into complete invocations. + + An invocation pattern: + 1. One or more user messages (including consecutive user messages) + 2. Model response (possibly with function_call) + 3. If function_call exists: user message(s) with function_response + 4. If function_call exists: model final response + + Example grouping: + Messages: [user, user, model, user, model+func_call, user+func_response, + model] Groups: [0,1,2] [3,4,5,6] + ^^^^^^^ ^^^^^^^^^^^ + Inv 1 Inv 2 (includes tool cycle) + + Args: + contents: List of message contents to group. + + Returns: + List of invocations, where each invocation is a list of message indices. + """ + invocations = [] + current_invocation = [] + i = 0 + + while i < len(contents): + content = contents[i] + + # CASE 1: User message + if content.role == "user": + # Check if this is a function_response (part of ongoing tool cycle) + if self._has_function_response(content): + # This is a tool response - must be part of current invocation + current_invocation.append(i) + i += 1 + else: + # Regular user message (not a function_response) + # Only start a NEW invocation if we've completed a previous one + if current_invocation: + # Check if previous invocation has a model response + has_model = any( + contents[idx].role == "model" for idx in current_invocation + ) + if has_model: + invocations.append(current_invocation) + current_invocation = [] + + # Add this user message to current invocation + current_invocation.append(i) + i += 1 + + # CASE 2: Model message + elif content.role == "model": + current_invocation.append(i) + + # Check if model is making a tool call + if self._has_function_call(content): + # Model made a tool call - keep following messages together: + # 1. This model message (function_call) - already added + # 2. User message(s) with function_response - collect next + # 3. Model's final response - collect after tool responses + + i += 1 # Move to next message + + # Collect all function_response messages (usually 1, but could be + # multiple) + while ( + i < len(contents) + and contents[i].role == "user" + and self._has_function_response(contents[i]) + ): + current_invocation.append(i) + i += 1 + + # Now collect the model's final response after processing tool results + if i < len(contents) and contents[i].role == "model": + current_invocation.append(i) + i += 1 + + # Complete tool cycle collected - this is ONE complete invocation + invocations.append(current_invocation) + current_invocation = [] + else: + # Model response WITHOUT function call - simple case + # The invocation is complete (user query → model answer) + i += 1 + invocations.append(current_invocation) + current_invocation = [] + else: + # Unknown role - just add to current invocation + current_invocation.append(i) + i += 1 + + # Add any remaining messages as final invocation + if current_invocation: + invocations.append(current_invocation) + + return invocations + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Filters the LLM request's context before it is sent to the model. + + This method groups messages into logical invocations and keeps only the + most recent N invocations, ensuring tool call sequences remain intact. + + Args: + callback_context: Context containing invocation and agent information. + llm_request: The LLM request to filter. + + Returns: + None - the request is modified in place. + """ + try: + contents = llm_request.contents + + if not contents: + return None + + # Apply invocation-based filtering if configured + if ( + self._num_invocations_to_keep is not None + and self._num_invocations_to_keep > 0 + ): + # Group messages into logical invocations + invocations = self._group_into_invocations(contents) + + logger.info( + "ToolAwareContextFilter: Total invocations=%d, keeping last %d", + len(invocations), + self._num_invocations_to_keep, + ) + + # Keep only the last N invocations + if len(invocations) > self._num_invocations_to_keep: + invocations_to_keep = invocations[-self._num_invocations_to_keep :] + + # Flatten the list of indices + indices_to_keep = [] + for invocation in invocations_to_keep: + indices_to_keep.extend(invocation) + + # Filter contents based on indices + filtered_contents = [contents[i] for i in sorted(indices_to_keep)] + + logger.info( + "ToolAwareContextFilter: Reduced from %d messages to %d messages" + " (kept %d invocations)", + len(contents), + len(filtered_contents), + len(invocations_to_keep), + ) + + contents = filtered_contents + + # Apply custom filter if provided + if self._custom_filter: + contents = self._custom_filter(contents) + + llm_request.contents = contents + + except Exception as e: + logger.error("ToolAwareContextFilter: Failed to filter context: %s", e) + + return None \ No newline at end of file From 7e130a330a631e989d18be868430f129caef3522 Mon Sep 17 00:00:00 2001 From: shekhar Date: Mon, 5 Jan 2026 20:03:23 +0530 Subject: [PATCH 2/4] fix: Address PR feedback for ToolAwareContextFilterPlugin - Fix async test execution by adding @pytest.mark.asyncio decorators and await keywords - Fix test_error_handling to create valid LlmRequest before setting contents to None - Convert _has_function_call() and _has_function_response() to @staticmethod - Remove redundant sorted() call in filtering logic for minor performance improvement --- .../tool_aware_context_filter_plugin.py | 8 +- .../test_tool_aware_context_filter_plugin.py | 77 +++++++++---------- 2 files changed, 40 insertions(+), 45 deletions(-) diff --git a/src/google/adk/plugins/tool_aware_context_filter_plugin.py b/src/google/adk/plugins/tool_aware_context_filter_plugin.py index 0a3b53e0ab..5835b890e8 100644 --- a/src/google/adk/plugins/tool_aware_context_filter_plugin.py +++ b/src/google/adk/plugins/tool_aware_context_filter_plugin.py @@ -81,7 +81,8 @@ def __init__( self._num_invocations_to_keep = num_invocations_to_keep self._custom_filter = custom_filter - def _has_function_call(self, content) -> bool: + @staticmethod + def _has_function_call(content) -> bool: """Check if a content has a function_call part.""" if not content.parts: return False @@ -90,7 +91,8 @@ def _has_function_call(self, content) -> bool: for part in content.parts ) - def _has_function_response(self, content) -> bool: + @staticmethod + def _has_function_response(content) -> bool: """Check if a content has a function_response part.""" if not content.parts: return False @@ -243,7 +245,7 @@ async def before_model_callback( indices_to_keep.extend(invocation) # Filter contents based on indices - filtered_contents = [contents[i] for i in sorted(indices_to_keep)] + filtered_contents = [contents[i] for i in indices_to_keep] logger.info( "ToolAwareContextFilter: Reduced from %d messages to %d messages" diff --git a/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py b/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py index bfa0b5bbdb..a4c88710a1 100644 --- a/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py +++ b/tests/unittests/plugins/test_tool_aware_context_filter_plugin.py @@ -32,7 +32,8 @@ def test_init(self): assert plugin._num_invocations_to_keep == 2 assert plugin._custom_filter is None - def test_no_filtering_when_disabled(self): + @pytest.mark.asyncio + async def test_no_filtering_when_disabled(self): """Test that no filtering occurs when num_invocations_to_keep is None.""" plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=None) @@ -47,16 +48,15 @@ def test_no_filtering_when_disabled(self): context = CallbackContext(invocation_id="test", agent_name="test") # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) + await plugin.before_model_callback( + callback_context=context, llm_request=request ) # No filtering should occur assert len(request.contents) == 4 - def test_simple_invocations_no_tool_calls(self): + @pytest.mark.asyncio + async def test_simple_invocations_no_tool_calls(self): """Test filtering simple Q&A without tool calls.""" plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) @@ -77,10 +77,8 @@ def test_simple_invocations_no_tool_calls(self): context = CallbackContext(invocation_id="test", agent_name="test") # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) + await plugin.before_model_callback( + callback_context=context, llm_request=request ) # Should keep last 2 invocations (indices 2-5) @@ -88,7 +86,8 @@ def test_simple_invocations_no_tool_calls(self): assert request.contents[0].parts[0].text == "How are you?" assert request.contents[-1].parts[0].text == "I'm Claude" - def test_tool_call_sequence_kept_together(self): + @pytest.mark.asyncio + async def test_tool_call_sequence_kept_together(self): """Test that function_call and function_response stay together.""" plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) @@ -126,10 +125,8 @@ def test_tool_call_sequence_kept_together(self): context = CallbackContext(invocation_id="test", agent_name="test") # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) + await plugin.before_model_callback( + callback_context=context, llm_request=request ) # Should keep entire tool call sequence (4 messages) @@ -139,7 +136,8 @@ def test_tool_call_sequence_kept_together(self): assert hasattr(request.contents[2].parts[0], "function_response") assert request.contents[3].parts[0].text == "It's 72°F" - def test_orphaned_function_response_prevented(self): + @pytest.mark.asyncio + async def test_orphaned_function_response_prevented(self): """Test that function_response is never orphaned without function_call.""" plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) @@ -199,10 +197,8 @@ def test_orphaned_function_response_prevented(self): context = CallbackContext(invocation_id="test", agent_name="test") # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) + await plugin.before_model_callback( + callback_context=context, llm_request=request ) # Should keep invocations 2 and 3 (indices 2-9) @@ -223,7 +219,8 @@ def test_orphaned_function_response_prevented(self): for p in prev_content.parts ) - def test_consecutive_user_messages_grouped(self): + @pytest.mark.asyncio + async def test_consecutive_user_messages_grouped(self): """Test that consecutive user messages are grouped together.""" plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) @@ -241,10 +238,8 @@ def test_consecutive_user_messages_grouped(self): context = CallbackContext(invocation_id="test", agent_name="test") # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) + await plugin.before_model_callback( + callback_context=context, llm_request=request ) # Should keep the last invocation with both user messages @@ -253,7 +248,8 @@ def test_consecutive_user_messages_grouped(self): assert request.contents[1].parts[0].text == "Tell me about X" assert request.contents[2].parts[0].text == "Here's about X" - def test_custom_filter_applied(self): + @pytest.mark.asyncio + async def test_custom_filter_applied(self): """Test that custom filter is applied after invocation filtering.""" # Custom filter that removes all model messages def custom_filter(contents): @@ -274,17 +270,16 @@ def custom_filter(contents): context = CallbackContext(invocation_id="test", agent_name="test") # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) + await plugin.before_model_callback( + callback_context=context, llm_request=request ) # Should have only user messages assert len(request.contents) == 2 assert all(c.role == "user" for c in request.contents) - def test_empty_contents(self): + @pytest.mark.asyncio + async def test_empty_contents(self): """Test handling of empty contents list.""" plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) @@ -292,28 +287,26 @@ def test_empty_contents(self): context = CallbackContext(invocation_id="test", agent_name="test") # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) + await plugin.before_model_callback( + callback_context=context, llm_request=request ) # Should handle empty contents gracefully assert len(request.contents) == 0 - def test_error_handling(self): + @pytest.mark.asyncio + async def test_error_handling(self): """Test that errors are caught and logged without crashing.""" plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) - # Create a malformed request that might cause errors - request = LlmRequest(model="test", contents=None) + # Create a valid request first, then set contents to None + request = LlmRequest(model="test", contents=[]) context = CallbackContext(invocation_id="test", agent_name="test") + request.contents = None # Should not raise an exception - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) + result = await plugin.before_model_callback( + callback_context=context, llm_request=request ) # Should return None without crashing From 1213b65d56de4e74826a9f9e426e1a595cd8f957 Mon Sep 17 00:00:00 2001 From: shekhar Date: Mon, 5 Jan 2026 20:21:21 +0530 Subject: [PATCH 3/4] refactor: Improve type safety and code clarity in ToolAwareContextFilterPlugin - Fix type hint for custom_filter to use List[types.Content] instead of List[Event] - Add type hint for contents parameter: List[types.Content] - Refactor _group_into_invocations into smaller helper methods for better maintainability: - _finalize_invocation_if_complete: Finalize current invocation if complete - _process_user_message: Handle user message processing - _process_model_message: Handle model message processing - _process_model_message_with_tool_call: Handle tool call sequences - Remove unused Event import - Add google.genai.types import for proper type annotations --- .../tool_aware_context_filter_plugin.py | 210 +++++++++++++----- 1 file changed, 149 insertions(+), 61 deletions(-) diff --git a/src/google/adk/plugins/tool_aware_context_filter_plugin.py b/src/google/adk/plugins/tool_aware_context_filter_plugin.py index 5835b890e8..a0e3d147fc 100644 --- a/src/google/adk/plugins/tool_aware_context_filter_plugin.py +++ b/src/google/adk/plugins/tool_aware_context_filter_plugin.py @@ -47,10 +47,10 @@ from typing import Callable, List, Optional from google.adk.agents.callback_context import CallbackContext -from google.adk.events.event import Event from google.adk.models.llm_request import LlmRequest from google.adk.models.llm_response import LlmResponse from google.adk.plugins.base_plugin import BasePlugin +from google.genai import types logger = logging.getLogger("google_adk." + __name__) @@ -65,7 +65,9 @@ class ToolAwareContextFilterPlugin(BasePlugin): def __init__( self, num_invocations_to_keep: Optional[int] = None, - custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, + custom_filter: Optional[ + Callable[[List[types.Content]], List[types.Content]] + ] = None, name: str = "tool_aware_context_filter_plugin", ): """Initializes the tool-aware context filter plugin. @@ -101,7 +103,145 @@ def _has_function_response(content) -> bool: for part in content.parts ) - def _group_into_invocations(self, contents: List) -> List[List[int]]: + def _finalize_invocation_if_complete( + self, + current_invocation: List[int], + invocations: List[List[int]], + contents: List[types.Content], + ) -> List[int]: + """Finalize current invocation if it has a model response. + + Args: + current_invocation: Current invocation being built. + invocations: List of completed invocations. + contents: List of message contents. + + Returns: + Empty list if invocation was finalized, otherwise current_invocation. + """ + if current_invocation: + has_model = any( + contents[idx].role == "model" for idx in current_invocation + ) + if has_model: + invocations.append(current_invocation) + return [] + return current_invocation + + def _process_user_message( + self, + i: int, + contents: List[types.Content], + current_invocation: List[int], + invocations: List[List[int]], + ) -> tuple[int, List[int]]: + """Process a user message and update invocation tracking. + + Args: + i: Current index in contents. + contents: List of message contents. + current_invocation: Current invocation being built. + invocations: List of completed invocations. + + Returns: + Tuple of (next_index, updated_current_invocation). + """ + content = contents[i] + + # Check if this is a function_response (part of ongoing tool cycle) + if self._has_function_response(content): + # This is a tool response - must be part of current invocation + current_invocation.append(i) + return i + 1, current_invocation + + # Regular user message (not a function_response) + # Only start a NEW invocation if we've completed a previous one + current_invocation = self._finalize_invocation_if_complete( + current_invocation, invocations, contents + ) + + # Add this user message to current invocation + current_invocation.append(i) + return i + 1, current_invocation + + def _process_model_message_with_tool_call( + self, + i: int, + contents: List[types.Content], + current_invocation: List[int], + invocations: List[List[int]], + ) -> tuple[int, List[int]]: + """Process a model message with tool call and collect the full cycle. + + Args: + i: Current index in contents (at model message with function_call). + contents: List of message contents. + current_invocation: Current invocation being built. + invocations: List of completed invocations. + + Returns: + Tuple of (next_index, empty_invocation_list). + """ + # Model made a tool call - keep following messages together: + # 1. This model message (function_call) - already added + # 2. User message(s) with function_response - collect next + # 3. Model's final response - collect after tool responses + + i += 1 # Move to next message + + # Collect all function_response messages (usually 1, but could be multiple) + while ( + i < len(contents) + and contents[i].role == "user" + and self._has_function_response(contents[i]) + ): + current_invocation.append(i) + i += 1 + + # Now collect the model's final response after processing tool results + if i < len(contents) and contents[i].role == "model": + current_invocation.append(i) + i += 1 + + # Complete tool cycle collected - this is ONE complete invocation + invocations.append(current_invocation) + return i, [] + + def _process_model_message( + self, + i: int, + contents: List[types.Content], + current_invocation: List[int], + invocations: List[List[int]], + ) -> tuple[int, List[int]]: + """Process a model message and update invocation tracking. + + Args: + i: Current index in contents. + contents: List of message contents. + current_invocation: Current invocation being built. + invocations: List of completed invocations. + + Returns: + Tuple of (next_index, updated_current_invocation). + """ + content = contents[i] + current_invocation.append(i) + + # Check if model is making a tool call + if self._has_function_call(content): + return self._process_model_message_with_tool_call( + i, contents, current_invocation, invocations + ) + + # Model response WITHOUT function call - simple case + # The invocation is complete (user query → model answer) + invocations.append(current_invocation) + return i + 1, [] + + def _group_into_invocations( + self, contents: List[types.Content] + ) -> List[List[int]]: """Group message indices into complete invocations. An invocation pattern: @@ -129,66 +269,14 @@ def _group_into_invocations(self, contents: List) -> List[List[int]]: while i < len(contents): content = contents[i] - # CASE 1: User message if content.role == "user": - # Check if this is a function_response (part of ongoing tool cycle) - if self._has_function_response(content): - # This is a tool response - must be part of current invocation - current_invocation.append(i) - i += 1 - else: - # Regular user message (not a function_response) - # Only start a NEW invocation if we've completed a previous one - if current_invocation: - # Check if previous invocation has a model response - has_model = any( - contents[idx].role == "model" for idx in current_invocation - ) - if has_model: - invocations.append(current_invocation) - current_invocation = [] - - # Add this user message to current invocation - current_invocation.append(i) - i += 1 - - # CASE 2: Model message + i, current_invocation = self._process_user_message( + i, contents, current_invocation, invocations + ) elif content.role == "model": - current_invocation.append(i) - - # Check if model is making a tool call - if self._has_function_call(content): - # Model made a tool call - keep following messages together: - # 1. This model message (function_call) - already added - # 2. User message(s) with function_response - collect next - # 3. Model's final response - collect after tool responses - - i += 1 # Move to next message - - # Collect all function_response messages (usually 1, but could be - # multiple) - while ( - i < len(contents) - and contents[i].role == "user" - and self._has_function_response(contents[i]) - ): - current_invocation.append(i) - i += 1 - - # Now collect the model's final response after processing tool results - if i < len(contents) and contents[i].role == "model": - current_invocation.append(i) - i += 1 - - # Complete tool cycle collected - this is ONE complete invocation - invocations.append(current_invocation) - current_invocation = [] - else: - # Model response WITHOUT function call - simple case - # The invocation is complete (user query → model answer) - i += 1 - invocations.append(current_invocation) - current_invocation = [] + i, current_invocation = self._process_model_message( + i, contents, current_invocation, invocations + ) else: # Unknown role - just add to current invocation current_invocation.append(i) From f1677d95f6ddd2a13d87bcc8abc775b031d4bb8d Mon Sep 17 00:00:00 2001 From: shekhar Date: Mon, 5 Jan 2026 20:23:32 +0530 Subject: [PATCH 4/4] chore: Remove duplicate files --- test_tool_aware_context_filter_plugin.py | 320 ----------------------- tool_aware_context_filter_plugin.py | 267 ------------------- 2 files changed, 587 deletions(-) delete mode 100644 test_tool_aware_context_filter_plugin.py delete mode 100644 tool_aware_context_filter_plugin.py diff --git a/test_tool_aware_context_filter_plugin.py b/test_tool_aware_context_filter_plugin.py deleted file mode 100644 index bfa0b5bbdb..0000000000 --- a/test_tool_aware_context_filter_plugin.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for ToolAwareContextFilterPlugin.""" - -import pytest -from google.adk.agents.callback_context import CallbackContext -from google.adk.models.llm_request import LlmRequest -from google.adk.plugins.tool_aware_context_filter_plugin import ( - ToolAwareContextFilterPlugin, -) -from google.genai.types import Content, FunctionCall, FunctionResponse, Part - - -class TestToolAwareContextFilterPlugin: - """Tests for ToolAwareContextFilterPlugin.""" - - def test_init(self): - """Test plugin initialization.""" - plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) - assert plugin._num_invocations_to_keep == 2 - assert plugin._custom_filter is None - - def test_no_filtering_when_disabled(self): - """Test that no filtering occurs when num_invocations_to_keep is None.""" - plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=None) - - contents = [ - Content(role="user", parts=[Part(text="Hello")]), - Content(role="model", parts=[Part(text="Hi")]), - Content(role="user", parts=[Part(text="How are you?")]), - Content(role="model", parts=[Part(text="I'm good")]), - ] - - request = LlmRequest(model="test", contents=contents) - context = CallbackContext(invocation_id="test", agent_name="test") - - # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) - ) - - # No filtering should occur - assert len(request.contents) == 4 - - def test_simple_invocations_no_tool_calls(self): - """Test filtering simple Q&A without tool calls.""" - plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) - - # Create 3 simple invocations - contents = [ - # Invocation 1 - Content(role="user", parts=[Part(text="Hello")]), - Content(role="model", parts=[Part(text="Hi")]), - # Invocation 2 - Content(role="user", parts=[Part(text="How are you?")]), - Content(role="model", parts=[Part(text="I'm good")]), - # Invocation 3 - Content(role="user", parts=[Part(text="What's your name?")]), - Content(role="model", parts=[Part(text="I'm Claude")]), - ] - - request = LlmRequest(model="test", contents=contents) - context = CallbackContext(invocation_id="test", agent_name="test") - - # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) - ) - - # Should keep last 2 invocations (indices 2-5) - assert len(request.contents) == 4 - assert request.contents[0].parts[0].text == "How are you?" - assert request.contents[-1].parts[0].text == "I'm Claude" - - def test_tool_call_sequence_kept_together(self): - """Test that function_call and function_response stay together.""" - plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) - - # Create invocations where the last one has a tool call - contents = [ - # Invocation 1 (should be removed) - Content(role="user", parts=[Part(text="Hello")]), - Content(role="model", parts=[Part(text="Hi")]), - # Invocation 2 (should be kept - has tool call) - Content(role="user", parts=[Part(text="What's the weather?")]), - Content( - role="model", - parts=[ - Part( - function_call=FunctionCall( - name="get_weather", args={"location": "SF"} - ) - ) - ], - ), - Content( - role="user", - parts=[ - Part( - function_response=FunctionResponse( - name="get_weather", response={"temp": 72} - ) - ) - ], - ), - Content(role="model", parts=[Part(text="It's 72°F")]), - ] - - request = LlmRequest(model="test", contents=contents) - context = CallbackContext(invocation_id="test", agent_name="test") - - # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) - ) - - # Should keep entire tool call sequence (4 messages) - assert len(request.contents) == 4 - assert request.contents[0].parts[0].text == "What's the weather?" - assert hasattr(request.contents[1].parts[0], "function_call") - assert hasattr(request.contents[2].parts[0], "function_response") - assert request.contents[3].parts[0].text == "It's 72°F" - - def test_orphaned_function_response_prevented(self): - """Test that function_response is never orphaned without function_call.""" - plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) - - contents = [ - # Invocation 1 - Content(role="user", parts=[Part(text="Hello")]), - Content(role="model", parts=[Part(text="Hi")]), - # Invocation 2 (with tool call) - Content(role="user", parts=[Part(text="Query 1")]), - Content( - role="model", - parts=[ - Part( - function_call=FunctionCall( - name="tool1", args={} - ) - ) - ], - ), - Content( - role="user", - parts=[ - Part( - function_response=FunctionResponse( - name="tool1", response={} - ) - ) - ], - ), - Content(role="model", parts=[Part(text="Response 1")]), - # Invocation 3 (with tool call) - Content(role="user", parts=[Part(text="Query 2")]), - Content( - role="model", - parts=[ - Part( - function_call=FunctionCall( - name="tool2", args={} - ) - ) - ], - ), - Content( - role="user", - parts=[ - Part( - function_response=FunctionResponse( - name="tool2", response={} - ) - ) - ], - ), - Content(role="model", parts=[Part(text="Response 2")]), - ] - - request = LlmRequest(model="test", contents=contents) - context = CallbackContext(invocation_id="test", agent_name="test") - - # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) - ) - - # Should keep invocations 2 and 3 (indices 2-9) - assert len(request.contents) == 8 - - # Verify no orphaned function_response - for i, content in enumerate(request.contents): - if content.role == "user" and any( - hasattr(p, "function_response") and p.function_response - for p in content.parts - ): - # There must be a preceding model message with function_call - assert i > 0 - prev_content = request.contents[i - 1] - assert prev_content.role == "model" - assert any( - hasattr(p, "function_call") and p.function_call - for p in prev_content.parts - ) - - def test_consecutive_user_messages_grouped(self): - """Test that consecutive user messages are grouped together.""" - plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=1) - - contents = [ - # Invocation 1 (should be removed) - Content(role="user", parts=[Part(text="Hello")]), - Content(role="model", parts=[Part(text="Hi")]), - # Invocation 2 (should be kept) - Content(role="user", parts=[Part(text="For context:")]), - Content(role="user", parts=[Part(text="Tell me about X")]), - Content(role="model", parts=[Part(text="Here's about X")]), - ] - - request = LlmRequest(model="test", contents=contents) - context = CallbackContext(invocation_id="test", agent_name="test") - - # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) - ) - - # Should keep the last invocation with both user messages - assert len(request.contents) == 3 - assert request.contents[0].parts[0].text == "For context:" - assert request.contents[1].parts[0].text == "Tell me about X" - assert request.contents[2].parts[0].text == "Here's about X" - - def test_custom_filter_applied(self): - """Test that custom filter is applied after invocation filtering.""" - # Custom filter that removes all model messages - def custom_filter(contents): - return [c for c in contents if c.role != "model"] - - plugin = ToolAwareContextFilterPlugin( - num_invocations_to_keep=2, custom_filter=custom_filter - ) - - contents = [ - Content(role="user", parts=[Part(text="Query 1")]), - Content(role="model", parts=[Part(text="Response 1")]), - Content(role="user", parts=[Part(text="Query 2")]), - Content(role="model", parts=[Part(text="Response 2")]), - ] - - request = LlmRequest(model="test", contents=contents) - context = CallbackContext(invocation_id="test", agent_name="test") - - # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) - ) - - # Should have only user messages - assert len(request.contents) == 2 - assert all(c.role == "user" for c in request.contents) - - def test_empty_contents(self): - """Test handling of empty contents list.""" - plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) - - request = LlmRequest(model="test", contents=[]) - context = CallbackContext(invocation_id="test", agent_name="test") - - # Run the plugin - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) - ) - - # Should handle empty contents gracefully - assert len(request.contents) == 0 - - def test_error_handling(self): - """Test that errors are caught and logged without crashing.""" - plugin = ToolAwareContextFilterPlugin(num_invocations_to_keep=2) - - # Create a malformed request that might cause errors - request = LlmRequest(model="test", contents=None) - context = CallbackContext(invocation_id="test", agent_name="test") - - # Should not raise an exception - result = pytest.mark.asyncio( - plugin.before_model_callback( - callback_context=context, llm_request=request - ) - ) - - # Should return None without crashing - assert result is None \ No newline at end of file diff --git a/tool_aware_context_filter_plugin.py b/tool_aware_context_filter_plugin.py deleted file mode 100644 index 0a3b53e0ab..0000000000 --- a/tool_aware_context_filter_plugin.py +++ /dev/null @@ -1,267 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tool-aware context filter plugin for managing conversation history. - -This plugin extends the standard context filtering to properly handle function -call/response sequences, ensuring they remain atomic during history trimming. - -PROBLEM WITH STANDARD ContextFilterPlugin: -========================================== -The standard ContextFilterPlugin treats each model message as a separate -"invocation", but when a model makes a tool call, it creates MULTIPLE model -messages in sequence: - 1. Model message with function_call - 2. User message with function_response (tool result) - 3. Model message with final text response - -When filtering to keep N "invocations", the standard plugin can split these -related messages apart, creating orphaned function_responses without their -corresponding function_calls, which violates OpenAI API requirements. - -HOW THIS PLUGIN SOLVES IT: -=========================== -This plugin groups messages into LOGICAL invocations where a complete cycle is: - - User query (one or more messages) - - Model response (possibly with function_call) - - Function response(s) (if tool was called) - - Model final response (after tool execution) - -All messages in a tool call sequence are kept together as an atomic unit. -""" - -from __future__ import annotations - -import logging -from typing import Callable, List, Optional - -from google.adk.agents.callback_context import CallbackContext -from google.adk.events.event import Event -from google.adk.models.llm_request import LlmRequest -from google.adk.models.llm_response import LlmResponse -from google.adk.plugins.base_plugin import BasePlugin - -logger = logging.getLogger("google_adk." + __name__) - - -class ToolAwareContextFilterPlugin(BasePlugin): - """A plugin that filters LLM context while preserving tool call sequences. - - This plugin extends context filtering to handle function call/response pairs - correctly, ensuring they are never split during history trimming. - """ - - def __init__( - self, - num_invocations_to_keep: Optional[int] = None, - custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, - name: str = "tool_aware_context_filter_plugin", - ): - """Initializes the tool-aware context filter plugin. - - Args: - num_invocations_to_keep: The number of last invocations to keep. An - invocation is defined as a complete user-model interaction cycle, - including any tool calls and their responses. - custom_filter: A function to apply additional filtering to the context. - name: The name of the plugin instance. - """ - super().__init__(name) - self._num_invocations_to_keep = num_invocations_to_keep - self._custom_filter = custom_filter - - def _has_function_call(self, content) -> bool: - """Check if a content has a function_call part.""" - if not content.parts: - return False - return any( - hasattr(part, "function_call") and part.function_call - for part in content.parts - ) - - def _has_function_response(self, content) -> bool: - """Check if a content has a function_response part.""" - if not content.parts: - return False - return any( - hasattr(part, "function_response") and part.function_response - for part in content.parts - ) - - def _group_into_invocations(self, contents: List) -> List[List[int]]: - """Group message indices into complete invocations. - - An invocation pattern: - 1. One or more user messages (including consecutive user messages) - 2. Model response (possibly with function_call) - 3. If function_call exists: user message(s) with function_response - 4. If function_call exists: model final response - - Example grouping: - Messages: [user, user, model, user, model+func_call, user+func_response, - model] Groups: [0,1,2] [3,4,5,6] - ^^^^^^^ ^^^^^^^^^^^ - Inv 1 Inv 2 (includes tool cycle) - - Args: - contents: List of message contents to group. - - Returns: - List of invocations, where each invocation is a list of message indices. - """ - invocations = [] - current_invocation = [] - i = 0 - - while i < len(contents): - content = contents[i] - - # CASE 1: User message - if content.role == "user": - # Check if this is a function_response (part of ongoing tool cycle) - if self._has_function_response(content): - # This is a tool response - must be part of current invocation - current_invocation.append(i) - i += 1 - else: - # Regular user message (not a function_response) - # Only start a NEW invocation if we've completed a previous one - if current_invocation: - # Check if previous invocation has a model response - has_model = any( - contents[idx].role == "model" for idx in current_invocation - ) - if has_model: - invocations.append(current_invocation) - current_invocation = [] - - # Add this user message to current invocation - current_invocation.append(i) - i += 1 - - # CASE 2: Model message - elif content.role == "model": - current_invocation.append(i) - - # Check if model is making a tool call - if self._has_function_call(content): - # Model made a tool call - keep following messages together: - # 1. This model message (function_call) - already added - # 2. User message(s) with function_response - collect next - # 3. Model's final response - collect after tool responses - - i += 1 # Move to next message - - # Collect all function_response messages (usually 1, but could be - # multiple) - while ( - i < len(contents) - and contents[i].role == "user" - and self._has_function_response(contents[i]) - ): - current_invocation.append(i) - i += 1 - - # Now collect the model's final response after processing tool results - if i < len(contents) and contents[i].role == "model": - current_invocation.append(i) - i += 1 - - # Complete tool cycle collected - this is ONE complete invocation - invocations.append(current_invocation) - current_invocation = [] - else: - # Model response WITHOUT function call - simple case - # The invocation is complete (user query → model answer) - i += 1 - invocations.append(current_invocation) - current_invocation = [] - else: - # Unknown role - just add to current invocation - current_invocation.append(i) - i += 1 - - # Add any remaining messages as final invocation - if current_invocation: - invocations.append(current_invocation) - - return invocations - - async def before_model_callback( - self, *, callback_context: CallbackContext, llm_request: LlmRequest - ) -> Optional[LlmResponse]: - """Filters the LLM request's context before it is sent to the model. - - This method groups messages into logical invocations and keeps only the - most recent N invocations, ensuring tool call sequences remain intact. - - Args: - callback_context: Context containing invocation and agent information. - llm_request: The LLM request to filter. - - Returns: - None - the request is modified in place. - """ - try: - contents = llm_request.contents - - if not contents: - return None - - # Apply invocation-based filtering if configured - if ( - self._num_invocations_to_keep is not None - and self._num_invocations_to_keep > 0 - ): - # Group messages into logical invocations - invocations = self._group_into_invocations(contents) - - logger.info( - "ToolAwareContextFilter: Total invocations=%d, keeping last %d", - len(invocations), - self._num_invocations_to_keep, - ) - - # Keep only the last N invocations - if len(invocations) > self._num_invocations_to_keep: - invocations_to_keep = invocations[-self._num_invocations_to_keep :] - - # Flatten the list of indices - indices_to_keep = [] - for invocation in invocations_to_keep: - indices_to_keep.extend(invocation) - - # Filter contents based on indices - filtered_contents = [contents[i] for i in sorted(indices_to_keep)] - - logger.info( - "ToolAwareContextFilter: Reduced from %d messages to %d messages" - " (kept %d invocations)", - len(contents), - len(filtered_contents), - len(invocations_to_keep), - ) - - contents = filtered_contents - - # Apply custom filter if provided - if self._custom_filter: - contents = self._custom_filter(contents) - - llm_request.contents = contents - - except Exception as e: - logger.error("ToolAwareContextFilter: Failed to filter context: %s", e) - - return None \ No newline at end of file