Skip to content

Commit dea75a2

Browse files
authored
Merge branch 'master' into fix/cache-strip-message-ids
2 parents fe812e2 + 69c7d1b commit dea75a2

File tree

8 files changed

+531
-58
lines changed

8 files changed

+531
-58
lines changed

libs/core/langchain_core/load/load.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -265,8 +265,6 @@ def _load(obj: Any) -> Any:
265265
return reviver(loaded_obj)
266266
if isinstance(obj, list):
267267
return [_load(o) for o in obj]
268-
if isinstance(obj, str) and obj in reviver.secrets_map:
269-
return reviver.secrets_map[obj]
270268
return obj
271269

272270
return _load(obj)

libs/core/tests/unit_tests/load/test_load.py

Lines changed: 0 additions & 11 deletions
This file was deleted.

libs/langchain_v1/langchain/agents/middleware/summarization.py

Lines changed: 198 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Summarization middleware."""
22

33
import uuid
4-
from collections.abc import Callable, Iterable
5-
from typing import Any, cast
4+
import warnings
5+
from collections.abc import Callable, Iterable, Mapping
6+
from typing import Any, Literal, cast
67

78
from langchain_core.messages import (
89
AIMessage,
@@ -51,13 +52,17 @@
5152
{messages}
5253
</messages>""" # noqa: E501
5354

54-
SUMMARY_PREFIX = "## Previous conversation summary:"
55-
5655
_DEFAULT_MESSAGES_TO_KEEP = 20
5756
_DEFAULT_TRIM_TOKEN_LIMIT = 4000
5857
_DEFAULT_FALLBACK_MESSAGE_COUNT = 15
5958
_SEARCH_RANGE_FOR_TOOL_PAIRS = 5
6059

60+
ContextFraction = tuple[Literal["fraction"], float]
61+
ContextTokens = tuple[Literal["tokens"], int]
62+
ContextMessages = tuple[Literal["messages"], int]
63+
64+
ContextSize = ContextFraction | ContextTokens | ContextMessages
65+
6166

