Skip to content

Commit 8cf2200

Browse files
committed
reduce size of LLM prompts
* truncate text/binary sample data fields to 1024 characters (or smaller if judged to be needed) * truncate entire tables from schema representation if the representation is very large * for latency improvement, cache sample data and schema representation, passing the dbname in both cases to invalidate the cache if changing the db * add separate progress message when generating sample data The target_size values are chosen somewhat arbitrarily. We could also apply final size limits to the prompt string, though meaning-preserving truncation at that point is harder. Addresses #1348.
1 parent ff4ce79 commit 8cf2200

File tree

4 files changed

+67
-28
lines changed

4 files changed

+67
-28
lines changed

changelog.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Upcoming (TBD)
22
==============
33

4+
Features
5+
--------
6+
* Limit size of LLM prompts and cache LLM prompt data.
7+
8+
49
Internal
510
--------
611
* Add mypy to Pull Request template.

mycli/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -797,7 +797,7 @@ def one_iteration(text: str | None = None) -> None:
797797
try:
798798
assert sqlexecute.conn is not None
799799
cur = sqlexecute.conn.cursor()
800-
context, sql, duration = special.handle_llm(text, cur)
800+
context, sql, duration = special.handle_llm(text, cur, self.sqlexecute.dbname)
801801
if context:
802802
click.echo("LLM Response:")
803803
click.echo(context)

mycli/packages/special/llm.py

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import functools
23
import io
34
import logging
45
import os
@@ -7,7 +8,7 @@
78
import shlex
89
import sys
910
from time import time
10-
from typing import Optional, Tuple
11+
from typing import Any, Optional, Tuple
1112

1213
import click
1314
import llm
@@ -159,7 +160,7 @@ def ensure_mycli_template(replace=False):
159160
return
160161

161162

162-
def handle_llm(text, cur) -> Tuple[str, Optional[str], float]:
163+
def handle_llm(text, cur, dbname: str) -> Tuple[str, Optional[str], float]:
163164
_, verbosity, arg = parse_special_command(text)
164165
if not arg.strip():
165166
output = [(None, None, None, USAGE)]
@@ -205,7 +206,7 @@ def handle_llm(text, cur) -> Tuple[str, Optional[str], float]:
205206
try:
206207
ensure_mycli_template()
207208
start = time()
208-
context, sql = sql_using_llm(cur=cur, question=arg)
209+
context, sql = sql_using_llm(cur=cur, question=arg, dbname=dbname)
209210
end = time()
210211
if verbosity == Verbosity.SUCCINCT:
211212
context = ""
@@ -219,42 +220,75 @@ def is_llm_command(command) -> bool:
219220
return cmd in ("\\llm", "\\ai")
220221

221222

222-
def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]:
223-
if cur is None:
224-
raise RuntimeError("Connect to a database and try again.")
225-
schema_query = """
226-
SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')')
223+
def truncate_list_elements(row: list) -> list:
224+
target_size = 20000
225+
width = 1024
226+
while width >= 0:
227+
truncated_row = [x[:width] if isinstance(x, (str, bytes)) else x for x in row]
228+
if sum(sys.getsizeof(x) for x in truncated_row) <= target_size:
229+
break
230+
width -= 100
231+
return truncated_row
232+
233+
234+
def truncate_table_lines(table: list[str]) -> list[str]:
235+
target_size = 40000
236+
truncated_table = []
237+
running_sum = 0
238+
while table and running_sum <= target_size:
239+
line = table.pop(0)
240+
running_sum += sys.getsizeof(line)
241+
truncated_table.append(line)
242+
return truncated_table
243+
244+
245+
@functools.cache
246+
def get_schema(cur, dbname) -> str:
247+
click.echo("Preparing schema information to feed the LLM")
248+
schema_query = f"""
249+
SELECT CONCAT(table_name, '(', GROUP_CONCAT(column_name, ' ', COLUMN_TYPE SEPARATOR ', '),')') AS schema
227250
FROM information_schema.columns
228-
WHERE table_schema = DATABASE()
251+
WHERE table_schema = '{dbname}'
229252
GROUP BY table_name
230253
ORDER BY table_name
231254
"""
232-
tables_query = "SHOW TABLES"
233-
sample_row_query = "SELECT * FROM `{table}` LIMIT 1"
234-
click.echo("Preparing schema information to feed the llm")
235255
cur.execute(schema_query)
236-
db_schema = "\n".join([row[0] for (row,) in cur.fetchall()])
256+
db_schema = [row[0] for (row,) in cur.fetchall()]
257+
return '\n'.join(truncate_table_lines(db_schema))
258+
259+
260+
@functools.cache
261+
def get_sample_data(cur, dbname) -> dict[str, Any]:
262+
click.echo("Preparing sample data to feed the LLM")
263+
tables_query = "SHOW TABLES"
264+
sample_row_query = "SELECT * FROM `{dbname}`.`{table}` LIMIT 1"
237265
cur.execute(tables_query)
238266
sample_data = {}
239267
for (table_name,) in cur.fetchall():
240268
try:
241-
cur.execute(sample_row_query.format(table=table_name))
269+
cur.execute(sample_row_query.format(dbname=dbname, table=table_name))
242270
except Exception:
243271
continue
244272
cols = [desc[0] for desc in cur.description]
245273
row = cur.fetchone()
246274
if row is None:
247275
continue
248-
sample_data[table_name] = list(zip(cols, row))
276+
sample_data[table_name] = list(zip(cols, truncate_list_elements(row)))
277+
return sample_data
278+
279+
280+
def sql_using_llm(cur, question=None, dbname: str = '') -> Tuple[str, Optional[str]]:
281+
if cur is None:
282+
raise RuntimeError("Connect to a database and try again.")
249283
args = [
250284
"--template",
251285
LLM_TEMPLATE_NAME,
252286
"--param",
253287
"db_schema",
254-
db_schema,
288+
get_schema(cur, dbname),
255289
"--param",
256290
"sample_data",
257-
sample_data,
291+
get_sample_data(cur, dbname),
258292
"--param",
259293
"question",
260294
question,

test/test_llm_special.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test_llm_command_without_args(mock_llm, executor):
2626
assert mock_llm is not None
2727
test_text = r"\llm"
2828
with pytest.raises(FinishIteration) as exc_info:
29-
handle_llm(test_text, executor)
29+
handle_llm(test_text, executor, 'mysql')
3030
# Should return usage message when no args provided
3131
assert exc_info.value.args[0] == [(None, None, None, USAGE)]
3232

@@ -38,7 +38,7 @@ def test_llm_command_with_c_flag(mock_run_cmd, mock_llm, executor):
3838
mock_run_cmd.return_value = (0, "Hello, no SQL today.")
3939
test_text = r"\llm -c 'Something?'"
4040
with pytest.raises(FinishIteration) as exc_info:
41-
handle_llm(test_text, executor)
41+
handle_llm(test_text, executor, 'mysql')
4242
# Expect raw output when no SQL fence found
4343
assert exc_info.value.args[0] == [(None, None, None, "Hello, no SQL today.")]
4444

@@ -51,7 +51,7 @@ def test_llm_command_with_c_flag_and_fenced_sql(mock_run_cmd, mock_llm, executor
5151
fenced = f"Here you go:\n```sql\n{sql_text}\n```"
5252
mock_run_cmd.return_value = (0, fenced)
5353
test_text = r"\llm -c 'Rewrite SQL'"
54-
result, sql, duration = handle_llm(test_text, executor)
54+
result, sql, duration = handle_llm(test_text, executor, 'mysql')
5555
# Without verbose, result is empty, sql extracted
5656
assert sql == sql_text
5757
assert result == ""
@@ -64,7 +64,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor):
6464
# 'models' is a known subcommand
6565
test_text = r"\llm models"
6666
with pytest.raises(FinishIteration) as exc_info:
67-
handle_llm(test_text, executor)
67+
handle_llm(test_text, executor, 'mysql')
6868
mock_run_cmd.assert_called_once_with("llm", "models", restart_cli=False)
6969
assert exc_info.value.args[0] is None
7070

@@ -74,7 +74,7 @@ def test_llm_command_known_subcommand(mock_run_cmd, mock_llm, executor):
7474
def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor):
7575
test_text = r"\llm --help"
7676
with pytest.raises(FinishIteration) as exc_info:
77-
handle_llm(test_text, executor)
77+
handle_llm(test_text, executor, 'mysql')
7878
mock_run_cmd.assert_called_once_with("llm", "--help", restart_cli=False)
7979
assert exc_info.value.args[0] is None
8080

@@ -84,7 +84,7 @@ def test_llm_command_with_help_flag(mock_run_cmd, mock_llm, executor):
8484
def test_llm_command_with_install_flag(mock_run_cmd, mock_llm, executor):
8585
test_text = r"\llm install openai"
8686
with pytest.raises(FinishIteration) as exc_info:
87-
handle_llm(test_text, executor)
87+
handle_llm(test_text, executor, 'mysql')
8888
mock_run_cmd.assert_called_once_with("llm", "install", "openai", restart_cli=True)
8989
assert exc_info.value.args[0] is None
9090

@@ -98,7 +98,7 @@ def test_llm_command_with_prompt(mock_sql_using_llm, mock_ensure_template, mock_
9898
"""
9999
mock_sql_using_llm.return_value = ("CTX", "SELECT 1;")
100100
test_text = r"\llm prompt 'Test?'"
101-
context, sql, duration = handle_llm(test_text, executor)
101+
context, sql, duration = handle_llm(test_text, executor, 'mysql')
102102
mock_ensure_template.assert_called_once()
103103
mock_sql_using_llm.assert_called()
104104
assert context == "CTX"
@@ -115,7 +115,7 @@ def test_llm_command_question_with_context(mock_sql_using_llm, mock_ensure_templ
115115
"""
116116
mock_sql_using_llm.return_value = ("CTX2", "SELECT 2;")
117117
test_text = r"\llm 'Top 10?'"
118-
context, sql, duration = handle_llm(test_text, executor)
118+
context, sql, duration = handle_llm(test_text, executor, 'mysql')
119119
mock_ensure_template.assert_called_once()
120120
mock_sql_using_llm.assert_called()
121121
assert context == "CTX2"
@@ -132,7 +132,7 @@ def test_llm_command_question_verbose(mock_sql_using_llm, mock_ensure_template,
132132
"""
133133
mock_sql_using_llm.return_value = ("NO_CTX", "SELECT 42;")
134134
test_text = r"\llm- 'Succinct?'"
135-
context, sql, duration = handle_llm(test_text, executor)
135+
context, sql, duration = handle_llm(test_text, executor, 'mysql')
136136
assert context == ""
137137
assert sql == "SELECT 42;"
138138
assert isinstance(duration, float)
@@ -194,5 +194,5 @@ def test_handle_llm_aliases_without_args(prefix, executor, monkeypatch):
194194

195195
monkeypatch.setattr(llm_module, "llm", object())
196196
with pytest.raises(FinishIteration) as exc_info:
197-
handle_llm(prefix, executor)
197+
handle_llm(prefix, executor, 'mysql')
198198
assert exc_info.value.args[0] == [(None, None, None, USAGE)]

0 commit comments

Comments
 (0)