From 1417c5260d0d165ea851249654738f748eebba5e Mon Sep 17 00:00:00 2001 From: Claude Date: Tue, 16 Jun 2026 18:40:02 +0000 Subject: [PATCH] Add --show-code to agent-cascade MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mirror the --show-code flag the other run commands carry (transcribe / stream / agent) on agent-cascade: print a runnable Python script that wires the three primitives client-side — Streaming STT -> LLM Gateway -> streaming TTS — instead of holding a live conversation. The generated script targets the active environment's hosts (sandbox-only, since streaming TTS has no production host) and reflects the named per-leg knobs (voice, language, system prompt, greeting, model, max tokens, speech model, format-turns); the --stt/--llm/--tts-config escape hatches are not inlined. As with `agent --show-code`, a passed audio source warns on stderr that the script is mic-driven. The code_gen template is split into a brace-free header (filled via str.format with the injected constants) and a static body holding the brace-heavy orchestration, so no literal brace has to be doubled. --- README.md | 6 +- aai_cli/code_gen/__init__.py | 17 ++ aai_cli/code_gen/agent_cascade.py | 119 ++++++++++++ aai_cli/code_gen/agent_cascade_body.py | 173 ++++++++++++++++++ aai_cli/commands/agent_cascade/__init__.py | 10 + aai_cli/commands/agent_cascade/_exec.py | 36 +++- scripts/generated_code_compile_gate.py | 13 ++ .../test_snapshots_help_run.ambr | 5 + tests/test_agent_cascade_command.py | 1 + tests/test_agent_cascade_show_code.py | 111 +++++++++++ tests/test_code_gen_agent_cascade.py | 107 +++++++++++ 11 files changed, 594 insertions(+), 4 deletions(-) create mode 100644 aai_cli/code_gen/agent_cascade.py create mode 100644 aai_cli/code_gen/agent_cascade_body.py create mode 100644 tests/test_agent_cascade_show_code.py create mode 100644 tests/test_code_gen_agent_cascade.py diff --git a/README.md b/README.md index 150fedde..84852b90 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ That's it. Run `assembly onboard` for a guided tour, or see [Installation](#-ins - **🎯 One command for everything**: transcription, real-time streaming, voice agents, LLM prompts, and WER benchmarking — no SDK boilerplate. - **🔌 Built for pipelines**: data goes to stdout, errors to stderr, `--json` gives stable machine-readable output, and `-` reads audio from stdin. - **🔐 Secure by default**: your API key lives in the OS keyring, never in a dotfile — and run commands have no `--api-key` flag, so keys can't leak into `ps` or shell history. -- **🛠️ From demo to deployed app**: `assembly init` scaffolds a runnable FastAPI starter, `assembly dev` / `share` / `deploy` run, tunnel, and ship it, and `--show-code` prints the equivalent Python SDK script for any run command. +- **🛠️ From demo to deployed app**: `assembly init` scaffolds a runnable FastAPI starter, `assembly dev` / `share` / `deploy` run, tunnel, and ship it, and `--show-code` prints the equivalent Python SDK script for any run command (`transcribe` / `stream` / `agent` / `agent-cascade`). - **🤖 Agent-ready**: `assembly setup install` wires your coding agent up with the AssemblyAI docs MCP server and skills. - **📖 Open source**: MIT licensed. @@ -62,7 +62,7 @@ That's it. Run `assembly onboard` for a guided tour, or see [Installation](#-ins | `assembly transcripts` / `sessions` | Browse and fetch past transcripts and streaming sessions | | `assembly keys` / `balance` / `usage` / `limits` / `audit` | Account self-service via browser login | -Add `--show-code` to `transcribe` / `stream` / `agent` to print the equivalent Python SDK script instead of running — the built-in path from CLI experiment to SDK code. +Add `--show-code` to `transcribe` / `stream` / `agent` / `agent-cascade` to print the equivalent Python SDK script instead of running — the built-in path from CLI experiment to SDK code. ## ✨ Things you can do with it @@ -152,7 +152,7 @@ printf '%s\n' \ assembly agent --voice ivy --system-prompt "you're a helpful interviewer" ``` -**Graduate to the SDK** — `--show-code` prints the equivalent Python script for any `transcribe`/`stream`/`agent` run instead of executing it: +**Graduate to the SDK** — `--show-code` prints the equivalent Python script for any `transcribe`/`stream`/`agent`/`agent-cascade` run instead of executing it: ```sh assembly agent --system-prompt "you're a story generator" --show-code > story.py diff --git a/aai_cli/code_gen/__init__.py b/aai_cli/code_gen/__init__.py index cd5f5f96..4bd15304 100644 --- a/aai_cli/code_gen/__init__.py +++ b/aai_cli/code_gen/__init__.py @@ -1,9 +1,15 @@ from __future__ import annotations +from typing import TYPE_CHECKING + from aai_cli.code_gen import agent as _agent +from aai_cli.code_gen import agent_cascade as _agent_cascade from aai_cli.code_gen import stream as _stream from aai_cli.code_gen import transcribe as _transcribe +if TYPE_CHECKING: + from aai_cli.agent_cascade.config import CascadeConfig + def gateway_options( prompts: list[str], model: str, max_tokens: int, *, interval: float = 0.0 @@ -28,6 +34,17 @@ def agent(voice: str, system_prompt: str, greeting: str) -> str: return _agent.render(voice, system_prompt, greeting) +def agent_cascade(config: CascadeConfig, *, speech_model: str) -> str: + """Generate runnable Python that reproduces this terminal cascade session. + + Unlike `agent` (one Voice Agent socket), the cascade wires the three primitives + itself — Streaming STT, the LLM Gateway, and streaming TTS — so the script mirrors + the CLI's client-side orchestration. Sandbox hosts only, since streaming TTS has no + production host. + """ + return _agent_cascade.render(config, speech_model=speech_model) + + def transcribe( merged: dict[str, object], source: str, diff --git a/aai_cli/code_gen/agent_cascade.py b/aai_cli/code_gen/agent_cascade.py new file mode 100644 index 00000000..0a861911 --- /dev/null +++ b/aai_cli/code_gen/agent_cascade.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING +from urllib.parse import urlencode + +from aai_cli.code_gen import agent_cascade_body +from aai_cli.core import environments + +if TYPE_CHECKING: + from aai_cli.agent_cascade.config import CascadeConfig + +# The header carries only the injected constants and the reply-cue predicate, so it +# has no literal braces and is safe to fill with str.format. All the brace-heavy +# orchestration (dict/set literals, the protocol loops) lives in the static body, +# which is never formatted — so no brace has to be doubled. +_HEADER = """\ +# Live voice cascade: Streaming STT -> LLM Gateway -> streaming TTS, wired client-side. +# This is what `assembly --sandbox agent-cascade` runs: it transcribes your speech, +# sends each finalized turn to the LLM Gateway, and speaks the reply through streaming +# TTS — the same three primitives the agent-cascade init template wires server-side. +# Requires audio + websockets: pip install sounddevice websockets openai +# Tip: use headphones — the mic stays open while the agent speaks, so on speakers it +# would hear itself and loop. +import base64 +import json +import os +import queue +import threading + +import sounddevice as sd +from openai import OpenAI +from websockets.sync.client import connect + +# Export your key first: export ASSEMBLYAI_API_KEY="" +API_KEY = os.environ["ASSEMBLYAI_API_KEY"] +STT_URL = {stt_url} +TTS_URL = {tts_url} +GATEWAY_URL = {gateway_url} +MODEL = {model} +MAX_TOKENS = {max_tokens} +MAX_HISTORY = {max_history} +SYSTEM_PROMPT = {system_prompt} +GREETING = {greeting} +RATE = 24000 # one full-duplex rate for mic capture + TTS playback (TTS native PCM16 mono) + + +def is_reply_cue(event): + # The cue to generate a reply. {cue_comment} + return {cue_expr} +""" + + +def _stt_url(speech_model: str, *, format_turns: bool) -> str: + """The Streaming v3 socket URL for the active environment. + + The mic is captured and streamed at 24 kHz (the one full-duplex rate), so the + sample_rate query param matches — a mismatch corrupts the audio server-side. + """ + params = urlencode( + { + "sample_rate": 24000, + "encoding": "pcm_s16le", + "speech_model": speech_model, + "format_turns": "true" if format_turns else "false", + } + ) + return f"wss://{environments.active().streaming_host}/v3/ws?{params}" + + +def _tts_url(voice: str, language: str | None) -> str: + """The streaming-TTS socket URL for the configured voice (sandbox-only host).""" + params: dict[str, str] = {"voice": voice, "sample_rate": "24000"} + if language is not None: + params["language"] = language + return f"wss://{environments.active().streaming_tts_host}/v1/ws/?{urlencode(params)}" + + +def _cue(*, format_turns: bool) -> tuple[str, str]: + """The (comment, predicate) for the reply trigger. + + With formatting on, wait for the *formatted* end-of-turn (better text for the LLM); + with it off the server never sets turn_is_formatted, so a bare end-of-turn is the cue. + """ + if format_turns: + return ( + "With --format-turns, wait for the punctuated end-of-turn.", + 'bool(event.get("end_of_turn")) and bool(event.get("turn_is_formatted"))', + ) + return ( + "With --no-format-turns the server never formats, so a bare end-of-turn is the cue.", + 'bool(event.get("end_of_turn"))', + ) + + +def render(config: CascadeConfig, *, speech_model: str) -> str: + """Generate a runnable terminal cascade script from a cascade config + STT model. + + Hosts come from the active environment, so a sandbox run generates a script that + targets the sandbox its key was minted for. The script mirrors the CLI run path: + one full-duplex mic+speaker stream, one LLM completion per finalized turn, spoken + sentence-by-sentence through a fresh TTS socket, with barge-in on the next turn. + The named per-leg knobs are reflected; the --stt/--llm/--tts-config escape hatches + (config.llm_extra / config.tts_extra) are not. + """ + cue_comment, cue_expr = _cue(format_turns=config.format_turns) + header = _HEADER.format( + stt_url=json.dumps(_stt_url(speech_model, format_turns=config.format_turns)), + tts_url=json.dumps(_tts_url(config.voice, config.language)), + gateway_url=json.dumps(environments.active().llm_gateway_base), + model=json.dumps(config.model), + max_tokens=config.max_tokens, + max_history=config.max_history, + system_prompt=json.dumps(config.system_prompt), + greeting=json.dumps(config.greeting), + cue_comment=cue_comment, + cue_expr=cue_expr, + ) + return header + agent_cascade_body.BODY diff --git a/aai_cli/code_gen/agent_cascade_body.py b/aai_cli/code_gen/agent_cascade_body.py new file mode 100644 index 00000000..a5d68dee --- /dev/null +++ b/aai_cli/code_gen/agent_cascade_body.py @@ -0,0 +1,173 @@ +"""The static body of the generated agent-cascade script. + +Kept separate from the header so the orchestration's many literal braces (dict/set +literals, the STT/TTS protocol loops) stay verbatim — this string is concatenated +onto the formatted header, never passed through str.format itself. +""" + +from __future__ import annotations + +# The constants (API_KEY, STT_URL, TTS_URL, GATEWAY_URL, MODEL, …) and is_reply_cue() +# are defined by the header above this body; everything here references them. +BODY = """ + +gateway = OpenAI(api_key=API_KEY, base_url=GATEWAY_URL) +history = [] # alternating user/assistant turns — the sliding LLM-context window +stop_reply = threading.Event() # set on barge-in to cut a reply short +reply_thread = None + +# ONE full-duplex stream (mic + speaker together) at 24 kHz. Opening two separate +# input/output streams on one device fails on macOS CoreAudio, which silently kills +# capture; a single sd.RawStream callback handles both directions. +mic_queue: queue.Queue = queue.Queue() +play_buffer = bytearray() +buffer_lock = threading.Lock() + + +def on_audio(indata, outdata, _frames, _time, _status): + mic_queue.put_nowait(bytes(indata)) # capture -> queue for STT + # Playback: drain the agent's audio into the output, zero-filling any shortfall. + needed = len(outdata) + with buffer_lock: + take = bytes(play_buffer[:needed]) + del play_buffer[:needed] + outdata[: len(take)] = take + if len(take) < needed: + outdata[len(take):] = b"\\x00" * (needed - len(take)) + + +def enqueue_audio(pcm): + with buffer_lock: + play_buffer.extend(pcm) + + +def flush_audio(): # drop queued-but-unplayed audio (used on barge-in) + with buffer_lock: + play_buffer.clear() + + +def trim_history(): # cap the running history to the most recent MAX_HISTORY messages + if len(history) > MAX_HISTORY: + del history[: len(history) - MAX_HISTORY] + + +def split_sentences(text): + # Split a reply into sentences (each ending in . ! ?) so the first audio can play + # before the whole answer is synthesized; a trailing fragment is kept too. + sentences, start = [], 0 + for i, ch in enumerate(text): + if ch in ".!?": + piece = text[start: i + 1].strip() + if piece: + sentences.append(piece) + start = i + 1 + tail = text[start:].strip() + if tail: + sentences.append(tail) + return sentences + + +def synthesize(text): + # Open a fresh streaming-TTS socket (the voice is fixed at connect time), drive the + # Begin -> Generate -> Flush -> Audio protocol, and return the concatenated PCM. TTS + # authenticates with the raw API key, not a Bearer token (the streaming convention). + pcm = bytearray() + with connect(TTS_URL, additional_headers={"Authorization": API_KEY}, max_size=None) as ws: + if json.loads(ws.recv()).get("type") != "Begin": + return b"" + ws.send(json.dumps({"type": "Generate", "text": text})) + ws.send(json.dumps({"type": "Flush"})) + for raw in ws: + frame = json.loads(raw) + kind = frame.get("type") + if kind == "Audio": + pcm += base64.b64decode(frame.get("audio", "")) + if frame.get("is_final"): + break + elif kind in ("FlushDone", "Error"): + break + ws.send(json.dumps({"type": "Terminate"})) + return bytes(pcm) + + +def speak(text): # show + synthesize one chunk of agent speech, honoring a barge-in + print("agent:", text) + if not stop_reply.is_set(): + enqueue_audio(synthesize(text)) + + +def generate_reply(): + # One LLM completion over the running history, spoken sentence-by-sentence. Record + # what was actually spoken so a barge-in still leaves the history alternating. + messages = [{"role": "system", "content": SYSTEM_PROMPT}, *history] + reply = gateway.chat.completions.create( + model=MODEL, messages=messages, max_tokens=MAX_TOKENS + ).choices[0].message.content or "" + spoken = [] + for sentence in split_sentences(reply): + if stop_reply.is_set(): + break + speak(sentence) + spoken.append(sentence) + said = " ".join(spoken).strip() + if said: + history.append({"role": "assistant", "content": said}) + trim_history() + + +def barge_in(): + # A new user turn cuts off any reply still playing: stop the worker and drop the + # queued audio (the flush is what silences the already-buffered speech). + if reply_thread is not None and reply_thread.is_alive(): + stop_reply.set() + flush_audio() + reply_thread.join() + + +def send_mic(stt): + while True: + chunk = mic_queue.get() + try: + stt.send(chunk) + except Exception: + return # socket closed (session over): end the mic thread quietly + + +stream = sd.RawStream( + samplerate=RATE, channels=1, dtype="int16", blocksize=RATE // 10, callback=on_audio +) +stream.start() + +# Greet first, seeding the opening line into the history so the model has a record of it. +if GREETING: + history.append({"role": "assistant", "content": GREETING}) + speak(GREETING) + +with connect(STT_URL, additional_headers={"Authorization": API_KEY}) as stt: + threading.Thread(target=send_mic, args=(stt,), daemon=True).start() + print("Connected — start talking. (Ctrl-C to stop)") + try: + for raw in stt: + event = json.loads(raw) + if event.get("type") != "Turn": + continue + text = (event.get("transcript") or "").strip() + if not text: + continue + if is_reply_cue(event): + print("you: ", text) + barge_in() + history.append({"role": "user", "content": text}) + trim_history() + stop_reply.clear() + reply_thread = threading.Thread(target=generate_reply, daemon=True) + reply_thread.start() + else: + barge_in() # an interim turn only interrupts a playing reply + except KeyboardInterrupt: + print("\\nStopped.") + finally: + stop_reply.set() + stream.stop() + stream.close() +""" diff --git a/aai_cli/commands/agent_cascade/__init__.py b/aai_cli/commands/agent_cascade/__init__.py index 448520ae..3e99f146 100644 --- a/aai_cli/commands/agent_cascade/__init__.py +++ b/aai_cli/commands/agent_cascade/__init__.py @@ -57,6 +57,10 @@ def _emit_voice_list(_state: AppState, json_mode: bool) -> None: 'assembly --sandbox agent-cascade --system-prompt "You are a terse pirate."', ), ("See available voices", "assembly --sandbox agent-cascade --list-voices"), + ( + "Print equivalent Python instead of running", + "assembly --sandbox agent-cascade --show-code", + ), ] ), ) @@ -159,6 +163,11 @@ def agent_cascade( "--output", help="Output mode: text (you:/agent: lines as plain stdout, pipe-friendly) or json", ), + show_code: bool = typer.Option( + False, + "--show-code", + help="Print the equivalent Python SDK code and exit (does not start a session)", + ), ) -> None: """\\[sandbox] Hold a live voice conversation through a self-wired cascade @@ -201,5 +210,6 @@ def agent_cascade( llm_config=tuple(llm_config or ()), language=language, tts_config=tuple(tts_config or ()), + show_code=show_code, ) run_with_options(ctx, agent_cascade_exec.run_agent_cascade, opts, json=json_out) diff --git a/aai_cli/commands/agent_cascade/_exec.py b/aai_cli/commands/agent_cascade/_exec.py index 3f364f5e..d13cdf1e 100644 --- a/aai_cli/commands/agent_cascade/_exec.py +++ b/aai_cli/commands/agent_cascade/_exec.py @@ -15,10 +15,11 @@ import typer +from aai_cli import code_gen from aai_cli.agent.audio import SAMPLE_RATE, DuplexAudio, NullPlayer from aai_cli.agent.render import AgentRenderer from aai_cli.agent_cascade import engine, voices -from aai_cli.agent_cascade.config import CascadeConfig +from aai_cli.agent_cascade.config import DEFAULT_MAX_HISTORY, CascadeConfig from aai_cli.app.agent_shared import resolve_system_prompt as _resolve_system_prompt from aai_cli.app.context import AppState from aai_cli.core import choices, client, config_builder, llm @@ -27,6 +28,7 @@ from aai_cli.streaming.session import resolve_output_modes from aai_cli.streaming.sources import FileSource from aai_cli.tts import session as tts_session +from aai_cli.ui import output if TYPE_CHECKING: from assemblyai.streaming.v3 import StreamingParameters @@ -70,6 +72,8 @@ class AgentCascadeOptions: # Text-to-speech: language named, any other query param via --tts-config. language: str | None tts_config: tuple[str, ...] + # Print the equivalent Python instead of running a conversation. + show_code: bool def _build_stt_params(opts: AgentCascadeOptions, sample_rate: int) -> StreamingParameters: @@ -135,6 +139,32 @@ def _open_audio( return duplex.mic, duplex.player, SAMPLE_RATE +def _print_show_code(opts: AgentCascadeOptions, system_prompt_text: str) -> None: + """Print the equivalent cascade script and exit without authenticating or opening + audio. Raw stdout for `> script.py`; the named per-leg knobs are reflected, the + --stt/--llm/--tts-config escape hatches are not.""" + if opts.source or opts.sample: + # The generated script is microphone-driven (like the agent snippet); a + # faithful file-driven cascade would need the CLI's ffmpeg-decode + + # exit-after-reply machinery. Say so on stderr so `--show-code > script.py` + # stays byte-clean instead of silently dropping the source. + output.error_console.print( + "[aai.warn]Note:[/aai.warn] the generated script uses the microphone; " + "it does not stream the audio source you passed." + ) + config = CascadeConfig( + voice=opts.voice, + system_prompt=system_prompt_text, + greeting=opts.greeting, + model=opts.model, + max_history=DEFAULT_MAX_HISTORY, + language=opts.language, + max_tokens=opts.max_tokens, + format_turns=opts.format_turns, + ) + output.print_code(code_gen.agent_cascade(config, speech_model=opts.speech_model)) + + def run_agent_cascade(opts: AgentCascadeOptions, state: AppState, *, json_mode: bool) -> None: """Execute one `assembly agent-cascade` cascade from already-parsed flags.""" text_mode, json_mode = resolve_output_modes(opts.output_field, json_mode=json_mode) @@ -147,6 +177,10 @@ def run_agent_cascade(opts: AgentCascadeOptions, state: AppState, *, json_mode: tts_session.require_available("agent-cascade") system_prompt_text = _resolve_system_prompt(opts.system_prompt, opts.system_prompt_file) + if opts.show_code: + _print_show_code(opts, system_prompt_text) + return + from_file = bool(opts.source) or opts.sample if from_file and opts.device is not None: raise UsageError("--device applies only to microphone input.") diff --git a/scripts/generated_code_compile_gate.py b/scripts/generated_code_compile_gate.py index 82346ce8..4a01b9b3 100644 --- a/scripts/generated_code_compile_gate.py +++ b/scripts/generated_code_compile_gate.py @@ -104,6 +104,19 @@ def main() -> int: "--show-code", ), ), + ( + # Sandbox-only: streaming TTS has no prod host, so --sandbox makes the URLs valid. + "agent-cascade-basic", + ( + "--sandbox", + "agent-cascade", + "--voice", + "jane", + "--greeting", + "Hello there", + "--show-code", + ), + ), ) runner = CliRunner() diff --git a/tests/__snapshots__/test_snapshots_help_run.ambr b/tests/__snapshots__/test_snapshots_help_run.ambr index 74c27476..aa83c523 100644 --- a/tests/__snapshots__/test_snapshots_help_run.ambr +++ b/tests/__snapshots__/test_snapshots_help_run.ambr @@ -50,6 +50,9 @@ │ --output -o [text|json] Output mode: text (you:/agent: │ │ lines as plain stdout, │ │ pipe-friendly) or json │ + │ --show-code Print the equivalent Python SDK │ + │ code and exit (does not start a │ + │ session) │ │ --help Show this message and exit. │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Text-to-speech ─────────────────────────────────────────────────────────────╮ @@ -100,6 +103,8 @@ $ assembly --sandbox agent-cascade --system-prompt "You are a terse pirate." See available voices $ assembly --sandbox agent-cascade --list-voices + Print equivalent Python instead of running + $ assembly --sandbox agent-cascade --show-code diff --git a/tests/test_agent_cascade_command.py b/tests/test_agent_cascade_command.py index ad010cda..7b89efde 100644 --- a/tests/test_agent_cascade_command.py +++ b/tests/test_agent_cascade_command.py @@ -47,6 +47,7 @@ llm_config=(), language=None, tts_config=(), + show_code=False, ) diff --git a/tests/test_agent_cascade_show_code.py b/tests/test_agent_cascade_show_code.py new file mode 100644 index 00000000..560d8c33 --- /dev/null +++ b/tests/test_agent_cascade_show_code.py @@ -0,0 +1,111 @@ +"""`assembly agent-cascade --show-code` tests. + +Split from test_agent_cascade_command.py (which holds the run-path wiring) so the +print-only path's many invocations live in their own file. The cascade is +sandbox-only, so the happy paths run under `--sandbox`; the generated code_gen +rendering itself is covered by test_code_gen_agent_cascade.py. +""" + +from __future__ import annotations + +from typer.testing import CliRunner + +from aai_cli.commands.agent_cascade import _exec +from aai_cli.core import config +from aai_cli.main import app + +runner = CliRunner() + + +def test_show_code_prints_sandbox_script_without_running(monkeypatch): + # Print-only: emits the cascade script, never wires deps or opens audio, no auth. + def _boom(**kwargs): + raise AssertionError("must not run a cascade") + + monkeypatch.setattr(_exec.engine, "run_cascade", _boom) + monkeypatch.setattr( + config, "resolve_api_key", lambda **_: (_ for _ in ()).throw(AssertionError("no auth")) + ) + result = runner.invoke( + app, + ["--sandbox", "agent-cascade", "--voice", "jane", "--greeting", "Hi there", "--show-code"], + ) + assert result.exit_code == 0 + # Targets the sandbox the key was minted for — all three legs. + assert "streaming.sandbox000" in result.stdout + assert "streaming-tts.sandbox000" in result.stdout + assert "llm-gateway" in result.stdout + assert "voice=jane" in result.stdout # the chosen voice rides the TTS URL + assert "Hi there" in result.stdout # the greeting is injected + compile(result.stdout, "", "exec") # the script is runnable Python + + +def test_show_code_defaults_off_at_the_argv_seam(monkeypatch): + # Pin the Typer default (omitted -> False, so a bare run holds a conversation) and the + # explicit form, captured at the argv->options seam so the run body never executes. + captured = {} + + def fake_run(opts, state, *, json_mode): + captured["opts"] = opts + + monkeypatch.setattr(_exec, "run_agent_cascade", fake_run) + assert runner.invoke(app, ["agent-cascade"]).exit_code == 0 + assert captured["opts"].show_code is False + assert runner.invoke(app, ["agent-cascade", "--show-code"]).exit_code == 0 + assert captured["opts"].show_code is True + + +def test_show_code_injects_speech_model(monkeypatch): + monkeypatch.setattr(_exec.engine, "run_cascade", lambda **kw: None) + result = runner.invoke( + app, ["--sandbox", "agent-cascade", "--speech-model", "u3-rt-pro", "--show-code"] + ) + assert result.exit_code == 0 + assert "speech_model=u3-rt-pro" in result.stdout + + +def test_show_code_reflects_no_format_turns(monkeypatch): + monkeypatch.setattr(_exec.engine, "run_cascade", lambda **kw: None) + formatted = runner.invoke(app, ["--sandbox", "agent-cascade", "--show-code"]) + bare = runner.invoke(app, ["--sandbox", "agent-cascade", "--no-format-turns", "--show-code"]) + # With formatting on the cue waits for the punctuated turn; off, a bare end-of-turn fires. + assert "turn_is_formatted" in formatted.stdout + assert "turn_is_formatted" not in bare.stdout + assert "format_turns=false" in bare.stdout + + +def test_show_code_threads_model_and_max_tokens(monkeypatch): + monkeypatch.setattr(_exec.engine, "run_cascade", lambda **kw: None) + result = runner.invoke( + app, + ["--sandbox", "agent-cascade", "--model", "claude-x", "--max-tokens", "321", "--show-code"], + ) + assert result.exit_code == 0 + assert "claude-x" in result.stdout + assert "MAX_TOKENS = 321" in result.stdout + + +def test_show_code_file_source_warns_on_stderr(monkeypatch): + # The generated script is mic-driven; a passed source must warn, not be dropped silently. + monkeypatch.setattr( + _exec.engine, "run_cascade", lambda **kw: (_ for _ in ()).throw(AssertionError("no run")) + ) + result = runner.invoke(app, ["--sandbox", "agent-cascade", "clip.wav", "--show-code"]) + assert result.exit_code == 0 + assert "uses the microphone" in result.stderr + assert "uses the microphone" not in result.stdout # stdout stays a clean script + compile(result.stdout, "", "exec") + + +def test_show_code_mic_emits_no_warning(monkeypatch): + monkeypatch.setattr(_exec.engine, "run_cascade", lambda **kw: None) + result = runner.invoke(app, ["--sandbox", "agent-cascade", "--show-code"]) + assert result.exit_code == 0 + assert "uses the microphone" not in result.stderr # mic script matches the run, nothing to warn + + +def test_show_code_in_production_is_rejected_with_sandbox_hint(): + # --show-code still honors the sandbox-only guard, so the generated URLs are valid. + result = runner.invoke(app, ["agent-cascade", "--show-code"]) + assert result.exit_code == 2 + assert "only available in the sandbox" in result.output diff --git a/tests/test_code_gen_agent_cascade.py b/tests/test_code_gen_agent_cascade.py new file mode 100644 index 00000000..8dfbe0dc --- /dev/null +++ b/tests/test_code_gen_agent_cascade.py @@ -0,0 +1,107 @@ +"""Example-based code_gen tests for the agent-cascade scaffold. + +The cascade wires three primitives client-side (Streaming STT -> LLM Gateway -> +streaming TTS), so the generated script is checked for all three legs plus the +session knobs it must inject. Sandbox hosts only — streaming TTS has no prod host. +""" + +from __future__ import annotations + +import ast +import dataclasses +import json + +import pytest + +from aai_cli import code_gen +from aai_cli.agent_cascade.config import CascadeConfig +from aai_cli.core import environments + + +@pytest.fixture(autouse=True) +def _sandbox_env(): + # The cascade is sandbox-only (streaming TTS has no prod host), so generate against it. + environments.set_active(environments.get("sandbox000")) + + +def _render(*, speech_model="u3-rt-pro", **overrides): + config = dataclasses.replace( + CascadeConfig( + voice="jane", + system_prompt="Be terse.", + greeting="Hi there", + model="claude-haiku-4-5-20251001", + max_tokens=1000, + max_history=40, + ), + **overrides, + ) + return code_gen.agent_cascade(config, speech_model=speech_model) + + +def test_render_parses_and_wires_all_three_legs(): + code = _render() + ast.parse(code) + sandbox = environments.get("sandbox000") + # STT, LLM Gateway, and TTS hosts all come from the active (sandbox) environment. + assert f"wss://{sandbox.streaming_host}/v3/ws" in code + assert f"wss://{sandbox.streaming_tts_host}/v1/ws/" in code + assert sandbox.llm_gateway_base in code + assert 'os.environ["ASSEMBLYAI_API_KEY"]' in code + + +def test_render_injects_session_knobs(): + code = _render(model="claude-x", max_tokens=321, max_history=12) + ast.parse(code) + assert "voice=jane" in code # the voice rides the TTS URL + assert "Be terse." in code # the system prompt + assert "Hi there" in code # the greeting + assert '"claude-x"' in code # the LLM model + assert "MAX_TOKENS = 321" in code + assert "MAX_HISTORY = 12" in code + + +def test_render_streams_stt_at_the_full_duplex_rate(): + # One full-duplex stream means the STT sample_rate must match the 24 kHz capture rate; + # a mismatch corrupts the audio server-side. Pin the STT URL's own sample_rate (not just + # any "sample_rate=24000", which the TTS URL also carries) so a drift can't slip through. + code = _render() + ast.parse(code) + assert "/v3/ws?sample_rate=24000&encoding=pcm_s16le" in code + assert "RATE = 24000" in code + + +def test_render_format_turns_waits_for_the_formatted_turn(): + code = _render(format_turns=True) + ast.parse(code) + assert "format_turns=true" in code + assert "turn_is_formatted" in code # the reply cue waits for the punctuated turn + + +def test_render_no_format_turns_fires_on_bare_end_of_turn(): + code = _render(format_turns=False) + ast.parse(code) + assert "format_turns=false" in code + # The server never formats, so a bare end-of-turn is the cue (no turn_is_formatted gate). + assert "turn_is_formatted" not in code + + +def test_render_includes_language_only_when_set(): + assert "language=" not in _render(language=None) + assert "language=de" in _render(language="de") + + +def test_render_uses_single_full_duplex_stream(): + # ONE sd.RawStream (mic + speaker); two separate streams fail on macOS CoreAudio. + code = _render() + ast.parse(code) + assert "sd.RawStream(" in code + assert "RawInputStream" not in code + assert "RawOutputStream" not in code + + +def test_render_escapes_quotes_in_prompt(): + tricky = 'Say "hi"\nand stop' + code = _render(system_prompt=tricky) + ast.parse(code) # valid Python despite embedded quotes/newlines + assert json.dumps(tricky) in code # injected via json.dumps, escaped form appears verbatim