6267
class SummarizationMiddleware(AgentMiddleware):
6368
"""Summarizes conversation history when token limits are approached.
@@ -70,48 +75,97 @@ class SummarizationMiddleware(AgentMiddleware):
7075
def __init__(
7176
self,
7277
model: str | BaseChatModel,
73-
max_tokens_before_summary: int | None = None,
74-
messages_to_keep: int = _DEFAULT_MESSAGES_TO_KEEP,
78+
*,
79+
trigger: ContextSize | list[ContextSize] | None = None,
80+
keep: ContextSize = ("messages", _DEFAULT_MESSAGES_TO_KEEP),
7581
token_counter: TokenCounter = count_tokens_approximately,
7682
summary_prompt: str = DEFAULT_SUMMARY_PROMPT,
77-
summary_prefix: str = SUMMARY_PREFIX,
83+
trim_tokens_to_summarize: int | None = _DEFAULT_TRIM_TOKEN_LIMIT,
84+
**deprecated_kwargs: Any,
7885
) -> None:
7986
"""Initialize summarization middleware.
8087
8188
Args:
8289
model: The language model to use for generating summaries.
83-
max_tokens_before_summary: Token threshold to trigger summarization.
84-
If `None`, summarization is disabled.
85-
messages_to_keep: Number of recent messages to preserve after summarization.
90+
trigger: One or more thresholds that trigger summarization. Provide a single
91+
`ContextSize` tuple or a list of tuples, in which case summarization runs
92+
when any threshold is breached. Examples: `("messages", 50)`, `("tokens", 3000)`,
93+
`[("fraction", 0.8), ("messages", 100)]`.
94+
keep: Context retention policy applied after summarization. Provide a
95+
`ContextSize` tuple to specify how much history to preserve. Defaults to
96+
keeping the most recent 20 messages. Examples: `("messages", 20)`,
97+
`("tokens", 3000)`, or `("fraction", 0.3)`.
8698
token_counter: Function to count tokens in messages.
8799
summary_prompt: Prompt template for generating summaries.
88-
summary_prefix: Prefix added to system message when including summary.
100+
trim_tokens_to_summarize: Maximum tokens to keep when preparing messages for the
101+
summarization call. Pass `None` to skip trimming entirely.
89102
"""
103+
# Handle deprecated parameters
104+
if "max_tokens_before_summary" in deprecated_kwargs:
105+
value = deprecated_kwargs["max_tokens_before_summary"]
106+
warnings.warn(
107+
"max_tokens_before_summary is deprecated. Use trigger=('tokens', value) instead.",
108+
DeprecationWarning,
109+
stacklevel=2,
110+
)
111+
if trigger is None and value is not None:
112+
trigger = ("tokens", value)
113+
114+
if "messages_to_keep" in deprecated_kwargs:
115+
value = deprecated_kwargs["messages_to_keep"]
116+
warnings.warn(
117+
"messages_to_keep is deprecated. Use keep=('messages', value) instead.",
118+
DeprecationWarning,
119+
stacklevel=2,
120+
)
121+
if keep == ("messages", _DEFAULT_MESSAGES_TO_KEEP):
122+
keep = ("messages", value)
123+
90124
super().__init__()
91125

92126
if isinstance(model, str):
93127
model = init_chat_model(model)
94128

95129
self.model = model
96-
self.max_tokens_before_summary = max_tokens_before_summary
97-
self.messages_to_keep = messages_to_keep
130+
if trigger is None:
131+
self.trigger: ContextSize | list[ContextSize] | None = None
132+
trigger_conditions: list[ContextSize] = []
133+
elif isinstance(trigger, list):
134+
validated_list = [self._validate_context_size(item, "trigger") for item in trigger]
135+
self.trigger = validated_list
136+
trigger_conditions = validated_list
137+
else:
138+
validated = self._validate_context_size(trigger, "trigger")
139+
self.trigger = validated
140+
trigger_conditions = [validated]
141+
self._trigger_conditions = trigger_conditions
142+
143+
self.keep = self._validate_context_size(keep, "keep")
98144
self.token_counter = token_counter
99145
self.summary_prompt = summary_prompt
100-
self.summary_prefix = summary_prefix
146+
self.trim_tokens_to_summarize = trim_tokens_to_summarize
147+
148+
requires_profile = any(condition[0] == "fraction" for condition in self._trigger_conditions)
149+
if self.keep[0] == "fraction":
150+
requires_profile = True
151+
if requires_profile and self._get_profile_limits() is None:
152+
msg = (
153+
"Model profile information is required to use fractional token limits. "
154+
'pip install "langchain[model-profiles]" or use absolute token counts '
155+
"instead."
156+
)
157+
raise ValueError(msg)
101158

102159
def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] | None: # noqa: ARG002
103160
"""Process messages before model invocation, potentially triggering summarization."""
104161
messages = state["messages"]
105162
self._ensure_message_ids(messages)
106163

107164
total_tokens = self.token_counter(messages)
108-
if (
109-
self.max_tokens_before_summary is not None
110-
and total_tokens < self.max_tokens_before_summary
111-
):
165+
if not self._should_summarize(messages, total_tokens):
112166
return None
113167

114-
cutoff_index = self._find_safe_cutoff(messages)
168+
cutoff_index = self._determine_cutoff_index(messages)
115169

116170
if cutoff_index <= 0:
117171
return None
@@ -129,6 +183,124 @@ def before_model(self, state: AgentState, runtime: Runtime) -> dict[str, Any] |
129183
]
130184
}
131185

186+
def _should_summarize(self, messages: list[AnyMessage], total_tokens: int) -> bool:
187+
"""Determine whether summarization should run for the current token usage."""
188+
if not self._trigger_conditions:
189+
return False
190+
191+
for kind, value in self._trigger_conditions:
192+
if kind == "messages" and len(messages) >= value:
193+
return True
194+
if kind == "tokens" and total_tokens >= value:
195+
return True
196+
if kind == "fraction":
197+
max_input_tokens = self._get_profile_limits()
198+
if max_input_tokens is None:
199+
continue
200+
threshold = int(max_input_tokens * value)
201+
if threshold <= 0:
202+
threshold = 1
203+
if total_tokens >= threshold:
204+
return True
205+
return False
206+
207+
def _determine_cutoff_index(self, messages: list[AnyMessage]) -> int:
208+
"""Choose cutoff index respecting retention configuration."""
209+
kind, value = self.keep
210+
if kind in {"tokens", "fraction"}:
211+
token_based_cutoff = self._find_token_based_cutoff(messages)
212+
if token_based_cutoff is not None:
213+
return token_based_cutoff
214+
# None cutoff -> model profile data not available (caught in __init__ but
215+
# here for safety), fallback to message count
216+
return self._find_safe_cutoff(messages, _DEFAULT_MESSAGES_TO_KEEP)
217+
return self._find_safe_cutoff(messages, cast("int", value))
218+
219+
def _find_token_based_cutoff(self, messages: list[AnyMessage]) -> int | None:
220+
"""Find cutoff index based on target token retention."""
221+
if not messages:
222+
return 0
223+
224+
kind, value = self.keep
225+
if kind == "fraction":
226+
max_input_tokens = self._get_profile_limits()
227+
if max_input_tokens is None:
228+
return None
229+
target_token_count = int(max_input_tokens * value)
230+
elif kind == "tokens":
231+
target_token_count = int(value)
232+
else:
233+
return None
234+
235+
if target_token_count <= 0:
236+
target_token_count = 1
237+
238+
if self.token_counter(messages) <= target_token_count:
239+
return 0
240+
241+
# Use binary search to identify the earliest message index that keeps the
242+
# suffix within the token budget.
243+
left, right = 0, len(messages)
244+
cutoff_candidate = len(messages)
245+
max_iterations = len(messages).bit_length() + 1
246+
for _ in range(max_iterations):
247+
if left >= right:
248+
break
249+
250+
mid = (left + right) // 2
251+
if self.token_counter(messages[mid:]) <= target_token_count:
252+
cutoff_candidate = mid
253+
right = mid
254+
else:
255+
left = mid + 1
256+
257+
if cutoff_candidate == len(messages):
258+
cutoff_candidate = left
259+
260+
if cutoff_candidate >= len(messages):
261+
if len(messages) == 1:
262+
return 0
263+
cutoff_candidate = len(messages) - 1
264+
265+
for i in range(cutoff_candidate, -1, -1):
266+
if self._is_safe_cutoff_point(messages, i):
267+
return i
268+
269+
return 0
270+
271+
def _get_profile_limits(self) -> int | None:
272+
"""Retrieve max input token limit from the model profile."""
273+
try:
274+
profile = self.model.profile
275+
except (AttributeError, ImportError):
276+
return None
277+
278+
if not isinstance(profile, Mapping):
279+
return None
280+
281+
max_input_tokens = profile.get("max_input_tokens")
282+
283+
if not isinstance(max_input_tokens, int):
284+
return None
285+
286+
return max_input_tokens
287+
288+
def _validate_context_size(self, context: ContextSize, parameter_name: str) -> ContextSize:
289+
"""Validate context configuration tuples."""
290+
kind, value = context
291+
if kind == "fraction":
292+
if not 0 < value <= 1:
293+
msg = f"Fractional {parameter_name} values must be between 0 and 1, got {value}."
294+
raise ValueError(msg)
295+
elif kind in {"tokens", "messages"}:
296+
if value <= 0:
297+
msg = f"{parameter_name} thresholds must be greater than 0, got {value}."
298+
raise ValueError(msg)
299+
else:
300+
msg = f"Unsupported context size type {kind} for {parameter_name}."
301+
raise ValueError(msg)
302+
return context
303+
132304
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
133305
return [
134306
HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}")
@@ -151,16 +323,16 @@ def _partition_messages(
151323

152324
return messages_to_summarize, preserved_messages
153325

154-
def _find_safe_cutoff(self, messages: list[AnyMessage]) -> int:
326+
def _find_safe_cutoff(self, messages: list[AnyMessage], messages_to_keep: int) -> int:
155327
"""Find safe cutoff point that preserves AI/Tool message pairs.
156328
157329
Returns the index where messages can be safely cut without separating
158330
related AI and Tool messages. Returns 0 if no safe cutoff is found.
159331
"""
160-
if len(messages) <= self.messages_to_keep:
332+
if len(messages) <= messages_to_keep:
161333
return 0
162334

163-
target_cutoff = len(messages) - self.messages_to_keep
335+
target_cutoff = len(messages) - messages_to_keep
164336

165337
for i in range(target_cutoff, -1, -1):
166338
if self._is_safe_cutoff_point(messages, i):
@@ -229,16 +401,18 @@ def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
229401

230402
try:
231403
response = self.model.invoke(self.summary_prompt.format(messages=trimmed_messages))
232-
return cast("str", response.content).strip()
404+
return response.text.strip()
233405
except Exception as e: # noqa: BLE001
234406
return f"Error generating summary: {e!s}"
235407

236408
def _trim_messages_for_summary(self, messages: list[AnyMessage]) -> list[AnyMessage]:
237409
"""Trim messages to fit within summary generation limits."""
238410
try:
411+
if self.trim_tokens_to_summarize is None:
412+
return messages
239413
return trim_messages(
240414
messages,
241-
max_tokens=_DEFAULT_TRIM_TOKEN_LIMIT,
415+
max_tokens=self.trim_tokens_to_summarize,
242416
token_counter=self.token_counter,
243417
start_on="human",
244418
strategy="last",

0 commit comments

Comments
 (0)