From b3652cf4bb508cca20094777f936c3e70eeaf579 Mon Sep 17 00:00:00 2001 From: Michael Neale Date: Tue, 10 Sep 2024 13:13:16 +1000 Subject: [PATCH] fix: goose should track files it reads and not overwrite changes (#46) Co-authored-by: Bradley Axen --- src/goose/toolkit/developer.py | 27 +++++++++++++++++++++++---- tests/toolkit/test_developer.py | 21 +++++++++++++++++++++ 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/src/goose/toolkit/developer.py b/src/goose/toolkit/developer.py index 062b1b4..450ce24 100644 --- a/src/goose/toolkit/developer.py +++ b/src/goose/toolkit/developer.py @@ -1,6 +1,7 @@ from pathlib import Path from subprocess import CompletedProcess, run -from typing import List +from typing import List, Dict +import os from goose.utils.check_shell_command import is_dangerous_command from exchange import Message @@ -30,6 +31,10 @@ class Developer(Toolkit): We also include some default shell strategies in the prompt, such as using ripgrep """ + def __init__(self, *args: object, **kwargs: Dict[str, object]) -> None: + super().__init__(*args, **kwargs) + self.timestamps: Dict[str, float] = {} + def system(self) -> str: """Retrieve system configuration details for developer""" hints_path = Path(".goosehints") @@ -65,7 +70,7 @@ def update_plan(self, tasks: List[dict]) -> List[dict]: table.add_column("Status", justify="left") # Mapping of statuses to emojis for better visual representation in the table. - emoji = {"planned": "⏳", "complete": "✅", "failed": "❌", "in-progress": "🕓"} + emoji = {"planned": "⏳", "complete": "✅", "failed": "❌", "in-progress": "🕑"} for i, entry in enumerate(tasks): table.add_row(str(i), entry["description"], emoji[entry["status"]]) @@ -124,6 +129,8 @@ def read_file(self, path: str) -> str: language = get_language(path) content = Path(path).expanduser().read_text() self.notifier.log(Panel.fit(Markdown(f"```\ncat {path}\n```"), box=box.MINIMAL)) + # Record the last read timestamp + self.timestamps[path] = os.path.getmtime(path) return f"```{language}\n{content}\n```" @tool @@ -183,12 +190,24 @@ def write_file(self, path: str, content: str) -> str: # this method is dynamically attached to functions in the Goose framework self.notifier.log(Panel.fit(Markdown(md), title=path)) - # Prepare the path and create any necessary parent directories _path = Path(path) + if path in self.timestamps: + last_read_timestamp = self.timestamps.get(path, 0.0) + current_timestamp = os.path.getmtime(path) + if current_timestamp > last_read_timestamp: + raise RuntimeError( + f"File '{path}' has been modified since it was last read." + + " Read the file to incorporate changes or update your plan." + ) + + # Prepare the path and create any necessary parent directories _path.parent.mkdir(parents=True, exist_ok=True) # Write the content to the file _path.write_text(content) + # Update the last read timestamp after writing to the file + self.timestamps[path] = os.path.getmtime(path) + # Return a success message - return f"Succesfully wrote to {path}" + return f"Successfully wrote to {path}" diff --git a/tests/toolkit/test_developer.py b/tests/toolkit/test_developer.py index a3a291a..e36c498 100644 --- a/tests/toolkit/test_developer.py +++ b/tests/toolkit/test_developer.py @@ -93,3 +93,24 @@ def test_write_file(temp_dir, developer_toolkit): content = "Hello World" developer_toolkit.write_file(test_file.as_posix(), content) assert test_file.read_text() == content + + +def test_write_file_prevent_write_if_changed(temp_dir, developer_toolkit): + test_file = temp_dir / "test.txt" + content = "Hello World" + updated_content = "Hello Universe" + + # Initial write to record the timestamp + developer_toolkit.write_file(test_file.as_posix(), content) + developer_toolkit.read_file(test_file.as_posix()) + + import time + + # Modify file externally to simulate change + time.sleep(1) + test_file.write_text(updated_content) + + # Try to write through toolkit and check for the raised exception + with pytest.raises(RuntimeError, match="has been modified"): + developer_toolkit.write_file(test_file.as_posix(), content) + assert test_file.read_text() == updated_content