Skip to content

Commit ff63f07

Browse files
feat: Introduce model definitions and enhance CLI configuration (#544)
This commit adds a new models module that defines supported AI models organized by family (GPT, Claude, Gemini) and their respective CLI configurations. Key changes include: - New `models.py` file containing model definitions and helper functions for model management. - Updates to `CLIConfig` in `cli_executor.py` to support model validation and retrieval of model-specific names. - Refactoring of model handling in CLI configuration methods to improve clarity and functionality. - Enhanced test coverage for CLI configuration methods to ensure correct model retrieval and validation. These changes improve the overall structure and usability of the agent's model management system, facilitating better integration with various CLI tools. 🤖 Generated with Claude Code
1 parent a76a6c5 commit ff63f07

File tree

4 files changed

+438
-28
lines changed

4 files changed

+438
-28
lines changed

utils/agent/__init__.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,58 @@
88
from .agent import SpecializedAgent, AgentConfig
99
from .cli_executor import CLIExecutor, CLIConfig, CLIType
1010
from .system_prompt import SystemPromptBuilder, create_default_system_prompt
11+
from .models import (
12+
ModelFamily,
13+
ModelInfo,
14+
GPT_MODELS,
15+
CLAUDE_MODELS,
16+
GEMINI_MODELS,
17+
ALL_MODELS,
18+
MODEL_BY_NAME,
19+
MODELS_BY_FAMILY,
20+
CLI_SUPPORTED_FAMILIES,
21+
CLI_DEFAULT_MODELS,
22+
get_supported_families,
23+
get_supported_models,
24+
get_supported_model_names,
25+
is_model_supported,
26+
get_model_info,
27+
get_model_family,
28+
get_default_model,
29+
validate_model_for_cli,
30+
get_cli_model_name,
31+
)
1132

1233
__all__ = [
34+
# Agent classes
1335
"SpecializedAgent",
1436
"AgentConfig",
37+
# CLI executor
1538
"CLIExecutor",
1639
"CLIConfig",
1740
"CLIType",
41+
# System prompt
1842
"SystemPromptBuilder",
1943
"create_default_system_prompt",
44+
# Model definitions
45+
"ModelFamily",
46+
"ModelInfo",
47+
"GPT_MODELS",
48+
"CLAUDE_MODELS",
49+
"GEMINI_MODELS",
50+
"ALL_MODELS",
51+
"MODEL_BY_NAME",
52+
"MODELS_BY_FAMILY",
53+
"CLI_SUPPORTED_FAMILIES",
54+
"CLI_DEFAULT_MODELS",
55+
# Model helper functions
56+
"get_supported_families",
57+
"get_supported_models",
58+
"get_supported_model_names",
59+
"is_model_supported",
60+
"get_model_info",
61+
"get_model_family",
62+
"get_default_model",
63+
"validate_model_for_cli",
64+
"get_cli_model_name",
2065
]

utils/agent/cli_executor.py

Lines changed: 105 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,23 @@
44
Handles execution of different AI CLI tools (Claude, Codex, Copilot).
55
"""
66

7+
from __future__ import annotations
8+
79
import asyncio
810
import logging
9-
from typing import List, Optional
11+
from typing import TYPE_CHECKING, List, Optional, Union
1012
from enum import Enum
1113
from dataclasses import dataclass, field
1214

15+
if TYPE_CHECKING:
16+
from .models import ModelInfo
17+
1318
logger = logging.getLogger(__name__)
1419

1520

1621
class CLIType(str, Enum):
1722
"""Supported CLI types"""
23+
1824
CLAUDE = "claude"
1925
CODEX = "codex"
2026
COPILOT = "copilot"
@@ -27,12 +33,12 @@ class CLIConfig:
2733
cli_type: CLIType = CLIType.COPILOT
2834
"""CLI type to use: claude, codex, or copilot (default: copilot)"""
2935

30-
model: Optional[str] = None
36+
model: Optional[Union[ModelInfo, str]] = None
3137
"""
32-
Model to use for the agent.
33-
- Claude: sonnet, opus, haiku (default: sonnet)
34-
- Codex: Uses ChatGPT account or API key
35-
- Copilot: Claude Sonnet 4.5, Claude Sonnet 4, GPT-5 (default: Claude Sonnet 4.5)
38+
Model to use for the agent. Can be a ModelInfo object or model name string.
39+
- Claude Code: Only Claude family models (default: claude-opus-4.5)
40+
- Codex: Only GPT family models (default: gpt-5)
41+
- Copilot: GPT, Claude, or Gemini models (default: claude-opus-4.5)
3642
"""
3743

3844
skip_permissions: bool = True
@@ -75,17 +81,92 @@ def get_cli_path(self) -> str:
7581
return self.cli_path
7682
return self.cli_type.value
7783

78-
def get_default_model(self) -> str:
79-
"""Get default model for the CLI type"""
80-
if self.model:
84+
def get_model_name(self) -> str:
85+
"""
86+
Get the canonical model name string.
87+
88+
Returns the model name from ModelInfo or the string directly.
89+
If no model is set, returns the default model for the CLI type.
90+
"""
91+
if self.model is not None:
92+
# Import here to avoid circular import
93+
from .models import ModelInfo as MI
94+
95+
if isinstance(self.model, MI):
96+
return self.model.name
8197
return self.model
8298

83-
if self.cli_type == CLIType.CLAUDE:
84-
return "sonnet"
85-
elif self.cli_type == CLIType.COPILOT:
86-
return "claude-sonnet-4.5"
87-
else: # CODEX
88-
return "" # Uses account/API key
99+
# Use defaults from models module
100+
from .models import get_default_model
101+
102+
return get_default_model(self.cli_type)
103+
104+
def get_cli_model_name(self) -> str:
105+
"""
106+
Get the CLI-specific model name.
107+
108+
Different CLIs may use different naming conventions. For example,
109+
Claude CLI uses 'sonnet', 'opus', 'haiku' as aliases instead of
110+
full model names like 'claude-sonnet-4.5'.
111+
112+
Returns:
113+
The model name appropriate for the configured CLI type.
114+
"""
115+
from .models import get_cli_model_name as convert_model_name, ModelInfo as MI
116+
117+
if self.model is not None:
118+
if isinstance(self.model, MI):
119+
return self.model.get_cli_model_name(self.cli_type)
120+
# String model name - convert it
121+
return convert_model_name(self.cli_type, self.model)
122+
123+
# Use default model and convert
124+
from .models import get_default_model
125+
126+
default_model = get_default_model(self.cli_type)
127+
return convert_model_name(self.cli_type, default_model)
128+
129+
def get_model_info(self) -> Optional[ModelInfo]:
130+
"""
131+
Get the ModelInfo object for the configured model.
132+
133+
Returns ModelInfo if model is set (either directly or looked up by name),
134+
or None if no model is configured.
135+
"""
136+
from .models import ModelInfo as MI, get_model_info as lookup_model
137+
138+
if self.model is None:
139+
# Get default model info
140+
from .models import get_default_model
141+
142+
default_name = get_default_model(self.cli_type)
143+
if default_name:
144+
return lookup_model(default_name)
145+
return None
146+
147+
if isinstance(self.model, MI):
148+
return self.model
149+
150+
# Look up by name
151+
return lookup_model(self.model)
152+
153+
def validate_model(self) -> None:
154+
"""
155+
Validate that the configured model is supported by the CLI type.
156+
157+
Raises:
158+
ValueError: If the model is not supported by the CLI type
159+
"""
160+
from .models import validate_model_for_cli
161+
162+
model_name = self.get_model_name()
163+
if model_name:
164+
validate_model_for_cli(self.cli_type, model_name)
165+
166+
# Keep for backward compatibility
167+
def get_default_model(self) -> str:
168+
"""Get default model for the CLI type (deprecated, use get_model_name)"""
169+
return self.get_model_name()
89170

90171

91172
class CLIExecutor:
@@ -129,8 +210,8 @@ def _build_claude_command(self, base_cmd: List[str], prompt: str) -> List[str]:
129210
if self.config.skip_permissions:
130211
cmd.append("--dangerously-skip-permissions")
131212

132-
# Add model flag
133-
model = self.config.get_default_model()
213+
# Add model flag (use CLI-specific name for Claude, e.g., 'sonnet', 'opus')
214+
model = self.config.get_cli_model_name()
134215
if model:
135216
cmd.extend(["--model", model])
136217

@@ -207,23 +288,22 @@ async def execute(self, prompt: str) -> str:
207288
*cmd,
208289
stdout=asyncio.subprocess.PIPE,
209290
stderr=asyncio.subprocess.PIPE,
210-
cwd=self.config.cwd
291+
cwd=self.config.cwd,
211292
)
212293

213294
# Wait for completion with timeout
214295
try:
215296
stdout, stderr = await asyncio.wait_for(
216-
result.communicate(),
217-
timeout=self.config.timeout
297+
result.communicate(), timeout=self.config.timeout
218298
)
219299
except asyncio.TimeoutError:
220300
result.kill()
221301
await result.wait()
222302
return f"Error: CLI execution timed out after {self.config.timeout} seconds"
223303

224304
# Decode output
225-
response = stdout.decode('utf-8').strip()
226-
error_output = stderr.decode('utf-8').strip()
305+
response = stdout.decode("utf-8").strip()
306+
error_output = stderr.decode("utf-8").strip()
227307

228308
logger.debug(f"Response length: {len(response)} chars")
229309
logger.debug(f"Response preview: {response[:200]}...")
@@ -235,7 +315,9 @@ async def execute(self, prompt: str) -> str:
235315
error_msg += f"\nStderr: {error_output}"
236316
if response:
237317
error_msg += f"\nStdout: {response}"
238-
logger.error(f"CLI execution failed:\nCommand: {' '.join(cmd[:3])} [prompt...]\n{error_msg}")
318+
logger.error(
319+
f"CLI execution failed:\nCommand: {' '.join(cmd[:3])} [prompt...]\n{error_msg}"
320+
)
239321
return f"Error: {error_msg}"
240322

241323
if error_output:

0 commit comments

Comments
 (0)