diff --git a/aai_cli/commands/stream/__init__.py b/aai_cli/commands/stream/__init__.py index 79b88ab1..210ada1f 100644 --- a/aai_cli/commands/stream/__init__.py +++ b/aai_cli/commands/stream/__init__.py @@ -10,6 +10,7 @@ from aai_cli.app.context import run_with_options from aai_cli.commands.stream import _exec as stream_exec from aai_cli.core import choices, llm +from aai_cli.streaming.turn_presets import TurnDetectionPreset from aai_cli.ui.help_text import examples_epilog app = typer.Typer() @@ -113,6 +114,12 @@ def stream( rich_help_panel=help_panels.OPT_MODEL, ), # turn detection + turn_detection: TurnDetectionPreset | None = typer.Option( + None, + "--turn-detection", + help="Turn-detection sensitivity preset", + rich_help_panel=help_panels.OPT_TURNS, + ), end_of_turn_confidence_threshold: float | None = typer.Option( None, # Not "--end-of-turn-confidence-threshold": at 34 chars the name can't render @@ -315,6 +322,7 @@ def stream( end_of_turn_confidence_threshold=end_of_turn_confidence_threshold, min_turn_silence=min_turn_silence, max_turn_silence=max_turn_silence, + turn_detection=turn_detection, vad_threshold=vad_threshold, format_turns=format_turns, include_partial_turns=include_partial_turns, diff --git a/aai_cli/commands/stream/_exec.py b/aai_cli/commands/stream/_exec.py index 91d1993e..688b03ba 100644 --- a/aai_cli/commands/stream/_exec.py +++ b/aai_cli/commands/stream/_exec.py @@ -21,6 +21,7 @@ from aai_cli.core import choices, client, config_builder, youtube from aai_cli.core.errors import UsageError from aai_cli.core.microphone import MicrophoneSource +from aai_cli.streaming import turn_presets from aai_cli.streaming.macos import MacSystemAudioSource from aai_cli.streaming.render import StreamRenderer from aai_cli.streaming.session import ( @@ -30,6 +31,7 @@ validate_sources, ) from aai_cli.streaming.sources import TARGET_RATE, FileSource, StdinSource +from aai_cli.streaming.turn_presets import TurnDetectionPreset from aai_cli.ui import output from aai_cli.ui.follow import FollowRenderer @@ -57,6 +59,7 @@ class StreamOptions: end_of_turn_confidence_threshold: float | None min_turn_silence: int | None max_turn_silence: int | None + turn_detection: TurnDetectionPreset | None vad_threshold: float | None format_turns: bool | None include_partial_turns: bool | None @@ -93,15 +96,21 @@ def source_options(self) -> SourceOptions: def base_flags(self) -> dict[str, object]: """Every streaming flag except sample_rate, which is set per source at stream time.""" + end_of_turn_confidence_threshold, min_turn_silence, max_turn_silence = turn_presets.resolve( + self.turn_detection, + self.end_of_turn_confidence_threshold, + self.min_turn_silence, + self.max_turn_silence, + ) flags: dict[str, object] = { "speech_model": config_builder.enum_value(self.speech_model), "format_turns": self.format_turns if self.format_turns is not None else True, "encoding": config_builder.enum_value(self.encoding), "language_detection": self.language_detection, "domain": self.domain, - "end_of_turn_confidence_threshold": self.end_of_turn_confidence_threshold, - "min_turn_silence": self.min_turn_silence, - "max_turn_silence": self.max_turn_silence, + "end_of_turn_confidence_threshold": end_of_turn_confidence_threshold, + "min_turn_silence": min_turn_silence, + "max_turn_silence": max_turn_silence, "vad_threshold": self.vad_threshold, "include_partial_turns": self.include_partial_turns, "keyterms_prompt": list(self.keyterms_prompt) if self.keyterms_prompt else None, diff --git a/aai_cli/streaming/turn_presets.py b/aai_cli/streaming/turn_presets.py new file mode 100644 index 00000000..f8c1fde8 --- /dev/null +++ b/aai_cli/streaming/turn_presets.py @@ -0,0 +1,50 @@ +"""Documented turn-detection quick-start presets for `assembly stream`. + +The Aggressive/Balanced/Conservative configurations mirror the streaming +turn-detection docs (streaming/universal-streaming/turn-detection). A preset +sets the three end-of-turn knobs together; `resolve` lets any explicitly-passed +raw flag override its slot so users can start from a preset and tweak one value. +""" + +from __future__ import annotations + +import enum + + +class TurnDetectionPreset(enum.StrEnum): + """Named end-of-turn sensitivity presets from the streaming turn-detection docs.""" + + aggressive = "aggressive" + balanced = "balanced" + conservative = "conservative" + + +# (end_of_turn_confidence_threshold, min_turn_silence, max_turn_silence) per the docs' +# quick-start configurations. Keep these verbatim — they're the published recommendations. +_PRESETS: dict[TurnDetectionPreset, tuple[float, int, int]] = { + TurnDetectionPreset.aggressive: (0.4, 160, 400), + TurnDetectionPreset.balanced: (0.4, 400, 1280), + TurnDetectionPreset.conservative: (0.7, 800, 3600), +} + + +def resolve( + preset: TurnDetectionPreset | None, + end_of_turn_confidence_threshold: float | None, + min_turn_silence: int | None, + max_turn_silence: int | None, +) -> tuple[float | None, int | None, int | None]: + """Merge a preset with raw flags, where an explicitly-passed value wins its slot. + + With no preset the three values pass through unchanged (server defaults apply). + """ + if preset is None: + return end_of_turn_confidence_threshold, min_turn_silence, max_turn_silence + preset_eot, preset_min, preset_max = _PRESETS[preset] + return ( + end_of_turn_confidence_threshold + if end_of_turn_confidence_threshold is not None + else preset_eot, + min_turn_silence if min_turn_silence is not None else preset_min, + max_turn_silence if max_turn_silence is not None else preset_max, + ) diff --git a/tests/__snapshots__/test_snapshots_help_run.ambr b/tests/__snapshots__/test_snapshots_help_run.ambr index d5b63740..56e8e369 100644 --- a/tests/__snapshots__/test_snapshots_help_run.ambr +++ b/tests/__snapshots__/test_snapshots_help_run.ambr @@ -615,6 +615,9 @@ │ (repeatable) │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Turn Detection ─────────────────────────────────────────────────────────────╮ + │ --turn-detection [aggressive| Turn-detecti… │ + │ balanced|con sensitivity │ + │ servative] preset │ │ --end-of-turn-confidence FLOAT RANGE End-of-turn │ │ [0.0<=x<=1.0 confidence │ │ ] (0-1) │ diff --git a/tests/test_stream_command_flags.py b/tests/test_stream_command_flags.py index 098b3dc5..aed1e70c 100644 --- a/tests/test_stream_command_flags.py +++ b/tests/test_stream_command_flags.py @@ -81,6 +81,41 @@ def test_stream_turn_silence_below_minimum_is_rejected(monkeypatch): assert result.exit_code == 2 +def test_stream_turn_detection_preset_reaches_params(monkeypatch): + # --turn-detection balanced must thread through the command wiring into the + # documented (0.4, 400, 1280) trio on StreamingParameters. + config.set_api_key("default", "sk_live") + captured = {} + monkeypatch.setattr( + "aai_cli.commands.stream._exec.client.stream_audio", + lambda api_key, source, *, params, **kw: captured.update(params=params), + ) + + runner.invoke(app, ["stream", "--sample", "--turn-detection", "balanced"]) + params = captured["params"] + assert params.end_of_turn_confidence_threshold == 0.4 + assert params.min_turn_silence == 400 + assert params.max_turn_silence == 1280 + + +def test_stream_explicit_flag_overrides_preset_via_cli(monkeypatch): + # A raw flag passed alongside the preset wins its slot through the real argv path. + config.set_api_key("default", "sk_live") + captured = {} + monkeypatch.setattr( + "aai_cli.commands.stream._exec.client.stream_audio", + lambda api_key, source, *, params, **kw: captured.update(params=params), + ) + + runner.invoke( + app, + ["stream", "--sample", "--turn-detection", "conservative", "--min-turn-silence", "200"], + ) + params = captured["params"] + assert params.min_turn_silence == 200 # explicit flag, not the preset's 800 + assert params.max_turn_silence == 3600 # preset's value survives + + def test_stream_config_escape_hatch(monkeypatch): config.set_api_key("default", "sk_live") captured = {} diff --git a/tests/test_stream_exec.py b/tests/test_stream_exec.py index 8b71e2c6..5781fea5 100644 --- a/tests/test_stream_exec.py +++ b/tests/test_stream_exec.py @@ -17,6 +17,7 @@ from aai_cli.commands.stream import _exec as stream_exec from aai_cli.core import config, llm from aai_cli.core.errors import UsageError +from aai_cli.streaming.turn_presets import TurnDetectionPreset # The CLI's flag defaults, as data. Tests override per-case with dataclasses.replace. DEFAULTS = stream_exec.StreamOptions( @@ -35,6 +36,7 @@ end_of_turn_confidence_threshold=None, min_turn_silence=None, max_turn_silence=None, + turn_detection=None, vad_threshold=None, format_turns=None, include_partial_turns=None, @@ -123,6 +125,32 @@ def test_redact_pii_sub_enum_maps_to_its_string_value(): assert DEFAULTS.base_flags()["redact_pii_sub"] is None # unset stays None +def test_turn_detection_preset_fills_base_flags(): + # --turn-detection balanced supplies the documented (0.4, 400, 1280) trio. + opts = dataclasses.replace(DEFAULTS, turn_detection=TurnDetectionPreset.balanced) + flags = opts.base_flags() + assert flags["end_of_turn_confidence_threshold"] == 0.4 + assert flags["min_turn_silence"] == 400 + assert flags["max_turn_silence"] == 1280 + + +def test_explicit_turn_flag_overrides_the_preset_slot(): + # A raw --min-turn-silence wins over the preset's value; the other slots stay. + opts = dataclasses.replace( + DEFAULTS, turn_detection=TurnDetectionPreset.balanced, min_turn_silence=900 + ) + flags = opts.base_flags() + assert flags["min_turn_silence"] == 900 + assert flags["max_turn_silence"] == 1280 + + +def test_no_preset_leaves_turn_flags_unset(): + flags = DEFAULTS.base_flags() + assert flags["end_of_turn_confidence_threshold"] is None + assert flags["min_turn_silence"] is None + assert flags["max_turn_silence"] is None + + def test_stream_options_are_immutable(): field_name = "sample" with pytest.raises(dataclasses.FrozenInstanceError): diff --git a/tests/test_turn_presets.py b/tests/test_turn_presets.py new file mode 100644 index 00000000..b1a49c31 --- /dev/null +++ b/tests/test_turn_presets.py @@ -0,0 +1,48 @@ +"""Unit tests for the streaming turn-detection presets (aai_cli.streaming.turn_presets). + +The presets mirror the documented Aggressive/Balanced/Conservative quick-start +configurations (streaming/universal-streaming/turn-detection). `resolve` merges a +preset with explicitly-passed raw flags, where an explicit value always wins. +""" + +from __future__ import annotations + +import pytest + +from aai_cli.streaming import turn_presets +from aai_cli.streaming.turn_presets import TurnDetectionPreset + + +def test_no_preset_passes_raw_values_through_unchanged(): + assert turn_presets.resolve(None, None, None, None) == (None, None, None) + assert turn_presets.resolve(None, 0.5, 300, 900) == (0.5, 300, 900) + + +@pytest.mark.parametrize( + ("preset", "expected"), + [ + (TurnDetectionPreset.aggressive, (0.4, 160, 400)), + (TurnDetectionPreset.balanced, (0.4, 400, 1280)), + (TurnDetectionPreset.conservative, (0.7, 800, 3600)), + ], +) +def test_preset_supplies_documented_values(preset, expected): + assert turn_presets.resolve(preset, None, None, None) == expected + + +def test_explicit_min_turn_silence_overrides_only_its_slot(): + # balanced is (0.4, 400, 1280); overriding min_turn_silence keeps the other two. + assert turn_presets.resolve(TurnDetectionPreset.balanced, None, 500, None) == (0.4, 500, 1280) + + +def test_explicit_confidence_overrides_preset_confidence(): + # conservative is (0.7, 800, 3600); an explicit eot threshold wins. + assert turn_presets.resolve(TurnDetectionPreset.conservative, 0.9, None, None) == ( + 0.9, + 800, + 3600, + ) + + +def test_all_explicit_flags_override_every_preset_slot(): + assert turn_presets.resolve(TurnDetectionPreset.aggressive, 0.1, 50, 100) == (0.1, 50, 100)