From a0c50b8b3d96a4c62e2f1b9ac6df57428bcf1b32 Mon Sep 17 00:00:00 2001 From: Roland Walker Date: Sat, 28 Mar 2026 14:35:54 -0400 Subject: [PATCH] make progress and checkpoint strictly by-statement Previously --progress and --checkpoint were influenced by linebreaks to some extent: multiline queries were correctly joined and counted/ dispatched/checkpointed as one query, but multiple queries on a single line were dispatched together. That means that the progress estimation could be thrown off somewhat, depending on the file contents, and more importantly means that a statement which was part of line with more than one statement might fail to be written to the line-influenced checkpoint file if that particular query succeeded, but a subsequent query on the same line failed. This subtlety is important if we are to use the checkpoint file to resume scripts, though in general it would be best when running scripts to avoid all of these corner cases by having one statement per line. We pull in sqlparse in addition to sqlglot, because sqlparse has the feature of preserving the input literally when splitting multi-statement lines. This also fixes a bug: the generator here named batch_gen was recreated in the --progress loop, which didn't matter before this change since iterating over a filehandle covered up the issue. Tests are added for statements_from_filehandle(), which had no coverage before. Incidentally * fix missing changelog entry * fix whitespace in a comment * remove a backslash by double-quoting a string which contains a single quote --- changelog.md | 6 +++ mycli/main.py | 11 ++--- mycli/packages/batch_utils.py | 12 ++++-- test/pytests/test_batch_utils.py | 54 ++++++++++++++++++++++++ test/pytests/test_main.py | 70 +++++++++++++++++++++++++++++++- 5 files changed, 144 insertions(+), 9 deletions(-) create mode 100644 test/pytests/test_batch_utils.py diff --git a/changelog.md b/changelog.md index 4c9602e5..bd05c5b6 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,12 @@ Upcoming (TBD) Features --------- * Continue to expand TIPS. +* Make `--progress` and `--checkpoint` strictly by statement. + + +Internal +--------- +* Add an `AGENTS.md`. 1.67.1 (2026/03/28) diff --git a/mycli/main.py b/mycli/main.py index 79050e5f..b2fe711c 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -2190,7 +2190,7 @@ class CliArgs: @click.command() @clickdc.adddc('cli_args', CliArgs) -@click.version_option(__version__, '--version', '-V', help='Output mycli\'s version.') +@click.version_option(__version__, '--version', '-V', help="Output mycli's version.") def click_entrypoint( cli_args: CliArgs, ) -> None: @@ -2658,7 +2658,7 @@ def get_password_from_file(password_file: str | None) -> str | None: cli_args.port, ) - # --execute argument + # --execute argument if cli_args.execute: if not sys.stdin.isatty(): click.secho('Ignoring STDIN since --execute was also given.', err=True, fg='red') @@ -2742,6 +2742,7 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None: goal_statements += 1 batch_count_h.close() batch_h = click.open_file(cli_args.batch) + batch_gen = statements_from_filehandle(batch_h) except (OSError, FileNotFoundError): click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') sys.exit(1) @@ -2762,9 +2763,9 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None: ] err_output = prompt_toolkit.output.create_output(stdout=sys.stderr, always_prefer_tty=True) with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb: - for pb_counter in pb(range(goal_statements)): - statement, _untrusted_counter = next(statements_from_filehandle(batch_h)) - dispatch_batch_statements(statement, pb_counter) + for _pb_counter in pb(range(goal_statements)): + statement, statement_counter = next(batch_gen) + dispatch_batch_statements(statement, statement_counter) except (ValueError, StopIteration) as e: click.secho(str(e), err=True, fg='red') sys.exit(1) diff --git a/mycli/packages/batch_utils.py b/mycli/packages/batch_utils.py index 34e48073..d0ebd218 100644 --- a/mycli/packages/batch_utils.py +++ b/mycli/packages/batch_utils.py @@ -1,6 +1,7 @@ from typing import IO, Generator import sqlglot +import sqlparse MAX_MULTILINE_BATCH_STATEMENT = 5000 @@ -20,11 +21,16 @@ def statements_from_filehandle(file_h: IO) -> Generator[tuple[str, int], None, N continue # we don't yet handle changing the delimiter within the batch input if tokens[-1].text == ';': - yield (statements, batch_counter) - batch_counter += 1 + # The advantage of sqlparse for splitting is that it preserves the input. + # https://github.com/tobymao/sqlglot/issues/2587#issuecomment-1823109501 + for statement in sqlparse.split(statements): + yield (statement, batch_counter) + batch_counter += 1 statements = '' line_counter = 0 except sqlglot.errors.TokenError: continue if statements: - yield (statements, batch_counter) + for statement in sqlparse.split(statements): + yield (statement, batch_counter) + batch_counter += 1 diff --git a/test/pytests/test_batch_utils.py b/test/pytests/test_batch_utils.py new file mode 100644 index 00000000..c00a76a6 --- /dev/null +++ b/test/pytests/test_batch_utils.py @@ -0,0 +1,54 @@ +# type: ignore + +from io import StringIO + +import pytest + +import mycli.packages.batch_utils +from mycli.packages.batch_utils import statements_from_filehandle + + +def collect_statements(sql: str) -> list[tuple[str, int]]: + return list(statements_from_filehandle(StringIO(sql))) + + +def test_statements_from_filehandle_splits_on_statements() -> None: + statements = collect_statements('select 1;\nselect\n 2;\nselect 3; select 4;\n') + + assert statements == [ + ('select 1;', 0), + ('select\n 2;', 1), + ('select 3;', 2), + ('select 4;', 3), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_01() -> None: + statements = collect_statements('select 1;\nselect 2;') + + assert statements == [ + ('select 1;', 0), + ('select 2;', 1), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_02() -> None: + statements = collect_statements('select 1;\nselect 2') + + assert statements == [ + ('select 1;', 0), + ('select 2', 1), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_03() -> None: + statements = collect_statements('select 1\nwhere 1 == 1;') + + assert statements == [('select 1\nwhere 1 == 1;', 0)] + + +def test_statements_from_filehandle_rejects_overlong_statement(monkeypatch) -> None: + monkeypatch.setattr(mycli.packages.batch_utils, 'MAX_MULTILINE_BATCH_STATEMENT', 2) + + with pytest.raises(ValueError, match='Saw single input statement greater than 2 lines'): + list(statements_from_filehandle(StringIO('select 1,\n2\nwhere 1 = 1;'))) diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index a6182501..85b13405 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -2139,6 +2139,25 @@ def test_batch_file(monkeypatch): os.remove(batch_file.name) +def test_batch_file_no_progress_multiple_statements_per_line(monkeypatch): + mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file: + batch_file.write('select 2; select 3;\nselect 4;\n') + batch_file.flush() + + try: + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', batch_file.name], + ) + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 2;', 'select 3;', 'select 4;'] + finally: + os.remove(batch_file.name) + + def test_batch_file_with_progress(monkeypatch): mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) runner = CliRunner() @@ -2182,7 +2201,56 @@ def __call__(self, iterable): args=['--batch', batch_file.name, '--progress'], ) assert result.exit_code == 0 - assert MockMyCli.ran_queries == ['select 2;\n', 'select 2;\n', 'select 2;\n'] + assert MockMyCli.ran_queries == ['select 2;', 'select 2;', 'select 2;'] + assert DummyProgressBar.calls == [[0, 1, 2]] + finally: + os.remove(batch_file.name) + + +def test_batch_file_with_progress_multiple_statements_per_line(monkeypatch): + mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + class DummyProgressBar: + calls = [] + + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def __call__(self, iterable): + values = list(iterable) + DummyProgressBar.calls.append(values) + return values + + monkeypatch.setattr(mycli_main, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(mycli_main.prompt_toolkit.output, 'create_output', lambda **kwargs: object()) + monkeypatch.setattr( + mycli_main, + 'sys', + SimpleNamespace( + stdin=SimpleNamespace(isatty=lambda: False), + stderr=SimpleNamespace(isatty=lambda: True), + exit=sys.exit, + ), + ) + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file: + batch_file.write('select 2; select 3;\nselect 4;\n') + batch_file.flush() + + try: + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', batch_file.name, '--progress'], + ) + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 2;', 'select 3;', 'select 4;'] assert DummyProgressBar.calls == [[0, 1, 2]] finally: os.remove(batch_file.name)