Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pathlib import Path

from langchain.chat_models import init_chat_model
from langchain_core.messages import HumanMessage
from langchain_openai import OpenAIEmbeddings

from ursa.agents import ExecutionAgent
from ursa.observability.timing import render_session_summary
from ursa.util import Checkpointer

### Run a simple example of an Execution Agent.

Expand All @@ -21,31 +22,19 @@
max_completion_tokens=30000,
)

embedding_kwargs = None
embedding_model = OpenAIEmbeddings(**(embedding_kwargs or {}))

workspace = Path("./workspace_BO")
checkpointer = Checkpointer.from_workspace(workspace)

# Initialize the agent
executor = ExecutionAgent(
llm=model,
enable_metrics=True,
) # , enable_metrics=False if you don't want metrics

set_workspace = False

if set_workspace:
# Syntax if you want to explicitly set the directory to work in
init = {
"messages": [HumanMessage(content=problem)],
"workspace": "workspace_BO",
}

print(f"\nSolving problem: {problem}\n")
thread_id="BO_test",
workspace=workspace,
checkpointer=checkpointer,
)

# Solve the problem
final_results = executor.invoke(init)

else:
final_results = executor.invoke(problem)
final_results = executor.invoke(problem)

render_session_summary(executor.thread_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pathlib import Path

from langchain.chat_models import init_chat_model

from ursa.agents import ExecutionAgent
from ursa.observability.timing import render_session_summary
from ursa.util import Checkpointer

### Run a simple example of continuing the Execution Agent from a checkpoint.

problem = """
Make a plot of the evaluations of the target function with the running minimum overlaid to show convergence.
Make a second plot highlighting the important inputs of the function.
"""

model = init_chat_model(
model="openai:gpt-5-mini",
max_completion_tokens=30000,
)

workspace = Path(
"./workspace_BO"
) # Point at the same workspace as the original run.
checkpointer = Checkpointer.from_workspace(workspace)

# Initialize the agent
executor = ExecutionAgent(
llm=model,
enable_metrics=True,
thread_id="BO_test", # Set the thread_id to the same as the previous result
workspace=workspace,
checkpointer=checkpointer,
)

final_results = executor.invoke(problem)


render_session_summary(executor.thread_id)
108 changes: 35 additions & 73 deletions src/ursa/agents/execution_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
# from langchain_core.runnables.graph import MermaidDrawMethod
from pathlib import Path
from typing import (
Annotated,
Any,
Literal,
Optional,
Expand All @@ -39,12 +40,14 @@
from langchain_core.messages import (
AIMessage,
AnyMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.messages.utils import count_tokens_approximately
from langchain_core.output_parsers import StrOutputParser
from langgraph.types import Command
from langgraph.graph.message import add_messages
from langgraph.types import Overwrite

# Rich
from rich import get_console
Expand Down Expand Up @@ -88,9 +91,9 @@ class ExecutionState(TypedDict):
is_linked).
"""

messages: list[AnyMessage]
messages: Annotated[list[AnyMessage], add_messages]
current_progress: str
code_files: list[str]
code_files: set[str]
workspace: Path
symlinkdir: dict
model: BaseChatModel
Expand Down Expand Up @@ -265,10 +268,6 @@ def _summarize_context(self, state: ExecutionState) -> ExecutionState:
conversation_to_summarize.append(msg)
conversation_to_keep.remove(msg)
tool_ids.remove(msg.tool_call_id)
elif isinstance(msg, ToolMessage):
print(
f"This is a Tool that happened in the last {self.messages_to_keep} messages:\n{str(msg)}\nbut not in {tool_ids}"
)
if tool_ids:
# We may need to implement something here for if a tool has not
# responded but its tool call is far enough back that it is being
Expand Down Expand Up @@ -314,7 +313,6 @@ def _summarize_context(self, state: ExecutionState) -> ExecutionState:
new_state["messages"] = summarized_messages
return new_state

# Define the function that calls the model
def query_executor(self, state: ExecutionState) -> ExecutionState:
"""Prepare workspace, handle optional symlinks, and invoke the executor LLM.

Expand Down Expand Up @@ -384,7 +382,7 @@ def query_executor(self, state: ExecutionState) -> ExecutionState:
else:
new_state["messages"] = [
SystemMessage(content=self.executor_prompt)
] + state["messages"]
] + new_state["messages"]

# 4) Invoke the LLM with the prepared message sequence.
try:
Expand All @@ -393,45 +391,20 @@ def query_executor(self, state: ExecutionState) -> ExecutionState:
)
new_state["messages"].append(response)
except Exception as e:
response = AIMessage(content=f"Response error {e}")
msg = new_state["messages"][-1].text
print("Error: ", e, " ", msg)
new_state["messages"].append(
AIMessage(content=f"Response error {e}")
)
new_state["messages"].append(response)

