Skip to content

Commit a9b8e6e

Browse files
committed
Add llm support.
1 parent 8e1f5db commit a9b8e6e

4 files changed

Lines changed: 525 additions & 8 deletions

File tree

pgspecial/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def export(defn):
1111
return defn
1212

1313

14-
from . import dbcommands, iocommands # noqa
14+
from . import dbcommands, iocommands, llm # noqa

pgspecial/llm.py

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
import contextlib
2+
import io
3+
import logging
4+
import os
5+
import re
6+
from runpy import run_module
7+
import shlex
8+
import sys
9+
from time import time
10+
from typing import Optional, Tuple
11+
from . import export
12+
13+
import click
14+
15+
try:
16+
import llm # type: ignore
17+
from llm.cli import cli # type: ignore
18+
except Exception: # pragma: no cover - llm may not be installed in all envs
19+
llm = None
20+
cli = None
21+
22+
from pgspecial.main import parse_special_command, Verbosity
23+
24+
log = logging.getLogger(__name__)
25+
26+
27+
def _safe_models(): # pragma: no cover - used when llm is installed
28+
try:
29+
return {x.model_id: None for x in llm.get_models()} if llm else {}
30+
except Exception:
31+
return {}
32+
33+
34+
LLM_CLI_COMMANDS = list(cli.commands.keys()) if cli else []
35+
MODELS = _safe_models()
36+
LLM_TEMPLATE_NAME = "pgspecial-llm-template"
37+
38+
39+
def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True):
40+
original_exe = sys.executable
41+
original_args = sys.argv
42+
try:
43+
sys.argv = [cmd] + list(args)
44+
code = 0
45+
if capture_output:
46+
buffer = io.StringIO()
47+
redirect = contextlib.ExitStack()
48+
redirect.enter_context(contextlib.redirect_stdout(buffer))
49+
redirect.enter_context(contextlib.redirect_stderr(buffer))
50+
else:
51+
redirect = contextlib.nullcontext()
52+
with redirect:
53+
try:
54+
run_module(cmd, run_name="__main__")
55+
except SystemExit as e:
56+
code = e.code
57+
if code != 0 and raise_exception:
58+
if capture_output:
59+
raise RuntimeError(buffer.getvalue())
60+
else:
61+
raise RuntimeError(f"Command {cmd} failed with exit code {code}.")
62+
except Exception as e:
63+
code = 1
64+
if raise_exception:
65+
if capture_output:
66+
raise RuntimeError(buffer.getvalue())
67+
else:
68+
raise RuntimeError(f"Command {cmd} failed: {e}")
69+
if restart_cli and code == 0:
70+
os.execv(original_exe, [original_exe] + original_args)
71+
if capture_output:
72+
return code, buffer.getvalue()
73+
else:
74+
return code, ""
75+
finally:
76+
sys.argv = original_args
77+
78+
79+
def build_command_tree(cmd): # pragma: no cover - not used in tests directly
80+
tree = {}
81+
if cmd and isinstance(getattr(cmd, "commands", None), dict):
82+
for name, subcmd in cmd.commands.items():
83+
if getattr(cmd, "name", None) == "models" and name == "default":
84+
tree[name] = MODELS
85+
else:
86+
tree[name] = build_command_tree(subcmd)
87+
else:
88+
tree = None
89+
return tree
90+
91+
92+
COMMAND_TREE = build_command_tree(cli) if cli else {}
93+
94+
95+
def get_completions(tokens, tree=COMMAND_TREE): # pragma: no cover - helper
96+
for token in tokens:
97+
if token.startswith("-"):
98+
continue
99+
if tree and token in tree:
100+
tree = tree[token]
101+
else:
102+
return []
103+
return list(tree.keys()) if tree else []
104+
105+
106+
@export
107+
class FinishIteration(Exception):
108+
def __init__(self, results=None):
109+
self.results = results
110+
111+
112+
USAGE = """
113+
Use an LLM to create SQL queries to answer questions from your database.
114+
Examples:
115+
116+
# Ask a question.
117+
> \\llm 'Most visited urls?'
118+
119+
# List available models
120+
> \\llm models
121+
> gpt-4o
122+
> gpt-3.5-turbo
123+
124+
# Change default model
125+
> \\llm models default llama3
126+
127+
# Set api key (not required for local models)
128+
> \\llm keys set openai
129+
130+
# Install a model plugin
131+
> \\llm install llm-ollama
132+
> llm-ollama installed.
133+
134+
# Plugins directory
135+
# https://llm.datasette.io/en/stable/plugins/directory.html
136+
"""
137+
138+
_SQL_CODE_FENCE = r"```sql\n(.*?)\n```"
139+
140+
PROMPT = """
141+
You are a helpful assistant who is a PostgreSQL expert. You are embedded in a
142+
psql-like cli tool called pgcli.
143+
144+
Answer this question:
145+
146+
$question
147+
148+
Use the following context if it is relevant to answering the question. If the
149+
question is not about the current database then ignore the context.
150+
151+
You are connected to a PostgreSQL database with the following schema:
152+
153+
$db_schema
154+
155+
Here is a sample row of data from each table:
156+
157+
$sample_data
158+
159+
If the answer can be found using a SQL query, include a sql query in a code
160+
fence such as this one:
161+
162+
```sql
163+
SELECT count(*) FROM table_name;
164+
```
165+
Keep your explanation concise and focused on the question asked.
166+
"""
167+
168+
169+
def ensure_pgspecial_template(replace=False):
170+
if not replace:
171+
code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False)
172+
if code == 0:
173+
return
174+
run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME)
175+
return
176+
177+
178+
@export
179+
def handle_llm(text, cur) -> Tuple[str, Optional[str], float]:
180+
_, verbosity, arg = parse_special_command(text)
181+
if not arg.strip():
182+
output = USAGE
183+
raise FinishIteration(output)
184+
185+
parts = shlex.split(arg)
186+
restart = False
187+
if "-c" in parts:
188+
capture_output = True
189+
use_context = False
190+
elif "prompt" in parts:
191+
capture_output = True
192+
use_context = True
193+
elif "install" in parts or "uninstall" in parts:
194+
capture_output = False
195+
use_context = False
196+
restart = True
197+
elif parts and parts[0] in LLM_CLI_COMMANDS:
198+
capture_output = False
199+
use_context = False
200+
elif parts and parts[0] == "--help":
201+
capture_output = False
202+
use_context = False
203+
else:
204+
capture_output = True
205+
use_context = True
206+
207+
if not use_context:
208+
args = parts
209+
if capture_output:
210+
click.echo("Calling llm command")
211+
start = time()
212+
_, result = run_external_cmd("llm", *args, capture_output=capture_output)
213+
end = time()
214+
match = re.search(_SQL_CODE_FENCE, result, re.DOTALL)
215+
if match:
216+
sql = match.group(1).strip()
217+
else:
218+
output = result
219+
raise FinishIteration(output)
220+
return (result if verbosity == Verbosity.SUCCINCT else "", sql, end - start)
221+
else:
222+
run_external_cmd("llm", *args, restart_cli=restart)
223+
raise FinishIteration(None)
224+
225+
try:
226+
ensure_pgspecial_template()
227+
start = time()
228+
context, sql = sql_using_llm(cur=cur, question=arg)
229+
end = time()
230+
if verbosity == Verbosity.SUCCINCT:
231+
context = ""
232+
return (context, sql, end - start)
233+
except Exception as e:
234+
raise RuntimeError(e)
235+
236+
237+
@export
238+
def is_llm_command(command) -> bool:
239+
cmd, _, _ = parse_special_command(command)
240+
return cmd in ("\\llm", "\\ai")
241+
242+
243+
def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]:
244+
if cur is None:
245+
raise RuntimeError("Connect to a database and try again.")
246+
247+
schema_sql = """
248+
SELECT
249+
table_schema,
250+
table_name,
251+
string_agg(column_name || ' ' || data_type, ', ' ORDER BY ordinal_position) AS cols
252+
FROM information_schema.columns
253+
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
254+
GROUP BY table_schema, table_name
255+
ORDER BY table_schema, table_name
256+
"""
257+
tables_sql = """
258+
SELECT table_schema, table_name
259+
FROM information_schema.tables
260+
WHERE table_schema NOT IN ('pg_catalog', 'information_schema')
261+
AND table_type IN ('BASE TABLE', 'VIEW')
262+
ORDER BY table_schema, table_name
263+
"""
264+
sample_row_tmpl = 'SELECT * FROM "{schema}"."{table}" LIMIT 1'
265+
266+
click.echo("Preparing schema information to feed the llm")
267+
cur.execute(schema_sql)
268+
db_schema = []
269+
for row in cur.fetchall():
270+
# Support both tuple results and dict-like rows
271+
if isinstance(row, (list, tuple)):
272+
schema, table, cols = row
273+
else:
274+
schema, table, cols = row["table_schema"], row["table_name"], row["cols"]
275+
db_schema.append(f"{schema}.{table}({cols})")
276+
db_schema = "\n".join(db_schema)
277+
278+
cur.execute(tables_sql)
279+
sample_data = {}
280+
for row in cur.fetchall():
281+
if isinstance(row, (list, tuple)):
282+
schema, table = row
283+
else:
284+
schema, table = row["table_schema"], row["table_name"]
285+
try:
286+
cur.execute(sample_row_tmpl.format(schema=schema, table=table))
287+
except Exception:
288+
continue
289+
cols = [desc[0] for desc in getattr(cur, "description", [])]
290+
one = getattr(cur, "fetchone", lambda: None)()
291+
if not one:
292+
continue
293+
sample_data[f"{schema}.{table}"] = list(zip(cols, one))
294+
295+
args = [
296+
"--template",
297+
LLM_TEMPLATE_NAME,
298+
"--param",
299+
"db_schema",
300+
db_schema,
301+
"--param",
302+
"sample_data",
303+
sample_data,
304+
"--param",
305+
"question",
306+
question,
307+
" ",
308+
]
309+
click.echo("Invoking llm command with schema information")
310+
_, result = run_external_cmd("llm", *args, capture_output=True)
311+
match = re.search(_SQL_CODE_FENCE, result, re.DOTALL)
312+
if match:
313+
sql = match.group(1).strip()
314+
else:
315+
sql = ""
316+
return (result, sql)

