Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions aai_cli/commands/stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aai_cli.app.context import run_with_options
from aai_cli.commands.stream import _exec as stream_exec
from aai_cli.core import choices, llm
from aai_cli.streaming.turn_presets import TurnDetectionPreset
from aai_cli.ui.help_text import examples_epilog

app = typer.Typer()
Expand Down Expand Up @@ -113,6 +114,12 @@ def stream(
rich_help_panel=help_panels.OPT_MODEL,
),
# turn detection
turn_detection: TurnDetectionPreset | None = typer.Option(
None,
"--turn-detection",
help="Turn-detection sensitivity preset",
rich_help_panel=help_panels.OPT_TURNS,
),
end_of_turn_confidence_threshold: float | None = typer.Option(
None,
# Not "--end-of-turn-confidence-threshold": at 34 chars the name can't render
Expand Down Expand Up @@ -315,6 +322,7 @@ def stream(
end_of_turn_confidence_threshold=end_of_turn_confidence_threshold,
min_turn_silence=min_turn_silence,
max_turn_silence=max_turn_silence,
turn_detection=turn_detection,
vad_threshold=vad_threshold,
format_turns=format_turns,
include_partial_turns=include_partial_turns,
Expand Down
15 changes: 12 additions & 3 deletions aai_cli/commands/stream/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from aai_cli.core import choices, client, config_builder, youtube
from aai_cli.core.errors import UsageError
from aai_cli.core.microphone import MicrophoneSource
from aai_cli.streaming import turn_presets
from aai_cli.streaming.macos import MacSystemAudioSource
from aai_cli.streaming.render import StreamRenderer
from aai_cli.streaming.session import (
Expand All @@ -30,6 +31,7 @@
validate_sources,
)
from aai_cli.streaming.sources import TARGET_RATE, FileSource, StdinSource
from aai_cli.streaming.turn_presets import TurnDetectionPreset
from aai_cli.ui import output
from aai_cli.ui.follow import FollowRenderer

Expand Down Expand Up @@ -57,6 +59,7 @@ class StreamOptions:
end_of_turn_confidence_threshold: float | None
min_turn_silence: int | None
max_turn_silence: int | None
turn_detection: TurnDetectionPreset | None
vad_threshold: float | None
format_turns: bool | None
include_partial_turns: bool | None
Expand Down Expand Up @@ -93,15 +96,21 @@ def source_options(self) -> SourceOptions:

def base_flags(self) -> dict[str, object]:
"""Every streaming flag except sample_rate, which is set per source at stream time."""
end_of_turn_confidence_threshold, min_turn_silence, max_turn_silence = turn_presets.resolve(
self.turn_detection,
self.end_of_turn_confidence_threshold,
self.min_turn_silence,
self.max_turn_silence,
)
flags: dict[str, object] = {
"speech_model": config_builder.enum_value(self.speech_model),
"format_turns": self.format_turns if self.format_turns is not None else True,
"encoding": config_builder.enum_value(self.encoding),
"language_detection": self.language_detection,
"domain": self.domain,
"end_of_turn_confidence_threshold": self.end_of_turn_confidence_threshold,
"min_turn_silence": self.min_turn_silence,
"max_turn_silence": self.max_turn_silence,
"end_of_turn_confidence_threshold": end_of_turn_confidence_threshold,
"min_turn_silence": min_turn_silence,
"max_turn_silence": max_turn_silence,
"vad_threshold": self.vad_threshold,
"include_partial_turns": self.include_partial_turns,
"keyterms_prompt": list(self.keyterms_prompt) if self.keyterms_prompt else None,
Expand Down
50 changes: 50 additions & 0 deletions aai_cli/streaming/turn_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
"""Documented turn-detection quick-start presets for `assembly stream`.

The Aggressive/Balanced/Conservative configurations mirror the streaming
turn-detection docs (streaming/universal-streaming/turn-detection). A preset
sets the three end-of-turn knobs together; `resolve` lets any explicitly-passed
raw flag override its slot so users can start from a preset and tweak one value.
"""

from __future__ import annotations

import enum


class TurnDetectionPreset(enum.StrEnum):
"""Named end-of-turn sensitivity presets from the streaming turn-detection docs."""

aggressive = "aggressive"
balanced = "balanced"
conservative = "conservative"


# (end_of_turn_confidence_threshold, min_turn_silence, max_turn_silence) per the docs'
# quick-start configurations. Keep these verbatim — they're the published recommendations.
_PRESETS: dict[TurnDetectionPreset, tuple[float, int, int]] = {
TurnDetectionPreset.aggressive: (0.4, 160, 400),
TurnDetectionPreset.balanced: (0.4, 400, 1280),
TurnDetectionPreset.conservative: (0.7, 800, 3600),
}


def resolve(
preset: TurnDetectionPreset | None,
end_of_turn_confidence_threshold: float | None,
min_turn_silence: int | None,
max_turn_silence: int | None,
) -> tuple[float | None, int | None, int | None]:
"""Merge a preset with raw flags, where an explicitly-passed value wins its slot.

With no preset the three values pass through unchanged (server defaults apply).
"""
if preset is None:
return end_of_turn_confidence_threshold, min_turn_silence, max_turn_silence
preset_eot, preset_min, preset_max = _PRESETS[preset]
return (
end_of_turn_confidence_threshold
if end_of_turn_confidence_threshold is not None
else preset_eot,
min_turn_silence if min_turn_silence is not None else preset_min,
max_turn_silence if max_turn_silence is not None else preset_max,
)
3 changes: 3 additions & 0 deletions tests/__snapshots__/test_snapshots_help_run.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,9 @@
│ (repeatable) │
╰──────────────────────────────────────────────────────────────────────────────╯
╭─ Turn Detection ─────────────────────────────────────────────────────────────╮
│ --turn-detection [aggressive| Turn-detecti… │
│ balanced|con sensitivity │
│ servative] preset │
│ --end-of-turn-confidence FLOAT RANGE End-of-turn │
│ [0.0<=x<=1.0 confidence │
│ ] (0-1) │
Expand Down
35 changes: 35 additions & 0 deletions tests/test_stream_command_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,41 @@ def test_stream_turn_silence_below_minimum_is_rejected(monkeypatch):
assert result.exit_code == 2


def test_stream_turn_detection_preset_reaches_params(monkeypatch):
# --turn-detection balanced must thread through the command wiring into the
# documented (0.4, 400, 1280) trio on StreamingParameters.
config.set_api_key("default", "sk_live")
captured = {}
monkeypatch.setattr(
"aai_cli.commands.stream._exec.client.stream_audio",
lambda api_key, source, *, params, **kw: captured.update(params=params),
)

runner.invoke(app, ["stream", "--sample", "--turn-detection", "balanced"])
params = captured["params"]
assert params.end_of_turn_confidence_threshold == 0.4
assert params.min_turn_silence == 400
assert params.max_turn_silence == 1280


def test_stream_explicit_flag_overrides_preset_via_cli(monkeypatch):
# A raw flag passed alongside the preset wins its slot through the real argv path.
config.set_api_key("default", "sk_live")
captured = {}
monkeypatch.setattr(
"aai_cli.commands.stream._exec.client.stream_audio",
lambda api_key, source, *, params, **kw: captured.update(params=params),
)

runner.invoke(
app,
["stream", "--sample", "--turn-detection", "conservative", "--min-turn-silence", "200"],
)
params = captured["params"]
assert params.min_turn_silence == 200 # explicit flag, not the preset's 800
assert params.max_turn_silence == 3600 # preset's value survives


def test_stream_config_escape_hatch(monkeypatch):
config.set_api_key("default", "sk_live")
captured = {}
Expand Down
28 changes: 28 additions & 0 deletions tests/test_stream_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from aai_cli.commands.stream import _exec as stream_exec
from aai_cli.core import config, llm
from aai_cli.core.errors import UsageError
from aai_cli.streaming.turn_presets import TurnDetectionPreset

# The CLI's flag defaults, as data. Tests override per-case with dataclasses.replace.
DEFAULTS = stream_exec.StreamOptions(
Expand All @@ -35,6 +36,7 @@
end_of_turn_confidence_threshold=None,
min_turn_silence=None,
max_turn_silence=None,
turn_detection=None,
vad_threshold=None,
format_turns=None,
include_partial_turns=None,
Expand Down Expand Up @@ -123,6 +125,32 @@ def test_redact_pii_sub_enum_maps_to_its_string_value():
assert DEFAULTS.base_flags()["redact_pii_sub"] is None # unset stays None


def test_turn_detection_preset_fills_base_flags():
# --turn-detection balanced supplies the documented (0.4, 400, 1280) trio.
opts = dataclasses.replace(DEFAULTS, turn_detection=TurnDetectionPreset.balanced)
flags = opts.base_flags()
assert flags["end_of_turn_confidence_threshold"] == 0.4
assert flags["min_turn_silence"] == 400
assert flags["max_turn_silence"] == 1280


def test_explicit_turn_flag_overrides_the_preset_slot():
# A raw --min-turn-silence wins over the preset's value; the other slots stay.
opts = dataclasses.replace(
DEFAULTS, turn_detection=TurnDetectionPreset.balanced, min_turn_silence=900
)
flags = opts.base_flags()
assert flags["min_turn_silence"] == 900
assert flags["max_turn_silence"] == 1280


def test_no_preset_leaves_turn_flags_unset():
flags = DEFAULTS.base_flags()
assert flags["end_of_turn_confidence_threshold"] is None
assert flags["min_turn_silence"] is None
assert flags["max_turn_silence"] is None


def test_stream_options_are_immutable():
field_name = "sample"
with pytest.raises(dataclasses.FrozenInstanceError):
Expand Down
48 changes: 48 additions & 0 deletions tests/test_turn_presets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Unit tests for the streaming turn-detection presets (aai_cli.streaming.turn_presets).

The presets mirror the documented Aggressive/Balanced/Conservative quick-start
configurations (streaming/universal-streaming/turn-detection). `resolve` merges a
preset with explicitly-passed raw flags, where an explicit value always wins.
"""

from __future__ import annotations

import pytest

from aai_cli.streaming import turn_presets
from aai_cli.streaming.turn_presets import TurnDetectionPreset


def test_no_preset_passes_raw_values_through_unchanged():
assert turn_presets.resolve(None, None, None, None) == (None, None, None)
assert turn_presets.resolve(None, 0.5, 300, 900) == (0.5, 300, 900)


@pytest.mark.parametrize(
("preset", "expected"),
[
(TurnDetectionPreset.aggressive, (0.4, 160, 400)),
(TurnDetectionPreset.balanced, (0.4, 400, 1280)),
(TurnDetectionPreset.conservative, (0.7, 800, 3600)),
],
)
def test_preset_supplies_documented_values(preset, expected):
assert turn_presets.resolve(preset, None, None, None) == expected


def test_explicit_min_turn_silence_overrides_only_its_slot():
# balanced is (0.4, 400, 1280); overriding min_turn_silence keeps the other two.
assert turn_presets.resolve(TurnDetectionPreset.balanced, None, 500, None) == (0.4, 500, 1280)


def test_explicit_confidence_overrides_preset_confidence():
# conservative is (0.7, 800, 3600); an explicit eot threshold wins.
assert turn_presets.resolve(TurnDetectionPreset.conservative, 0.9, None, None) == (
0.9,
800,
3600,
)


def test_all_explicit_flags_override_every_preset_slot():
assert turn_presets.resolve(TurnDetectionPreset.aggressive, 0.1, 50, 100) == (0.1, 50, 100)
Loading