# 5) Optionally persist the pre-invocation state for audit/debugging.
if self.log_state:
self.write_state("execution_agent.json", new_state)

# Return the model's response and the workspace path as a partial state update.
return new_state

def tool_use(self, state: ExecutionState) -> ExecutionState:
new_state = state.copy()
update = self.tool_node.invoke(state)
# Could be implemented better, but handles the different forms of tool response:
# dict of messages, list of Commands, etc.
try:
if isinstance(update, dict) and "messages" in update:
new_state["messages"].extend(update["messages"])
elif isinstance(update, list):
for resp in update:
if isinstance(resp, Command):
new_state["messages"].extend(resp.update["messages"])
new_state.setdefault("code_files", []).extend(
resp.update["code_files"]
)
else:
new_state["messages"].extend(resp["messages"])
elif isinstance(update, Command):
new_state["messages"].extend(update.update["messages"])
new_state.setdefault("code_files", []).extend(
update.update["code_files"]
)
except Exception as e:
print(f"SOMETHING IS WRONG WITH {update}: {e}")
new_state["messages"].extend(update["messages"])
return new_state
return {
"messages": Overwrite(new_state["messages"]),
"workspace": new_state["workspace"],
}

def recap(self, state: ExecutionState) -> ExecutionState:
"""Produce a concise summary of the conversation and optionally persist memory.
Expand All @@ -453,27 +426,21 @@ def recap(self, state: ExecutionState) -> ExecutionState:
new_state = self._summarize_context(new_state)

# 1) Construct the summarization message list (system prompt + prior messages).
if isinstance(new_state["messages"][0], SystemMessage):
messages = new_state["messages"].copy()
messages[0] = SystemMessage(content=recap_prompt)
else:
messages = [SystemMessage(content=recap_prompt)] + new_state[
"messages"
]
recap_message = HumanMessage(content=self.recap_prompt)
new_state["messages"] = new_state["messages"] + [recap_message]

# 2) Invoke the LLM to generate a recap; capture content even on failure.
response_content = ""
try:
response = self.llm.invoke(
messages, self.build_config(tags=["recap"])
input=new_state["messages"],
config=self.build_config(tags=["recap"]),
)
response_content = response.text
new_state["messages"].append(response)
except Exception as e:
print("Error: ", e, " ", messages[-1].text)
new_state["messages"].append(
AIMessage(content=f"Response error {e}")
)
response_content = f"Response error {e}"
response = AIMessage(content=response_content)
print("Error: ", e, " ", new_state["messages"][-1].text)

console.print(
Panel(
Markdown(response_content),
Expand Down Expand Up @@ -509,10 +476,11 @@ def recap(self, state: ExecutionState) -> ExecutionState:

# 4) Optionally write state to disk for debugging/auditing.
if self.log_state:
new_state["messages"].append(response)
self.write_state("execution_agent.json", new_state)

# 5) Return a partial state update with only the summary content.
return new_state
return {"messages": [recap_message, response]}

def safety_check(self, state: ExecutionState) -> ExecutionState:
"""Assess pending shell commands for safety and inject ToolMessages with results.
Expand All @@ -535,11 +503,9 @@ def safety_check(self, state: ExecutionState) -> ExecutionState:
last_msg = new_state["messages"][-1]

# 1.5) Check message history length and summarize to shorten the token usage:
new_state = self._summarize_context(new_state)

