diff --git a/utils/agent/__init__.py b/utils/agent/__init__.py index 7071362..355eaf2 100644 --- a/utils/agent/__init__.py +++ b/utils/agent/__init__.py @@ -8,13 +8,58 @@ from .agent import SpecializedAgent, AgentConfig from .cli_executor import CLIExecutor, CLIConfig, CLIType from .system_prompt import SystemPromptBuilder, create_default_system_prompt +from .models import ( + ModelFamily, + ModelInfo, + GPT_MODELS, + CLAUDE_MODELS, + GEMINI_MODELS, + ALL_MODELS, + MODEL_BY_NAME, + MODELS_BY_FAMILY, + CLI_SUPPORTED_FAMILIES, + CLI_DEFAULT_MODELS, + get_supported_families, + get_supported_models, + get_supported_model_names, + is_model_supported, + get_model_info, + get_model_family, + get_default_model, + validate_model_for_cli, + get_cli_model_name, +) __all__ = [ + # Agent classes "SpecializedAgent", "AgentConfig", + # CLI executor "CLIExecutor", "CLIConfig", "CLIType", + # System prompt "SystemPromptBuilder", "create_default_system_prompt", + # Model definitions + "ModelFamily", + "ModelInfo", + "GPT_MODELS", + "CLAUDE_MODELS", + "GEMINI_MODELS", + "ALL_MODELS", + "MODEL_BY_NAME", + "MODELS_BY_FAMILY", + "CLI_SUPPORTED_FAMILIES", + "CLI_DEFAULT_MODELS", + # Model helper functions + "get_supported_families", + "get_supported_models", + "get_supported_model_names", + "is_model_supported", + "get_model_info", + "get_model_family", + "get_default_model", + "validate_model_for_cli", + "get_cli_model_name", ] diff --git a/utils/agent/cli_executor.py b/utils/agent/cli_executor.py index 4296a82..6f41e63 100644 --- a/utils/agent/cli_executor.py +++ b/utils/agent/cli_executor.py @@ -4,17 +4,23 @@ Handles execution of different AI CLI tools (Claude, Codex, Copilot). """ +from __future__ import annotations + import asyncio import logging -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional, Union from enum import Enum from dataclasses import dataclass, field +if TYPE_CHECKING: + from .models import ModelInfo + logger = logging.getLogger(__name__) class CLIType(str, Enum): """Supported CLI types""" + CLAUDE = "claude" CODEX = "codex" COPILOT = "copilot" @@ -27,12 +33,12 @@ class CLIConfig: cli_type: CLIType = CLIType.COPILOT """CLI type to use: claude, codex, or copilot (default: copilot)""" - model: Optional[str] = None + model: Optional[Union[ModelInfo, str]] = None """ - Model to use for the agent. - - Claude: sonnet, opus, haiku (default: sonnet) - - Codex: Uses ChatGPT account or API key - - Copilot: Claude Sonnet 4.5, Claude Sonnet 4, GPT-5 (default: Claude Sonnet 4.5) + Model to use for the agent. Can be a ModelInfo object or model name string. + - Claude Code: Only Claude family models (default: claude-opus-4.5) + - Codex: Only GPT family models (default: gpt-5) + - Copilot: GPT, Claude, or Gemini models (default: claude-opus-4.5) """ skip_permissions: bool = True @@ -75,17 +81,92 @@ def get_cli_path(self) -> str: return self.cli_path return self.cli_type.value - def get_default_model(self) -> str: - """Get default model for the CLI type""" - if self.model: + def get_model_name(self) -> str: + """ + Get the canonical model name string. + + Returns the model name from ModelInfo or the string directly. + If no model is set, returns the default model for the CLI type. + """ + if self.model is not None: + # Import here to avoid circular import + from .models import ModelInfo as MI + + if isinstance(self.model, MI): + return self.model.name return self.model - if self.cli_type == CLIType.CLAUDE: - return "sonnet" - elif self.cli_type == CLIType.COPILOT: - return "claude-sonnet-4.5" - else: # CODEX - return "" # Uses account/API key + # Use defaults from models module + from .models import get_default_model + + return get_default_model(self.cli_type) + + def get_cli_model_name(self) -> str: + """ + Get the CLI-specific model name. + + Different CLIs may use different naming conventions. For example, + Claude CLI uses 'sonnet', 'opus', 'haiku' as aliases instead of + full model names like 'claude-sonnet-4.5'. + + Returns: + The model name appropriate for the configured CLI type. + """ + from .models import get_cli_model_name as convert_model_name, ModelInfo as MI + + if self.model is not None: + if isinstance(self.model, MI): + return self.model.get_cli_model_name(self.cli_type) + # String model name - convert it + return convert_model_name(self.cli_type, self.model) + + # Use default model and convert + from .models import get_default_model + + default_model = get_default_model(self.cli_type) + return convert_model_name(self.cli_type, default_model) + + def get_model_info(self) -> Optional[ModelInfo]: + """ + Get the ModelInfo object for the configured model. + + Returns ModelInfo if model is set (either directly or looked up by name), + or None if no model is configured. + """ + from .models import ModelInfo as MI, get_model_info as lookup_model + + if self.model is None: + # Get default model info + from .models import get_default_model + + default_name = get_default_model(self.cli_type) + if default_name: + return lookup_model(default_name) + return None + + if isinstance(self.model, MI): + return self.model + + # Look up by name + return lookup_model(self.model) + + def validate_model(self) -> None: + """ + Validate that the configured model is supported by the CLI type. + + Raises: + ValueError: If the model is not supported by the CLI type + """ + from .models import validate_model_for_cli + + model_name = self.get_model_name() + if model_name: + validate_model_for_cli(self.cli_type, model_name) + + # Keep for backward compatibility + def get_default_model(self) -> str: + """Get default model for the CLI type (deprecated, use get_model_name)""" + return self.get_model_name() class CLIExecutor: @@ -129,8 +210,8 @@ def _build_claude_command(self, base_cmd: List[str], prompt: str) -> List[str]: if self.config.skip_permissions: cmd.append("--dangerously-skip-permissions") - # Add model flag - model = self.config.get_default_model() + # Add model flag (use CLI-specific name for Claude, e.g., 'sonnet', 'opus') + model = self.config.get_cli_model_name() if model: cmd.extend(["--model", model]) @@ -207,14 +288,13 @@ async def execute(self, prompt: str) -> str: *cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, - cwd=self.config.cwd + cwd=self.config.cwd, ) # Wait for completion with timeout try: stdout, stderr = await asyncio.wait_for( - result.communicate(), - timeout=self.config.timeout + result.communicate(), timeout=self.config.timeout ) except asyncio.TimeoutError: result.kill() @@ -222,8 +302,8 @@ async def execute(self, prompt: str) -> str: return f"Error: CLI execution timed out after {self.config.timeout} seconds" # Decode output - response = stdout.decode('utf-8').strip() - error_output = stderr.decode('utf-8').strip() + response = stdout.decode("utf-8").strip() + error_output = stderr.decode("utf-8").strip() logger.debug(f"Response length: {len(response)} chars") logger.debug(f"Response preview: {response[:200]}...") @@ -235,7 +315,9 @@ async def execute(self, prompt: str) -> str: error_msg += f"\nStderr: {error_output}" if response: error_msg += f"\nStdout: {response}" - logger.error(f"CLI execution failed:\nCommand: {' '.join(cmd[:3])} [prompt...]\n{error_msg}") + logger.error( + f"CLI execution failed:\nCommand: {' '.join(cmd[:3])} [prompt...]\n{error_msg}" + ) return f"Error: {error_msg}" if error_output: diff --git a/utils/agent/models.py b/utils/agent/models.py new file mode 100644 index 0000000..33fa937 --- /dev/null +++ b/utils/agent/models.py @@ -0,0 +1,277 @@ +""" +Model Definitions Module + +Defines supported AI models organized by family (GPT, Claude, Gemini) +and maps which CLI executors support which model families. +""" + +from enum import Enum +from dataclasses import dataclass +from typing import Dict, List, Set, Optional + +from .cli_executor import CLIType + + +class ModelFamily(str, Enum): + """Model family categories""" + + GPT = "gpt" + CLAUDE = "claude" + GEMINI = "gemini" + + +@dataclass(frozen=True) +class ModelInfo: + """Information about a specific model""" + + name: str + """The canonical model identifier/name""" + + family: ModelFamily + """The model family this model belongs to""" + + display_name: str + """Human-readable display name""" + + description: str = "" + """Optional description of the model""" + + claude_cli_name: str = "" + """Model name/alias for Claude CLI (e.g., 'sonnet', 'opus', 'haiku')""" + + def get_cli_model_name(self, cli_type: CLIType) -> str: + """ + Get the appropriate model name for a specific CLI type. + + Args: + cli_type: The CLI type to get the model name for + + Returns: + The CLI-specific model name + """ + if cli_type == CLIType.CLAUDE and self.claude_cli_name: + return self.claude_cli_name + return self.name + + +# GPT Family Models +GPT_MODELS: List[ModelInfo] = [ + ModelInfo( + name="gpt-5", + family=ModelFamily.GPT, + display_name="GPT-5", + description="Latest GPT-5 model", + ), +] + +# Claude Family Models +CLAUDE_MODELS: List[ModelInfo] = [ + ModelInfo( + name="claude-sonnet-4.5", + family=ModelFamily.CLAUDE, + display_name="Claude Sonnet 4.5", + description="Claude Sonnet 4.5 model", + claude_cli_name="sonnet", + ), + ModelInfo( + name="claude-haiku-4.5", + family=ModelFamily.CLAUDE, + display_name="Claude Haiku 4.5", + description="Fast and efficient Claude Haiku 4.5 model", + claude_cli_name="haiku", + ), + ModelInfo( + name="claude-opus-4.5", + family=ModelFamily.CLAUDE, + display_name="Claude Opus 4.5", + description="Most capable Claude Opus 4.5 model", + claude_cli_name="opus", + ), +] + +# Gemini Family Models +GEMINI_MODELS: List[ModelInfo] = [ + ModelInfo( + name="gemini-3.0-pro", + family=ModelFamily.GEMINI, + display_name="Gemini 3.0 Pro", + description="Latest Gemini 3.0 Pro model", + ), +] + +# All models combined +ALL_MODELS: List[ModelInfo] = GPT_MODELS + CLAUDE_MODELS + GEMINI_MODELS + +# Model lookup by name +MODEL_BY_NAME: Dict[str, ModelInfo] = {model.name: model for model in ALL_MODELS} + +# Models by family +MODELS_BY_FAMILY: Dict[ModelFamily, List[ModelInfo]] = { + ModelFamily.GPT: GPT_MODELS, + ModelFamily.CLAUDE: CLAUDE_MODELS, + ModelFamily.GEMINI: GEMINI_MODELS, +} + +# CLI type to supported model families mapping +# - Codex: Only supports GPT family (uses OpenAI/ChatGPT) +# - Claude Code: Only supports Claude family +# - Copilot: Supports all families (GPT, Claude, Gemini) +CLI_SUPPORTED_FAMILIES: Dict[CLIType, Set[ModelFamily]] = { + CLIType.CODEX: {ModelFamily.GPT}, + CLIType.CLAUDE: {ModelFamily.CLAUDE}, + CLIType.COPILOT: {ModelFamily.GPT, ModelFamily.CLAUDE, ModelFamily.GEMINI}, +} + +# Default models for each CLI type +CLI_DEFAULT_MODELS: Dict[CLIType, str] = { + CLIType.CODEX: "gpt-5", + CLIType.CLAUDE: "claude-opus-4.5", + CLIType.COPILOT: "claude-opus-4.5", +} + + +def get_supported_families(cli_type: CLIType) -> Set[ModelFamily]: + """ + Get the set of model families supported by a CLI type. + + Args: + cli_type: The CLI type to check + + Returns: + Set of supported model families + """ + return CLI_SUPPORTED_FAMILIES.get(cli_type, set()) + + +def get_supported_models(cli_type: CLIType) -> List[ModelInfo]: + """ + Get all models supported by a CLI type. + + Args: + cli_type: The CLI type to check + + Returns: + List of supported ModelInfo objects + """ + supported_families = get_supported_families(cli_type) + return [model for model in ALL_MODELS if model.family in supported_families] + + +def get_supported_model_names(cli_type: CLIType) -> List[str]: + """ + Get all model names supported by a CLI type. + + Args: + cli_type: The CLI type to check + + Returns: + List of supported model name strings + """ + return [model.name for model in get_supported_models(cli_type)] + + +def is_model_supported(cli_type: CLIType, model_name: str) -> bool: + """ + Check if a model is supported by a CLI type. + + Args: + cli_type: The CLI type to check + model_name: The model name to validate + + Returns: + True if the model is supported, False otherwise + """ + model_info = MODEL_BY_NAME.get(model_name) + if model_info is None: + return False + return model_info.family in get_supported_families(cli_type) + + +def get_model_info(model_name: str) -> Optional[ModelInfo]: + """ + Get model information by name. + + Args: + model_name: The model name to look up + + Returns: + ModelInfo if found, None otherwise + """ + return MODEL_BY_NAME.get(model_name) + + +def get_model_family(model_name: str) -> Optional[ModelFamily]: + """ + Get the family of a model by name. + + Args: + model_name: The model name to look up + + Returns: + ModelFamily if found, None otherwise + """ + model_info = get_model_info(model_name) + return model_info.family if model_info else None + + +def get_default_model(cli_type: CLIType) -> str: + """ + Get the default model for a CLI type. + + Args: + cli_type: The CLI type + + Returns: + Default model name string + """ + return CLI_DEFAULT_MODELS.get(cli_type, "") + + +def validate_model_for_cli(cli_type: CLIType, model_name: str) -> None: + """ + Validate that a model is supported by a CLI type. + + Args: + cli_type: The CLI type to check + model_name: The model name to validate + + Raises: + ValueError: If the model is not supported + """ + if not model_name: + return # Empty model means use default + + if not is_model_supported(cli_type, model_name): + supported = get_supported_model_names(cli_type) + model_info = get_model_info(model_name) + if model_info: + raise ValueError( + f"Model '{model_name}' (family: {model_info.family.value}) " + f"is not supported by {cli_type.value}. " + f"Supported models: {', '.join(supported)}" + ) + else: + raise ValueError( + f"Unknown model '{model_name}'. " + f"Supported models for {cli_type.value}: {', '.join(supported)}" + ) + + +def get_cli_model_name(cli_type: CLIType, model_name: str) -> str: + """ + Get the CLI-specific model name for a given model. + + Different CLIs may use different naming conventions. For example, + Claude CLI uses 'sonnet', 'opus', 'haiku' as aliases. + + Args: + cli_type: The CLI type to get the model name for + model_name: The canonical model name + + Returns: + The CLI-specific model name + """ + model_info = get_model_info(model_name) + if model_info: + return model_info.get_cli_model_name(cli_type) + return model_name diff --git a/utils/agent/tests/test_cli_executor.py b/utils/agent/tests/test_cli_executor.py index 11ef62b..0f729a8 100644 --- a/utils/agent/tests/test_cli_executor.py +++ b/utils/agent/tests/test_cli_executor.py @@ -70,20 +70,26 @@ def test_get_cli_path_custom(self): def test_get_default_model_claude(self): """Test getting default model for Claude""" config = CLIConfig(cli_type=CLIType.CLAUDE) - assert config.get_default_model() == "sonnet" + # Default model is claude-opus-4.5, but CLI name is 'opus' + assert config.get_model_name() == "claude-opus-4.5" + assert config.get_cli_model_name() == "opus" - config = CLIConfig(cli_type=CLIType.CLAUDE, model="haiku") - assert config.get_default_model() == "haiku" + config = CLIConfig(cli_type=CLIType.CLAUDE, model="claude-haiku-4.5") + assert config.get_model_name() == "claude-haiku-4.5" + assert config.get_cli_model_name() == "haiku" def test_get_default_model_copilot(self): """Test getting default model for Copilot""" config = CLIConfig(cli_type=CLIType.COPILOT) - assert config.get_default_model() == "claude-sonnet-4.5" + assert config.get_model_name() == "claude-opus-4.5" + # Copilot uses full model names + assert config.get_cli_model_name() == "claude-opus-4.5" def test_get_default_model_codex(self): """Test getting default model for Codex""" config = CLIConfig(cli_type=CLIType.CODEX) - assert config.get_default_model() == "" + assert config.get_model_name() == "gpt-5" + assert config.get_cli_model_name() == "gpt-5" class TestCLIExecutor: