diff --git a/cecli/__init__.py b/cecli/__init__.py index ac9835ec20c..8afa159985f 100644 --- a/cecli/__init__.py +++ b/cecli/__init__.py @@ -1,6 +1,6 @@ from packaging import version -__version__ = "0.96.1.dev" +__version__ = "0.96.2.dev" safe_version = __version__ try: diff --git a/cecli/coders/agent_coder.py b/cecli/coders/agent_coder.py index 10618f6c8ca..db107fc2544 100644 --- a/cecli/coders/agent_coder.py +++ b/cecli/coders/agent_coder.py @@ -12,8 +12,6 @@ from datetime import datetime from pathlib import Path -from litellm import experimental_mcp_client - from cecli import urls, utils from cecli.change_tracker import ChangeTracker from cecli.helpers import nested @@ -29,6 +27,7 @@ normalize_vector, ) from cecli.helpers.skills import SkillsManager +from cecli.llm import litellm from cecli.mcp import LocalServer, McpServerManager from cecli.repo import ANY_GIT_ERROR from cecli.tools.utils.registry import ToolRegistry @@ -306,7 +305,7 @@ async def _exec_async(): } try: session = await server.connect() - call_result = await experimental_mcp_client.call_openai_tool( + call_result = await litellm.experimental_mcp_client.call_openai_tool( session=session, openai_tool=tool_call_dict ) content_parts = [] diff --git a/cecli/commands/tokens.py b/cecli/commands/tokens.py index 9e11fd5d58b..11e785b7d28 100644 --- a/cecli/commands/tokens.py +++ b/cecli/commands/tokens.py @@ -2,8 +2,7 @@ from cecli.commands.utils.base_command import BaseCommand from cecli.commands.utils.helpers import format_command_result -from cecli.helpers.conversation import ConversationManager -from cecli.utils import is_image_file +from cecli.helpers.conversation import ConversationManager, MessageTag class TokensCommand(BaseCommand): @@ -15,39 +14,29 @@ async def execute(cls, io, coder, args, **kwargs): res = [] coder.choose_fence() + coder.format_chat_chunks() # Show progress indicator total_files = len(coder.abs_fnames) + len(coder.abs_read_only_fnames) if total_files > 20: io.tool_output(f"Calculating tokens for {total_files} files...") - # system messages - main_sys = coder.fmt_system_prompt(coder.gpt_prompts.main_system) - main_sys += "\n" + coder.fmt_system_prompt(coder.gpt_prompts.system_reminder) - msgs = [ - dict(role="system", content=main_sys), - dict( - role="system", - content=coder.fmt_system_prompt(coder.gpt_prompts.system_reminder), - ), + # system messages - sum of SYSTEM, STATIC, EXAMPLES, and REMINDER tags + system_tags = [ + MessageTag.SYSTEM, + MessageTag.STATIC, + MessageTag.EXAMPLES, + MessageTag.REMINDER, ] + system_tokens = 0 - tokens = coder.main_model.token_count(msgs) - res.append((tokens, "system messages", "")) + for tag in system_tags: + msgs = ConversationManager.get_messages_dict(tag=tag) + if msgs: + system_tokens += coder.main_model.token_count(msgs) - # chat history - msgs = ConversationManager.get_messages_dict() - if msgs: - tokens = coder.main_model.token_count(msgs) - res.append((tokens, "chat history", "use /clear to clear")) - - # repo map - other_files = set(coder.get_all_abs_files()) - set(coder.abs_fnames) - if coder.repo_map: - repo_content = coder.repo_map.get_repo_map(coder.abs_fnames, other_files) - if repo_content: - tokens = coder.main_model.token_count(repo_content) - res.append((tokens, "repository map", "use --map-tokens to resize")) + # Calculate context block tokens (they are part of STATIC messages) + context_block_total = 0 # Enhanced context blocks (only for agent mode) if hasattr(coder, "use_enhanced_context") and coder.use_enhanced_context: @@ -56,86 +45,124 @@ async def execute(cls, io, coder, args, **kwargs): if not hasattr(coder, "tokens_calculated") or not coder.tokens_calculated: coder._calculate_context_block_tokens() - # Add enhanced context blocks to the display + # Calculate total context block tokens if hasattr(coder, "context_block_tokens") and coder.context_block_tokens: - for block_name, tokens in coder.context_block_tokens.items(): - # Format the block name more nicely - display_name = block_name.replace("_", " ").title() - res.append( - (tokens, f"{display_name} context block", "/context-blocks to toggle") - ) + context_block_total = sum(coder.context_block_tokens.values()) - fence = "`" * 3 + # Subtract context block tokens from system token count + # Context blocks are part of STATIC messages, so we need to subtract them + system_tokens = max(0, system_tokens - context_block_total) - file_res = [] - # Process files with progress indication - total_editable_files = len(coder.abs_fnames) - total_readonly_files = len(coder.abs_read_only_fnames) + res.append((system_tokens, "system messages", "")) - # Display progress for editable files - if total_editable_files > 0: - if total_editable_files > 20: - io.tool_output(f"Calculating tokens for {total_editable_files} editable files...") + # chat history + msgs_done = ConversationManager.get_messages_dict(tag=MessageTag.DONE) + msgs_cur = ConversationManager.get_messages_dict(tag=MessageTag.CUR) + tokens_done = 0 + tokens_cur = 0 - # Calculate tokens for editable files - for i, fname in enumerate(coder.abs_fnames): - if i > 0 and i % 20 == 0 and total_editable_files > 20: - io.tool_output(f"Processed {i}/{total_editable_files} editable files...") + if msgs_done: + tokens_done = coder.main_model.token_count(msgs_done) - relative_fname = coder.get_rel_fname(fname) - content = io.read_text(fname) + if msgs_cur: + tokens_cur = coder.main_model.token_count(msgs_cur) - if not content: - continue + if tokens_cur + tokens_done: + res.append((tokens_cur + tokens_done, "chat history", "use /clear to clear")) - if is_image_file(relative_fname): - tokens = coder.main_model.token_count_for_image(fname) - else: - # approximate - content = f"{relative_fname}\n{fence}\n" + content + f"{fence}\n" - tokens = coder.main_model.token_count(content) - file_res.append((tokens, f"{relative_fname}", "/drop to remove")) + # repo map + if coder.repo_map: + tokens = coder.main_model.token_count( + ConversationManager.get_messages_dict(tag=MessageTag.REPO) + ) + res.append((tokens, "repository map", "use --map-tokens to resize")) - # Display progress for read-only files - if total_readonly_files > 0: - if total_readonly_files > 20: - io.tool_output(f"Calculating tokens for {total_readonly_files} read-only files...") + # Display enhanced context blocks (only for agent mode) + # Note: Context block tokens were already calculated and subtracted from system messages + if hasattr(coder, "use_enhanced_context") and coder.use_enhanced_context: + if hasattr(coder, "context_block_tokens") and coder.context_block_tokens: + for block_name, tokens in coder.context_block_tokens.items(): + # Format the block name more nicely + display_name = block_name.replace("_", " ").title() + res.append( + (tokens, f"{display_name} context block", "/context-blocks to toggle") + ) - # Calculate tokens for read-only files - for i, fname in enumerate(coder.abs_read_only_fnames): - if i > 0 and i % 20 == 0 and total_readonly_files > 20: - io.tool_output(f"Processed {i}/{total_readonly_files} read-only files...") + file_res = [] + # Calculate tokens for read-only files using READONLY_FILES tag + readonly_msgs = ConversationManager.get_messages_dict(tag=MessageTag.READONLY_FILES) + if readonly_msgs: + # Group messages by file (each file has user and assistant messages) + file_tokens = {} + for msg in readonly_msgs: + # Extract file name from message content + content = msg.get("content", "") + if content.startswith("File Contents"): + # Extract file path from "File Contents {path}:" + lines = content.split("\n", 1) + if lines: + file_line = lines[0] + if file_line.startswith("File Contents"): + fname = file_line[13:].rstrip(":") + # Calculate tokens for this message + tokens = coder.main_model.token_count([msg]) + if fname not in file_tokens: + file_tokens[fname] = 0 + file_tokens[fname] += tokens + elif "image_file" in msg: + # Handle image files + fname = msg.get("image_file") + if fname: + tokens = coder.main_model.token_count([msg]) + if fname not in file_tokens: + file_tokens[fname] = 0 + file_tokens[fname] += tokens + + # Add to results + for fname, tokens in file_tokens.items(): relative_fname = coder.get_rel_fname(fname) - content = io.read_text(fname) - - if not content: - continue - - if not is_image_file(relative_fname): - # approximate - content = f"{relative_fname}\n{fence}\n" + content + f"{fence}\n" - tokens = coder.main_model.token_count(content) - file_res.append((tokens, f"{relative_fname} (read-only)", "/drop to remove")) - - if total_files > 20: - io.tool_output("Token calculation complete. Generating report...") - - file_res.sort() - res.extend(file_res) - - # stub files - for fname in coder.abs_read_only_stubs_fnames: + file_res.append((tokens, f"{relative_fname} (read-only)", "/drop to remove")) + + # Calculate tokens for editable files using CHAT_FILES and EDIT_FILES tags + editable_tags = [MessageTag.CHAT_FILES, MessageTag.EDIT_FILES] + editable_file_tokens = {} + + for tag in editable_tags: + msgs = ConversationManager.get_messages_dict(tag=tag) + if msgs: + for msg in msgs: + # Extract file name from message content + content = msg.get("content", "") + if content.startswith("File Contents"): + # Extract file path from "File Contents {path}:" + lines = content.split("\n", 1) + if lines: + file_line = lines[0] + if file_line.startswith("File Contents"): + fname = file_line[13:].rstrip(":") + # Calculate tokens for this message + tokens = coder.main_model.token_count([msg]) + if fname not in editable_file_tokens: + editable_file_tokens[fname] = 0 + editable_file_tokens[fname] += tokens + elif "image_file" in msg: + # Handle image files + fname = msg.get("image_file") + if fname: + tokens = coder.main_model.token_count([msg]) + if fname not in editable_file_tokens: + editable_file_tokens[fname] = 0 + editable_file_tokens[fname] += tokens + + # Add editable files to results + for fname, tokens in editable_file_tokens.items(): relative_fname = coder.get_rel_fname(fname) - if not is_image_file(relative_fname): - stub = coder.get_file_stub(fname) - - if not stub: - continue + file_res.append((tokens, f"{relative_fname}", "/drop to remove")) - content = f"{relative_fname} (stub)\n{fence}\n" + stub + "{fence}\n" - tokens = coder.main_model.token_count(content) - res.append((tokens, f"{relative_fname} (read-only stub)", "/drop to remove")) + if file_res: + file_res.sort() + res.extend(file_res) io.tool_output(f"Approximate context window usage for {coder.main_model.name}, in tokens:") io.tool_output() diff --git a/cecli/helpers/conversation/base_message.py b/cecli/helpers/conversation/base_message.py index 79f0be2505e..e763b2f9bcb 100644 --- a/cecli/helpers/conversation/base_message.py +++ b/cecli/helpers/conversation/base_message.py @@ -1,3 +1,4 @@ +import json import time import uuid from dataclasses import dataclass, field @@ -54,6 +55,9 @@ def _transform_message(self, tool_calls): tool_calls_list.append(tool_call) return tool_calls_list + def _serialize_default(self, content): + return "" + def generate_id(self) -> str: """ Creates deterministic hash from hash_key or (role, content). @@ -81,7 +85,7 @@ def generate_id(self) -> str: if tool_calls: # For tool calls, include them in the hash transformed_tool_calls = self._transform_message(tool_calls) - tool_calls_str = str(transformed_tool_calls) + tool_calls_str = json.dumps(transformed_tool_calls, default=self._serialize_default) key_data = f"{role}:{content}:{tool_calls_str}" else: key_data = f"{role}:{content}" diff --git a/cecli/helpers/model_providers.py b/cecli/helpers/model_providers.py index 19fcbd33167..d96f4163e2b 100644 --- a/cecli/helpers/model_providers.py +++ b/cecli/helpers/model_providers.py @@ -17,24 +17,15 @@ import time from copy import deepcopy from pathlib import Path -from typing import Dict, Optional +from typing import Any, Dict, Optional import requests from cecli.helpers.file_searcher import handle_core_files -try: - from litellm.llms.custom_httpx.http_handler import HTTPHandler - from litellm.llms.custom_llm import CustomLLM, CustomLLMError - from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler -except Exception: - CustomLLM = None - CustomLLMError = Exception - OpenAILikeChatHandler = None - HTTPHandler = None RESOURCE_FILE = "providers.json" _PROVIDERS_REGISTERED = False -_CUSTOM_HANDLERS: Dict[str, "_JSONOpenAIProvider"] = {} +_CUSTOM_HANDLERS: Dict[str, "Any"] = {} def _coerce_str(value): @@ -61,146 +52,166 @@ def _first_env_value(names): return None -class _JSONOpenAIProvider(OpenAILikeChatHandler): - """CustomLLM wrapper that routes OpenAI-compatible providers through LiteLLM.""" - - def __init__(self, slug: str, config: Dict): - if CustomLLM is None or OpenAILikeChatHandler is None: - raise RuntimeError("litellm custom handler support unavailable") - super().__init__() - self.slug = slug - self.config = config +def _get_json_openai_handler(slug: str, config: Dict) -> Any: + """Create a custom handler for OpenAI-compatible providers, lazily importing litellm.""" + try: + from litellm.llms.openai_like.chat.handler import OpenAILikeChatHandler + except Exception: + return None - def _resolve_api_base(self, api_base: Optional[str]) -> str: - base = ( - api_base - or _first_env_value(self.config.get("base_url_env")) - or self.config.get("api_base") - ) - if not base: - raise CustomLLMError(500, f"{self.slug} missing base URL") - return base.rstrip("/") + class _JSONOpenAIProvider(OpenAILikeChatHandler): + """CustomLLM wrapper that routes OpenAI-compatible providers through LiteLLM.""" - def _resolve_api_key(self, api_key: Optional[str]) -> Optional[str]: - if api_key: - return api_key - env_val = _first_env_value(self.config.get("api_key_env")) - return env_val - - def _apply_special_handling(self, messages): - special = self.config.get("special_handling") or {} - if special.get("convert_content_list_to_string"): - from litellm.litellm_core_utils.prompt_templates.common_utils import ( - handle_messages_with_content_list_to_str_conversion, + def __init__(self, slug: str, config: Dict): + try: + from litellm.llms.custom_llm import CustomLLM + except Exception: + CustomLLM = None + + if CustomLLM is None: + raise RuntimeError("litellm custom handler support unavailable") + + super().__init__() + self.slug = slug + self.config = config + + def _resolve_api_base(self, api_base: Optional[str]) -> str: + base = ( + api_base + or _first_env_value(self.config.get("base_url_env")) + or self.config.get("api_base") + ) + if not base: + try: + from litellm.llms.custom_llm import CustomLLMError + except Exception: + CustomLLMError = Exception + + raise CustomLLMError(500, f"{self.slug} missing base URL") + return base.rstrip("/") + + def _resolve_api_key(self, api_key: Optional[str]) -> Optional[str]: + if api_key: + return api_key + env_val = _first_env_value(self.config.get("api_key_env")) + return env_val + + def _apply_special_handling(self, messages): + special = self.config.get("special_handling") or {} + if special.get("convert_content_list_to_string"): + from litellm.litellm_core_utils.prompt_templates.common_utils import ( + handle_messages_with_content_list_to_str_conversion, + ) + + return handle_messages_with_content_list_to_str_conversion(messages) + return messages + + def _inject_headers(self, headers): + defaults = self.config.get("default_headers") or {} + combined = dict(defaults) + combined.update(headers or {}) + return combined + + def _normalize_model_name(self, model: str) -> str: + if not isinstance(model, str): + return model + trimmed = model + if trimmed.startswith(f"{self.slug}/"): + trimmed = trimmed.split("/", 1)[1] + hf_namespace = self.config.get("hf_namespace") + if hf_namespace and not trimmed.startswith("hf:"): + trimmed = f"hf:{trimmed}" + return trimmed + + def _build_request_params(self, optional_params, stream: bool): + params = dict(optional_params or {}) + default_headers = dict(self.config.get("default_headers") or {}) + headers = params.setdefault("extra_headers", default_headers) + if headers is default_headers and default_headers: + params["extra_headers"] = dict(default_headers) + if stream: + params["stream"] = True + return params + + def completion(self, *args, **kwargs): + kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) + kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) + kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) + kwargs["optional_params"] = self._build_request_params( + kwargs.get("optional_params", None), False + ) + kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) + kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) + kwargs["custom_llm_provider"] = "openai" + return super().completion(*args, **kwargs) + + async def acompletion(self, *args, **kwargs): + kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) + kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) + kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) + kwargs["optional_params"] = self._build_request_params( + kwargs.get("optional_params", None), False ) + kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) + kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) + kwargs["custom_llm_provider"] = "openai" + kwargs["acompletion"] = True + return await super().completion(*args, **kwargs) + + def streaming(self, *args, **kwargs): + kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) + kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) + kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) + kwargs["optional_params"] = self._build_request_params( + kwargs.get("optional_params", None), True + ) + kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) + kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) + kwargs["custom_llm_provider"] = "openai" + response = super().completion(*args, **kwargs) + for chunk in response: + yield self.get_generic_chunk(chunk) + + async def astreaming(self, *args, **kwargs): + kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) + kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) + kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) + kwargs["optional_params"] = self._build_request_params( + kwargs.get("optional_params", None), True + ) + kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) + kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) + kwargs["custom_llm_provider"] = "openai" + kwargs["acompletion"] = True + response = await super().completion(*args, **kwargs) + async for chunk in response: + yield self.get_generic_chunk(chunk) + + def get_generic_chunk(self, chunk): + choice = chunk.choices[0] if chunk.choices else None + delta = choice.delta if choice else None + text_content = delta.content if delta and delta.content else "" + tool_calls = delta.tool_calls if delta and delta.tool_calls else None + if tool_calls and len(tool_calls): + tool_calls = tool_calls[0] + usage_data = getattr(chunk, "usage", None) + if hasattr(usage_data, "model_dump"): + usage_dict = usage_data.model_dump() + elif isinstance(usage_data, dict): + usage_dict = usage_data + else: + usage_dict = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0} + generic_chunk = { + "finish_reason": choice.finish_reason if choice else None, + "index": choice.index if choice else 0, + "is_finished": bool(choice.finish_reason) if choice else False, + "text": text_content, + "tool_use": tool_calls, + "usage": usage_dict, + } + return generic_chunk - return handle_messages_with_content_list_to_str_conversion(messages) - return messages - - def _inject_headers(self, headers): - defaults = self.config.get("default_headers") or {} - combined = dict(defaults) - combined.update(headers or {}) - return combined - - def _normalize_model_name(self, model: str) -> str: - if not isinstance(model, str): - return model - trimmed = model - if trimmed.startswith(f"{self.slug}/"): - trimmed = trimmed.split("/", 1)[1] - hf_namespace = self.config.get("hf_namespace") - if hf_namespace and not trimmed.startswith("hf:"): - trimmed = f"hf:{trimmed}" - return trimmed - - def _build_request_params(self, optional_params, stream: bool): - params = dict(optional_params or {}) - default_headers = dict(self.config.get("default_headers") or {}) - headers = params.setdefault("extra_headers", default_headers) - if headers is default_headers and default_headers: - params["extra_headers"] = dict(default_headers) - if stream: - params["stream"] = True - return params - - def completion(self, *args, **kwargs): - kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) - kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) - kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) - kwargs["optional_params"] = self._build_request_params( - kwargs.get("optional_params", None), False - ) - kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) - kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) - kwargs["custom_llm_provider"] = "openai" - return super().completion(*args, **kwargs) - - async def acompletion(self, *args, **kwargs): - kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) - kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) - kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) - kwargs["optional_params"] = self._build_request_params( - kwargs.get("optional_params", None), False - ) - kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) - kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) - kwargs["custom_llm_provider"] = "openai" - kwargs["acompletion"] = True - return await super().completion(*args, **kwargs) - - def streaming(self, *args, **kwargs): - kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) - kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) - kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) - kwargs["optional_params"] = self._build_request_params( - kwargs.get("optional_params", None), True - ) - kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) - kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) - kwargs["custom_llm_provider"] = "openai" - response = super().completion(*args, **kwargs) - for chunk in response: - yield self.get_generic_chunk(chunk) - - async def astreaming(self, *args, **kwargs): - kwargs["api_base"] = self._resolve_api_base(kwargs.get("api_base", None)) - kwargs["api_key"] = self._resolve_api_key(kwargs.get("api_key", None)) - kwargs["headers"] = self._inject_headers(kwargs.get("headers", None)) - kwargs["optional_params"] = self._build_request_params( - kwargs.get("optional_params", None), True - ) - kwargs["messages"] = self._apply_special_handling(kwargs.get("messages", [])) - kwargs["model"] = self._normalize_model_name(kwargs.get("model", None)) - kwargs["custom_llm_provider"] = "openai" - kwargs["acompletion"] = True - response = await super().completion(*args, **kwargs) - async for chunk in response: - yield self.get_generic_chunk(chunk) - - def get_generic_chunk(self, chunk): - choice = chunk.choices[0] if chunk.choices else None - delta = choice.delta if choice else None - text_content = delta.content if delta and delta.content else "" - tool_calls = delta.tool_calls if delta and delta.tool_calls else None - if tool_calls and len(tool_calls): - tool_calls = tool_calls[0] - usage_data = getattr(chunk, "usage", None) - if hasattr(usage_data, "model_dump"): - usage_dict = usage_data.model_dump() - elif isinstance(usage_data, dict): - usage_dict = usage_data - else: - usage_dict = {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0} - generic_chunk = { - "finish_reason": choice.finish_reason if choice else None, - "index": choice.index if choice else 0, - "is_finished": bool(choice.finish_reason) if choice else False, - "text": text_content, - "tool_use": tool_calls, - "usage": usage_dict, - } - return generic_chunk + return _JSONOpenAIProvider(slug, config) def _register_provider_with_litellm(slug: str, config: Dict) -> None: @@ -220,7 +231,7 @@ def _register_provider_with_litellm(slug: str, config: Dict) -> None: return handler = _CUSTOM_HANDLERS.get(slug) if handler is None: - handler = _JSONOpenAIProvider(slug, config) + handler = _get_json_openai_handler(slug, config) _CUSTOM_HANDLERS[slug] = handler if handler is None: return diff --git a/cecli/helpers/requests.py b/cecli/helpers/requests.py index 4ee4866fd75..2e39bc41e54 100644 --- a/cecli/helpers/requests.py +++ b/cecli/helpers/requests.py @@ -32,48 +32,44 @@ def remove_empty_tool_calls(messages): ] +def _process_thought_signature(container): + if "provider_specific_fields" not in container: + container["provider_specific_fields"] = {} + + psf = container["provider_specific_fields"] + + if "thought_signature" not in psf: + if "thought_signatures" in psf: + sigs = psf["thought_signatures"] + if isinstance(sigs, list) and len(sigs) > 0: + psf["thought_signature"] = sigs[0] + elif isinstance(sigs, str): + psf["thought_signature"] = sigs + psf.pop("thought_signatures", None) + + if "thought_signature" not in psf: + psf["thought_signature"] = "skip_thought_signature_validator" + + def thought_signature(model, messages): # Add thought signatures for Vertex AI and Gemini models if model.name.startswith("vertex_ai/") or model.name.startswith("gemini/"): for msg in messages: + # Handle top-level provider_specific_fields + if "provider_specific_fields" in msg or msg.get("role") == "assistant": + _process_thought_signature(msg) + if "tool_calls" in msg: tool_calls = msg["tool_calls"] - if tool_calls: for call in tool_calls: - if not call: - continue - - # Check if thought signature is missing in extra_content.google.thought_signature - if "provider_specific_fields" not in call: - call["provider_specific_fields"] = {} - if "thought_signature" not in call["provider_specific_fields"]: - if "thought_signatures" in call["provider_specific_fields"] and len( - call["provider_specific_fields"]["thought_signatures"] - ): - call["provider_specific_fields"]["thought_signature"] = call[ - "provider_specific_fields" - ]["thought_signatures"][0] - - call["provider_specific_fields"].pop("thought_signatures", None) - else: - call["provider_specific_fields"][ - "thought_signature" - ] = "skip_thought_signature_validator" + if call: + _process_thought_signature(call) if "function_call" in msg: call = msg["function_call"] - - if not call: - continue - - # Check if thought signature is missing in extra_content.google.thought_signature - if "provider_specific_fields" not in call: - call["provider_specific_fields"] = {} - if "thought_signature" not in call["provider_specific_fields"]: - call["provider_specific_fields"][ - "thought_signature" - ] = "skip_thought_signature_validator" + if call: + _process_thought_signature(call) return messages diff --git a/cecli/helpers/skills.py b/cecli/helpers/skills.py index 1f99c79018c..6ad8df3349d 100644 --- a/cecli/helpers/skills.py +++ b/cecli/helpers/skills.py @@ -486,7 +486,7 @@ def get_skills_content(self) -> Optional[str]: if not self._loaded_skills: return None - result = '\n' + result = '\n' result += "## Loaded Skills Content\n\n" result += f"Found {len(self._loaded_skills)} skill(s) in configured directories:\n\n" @@ -557,7 +557,7 @@ def get_skills_context(self) -> Optional[str]: if not summaries: return None - result = '\n' + result = '\n' result += "## Available Skills\n\n" result += f"Found {len(summaries)} skill(s) in configured directories:\n\n" diff --git a/cecli/llm.py b/cecli/llm.py index 7ddd15fcb5e..e177854c9fd 100644 --- a/cecli/llm.py +++ b/cecli/llm.py @@ -61,7 +61,8 @@ def _load_litellm(self): # See: https://github.com/BerriAI/litellm/issues/16518 # See: https://github.com/BerriAI/litellm/issues/14521 try: - from litellm.litellm_core_utils import logging_worker + # Use importlib for lazy loading + logging_worker = importlib.import_module("litellm.litellm_core_utils.logging_worker") except ImportError: # Module didn't exist before litellm 1.76.0 # https://github.com/BerriAI/litellm/pull/13905 diff --git a/cecli/mcp/manager.py b/cecli/mcp/manager.py index 6e795da397d..58524f768d5 100644 --- a/cecli/mcp/manager.py +++ b/cecli/mcp/manager.py @@ -1,7 +1,6 @@ import asyncio -from litellm import experimental_mcp_client - +from cecli.llm import litellm from cecli.mcp.server import LocalServer, McpServer from cecli.tools.utils.registry import ToolRegistry @@ -132,7 +131,9 @@ async def connect_server(self, name: str) -> bool: try: session = await server.connect() - tools = await experimental_mcp_client.load_mcp_tools(session=session, format="openai") + tools = await litellm.experimental_mcp_client.load_mcp_tools( + session=session, format="openai" + ) self._server_tools[server.name] = tools self._connected_servers.add(server) self._log_verbose(f"Connected to MCP server: {name}") diff --git a/cecli/mcp/server.py b/cecli/mcp/server.py index e4769f8d0ac..fa1fb46ba8d 100644 --- a/cecli/mcp/server.py +++ b/cecli/mcp/server.py @@ -266,12 +266,30 @@ def _create_transport(self, url, http_client): return streamable_http_client(url, http_client=http_client) -class SseServer(HttpBasedMcpServer): +class SseServer(McpServer): """SSE (Server-Sent Events) MCP server using mcp.client.sse_client.""" - def _create_transport(self, url, http_client): - """Create the SSE transport.""" - return sse_client(url, http_client=http_client) + async def connect(self): + if self.session is not None: + logging.info(f"Using existing session for SSE MCP server: {self.name}") + return self.session + + logging.info(f"Establishing new connection to SSE MCP server: {self.name}") + try: + url = self.config.get("url") + headers = self.config.get("headers", {}) + sse_transport = await self.exit_stack.enter_async_context( + sse_client(url, headers=headers) + ) + read, write = sse_transport + session = await self.exit_stack.enter_async_context(ClientSession(read, write)) + await session.initialize() + self.session = session + return session + except Exception as e: + logging.error(f"Error initializing SSE server {self.name}: {e}") + await self.disconnect() + raise class LocalServer(McpServer): diff --git a/cecli/tools/thinking.py b/cecli/tools/thinking.py index 023bba54608..b434ff59de3 100644 --- a/cecli/tools/thinking.py +++ b/cecli/tools/thinking.py @@ -35,7 +35,7 @@ def execute(cls, coder, content): iterates over tools to ideally help it guide itself to a proper solution """ coder.io.tool_output("🧠 Thoughts recorded in context") - return content + return "🧠 Thoughts recorded in context" @classmethod def format_output(cls, coder, mcp_server, tool_response): diff --git a/cecli/tui/app.py b/cecli/tui/app.py index a0966404b5c..6168c742813 100644 --- a/cecli/tui/app.py +++ b/cecli/tui/app.py @@ -48,6 +48,7 @@ def __init__(self, coder_worker, output_queue, input_queue, args): self._symbols_cache = None self._symbols_files_hash = None self._mouse_hold_timer = None + self._currently_generating = False self.tui_config = self._get_config() @@ -332,7 +333,7 @@ def on_mouse_up(self, event: events.MouseUp) -> None: if self._mouse_hold_timer: self._mouse_hold_timer.stop() self._mouse_hold_timer = None - self.update_key_hints() + self.update_key_hints(generating=self._currently_generating) def _show_select_hint(self) -> None: """Show the shift+drag to select hint.""" @@ -350,9 +351,11 @@ def update_key_hints(self, generating=False): try: hints = self.query_one(KeyHints) if generating: + self._currently_generating = True stop = self.app.get_keys_for("stop") hints.update_right(f"{stop} to cancel") else: + self._currently_generating = False submit = self.app.get_keys_for("submit") hints.update_right(f"{submit} to submit") except Exception: diff --git a/requirements.txt b/requirements.txt index 3e22dddda23..7c49985e697 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,10 +39,6 @@ beautifulsoup4==4.14.2 # via # -c requirements/common-constraints.txt # -r requirements/requirements.in -cachetools==6.2.2 - # via - # -c requirements/common-constraints.txt - # google-auth certifi==2025.11.12 # via # -c requirements/common-constraints.txt @@ -113,41 +109,6 @@ gitpython==3.1.45 # via # -c requirements/common-constraints.txt # -r requirements/requirements.in -google-ai-generativelanguage==0.6.15 - # via - # -c requirements/common-constraints.txt - # google-generativeai -google-api-core[grpc]==2.28.1 - # via - # -c requirements/common-constraints.txt - # google-ai-generativelanguage - # google-api-python-client - # google-generativeai -google-api-python-client==2.187.0 - # via - # -c requirements/common-constraints.txt - # google-generativeai -google-auth==2.43.0 - # via - # -c requirements/common-constraints.txt - # google-ai-generativelanguage - # google-api-core - # google-api-python-client - # google-auth-httplib2 - # google-generativeai -google-auth-httplib2==0.2.1 - # via - # -c requirements/common-constraints.txt - # google-api-python-client -google-generativeai==0.8.5 - # via - # -c requirements/common-constraints.txt - # -r requirements/requirements.in -googleapis-common-protos==1.72.0 - # via - # -c requirements/common-constraints.txt - # google-api-core - # grpcio-status grep-ast==0.9.0 # via # -c requirements/common-constraints.txt @@ -155,13 +116,7 @@ grep-ast==0.9.0 grpcio==1.67.1 # via # -c requirements/common-constraints.txt - # google-api-core - # grpcio-status # litellm -grpcio-status==1.67.1 - # via - # -c requirements/common-constraints.txt - # google-api-core h11==0.16.0 # via # -c requirements/common-constraints.txt @@ -175,11 +130,6 @@ httpcore==1.0.9 # via # -c requirements/common-constraints.txt # httpx -httplib2==0.31.0 - # via - # -c requirements/common-constraints.txt - # google-api-python-client - # google-auth-httplib2 httpx==0.28.1 # via # -c requirements/common-constraints.txt @@ -320,20 +270,6 @@ propcache==0.4.1 # -c requirements/common-constraints.txt # aiohttp # yarl -proto-plus==1.26.1 - # via - # -c requirements/common-constraints.txt - # google-ai-generativelanguage - # google-api-core -protobuf==5.29.5 - # via - # -c requirements/common-constraints.txt - # google-ai-generativelanguage - # google-api-core - # google-generativeai - # googleapis-common-protos - # grpcio-status - # proto-plus psutil==7.1.3 # via # -c requirements/common-constraints.txt @@ -342,15 +278,6 @@ ptyprocess==0.7.0 # via # -c requirements/common-constraints.txt # pexpect -pyasn1==0.6.1 - # via - # -c requirements/common-constraints.txt - # pyasn1-modules - # rsa -pyasn1-modules==0.4.2 - # via - # -c requirements/common-constraints.txt - # google-auth pycodestyle==2.14.0 # via # -c requirements/common-constraints.txt @@ -362,7 +289,6 @@ pycparser==2.23 pydantic==2.12.4 # via # -c requirements/common-constraints.txt - # google-generativeai # litellm # mcp # openai @@ -396,10 +322,6 @@ pypandoc==1.16.2 # via # -c requirements/common-constraints.txt # -r requirements/requirements.in -pyparsing==3.2.5 - # via - # -c requirements/common-constraints.txt - # httplib2 pyperclip==1.11.0 # via # -c requirements/common-constraints.txt @@ -430,7 +352,6 @@ regex==2025.11.3 requests==2.32.5 # via # -c requirements/common-constraints.txt - # google-api-core # huggingface-hub # tiktoken rich==14.2.0 @@ -443,10 +364,6 @@ rpds-py==0.29.0 # -c requirements/common-constraints.txt # jsonschema # referencing -rsa==4.9.1 - # via - # -c requirements/common-constraints.txt - # google-auth rustworkx==0.17.1 # via # -c requirements/common-constraints.txt @@ -507,7 +424,6 @@ tokenizers==0.22.1 tqdm==4.67.1 # via # -c requirements/common-constraints.txt - # google-generativeai # huggingface-hub # openai # via @@ -540,7 +456,6 @@ typing-extensions==4.15.0 # aiosignal # anyio # beautifulsoup4 - # google-generativeai # huggingface-hub # mcp # openai @@ -560,10 +475,6 @@ uc-micro-py==1.0.3 # via # -c requirements/common-constraints.txt # linkify-it-py -uritemplate==4.2.0 - # via - # -c requirements/common-constraints.txt - # google-api-python-client urllib3==2.5.0 # via # -c requirements/common-constraints.txt diff --git a/requirements/common-constraints.txt b/requirements/common-constraints.txt index cc340d17cf4..902d4066915 100644 --- a/requirements/common-constraints.txt +++ b/requirements/common-constraints.txt @@ -115,36 +115,21 @@ gitdb==4.0.12 # via gitpython gitpython==3.1.45 # via -r requirements/requirements.in -google-ai-generativelanguage==0.6.15 - # via google-generativeai google-api-core[grpc]==2.28.1 # via - # google-ai-generativelanguage - # google-api-python-client # google-cloud-bigquery # google-cloud-core - # google-generativeai -google-api-python-client==2.187.0 - # via google-generativeai google-auth==2.43.0 # via - # google-ai-generativelanguage # google-api-core - # google-api-python-client - # google-auth-httplib2 # google-cloud-bigquery # google-cloud-core - # google-generativeai -google-auth-httplib2==0.2.1 - # via google-api-python-client google-cloud-bigquery==3.38.0 # via -r requirements/requirements-dev.in google-cloud-core==2.5.0 # via google-cloud-bigquery google-crc32c==1.7.1 # via google-resumable-media -google-generativeai==0.8.5 - # via -r requirements/requirements.in google-resumable-media==2.8.0 # via google-cloud-bigquery googleapis-common-protos==1.72.0 @@ -174,10 +159,6 @@ hf-xet==1.2.0 # via huggingface-hub httpcore==1.0.9 # via httpx -httplib2==0.31.0 - # via - # google-api-python-client - # google-auth-httplib2 httpx==0.28.1 # via # litellm @@ -387,14 +368,10 @@ propcache==0.4.1 # aiohttp # yarl proto-plus==1.26.1 - # via - # google-ai-generativelanguage - # google-api-core + # via google-api-core protobuf==5.29.5 # via - # google-ai-generativelanguage # google-api-core - # google-generativeai # googleapis-common-protos # grpcio-status # proto-plus @@ -415,7 +392,6 @@ pycparser==2.23 pydantic==2.12.4 # via # banks - # google-generativeai # litellm # llama-index-core # llama-index-instrumentation @@ -443,9 +419,7 @@ pyjwt[crypto]==2.10.1 pypandoc==1.16.2 # via -r requirements/requirements.in pyparsing==3.2.5 - # via - # httplib2 - # matplotlib + # via matplotlib pyperclip==1.11.0 # via -r requirements/requirements.in pyproject-hooks==1.2.0 @@ -578,7 +552,6 @@ torch==2.9.1 # via sentence-transformers tqdm==4.67.1 # via - # google-generativeai # huggingface-hub # llama-index-core # nltk @@ -611,7 +584,6 @@ typing-extensions==4.15.0 # aiosqlite # anyio # beautifulsoup4 - # google-generativeai # huggingface-hub # llama-index-core # llama-index-workflows @@ -643,8 +615,6 @@ tzdata==2025.2 # via pandas uc-micro-py==1.0.3 # via linkify-it-py -uritemplate==4.2.0 - # via google-api-python-client urllib3==2.5.0 # via requests uv==0.9.11 diff --git a/requirements/requirements.in b/requirements/requirements.in index ef14d663bb8..c94e5d2a524 100644 --- a/requirements/requirements.in +++ b/requirements/requirements.in @@ -27,7 +27,6 @@ socksio>=1.0.0 pillow>=11.3.0 shtab>=1.7.2 oslex>=0.1.3 -google-generativeai>=0.8.5 mcp>=1.24.0 textual>=6.0.0 truststore diff --git a/tests/basic/test_reasoning.py b/tests/basic/test_reasoning.py index e5155f8a3b9..1ab67922ce4 100644 --- a/tests/basic/test_reasoning.py +++ b/tests/basic/test_reasoning.py @@ -2,11 +2,10 @@ import textwrap from unittest.mock import MagicMock, patch -import litellm - from cecli.coders.base_coder import Coder from cecli.dump import dump # noqa from cecli.io import InputOutput +from cecli.llm import litellm from cecli.models import Model from cecli.reasoning_tags import ( REASONING_END,