# 2) Evaluate any pending run_command tool calls for safety.
tool_responses: list[ToolMessage] = []
any_unsafe = False
tool_responses = []
for tool_call in last_msg.tool_calls:
if tool_call["name"] != "run_command":
continue
Expand All @@ -555,7 +521,6 @@ def safety_check(self, state: ExecutionState) -> ExecutionState:
)

if "[NO]" in safety_result:
any_unsafe = True
tool_response = (
"[UNSAFE] That command `{q}` was deemed unsafe and cannot be run.\n"
"For reason: {r}"
Expand All @@ -568,24 +533,21 @@ def safety_check(self, state: ExecutionState) -> ExecutionState:
console.print(
"[bold red][WARNING][/bold red] REASON:", tool_response
)
tool_responses.append(
ToolMessage(
content=tool_response, tool_call_id=tool_call["id"]
)
)
# last_msg.tool_calls.remove(tool_call)
else:
tool_response = f"Command `{query}` passed safety check."
console.print(
f"[green]Command passed safety check:[/green] {query}"
)

tool_responses.append(
ToolMessage(
content=tool_response,
tool_call_id=tool_call["id"],
)
)

# 3) If any command is unsafe, append all tool responses; otherwise keep state.
if any_unsafe:
new_state["messages"].extend(tool_responses)

return new_state
if tool_responses:
return {"messages": tool_responses}
else:
return {}

def _build_graph(self):
"""Construct and compile the agent's LangGraph state machine."""
Expand All @@ -599,7 +561,7 @@ def _build_graph(self):
# - "recap": summary/finalization step
# - "safety_check": gate for shell command safety
self.add_node(self.query_executor, "agent")
self.add_node(self.tool_use, "action")
self.add_node(self.tool_node, "action")
self.add_node(self.recap, "recap")
self.add_node(self.safety_check, "safety_check")

Expand Down
2 changes: 1 addition & 1 deletion src/ursa/prompt_library/execution_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

recap_prompt = """
You are a summarizing agent. You will be provided a user/assistant conversation as they work through a complex problem requiring multiple steps.
You are a summarizing agent. You have a user/assistant conversation as they work through a complex problem requiring multiple steps.

Your responsibilities is to write a condensed summary of the conversation.
- Keep all important points from the conversation.
Expand Down
22 changes: 17 additions & 5 deletions src/ursa/tools/write_code_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def write_code(
)

# Append the file to the list in agent's state for later reference
file_list = state.get("code_files", [])
file_list = state.get("code_files", set([]))
if filename not in file_list:
file_list.append(filename)
file_list.add(filename)

# Create a tool message to send back to acknowledge success.
msg = ToolMessage(
Expand All @@ -122,6 +122,7 @@ def edit_code(
old_code: str,
new_code: str,
filename: str,
tool_call_id: Annotated[str, InjectedToolCallId],
state: Annotated[dict, InjectedState],
) -> str:
"""Replace the **first** occurrence of *old_code* with *new_code* in *filename*.
Expand Down Expand Up @@ -183,9 +184,20 @@ def edit_code(
f"[bold bright_white on green] :heavy_check_mark: [/] "
f"[green]File updated:[/] {code_file}"
)
file_list = state.get("code_files", [])
file_list = state.get("code_files", set([]))
if code_file not in file_list:
file_list.append(filename)
file_list.add(filename)
state["code_files"] = file_list

return f"File {filename} updated successfully."
# Create a tool message to send back to acknowledge success.
msg = ToolMessage(
content=f"File {filename} updated successfully.",
tool_call_id=tool_call_id,
)

return Command(
update={
"code_files": file_list,
"messages": [msg],
}
)
6 changes: 4 additions & 2 deletions src/ursa/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def from_workspace(
return SqliteSaver(conn)

@classmethod
def from_path(cls, db_path: Path) -> SqliteSaver:
def from_path(
cls, db_path: Path, db_name: str = "checkpointer.db"
) -> SqliteSaver:
"""Make checkpointer sqlite db.

Args
Expand All @@ -26,5 +28,5 @@ def from_path(cls, db_path: Path) -> SqliteSaver:
"""

db_path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(str(db_path), check_same_thread=False)
conn = sqlite3.connect(str(db_path / db_name), check_same_thread=False)
return SqliteSaver(conn)