Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Features
* Respond to `-h` alone with the helpdoc.
* Allow `--hostname` as an alias for `--host`.
* Deprecate `$DSN` environment variable in favor of `$MYSQL_DSN`.
* Add a `--progress` progress-bar option with `--batch`.


Bug Fixes
Expand Down
85 changes: 72 additions & 13 deletions mycli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import clickdc
from configobj import ConfigObj
import keyring
import prompt_toolkit
from prompt_toolkit import print_formatted_text
from prompt_toolkit.application.current import get_app
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest
Expand All @@ -55,7 +56,8 @@
from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor
from prompt_toolkit.lexers import PygmentsLexer
from prompt_toolkit.output import ColorDepth
from prompt_toolkit.shortcuts import CompleteStyle, PromptSession
from prompt_toolkit.shortcuts import CompleteStyle, ProgressBar, PromptSession
from prompt_toolkit.shortcuts.progress_bar import formatters as progress_bar_formatters
import pymysql
from pymysql.constants.CR import CR_SERVER_LOST
from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR
Expand Down Expand Up @@ -2036,7 +2038,7 @@ class CliArgs:
)
ssl_verify_server_cert: bool = clickdc.option(
is_flag=True,
help=('Verify server\'s "Common Name" in its cert against hostname used when connecting. This option is disabled by default.'),
help=("""Verify server's "Common Name" in its cert against hostname used when connecting. This option is disabled by default."""),
)
verbose: bool = clickdc.option(
'-v',
Expand Down Expand Up @@ -2167,6 +2169,10 @@ class CliArgs:
default=0.0,
help='Pause in seconds between queries in batch mode.',
)
progress: bool = clickdc.option(
is_flag=True,
help='Show progress on the standard error with --batch.',
)
use_keyring: str | None = clickdc.option(
type=click.Choice(['true', 'false', 'reset']),
default=None,
Expand Down Expand Up @@ -2721,17 +2727,70 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None:
click.secho(str(e), err=True, fg="red")
sys.exit(1)

if cli_args.batch or not sys.stdin.isatty():
if cli_args.batch:
if not sys.stdin.isatty() and cli_args.batch != '-':
click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red')
try:
batch_h = click.open_file(cli_args.batch)
except (OSError, FileNotFoundError):
click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red')
sys.exit(1)
else:
batch_h = click.get_text_stream('stdin')
if cli_args.batch and cli_args.batch != '-' and cli_args.progress and sys.stderr.isatty():
# The actual number of SQL statements can be greater, if there is more than
# one statement per line, but this is how the progress bar will count.
goal_statements = 0
if not sys.stdin.isatty() and cli_args.batch != '-':
click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='yellow')
if os.path.exists(cli_args.batch) and not os.path.isfile(cli_args.batch):
click.secho('--progress is only compatible with a plain file.', err=True, fg='red')
sys.exit(1)
try:
batch_count_h = click.open_file(cli_args.batch)
for _statement, _counter in statements_from_filehandle(batch_count_h):
goal_statements += 1
batch_count_h.close()
batch_h = click.open_file(cli_args.batch)
except (OSError, FileNotFoundError):
click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red')
sys.exit(1)
except ValueError as e:
click.secho(f'Error reading --batch file: {cli_args.batch}: {e}', err=True, fg='red')
sys.exit(1)
try:
if goal_statements:
pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'})
custom_formatters = [
progress_bar_formatters.Bar(start='[', end=']', sym_a=' ', sym_b=' ', sym_c=' '),
progress_bar_formatters.Text(' '),
progress_bar_formatters.Progress(),
progress_bar_formatters.Text(' '),
progress_bar_formatters.Text('eta ', style='class:time-left'),
progress_bar_formatters.TimeLeft(),
progress_bar_formatters.Text(' ', style='class:time-left'),
]
err_output = prompt_toolkit.output.create_output(stdout=sys.stderr, always_prefer_tty=True)
with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb:
for pb_counter in pb(range(goal_statements)):
statement, _untrusted_counter = next(statements_from_filehandle(batch_h))
dispatch_batch_statements(statement, pb_counter)
except (ValueError, StopIteration) as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)
finally:
batch_h.close()
sys.exit(0)

if cli_args.batch:
if not sys.stdin.isatty() and cli_args.batch != '-':
click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red')
try:
batch_h = click.open_file(cli_args.batch)
except (OSError, FileNotFoundError):
click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red')
sys.exit(1)
try:
for statement, counter in statements_from_filehandle(batch_h):
dispatch_batch_statements(statement, counter)
batch_h.close()
except ValueError as e:
click.secho(str(e), err=True, fg='red')
sys.exit(1)
sys.exit(0)

if not sys.stdin.isatty():
batch_h = click.get_text_stream('stdin')
try:
for statement, counter in statements_from_filehandle(batch_h):
dispatch_batch_statements(statement, counter)
Expand Down
75 changes: 75 additions & 0 deletions test/pytests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import io
import os
import shutil
import sys
from tempfile import NamedTemporaryFile
from textwrap import dedent
from types import SimpleNamespace

import click
from click.testing import CliRunner
Expand Down Expand Up @@ -2137,6 +2139,79 @@ def test_batch_file(monkeypatch):
os.remove(batch_file.name)


def test_batch_file_with_progress(monkeypatch):
mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch)
runner = CliRunner()

class DummyProgressBar:
calls = []

def __init__(self, *args, **kwargs):
pass

def __enter__(self):
return self

def __exit__(self, exc_type, exc, tb):
return False

def __call__(self, iterable):
values = list(iterable)
DummyProgressBar.calls.append(values)
return values

monkeypatch.setattr(mycli_main, 'ProgressBar', DummyProgressBar)
monkeypatch.setattr(mycli_main.prompt_toolkit.output, 'create_output', lambda **kwargs: object())
monkeypatch.setattr(
mycli_main,
'sys',
SimpleNamespace(
stdin=SimpleNamespace(isatty=lambda: False),
stderr=SimpleNamespace(isatty=lambda: True),
exit=sys.exit,
),
)

with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file:
batch_file.write('select 2;\nselect 2;\nselect 2;\n')
batch_file.flush()

try:
result = runner.invoke(
mycli_main.click_entrypoint,
args=['--batch', batch_file.name, '--progress'],
)
assert result.exit_code == 0
assert MockMyCli.ran_queries == ['select 2;\n', 'select 2;\n', 'select 2;\n']
assert DummyProgressBar.calls == [[0, 1, 2]]
finally:
os.remove(batch_file.name)


def test_batch_file_with_progress_requires_plain_file(monkeypatch, tmp_path):
mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch)
runner = CliRunner()

monkeypatch.setattr(
mycli_main,
'sys',
SimpleNamespace(
stdin=SimpleNamespace(isatty=lambda: False),
stderr=SimpleNamespace(isatty=lambda: True),
exit=sys.exit,
),
)

result = runner.invoke(
mycli_main.click_entrypoint,
args=['--batch', str(tmp_path), '--progress'],
)

assert result.exit_code != 0
assert '--progress is only compatible with a plain file.' in result.output
assert MockMyCli.ran_queries == []


def test_execute_arg_warns_about_ignoring_stdin(monkeypatch):
mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch)
runner = CliRunner()
Expand Down
Loading