diff --git a/changelog.md b/changelog.md index 16cb6f20..51ae71a4 100644 --- a/changelog.md +++ b/changelog.md @@ -17,6 +17,7 @@ Internal * Enable flake8-bugbear lint rules. * Fix flaky editor-command tests in CI. * Require release format of `changelog.md` when making a release. +* Improve type annotations on LLM driver. 1.40.0 (2025/10/14) diff --git a/mycli/packages/special/llm.py b/mycli/packages/special/llm.py index ce7e2ae1..d19b8c41 100644 --- a/mycli/packages/special/llm.py +++ b/mycli/packages/special/llm.py @@ -8,7 +8,7 @@ import shlex import sys from time import time -from typing import Optional, Tuple +from typing import Any import click @@ -30,6 +30,7 @@ LLM_CLI_IMPORTED = False except ImportError: LLM_CLI_IMPORTED = False +from pymysql.cursors import Cursor from mycli.packages.special.main import Verbosity, parse_special_command @@ -38,7 +39,13 @@ LLM_TEMPLATE_NAME = "mycli-llm-template" -def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True): +def run_external_cmd( + cmd: str, + *args, + capture_output=False, + restart_cli=False, + raise_exception=True, +) -> tuple[int, str]: original_exe = sys.executable original_args = sys.argv try: @@ -46,7 +53,8 @@ def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_ code = 0 if capture_output: buffer = io.StringIO() - redirect = contextlib.ExitStack() + redirect: contextlib.ExitStack[bool | None] | contextlib.nullcontext[None] = contextlib.ExitStack() + assert isinstance(redirect, contextlib.ExitStack) redirect.enter_context(contextlib.redirect_stdout(buffer)) redirect.enter_context(contextlib.redirect_stderr(buffer)) else: @@ -55,7 +63,7 @@ def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_ try: run_module(cmd, run_name="__main__") except SystemExit as e: - code = e.code + code = int(e.code or 0) if code != 0 and raise_exception: if capture_output: raise RuntimeError(buffer.getvalue()) from e @@ -76,24 +84,33 @@ def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_ sys.argv = original_args -def build_command_tree(cmd): - tree = {} +def _build_command_tree(cmd) -> dict[str, Any] | None: + tree: dict[str, Any] | None = {} + assert isinstance(tree, dict) if isinstance(cmd, click.Group): for name, subcmd in cmd.commands.items(): if cmd.name == "models" and name == "default": tree[name] = {x.model_id: None for x in llm.get_models()} else: - tree[name] = build_command_tree(subcmd) + tree[name] = _build_command_tree(subcmd) else: tree = None return tree +def build_command_tree(cmd) -> dict[str, Any]: + return _build_command_tree(cmd) or {} + + # Generate the command tree for autocompletion COMMAND_TREE = build_command_tree(cli) if LLM_CLI_IMPORTED is True else {} -def get_completions(tokens, tree=COMMAND_TREE): +def get_completions( + tokens: list[str], + tree: dict[str, Any] | None = None, +) -> list[str]: + tree = tree or COMMAND_TREE for token in tokens: if token.startswith("-"): continue @@ -182,13 +199,12 @@ def __init__(self, results=None): """ -def ensure_mycli_template(replace=False): +def ensure_mycli_template(replace: bool = False) -> None: if not replace: code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False) if code == 0: return run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME) - return @functools.cache @@ -196,7 +212,7 @@ def cli_commands() -> list[str]: return list(cli.commands.keys()) -def handle_llm(text, cur) -> Tuple[str, Optional[str], float]: +def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]: _, verbosity, arg = parse_special_command(text) if not LLM_IMPORTED: output = [(None, None, None, NEED_DEPENDENCIES)] @@ -254,12 +270,15 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str], float]: raise RuntimeError(e) from e -def is_llm_command(command) -> bool: +def is_llm_command(command: str) -> bool: cmd, _, _ = parse_special_command(command) return cmd in ("\\llm", "\\ai") -def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: +def sql_using_llm( + cur: Cursor | None, + question: str | None = None, +) -> tuple[str, str | None]: if cur is None: raise RuntimeError("Connect to a database and try again.") schema_query = """