Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions aai_cli/commands/stream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
("Stream a list of files in turn", "ls *.wav | assembly stream --from-stdin"),
("Stream the hosted sample", "assembly stream --sample"),
("Label speakers in the live transcript", "assembly stream --speaker-labels"),
("Save a WAV of the audio while streaming", "assembly stream --save-audio out.wav"),
(
"Boost domain terms with keyterm prompts",
'assembly stream --keyterms-prompt "AssemblyAI" --keyterms-prompt "Claude"',
Expand Down Expand Up @@ -82,6 +83,16 @@ def stream(
help="macOS only: stream system/app audio without the microphone",
rich_help_panel=help_panels.OPT_CAPTURE,
),
save_audio: Path | None = typer.Option(
None,
"--save-audio",
help="Tee the streamed PCM to PATH as a 16-bit mono WAV while transcribing",
rich_help_panel=help_panels.OPT_CAPTURE,
dir_okay=False,
# Click guardrail; flipping it changes no behavior a unit test can observe
# (and the writable check is a no-op under the test runner's root uid).
writable=True, # pragma: no mutate
),
# model & input
speech_model: SpeechModel = typer.Option(
DEFAULT_SPEECH_MODEL,
Expand Down Expand Up @@ -355,5 +366,6 @@ def stream(
config_file=config_file,
output_field=output_field,
show_code=show_code,
save_audio=save_audio,
)
run_with_options(ctx, stream_exec.run_stream, opts, json=json_out)
22 changes: 21 additions & 1 deletion aai_cli/commands/stream/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from aai_cli.core import choices, client, config_builder, stdio, youtube
from aai_cli.core.errors import UsageError, mutually_exclusive
from aai_cli.core.microphone import MicrophoneSource
from aai_cli.streaming import turn_presets
from aai_cli.streaming import record, turn_presets
from aai_cli.streaming.macos import MacSystemAudioSource
from aai_cli.streaming.render import StreamRenderer
from aai_cli.streaming.session import (
Expand Down Expand Up @@ -85,6 +85,7 @@ class StreamOptions:
config_file: Path | None
output_field: choices.TextOrJson | None
show_code: bool
save_audio: Path | None

def source_options(self) -> SourceOptions:
"""The audio-input subset, in the shape the validation/dispatch helpers read."""
Expand Down Expand Up @@ -245,6 +246,11 @@ def _collect_batch_sources(opts: StreamOptions, *, text_mode: bool) -> list[str]
("--show-code", opts.show_code),
suggestion="--show-code renders one source; pass a single file or URL.",
)
mutually_exclusive(
("--from-stdin", True),
("--save-audio", opts.save_audio is not None),
suggestion="--save-audio tees one stream; run a single source to record it.",
)
mutually_exclusive(
("--llm", bool(opts.llm_prompt)),
("-o text", text_mode),
Expand Down Expand Up @@ -305,12 +311,25 @@ def run_stream(opts: StreamOptions, state: AppState, *, json_mode: bool) -> None
base_flags = opts.base_flags()

if opts.show_code:
if opts.save_audio is not None:
raise UsageError(
"--save-audio cannot be combined with --show-code; the generated SDK "
"code does not tee audio to disk."
)
_print_show_code(opts, sources, base_flags, text_mode=text_mode)
return

# Validate the requested sources (including that a local file exists) before
# credentials, so a typo'd path reads as "file not found" — not as a login.
validate_sources(sources, has_llm=bool(opts.llm_prompt), text_mode=text_mode)
if opts.save_audio is not None:
if sources.from_system_audio:
raise UsageError(
"--save-audio cannot be combined with --system-audio; the mic and system "
"streams can't share one file.",
suggestion="Record a single source (mic, file, URL, or - on stdin).",
)
record.validate_target(opts.save_audio)
if sources.from_file and not sources.from_stdin:
client.resolve_audio_source(sources.source, sample=sources.sample)
api_key = state.resolve_api_key()
Expand All @@ -326,6 +345,7 @@ def run_stream(opts: StreamOptions, state: AppState, *, json_mode: bool) -> None
llm_prompts=llm_prompts,
model=opts.model,
max_tokens=opts.max_tokens,
save_audio=opts.save_audio,
llm_interval=opts.llm_interval,
)
_dispatch(session, sources)
67 changes: 67 additions & 0 deletions aai_cli/streaming/record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Tee streamed PCM to a WAV file — backs `assembly stream --save-audio PATH`.

The whole point is a verbatim recording of exactly the bytes sent to the streaming
API, so a caller (e.g. an ensemble that compares the live turns against an async
re-transcribe) can keep the audio without owning capture itself. The tee never alters
what's transcribed: it writes each chunk to disk and yields it onward unchanged.
"""

from __future__ import annotations

import wave
from collections.abc import Generator, Iterable
from pathlib import Path

from aai_cli.core.errors import CLIError
from aai_cli.streaming.sources import PCM16_SAMPLE_WIDTH_BYTES


def validate_target(path: Path) -> None:
"""Reject a ``--save-audio`` path whose parent directory is missing, before streaming.

Run before credentials/audio are opened so a bad path reads as a path error up
front, not after a session has already started recording into the void.
"""
parent = path.parent
if not parent.is_dir():
raise CLIError(
f"Cannot save audio to {path}: {parent} is not a directory.",
error_type="save_audio_path",
exit_code=2,
suggestion="Create the directory first, or pass a path under an existing one.",
)


def tee_wav(audio: Iterable[bytes], path: Path, *, rate: int) -> Generator[bytes, None, None]:
"""Yield every PCM16 chunk from ``audio`` unchanged while writing it to ``path`` as WAV.

The recording is mono 16-bit PCM at ``rate`` — the same shape the streaming API
receives. The header's length fields are patched when the iterable is exhausted or
closed early (Ctrl-C raises ``GeneratorExit`` at the ``yield``), so even an
interrupted run leaves a valid, playable WAV of the audio captured so far.
"""
try:
# Open the handle ourselves (rather than letting wave.open(str) do it): a bad
# path then fails here cleanly, with no half-built Wave_write whose __del__ would
# later raise an "ignored in __del__" warning during GC.
handle = path.open("wb")
except OSError as exc:
raise CLIError(
f"Cannot open {path} for writing: {exc}",
error_type="save_audio_path",
exit_code=2,
) from exc
try:
# The Wave_write context manager closes (flushes + patches the length fields from
# what was actually written) on exit, so the file is a valid WAV even when the
# generator is closed mid-stream (Ctrl-C). The outer finally then closes the
# handle we opened — after the patch — since wave only closes handles it opened.
with wave.open(handle, "wb") as wav:
wav.setnchannels(1)
wav.setsampwidth(PCM16_SAMPLE_WIDTH_BYTES)
wav.setframerate(rate)
for chunk in audio:
wav.writeframesraw(chunk)
yield chunk
finally:
handle.close()
7 changes: 7 additions & 0 deletions aai_cli/streaming/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
UsageError,
mutually_exclusive,
)
from aai_cli.streaming import record
from aai_cli.streaming.render import StreamRenderer, speaker_prefix
from aai_cli.ui import output
from aai_cli.ui.follow import FollowRenderer
Expand Down Expand Up @@ -137,6 +138,9 @@ class StreamSession:
llm_prompts: list[str]
model: str
max_tokens: int
# When set, tee the streamed PCM to this path as a WAV (see record.tee_wav). Only
# the single-source path sets it — the parallel/batch callers reject --save-audio.
save_audio: Path | None = None
# Seconds between --llm summary refreshes; <=0 re-runs the chain on every turn.
llm_interval: float = 0.0
# Monotonic clock, injectable so the interval throttle is deterministic in tests.
Expand Down Expand Up @@ -242,6 +246,9 @@ def _maybe_summarize(self, *, final: bool = False) -> None:
def stream_one(
self, audio: Iterable[bytes], rate: int, *, source_label: str | None = None
) -> None:
if self.save_audio is not None:
# Tee verbatim to disk at the source's true rate before it hits the wire.
audio = record.tee_wav(audio, self.save_audio, rate=rate)
flags = self.base_flags | {"sample_rate": rate}
if source_label == "you":
# The microphone captures you alone, so never diarize it into separate
Expand Down
5 changes: 5 additions & 0 deletions tests/__snapshots__/test_snapshots_help_run.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,9 @@
│ --system-audio-only macOS only: stream │
│ system/app audio without │
│ the microphone │
│ --save-audio FILE Tee the streamed PCM to │
│ PATH as a 16-bit mono WAV │
│ while transcribing │
╰──────────────────────────────────────────────────────────────────────────────╯
╭─ Model & Language ───────────────────────────────────────────────────────────╮
│ --speech-model [universal-streaming-m Streaming speech model │
Expand Down Expand Up @@ -813,6 +816,8 @@
$ assembly stream --sample
Label speakers in the live transcript
$ assembly stream --speaker-labels
Save a WAV of the audio while streaming
$ assembly stream --save-audio out.wav
Boost domain terms with keyterm prompts
$ assembly stream --keyterms-prompt "AssemblyAI" --keyterms-prompt "Claude"
Summarize action items live as you talk
Expand Down
87 changes: 86 additions & 1 deletion tests/test_stream_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,16 @@
from __future__ import annotations

import dataclasses
import wave
from pathlib import Path

import pytest

from aai_cli.app.context import AppState
from aai_cli.commands.stream import DEFAULT_SPEECH_MODEL
from aai_cli.commands.stream import _exec as stream_exec
from aai_cli.core import config, llm
from aai_cli.core.errors import UsageError
from aai_cli.core.errors import CLIError, UsageError
from aai_cli.streaming.turn_presets import TurnDetectionPreset

# The CLI's flag defaults, as data. Tests override per-case with dataclasses.replace.
Expand Down Expand Up @@ -60,6 +62,7 @@
config_file=None,
output_field=None,
show_code=False,
save_audio=None,
)


Expand Down Expand Up @@ -170,6 +173,7 @@ def test_stream_options_are_immutable():
{"from_stdin": True, "device": 2}, # mic-only capture flags
{"from_stdin": True, "sample_rate": 44100},
{"from_stdin": True, "show_code": True}, # renders one source
{"from_stdin": True, "save_audio": Path("out.wav")}, # tees one stream
],
)
def test_from_stdin_rejects_incompatible_flags(overrides):
Expand Down Expand Up @@ -222,3 +226,84 @@ def fake_stream_batch(sources, *, make_session, open_source, renderer, json_mode
dataclasses.replace(DEFAULTS, from_stdin=True), AppState(), json_mode=True
)
assert seen["sources"] == ["a.wav", "b.wav"]


# --- --save-audio (tee the streamed PCM to a WAV) --------------------------
class RecordingMic(FakeMic):
"""A mic that yields known PCM so the tee'd WAV's contents can be asserted."""

PCM = b"\x01\x02\x03\x04\x05\x06\x07\x08"

def __iter__(self):
return iter([self.PCM])


def test_save_audio_tees_streamed_pcm_to_a_wav(monkeypatch, tmp_path):
# The bytes the streaming API receives are also written to --save-audio, verbatim,
# as a 16-bit mono WAV at the source's sample rate.
config.set_api_key("default", "sk_live")
out = tmp_path / "rec.wav"

def fake_stream_audio(api_key, source, *, params, **_kwargs):
# Draining the iterable is what drives the tee — mirror the real SDK consuming it.
sent = b"".join(source)
assert sent == RecordingMic.PCM # the API still sees the unaltered audio

monkeypatch.setattr(stream_exec.client, "stream_audio", fake_stream_audio)
monkeypatch.setattr(stream_exec, "MicrophoneSource", RecordingMic)

stream_exec.run_stream(
dataclasses.replace(DEFAULTS, save_audio=out), AppState(), json_mode=True
)

assert out.is_file()
with wave.open(str(out), "rb") as w:
assert w.getnchannels() == 1
assert w.getsampwidth() == 2
assert w.getframerate() == 16000 # FakeMic's reported rate
assert w.readframes(w.getnframes()) == RecordingMic.PCM


def test_save_audio_not_written_when_flag_unset(monkeypatch, tmp_path):
# Without --save-audio, the default run leaves no stray WAV behind (kills a mutant
# that tees unconditionally).
config.set_api_key("default", "sk_live")
monkeypatch.setattr(stream_exec.client, "stream_audio", lambda *a, **k: b"".join(a[1]))
monkeypatch.setattr(stream_exec, "MicrophoneSource", RecordingMic)

stream_exec.run_stream(DEFAULTS, AppState(), json_mode=True)

assert list(tmp_path.glob("*.wav")) == []


def test_save_audio_rejects_system_audio():
# The mic + system streams can't share one file, so the combo is a usage error
# (raised before credentials).
with pytest.raises(UsageError):
stream_exec.run_stream(
dataclasses.replace(DEFAULTS, save_audio=Path("rec.wav"), system_audio=True),
AppState(),
json_mode=False,
)


def test_save_audio_rejects_show_code():
# --show-code emits SDK code that doesn't tee audio, so the combo is rejected.
with pytest.raises(UsageError):
stream_exec.run_stream(
dataclasses.replace(DEFAULTS, save_audio=Path("rec.wav"), show_code=True),
AppState(),
json_mode=False,
)


def test_save_audio_rejects_missing_parent_dir(tmp_path):
# A path under a directory that doesn't exist is a clean path error, before auth.
config.set_api_key("default", "sk_live")
with pytest.raises(CLIError) as excinfo:
stream_exec.run_stream(
dataclasses.replace(DEFAULTS, save_audio=tmp_path / "nope" / "rec.wav"),
AppState(),
json_mode=False,
)
assert excinfo.value.error_type == "save_audio_path"
67 changes: 67 additions & 0 deletions tests/test_streaming_record.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""Unit tests for aai_cli.streaming.record — the --save-audio WAV tee."""

from __future__ import annotations

import wave

import pytest

from aai_cli.core.errors import CLIError
from aai_cli.streaming import record


def _read_wav(path):
with wave.open(str(path), "rb") as w:
return w.getnchannels(), w.getsampwidth(), w.getframerate(), w.readframes(w.getnframes())


def test_tee_wav_yields_chunks_unchanged(tmp_path):
chunks = [b"\x01\x02", b"\x03\x04\x05\x06"]
out = list(record.tee_wav(iter(chunks), tmp_path / "a.wav", rate=16000))
assert out == chunks # the tee must not alter what's streamed onward


def test_tee_wav_writes_a_valid_wav_with_the_source_rate(tmp_path):
path = tmp_path / "a.wav"
list(record.tee_wav(iter([b"\x01\x02", b"\x03\x04"]), path, rate=44100))
channels, width, rate, frames = _read_wav(path)
assert channels == 1
assert width == 2
assert rate == 44100 # the declared source rate, not a hardcoded default
assert frames == b"\x01\x02\x03\x04"


def test_tee_wav_finalizes_a_valid_wav_on_early_close(tmp_path):
# Ctrl-C closes the generator mid-stream; the partial file must still be valid WAV.
path = tmp_path / "a.wav"
gen = record.tee_wav(iter([b"\x01\x02", b"\x03\x04"]), path, rate=16000)
assert next(gen) == b"\x01\x02" # consume only the first chunk
gen.close() # raises GeneratorExit at the yield -> finally closes the WAV
_channels, _width, _rate, frames = _read_wav(path)
assert frames == b"\x01\x02" # only the consumed chunk landed


def test_tee_wav_empty_stream_writes_a_zero_length_wav(tmp_path):
path = tmp_path / "a.wav"
assert list(record.tee_wav(iter([]), path, rate=16000)) == []
_channels, _width, _rate, frames = _read_wav(path)
assert frames == b""


def test_tee_wav_unopenable_path_is_a_clean_error(tmp_path):
# Pointing at a directory can't be opened for writing -> a CLIError, not a raw OSError.
with pytest.raises(CLIError) as excinfo:
# tee_wav opens lazily on first iteration, so the generator must be started.
next(record.tee_wav(iter([b"\x01\x02"]), tmp_path, rate=16000))
assert excinfo.value.error_type == "save_audio_path"


def test_validate_target_accepts_an_existing_directory(tmp_path):
record.validate_target(tmp_path / "rec.wav") # parent exists -> no raise


def test_validate_target_rejects_a_missing_parent_directory(tmp_path):
with pytest.raises(CLIError) as excinfo:
record.validate_target(tmp_path / "nope" / "rec.wav")
assert excinfo.value.error_type == "save_audio_path"
assert excinfo.value.exit_code == 2
Loading