11"""Summarization middleware."""
22
33import uuid
4- from collections .abc import Callable , Iterable
5- from typing import Any , cast
4+ import warnings
5+ from collections .abc import Callable , Iterable , Mapping
6+ from typing import Any , Literal , cast
67
78from langchain_core .messages import (
89 AIMessage ,
5152{messages}
5253</messages>""" # noqa: E501
5354
54- SUMMARY_PREFIX = "## Previous conversation summary:"
55-
5655_DEFAULT_MESSAGES_TO_KEEP = 20
5756_DEFAULT_TRIM_TOKEN_LIMIT = 4000
5857_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
5958_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
6059
60+ ContextFraction = tuple [Literal ["fraction" ], float ]
61+ ContextTokens = tuple [Literal ["tokens" ], int ]
62+ ContextMessages = tuple [Literal ["messages" ], int ]
63+
64+ ContextSize = ContextFraction | ContextTokens | ContextMessages
65+
6166
6267class SummarizationMiddleware (AgentMiddleware ):
6368 """Summarizes conversation history when token limits are approached.
@@ -70,48 +75,97 @@ class SummarizationMiddleware(AgentMiddleware):
7075 def __init__ (
7176 self ,
7277 model : str | BaseChatModel ,
73- max_tokens_before_summary : int | None = None ,
74- messages_to_keep : int = _DEFAULT_MESSAGES_TO_KEEP ,
78+ * ,
79+ trigger : ContextSize | list [ContextSize ] | None = None ,
80+ keep : ContextSize = ("messages" , _DEFAULT_MESSAGES_TO_KEEP ),
7581 token_counter : TokenCounter = count_tokens_approximately ,
7682 summary_prompt : str = DEFAULT_SUMMARY_PROMPT ,
77- summary_prefix : str = SUMMARY_PREFIX ,
83+ trim_tokens_to_summarize : int | None = _DEFAULT_TRIM_TOKEN_LIMIT ,
84+ ** deprecated_kwargs : Any ,
7885 ) -> None :
7986 """Initialize summarization middleware.
8087
8188 Args:
8289 model: The language model to use for generating summaries.
83- max_tokens_before_summary: Token threshold to trigger summarization.
84- If `None`, summarization is disabled.
85- messages_to_keep: Number of recent messages to preserve after summarization.
90+ trigger: One or more thresholds that trigger summarization. Provide a single
91+ `ContextSize` tuple or a list of tuples, in which case summarization runs
92+ when any threshold is breached. Examples: `("messages", 50)`, `("tokens", 3000)`,
93+ `[("fraction", 0.8), ("messages", 100)]`.
94+ keep: Context retention policy applied after summarization. Provide a
95+ `ContextSize` tuple to specify how much history to preserve. Defaults to
96+ keeping the most recent 20 messages. Examples: `("messages", 20)`,
97+ `("tokens", 3000)`, or `("fraction", 0.3)`.
8698 token_counter: Function to count tokens in messages.
8799 summary_prompt: Prompt template for generating summaries.
88- summary_prefix: Prefix added to system message when including summary.
100+ trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for the
101+ summarization call. Pass `None` to skip trimming entirely.
89102 """
103+ # Handle deprecated parameters
104+ if "max_tokens_before_summary" in deprecated_kwargs :
105+ value = deprecated_kwargs ["max_tokens_before_summary" ]
106+ warnings .warn (
107+ "max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead." ,
108+ DeprecationWarning ,
109+ stacklevel = 2 ,
110+ )
111+ if trigger is None and value is not None :
112+ trigger = ("tokens" , value )
113+
114+ if "messages_to_keep" in deprecated_kwargs :
115+ value = deprecated_kwargs ["messages_to_keep" ]
116+ warnings .warn (
117+ "messages_to_keep is deprecated. Use keep=('messages', value) instead." ,
118+ DeprecationWarning ,
119+ stacklevel = 2 ,
120+ )
121+ if keep == ("messages" , _DEFAULT_MESSAGES_TO_KEEP ):
122+ keep = ("messages" , value )
123+
90124 super ().__init__ ()
91125
92126 if isinstance (model , str ):
93127 model = init_chat_model (model )
94128
95129 self .model = model
96- self .max_tokens_before_summary = max_tokens_before_summary
97- self .messages_to_keep = messages_to_keep
130+ if trigger is None :
131+ self .trigger : ContextSize | list [ContextSize ] | None = None
132+ trigger_conditions : list [ContextSize ] = []
133+ elif isinstance (trigger , list ):
134+ validated_list = [self ._validate_context_size (item , "trigger" ) for item in trigger ]
135+ self .trigger = validated_list
136+ trigger_conditions = validated_list
137+ else :
138+ validated = self ._validate_context_size (trigger , "trigger" )
139+ self .trigger = validated
140+ trigger_conditions = [validated ]
141+ self ._trigger_conditions = trigger_conditions
142+
143+ self .keep = self ._validate_context_size (keep , "keep" )
98144 self .token_counter = token_counter
99145 self .summary_prompt = summary_prompt
100- self .summary_prefix = summary_prefix
146+ self .trim_tokens_to_summarize = trim_tokens_to_summarize
147+
148+ requires_profile = any (condition [0 ] == "fraction" for condition in self ._trigger_conditions )
149+ if self .keep [0 ] == "fraction" :
150+ requires_profile = True
151+ if requires_profile and self ._get_profile_limits () is None :
152+ msg = (
153+ "Model profile information is required to use fractional token limits. "
154+ 'pip install "langchain[model-profiles]" or use absolute token counts '
155+ "instead."
156+ )
157+ raise ValueError (msg )
101158
102159 def before_model (self , state : AgentState , runtime : Runtime ) -> dict [str , Any ] | None : # noqa: ARG002
103160 """Process messages before model invocation, potentially triggering summarization."""
104161 messages = state ["messages" ]
105162 self ._ensure_message_ids (messages )
106163
107164 total_tokens = self .token_counter (messages )
108- if (
109- self .max_tokens_before_summary is not None
110- and total_tokens < self .max_tokens_before_summary
111- ):
165+ if not self ._should_summarize (messages , total_tokens ):
112166 return None
113167
114- cutoff_index = self ._find_safe_cutoff (messages )
168+ cutoff_index = self ._determine_cutoff_index (messages )
115169
116170 if cutoff_index <= 0 :
117171 return None
@@ -129,6 +183,124 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] |
129183 ]
130184 }
131185
186+ def _should_summarize (self , messages : list [AnyMessage ], total_tokens : int ) -> bool :
187+ """Determine whether summarization should run for the current token usage."""
188+ if not self ._trigger_conditions :
189+ return False
190+
191+ for kind , value in self ._trigger_conditions :
192+ if kind == "messages" and len (messages ) >= value :
193+ return True
194+ if kind == "tokens" and total_tokens >= value :
195+ return True
196+ if kind == "fraction" :
197+ max_input_tokens = self ._get_profile_limits ()
198+ if max_input_tokens is None :
199+ continue
200+ threshold = int (max_input_tokens * value )
201+ if threshold <= 0 :
202+ threshold = 1
203+ if total_tokens >= threshold :
204+ return True
205+ return False
206+
207+ def _determine_cutoff_index (self , messages : list [AnyMessage ]) -> int :
208+ """Choose cutoff index respecting retention configuration."""
209+ kind , value = self .keep
210+ if kind in {"tokens" , "fraction" }:
211+ token_based_cutoff = self ._find_token_based_cutoff (messages )
212+ if token_based_cutoff is not None :
213+ return token_based_cutoff
214+ # None cutoff -> model profile data not available (caught in __init__ but
215+ # here for safety), fallback to message count
216+ return self ._find_safe_cutoff (messages , _DEFAULT_MESSAGES_TO_KEEP )
217+ return self ._find_safe_cutoff (messages , cast ("int" , value ))
218+
219+ def _find_token_based_cutoff (self , messages : list [AnyMessage ]) -> int | None :
220+ """Find cutoff index based on target token retention."""
221+ if not messages :
222+ return 0
223+
224+ kind , value = self .keep
225+ if kind == "fraction" :
226+ max_input_tokens = self ._get_profile_limits ()
227+ if max_input_tokens is None :
228+ return None
229+ target_token_count = int (max_input_tokens * value )
230+ elif kind == "tokens" :
231+ target_token_count = int (value )
232+ else :
233+ return None
234+
235+ if target_token_count <= 0 :
236+ target_token_count = 1
237+
238+ if self .token_counter (messages ) <= target_token_count :
239+ return 0
240+
241+ # Use binary search to identify the earliest message index that keeps the
242+ # suffix within the token budget.
243+ left , right = 0 , len (messages )
244+ cutoff_candidate = len (messages )
245+ max_iterations = len (messages ).bit_length () + 1
246+ for _ in range (max_iterations ):
247+ if left >= right :
248+ break
249+
250+ mid = (left + right ) // 2
251+ if self .token_counter (messages [mid :]) <= target_token_count :
252+ cutoff_candidate = mid
253+ right = mid
254+ else :
255+ left = mid + 1
256+
257+ if cutoff_candidate == len (messages ):
258+ cutoff_candidate = left
259+
260+ if cutoff_candidate >= len (messages ):
261+ if len (messages ) == 1 :
262+ return 0
263+ cutoff_candidate = len (messages ) - 1
264+
265+ for i in range (cutoff_candidate , - 1 , - 1 ):
266+ if self ._is_safe_cutoff_point (messages , i ):
267+ return i
268+
269+ return 0
270+
271+ def _get_profile_limits (self ) -> int | None :
272+ """Retrieve max input token limit from the model profile."""
273+ try :
274+ profile = self .model .profile
275+ except (AttributeError , ImportError ):
276+ return None
277+
278+ if not isinstance (profile , Mapping ):
279+ return None
280+
281+ max_input_tokens = profile .get ("max_input_tokens" )
282+
283+ if not isinstance (max_input_tokens , int ):
284+ return None
285+
286+ return max_input_tokens
287+
288+ def _validate_context_size (self , context : ContextSize , parameter_name : str ) -> ContextSize :
289+ """Validate context configuration tuples."""
290+ kind , value = context
291+ if kind == "fraction" :
292+ if not 0 < value <= 1 :
293+ msg = f"Fractional { parameter_name } values must be between 0 and 1, got { value } ."
294+ raise ValueError (msg )
295+ elif kind in {"tokens" , "messages" }:
296+ if value <= 0 :
297+ msg = f"{ parameter_name } thresholds must be greater than 0, got { value } ."
298+ raise ValueError (msg )
299+ else :
300+ msg = f"Unsupported context size type { kind } for { parameter_name } ."
301+ raise ValueError (msg )
302+ return context
303+
132304 def _build_new_messages (self , summary : str ) -> list [HumanMessage ]:
133305 return [
134306 HumanMessage (content = f"Here is a summary of the conversation to date:\n \n { summary } " )
@@ -151,16 +323,16 @@ def _partition_messages(
151323
152324 return messages_to_summarize , preserved_messages
153325
154- def _find_safe_cutoff (self , messages : list [AnyMessage ]) -> int :
326+ def _find_safe_cutoff (self , messages : list [AnyMessage ], messages_to_keep : int ) -> int :
155327 """Find safe cutoff point that preserves AI/Tool message pairs.
156328
157329 Returns the index where messages can be safely cut without separating
158330 related AI and Tool messages. Returns 0 if no safe cutoff is found.
159331 """
160- if len (messages ) <= self . messages_to_keep :
332+ if len (messages ) <= messages_to_keep :
161333 return 0
162334
163- target_cutoff = len (messages ) - self . messages_to_keep
335+ target_cutoff = len (messages ) - messages_to_keep
164336
165337 for i in range (target_cutoff , - 1 , - 1 ):
166338 if self ._is_safe_cutoff_point (messages , i ):
@@ -229,16 +401,18 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
229401
230402 try :
231403 response = self .model .invoke (self .summary_prompt .format (messages = trimmed_messages ))
232- return cast ( "str" , response .content ) .strip ()
404+ return response .text .strip ()
233405 except Exception as e : # noqa: BLE001
234406 return f"Error generating summary: { e !s} "
235407
236408 def _trim_messages_for_summary (self , messages : list [AnyMessage ]) -> list [AnyMessage ]:
237409 """Trim messages to fit within summary generation limits."""
238410 try :
411+ if self .trim_tokens_to_summarize is None :
412+ return messages
239413 return trim_messages (
240414 messages ,
241- max_tokens = _DEFAULT_TRIM_TOKEN_LIMIT ,
415+ max_tokens = self . trim_tokens_to_summarize ,
242416 token_counter = self .token_counter ,
243417 start_on = "human" ,
244418 strategy = "last" ,
0 commit comments