diff --git a/src/google/adk/plugins/context_filter_plugin.py b/src/google/adk/plugins/context_filter_plugin.py index b778de02ad..5d62e8f57b 100644 --- a/src/google/adk/plugins/context_filter_plugin.py +++ b/src/google/adk/plugins/context_filter_plugin.py @@ -36,6 +36,8 @@ def __init__( num_invocations_to_keep: Optional[int] = None, custom_filter: Optional[Callable[[List[Event]], List[Event]]] = None, name: str = "context_filter_plugin", + remove_amount: int = 1 + ): """Initializes the context management plugin. @@ -45,10 +47,12 @@ def __init__( by a model response. custom_filter: A function to filter the context. name: The name of the plugin instance. + remove_amount: The amount to remove the context. """ super().__init__(name) self._num_invocations_to_keep = num_invocations_to_keep self._custom_filter = custom_filter + self._remove_amount = remove_amount async def before_model_callback( self, *, callback_context: CallbackContext, llm_request: LlmRequest @@ -60,9 +64,10 @@ async def before_model_callback( if ( self._num_invocations_to_keep is not None and self._num_invocations_to_keep > 0 + and self._remove_amount > 0 ): num_model_turns = sum(1 for c in contents if c.role == "model") - if num_model_turns >= self._num_invocations_to_keep: + if num_model_turns >= self._num_invocations_to_keep + self._remove_amount - 1: model_turns_to_find = self._num_invocations_to_keep split_index = 0 for i in range(len(contents) - 1, -1, -1): diff --git a/tests/unittests/plugins/test_context_filtering_plugin.py b/tests/unittests/plugins/test_context_filtering_plugin.py index f9c8222ea3..f8c0c4f99b 100644 --- a/tests/unittests/plugins/test_context_filtering_plugin.py +++ b/tests/unittests/plugins/test_context_filtering_plugin.py @@ -183,3 +183,110 @@ def faulty_filter(contents): ) assert llm_request.contents == original_contents + + +@pytest.mark.asyncio +async def test_filter_with_remove_amount(): + """Tests that remove_amount correctly removes additional invocations.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=2, remove_amount=1) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + _create_content("user", "user_prompt_3"), + _create_content("model", "model_response_3"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # With num_invocations_to_keep=2 and remove_amount=1, should keep last 2 invocations + assert len(llm_request.contents) == 4 + assert llm_request.contents[0].parts[0].text == "user_prompt_2" + assert llm_request.contents[1].parts[0].text == "model_response_2" + assert llm_request.contents[2].parts[0].text == "user_prompt_3" + assert llm_request.contents[3].parts[0].text == "model_response_3" + + +@pytest.mark.asyncio +async def test_filter_with_higher_remove_amount(): + """Tests remove_amount with a higher value to remove more invocations.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=3, remove_amount=2) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + _create_content("user", "user_prompt_3"), + _create_content("model", "model_response_3"), + _create_content("user", "user_prompt_4"), + _create_content("model", "model_response_4"), + _create_content("user", "user_prompt_5"), + _create_content("model", "model_response_5"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # With num_invocations_to_keep=3 and remove_amount=2, keeps last 3 invocations + assert len(llm_request.contents) == 6 + assert llm_request.contents[0].parts[0].text == "user_prompt_3" + assert llm_request.contents[1].parts[0].text == "model_response_3" + assert llm_request.contents[2].parts[0].text == "user_prompt_4" + assert llm_request.contents[3].parts[0].text == "model_response_4" + assert llm_request.contents[4].parts[0].text == "user_prompt_5" + assert llm_request.contents[5].parts[0].text == "model_response_5" + + +@pytest.mark.asyncio +async def test_filter_with_zero_remove_amount(): + """Tests that remove_amount=0 disables the filtering logic.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=1, remove_amount=0) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2"), + _create_content("model", "model_response_2"), + ] + llm_request = LlmRequest(contents=contents) + original_contents = list(llm_request.contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # With remove_amount=0, filtering should be disabled + assert llm_request.contents == original_contents + + +@pytest.mark.asyncio +async def test_filter_remove_amount_with_multiple_user_turns(): + """Tests remove_amount with multiple user turns in invocations.""" + plugin = ContextFilterPlugin(num_invocations_to_keep=2, remove_amount=1) + contents = [ + _create_content("user", "user_prompt_1"), + _create_content("model", "model_response_1"), + _create_content("user", "user_prompt_2a"), + _create_content("user", "user_prompt_2b"), + _create_content("model", "model_response_2"), + _create_content("user", "user_prompt_3"), + _create_content("model", "model_response_3"), + ] + llm_request = LlmRequest(contents=contents) + + await plugin.before_model_callback( + callback_context=Mock(spec=CallbackContext), llm_request=llm_request + ) + + # Should keep last 2 invocations including multiple user turns + assert len(llm_request.contents) == 5 + assert llm_request.contents[0].parts[0].text == "user_prompt_2a" + assert llm_request.contents[1].parts[0].text == "user_prompt_2b" + assert llm_request.contents[2].parts[0].text == "model_response_2" + assert llm_request.contents[3].parts[0].text == "user_prompt_3" + assert llm_request.contents[4].parts[0].text == "model_response_3"