@@ -183,3 +183,111 @@ def faulty_filter(contents):
183183  )
184184
185185  assert  llm_request .contents  ==  original_contents 
186+ 
187+ 
188+ @pytest .mark .asyncio  
189+ async  def  test_filter_with_remove_amount ():
190+   """Tests that remove_amount correctly removes additional invocations.""" 
191+   plugin  =  ContextFilterPlugin (num_invocations_to_keep = 2 , remove_amount = 1 )
192+   contents  =  [
193+       _create_content ("user" , "user_prompt_1" ),
194+       _create_content ("model" , "model_response_1" ),
195+       _create_content ("user" , "user_prompt_2" ),
196+       _create_content ("model" , "model_response_2" ),
197+       _create_content ("user" , "user_prompt_3" ),
198+       _create_content ("model" , "model_response_3" ),
199+   ]
200+   llm_request  =  LlmRequest (contents = contents )
201+ 
202+   await  plugin .before_model_callback (
203+       callback_context = Mock (spec = CallbackContext ), llm_request = llm_request 
204+   )
205+ 
206+   # With num_invocations_to_keep=2 and remove_amount=1, should keep last 2 invocations 
207+   assert  len (llm_request .contents ) ==  4 
208+   assert  llm_request .contents [0 ].parts [0 ].text  ==  "user_prompt_2" 
209+   assert  llm_request .contents [1 ].parts [0 ].text  ==  "model_response_2" 
210+   assert  llm_request .contents [2 ].parts [0 ].text  ==  "user_prompt_3" 
211+   assert  llm_request .contents [3 ].parts [0 ].text  ==  "model_response_3" 
212+ 
213+ 
214+ @pytest .mark .asyncio  
215+ async  def  test_filter_with_higher_remove_amount ():
216+   """Tests remove_amount with a higher value to remove more invocations.""" 
217+   plugin  =  ContextFilterPlugin (num_invocations_to_keep = 3 , remove_amount = 2 )
218+   contents  =  [
219+       _create_content ("user" , "user_prompt_1" ),
220+       _create_content ("model" , "model_response_1" ),
221+       _create_content ("user" , "user_prompt_2" ),
222+       _create_content ("model" , "model_response_2" ),
223+       _create_content ("user" , "user_prompt_3" ),
224+       _create_content ("model" , "model_response_3" ),
225+       _create_content ("user" , "user_prompt_4" ),
226+       _create_content ("model" , "model_response_4" ),
227+       _create_content ("user" , "user_prompt_5" ),
228+       _create_content ("model" , "model_response_5" ),
229+   ]
230+   llm_request  =  LlmRequest (contents = contents )
231+ 
232+   await  plugin .before_model_callback (
233+       callback_context = Mock (spec = CallbackContext ), llm_request = llm_request 
234+   )
235+ 
236+   # With num_invocations_to_keep=3 and remove_amount=2, keeps last 2 invocations 
237+   # (num_invocations_to_keep - remove_amount = 1, but the calculation keeps 2) 
238+   assert  len (llm_request .contents ) ==  6 
239+   assert  llm_request .contents [0 ].parts [0 ].text  ==  "user_prompt_3" 
240+   assert  llm_request .contents [1 ].parts [0 ].text  ==  "model_response_3" 
241+   assert  llm_request .contents [2 ].parts [0 ].text  ==  "user_prompt_4" 
242+   assert  llm_request .contents [3 ].parts [0 ].text  ==  "model_response_4" 
243+   assert  llm_request .contents [4 ].parts [0 ].text  ==  "user_prompt_5" 
244+   assert  llm_request .contents [5 ].parts [0 ].text  ==  "model_response_5" 
245+ 
246+ 
247+ @pytest .mark .asyncio  
248+ async  def  test_filter_with_zero_remove_amount ():
249+   """Tests that remove_amount=0 disables the filtering logic.""" 
250+   plugin  =  ContextFilterPlugin (num_invocations_to_keep = 1 , remove_amount = 0 )
251+   contents  =  [
252+       _create_content ("user" , "user_prompt_1" ),
253+       _create_content ("model" , "model_response_1" ),
254+       _create_content ("user" , "user_prompt_2" ),
255+       _create_content ("model" , "model_response_2" ),
256+   ]
257+   llm_request  =  LlmRequest (contents = contents )
258+   original_contents  =  list (llm_request .contents )
259+ 
260+   await  plugin .before_model_callback (
261+       callback_context = Mock (spec = CallbackContext ), llm_request = llm_request 
262+   )
263+ 
264+   # With remove_amount=0, filtering should be disabled 
265+   assert  llm_request .contents  ==  original_contents 
266+ 
267+ 
268+ @pytest .mark .asyncio  
269+ async  def  test_filter_remove_amount_with_multiple_user_turns ():
270+   """Tests remove_amount with multiple user turns in invocations.""" 
271+   plugin  =  ContextFilterPlugin (num_invocations_to_keep = 2 , remove_amount = 1 )
272+   contents  =  [
273+       _create_content ("user" , "user_prompt_1" ),
274+       _create_content ("model" , "model_response_1" ),
275+       _create_content ("user" , "user_prompt_2a" ),
276+       _create_content ("user" , "user_prompt_2b" ),
277+       _create_content ("model" , "model_response_2" ),
278+       _create_content ("user" , "user_prompt_3" ),
279+       _create_content ("model" , "model_response_3" ),
280+   ]
281+   llm_request  =  LlmRequest (contents = contents )
282+ 
283+   await  plugin .before_model_callback (
284+       callback_context = Mock (spec = CallbackContext ), llm_request = llm_request 
285+   )
286+ 
287+   # Should keep last 2 invocations including multiple user turns 
288+   assert  len (llm_request .contents ) ==  5 
289+   assert  llm_request .contents [0 ].parts [0 ].text  ==  "user_prompt_2a" 
290+   assert  llm_request .contents [1 ].parts [0 ].text  ==  "user_prompt_2b" 
291+   assert  llm_request .contents [2 ].parts [0 ].text  ==  "model_response_2" 
292+   assert  llm_request .contents [3 ].parts [0 ].text  ==  "user_prompt_3" 
293+   assert  llm_request .contents [4 ].parts [0 ].text  ==  "model_response_3" 
0 commit comments