diff --git a/src/ursa/agents/base.py b/src/ursa/agents/base.py index 97d5d45..6173dd4 100644 --- a/src/ursa/agents/base.py +++ b/src/ursa/agents/base.py @@ -16,7 +16,9 @@ """ import re +import sqlite3 from abc import ABC, abstractmethod +from dataclasses import dataclass from functools import cached_property from pathlib import Path from typing import ( @@ -41,6 +43,8 @@ from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.graph.state import CompiledStateGraph, StateGraph from langgraph.prebuilt import ToolNode +from langgraph.store.base import BaseStore +from langgraph.store.sqlite import SqliteStore from ursa.observability.timing import ( Telemetry, # for timing / telemetry / metrics @@ -50,6 +54,17 @@ TState = TypeVar("TState", bound=Mapping[str, Any]) +@dataclass(frozen=True, kw_only=True) +class AgentContext: + """Immutable context provided during graph execution""" + + workspace: Path + """ Workspace path for the agent """ + + tool_character_limit: int = 3000 + """ Suggested limit on tool call responses """ + + def _to_snake(s: str) -> str: """Convert a string to snake_case format. @@ -170,6 +185,11 @@ def name(self) -> str: """Agent name.""" return self.__class__.__name__ + @property + def context(self) -> AgentContext: + """Immutable run-scoped information provided to the Agent's graph""" + return AgentContext(workspace=self.workspace) + def add_node( self, f: Callable[..., Mapping[str, Any]], @@ -512,11 +532,23 @@ def _normalize_inputs(self, inputs: InputLike) -> Mapping[str, Any]: def compiled_graph(self) -> CompiledStateGraph: """Return the compiled StateGraph application for the agent.""" graph = self.build_graph() - compiled = graph.compile(checkpointer=self.checkpointer).with_config({ - "recursion_limit": 50000 - }) + compiled = graph.compile( + checkpointer=self.checkpointer, + store=self.storage, + ).with_config({"recursion_limit": 50000}) return self._finalize_graph(compiled) + @cached_property + def storage(self) -> BaseStore: + """Create a SQLite-backed LangGraph store for persistent graph data.""" + store_path = self.workspace / "graph_store.sqlite" + conn = sqlite3.connect( + store_path, check_same_thread=False, isolation_level=None + ) + store = SqliteStore(conn) + store.setup() + return store + @final def build_graph(self) -> StateGraph: """Build and return the StateGraph backing this agent.""" diff --git a/src/ursa/agents/execution_agent.py b/src/ursa/agents/execution_agent.py index 72b9fa6..2385e4b 100644 --- a/src/ursa/agents/execution_agent.py +++ b/src/ursa/agents/execution_agent.py @@ -43,7 +43,7 @@ ToolMessage, ) from langchain_core.messages.utils import count_tokens_approximately -from langchain_core.output_parsers import StrOutputParser +from langgraph.runtime import Runtime from langgraph.types import Command # Rich @@ -51,7 +51,7 @@ from rich.markdown import Markdown from rich.panel import Panel -from ursa.agents.base import AgentWithTools, BaseAgent +from ursa.agents.base import AgentContext, AgentWithTools, BaseAgent from ursa.prompt_library.execution_prompts import ( executor_prompt, get_safety_prompt, @@ -82,7 +82,6 @@ class ExecutionState(TypedDict): Fields: - messages: list of messages (System/Human/AI/Tool). - current_progress: short status string describing agent progress. - - code_files: list of filenames created or edited in the workspace. - workspace: path to the working directory where files and commands run. - symlinkdir: optional dict describing a symlink operation (source, dest, is_linked). @@ -90,7 +89,6 @@ class ExecutionState(TypedDict): messages: list[AnyMessage] current_progress: str - code_files: list[str] workspace: Path symlinkdir: dict model: BaseChatModel @@ -280,11 +278,11 @@ def _summarize_context(self, state: ExecutionState) -> ExecutionState: pass summarize_prompt = f""" - Your only tasks is to provide a detailed, comprehensive summary of the following - conversation. + Your only tasks is to provide a detailed, comprehensive summary of the following + conversation. - Your summary will be the only information retained from the conversation, so ensure - it contains all details that need to be remembered to meet the goals of the work. + Your summary will be the only information retained from the conversation, so ensure + it contains all details that need to be remembered to meet the goals of the work. Conversation to summarize: {conversation_to_summarize} @@ -418,16 +416,10 @@ def tool_use(self, state: ExecutionState) -> ExecutionState: for resp in update: if isinstance(resp, Command): new_state["messages"].extend(resp.update["messages"]) - new_state.setdefault("code_files", []).extend( - resp.update["code_files"] - ) else: new_state["messages"].extend(resp["messages"]) elif isinstance(update, Command): new_state["messages"].extend(update.update["messages"]) - new_state.setdefault("code_files", []).extend( - update.update["code_files"] - ) except Exception as e: print(f"SOMETHING IS WRONG WITH {update}: {e}") new_state["messages"].extend(update["messages"]) @@ -514,7 +506,9 @@ def recap(self, state: ExecutionState) -> ExecutionState: # 5) Return a partial state update with only the summary content. return new_state - def safety_check(self, state: ExecutionState) -> ExecutionState: + def safety_check( + self, state: ExecutionState, runtime: Runtime[AgentContext] + ) -> ExecutionState: """Assess pending shell commands for safety and inject ToolMessages with results. This method inspects the most recent AI tool calls, evaluates any run_command @@ -544,15 +538,18 @@ def safety_check(self, state: ExecutionState) -> ExecutionState: if tool_call["name"] != "run_command": continue - query = tool_call["args"]["query"] - safety_result = StrOutputParser().invoke( - self.llm.invoke( - self.get_safety_prompt( - query, self.safe_codes, new_state.get("code_files", []) - ), - self.build_config(tags=["safety_check"]), + if runtime.store is not None: + search_results = runtime.store.search( + ("workspace", "file_edit"), limit=1000 ) - ) + edited_files = [item.key for item in search_results] + else: + edited_files = [] + query = tool_call["args"]["query"] + safety_result = self.llm.invoke( + self.get_safety_prompt(query, self.safe_codes, edited_files), + self.build_config(tags=["safety_check"]), + ).text if "[NO]" in safety_result: any_unsafe = True diff --git a/src/ursa/tools/read_file_tool.py b/src/ursa/tools/read_file_tool.py index f50e3c8..a98e162 100644 --- a/src/ursa/tools/read_file_tool.py +++ b/src/ursa/tools/read_file_tool.py @@ -1,15 +1,12 @@ -import os -from typing import Annotated - +from langchain.tools import ToolRuntime from langchain_core.tools import tool -from langgraph.prebuilt import InjectedState +from ursa.agents.base import AgentContext from ursa.util.parse import read_pdf_text, read_text_file -# Tools for ExecutionAgent @tool -def read_file(filename: str, state: Annotated[dict, InjectedState]) -> str: +def read_file(filename: str, runtime: ToolRuntime[AgentContext]) -> str: """ Reads in a file with a given filename into a string. Can read in PDF or files that are text/ASCII. Uses a PDF parser if the filename ends @@ -18,12 +15,11 @@ def read_file(filename: str, state: Annotated[dict, InjectedState]) -> str: Args: filename: string filename to read in """ - workspace_dir = state["workspace"] - full_filename = os.path.join(workspace_dir, filename) + full_filename = runtime.context.workspace.joinpath(filename) print("[READING]: ", full_filename) try: - if full_filename.lower().endswith(".pdf"): + if full_filename.suffix == ".pdf": file_contents = read_pdf_text(full_filename) else: file_contents = read_text_file(full_filename) diff --git a/src/ursa/tools/run_command_tool.py b/src/ursa/tools/run_command_tool.py index b990f8d..d7f86bb 100644 --- a/src/ursa/tools/run_command_tool.py +++ b/src/ursa/tools/run_command_tool.py @@ -1,19 +1,15 @@ -import os import subprocess -from typing import Annotated +from pathlib import Path +from langchain.tools import ToolRuntime from langchain_core.tools import tool -from langgraph.prebuilt import InjectedState -# Global variables for the module. - -# Set a limit for message characters - the user could overload -# that in their env, or maybe we could pull this out of the LLM parameters -MAX_TOOL_MSG_CHARS = int(os.getenv("MAX_TOOL_MSG_CHARS", "30000")) +from ursa.agents.base import AgentContext +from ursa.util.types import AsciiStr @tool -def run_command(query: str, state: Annotated[dict, InjectedState]) -> str: +def run_command(query: AsciiStr, runtime: ToolRuntime[AgentContext]) -> str: """Execute a shell command in the workspace and return its combined output. Runs the specified command using subprocess.run in the given workspace @@ -29,9 +25,10 @@ def run_command(query: str, state: Annotated[dict, InjectedState]) -> str: A formatted string with "STDOUT:" followed by the truncated stdout and "STDERR:" followed by the truncated stderr. """ - workspace_dir = state["workspace"] + workspace_dir = Path(runtime.context.workspace) print("RUNNING: ", query) + try: result = subprocess.run( query, @@ -45,18 +42,10 @@ def run_command(query: str, state: Annotated[dict, InjectedState]) -> str: except KeyboardInterrupt: print("Keyboard Interrupt of command: ", query) stdout, stderr = "", "KeyboardInterrupt:" - except UnicodeDecodeError: - print( - f"Invalid Command: {query} - only 'utf-8' decodable characters allowed." - ) - stdout, stderr = ( - "", - f"Invalid Command: {query} - only 'utf-8' decodable characters allowed.:", - ) # Fit BOTH streams under a single overall cap stdout_fit, stderr_fit = _fit_streams_to_budget( - stdout or "", stderr or "", MAX_TOOL_MSG_CHARS + stdout or "", stderr or "", runtime.context.tool_character_limit ) print("STDOUT: ", stdout_fit) diff --git a/src/ursa/tools/write_code_tool.py b/src/ursa/tools/write_code_tool.py index d337df2..da15265 100644 --- a/src/ursa/tools/write_code_tool.py +++ b/src/ursa/tools/write_code_tool.py @@ -1,16 +1,16 @@ -import os -from typing import Annotated +import time +from pathlib import Path -from langchain_core.messages import ToolMessage -from langchain_core.tools import InjectedToolCallId, tool -from langgraph.prebuilt import InjectedState -from langgraph.types import Command +from langchain.tools import ToolRuntime +from langchain_core.tools import tool from rich import get_console from rich.panel import Panel from rich.syntax import Syntax +from ursa.agents.base import AgentContext from ursa.util.diff_renderer import DiffRenderer from ursa.util.parse import read_text_file +from ursa.util.types import AsciiStr console = get_console() @@ -39,27 +39,23 @@ def _strip_fences(snippet: str) -> str: return "\n".join(body.split("\n")[1:]) if "\n" in body else body.strip() -@tool +@tool(description="Write source code to a file") def write_code( code: str, - filename: str, - tool_call_id: Annotated[str, InjectedToolCallId], - state: Annotated[dict, InjectedState], -) -> Command: - """Write source code to a file and update the agent’s workspace state. + filename: AsciiStr, + runtime: ToolRuntime[AgentContext], +) -> str: + """Write source code to a file + + Records successful file edits to the graph's store Args: code: The source code content to be written to disk. filename: Name of the target file (including its extension). - tool_call_id: Identifier for this tool invocation. - state: Agent state dict holding workspace path and file list. - Returns: - Command: Contains an updated state (including code_files) and - a ToolMessage acknowledging success or failure. """ # Determine the full path to the target file - workspace_dir = state["workspace"] + workspace_dir = runtime.context.workspace console.print("[cyan]Writing file:[/]", filename) # Clean up markdown fences on submitted code. @@ -80,7 +76,7 @@ def write_code( ) # Write cleaned code to disk - code_file = os.path.join(workspace_dir, filename) + code_file = workspace_dir.joinpath(filename) try: with open(code_file, "w", encoding="utf-8") as f: f.write(code) @@ -97,32 +93,29 @@ def write_code( f"[green]File written:[/] {code_file}" ) - # Append the file to the list in agent's state for later reference - file_list = state.get("code_files", []) - if filename not in file_list: - file_list.append(filename) - - # Create a tool message to send back to acknowledge success. - msg = ToolMessage( - content=f"File {filename} written successfully.", - tool_call_id=tool_call_id, - ) + # Record the edit operation + if (store := runtime.store) is not None: + store.put( + ("workspace", "file_edit"), + filename, + { + "modified": time.time(), + "tool_call_id": runtime.tool_call_id, + "thread_id": runtime.config.get("metadata", {}).get( + "thread_id", None + ), + }, + ) - # Return updated code files list & the message - return Command( - update={ - "code_files": file_list, - "messages": [msg], - } - ) + return f"File {filename} written successfully." @tool def edit_code( old_code: str, new_code: str, - filename: str, - state: Annotated[dict, InjectedState], + filename: AsciiStr, + runtime: ToolRuntime[AgentContext], ) -> str: """Replace the **first** occurrence of *old_code* with *new_code* in *filename*. @@ -134,17 +127,16 @@ def edit_code( Returns: Success / failure message. """ - workspace_dir = state["workspace"] + workspace_dir = runtime.context.workspace console.print("[cyan]Editing file:[/cyan]", filename) - code_file = os.path.join(workspace_dir, filename) + code_file = Path(workspace_dir, filename) try: content = read_text_file(code_file) except FileNotFoundError: console.print( "[bold bright_white on red] :heavy_multiplication_x: [/] " "[red]File not found:[/]", - filename, ) return f"Failed: {filename} not found." @@ -183,9 +175,19 @@ def edit_code( f"[bold bright_white on green] :heavy_check_mark: [/] " f"[green]File updated:[/] {code_file}" ) - file_list = state.get("code_files", []) - if code_file not in file_list: - file_list.append(filename) - state["code_files"] = file_list + + # Record the edit operation + if (store := runtime.store) is not None: + store.put( + ("workspace", "file_edit"), + filename, + { + "modified": time.time(), + "tool_call_id": runtime.tool_call_id, + "thread_id": runtime.config.get("metadata", {}).get( + "thread_id", None + ), + }, + ) return f"File {filename} updated successfully." diff --git a/src/ursa/util/parse.py b/src/ursa/util/parse.py index 88f6552..aa8b684 100644 --- a/src/ursa/util/parse.py +++ b/src/ursa/util/parse.py @@ -3,6 +3,7 @@ import re import shutil import unicodedata +from pathlib import Path from typing import Any, Optional from urllib.parse import urljoin, urlparse @@ -406,13 +407,13 @@ def extract_main_text_only(html: str, *, max_chars: int = 250_000) -> str: return txt[:max_chars] -def read_pdf_text(path: str) -> str: +def read_pdf_text(path: str | Path) -> str: loader = PyPDFLoader(path) pages = loader.load() return "\n".join(p.page_content for p in pages) -def read_text_file(path: str) -> str: +def read_text_file(path: str | Path) -> str: """ Reads in a file at a given path into a string diff --git a/src/ursa/util/types.py b/src/ursa/util/types.py new file mode 100644 index 0000000..3a2c498 --- /dev/null +++ b/src/ursa/util/types.py @@ -0,0 +1,11 @@ +from typing import Annotated + +from pydantic import StringConstraints + +AsciiStr = Annotated[ + str, + StringConstraints( + strip_whitespace=True, strict=True, pattern=r"^[\x20-\x7E\t\n\r\f\v]+$" + ), +] +""" Limit strings to "text" ASCII characters (letters, digits, symbols, whitespace) """ diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/agents/test_base/test_base.py b/tests/agents/test_base/test_base.py index 17f1af9..85eecf1 100644 --- a/tests/agents/test_base/test_base.py +++ b/tests/agents/test_base/test_base.py @@ -208,6 +208,20 @@ def test_metrics_toggle_off(tmpdir: Path, monkeypatch, pricing_file: Path): assert files == [] +def test_base_agent_provisions_sqlite_store(tmpdir: Path): + agent = Agent(llm=TinyCountingModel(), workspace=tmpdir) + + store = agent.storage + store.put(("tests",), "key", {"value": "ok"}) + + item = store.get(("tests",), "key") + assert item is not None + assert item.value["value"] == "ok" + + if hasattr(store, "conn"): + store.conn.close() + + async def test_chat_interface(tmpdir: Path): agent = Agent( llm=TinyCountingModel(), diff --git a/tests/agents/test_execution_agent/test_execution_agent.py b/tests/agents/test_execution_agent/test_execution_agent.py index 7941bfa..f48cbf6 100644 --- a/tests/agents/test_execution_agent/test_execution_agent.py +++ b/tests/agents/test_execution_agent/test_execution_agent.py @@ -1,11 +1,12 @@ from math import sqrt from pathlib import Path -from typing import Annotated +from typing import Iterator import pytest from langchain.tools import ToolRuntime, tool +from langchain_core.language_models.fake_chat_models import GenericFakeChatModel from langchain_core.messages import AIMessage, HumanMessage -from langgraph.prebuilt import InjectedState +from langgraph.runtime import Runtime from ursa.agents import ExecutionAgent @@ -15,9 +16,7 @@ def stub_execution_tools(monkeypatch): """Replace external tools with lightweight stubs for deterministic testing.""" @tool - def fake_run_command( - query: str, state: Annotated[dict, InjectedState] - ) -> str: + def fake_run_command(query: str, runtime: ToolRuntime) -> str: """Return a placeholder response instead of executing shell commands.""" return "STDOUT:\nstubbed output\nSTDERR:\n" @@ -65,6 +64,21 @@ def fake_run_osti_search( ) +class ToolReadyFakeChatModel(GenericFakeChatModel): + def bind_tools(self, tools, **kwargs): + return self + + +def _message_stream(content: str) -> Iterator[AIMessage]: + while True: + yield AIMessage(content=content) + + +@pytest.fixture +def chat_model(): + return ToolReadyFakeChatModel(messages=_message_stream("ok")) + + @pytest.mark.asyncio async def test_execution_agent_ainvoke_returns_ai_message( chat_model, tmpdir: Path @@ -130,3 +144,67 @@ def do_magic(a: int, b: int) -> float: assert ai_messages assert isinstance(result["messages"][-1], AIMessage) assert Path(result["workspace"]).exists() + + +def test_write_code_edits_are_considered_in_safety_check( + tmpdir: Path, monkeypatch +): + execution_agent = ExecutionAgent( + llm=ToolReadyFakeChatModel(messages=_message_stream("[YES] allowed")), + workspace=tmpdir, + ) + + runtime = Runtime( + context=execution_agent.context, store=execution_agent.storage + ) + + write_call = AIMessage( + content="", + tool_calls=[ + { + "id": "write-1", + "name": "write_code", + "args": {"code": "print('tracked')", "filename": "tracked.py"}, + "type": "tool_call", + } + ], + ) + + execution_agent.tool_node.invoke( + {"messages": [write_call]}, runtime=runtime + ) + + captured = {} + + def fake_prompt(query, safe_codes, edited_files): + captured["files"] = edited_files + return "prompt" + + monkeypatch.setattr(execution_agent, "get_safety_prompt", fake_prompt) + + command_call = AIMessage( + content="", + tool_calls=[ + { + "id": "run-1", + "name": "run_command", + "args": {"query": "ls"}, + "type": "tool_call", + } + ], + ) + state = { + "messages": [HumanMessage(content="start"), command_call], + "current_progress": "", + "workspace": execution_agent.workspace, + "symlinkdir": {}, + "model": execution_agent.llm, + } + + result = execution_agent.safety_check(state, runtime) + + assert captured["files"], "safety prompt receives recorded edits" + assert all(isinstance(entry, str) for entry in captured["files"]) + assert "tracked.py" in captured["files"] + # safe command should leave messages unchanged + assert result["messages"] == state["messages"] diff --git a/tests/cli/test_hitl.py b/tests/cli/test_hitl.py index 2ceb53a..4c9bee3 100644 --- a/tests/cli/test_hitl.py +++ b/tests/cli/test_hitl.py @@ -17,6 +17,34 @@ from ursa.cli.hitl import HITL, UrsaRepl +@pytest.fixture(autouse=True) +def stub_duckduckgo(monkeypatch): + class DummyDDGS: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def text(self, *args, **kwargs): + yield { + "href": "https://example.com", + "title": "Example Result", + "body": "Example summary", + } + + monkeypatch.setattr( + "ursa.agents.acquisition_agents.DDGS", + lambda: DummyDDGS(), + raising=False, + ) + monkeypatch.setattr( + "ursa.agents.hypothesizer_agent.DDGS", + lambda: DummyDDGS(), + raising=False, + ) + + @pytest.fixture(scope="function") def ursa_config(tmpdir, chat_model, embedding_model): config = UrsaConfig( diff --git a/tests/tools/__init__.py b/tests/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/tools/test_read_file_tool.py b/tests/tools/test_read_file_tool.py new file mode 100644 index 0000000..8ad312f --- /dev/null +++ b/tests/tools/test_read_file_tool.py @@ -0,0 +1,40 @@ +from pathlib import Path + +from tests.tools.utils import make_runtime +from ursa.tools.read_file_tool import read_file + + +def test_read_file_reads_text_from_workspace(tmp_path: Path): + target = tmp_path / "example.txt" + target.write_text("sample text", encoding="utf-8") + + result = read_file.func( + filename=str(target.name), + runtime=make_runtime(tmp_path, tool_call_id="read-file-call"), + ) + + assert result == "sample text" + + +def test_read_file_uses_pdf_reader(monkeypatch, tmp_path: Path): + called = {} + + def fake_pdf_reader(path: Path) -> str: + called["path"] = path + return "pdf contents" + + def fail_text_reader(path: Path) -> str: + raise AssertionError("read_text_file should not be called for PDFs") + + monkeypatch.setattr( + "ursa.tools.read_file_tool.read_pdf_text", fake_pdf_reader + ) + monkeypatch.setattr( + "ursa.tools.read_file_tool.read_text_file", fail_text_reader + ) + + runtime = make_runtime(tmp_path, tool_call_id="pdf-call") + result = read_file.func(filename="report.pdf", runtime=runtime) + + assert result == "pdf contents" + assert called["path"] == tmp_path / "report.pdf" diff --git a/tests/tools/test_run_command_tool.py b/tests/tools/test_run_command_tool.py new file mode 100644 index 0000000..c5250fc --- /dev/null +++ b/tests/tools/test_run_command_tool.py @@ -0,0 +1,100 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest +from pydantic import ValidationError + +from tests.tools.utils import make_runtime +from ursa.tools.run_command_tool import run_command +from ursa.util.types import AsciiStr + + +def test_run_command_invokes_subprocess_in_workspace( + monkeypatch, tmp_path: Path +): + recorded = {} + + def fake_run(*args, **kwargs): + recorded["args"] = args + recorded["kwargs"] = kwargs + return SimpleNamespace(stdout="output", stderr="") + + monkeypatch.setattr("ursa.tools.run_command_tool.subprocess.run", fake_run) + + result = run_command.func( + "echo hi", + runtime=make_runtime( + tmp_path, thread_id="run-thread", tool_call_id="run-call" + ), + ) + + assert result == "STDOUT:\noutput\nSTDERR:\n" + assert recorded["kwargs"]["cwd"] == tmp_path + assert recorded["kwargs"]["shell"] is True + + +def test_run_command_truncates_output(monkeypatch, tmp_path: Path): + long_stdout = "a" * 200 + long_stderr = "b" * 200 + + monkeypatch.setattr( + "ursa.tools.run_command_tool.subprocess.run", + lambda *args, **kwargs: SimpleNamespace( + stdout=long_stdout, stderr=long_stderr + ), + ) + + result = run_command.func( + "noop", + runtime=make_runtime( + tmp_path, limit=64, tool_call_id="truncate", thread_id="run-thread" + ), + ) + + stdout_part, stderr_part = result.split("STDERR:\n", maxsplit=1) + stdout_body = stdout_part.replace("STDOUT:\n", "", 1).rstrip("\n") + stderr_body = stderr_part + + assert "... [snipped" in stdout_body + assert "... [snipped" in stderr_body + assert len(stdout_body) < len(long_stdout) + assert len(stderr_body) < len(long_stderr) + + +def test_run_command_handles_keyboard_interrupt(monkeypatch, tmp_path: Path): + def raise_interrupt(*args, **kwargs): + raise KeyboardInterrupt() + + monkeypatch.setattr( + "ursa.tools.run_command_tool.subprocess.run", raise_interrupt + ) + + result = run_command.func( + "sleep 1", + runtime=make_runtime( + tmp_path, tool_call_id="interrupt", thread_id="run-thread" + ), + ) + + assert "KeyboardInterrupt:" in result + + +def test_run_command_rejects_unicode_input(tmp_path: Path): + runtime = make_runtime( + tmp_path, thread_id="run-thread", tool_call_id="unicode" + ) + + with pytest.raises(ValidationError): + run_command.invoke({"query": "ls café", "runtime": runtime}) + + +def test_run_command_schema_has_regex_constraint(): + field = run_command.args_schema.model_fields["query"] + assert field.annotation is str + constraints = [meta for meta in field.metadata if hasattr(meta, "pattern")] + assert constraints + ascii_constraints = [ + meta for meta in AsciiStr.__metadata__ if hasattr(meta, "pattern") + ] + assert ascii_constraints + assert constraints[0].pattern == ascii_constraints[0].pattern diff --git a/tests/tools/test_write_code_tool.py b/tests/tools/test_write_code_tool.py new file mode 100644 index 0000000..da90390 --- /dev/null +++ b/tests/tools/test_write_code_tool.py @@ -0,0 +1,97 @@ +import time +from pathlib import Path + +from langgraph.store.memory import InMemoryStore + +from tests.tools.utils import make_runtime +from ursa.tools.write_code_tool import edit_code, write_code + + +def test_write_code_strips_fences_and_writes(tmp_path: Path): + runtime = make_runtime( + tmp_path, thread_id="thread-1", tool_call_id="write-call" + ) + + fenced = """```python +print(\"hello\") +```""" + + result = write_code.func(code=fenced, filename="hello.py", runtime=runtime) + + target = tmp_path / "hello.py" + assert target.exists() + assert target.read_text(encoding="utf-8") == 'print("hello")\n' + assert "written successfully" in result + + +def test_write_code_records_store_entry(tmp_path: Path): + store = InMemoryStore() + runtime = make_runtime( + tmp_path, store=store, tool_call_id="tc-1", thread_id="thread-1" + ) + + write_code.func(code="print(42)", filename="sample.py", runtime=runtime) + + item = store.get(("workspace", "file_edit"), "sample.py") + assert item is not None + assert item.value["tool_call_id"] == "tc-1" + assert item.value["thread_id"] == "thread-1" + assert item.value["modified"] <= time.time() + + +def test_edit_code_updates_file_and_records(tmp_path: Path): + target = tmp_path / "app.py" + target.write_text("print('hello')\nprint('hello')\n", encoding="utf-8") + store = InMemoryStore() + runtime = make_runtime( + tmp_path, store=store, tool_call_id="tc-edit", thread_id="thread-7" + ) + + result = edit_code.func( + old_code="print('hello')", + new_code="print('bye')", + filename="app.py", + runtime=runtime, + ) + + assert "updated successfully" in result + assert ( + target.read_text(encoding="utf-8") == "print('bye')\nprint('hello')\n" + ) + item = store.get(("workspace", "file_edit"), "app.py") + assert item is not None + assert item.value["tool_call_id"] == "tc-edit" + assert item.value["thread_id"] == "thread-7" + + +def test_edit_code_noop_when_old_code_missing(tmp_path: Path): + target = tmp_path / "script.py" + target.write_text("print('hello')\n", encoding="utf-8") + store = InMemoryStore() + runtime = make_runtime(tmp_path, store=store, tool_call_id="tc-miss") + + result = edit_code.func( + old_code="print('world')", + new_code="print('bye')", + filename="script.py", + runtime=runtime, + ) + + assert "No changes made" in result + assert target.read_text(encoding="utf-8") == "print('hello')\n" + assert store.get(("workspace", "file_edit"), "script.py") is None + + +def test_edit_code_missing_file(tmp_path: Path): + store = InMemoryStore() + runtime = make_runtime(tmp_path, store=store, tool_call_id="tc-missing") + + result = edit_code.func( + old_code="print('hello')", + new_code="print('bye')", + filename="missing.py", + runtime=runtime, + ) + + assert "Failed: missing.py not found" in result + assert store.get(("workspace", "file_edit"), "missing.py") is None diff --git a/tests/tools/utils.py b/tests/tools/utils.py new file mode 100644 index 0000000..883c037 --- /dev/null +++ b/tests/tools/utils.py @@ -0,0 +1,26 @@ +from pathlib import Path +from typing import Optional + +from langchain.tools import ToolRuntime +from langgraph.store.base import BaseStore + +from ursa.agents.base import AgentContext + + +def make_runtime( + workspace: Path, + *, + tool_call_id: str = "tool-call", + thread_id: str = "thread", + limit: int = 3000, + store: Optional[BaseStore] = None, +) -> ToolRuntime[AgentContext]: + """Construct a minimal ToolRuntime populated with AgentContext.""" + return ToolRuntime( + state={}, + context=AgentContext(workspace=workspace, tool_character_limit=limit), + config={"metadata": {"thread_id": thread_id}}, + stream_writer=lambda _: None, + tool_call_id=tool_call_id, + store=store, + )