44Handles execution of different AI CLI tools (Claude, Codex, Copilot).
55"""
66
7+ from __future__ import annotations
8+
79import asyncio
810import logging
9- from typing import List , Optional
11+ from typing import TYPE_CHECKING , List , Optional , Union
1012from enum import Enum
1113from dataclasses import dataclass , field
1214
15+ if TYPE_CHECKING :
16+ from .models import ModelInfo
17+
1318logger = logging .getLogger (__name__ )
1419
1520
1621class CLIType (str , Enum ):
1722 """Supported CLI types"""
23+
1824 CLAUDE = "claude"
1925 CODEX = "codex"
2026 COPILOT = "copilot"
@@ -27,12 +33,12 @@ class CLIConfig:
2733 cli_type : CLIType = CLIType .COPILOT
2834 """CLI type to use: claude, codex, or copilot (default: copilot)"""
2935
30- model : Optional [str ] = None
36+ model : Optional [Union [ ModelInfo , str ] ] = None
3137 """
32- Model to use for the agent.
33- - Claude: sonnet, opus, haiku (default: sonnet )
34- - Codex: Uses ChatGPT account or API key
35- - Copilot: Claude Sonnet 4.5 , Claude Sonnet 4, GPT-5 (default: Claude Sonnet 4.5)
38+ Model to use for the agent. Can be a ModelInfo object or model name string.
39+ - Claude Code: Only Claude family models (default: claude-opus-4.5 )
40+ - Codex: Only GPT family models (default: gpt-5)
41+ - Copilot: GPT , Claude, or Gemini models (default: claude-opus- 4.5)
3642 """
3743
3844 skip_permissions : bool = True
@@ -75,17 +81,92 @@ def get_cli_path(self) -> str:
7581 return self .cli_path
7682 return self .cli_type .value
7783
78- def get_default_model (self ) -> str :
79- """Get default model for the CLI type"""
80- if self .model :
84+ def get_model_name (self ) -> str :
85+ """
86+ Get the canonical model name string.
87+
88+ Returns the model name from ModelInfo or the string directly.
89+ If no model is set, returns the default model for the CLI type.
90+ """
91+ if self .model is not None :
92+ # Import here to avoid circular import
93+ from .models import ModelInfo as MI
94+
95+ if isinstance (self .model , MI ):
96+ return self .model .name
8197 return self .model
8298
83- if self .cli_type == CLIType .CLAUDE :
84- return "sonnet"
85- elif self .cli_type == CLIType .COPILOT :
86- return "claude-sonnet-4.5"
87- else : # CODEX
88- return "" # Uses account/API key
99+ # Use defaults from models module
100+ from .models import get_default_model
101+
102+ return get_default_model (self .cli_type )
103+
104+ def get_cli_model_name (self ) -> str :
105+ """
106+ Get the CLI-specific model name.
107+
108+ Different CLIs may use different naming conventions. For example,
109+ Claude CLI uses 'sonnet', 'opus', 'haiku' as aliases instead of
110+ full model names like 'claude-sonnet-4.5'.
111+
112+ Returns:
113+ The model name appropriate for the configured CLI type.
114+ """
115+ from .models import get_cli_model_name as convert_model_name , ModelInfo as MI
116+
117+ if self .model is not None :
118+ if isinstance (self .model , MI ):
119+ return self .model .get_cli_model_name (self .cli_type )
120+ # String model name - convert it
121+ return convert_model_name (self .cli_type , self .model )
122+
123+ # Use default model and convert
124+ from .models import get_default_model
125+
126+ default_model = get_default_model (self .cli_type )
127+ return convert_model_name (self .cli_type , default_model )
128+
129+ def get_model_info (self ) -> Optional [ModelInfo ]:
130+ """
131+ Get the ModelInfo object for the configured model.
132+
133+ Returns ModelInfo if model is set (either directly or looked up by name),
134+ or None if no model is configured.
135+ """
136+ from .models import ModelInfo as MI , get_model_info as lookup_model
137+
138+ if self .model is None :
139+ # Get default model info
140+ from .models import get_default_model
141+
142+ default_name = get_default_model (self .cli_type )
143+ if default_name :
144+ return lookup_model (default_name )
145+ return None
146+
147+ if isinstance (self .model , MI ):
148+ return self .model
149+
150+ # Look up by name
151+ return lookup_model (self .model )
152+
153+ def validate_model (self ) -> None :
154+ """
155+ Validate that the configured model is supported by the CLI type.
156+
157+ Raises:
158+ ValueError: If the model is not supported by the CLI type
159+ """
160+ from .models import validate_model_for_cli
161+
162+ model_name = self .get_model_name ()
163+ if model_name :
164+ validate_model_for_cli (self .cli_type , model_name )
165+
166+ # Keep for backward compatibility
167+ def get_default_model (self ) -> str :
168+ """Get default model for the CLI type (deprecated, use get_model_name)"""
169+ return self .get_model_name ()
89170
90171
91172class CLIExecutor :
@@ -129,8 +210,8 @@ def _build_claude_command(self, base_cmd: List[str], prompt: str) -> List[str]:
129210 if self .config .skip_permissions :
130211 cmd .append ("--dangerously-skip-permissions" )
131212
132- # Add model flag
133- model = self .config .get_default_model ()
213+ # Add model flag (use CLI-specific name for Claude, e.g., 'sonnet', 'opus')
214+ model = self .config .get_cli_model_name ()
134215 if model :
135216 cmd .extend (["--model" , model ])
136217
@@ -207,23 +288,22 @@ async def execute(self, prompt: str) -> str:
207288 * cmd ,
208289 stdout = asyncio .subprocess .PIPE ,
209290 stderr = asyncio .subprocess .PIPE ,
210- cwd = self .config .cwd
291+ cwd = self .config .cwd ,
211292 )
212293
213294 # Wait for completion with timeout
214295 try :
215296 stdout , stderr = await asyncio .wait_for (
216- result .communicate (),
217- timeout = self .config .timeout
297+ result .communicate (), timeout = self .config .timeout
218298 )
219299 except asyncio .TimeoutError :
220300 result .kill ()
221301 await result .wait ()
222302 return f"Error: CLI execution timed out after { self .config .timeout } seconds"
223303
224304 # Decode output
225- response = stdout .decode (' utf-8' ).strip ()
226- error_output = stderr .decode (' utf-8' ).strip ()
305+ response = stdout .decode (" utf-8" ).strip ()
306+ error_output = stderr .decode (" utf-8" ).strip ()
227307
228308 logger .debug (f"Response length: { len (response )} chars" )
229309 logger .debug (f"Response preview: { response [:200 ]} ..." )
@@ -235,7 +315,9 @@ async def execute(self, prompt: str) -> str:
235315 error_msg += f"\n Stderr: { error_output } "
236316 if response :
237317 error_msg += f"\n Stdout: { response } "
238- logger .error (f"CLI execution failed:\n Command: { ' ' .join (cmd [:3 ])} [prompt...]\n { error_msg } " )
318+ logger .error (
319+ f"CLI execution failed:\n Command: { ' ' .join (cmd [:3 ])} [prompt...]\n { error_msg } "
320+ )
239321 return f"Error: { error_msg } "
240322
241323 if error_output :
0 commit comments