diff --git a/changelog.md b/changelog.md index 4c9602e5..bd05c5b6 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,12 @@ Upcoming (TBD) Features --------- * Continue to expand TIPS. +* Make `--progress` and `--checkpoint` strictly by statement. + + +Internal +--------- +* Add an `AGENTS.md`. 1.67.1 (2026/03/28) diff --git a/mycli/main.py b/mycli/main.py index 79050e5f..b2fe711c 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -2190,7 +2190,7 @@ class CliArgs: @click.command() @clickdc.adddc('cli_args', CliArgs) -@click.version_option(__version__, '--version', '-V', help='Output mycli\'s version.') +@click.version_option(__version__, '--version', '-V', help="Output mycli's version.") def click_entrypoint( cli_args: CliArgs, ) -> None: @@ -2658,7 +2658,7 @@ def get_password_from_file(password_file: str | None) -> str | None: cli_args.port, ) - # --execute argument + # --execute argument if cli_args.execute: if not sys.stdin.isatty(): click.secho('Ignoring STDIN since --execute was also given.', err=True, fg='red') @@ -2742,6 +2742,7 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None: goal_statements += 1 batch_count_h.close() batch_h = click.open_file(cli_args.batch) + batch_gen = statements_from_filehandle(batch_h) except (OSError, FileNotFoundError): click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') sys.exit(1) @@ -2762,9 +2763,9 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None: ] err_output = prompt_toolkit.output.create_output(stdout=sys.stderr, always_prefer_tty=True) with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb: - for pb_counter in pb(range(goal_statements)): - statement, _untrusted_counter = next(statements_from_filehandle(batch_h)) - dispatch_batch_statements(statement, pb_counter) + for _pb_counter in pb(range(goal_statements)): + statement, statement_counter = next(batch_gen) + dispatch_batch_statements(statement, statement_counter) except (ValueError, StopIteration) as e: click.secho(str(e), err=True, fg='red') sys.exit(1) diff --git a/mycli/packages/batch_utils.py b/mycli/packages/batch_utils.py index 34e48073..d0ebd218 100644 --- a/mycli/packages/batch_utils.py +++ b/mycli/packages/batch_utils.py @@ -1,6 +1,7 @@ from typing import IO, Generator import sqlglot +import sqlparse MAX_MULTILINE_BATCH_STATEMENT = 5000 @@ -20,11 +21,16 @@ def statements_from_filehandle(file_h: IO) -> Generator[tuple[str, int], None, N continue # we don't yet handle changing the delimiter within the batch input if tokens[-1].text == ';': - yield (statements, batch_counter) - batch_counter += 1 + # The advantage of sqlparse for splitting is that it preserves the input. + # https://github.com/tobymao/sqlglot/issues/2587#issuecomment-1823109501 + for statement in sqlparse.split(statements): + yield (statement, batch_counter) + batch_counter += 1 statements = '' line_counter = 0 except sqlglot.errors.TokenError: continue if statements: - yield (statements, batch_counter) + for statement in sqlparse.split(statements): + yield (statement, batch_counter) + batch_counter += 1 diff --git a/test/pytests/test_batch_utils.py b/test/pytests/test_batch_utils.py new file mode 100644 index 00000000..c00a76a6 --- /dev/null +++ b/test/pytests/test_batch_utils.py @@ -0,0 +1,54 @@ +# type: ignore + +from io import StringIO + +import pytest + +import mycli.packages.batch_utils +from mycli.packages.batch_utils import statements_from_filehandle + + +def collect_statements(sql: str) -> list[tuple[str, int]]: + return list(statements_from_filehandle(StringIO(sql))) + + +def test_statements_from_filehandle_splits_on_statements() -> None: + statements = collect_statements('select 1;\nselect\n 2;\nselect 3; select 4;\n') + + assert statements == [ + ('select 1;', 0), + ('select\n 2;', 1), + ('select 3;', 2), + ('select 4;', 3), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_01() -> None: + statements = collect_statements('select 1;\nselect 2;') + + assert statements == [ + ('select 1;', 0), + ('select 2;', 1), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_02() -> None: + statements = collect_statements('select 1;\nselect 2') + + assert statements == [ + ('select 1;', 0), + ('select 2', 1), + ] + + +def test_statements_from_filehandle_yields_trailing_statement_without_newline_03() -> None: + statements = collect_statements('select 1\nwhere 1 == 1;') + + assert statements == [('select 1\nwhere 1 == 1;', 0)] + + +def test_statements_from_filehandle_rejects_overlong_statement(monkeypatch) -> None: + monkeypatch.setattr(mycli.packages.batch_utils, 'MAX_MULTILINE_BATCH_STATEMENT', 2) + + with pytest.raises(ValueError, match='Saw single input statement greater than 2 lines'): + list(statements_from_filehandle(StringIO('select 1,\n2\nwhere 1 = 1;'))) diff --git a/test/pytests/test_main.py b/test/pytests/test_main.py index a6182501..85b13405 100644 --- a/test/pytests/test_main.py +++ b/test/pytests/test_main.py @@ -2139,6 +2139,25 @@ def test_batch_file(monkeypatch): os.remove(batch_file.name) +def test_batch_file_no_progress_multiple_statements_per_line(monkeypatch): + mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file: + batch_file.write('select 2; select 3;\nselect 4;\n') + batch_file.flush() + + try: + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', batch_file.name], + ) + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 2;', 'select 3;', 'select 4;'] + finally: + os.remove(batch_file.name) + + def test_batch_file_with_progress(monkeypatch): mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) runner = CliRunner() @@ -2182,7 +2201,56 @@ def __call__(self, iterable): args=['--batch', batch_file.name, '--progress'], ) assert result.exit_code == 0 - assert MockMyCli.ran_queries == ['select 2;\n', 'select 2;\n', 'select 2;\n'] + assert MockMyCli.ran_queries == ['select 2;', 'select 2;', 'select 2;'] + assert DummyProgressBar.calls == [[0, 1, 2]] + finally: + os.remove(batch_file.name) + + +def test_batch_file_with_progress_multiple_statements_per_line(monkeypatch): + mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + class DummyProgressBar: + calls = [] + + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def __call__(self, iterable): + values = list(iterable) + DummyProgressBar.calls.append(values) + return values + + monkeypatch.setattr(mycli_main, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(mycli_main.prompt_toolkit.output, 'create_output', lambda **kwargs: object()) + monkeypatch.setattr( + mycli_main, + 'sys', + SimpleNamespace( + stdin=SimpleNamespace(isatty=lambda: False), + stderr=SimpleNamespace(isatty=lambda: True), + exit=sys.exit, + ), + ) + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file: + batch_file.write('select 2; select 3;\nselect 4;\n') + batch_file.flush() + + try: + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', batch_file.name, '--progress'], + ) + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 2;', 'select 3;', 'select 4;'] assert DummyProgressBar.calls == [[0, 1, 2]] finally: os.remove(batch_file.name)