From ec4c9cb805b0f38aea64c2f0154739fecd803be4 Mon Sep 17 00:00:00 2001 From: Ahmed Hassanin Date: Sat, 1 Nov 2025 20:50:28 -0400 Subject: [PATCH] fix: Add Windows support to python_repl tool (#15) The python_repl tool was POSIX-only due to PTY dependencies (pty, fcntl, termios). Added cross-platform SubprocessExecutor class for Windows while preserving existing PtyManager for Unix systems. No breaking changes. Changes: - Added platform detection (IS_WINDOWS, IS_POSIX) - Created SubprocessExecutor using subprocess.Popen for Windows - Conditional PtyManager import for POSIX systems only - Smart execution routing based on platform - Enhanced docs with cross-platform support info Closes #15 --- src/strands_tools/python_repl.py | 636 ++++++++++++++++++------------- 1 file changed, 381 insertions(+), 255 deletions(-) diff --git a/src/strands_tools/python_repl.py b/src/strands_tools/python_repl.py index ddd99775..c9a65cc5 100644 --- a/src/strands_tools/python_repl.py +++ b/src/strands_tools/python_repl.py @@ -3,11 +3,12 @@ This module provides a tool for running Python code through a Strands Agent, with features like: - Persistent state between executions -- Interactive PTY support for real-time feedback +- Interactive PTY support for real-time feedback (Unix) or subprocess (Windows) - Output capturing and formatting - Error handling and logging - State reset capabilities - User confirmation for code execution +- Cross-platform support (Windows, Linux, macOS) Usage with Strands Agent: ```python @@ -32,16 +33,13 @@ ``` """ -import fcntl import logging import os -import pty +import platform import re -import select -import signal -import struct +import subprocess import sys -import termios +import tempfile import threading import traceback import types @@ -60,23 +58,36 @@ from strands_tools.utils import console_util from strands_tools.utils.user_input import get_user_input +# Platform-specific imports +IS_WINDOWS = platform.system() == "Windows" +IS_POSIX = not IS_WINDOWS + +if IS_POSIX: + import fcntl + import pty + import select + import signal + import struct + import termios + # Initialize logging and set paths logger = logging.getLogger(__name__) # Tool specification TOOL_SPEC = { "name": "python_repl", - "description": "Execute Python code in a REPL environment with interactive PTY support and state persistence.\n\n" + "description": "Execute Python code in a REPL environment with interactive support and state persistence.\n\n" "IMPORTANT SAFETY FEATURES:\n" "1. User Confirmation: Requires explicit approval before executing code\n" "2. Code Preview: Shows syntax-highlighted code before execution\n" "3. State Management: Maintains variables between executions, default controlled by PYTHON_REPL_RESET_STATE\n" "4. Error Handling: Captures and formats errors with suggestions\n" "5. Development Mode: Can bypass confirmation in BYPASS_TOOL_CONSENT environments\n" - "6. Interactive Control: Can enable/disable interactive PTY mode in PYTHON_REPL_INTERACTIVE environments\n\n" + "6. Interactive Control: Can enable/disable interactive mode in PYTHON_REPL_INTERACTIVE environments\n" + "7. Cross-Platform: Works on Windows, Linux, and macOS\n\n" "Key Features:\n" "- Persistent state between executions\n" - "- Interactive PTY support for real-time feedback\n" + "- Interactive support for real-time feedback\n" "- Output capturing and formatting\n" "- Error handling and logging\n" "- State reset capabilities\n\n" @@ -92,7 +103,7 @@ "interactive": { "type": "boolean", "description": ( - "Whether to enable interactive PTY mode. " + "Whether to enable interactive mode. " "Default controlled by PYTHON_REPL_INTERACTIVE environment variable." ), "default": True, @@ -257,213 +268,123 @@ def clean_ansi(text: str) -> str: return ansi_escape.sub("", text) -class PtyManager: - """Manages PTY-based Python execution with state synchronization.""" +class SubprocessExecutor: + """Cross-platform subprocess-based Python execution with state synchronization.""" def __init__(self, callback: Optional[Callable] = None): - self.supervisor_fd = -1 - self.worker_fd = -1 - self.pid = -1 self.output_buffer: List[str] = [] - self.input_buffer: List[str] = [] - self.stop_event = threading.Event() self.callback = callback + self.process: Optional[subprocess.Popen] = None - def start(self, code: str) -> None: - """Start PTY session with code execution.""" - # Create PTY - self.supervisor_fd, self.worker_fd = pty.openpty() + def start(self, code: str) -> int: + """Start subprocess session with code execution.""" + # Create a temporary script that loads state, executes code, and saves state + script_code = f""" +import sys +import os +import dill +from pathlib import Path - # Set terminal size - term_size = struct.pack("HHHH", 24, 80, 0, 0) - fcntl.ioctl(self.worker_fd, termios.TIOCSWINSZ, term_size) +# Load state +persistence_dir = os.path.join(Path.cwd(), "repl_state") +state_file = os.path.join(persistence_dir, "repl_state.pkl") - # Fork process - self.pid = os.fork() +namespace = {{"__name__": "__main__"}} - if self.pid == 0: # Child process +if os.path.exists(state_file): + try: + with open(state_file, "rb") as f: + saved_state = dill.load(f) + namespace.update(saved_state) + except Exception: + pass + +# Execute user code +try: + exec(''' +{code} +''', namespace) + + # Save state + save_dict = {{}} + for name, value in namespace.items(): + if not name.startswith("_"): try: - # Setup PTY - os.close(self.supervisor_fd) - os.dup2(self.worker_fd, 0) - os.dup2(self.worker_fd, 1) - os.dup2(self.worker_fd, 2) - - # Execute in REPL namespace - namespace = repl_state.get_namespace() - exec(code, namespace) + dill.dumps(value) + save_dict[name] = value + except: + continue + + with open(state_file, "wb") as f: + dill.dump(save_dict, f) + + sys.exit(0) +except Exception as e: + import traceback + traceback.print_exc() + sys.exit(1) +""" - os._exit(0) + # Write script to temporary file + with tempfile.NamedTemporaryFile(mode='w', suffix='.py', delete=False, encoding='utf-8') as f: + script_path = f.name + f.write(script_code) - except Exception: - traceback.print_exc(file=sys.stderr) - os._exit(1) + try: + # Execute the script + self.process = subprocess.Popen( + [sys.executable, script_path], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + stdin=subprocess.PIPE, + text=True, + bufsize=1, + encoding='utf-8', + errors='replace' + ) - else: # Parent process - os.close(self.worker_fd) + # Read output in separate thread + output_thread = threading.Thread(target=self._read_output) + output_thread.daemon = True + output_thread.start() - # Start output reader - reader = threading.Thread(target=self._read_output) - reader.daemon = True - reader.start() + # Wait for completion + exit_code = self.process.wait() - # Start input handler - input_handler = threading.Thread(target=self._handle_input) - input_handler.daemon = True - input_handler.start() + # Wait for output thread to finish + output_thread.join(timeout=1.0) - def _read_output(self) -> None: - """Read and process PTY output with improved error handling and file descriptor management.""" - buffer = "" - incomplete_bytes = b"" # Buffer for incomplete UTF-8 sequences + return exit_code - while not self.stop_event.is_set(): + finally: + # Clean up temporary file try: - # Check if file descriptor is still valid - if self.supervisor_fd < 0: - logger.debug("Invalid file descriptor, stopping output reader") - break - - # Use select with timeout to avoid blocking - try: - r, _, _ = select.select([self.supervisor_fd], [], [], 0.1) - except (OSError, ValueError) as e: - # File descriptor became invalid during select - logger.debug(f"File descriptor error during select: {e}") - break - - if self.supervisor_fd in r: - try: - raw_data = os.read(self.supervisor_fd, 1024) - except (OSError, ValueError) as e: - # Handle closed file descriptor or other OS errors - if e.errno == 9: # Bad file descriptor - logger.debug("PTY closed, stopping output reader") - else: - logger.warning(f"Error reading from PTY: {e}") - break - - if not raw_data: - # EOF reached, PTY closed - logger.debug("EOF reached, PTY closed") - break + os.unlink(script_path) + except Exception: + pass - # Combine with any incomplete bytes from previous read - full_data = incomplete_bytes + raw_data + def _read_output(self) -> None: + """Read and process subprocess output.""" + if not self.process or not self.process.stdout: + return - try: - # Try to decode the data - data = full_data.decode("utf-8") - incomplete_bytes = b"" # Clear incomplete buffer on success - - except UnicodeDecodeError as e: - # Handle incomplete UTF-8 sequence at the end - if e.start > 0: - # We can decode part of the data - data = full_data[: e.start].decode("utf-8") - incomplete_bytes = full_data[e.start :] - else: - # Can't decode anything, save for next iteration - incomplete_bytes = full_data - continue - - if data: - # Append to buffer - buffer += data - - # Process complete lines - while "\n" in buffer: - line, buffer = buffer.split("\n", 1) - # Clean and store output - cleaned = clean_ansi(line + "\n") - self.output_buffer.append(cleaned) - - # Stream if callback exists - if self.callback: - try: - self.callback(cleaned) - except Exception as callback_error: - logger.warning(f"Error in output callback: {callback_error}") - - # Handle remaining buffer (usually prompts) - if buffer: - cleaned = clean_ansi(buffer) - if self.callback: - try: - self.callback(cleaned) - except Exception as callback_error: - logger.warning(f"Error in output callback: {callback_error}") - - except (OSError, IOError) as e: - # Handle file descriptor errors gracefully - if hasattr(e, "errno") and e.errno == 9: # Bad file descriptor - logger.debug("PTY file descriptor closed, stopping reader") + try: + for line in iter(self.process.stdout.readline, ''): + if not line: break - else: - logger.warning(f"I/O error reading PTY output: {e}") - # Don't break immediately, try to continue - continue - - except UnicodeDecodeError as e: - # This shouldn't happen anymore with our improved handling, but just in case - logger.warning(f"Unicode decode error: {e}") - incomplete_bytes = b"" - continue - - except Exception as e: - # Catch any other unexpected errors - logger.error(f"Unexpected error in _read_output: {e}") - break - - # Clean shutdown - handle any remaining buffer - if buffer: - try: - cleaned = clean_ansi(buffer) + + cleaned = clean_ansi(line) self.output_buffer.append(cleaned) - if self.callback: - self.callback(cleaned) - except Exception as e: - logger.warning(f"Error processing final buffer: {e}") - # Handle any remaining incomplete bytes at shutdown - if incomplete_bytes: - try: - # Try to decode with error handling - final_data = incomplete_bytes.decode("utf-8", errors="replace") - if final_data: - cleaned = clean_ansi(final_data) - self.output_buffer.append(cleaned) - if self.callback: + # Stream if callback exists + if self.callback: + try: self.callback(cleaned) - except Exception as e: - logger.warning(f"Failed to process remaining bytes at shutdown: {e}") - - logger.debug("PTY output reader thread finished") - - def _handle_input(self) -> None: - """Handle interactive user input with improved buffering.""" - while not self.stop_event.is_set(): - try: - r, _, _ = select.select([sys.stdin], [], [], 0.1) - if sys.stdin in r: - # Read all available input - input_data = "" - while True: - char = sys.stdin.read(1) - if not char or char == "\n": - input_data += "\n" - break - input_data += char + except Exception as callback_error: + logger.warning(f"Error in output callback: {callback_error}") - if input_data: - # Only store input once - if input_data not in self.input_buffer: - self.input_buffer.append(input_data) - # Send to PTY with proper line ending - os.write(self.supervisor_fd, input_data.encode()) - - except (OSError, IOError): - break + except Exception as e: + logger.warning(f"Error reading subprocess output: {e}") def get_output(self) -> str: """Get complete output with ANSI codes removed and binary content truncated.""" @@ -481,59 +402,253 @@ def format_binary(text: str, max_len: int = None) -> str: return format_binary(clean) def stop(self) -> None: - """Stop PTY session and clean up resources properly.""" - logger.debug("Stopping PTY session...") + """Stop subprocess and clean up resources.""" + if self.process: + try: + if self.process.poll() is None: # Process still running + self.process.terminate() + try: + self.process.wait(timeout=1.0) + except subprocess.TimeoutExpired: + self.process.kill() + self.process.wait() + except Exception as e: + logger.debug(f"Error stopping subprocess: {e}") - # Signal threads to stop - self.stop_event.set() - # Clean up child process - if self.pid > 0: - try: - # Try graceful termination first - os.kill(self.pid, signal.SIGTERM) +if IS_POSIX: + class PtyManager: + """Manages PTY-based Python execution with state synchronization (Unix only).""" - # Wait briefly for graceful shutdown + def __init__(self, callback: Optional[Callable] = None): + self.supervisor_fd = -1 + self.worker_fd = -1 + self.pid = -1 + self.output_buffer: List[str] = [] + self.input_buffer: List[str] = [] + self.stop_event = threading.Event() + self.callback = callback + + def start(self, code: str) -> None: + """Start PTY session with code execution.""" + # Create PTY + self.supervisor_fd, self.worker_fd = pty.openpty() + + # Set terminal size + term_size = struct.pack("HHHH", 24, 80, 0, 0) + fcntl.ioctl(self.worker_fd, termios.TIOCSWINSZ, term_size) + + # Fork process + self.pid = os.fork() + + if self.pid == 0: # Child process try: - pid, status = os.waitpid(self.pid, os.WNOHANG) - if pid == 0: # Process still running - # Give it a moment - import time + # Setup PTY + os.close(self.supervisor_fd) + os.dup2(self.worker_fd, 0) + os.dup2(self.worker_fd, 1) + os.dup2(self.worker_fd, 2) - time.sleep(0.1) - # Try again - pid, status = os.waitpid(self.pid, os.WNOHANG) - if pid == 0: - # Force kill if still running - logger.debug("Forcing process termination") - os.kill(self.pid, signal.SIGKILL) - os.waitpid(self.pid, 0) + # Execute in REPL namespace + namespace = repl_state.get_namespace() + exec(code, namespace) - except OSError as e: - # Process might have already exited - logger.debug(f"Process cleanup error (likely already exited): {e}") + os._exit(0) - except (OSError, ProcessLookupError) as e: - # Process doesn't exist or already terminated - logger.debug(f"Process termination error (likely already gone): {e}") + except Exception: + traceback.print_exc(file=sys.stderr) + os._exit(1) - finally: - self.pid = -1 + else: # Parent process + os.close(self.worker_fd) - # Clean up file descriptor - if self.supervisor_fd >= 0: - try: - os.close(self.supervisor_fd) - logger.debug("PTY supervisor file descriptor closed") - except OSError as e: - logger.debug(f"Error closing supervisor fd: {e}") - finally: - self.supervisor_fd = -1 + # Start output reader + reader = threading.Thread(target=self._read_output) + reader.daemon = True + reader.start() + + # Start input handler + input_handler = threading.Thread(target=self._handle_input) + input_handler.daemon = True + input_handler.start() + + def _read_output(self) -> None: + """Read and process PTY output with improved error handling.""" + buffer = "" + incomplete_bytes = b"" + + while not self.stop_event.is_set(): + try: + if self.supervisor_fd < 0: + logger.debug("Invalid file descriptor, stopping output reader") + break + + try: + r, _, _ = select.select([self.supervisor_fd], [], [], 0.1) + except (OSError, ValueError) as e: + logger.debug(f"File descriptor error during select: {e}") + break + + if self.supervisor_fd in r: + try: + raw_data = os.read(self.supervisor_fd, 1024) + except (OSError, ValueError) as e: + if hasattr(e, 'errno') and e.errno == 9: + logger.debug("PTY closed, stopping output reader") + else: + logger.warning(f"Error reading from PTY: {e}") + break + + if not raw_data: + logger.debug("EOF reached, PTY closed") + break + + full_data = incomplete_bytes + raw_data + + try: + data = full_data.decode("utf-8") + incomplete_bytes = b"" + + except UnicodeDecodeError as e: + if e.start > 0: + data = full_data[: e.start].decode("utf-8") + incomplete_bytes = full_data[e.start :] + else: + incomplete_bytes = full_data + continue + + if data: + buffer += data + + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + cleaned = clean_ansi(line + "\n") + self.output_buffer.append(cleaned) + + if self.callback: + try: + self.callback(cleaned) + except Exception as callback_error: + logger.warning(f"Error in output callback: {callback_error}") + + if buffer: + cleaned = clean_ansi(buffer) + if self.callback: + try: + self.callback(cleaned) + except Exception as callback_error: + logger.warning(f"Error in output callback: {callback_error}") + + except (OSError, IOError) as e: + if hasattr(e, "errno") and e.errno == 9: + logger.debug("PTY file descriptor closed, stopping reader") + break + else: + logger.warning(f"I/O error reading PTY output: {e}") + continue + + except Exception as e: + logger.error(f"Unexpected error in _read_output: {e}") + break - logger.debug("PTY session cleanup completed") + # Handle remaining buffer + if buffer: + try: + cleaned = clean_ansi(buffer) + self.output_buffer.append(cleaned) + if self.callback: + self.callback(cleaned) + except Exception as e: + logger.warning(f"Error processing final buffer: {e}") + + if incomplete_bytes: + try: + final_data = incomplete_bytes.decode("utf-8", errors="replace") + if final_data: + cleaned = clean_ansi(final_data) + self.output_buffer.append(cleaned) + if self.callback: + self.callback(cleaned) + except Exception as e: + logger.warning(f"Failed to process remaining bytes: {e}") + + logger.debug("PTY output reader thread finished") + + def _handle_input(self) -> None: + """Handle interactive user input.""" + while not self.stop_event.is_set(): + try: + r, _, _ = select.select([sys.stdin], [], [], 0.1) + if sys.stdin in r: + input_data = "" + while True: + char = sys.stdin.read(1) + if not char or char == "\n": + input_data += "\n" + break + input_data += char + + if input_data: + if input_data not in self.input_buffer: + self.input_buffer.append(input_data) + os.write(self.supervisor_fd, input_data.encode()) + + except (OSError, IOError): + break + + def get_output(self) -> str: + """Get complete output with ANSI codes removed.""" + raw = "".join(self.output_buffer) + clean = clean_ansi(raw) + + def format_binary(text: str, max_len: int = None) -> str: + if max_len is None: + max_len = int(os.environ.get("PYTHON_REPL_BINARY_MAX_LEN", "100")) + if "\\x" in text and len(text) > max_len: + return f"{text[:max_len]}... [binary content truncated]" + return text + + return format_binary(clean) + + def stop(self) -> None: + """Stop PTY session and clean up resources.""" + logger.debug("Stopping PTY session...") + self.stop_event.set() + + if self.pid > 0: + try: + os.kill(self.pid, signal.SIGTERM) + try: + pid, status = os.waitpid(self.pid, os.WNOHANG) + if pid == 0: + import time + time.sleep(0.1) + pid, status = os.waitpid(self.pid, os.WNOHANG) + if pid == 0: + logger.debug("Forcing process termination") + os.kill(self.pid, signal.SIGKILL) + os.waitpid(self.pid, 0) -output_buffer: List[str] = [] + except OSError as e: + logger.debug(f"Process cleanup error: {e}") + + except (OSError, ProcessLookupError) as e: + logger.debug(f"Process termination error: {e}") + + finally: + self.pid = -1 + + if self.supervisor_fd >= 0: + try: + os.close(self.supervisor_fd) + logger.debug("PTY supervisor file descriptor closed") + except OSError as e: + logger.debug(f"Error closing supervisor fd: {e}") + finally: + self.supervisor_fd = -1 + + logger.debug("PTY session cleanup completed") def python_repl(tool: ToolUse, **kwargs: Any) -> ToolResult: @@ -568,21 +683,18 @@ def python_repl(tool: ToolUse, **kwargs: Any) -> ToolResult: ) ) - # Add permissions check - only show confirmation dialog if not - # in BYPASS_TOOL_CONSENT mode and not in non_interactive mode + # Add permissions check if not strands_dev and not non_interactive_mode: - # Create a table with code details for better visualization details_table = Table(show_header=False, box=box.SIMPLE) details_table.add_column("Property", style="cyan", justify="right") details_table.add_column("Value", style="green") - # Add code details details_table.add_row("Code Length", f"{len(code)} characters") details_table.add_row("Line Count", f"{len(code.splitlines())} lines") details_table.add_row("Mode", "Interactive" if interactive else "Standard") details_table.add_row("Reset State", "Yes" if reset_state else "No") + details_table.add_row("Platform", platform.system()) - # Show confirmation panel console.print( Panel( details_table, @@ -591,7 +703,7 @@ def python_repl(tool: ToolUse, **kwargs: Any) -> ToolResult: box=box.ROUNDED, ) ) - # Get user confirmation + user_input = get_user_input( "Do you want to proceed with Python code execution? [y/*]" ) @@ -618,15 +730,39 @@ def python_repl(tool: ToolUse, **kwargs: Any) -> ToolResult: # Track execution time and capture output start_time = datetime.now() output = None + exit_status = 0 try: - if interactive: - console.print("[green]Running in interactive mode...[/]") + # On Windows or when PTY is not available, use subprocess + if IS_WINDOWS or not interactive: + if interactive: + console.print("[green]Running in interactive mode (subprocess)...[/]") + else: + console.print("[blue]Running in standard mode...[/]") + + if interactive: + # Use subprocess executor for interactive mode on Windows + executor = SubprocessExecutor() + exit_status = executor.start(code) + output = executor.get_output() + executor.stop() + else: + # Use direct execution for non-interactive mode + captured = OutputCapture() + with captured as output_capture: + repl_state.execute(code) + output = output_capture.get_output() + if output: + console.print("[cyan]Output:[/]") + console.print(output) + + # On Unix systems, use PTY for better interactive support + elif IS_POSIX and interactive: + console.print("[green]Running in interactive mode (PTY)...[/]") pty_mgr = PtyManager() pty_mgr.start(code) # Wait for completion - exit_status = None # Initialize exit_status variable while True: try: pid, exit_status = os.waitpid(pty_mgr.pid, os.WNOHANG) @@ -642,15 +778,6 @@ def python_repl(tool: ToolUse, **kwargs: Any) -> ToolResult: # Save state if execution succeeded if exit_status == 0: repl_state.save_state(code) - else: - console.print("[blue]Running in standard mode...[/]") - captured = OutputCapture() - with captured as output_capture: - repl_state.execute(code) - output = output_capture.get_output() - if output: - console.print("[cyan]Output:[/]") - console.print(output) # Show execution stats duration = (datetime.now() - start_time).total_seconds() @@ -673,7 +800,6 @@ def python_repl(tool: ToolUse, **kwargs: Any) -> ToolResult: except RecursionError: console.print("[yellow]Recursion error detected - resetting state...[/]") repl_state.clear_state() - # Re-raise the exception after cleanup raise except Exception as e: @@ -708,4 +834,4 @@ def python_repl(tool: ToolUse, **kwargs: Any) -> ToolResult: "toolUseId": tool_use_id, "status": "error", "content": [{"text": f"{error_msg}{suggestion}"}], - } + } \ No newline at end of file