Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/mcp_agent/core/enhanced_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,12 @@ def pre_process_input(text):
result = await session.prompt_async(HTML(prompt_text), default=default)
return pre_process_input(result)
except KeyboardInterrupt:
# Handle Ctrl+C gracefully
return "STOP"
# Handle Ctrl+C gracefully at the prompt
rich_print("\n[yellow]Input cancelled. Type a command or 'STOP' to exit session.[/yellow]")
return "" # Return empty string to re-prompt
except EOFError:
# Handle Ctrl+D gracefully
rich_print("\n[yellow]EOF received. Type 'STOP' to exit session.[/yellow]")
return "STOP"
except Exception as e:
# Log and gracefully handle other exceptions
Expand Down
77 changes: 63 additions & 14 deletions src/mcp_agent/core/interactive_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
"""

import asyncio
from typing import Awaitable, Callable, Dict, List, Mapping, Optional, Protocol, Union

from mcp.types import Prompt, PromptMessage
Expand All @@ -38,12 +39,20 @@

class PromptProvider(Protocol):
"""Protocol for objects that can provide prompt functionality."""

async def list_prompts(self, server_name: Optional[str] = None, agent_name: Optional[str] = None) -> Mapping[str, List[Prompt]]:

async def list_prompts(
self, server_name: Optional[str] = None, agent_name: Optional[str] = None
) -> Mapping[str, List[Prompt]]:
"""List available prompts."""
...

async def apply_prompt(self, prompt_name: str, arguments: Optional[Dict[str, str]] = None, agent_name: Optional[str] = None, **kwargs) -> str:

async def apply_prompt(
self,
prompt_name: str,
arguments: Optional[Dict[str, str]] = None,
agent_name: Optional[str] = None,
**kwargs,
) -> str:
"""Apply a prompt."""
...

Expand Down Expand Up @@ -160,17 +169,19 @@ async def prompt_loop(
await self._list_prompts(prompt_provider, agent)
else:
# Use the name-based selection
await self._select_prompt(
prompt_provider, agent, prompt_name
)
await self._select_prompt(prompt_provider, agent, prompt_name)
continue

# Skip further processing if:
# 1. The command was handled (command_result is truthy)
# 2. The original input was a dictionary (special command like /prompt)
# 3. The command result itself is a dictionary (special command handling result)
# This fixes the issue where /prompt without arguments gets sent to the LLM
if command_result or isinstance(user_input, dict) or isinstance(command_result, dict):
if (
command_result
or isinstance(user_input, dict)
or isinstance(command_result, dict)
):
continue

if user_input.upper() == "STOP":
Expand All @@ -179,11 +190,45 @@ async def prompt_loop(
continue

# Send the message to the agent
result = await send_func(user_input, agent)
try:
result = await send_func(user_input, agent)
except KeyboardInterrupt:
rich_print("\n[yellow]Request cancelled by user (Ctrl+C).[/yellow]")
result = "" # Ensure result has a benign value for the loop
# Attempt to stop progress display safely
try:
# For rich.progress.Progress, 'progress_display.live.is_started' is a common check
if hasattr(progress_display, "live") and progress_display.live.is_started:
progress_display.stop()
# Fallback for older rich or different progress setup
elif hasattr(progress_display, "is_running") and progress_display.is_running:
progress_display.stop()
else: # If unsure, try stopping directly if stop() is available
if hasattr(progress_display, "stop"):
progress_display.stop()
except Exception:
pass # Continue anyway, don't let progress display crash the cancel
continue
except asyncio.CancelledError:
rich_print("\n[yellow]Request task was cancelled.[/yellow]")
result = ""
try:
if hasattr(progress_display, "live") and progress_display.live.is_started:
progress_display.stop()
elif hasattr(progress_display, "is_running") and progress_display.is_running:
progress_display.stop()
else:
if hasattr(progress_display, "stop"):
progress_display.stop()
except Exception:
pass
continue

return result

async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Optional[str] = None):
async def _get_all_prompts(
self, prompt_provider: PromptProvider, agent_name: Optional[str] = None
):
"""
Get a list of all available prompts.

Expand All @@ -196,8 +241,10 @@ async def _get_all_prompts(self, prompt_provider: PromptProvider, agent_name: Op
"""
try:
# Call list_prompts on the provider
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)

prompt_servers = await prompt_provider.list_prompts(
server_name=None, agent_name=agent_name
)

all_prompts = []

# Process the returned prompt servers
Expand Down Expand Up @@ -326,9 +373,11 @@ async def _select_prompt(
try:
# Get all available prompts directly from the prompt provider
rich_print(f"\n[bold]Fetching prompts for agent [cyan]{agent_name}[/cyan]...[/bold]")

# Call list_prompts on the provider
prompt_servers = await prompt_provider.list_prompts(server_name=None, agent_name=agent_name)
prompt_servers = await prompt_provider.list_prompts(
server_name=None, agent_name=agent_name
)

if not prompt_servers:
rich_print("[yellow]No prompts available for this agent[/yellow]")
Expand Down
Loading