diff --git a/openhands_cli/runner.py b/openhands_cli/runner.py index 7f5028d..e0e15e5 100644 --- a/openhands_cli/runner.py +++ b/openhands_cli/runner.py @@ -16,6 +16,8 @@ from openhands_cli.user_actions import ask_user_confirmation from openhands_cli.user_actions.types import UserConfirmation +import asyncio + class ConversationRunner: """Handles the conversation state machine logic cleanly.""" @@ -188,3 +190,42 @@ def _handle_confirmation_request(self) -> UserConfirmation: # Accept action without changing existing policies assert decision == UserConfirmation.ACCEPT return decision + + # openhands_cli/runner.py + def set_input_manager(self, input_manager): + self.input_manager = input_manager + + async def _step_agent_safe(self): + if hasattr(self.conversation, 'step_async'): + await self.conversation.step_async() + else: + # Roda código bloqueante em outra thread + loop = asyncio.get_running_loop() + await loop.run_in_executor(None, self.conversation.step) + + + async def run_concurrent_loop(self): + input_task = asyncio.create_task(self.input_manager.read_input()) + + while self.conversation.state.execution_status == "RUNNING": + agent_task = asyncio.create_task(self._step_agent_safe()) + + done, _ = await asyncio.wait( + [input_task, agent_task], + return_when=asyncio.FIRST_COMPLETED + ) + + if input_task in done: + cmd = input_task.result() + + if cmd == "/exit": + break + + + self.conversation.send_message(cmd) + + input_task = asyncio.create_task(self.input_manager.read_input()) + + + if agent_task in done: + pass \ No newline at end of file diff --git a/openhands_cli/tui/tui.py b/openhands_cli/tui/tui.py index 7b6da8a..ddb9257 100644 --- a/openhands_cli/tui/tui.py +++ b/openhands_cli/tui/tui.py @@ -1,7 +1,7 @@ from collections.abc import Generator from uuid import UUID -from prompt_toolkit import print_formatted_text +from prompt_toolkit import print_formatted_text, PromptSession from prompt_toolkit.completion import CompleteEvent, Completer, Completion from prompt_toolkit.document import Document from prompt_toolkit.formatted_text import HTML @@ -10,6 +10,8 @@ from openhands_cli.pt_style import get_cli_style + + DEFAULT_STYLE = get_cli_style() # Available commands with descriptions @@ -100,3 +102,15 @@ def display_welcome(conversation_id: UUID, resume: bool = False) -> None: ) ) print() + +class InputManager: + def __init__(self): + self.session = PromptSession(style=get_cli_style()) + + async def read_input(self): + from prompt_toolkit.patch_stdout import patch_stdout + with patch_stdout(): + try: + return await self.session.prompt_async("> ") + except (EOFError, KeyboardInterrupt): + return "/exit" diff --git a/tests/test_concurrency_flow.py b/tests/test_concurrency_flow.py new file mode 100644 index 0000000..90c2d88 --- /dev/null +++ b/tests/test_concurrency_flow.py @@ -0,0 +1,124 @@ +import sys +from unittest.mock import MagicMock, AsyncMock +import pytest +import asyncio + +mock_obj = MagicMock() + + +modules_to_patch = [ + "openhands", + "openhands.sdk", + "openhands.sdk.conversation", + "openhands.sdk.conversation.state", + "openhands.sdk.security", + "openhands.sdk.security.confirmation_policy", + "openhands_cli.listeners", + "openhands_cli.listeners.pause_listener", + "openhands_cli.setup", + "openhands_cli.user_actions", + "openhands_cli.user_actions.types", +] + +for module in modules_to_patch: + sys.modules[module] = mock_obj + + +from openhands_cli.tui.tui import InputManager +from openhands_cli.runner import ConversationRunner + + + +def test_instantiate_input_manager(): + # Act + manager = InputManager() + # Assert + assert manager.session is not None + +@pytest.mark.asyncio +async def test_input_manager_reads_async(): + # Arrange + manager = InputManager() + # Mockamos o prompt para não travar esperando digitação real + manager.session.prompt_async = AsyncMock(return_value="hello") + + # Act + result = await manager.read_input() + + # Assert + assert result == "hello" + +def test_runner_accepts_input_manager(): + # Arrange + runner = ConversationRunner(MagicMock()) + input_mgr = InputManager() + + # Act + runner.set_input_manager(input_mgr) + + # Assert + assert runner.input_manager == input_mgr + + +@pytest.mark.asyncio +async def test_runner_executes_step_in_executor(): + """Ciclo 4: Testa execução segura (fallback para síncrono)""" + # Arrange + mock_conv = MagicMock() + + # TRUQUE: Deletamos explicitamente o atributo 'step_async'. + # Isso obriga o 'hasattr' a retornar False no código, forçando o 'else'. + del mock_conv.step_async + + # Definimos o método síncrono que esperamos que seja chamado + mock_conv.step = MagicMock() + + runner = ConversationRunner(mock_conv) + + # Act + await runner._step_agent_safe() + + # Assert + # Agora sim verificamos se o método síncrono foi chamado + mock_conv.step.assert_called_once() + +@pytest.mark.asyncio +async def test_concurrent_loop_exit(): + # Arrange + runner = ConversationRunner(MagicMock()) + runner.conversation.state.execution_status = "RUNNING" + + mock_input = MagicMock() + mock_input.read_input = AsyncMock(return_value="/exit") + runner.set_input_manager(mock_input) + + # Act + await runner.run_concurrent_loop() + + # Assert + # Se o loop não terminar, o teste trava (timeout) + assert True + +@pytest.mark.asyncio +async def test_input_interrupts_agent(): + + mock_conv = MagicMock() + mock_conv.state.execution_status = "RUNNING" + + + async def slow_agent_step(): + await asyncio.sleep(0.01) + + mock_conv.step_async = AsyncMock(side_effect=slow_agent_step) + + + mock_input = MagicMock() + + mock_input.read_input = AsyncMock(side_effect=["ajuda aqui", "/exit"]) + + runner = ConversationRunner(mock_conv) + runner.set_input_manager(mock_input) + + await runner.run_concurrent_loop() + + mock_conv.send_message.assert_called_with("ajuda aqui") \ No newline at end of file