Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Internal
* Enable flake8-bugbear lint rules.
* Fix flaky editor-command tests in CI.
* Require release format of `changelog.md` when making a release.
* Improve type annotations on LLM driver.


1.40.0 (2025/10/14)
Expand Down
45 changes: 32 additions & 13 deletions mycli/packages/special/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import shlex
import sys
from time import time
from typing import Optional, Tuple
from typing import Any

import click

Expand All @@ -30,6 +30,7 @@
LLM_CLI_IMPORTED = False
except ImportError:
LLM_CLI_IMPORTED = False
from pymysql.cursors import Cursor

from mycli.packages.special.main import Verbosity, parse_special_command

Expand All @@ -38,15 +39,22 @@
LLM_TEMPLATE_NAME = "mycli-llm-template"


def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True):
def run_external_cmd(
cmd: str,
*args,
capture_output=False,
restart_cli=False,
raise_exception=True,
) -> tuple[int, str]:
original_exe = sys.executable
original_args = sys.argv
try:
sys.argv = [cmd] + list(args)
code = 0
if capture_output:
buffer = io.StringIO()
redirect = contextlib.ExitStack()
redirect: contextlib.ExitStack[bool | None] | contextlib.nullcontext[None] = contextlib.ExitStack()
assert isinstance(redirect, contextlib.ExitStack)
redirect.enter_context(contextlib.redirect_stdout(buffer))
redirect.enter_context(contextlib.redirect_stderr(buffer))
else:
Expand All @@ -55,7 +63,7 @@ def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_
try:
run_module(cmd, run_name="__main__")
except SystemExit as e:
code = e.code
code = int(e.code or 0)
if code != 0 and raise_exception:
if capture_output:
raise RuntimeError(buffer.getvalue()) from e
Expand All @@ -76,24 +84,33 @@ def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_
sys.argv = original_args


def build_command_tree(cmd):
tree = {}
def _build_command_tree(cmd) -> dict[str, Any] | None:
tree: dict[str, Any] | None = {}
assert isinstance(tree, dict)
if isinstance(cmd, click.Group):
for name, subcmd in cmd.commands.items():
if cmd.name == "models" and name == "default":
tree[name] = {x.model_id: None for x in llm.get_models()}
else:
tree[name] = build_command_tree(subcmd)
tree[name] = _build_command_tree(subcmd)
else:
tree = None
return tree


def build_command_tree(cmd) -> dict[str, Any]:
return _build_command_tree(cmd) or {}


# Generate the command tree for autocompletion
COMMAND_TREE = build_command_tree(cli) if LLM_CLI_IMPORTED is True else {}


def get_completions(tokens, tree=COMMAND_TREE):
def get_completions(
tokens: list[str],
tree: dict[str, Any] | None = None,
) -> list[str]:
tree = tree or COMMAND_TREE
for token in tokens:
if token.startswith("-"):
continue
Expand Down Expand Up @@ -182,21 +199,20 @@ def __init__(self, results=None):
"""


def ensure_mycli_template(replace=False):
def ensure_mycli_template(replace: bool = False) -> None:
if not replace:
code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False)
if code == 0:
return
run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME)
return


@functools.cache
def cli_commands() -> list[str]:
return list(cli.commands.keys())


def handle_llm(text, cur) -> Tuple[str, Optional[str], float]:
def handle_llm(text: str, cur: Cursor) -> tuple[str, str | None, float]:
_, verbosity, arg = parse_special_command(text)
if not LLM_IMPORTED:
output = [(None, None, None, NEED_DEPENDENCIES)]
Expand Down Expand Up @@ -254,12 +270,15 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str], float]:
raise RuntimeError(e) from e


def is_llm_command(command) -> bool:
def is_llm_command(command: str) -> bool:
cmd, _, _ = parse_special_command(command)
return cmd in ("\\llm", "\\ai")


def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]:
def sql_using_llm(
cur: Cursor | None,
question: str | None = None,
) -> tuple[str, str | None]:
if cur is None:
raise RuntimeError("Connect to a database and try again.")
schema_query = """
Expand Down