pgspecial/main.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,16 @@
33
import logging
44
from collections import namedtuple
55

6-
from . import export
76
from .help.commands import helpcommands
7+
from . import export
8+
from enum import Enum
9+
10+
11+
class Verbosity(Enum):
12+
SUCCINCT = "succinct"
13+
NORMAL = "normal"
14+
VERBOSE = "verbose"
15+
816

917
log = logging.getLogger(__name__)
1018

@@ -96,7 +104,7 @@ def register(self, *args, **kwargs):
96104

97105
def execute(self, cur, sql):
98106
commands = self.commands
99-
command, verbose, pattern = parse_special_command(sql)
107+
command, verbosity, pattern = parse_special_command(sql)
100108

101109
if (command not in commands) and (command.lower() not in commands):
102110
raise CommandNotFound
@@ -111,7 +119,8 @@ def execute(self, cur, sql):
111119
if special_cmd.arg_type == NO_QUERY:
112120
return special_cmd.handler()
113121
elif special_cmd.arg_type == PARSED_QUERY:
114-
return special_cmd.handler(cur=cur, pattern=pattern, verbose=verbose)
122+
# Keep existing handlers working: convert Verbosity -> bool
123+
return special_cmd.handler(cur=cur, pattern=pattern, verbose=(verbosity == Verbosity.VERBOSE))
115124
elif special_cmd.arg_type == RAW_QUERY:
116125
return special_cmd.handler(cur=cur, query=sql)
117126

@@ -225,10 +234,14 @@ def content_exceeds_width(row, width):
225234
@export
226235
def parse_special_command(sql):
227236
command, _, arg = sql.partition(" ")
228-
verbose = "+" in command
229-
230-
command = command.strip().replace("+", "")
231-
return (command, verbose, arg.strip())
237+
verbosity = Verbosity.NORMAL
238+
if "+" in command:
239+
verbosity = Verbosity.VERBOSE
240+
elif "-" in command:
241+
verbosity = Verbosity.SUCCINCT
242+
243+
command = command.strip().strip("+-")
244+
return (command, verbosity, arg.strip())
232245

233246

234247
def show_extra_help_command(command, syntax, description):

0 commit comments

Comments
 (0)