Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions utils/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,58 @@
from .agent import SpecializedAgent, AgentConfig
from .cli_executor import CLIExecutor, CLIConfig, CLIType
from .system_prompt import SystemPromptBuilder, create_default_system_prompt
from .models import (
ModelFamily,
ModelInfo,
GPT_MODELS,
CLAUDE_MODELS,
GEMINI_MODELS,
ALL_MODELS,
MODEL_BY_NAME,
MODELS_BY_FAMILY,
CLI_SUPPORTED_FAMILIES,
CLI_DEFAULT_MODELS,
get_supported_families,
get_supported_models,
get_supported_model_names,
is_model_supported,
get_model_info,
get_model_family,
get_default_model,
validate_model_for_cli,
get_cli_model_name,
)

__all__ = [
# Agent classes
"SpecializedAgent",
"AgentConfig",
# CLI executor
"CLIExecutor",
"CLIConfig",
"CLIType",
# System prompt
"SystemPromptBuilder",
"create_default_system_prompt",
# Model definitions
"ModelFamily",
"ModelInfo",
"GPT_MODELS",
"CLAUDE_MODELS",
"GEMINI_MODELS",
"ALL_MODELS",
"MODEL_BY_NAME",
"MODELS_BY_FAMILY",
"CLI_SUPPORTED_FAMILIES",
"CLI_DEFAULT_MODELS",
# Model helper functions
"get_supported_families",
"get_supported_models",
"get_supported_model_names",
"is_model_supported",
"get_model_info",
"get_model_family",
"get_default_model",
"validate_model_for_cli",
"get_cli_model_name",
]
128 changes: 105 additions & 23 deletions utils/agent/cli_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,23 @@
Handles execution of different AI CLI tools (Claude, Codex, Copilot).
"""

from __future__ import annotations

import asyncio
import logging
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional, Union
from enum import Enum
from dataclasses import dataclass, field

if TYPE_CHECKING:
from .models import ModelInfo

logger = logging.getLogger(__name__)


class CLIType(str, Enum):
"""Supported CLI types"""

CLAUDE = "claude"
CODEX = "codex"
COPILOT = "copilot"
Expand All @@ -27,12 +33,12 @@ class CLIConfig:
cli_type: CLIType = CLIType.COPILOT
"""CLI type to use: claude, codex, or copilot (default: copilot)"""

model: Optional[str] = None
model: Optional[Union[ModelInfo, str]] = None
"""
Model to use for the agent.
- Claude: sonnet, opus, haiku (default: sonnet)
- Codex: Uses ChatGPT account or API key
- Copilot: Claude Sonnet 4.5, Claude Sonnet 4, GPT-5 (default: Claude Sonnet 4.5)
Model to use for the agent. Can be a ModelInfo object or model name string.
- Claude Code: Only Claude family models (default: claude-opus-4.5)
- Codex: Only GPT family models (default: gpt-5)
- Copilot: GPT, Claude, or Gemini models (default: claude-opus-4.5)
"""

skip_permissions: bool = True
Expand Down Expand Up @@ -75,17 +81,92 @@ def get_cli_path(self) -> str:
return self.cli_path
return self.cli_type.value

def get_default_model(self) -> str:
"""Get default model for the CLI type"""
if self.model:
def get_model_name(self) -> str:
"""
Get the canonical model name string.

Returns the model name from ModelInfo or the string directly.
If no model is set, returns the default model for the CLI type.
"""
if self.model is not None:
# Import here to avoid circular import
from .models import ModelInfo as MI

if isinstance(self.model, MI):
return self.model.name
return self.model

if self.cli_type == CLIType.CLAUDE:
return "sonnet"
elif self.cli_type == CLIType.COPILOT:
return "claude-sonnet-4.5"
else: # CODEX
return "" # Uses account/API key
# Use defaults from models module
from .models import get_default_model

return get_default_model(self.cli_type)

def get_cli_model_name(self) -> str:
"""
Get the CLI-specific model name.

Different CLIs may use different naming conventions. For example,
Claude CLI uses 'sonnet', 'opus', 'haiku' as aliases instead of
full model names like 'claude-sonnet-4.5'.

Returns:
The model name appropriate for the configured CLI type.
"""
from .models import get_cli_model_name as convert_model_name, ModelInfo as MI

if self.model is not None:
if isinstance(self.model, MI):
return self.model.get_cli_model_name(self.cli_type)
# String model name - convert it
return convert_model_name(self.cli_type, self.model)

# Use default model and convert
from .models import get_default_model

default_model = get_default_model(self.cli_type)
return convert_model_name(self.cli_type, default_model)

def get_model_info(self) -> Optional[ModelInfo]:
"""
Get the ModelInfo object for the configured model.

Returns ModelInfo if model is set (either directly or looked up by name),
or None if no model is configured.
"""
from .models import ModelInfo as MI, get_model_info as lookup_model

if self.model is None:
# Get default model info
from .models import get_default_model

default_name = get_default_model(self.cli_type)
if default_name:
return lookup_model(default_name)
return None

if isinstance(self.model, MI):
return self.model

# Look up by name
return lookup_model(self.model)

def validate_model(self) -> None:
"""
Validate that the configured model is supported by the CLI type.

Raises:
ValueError: If the model is not supported by the CLI type
"""
from .models import validate_model_for_cli

model_name = self.get_model_name()
if model_name:
validate_model_for_cli(self.cli_type, model_name)

# Keep for backward compatibility
def get_default_model(self) -> str:
"""Get default model for the CLI type (deprecated, use get_model_name)"""
return self.get_model_name()


class CLIExecutor:
Expand Down Expand Up @@ -129,8 +210,8 @@ def _build_claude_command(self, base_cmd: List[str], prompt: str) -> List[str]:
if self.config.skip_permissions:
cmd.append("--dangerously-skip-permissions")

# Add model flag
model = self.config.get_default_model()
# Add model flag (use CLI-specific name for Claude, e.g., 'sonnet', 'opus')
model = self.config.get_cli_model_name()
if model:
cmd.extend(["--model", model])

Expand Down Expand Up @@ -207,23 +288,22 @@ async def execute(self, prompt: str) -> str:
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
cwd=self.config.cwd
cwd=self.config.cwd,
)

# Wait for completion with timeout
try:
stdout, stderr = await asyncio.wait_for(
result.communicate(),
timeout=self.config.timeout
result.communicate(), timeout=self.config.timeout
)
except asyncio.TimeoutError:
result.kill()
await result.wait()
return f"Error: CLI execution timed out after {self.config.timeout} seconds"

# Decode output
response = stdout.decode('utf-8').strip()
error_output = stderr.decode('utf-8').strip()
response = stdout.decode("utf-8").strip()
error_output = stderr.decode("utf-8").strip()

logger.debug(f"Response length: {len(response)} chars")
logger.debug(f"Response preview: {response[:200]}...")
Expand All @@ -235,7 +315,9 @@ async def execute(self, prompt: str) -> str:
error_msg += f"\nStderr: {error_output}"
if response:
error_msg += f"\nStdout: {response}"
logger.error(f"CLI execution failed:\nCommand: {' '.join(cmd[:3])} [prompt...]\n{error_msg}")
logger.error(
f"CLI execution failed:\nCommand: {' '.join(cmd[:3])} [prompt...]\n{error_msg}"
)
return f"Error: {error_msg}"

if error_output:
Expand Down
Loading