|
| 1 | +import contextlib |
| 2 | +import io |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import re |
| 6 | +from runpy import run_module |
| 7 | +import shlex |
| 8 | +import sys |
| 9 | +from time import time |
| 10 | +from typing import Optional, Tuple |
| 11 | +from . import export |
| 12 | + |
| 13 | +import click |
| 14 | + |
| 15 | +try: |
| 16 | + import llm # type: ignore |
| 17 | + from llm.cli import cli # type: ignore |
| 18 | +except Exception: # pragma: no cover - llm may not be installed in all envs |
| 19 | + llm = None |
| 20 | + cli = None |
| 21 | + |
| 22 | +from pgspecial.main import parse_special_command, Verbosity |
| 23 | + |
| 24 | +log = logging.getLogger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +def _safe_models(): # pragma: no cover - used when llm is installed |
| 28 | + try: |
| 29 | + return {x.model_id: None for x in llm.get_models()} if llm else {} |
| 30 | + except Exception: |
| 31 | + return {} |
| 32 | + |
| 33 | + |
| 34 | +LLM_CLI_COMMANDS = list(cli.commands.keys()) if cli else [] |
| 35 | +MODELS = _safe_models() |
| 36 | +LLM_TEMPLATE_NAME = "pgspecial-llm-template" |
| 37 | + |
| 38 | + |
| 39 | +def run_external_cmd(cmd, *args, capture_output=False, restart_cli=False, raise_exception=True): |
| 40 | + original_exe = sys.executable |
| 41 | + original_args = sys.argv |
| 42 | + try: |
| 43 | + sys.argv = [cmd] + list(args) |
| 44 | + code = 0 |
| 45 | + if capture_output: |
| 46 | + buffer = io.StringIO() |
| 47 | + redirect = contextlib.ExitStack() |
| 48 | + redirect.enter_context(contextlib.redirect_stdout(buffer)) |
| 49 | + redirect.enter_context(contextlib.redirect_stderr(buffer)) |
| 50 | + else: |
| 51 | + redirect = contextlib.nullcontext() |
| 52 | + with redirect: |
| 53 | + try: |
| 54 | + run_module(cmd, run_name="__main__") |
| 55 | + except SystemExit as e: |
| 56 | + code = e.code |
| 57 | + if code != 0 and raise_exception: |
| 58 | + if capture_output: |
| 59 | + raise RuntimeError(buffer.getvalue()) |
| 60 | + else: |
| 61 | + raise RuntimeError(f"Command {cmd} failed with exit code {code}.") |
| 62 | + except Exception as e: |
| 63 | + code = 1 |
| 64 | + if raise_exception: |
| 65 | + if capture_output: |
| 66 | + raise RuntimeError(buffer.getvalue()) |
| 67 | + else: |
| 68 | + raise RuntimeError(f"Command {cmd} failed: {e}") |
| 69 | + if restart_cli and code == 0: |
| 70 | + os.execv(original_exe, [original_exe] + original_args) |
| 71 | + if capture_output: |
| 72 | + return code, buffer.getvalue() |
| 73 | + else: |
| 74 | + return code, "" |
| 75 | + finally: |
| 76 | + sys.argv = original_args |
| 77 | + |
| 78 | + |
| 79 | +def build_command_tree(cmd): # pragma: no cover - not used in tests directly |
| 80 | + tree = {} |
| 81 | + if cmd and isinstance(getattr(cmd, "commands", None), dict): |
| 82 | + for name, subcmd in cmd.commands.items(): |
| 83 | + if getattr(cmd, "name", None) == "models" and name == "default": |
| 84 | + tree[name] = MODELS |
| 85 | + else: |
| 86 | + tree[name] = build_command_tree(subcmd) |
| 87 | + else: |
| 88 | + tree = None |
| 89 | + return tree |
| 90 | + |
| 91 | + |
| 92 | +COMMAND_TREE = build_command_tree(cli) if cli else {} |
| 93 | + |
| 94 | + |
| 95 | +def get_completions(tokens, tree=COMMAND_TREE): # pragma: no cover - helper |
| 96 | + for token in tokens: |
| 97 | + if token.startswith("-"): |
| 98 | + continue |
| 99 | + if tree and token in tree: |
| 100 | + tree = tree[token] |
| 101 | + else: |
| 102 | + return [] |
| 103 | + return list(tree.keys()) if tree else [] |
| 104 | + |
| 105 | + |
| 106 | +@export |
| 107 | +class FinishIteration(Exception): |
| 108 | + def __init__(self, results=None): |
| 109 | + self.results = results |
| 110 | + |
| 111 | + |
| 112 | +USAGE = """ |
| 113 | +Use an LLM to create SQL queries to answer questions from your database. |
| 114 | +Examples: |
| 115 | +
|
| 116 | +# Ask a question. |
| 117 | +> \\llm 'Most visited urls?' |
| 118 | +
|
| 119 | +# List available models |
| 120 | +> \\llm models |
| 121 | +> gpt-4o |
| 122 | +> gpt-3.5-turbo |
| 123 | +
|
| 124 | +# Change default model |
| 125 | +> \\llm models default llama3 |
| 126 | +
|
| 127 | +# Set api key (not required for local models) |
| 128 | +> \\llm keys set openai |
| 129 | +
|
| 130 | +# Install a model plugin |
| 131 | +> \\llm install llm-ollama |
| 132 | +> llm-ollama installed. |
| 133 | +
|
| 134 | +# Plugins directory |
| 135 | +# https://llm.datasette.io/en/stable/plugins/directory.html |
| 136 | +""" |
| 137 | + |
| 138 | +_SQL_CODE_FENCE = r"```sql\n(.*?)\n```" |
| 139 | + |
| 140 | +PROMPT = """ |
| 141 | +You are a helpful assistant who is a PostgreSQL expert. You are embedded in a |
| 142 | +psql-like cli tool called pgcli. |
| 143 | +
|
| 144 | +Answer this question: |
| 145 | +
|
| 146 | +$question |
| 147 | +
|
| 148 | +Use the following context if it is relevant to answering the question. If the |
| 149 | +question is not about the current database then ignore the context. |
| 150 | +
|
| 151 | +You are connected to a PostgreSQL database with the following schema: |
| 152 | +
|
| 153 | +$db_schema |
| 154 | +
|
| 155 | +Here is a sample row of data from each table: |
| 156 | +
|
| 157 | +$sample_data |
| 158 | +
|
| 159 | +If the answer can be found using a SQL query, include a sql query in a code |
| 160 | +fence such as this one: |
| 161 | +
|
| 162 | +```sql |
| 163 | +SELECT count(*) FROM table_name; |
| 164 | +``` |
| 165 | +Keep your explanation concise and focused on the question asked. |
| 166 | +""" |
| 167 | + |
| 168 | + |
| 169 | +def ensure_pgspecial_template(replace=False): |
| 170 | + if not replace: |
| 171 | + code, _ = run_external_cmd("llm", "templates", "show", LLM_TEMPLATE_NAME, capture_output=True, raise_exception=False) |
| 172 | + if code == 0: |
| 173 | + return |
| 174 | + run_external_cmd("llm", PROMPT, "--save", LLM_TEMPLATE_NAME) |
| 175 | + return |
| 176 | + |
| 177 | + |
| 178 | +@export |
| 179 | +def handle_llm(text, cur) -> Tuple[str, Optional[str], float]: |
| 180 | + _, verbosity, arg = parse_special_command(text) |
| 181 | + if not arg.strip(): |
| 182 | + output = USAGE |
| 183 | + raise FinishIteration(output) |
| 184 | + |
| 185 | + parts = shlex.split(arg) |
| 186 | + restart = False |
| 187 | + if "-c" in parts: |
| 188 | + capture_output = True |
| 189 | + use_context = False |
| 190 | + elif "prompt" in parts: |
| 191 | + capture_output = True |
| 192 | + use_context = True |
| 193 | + elif "install" in parts or "uninstall" in parts: |
| 194 | + capture_output = False |
| 195 | + use_context = False |
| 196 | + restart = True |
| 197 | + elif parts and parts[0] in LLM_CLI_COMMANDS: |
| 198 | + capture_output = False |
| 199 | + use_context = False |
| 200 | + elif parts and parts[0] == "--help": |
| 201 | + capture_output = False |
| 202 | + use_context = False |
| 203 | + else: |
| 204 | + capture_output = True |
| 205 | + use_context = True |
| 206 | + |
| 207 | + if not use_context: |
| 208 | + args = parts |
| 209 | + if capture_output: |
| 210 | + click.echo("Calling llm command") |
| 211 | + start = time() |
| 212 | + _, result = run_external_cmd("llm", *args, capture_output=capture_output) |
| 213 | + end = time() |
| 214 | + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) |
| 215 | + if match: |
| 216 | + sql = match.group(1).strip() |
| 217 | + else: |
| 218 | + output = result |
| 219 | + raise FinishIteration(output) |
| 220 | + return (result if verbosity == Verbosity.SUCCINCT else "", sql, end - start) |
| 221 | + else: |
| 222 | + run_external_cmd("llm", *args, restart_cli=restart) |
| 223 | + raise FinishIteration(None) |
| 224 | + |
| 225 | + try: |
| 226 | + ensure_pgspecial_template() |
| 227 | + start = time() |
| 228 | + context, sql = sql_using_llm(cur=cur, question=arg) |
| 229 | + end = time() |
| 230 | + if verbosity == Verbosity.SUCCINCT: |
| 231 | + context = "" |
| 232 | + return (context, sql, end - start) |
| 233 | + except Exception as e: |
| 234 | + raise RuntimeError(e) |
| 235 | + |
| 236 | + |
| 237 | +@export |
| 238 | +def is_llm_command(command) -> bool: |
| 239 | + cmd, _, _ = parse_special_command(command) |
| 240 | + return cmd in ("\\llm", "\\ai") |
| 241 | + |
| 242 | + |
| 243 | +def sql_using_llm(cur, question=None) -> Tuple[str, Optional[str]]: |
| 244 | + if cur is None: |
| 245 | + raise RuntimeError("Connect to a database and try again.") |
| 246 | + |
| 247 | + schema_sql = """ |
| 248 | + SELECT |
| 249 | + table_schema, |
| 250 | + table_name, |
| 251 | + string_agg(column_name || ' ' || data_type, ', ' ORDER BY ordinal_position) AS cols |
| 252 | + FROM information_schema.columns |
| 253 | + WHERE table_schema NOT IN ('pg_catalog', 'information_schema') |
| 254 | + GROUP BY table_schema, table_name |
| 255 | + ORDER BY table_schema, table_name |
| 256 | + """ |
| 257 | + tables_sql = """ |
| 258 | + SELECT table_schema, table_name |
| 259 | + FROM information_schema.tables |
| 260 | + WHERE table_schema NOT IN ('pg_catalog', 'information_schema') |
| 261 | + AND table_type IN ('BASE TABLE', 'VIEW') |
| 262 | + ORDER BY table_schema, table_name |
| 263 | + """ |
| 264 | + sample_row_tmpl = 'SELECT * FROM "{schema}"."{table}" LIMIT 1' |
| 265 | + |
| 266 | + click.echo("Preparing schema information to feed the llm") |
| 267 | + cur.execute(schema_sql) |
| 268 | + db_schema = [] |
| 269 | + for row in cur.fetchall(): |
| 270 | + # Support both tuple results and dict-like rows |
| 271 | + if isinstance(row, (list, tuple)): |
| 272 | + schema, table, cols = row |
| 273 | + else: |
| 274 | + schema, table, cols = row["table_schema"], row["table_name"], row["cols"] |
| 275 | + db_schema.append(f"{schema}.{table}({cols})") |
| 276 | + db_schema = "\n".join(db_schema) |
| 277 | + |
| 278 | + cur.execute(tables_sql) |
| 279 | + sample_data = {} |
| 280 | + for row in cur.fetchall(): |
| 281 | + if isinstance(row, (list, tuple)): |
| 282 | + schema, table = row |
| 283 | + else: |
| 284 | + schema, table = row["table_schema"], row["table_name"] |
| 285 | + try: |
| 286 | + cur.execute(sample_row_tmpl.format(schema=schema, table=table)) |
| 287 | + except Exception: |
| 288 | + continue |
| 289 | + cols = [desc[0] for desc in getattr(cur, "description", [])] |
| 290 | + one = getattr(cur, "fetchone", lambda: None)() |
| 291 | + if not one: |
| 292 | + continue |
| 293 | + sample_data[f"{schema}.{table}"] = list(zip(cols, one)) |
| 294 | + |
| 295 | + args = [ |
| 296 | + "--template", |
| 297 | + LLM_TEMPLATE_NAME, |
| 298 | + "--param", |
| 299 | + "db_schema", |
| 300 | + db_schema, |
| 301 | + "--param", |
| 302 | + "sample_data", |
| 303 | + sample_data, |
| 304 | + "--param", |
| 305 | + "question", |
| 306 | + question, |
| 307 | + " ", |
| 308 | + ] |
| 309 | + click.echo("Invoking llm command with schema information") |
| 310 | + _, result = run_external_cmd("llm", *args, capture_output=True) |
| 311 | + match = re.search(_SQL_CODE_FENCE, result, re.DOTALL) |
| 312 | + if match: |
| 313 | + sql = match.group(1).strip() |
| 314 | + else: |
| 315 | + sql = "" |
| 316 | + return (result, sql) |
0 commit comments