Skip to content

Commit 046b238

Browse files
Revert "Revert "Adding Schema for Tool Outputs"" (#894)
Reverts #892 --------- Co-authored-by: Rushil Patel <[email protected]> Co-authored-by: rushilpatel0 <[email protected]>
1 parent 5f91de7 commit 046b238

File tree

14 files changed

+466
-74
lines changed

14 files changed

+466
-74
lines changed

codegen-examples/examples/swebench_agent_run/local_run.ipynb

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"metadata": {},
3333
"outputs": [],
3434
"source": [
35-
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=20, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
35+
"await run_eval(use_existing_preds=None, dataset=\"lite\", length=5, repo=\"django/django\", num_workers=10, model=\"claude-3-7-sonnet-latest\")"
3636
]
3737
},
3838
{

src/codegen/agents/data.py

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class ToolMessageData(BaseMessage):
5252
tool_name: Optional[str] = None
5353
tool_response: Optional[str] = None
5454
tool_id: Optional[str] = None
55+
status: Optional[str] = None
5556

5657

5758
@dataclass

src/codegen/agents/tracer.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,14 @@ def extract_structured_data(self, chunk: dict[str, Any]) -> Optional[BaseMessage
7171
tool_calls = [ToolCall(name=tc.get("name"), arguments=tc.get("arguments"), id=tc.get("id")) for tc in tool_calls_data]
7272
return AssistantMessage(type=message_type, content=content, tool_calls=tool_calls)
7373
elif message_type == "tool":
74-
return ToolMessageData(type=message_type, content=content, tool_name=getattr(latest_message, "name", None), tool_response=content, tool_id=getattr(latest_message, "tool_call_id", None))
74+
return ToolMessageData(
75+
type=message_type,
76+
content=content,
77+
tool_name=getattr(latest_message, "name", None),
78+
tool_response=getattr(latest_message, "artifact", content),
79+
tool_id=getattr(latest_message, "tool_call_id", None),
80+
status=getattr(latest_message, "status", None),
81+
)
7582
elif message_type == "function":
7683
return FunctionMessageData(type=message_type, content=content)
7784
else:

src/codegen/extensions/langchain/graph.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def reasoner(self, state: GraphState) -> dict[str, Any]:
100100
messages.append(HumanMessage(content=query))
101101

102102
result = self.model.invoke([self.system_message, *messages])
103-
if isinstance(result, AIMessage):
103+
if isinstance(result, AIMessage) and not result.tool_calls:
104104
updated_messages = [*messages, result]
105105
return {"messages": updated_messages, "final_answer": result.content}
106106

@@ -455,7 +455,7 @@ def get_field_descriptions(tool_obj):
455455
return f"Error: Could not identify the tool you're trying to use.\n\nAvailable tools:\n{available_tools}\n\nPlease use one of the available tools with the correct parameters."
456456

457457
# For other types of errors
458-
return f"Error executing tool: {error_msg}\n\nPlease check your tool usage and try again with the correct parameters."
458+
return f"Error executing tool: {exception!s}\n\nPlease check your tool usage and try again with the correct parameters."
459459

460460
# Add nodes
461461
builder.add_node("reasoner", self.reasoner, retry=retry_policy)

src/codegen/extensions/langchain/tools.py

+30-21
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
"""Langchain tools for workspace operations."""
22

33
from collections.abc import Callable
4-
from typing import ClassVar, Literal
4+
from typing import Annotated, ClassVar, Literal, Optional
55

6+
from langchain_core.messages import ToolMessage
7+
from langchain_core.tools import InjectedToolCallId
68
from langchain_core.tools.base import BaseTool
79
from pydantic import BaseModel, Field
810

@@ -52,10 +54,11 @@ class ViewFileInput(BaseModel):
5254
"""Input for viewing a file."""
5355

5456
filepath: str = Field(..., description="Path to the file relative to workspace root")
55-
start_line: int | None = Field(None, description="Starting line number to view (1-indexed, inclusive)")
56-
end_line: int | None = Field(None, description="Ending line number to view (1-indexed, inclusive)")
57-
max_lines: int | None = Field(None, description="Maximum number of lines to view at once, defaults to 500")
58-
line_numbers: bool | None = Field(True, description="If True, add line numbers to the content (1-indexed)")
57+
start_line: Optional[int] = Field(None, description="Starting line number to view (1-indexed, inclusive)")
58+
end_line: Optional[int] = Field(None, description="Ending line number to view (1-indexed, inclusive)")
59+
max_lines: Optional[int] = Field(None, description="Maximum number of lines to view at once, defaults to 500")
60+
line_numbers: Optional[bool] = Field(True, description="If True, add line numbers to the content (1-indexed)")
61+
tool_call_id: Annotated[str, InjectedToolCallId]
5962

6063

6164
class ViewFileTool(BaseTool):
@@ -73,12 +76,13 @@ def __init__(self, codebase: Codebase) -> None:
7376

7477
def _run(
7578
self,
79+
tool_call_id: str,
7680
filepath: str,
77-
start_line: int | None = None,
78-
end_line: int | None = None,
79-
max_lines: int | None = None,
80-
line_numbers: bool | None = True,
81-
) -> str:
81+
start_line: Optional[int] = None,
82+
end_line: Optional[int] = None,
83+
max_lines: Optional[int] = None,
84+
line_numbers: Optional[bool] = True,
85+
) -> ToolMessage:
8286
result = view_file(
8387
self.codebase,
8488
filepath,
@@ -88,14 +92,15 @@ def _run(
8892
max_lines=max_lines if max_lines is not None else 500,
8993
)
9094

91-
return result.render()
95+
return result.render(tool_call_id)
9296

9397

9498
class ListDirectoryInput(BaseModel):
9599
"""Input for listing directory contents."""
96100

97101
dirpath: str = Field(default="./", description="Path to directory relative to workspace root")
98102
depth: int = Field(default=1, description="How deep to traverse. Use -1 for unlimited depth.")
103+
tool_call_id: Annotated[str, InjectedToolCallId]
99104

100105

101106
class ListDirectoryTool(BaseTool):
@@ -109,9 +114,9 @@ class ListDirectoryTool(BaseTool):
109114
def __init__(self, codebase: Codebase) -> None:
110115
super().__init__(codebase=codebase)
111116

112-
def _run(self, dirpath: str = "./", depth: int = 1) -> str:
117+
def _run(self, tool_call_id: str, dirpath: str = "./", depth: int = 1) -> ToolMessage:
113118
result = list_directory(self.codebase, dirpath, depth)
114-
return result.render()
119+
return result.render(tool_call_id)
115120

116121

117122
class SearchInput(BaseModel):
@@ -126,6 +131,7 @@ class SearchInput(BaseModel):
126131
page: int = Field(default=1, description="Page number to return (1-based, default: 1)")
127132
files_per_page: int = Field(default=10, description="Number of files to return per page (default: 10)")
128133
use_regex: bool = Field(default=False, description="Whether to treat query as a regex pattern (default: False)")
134+
tool_call_id: Annotated[str, InjectedToolCallId]
129135

130136

131137
class SearchTool(BaseTool):
@@ -139,16 +145,17 @@ class SearchTool(BaseTool):
139145
def __init__(self, codebase: Codebase) -> None:
140146
super().__init__(codebase=codebase)
141147

142-
def _run(self, query: str, file_extensions: list[str] | None = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> str:
148+
def _run(self, tool_call_id: str, query: str, file_extensions: Optional[list[str]] = None, page: int = 1, files_per_page: int = 10, use_regex: bool = False) -> ToolMessage:
143149
result = search(self.codebase, query, file_extensions=file_extensions, page=page, files_per_page=files_per_page, use_regex=use_regex)
144-
return result.render()
150+
return result.render(tool_call_id)
145151

146152

147153
class EditFileInput(BaseModel):
148154
"""Input for editing a file."""
149155

150156
filepath: str = Field(..., description="Path to the file to edit")
151157
content: str = Field(..., description="New content for the file")
158+
tool_call_id: Annotated[str, InjectedToolCallId]
152159

153160

154161
class EditFileTool(BaseTool):
@@ -181,9 +188,9 @@ class EditFileTool(BaseTool):
181188
def __init__(self, codebase: Codebase) -> None:
182189
super().__init__(codebase=codebase)
183190

184-
def _run(self, filepath: str, content: str) -> str:
191+
def _run(self, filepath: str, content: str, tool_call_id: str) -> str:
185192
result = edit_file(self.codebase, filepath, content)
186-
return result.render()
193+
return result.render(tool_call_id)
187194

188195

189196
class CreateFileInput(BaseModel):
@@ -340,6 +347,7 @@ class SemanticEditInput(BaseModel):
340347
edit_content: str = Field(..., description=FILE_EDIT_PROMPT)
341348
start: int = Field(default=1, description="Starting line number (1-indexed, inclusive). Default is 1.")
342349
end: int = Field(default=-1, description="Ending line number (1-indexed, inclusive). Default is -1 (end of file).")
350+
tool_call_id: Annotated[str, InjectedToolCallId]
343351

344352

345353
class SemanticEditTool(BaseTool):
@@ -353,10 +361,10 @@ class SemanticEditTool(BaseTool):
353361
def __init__(self, codebase: Codebase) -> None:
354362
super().__init__(codebase=codebase)
355363

356-
def _run(self, filepath: str, edit_content: str, start: int = 1, end: int = -1) -> str:
364+
def _run(self, filepath: str, tool_call_id: str, edit_content: str, start: int = 1, end: int = -1) -> ToolMessage:
357365
# Create the the draft editor mini llm
358366
result = semantic_edit(self.codebase, filepath, edit_content, start=start, end=end)
359-
return result.render()
367+
return result.render(tool_call_id)
360368

361369

362370
class RenameFileInput(BaseModel):
@@ -1033,6 +1041,7 @@ class RelaceEditInput(BaseModel):
10331041

10341042
filepath: str = Field(..., description="Path of the file relative to workspace root")
10351043
edit_snippet: str = Field(..., description=RELACE_EDIT_PROMPT)
1044+
tool_call_id: Annotated[str, InjectedToolCallId]
10361045

10371046

10381047
class RelaceEditTool(BaseTool):
@@ -1046,9 +1055,9 @@ class RelaceEditTool(BaseTool):
10461055
def __init__(self, codebase: Codebase) -> None:
10471056
super().__init__(codebase=codebase)
10481057

1049-
def _run(self, filepath: str, edit_snippet: str) -> str:
1058+
def _run(self, filepath: str, edit_snippet: str, tool_call_id: str) -> ToolMessage:
10501059
result = relace_edit(self.codebase, filepath, edit_snippet)
1051-
return result.render()
1060+
return result.render(tool_call_id=tool_call_id)
10521061

10531062

10541063
class ReflectionInput(BaseModel):

src/codegen/extensions/tools/edit_file.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,53 @@
11
"""Tool for editing file contents."""
22

3-
from typing import ClassVar
3+
from typing import TYPE_CHECKING, ClassVar, Optional
44

5+
from langchain_core.messages import ToolMessage
56
from pydantic import Field
67

78
from codegen.sdk.core.codebase import Codebase
89

910
from .observation import Observation
1011
from .replacement_edit import generate_diff
1112

13+
if TYPE_CHECKING:
14+
from .tool_output_types import EditFileArtifacts
15+
1216

1317
class EditFileObservation(Observation):
1418
"""Response from editing a file."""
1519

1620
filepath: str = Field(
1721
description="Path to the edited file",
1822
)
19-
diff: str = Field(
23+
diff: Optional[str] = Field(
24+
default=None,
2025
description="Unified diff showing the changes made",
2126
)
2227

2328
str_template: ClassVar[str] = "Edited file {filepath}"
2429

25-
def render(self) -> str:
30+
def render(self, tool_call_id: str) -> ToolMessage:
2631
"""Render edit results in a clean format."""
27-
return f"""[EDIT FILE]: {self.filepath}
28-
29-
{self.diff}"""
32+
if self.status == "error":
33+
artifacts_error: EditFileArtifacts = {"filepath": self.filepath, "error": self.error}
34+
return ToolMessage(
35+
content=f"[ERROR EDITING FILE]: {self.filepath}: {self.error}",
36+
status=self.status,
37+
name="edit_file",
38+
artifact=artifacts_error,
39+
tool_call_id=tool_call_id,
40+
)
41+
42+
artifacts_success: EditFileArtifacts = {"filepath": self.filepath, "diff": self.diff}
43+
44+
return ToolMessage(
45+
content=f"""[EDIT FILE]: {self.filepath}\n\n{self.diff}""",
46+
status=self.status,
47+
name="edit_file",
48+
artifact=artifacts_success,
49+
tool_call_id=tool_call_id,
50+
)
3051

3152

3253
def edit_file(codebase: Codebase, filepath: str, new_content: str) -> EditFileObservation:

0 commit comments

Comments
 (0)