Skip to content

Commit 3a3231f

Browse files
fix!: LLM truncation error catch (#906)
# Motivation The llm response sometimes gets truncated preventing it from calling the create file tool # Content We know check to see why the llm stopped producing tokens. If the reason is "max_tokens_reached" we return an error to the llm # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [ ] I have added tests for my changes - [ ] I have updated the documentation or added new documentation as needed
1 parent 5ace100 commit 3a3231f

File tree

5 files changed

+85
-18
lines changed

5 files changed

+85
-18
lines changed

src/codegen/extensions/langchain/graph.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,24 @@
66
import anthropic
77
import openai
88
from langchain.tools import BaseTool
9-
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage, ToolMessage
9+
from langchain_core.messages import (
10+
AIMessage,
11+
AnyMessage,
12+
HumanMessage,
13+
SystemMessage,
14+
ToolMessage,
15+
)
1016
from langchain_core.prompts import ChatPromptTemplate
17+
from langchain_core.stores import InMemoryBaseStore
1118
from langgraph.checkpoint.memory import MemorySaver
1219
from langgraph.graph import END, START
1320
from langgraph.graph.state import CompiledGraph, StateGraph
14-
from langgraph.prebuilt import ToolNode
1521
from langgraph.pregel import RetryPolicy
1622

1723
from codegen.agents.utils import AgentConfig
1824
from codegen.extensions.langchain.llm import LLM
1925
from codegen.extensions.langchain.prompts import SUMMARIZE_CONVERSATION_PROMPT
26+
from codegen.extensions.langchain.utils.custom_tool_node import CustomToolNode
2027
from codegen.extensions.langchain.utils.utils import get_max_model_input_tokens
2128

2229

@@ -87,6 +94,7 @@ def __init__(self, model: "LLM", tools: list[BaseTool], system_message: SystemMe
8794
self.config = config
8895
self.max_messages = config.get("max_messages", 100) if config else 100
8996
self.keep_first_messages = config.get("keep_first_messages", 1) if config else 1
97+
self.store = InMemoryBaseStore()
9098

9199
# =================================== NODES ====================================
92100

@@ -459,7 +467,7 @@ def get_field_descriptions(tool_obj):
459467

460468
# Add nodes
461469
builder.add_node("reasoner", self.reasoner, retry=retry_policy)
462-
builder.add_node("tools", ToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy)
470+
builder.add_node("tools", CustomToolNode(self.tools, handle_tool_errors=handle_tool_errors), retry=retry_policy)
463471
builder.add_node("summarize_conversation", self.summarize_conversation, retry=retry_policy)
464472

465473
# Add edges
@@ -471,7 +479,7 @@ def get_field_descriptions(tool_obj):
471479
)
472480
builder.add_conditional_edges("summarize_conversation", self.should_continue)
473481

474-
return builder.compile(checkpointer=checkpointer, debug=debug)
482+
return builder.compile(checkpointer=checkpointer, store=self.store, debug=debug)
475483

476484

477485
def create_react_agent(

src/codegen/extensions/langchain/llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def _get_model(self) -> BaseChatModel:
8989
if not os.getenv("ANTHROPIC_API_KEY"):
9090
msg = "ANTHROPIC_API_KEY not found in environment. Please set it in your .env file or environment variables."
9191
raise ValueError(msg)
92-
max_tokens = 16384 if "claude-3-7" in self.model_name else 8192
92+
max_tokens = 8192
9393
return ChatAnthropic(**self._get_model_kwargs(), max_tokens=max_tokens, max_retries=10, timeout=1000)
9494

9595
elif self.model_provider == "openai":

src/codegen/extensions/langchain/tools.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from typing import Annotated, ClassVar, Literal, Optional
55

66
from langchain_core.messages import ToolMessage
7+
from langchain_core.stores import InMemoryBaseStore
78
from langchain_core.tools import InjectedToolCallId
89
from langchain_core.tools.base import BaseTool
10+
from langgraph.prebuilt import InjectedStore
911
from pydantic import BaseModel, Field
1012

1113
from codegen.extensions.linear.linear_client import LinearClient
@@ -196,11 +198,13 @@ def _run(self, filepath: str, content: str, tool_call_id: str) -> str:
196198
class CreateFileInput(BaseModel):
197199
"""Input for creating a file."""
198200

201+
model_config = {"arbitrary_types_allowed": True}
199202
filepath: str = Field(..., description="Path where to create the file")
203+
store: Annotated[InMemoryBaseStore, InjectedStore()]
200204
content: str = Field(
201-
...,
205+
default="",
202206
description="""
203-
Content for the new file (REQUIRED).
207+
Content for the new file.
204208
205209
⚠️ IMPORTANT: This parameter MUST be a STRING, not a dictionary, JSON object, or any other data type.
206210
Example: content="print('Hello world')"
@@ -214,19 +218,14 @@ class CreateFileTool(BaseTool):
214218

215219
name: ClassVar[str] = "create_file"
216220
description: ClassVar[str] = """
217-
Create a new file in the codebase. Always provide content for the new file, even if minimal.
218-
219-
⚠️ CRITICAL WARNING ⚠️
220-
Both parameters MUST be provided as STRINGS:
221-
The content for the new file always needs to be provided.
221+
Create a new file in the codebase.
222222
223223
1. filepath: The path where to create the file (as a string)
224224
2. content: The content for the new file (as a STRING, NOT as a dictionary or JSON object)
225225
226226
✅ CORRECT usage:
227227
create_file(filepath="path/to/file.py", content="print('Hello world')")
228-
229-
The content parameter is REQUIRED and MUST be a STRING. If you receive a validation error about
228+
If you receive a validation error about
230229
missing content, you are likely trying to pass a dictionary instead of a string.
231230
"""
232231
args_schema: ClassVar[type[BaseModel]] = CreateFileInput
@@ -235,8 +234,15 @@ class CreateFileTool(BaseTool):
235234
def __init__(self, codebase: Codebase) -> None:
236235
super().__init__(codebase=codebase)
237236

238-
def _run(self, filepath: str, content: str) -> str:
239-
result = create_file(self.codebase, filepath, content)
237+
def _run(self, filepath: str, store: InMemoryBaseStore, content: str = "") -> str:
238+
create_file_tool_status = store.mget([self.name])[0]
239+
if create_file_tool_status and create_file_tool_status.get("max_tokens_reached", False):
240+
max_tokens = create_file_tool_status.get("max_tokens", None)
241+
store.mset([(self.name, {"max_tokens": max_tokens, "max_tokens_reached": False})])
242+
result = create_file(self.codebase, filepath, content, max_tokens=max_tokens)
243+
else:
244+
result = create_file(self.codebase, filepath, content)
245+
240246
return result.render()
241247

242248

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from typing import Any, Literal, Optional, Union
2+
3+
from langchain_core.messages import (
4+
AIMessage,
5+
AnyMessage,
6+
ToolCall,
7+
)
8+
from langchain_core.stores import InMemoryBaseStore
9+
from langgraph.prebuilt import ToolNode
10+
from pydantic import BaseModel
11+
12+
13+
class CustomToolNode(ToolNode):
14+
"""Extended ToolNode that detects truncated tool calls."""
15+
16+
def _parse_input(
17+
self,
18+
input: Union[
19+
list[AnyMessage],
20+
dict[str, Any],
21+
BaseModel,
22+
],
23+
store: Optional[InMemoryBaseStore],
24+
) -> tuple[list[ToolCall], Literal["list", "dict", "tool_calls"]]:
25+
"""Parse the input and check for truncated tool calls."""
26+
messages = input.get("messages", [])
27+
if isinstance(messages, list):
28+
if isinstance(messages[-1], AIMessage):
29+
response_metadata = messages[-1].response_metadata
30+
# Check if the stop reason is due to max tokens
31+
if response_metadata.get("stop_reason") == "max_tokens":
32+
# Check if the response metadata contains usage information
33+
if "usage" not in response_metadata or "output_tokens" not in response_metadata["usage"]:
34+
msg = "Response metadata is missing usage information."
35+
raise ValueError(msg)
36+
37+
output_tokens = response_metadata["usage"]["output_tokens"]
38+
for tool_call in messages[-1].tool_calls:
39+
if tool_call.get("name") == "create_file":
40+
# Set the max tokens and max tokens reached flag in the store
41+
store.mset([(tool_call["name"], {"max_tokens": output_tokens, "max_tokens_reached": True})])
42+
43+
return super()._parse_input(input, store)

src/codegen/extensions/tools/create_file.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Tool for creating new files."""
22

3-
from typing import ClassVar
3+
from typing import ClassVar, Optional
44

55
from pydantic import Field
66

@@ -23,7 +23,7 @@ class CreateFileObservation(Observation):
2323
str_template: ClassVar[str] = "Created file {filepath}"
2424

2525

26-
def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileObservation:
26+
def create_file(codebase: Codebase, filepath: str, content: str, max_tokens: Optional[int] = None) -> CreateFileObservation:
2727
"""Create a new file.
2828
2929
Args:
@@ -34,6 +34,16 @@ def create_file(codebase: Codebase, filepath: str, content: str) -> CreateFileOb
3434
Returns:
3535
CreateFileObservation containing new file state, or error if file exists
3636
"""
37+
if max_tokens:
38+
error = f"""Your response reached the max output tokens limit of {max_tokens} tokens (~ {max_tokens / 10} lines).
39+
Create the file in chunks or break up the content into smaller files.
40+
"""
41+
return CreateFileObservation(
42+
status="error",
43+
error=error,
44+
filepath=filepath,
45+
file_info=ViewFileObservation(status="error", error=error, filepath=filepath, content="", raw_content="", line_count=0),
46+
)
3747
if codebase.has_file(filepath):
3848
return CreateFileObservation(
3949
status="error",

0 commit comments

Comments
 (0)