diff --git a/CHANGELOG.md b/CHANGELOG.md index afdb40f..c0bbb21 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,18 @@ ## Unreleased +### Bug Fixes + +- Expand `~` in configured log file paths before opening the log. + ### Internal - Add a GitHub Actions workflow to run Codex review on pull requests. - Drop Python 3.9 from test matrices and tooling targets. +### Features + +- Add `--readonly` and `.open --readonly` support for opening databases read-only. + ## 1.19.0 - 2026-01-30 ### Features diff --git a/litecli/completion_refresher.py b/litecli/completion_refresher.py index 4e76faa..431ab1b 100644 --- a/litecli/completion_refresher.py +++ b/litecli/completion_refresher.py @@ -72,7 +72,7 @@ def _bg_refresh( executor = sqlexecute else: # Create a new sqlexecute method to populate the completions. - executor = SQLExecute(e.dbname) + executor = SQLExecute(e.connect_target) # If callbacks is a single function then push it into a list. if callable(callbacks): diff --git a/litecli/main.py b/litecli/main.py index 8151ffa..9407d45 100644 --- a/litecli/main.py +++ b/litecli/main.py @@ -12,7 +12,7 @@ from datetime import datetime from io import open from time import time -from typing import Any, Generator, Iterable, cast +from typing import Any, Generator, Iterable, Literal, TextIO, cast import click import sqlparse @@ -44,7 +44,7 @@ from .packages.prompt_utils import confirm, confirm_destructive_query from .packages.special.main import NO_QUERY from .sqlcompleter import SQLCompleter -from .sqlexecute import SQLExecute +from .sqlexecute import SQLExecute, make_readonly_uri def _load_sqlite3() -> Any: @@ -75,13 +75,13 @@ def __init__( self, sqlexecute: SQLExecute | None = None, prompt: str | None = None, - logfile: Any | None = None, + logfile: TextIO | None = None, auto_vertical_output: bool = False, warn: bool | None = None, liteclirc: str | None = None, ) -> None: self.sqlexecute = sqlexecute - self.logfile = logfile + self.logfile: TextIO | Literal[False] | None = logfile # Load config. c = self.config = get_config(liteclirc) @@ -203,6 +203,12 @@ def change_db(self, arg: str | None, **_: Any) -> Iterable[tuple]: assert self.sqlexecute is not None self.sqlexecute.connect() else: + open_args = arg.split(maxsplit=1) + if open_args and open_args[0] == "--readonly": + if len(open_args) == 1: + yield (None, None, None, "Missing required argument, database.") + return + arg = make_readonly_uri(open_args[1]) assert self.sqlexecute is not None self.sqlexecute.connect(database=arg) @@ -249,6 +255,7 @@ def initialize_logging(self) -> None: log_file = self.config["main"]["log_file"] if log_file == "default": log_file = config_location() + "log" + log_file = os.path.expanduser(log_file) try: ensure_dir_exists(log_file) except OSError: @@ -472,7 +479,9 @@ def one_iteration(text: str | None = None) -> None: try: start = time() assert self.sqlexecute is not None - cur = self.sqlexecute.conn and self.sqlexecute.conn.cursor() + conn = self.sqlexecute.conn + assert conn is not None + cur = conn.cursor() context, sql, duration = special.handle_llm(text, cur) if context: click.echo("LLM Reponse:") @@ -534,7 +543,9 @@ def one_iteration(text: str | None = None) -> None: except KeyboardInterrupt: try: # since connection can be sqlite3 or sqlean, it's hard to annotate the type for interrupt. so ignore the type hint warning. - sqlexecute.conn.interrupt() # type: ignore[attr-defined] + conn = sqlexecute.conn + if conn is not None: + conn.interrupt() # type: ignore[attr-defined] except Exception as e: self.echo( "Encountered error while cancelling query: {}".format(e), @@ -791,7 +802,7 @@ def _on_completions_refreshed(self, new_completer: SQLCompleter) -> None: def get_completions(self, text: str, cursor_positition: int) -> Iterable[Completion]: with self._completer_lock: - return cast(Iterable[Completion], self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)) + return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None) def get_prompt(self, string: str) -> str: self.logger.debug("Getting prompt %r", string) @@ -927,17 +938,19 @@ def get_last_query(self) -> str | None: @click.option("-t", "--table", is_flag=True, help="Display batch output in table format.") @click.option("--csv", is_flag=True, help="Display batch output in CSV format.") @click.option("--warn/--no-warn", default=None, help="Warn before running a destructive query.") +@click.option("--readonly", is_flag=True, help="Open the database in read-only mode.") @click.option("-e", "--execute", type=str, help="Execute command and quit.") @click.argument("database", default="", nargs=1) def cli( database: str, dbname: str, prompt: str | None, - logfile: Any | None, + logfile: TextIO | None, auto_vertical_output: bool, table: bool, csv: bool, warn: bool | None, + readonly: bool, execute: str | None, liteclirc: str, ) -> None: @@ -958,6 +971,8 @@ def cli( # Choose which ever one has a valid value. database = database or dbname + if readonly and database: + database = make_readonly_uri(database) litecli.connect(database) diff --git a/litecli/sqlexecute.py b/litecli/sqlexecute.py index 5e52b73..03b942d 100644 --- a/litecli/sqlexecute.py +++ b/litecli/sqlexecute.py @@ -4,7 +4,8 @@ import os.path from contextlib import closing from typing import Any, Generator, Iterable, cast -from urllib.parse import urlparse +from urllib.parse import parse_qsl, unquote, urlencode, urlparse, urlunparse +from urllib.request import pathname2url import sqlparse @@ -60,6 +61,7 @@ class SQLExecute(object): WHERE ROUTINE_TYPE="FUNCTION" AND ROUTINE_SCHEMA = "%s"''' def __init__(self, database: str | None): + self.connect_target: str | None = database self.dbname: str | None = database self._server_type: tuple[str, str] | None = None # Connection can be sqlite3.Connection or sqlean.sqlite3 connection. @@ -70,7 +72,7 @@ def __init__(self, database: str | None): self.connect() def connect(self, database: str | None = None) -> None: - db = database or self.dbname + db = database or self.connect_target _logger.debug("Connection DB Params: \n\tdatabase: %r", db) if db is None: # Nothing to connect to. @@ -80,7 +82,7 @@ def connect(self, database: str | None = None) -> None: if location.scheme and location.scheme == "file": uri = True db_name = db - db_filename = location.path + db_filename = unquote(location.path) else: uri = False db_filename = db_name = os.path.expanduser(db) @@ -96,6 +98,7 @@ def connect(self, database: str | None = None) -> None: self.conn = conn # Update them after the connection is made to ensure that it was a # successful connection. + self.connect_target = db_name self.dbname = db_filename def run(self, statement: str) -> Iterable[tuple]: @@ -220,3 +223,13 @@ def functions(self) -> Iterable[tuple]: def server_type(self) -> tuple[str, str]: self._server_type = ("sqlite3", "3") return self._server_type + + +def make_readonly_uri(database: str) -> str: + location = urlparse(database) + if location.scheme == "file": + query = [(key, value) for key, value in parse_qsl(location.query, keep_blank_values=True) if key.lower() != "mode"] + query.append(("mode", "ro")) + return urlunparse(location._replace(query=urlencode(query))) + + return "file:{}?mode=ro".format(pathname2url(os.path.expanduser(database))) diff --git a/tests/test_completion_refresher.py b/tests/test_completion_refresher.py index 32da6bd..01110bc 100644 --- a/tests/test_completion_refresher.py +++ b/tests/test_completion_refresher.py @@ -93,3 +93,18 @@ def test_refresh_with_callbacks(refresher): refresher.refresh(sqlexecute, callbacks) time.sleep(1) # Wait for the thread to work. assert callbacks[0].call_count == 1 + + +def test_bg_refresh_uses_connect_target(refresher): + callbacks = Mock() + sqlexecute_class = Mock() + sqlexecute = Mock() + sqlexecute.dbname = "/tmp/test.db" + sqlexecute.connect_target = "file:/tmp/test.db?mode=ro" + + with patch("litecli.completion_refresher.SQLExecute", sqlexecute_class): + refresher.refreshers = {} + refresher._bg_refresh(sqlexecute, callbacks, {}) + + sqlexecute_class.assert_called_once_with("file:/tmp/test.db?mode=ro") + callbacks.assert_called_once() diff --git a/tests/test_main.py b/tests/test_main.py index 0a47c9b..5ece7ab 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,10 @@ +import logging import os import shutil from collections import namedtuple from datetime import datetime from textwrap import dedent +from typing import Any, cast from unittest.mock import patch import click @@ -12,6 +14,7 @@ from litecli.main import LiteCli, cli from litecli.packages.special.main import COMMANDS as SPECIAL_COMMANDS +from litecli.sqlexecute import OperationalError, make_readonly_uri from .utils import create_db, db_connection, dbtest, run @@ -148,9 +151,8 @@ def output(monkeypatch, terminal_size, testdata, explicit_pager, expect_pager): class TestOutput: def get_size(self): - size = namedtuple("Size", "rows columns") - size.columns, size.rows = terminal_size - return size + Size = namedtuple("Size", "rows columns") + return Size(rows=terminal_size[1], columns=terminal_size[0]) class TestExecute: host = "test" @@ -165,7 +167,7 @@ class PromptBuffer(PromptSession): output = TestOutput() m.prompt_app = PromptBuffer() - m.sqlexecute = TestExecute() + m.sqlexecute = cast(Any, TestExecute()) m.explicit_pager = explicit_pager def echo_via_pager(s): @@ -232,18 +234,15 @@ def test_conditional_pager(monkeypatch): SPECIAL_COMMANDS["pager"].handler("") -def test_reserved_space_is_integer(): +def test_reserved_space_is_integer(monkeypatch): """Make sure that reserved space is returned as an integer.""" - def stub_terminal_size(): - return (5, 5) + def stub_terminal_size(fallback=(80, 24)): + return os.terminal_size((5, 5)) - old_func = shutil.get_terminal_size - - shutil.get_terminal_size = stub_terminal_size # type: ignore[assignment] + monkeypatch.setattr(shutil, "get_terminal_size", stub_terminal_size) lc = LiteCli() assert isinstance(lc.get_reserved_space(), int) - shutil.get_terminal_size = old_func @dbtest @@ -278,6 +277,33 @@ def test_startup_commands(executor): # implement tests on executions of the startupcommands +def test_initialize_logging_expands_user_log_file(monkeypatch, tmp_path): + home = tmp_path / "home" + log_file = home / ".cache" / "litecli" / "log" + monkeypatch.setenv("HOME", str(home)) + monkeypatch.setenv("USERPROFILE", str(home)) + + m = cast(Any, object.__new__(LiteCli)) + m.config = {"main": {"log_file": "~/.cache/litecli/log", "log_level": "INFO"}} + echo_messages = [] + m.echo = lambda *args, **kwargs: echo_messages.append((args, kwargs)) + + root_logger = logging.getLogger("litecli") + original_handlers = list(root_logger.handlers) + try: + m.initialize_logging() + + added_handlers = [handler for handler in root_logger.handlers if handler not in original_handlers] + assert log_file.exists() + assert not echo_messages + assert any(isinstance(handler, logging.FileHandler) and handler.baseFilename == str(log_file) for handler in added_handlers) + finally: + for handler in root_logger.handlers[:]: + if handler not in original_handlers: + root_logger.removeHandler(handler) + handler.close() + + @patch("litecli.main.datetime") # Adjust if your module path is different def test_get_prompt(mock_datetime): # We'll freeze time at 2025-01-20 13:37:42 for comedic effect. @@ -365,3 +391,80 @@ def test_file_uri(tmp_path, uri, expected_dbname): lc.connect(uri) assert lc.get_prompt(r"\d") == expected_dbname.format(tmp_path=tmp_path) + + +def _create_readonly_test_db(db_path): + conn = db_connection(str(db_path)) + try: + conn.execute("create table test(value text)") + conn.execute("insert into test values('seed')") + finally: + conn.close() + + +def test_make_readonly_uri_adds_readonly_mode(tmp_path): + db_path = str(tmp_path / "test.db") + + assert make_readonly_uri(db_path) == f"file:{db_path}?mode=ro" + assert make_readonly_uri(f"file:{db_path}?cache=shared") == f"file://{db_path}?cache=shared&mode=ro" + assert make_readonly_uri(f"file:{db_path}?mode=rw&cache=shared") == f"file://{db_path}?cache=shared&mode=ro" + + +def test_readonly_option_opens_database_readonly(tmp_path): + db_path = tmp_path / "readonly.db" + _create_readonly_test_db(db_path) + + runner = CliRunner() + result = runner.invoke( + cli, + args=[ + "--liteclirc", + default_config_file, + "--readonly", + "-e", + "select value from test;", + str(db_path), + ], + ) + + assert result.exit_code == 0 + assert "seed" in result.output + + result = runner.invoke( + cli, + args=[ + "--liteclirc", + default_config_file, + "--readonly", + "-e", + "insert into test values('blocked');", + str(db_path), + ], + ) + + assert result.exit_code == 1 + assert "readonly" in result.output.lower() + + conn = db_connection(str(db_path)) + try: + rows = conn.execute("select value from test").fetchall() + finally: + conn.close() + assert rows == [("seed",)] + + +def test_open_readonly_opens_database_readonly(tmp_path): + db_path = tmp_path / "readonly.db" + _create_readonly_test_db(db_path) + + lc = LiteCli(liteclirc=default_config_file) + lc.connect("") + assert lc.sqlexecute is not None + + results = run(lc.sqlexecute, f".open --readonly {db_path}") + + assert results[0]["status"] == f'You are now connected to database "{db_path}"' + assert lc.sqlexecute.connect_target == make_readonly_uri(str(db_path)) + with pytest.raises(OperationalError) as excinfo: + run(lc.sqlexecute, "insert into test values('blocked')") + assert "readonly" in str(excinfo.value).lower()