diff --git a/aai_cli/commands/transcripts.py b/aai_cli/commands/transcripts.py index e144a421..50cde060 100644 --- a/aai_cli/commands/transcripts.py +++ b/aai_cli/commands/transcripts.py @@ -5,8 +5,9 @@ from aai_cli import command_registry, help_panels, options from aai_cli.app.context import AppState, run_command -from aai_cli.core import choices, client, timeparse -from aai_cli.core.errors import APIError +from aai_cli.app.transcribe.run import render_transform_steps +from aai_cli.core import choices, client, llm, stdio, timeparse +from aai_cli.core.errors import APIError, UsageError from aai_cli.ui import output, theme from aai_cli.ui.help_text import examples_epilog @@ -65,20 +66,170 @@ def render(data: list[dict[str, object]]) -> object: run_command(ctx, body, json=json_out) +def _resolve_ids(transcript_id: str | None) -> tuple[list[str], bool]: + """The transcript ids to fetch, and whether this is a stdin batch. + + A positional id stays the single-fetch path (output shape unchanged). With no + id, ids are read from piped stdin so ``transcripts list --json | …`` composes; + running interactively with no id is a usage error rather than a hang. + """ + if transcript_id is not None: + return [transcript_id], False + piped = stdio.piped_stdin_text() + if piped is None: + raise UsageError( + "Give a transcript id, or pipe transcript ids on stdin.", + suggestion="e.g. assembly transcripts list --json | assembly transcripts get -o text", + ) + ids = client.parse_transcript_ids(piped) + if not ids: + raise UsageError( + "No transcript ids found on stdin.", + suggestion="Pipe `assembly transcripts list --json`, or one id per line.", + ) + return ids, True + + +def _emit_transcript(transcript: object, *, json_mode: bool, batch: bool) -> None: + """Render one fetched transcript (no -o field, no --llm) for the chosen output mode.""" + if json_mode and batch: + # One NDJSON record per id so a downstream stage can map over the stream; + # "type" discriminates NDJSON lines CLI-wide (matching `transcribe` batch). + output.emit_ndjson({"type": "transcript", **client.transcript_json_payload(transcript)}) + elif json_mode: + # The full SDK payload, identical to `assembly transcribe … --json`, so the + # same `jq` works whether the transcript is fetched fresh or re-fetched. + output.emit(client.transcript_json_payload(transcript), lambda d: d, json_mode=True) + elif batch: + output.emit_text(str(client.transcript_summary(transcript)["text"])) + else: + output.emit( + client.transcript_summary(transcript), + lambda d: escape(str(d["text"])), + json_mode=False, + ) + + +def _id_of(transcript: object) -> str: + return str(getattr(transcript, "id", "") or "") + + +def _text_of(transcript: object) -> str: + return str(getattr(transcript, "text", "") or "") + + +def _emit_transform( + transcript: object, + model: str, + steps: list[dict[str, str]], + *, + json_mode: bool, + batch: bool, +) -> None: + """Emit a transcript's ``--llm`` chain result: NDJSON per id in batch, else like `transcribe`.""" + record = client.transcript_summary(transcript) | {"transform": {"model": model, "steps": steps}} + if json_mode and batch: + output.emit_ndjson({"type": "transcript", **record}) + else: + output.emit(record, render_transform_steps, json_mode=json_mode) + + +def _deliver_transcript( + transcript: object, + api_key: str, + *, + output_field: choices.TranscriptOutput | None, + chars_per_caption: int | None, + chain: list[str], + model: str, + max_tokens: int, + json_mode: bool, + batch: bool, + suppress: bool, +) -> str: + """Emit one fetched transcript (unless ``suppress``ed for a pending reduce) and return + its ``--llm-reduce`` contribution — the last ``--llm`` output, else the transcript text.""" + if output_field is not None: + # -o wins over the chain, matching `transcribe` deliver_result precedence; a + # pending human reduce suppresses the per-id field so only the aggregate prints. + if not suppress: + output.emit_text( + client.select_transcript_field( + transcript, output_field, chars_per_caption=chars_per_caption + ) + ) + return _text_of(transcript) + if chain: + steps = llm.run_chain_steps( + api_key, chain, transcript_id=_id_of(transcript), model=model, max_tokens=max_tokens + ) + if not suppress: + _emit_transform(transcript, model, steps, json_mode=json_mode, batch=batch) + return steps[-1]["output"] if steps else "" + if not suppress: + _emit_transcript(transcript, json_mode=json_mode, batch=batch) + return _text_of(transcript) + + +def _run_reduce( + api_key: str, + contributions: list[tuple[str, str]], + *, + prompts: list[str], + model: str, + max_tokens: int, + json_mode: bool, +) -> None: + """Run the ``--llm-reduce`` chain once over every fetched transcript; print to stdout. + + Mirrors `transcribe`'s reduce: concatenate each id's contribution under a header, + skip the billable call when there's nothing to reduce, and emit the same additive + ``{"type": "reduce", …}`` NDJSON record under --json. + """ + combined = "\n\n".join(f"### Transcript: {tid}\n{text}" for tid, text in contributions if text) + if not combined: + output.emit_warning( + "Nothing to reduce: no transcript text across ids.", json_mode=json_mode + ) + return + result = llm.run_chain( + api_key, prompts, transcript_text=combined, model=model, max_tokens=max_tokens + ) + if json_mode: + output.emit_ndjson({"type": "reduce", "model": model, "prompts": prompts, "output": result}) + else: + output.emit_text(result) + + @app.command( epilog=examples_epilog( [ ("Fetch a transcript's text by id", "assembly transcripts get 5551234-abcd"), ("Speaker-labeled turns", "assembly transcripts get 5551234-abcd -o utterances"), ("Save SRT subtitles", "assembly transcripts get 5551234-abcd -o srt > captions.srt"), - ("Save VTT subtitles", "assembly transcripts get 5551234-abcd -o vtt > captions.vtt"), ("Get the raw JSON", "assembly transcripts get 5551234-abcd --json"), + ( + "Fetch many at once from a piped list", + "assembly transcripts list --json | assembly transcripts get -o text", + ), + ( + "Summarize each transcript in a piped list", + "assembly transcripts list --json | " + 'assembly transcripts get --llm "Summarize this call"', + ), + ( + "Rank a piped list with one reduce prompt", + "assembly transcripts list --json | " + 'assembly transcripts get --llm-reduce "Rank these worst-to-best"', + ), ] ) ) def get( ctx: typer.Context, - transcript_id: str = typer.Argument(..., help="Transcript id"), + transcript_id: str | None = typer.Argument( + None, help="Transcript id; omit to read ids from stdin" + ), output_field: choices.TranscriptOutput | None = typer.Option( None, "-o", @@ -86,39 +237,87 @@ def get( help="Print one field of the result", ), chars_per_caption: int | None = options.chars_per_caption_option(), + llm_prompt: list[str] | None = typer.Option( + None, + "--llm", + help="Transform each transcript through LLM Gateway. Repeatable: each prompt runs " + "on the previous one's response (a chain), the first on the transcript.", + rich_help_panel=help_panels.OPT_LLM, + ), + llm_reduce: list[str] | None = typer.Option( + None, + "--llm-reduce", + help="Run one LLM-Gateway prompt over all fetched transcripts (a reduce). " + "Repeatable: each runs on the previous one's output. For a single id it " + "extends the --llm chain over that transcript.", + rich_help_panel=help_panels.OPT_LLM, + ), + model: str = typer.Option( + llm.DEFAULT_MODEL, + "--model", + help="LLM Gateway model", + rich_help_panel=help_panels.OPT_LLM, + autocompletion=llm.complete_model, + ), + max_tokens: int = typer.Option( + llm.DEFAULT_MAX_TOKENS, + "--max-tokens", + help="Max tokens", + rich_help_panel=help_panels.OPT_LLM, + ), json_out: bool = options.json_option(), ) -> None: - """Fetch a past transcript by id and print its text""" + """Fetch a past transcript by id and print its text + + Omit the id to read transcript ids from stdin — one per line, or the JSON from + `assembly transcripts list --json`. Add --llm to transform each transcript through + LLM Gateway (a map), or --llm-reduce to run one prompt over them all (a reduce). + """ def body(state: AppState, json_mode: bool) -> None: # Cheap local validation first: a malformed id or flag conflict is a usage # error whether or not the user is signed in, so it must not trigger auth. - client.validate_transcript_id(transcript_id) client.validate_chars_per_caption(chars_per_caption, output_field) + map_prompts = list(llm_prompt or []) + reduce_prompts = list(llm_reduce or []) + ids, batch = _resolve_ids(transcript_id) + for tid in ids: + client.validate_transcript_id(tid) + # A single source has nothing to aggregate, so --llm-reduce just extends the + # --llm chain over that transcript (mirrors `transcribe`); a stdin batch runs + # the reduce separately over every fetched transcript. + per_transcript_chain = map_prompts if batch else map_prompts + reduce_prompts + do_reduce = batch and bool(reduce_prompts) api_key = state.resolve_api_key() - transcript = client.get_transcript(api_key, transcript_id) - if client.status_str(transcript) == "error": - raise APIError( - getattr(transcript, "error", None) or "Transcript failed.", - transcript_id=transcript_id, - ) - if output_field is not None: - # Raw single-field output for pipelines (overrides --json), matching `transcribe`. - output.emit_text( - client.select_transcript_field( - transcript, output_field, chars_per_caption=chars_per_caption + contributions: list[tuple[str, str]] = [] + for tid in ids: + transcript = client.get_transcript(api_key, tid) + if client.status_str(transcript) == "error": + raise APIError( + getattr(transcript, "error", None) or "Transcript failed.", + transcript_id=tid, ) + contribution = _deliver_transcript( + transcript, + api_key, + output_field=output_field, + chars_per_caption=chars_per_caption, + chain=per_transcript_chain, + model=model, + max_tokens=max_tokens, + json_mode=json_mode, + batch=batch, + suppress=do_reduce and not json_mode, ) - return - if json_mode: - # The full SDK payload, identical to `assembly transcribe … --json`, so the - # same `jq` works whether the transcript is fetched fresh or re-fetched. - output.emit(client.transcript_json_payload(transcript), lambda d: d, json_mode=True) - else: - output.emit( - client.transcript_summary(transcript), - lambda d: escape(str(d["text"])), - json_mode=False, + contributions.append((tid, contribution)) + if do_reduce: + _run_reduce( + api_key, + contributions, + prompts=reduce_prompts, + model=model, + max_tokens=max_tokens, + json_mode=json_mode, ) run_command(ctx, body, json=json_out) diff --git a/aai_cli/core/client.py b/aai_cli/core/client.py index f7e333a2..98765267 100644 --- a/aai_cli/core/client.py +++ b/aai_cli/core/client.py @@ -301,6 +301,40 @@ def validate_transcript_id(transcript_id: str) -> str: return transcript_id +def _extract_id(item: object) -> str: + """The transcript id from one parsed stdin item (a mapping's ``id`` or a bare line).""" + mapping = jsonshape.as_mapping(item) + if mapping is not None: + return str(mapping.get("id") or "").strip() + return str(item).strip() + + +def _stdin_items(stripped: str) -> list[object]: + """Items in piped stdin: a JSON array's elements, a single JSON object, or text lines.""" + try: + loaded: object = json.loads(stripped) + except json.JSONDecodeError: + return list(stripped.splitlines()) + mapping = jsonshape.as_mapping(loaded) + return [mapping] if mapping is not None else jsonshape.object_list(loaded) + + +def parse_transcript_ids(text: str) -> list[str]: + """Transcript ids parsed from piped stdin, order-preserving and de-duplicated. + + Accepts the shapes a pipeline naturally produces: the JSON array printed by + ``assembly transcripts list --json`` (objects carrying an ``id``), a single + transcript JSON object (``transcripts get --json``), or plain text with one id + per line (e.g. piped through ``jq -r '.[].id'``). Input that isn't JSON falls + back to the line form, so both the jq-free ``list --json | get`` and the + explicit ``… | jq -r '.[].id' | get`` compose. + """ + stripped = text.strip() + if not stripped: + return [] + return list(dict.fromkeys(id_ for id_ in map(_extract_id, _stdin_items(stripped)) if id_)) + + def get_transcript(api_key: str, transcript_id: str) -> aai.Transcript: validate_transcript_id(transcript_id) _configure(api_key) diff --git a/tests/__snapshots__/test_snapshots_help_history.ambr b/tests/__snapshots__/test_snapshots_help_history.ambr index 7cd73176..39522856 100644 --- a/tests/__snapshots__/test_snapshots_help_history.ambr +++ b/tests/__snapshots__/test_snapshots_help_history.ambr @@ -61,12 +61,19 @@ # name: test_command_help_matches_snapshot[transcripts_get] ''' - Usage: assembly transcripts get [OPTIONS] TRANSCRIPT_ID + Usage: assembly transcripts get [OPTIONS] [TRANSCRIPT_ID] Fetch a past transcript by id and print its text + Omit the id to read transcript ids from stdin — one per line, or the JSON from + `assembly transcripts list --json`. Add --llm to transform each transcript + through + LLM Gateway (a map), or --llm-reduce to run one prompt over them all (a + reduce). + ╭─ Arguments ──────────────────────────────────────────────────────────────────╮ - │ * transcript_id TEXT Transcript id [required] │ + │ transcript_id [TRANSCRIPT_ID] Transcript id; omit to read ids from │ + │ stdin │ ╰──────────────────────────────────────────────────────────────────────────────╯ ╭─ Options ────────────────────────────────────────────────────────────────────╮ │ --output -o [text|id|status|uttera Print one field of the │ @@ -77,6 +84,20 @@ │ --json -j Output raw JSON │ │ --help Show this message and │ │ exit. │ + ╰──────────────────────────────────────────────────────────────────────────────╯ + ╭─ LLM Transform ──────────────────────────────────────────────────────────────╮ + │ --llm TEXT Transform each transcript through LLM Gateway. │ + │ Repeatable: each prompt runs on the previous │ + │ one's response (a chain), the first on the │ + │ transcript. │ + │ --llm-reduce TEXT Run one LLM-Gateway prompt over all fetched │ + │ transcripts (a reduce). Repeatable: each runs │ + │ on the previous one's output. For a single id │ + │ it extends the --llm chain over that │ + │ transcript. │ + │ --model TEXT LLM Gateway model │ + │ [default: claude-haiku-4-5-20251001] │ + │ --max-tokens INTEGER Max tokens [default: 1000] │ ╰──────────────────────────────────────────────────────────────────────────────╯ Examples @@ -86,10 +107,16 @@ $ assembly transcripts get 5551234-abcd -o utterances Save SRT subtitles $ assembly transcripts get 5551234-abcd -o srt > captions.srt - Save VTT subtitles - $ assembly transcripts get 5551234-abcd -o vtt > captions.vtt Get the raw JSON $ assembly transcripts get 5551234-abcd --json + Fetch many at once from a piped list + $ assembly transcripts list --json | assembly transcripts get -o text + Summarize each transcript in a piped list + $ assembly transcripts list --json | assembly transcripts get --llm "Summarize + this call" + Rank a piped list with one reduce prompt + $ assembly transcripts list --json | assembly transcripts get --llm-reduce + "Rank these worst-to-best" diff --git a/tests/test_help_rendering.py b/tests/test_help_rendering.py index cfc100ad..acc5897f 100644 --- a/tests/test_help_rendering.py +++ b/tests/test_help_rendering.py @@ -68,7 +68,7 @@ def _json_error(result): @pytest.mark.parametrize( ("argv", "fragment"), [ - (["transcripts", "get", "--json"], "Missing argument 'TRANSCRIPT_ID'"), + (["sessions", "get", "--json"], "Missing argument 'SESSION_ID'"), (["llm", "hi", "--max-tokens", "abc", "--json"], "not a valid integer"), ], ids=["missing-argument", "bad-option-value"], @@ -86,11 +86,11 @@ def test_parse_error_with_json_emits_error_envelope(argv, fragment): def test_parse_error_without_json_keeps_human_panel(): - result = runner.invoke(app, ["transcripts", "get"]) + result = runner.invoke(app, ["sessions", "get"]) assert result.exit_code == 2 plain = _plain(result.output) - assert "Usage: assembly transcripts get" in plain - assert "Missing argument 'TRANSCRIPT_ID'" in plain + assert "Usage: assembly sessions get" in plain + assert "Missing argument 'SESSION_ID'" in plain assert '{"error"' not in plain diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index a73f4079..63924c54 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -109,6 +109,9 @@ def test_get_json_emits_full_payload(mocker): data = json.loads(result.output) assert data["id"] == "t_42" assert data["text"] == "retrieved text" + # A single positional fetch emits the bare payload, not the batch NDJSON record + # (pins `json_mode and batch`: the single path must not carry a "type" wrapper). + assert "type" not in data def test_get_json_emits_full_sdk_payload_when_present(mocker): diff --git a/tests/test_transcripts_pipeline.py b/tests/test_transcripts_pipeline.py new file mode 100644 index 00000000..3932b6dc --- /dev/null +++ b/tests/test_transcripts_pipeline.py @@ -0,0 +1,289 @@ +"""`assembly transcripts get` stdin batching and the --llm/--llm-reduce map-reduce. + +Split out of test_transcripts.py (which holds the single-fetch and `list` tests) to +keep each file under the 500-line gate. +""" + +from __future__ import annotations + +import json + +from typer.testing import CliRunner + +from aai_cli.core import client, config +from aai_cli.main import app + +runner = CliRunner() + + +def _fake_transcript(mocker, *, id_, text): + fake = mocker.MagicMock() + fake.id = id_ + fake.text = text + fake.status = "completed" + fake.json_response = None + return fake + + +def _dispatch_by_id(mocker, mapping): + def fetch(_api_key, transcript_id): + return _fake_transcript(mocker, id_=transcript_id, text=mapping[transcript_id]) + + return mocker.patch( + "aai_cli.commands.transcripts.client.get_transcript", autospec=True, side_effect=fetch + ) + + +def _patch_chain_steps(mocker, fn): + return mocker.patch("aai_cli.commands.transcripts.llm.run_chain_steps", side_effect=fn) + + +def _patch_reduce(mocker, fn): + return mocker.patch("aai_cli.commands.transcripts.llm.run_chain", side_effect=fn) + + +def test_parse_transcript_ids_reads_list_json_array(): + # The array `transcripts list --json` prints: ids pulled out, order preserved, deduped. + text = json.dumps([{"id": "t1", "status": "completed"}, {"id": "t2"}, {"id": "t1"}, {"id": ""}]) + assert client.parse_transcript_ids(text) == ["t1", "t2"] + + +def test_parse_transcript_ids_reads_single_object_and_string_array(): + assert client.parse_transcript_ids('{"id": "t9", "text": "hi"}') == ["t9"] + assert client.parse_transcript_ids('["t1", "t2"]') == ["t1", "t2"] + + +def test_parse_transcript_ids_falls_back_to_lines(): + # Plain text (e.g. `jq -r '.[].id'`): one id per line, blanks dropped, deduped in order. + assert client.parse_transcript_ids("t1\n\n t2 \nt1\n") == ["t1", "t2"] + + +def test_parse_transcript_ids_empty_input_yields_no_ids(): + assert client.parse_transcript_ids(" \n ") == [] + + +def test_get_reads_ids_from_piped_list_json(mocker): + # The headline pipeline: `transcripts list --json | transcripts get -o text`, jq-free. + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "first text", "t2": "second text"}) + piped = json.dumps([{"id": "t1", "status": "completed"}, {"id": "t2", "status": "completed"}]) + result = runner.invoke(app, ["transcripts", "get", "-o", "text"], input=piped) + assert result.exit_code == 0 + # One transcript's text per line, in the piped order. + assert result.output.splitlines() == ["first text", "second text"] + + +def test_get_batch_json_emits_one_ndjson_record_per_id(mocker): + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "first", "t2": "second"}) + result = runner.invoke(app, ["transcripts", "get", "--json"], input="t1\nt2\n") + assert result.exit_code == 0 + records = [json.loads(line) for line in result.output.splitlines() if line.strip()] + # NDJSON stream: one record per id, each tagged with the CLI-wide "type" discriminator. + assert [r["type"] for r in records] == ["transcript", "transcript"] + assert [r["id"] for r in records] == ["t1", "t2"] + + +def test_get_batch_human_prints_plain_text_not_json(mocker): + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "alpha", "t2": "beta"}) + result = runner.invoke(app, ["transcripts", "get"], input="t1\nt2\n") + assert result.exit_code == 0 + # Human batch stays plain text — no NDJSON "type" wrapper leaks in. + assert "alpha" in result.output and "beta" in result.output + assert "type" not in result.output + + +def test_get_no_id_and_no_stdin_is_usage_error(mocker): + config.set_api_key("default", "sk_live") + get = mocker.patch("aai_cli.commands.transcripts.client.get_transcript", autospec=True) + result = runner.invoke(app, ["transcripts", "get"]) + assert result.exit_code == 2 + assert "Give a transcript id" in result.output + get.assert_not_called() + + +def test_get_stdin_without_ids_is_usage_error(mocker): + config.set_api_key("default", "sk_live") + get = mocker.patch("aai_cli.commands.transcripts.client.get_transcript", autospec=True) + result = runner.invoke(app, ["transcripts", "get"], input="[]") + assert result.exit_code == 2 + assert "No transcript ids found on stdin" in result.output + get.assert_not_called() + + +def test_get_single_llm_runs_chain_over_transcript_id(mocker): + config.set_api_key("default", "sk_live") + mocker.patch( + "aai_cli.commands.transcripts.client.get_transcript", + autospec=True, + return_value=_fake_transcript(mocker, id_="t_42", text="hello"), + ) + seen = {} + + def steps(api_key, prompts, *, transcript_id, model, max_tokens): + seen["transcript_id"] = transcript_id + seen["prompts"] = list(prompts) + return [{"prompt": prompts[-1], "output": "SUMMARY"}] + + _patch_chain_steps(mocker, steps) + reduce = _patch_reduce(mocker, lambda *a, **k: "") + result = runner.invoke(app, ["transcripts", "get", "t_42", "--llm", "summarize"]) + assert result.exit_code == 0, result.output + # The map injects the transcript server-side by id, and only the --llm prompt runs. + assert seen == {"transcript_id": "t_42", "prompts": ["summarize"]} + assert "SUMMARY" in result.output + reduce.assert_not_called() + + +def test_get_single_reduce_extends_the_chain(mocker): + # One id: nothing to aggregate, so --llm-reduce becomes extra chain steps (no reduce call). + config.set_api_key("default", "sk_live") + mocker.patch( + "aai_cli.commands.transcripts.client.get_transcript", + autospec=True, + return_value=_fake_transcript(mocker, id_="t1", text="hi"), + ) + seen = {} + + def steps(api_key, prompts, *, transcript_id, model, max_tokens): + seen["prompts"] = list(prompts) + return [{"prompt": prompts[-1], "output": "FINAL"}] + + _patch_chain_steps(mocker, steps) + reduce = _patch_reduce(mocker, lambda *a, **k: "") + result = runner.invoke(app, ["transcripts", "get", "t1", "--llm", "map", "--llm-reduce", "red"]) + assert result.exit_code == 0, result.output + assert seen["prompts"] == ["map", "red"] + reduce.assert_not_called() + assert "FINAL" in result.output + + +def test_get_batch_llm_maps_each_transcript(mocker): + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "a", "t2": "b"}) + + def steps(api_key, prompts, *, transcript_id, model, max_tokens): + return [{"prompt": prompts[-1], "output": f"SUM-{transcript_id}"}] + + _patch_chain_steps(mocker, steps) + reduce = _patch_reduce(mocker, lambda *a, **k: "") + result = runner.invoke(app, ["transcripts", "get", "--llm", "summarize"], input="t1\nt2\n") + assert result.exit_code == 0, result.output + # One map per transcript, rendered as plain text — not the JSON NDJSON wrapper. + assert "SUM-t1" in result.output and "SUM-t2" in result.output + assert "type" not in result.output + reduce.assert_not_called() + + +def test_get_batch_llm_json_emits_transform_records(mocker): + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "a", "t2": "b"}) + + def steps(api_key, prompts, *, transcript_id, model, max_tokens): + return [{"prompt": "summarize", "output": f"S-{transcript_id}"}] + + _patch_chain_steps(mocker, steps) + result = runner.invoke( + app, ["transcripts", "get", "--llm", "summarize", "--json"], input="t1\nt2\n" + ) + assert result.exit_code == 0, result.output + records = [json.loads(line) for line in result.output.splitlines() if line.strip()] + assert [r["type"] for r in records] == ["transcript", "transcript"] + outputs = [r["transform"]["steps"][-1]["output"] for r in records] + assert outputs == ["S-t1", "S-t2"] + + +def test_get_batch_reduce_over_transcript_texts(mocker): + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "alpha", "t2": "beta"}) + steps = mocker.patch("aai_cli.commands.transcripts.llm.run_chain_steps", autospec=True) + captured = {} + + def reduce(api_key, prompts, *, transcript_text, model, max_tokens): + captured["text"] = transcript_text + captured["prompts"] = list(prompts) + return "RANKED" + + _patch_reduce(mocker, reduce) + result = runner.invoke(app, ["transcripts", "get", "--llm-reduce", "rank"], input="t1\nt2\n") + assert result.exit_code == 0, result.output + # No --llm, so each transcript contributes its text under a per-id header. + assert "### Transcript: t1" in captured["text"] + assert "alpha" in captured["text"] and "beta" in captured["text"] + assert captured["prompts"] == ["rank"] + steps.assert_not_called() + # Human reduce keeps stdout clean: only the aggregate prints, not each transcript. + assert "RANKED" in result.output + assert "alpha" not in result.output + + +def test_get_batch_reduce_with_output_field_suppresses_per_item(mocker): + # -o + --llm-reduce in human mode: the per-id field is suppressed like the other + # branches, so only the aggregate reaches stdout — the field still feeds the reduce. + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "alpha", "t2": "beta"}) + captured = {} + + def reduce(api_key, prompts, *, transcript_text, model, max_tokens): + captured["text"] = transcript_text + return "RANKED" + + _patch_reduce(mocker, reduce) + result = runner.invoke( + app, ["transcripts", "get", "-o", "text", "--llm-reduce", "rank"], input="t1\nt2\n" + ) + assert result.exit_code == 0, result.output + assert result.output.strip() == "RANKED" + assert "alpha" in captured["text"] and "beta" in captured["text"] + + +def test_get_batch_reduce_feeds_map_outputs(mocker): + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "a", "t2": "b"}) + + def steps(api_key, prompts, *, transcript_id, model, max_tokens): + return [{"prompt": prompts[-1], "output": f"JUDGED-{transcript_id}"}] + + _patch_chain_steps(mocker, steps) + captured = {} + + def reduce(api_key, prompts, *, transcript_text, model, max_tokens): + captured["text"] = transcript_text + return "FINAL" + + _patch_reduce(mocker, reduce) + result = runner.invoke( + app, ["transcripts", "get", "--llm", "judge", "--llm-reduce", "rank"], input="t1\nt2\n" + ) + assert result.exit_code == 0, result.output + # The reduce sees the --llm output of each transcript, not the raw text. + assert "JUDGED-t1" in captured["text"] and "JUDGED-t2" in captured["text"] + assert "FINAL" in result.output + + +def test_get_batch_reduce_json_emits_reduce_record_and_per_id_records(mocker): + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": "alpha"}) + _patch_reduce(mocker, lambda *a, **k: "RANKED") + result = runner.invoke( + app, ["transcripts", "get", "--llm-reduce", "rank", "--json"], input="t1\n" + ) + assert result.exit_code == 0, result.output + records = [json.loads(line) for line in result.output.splitlines() if line.strip()] + reduce_records = [r for r in records if r.get("type") == "reduce"] + assert len(reduce_records) == 1 + assert reduce_records[0]["output"] == "RANKED" + assert reduce_records[0]["prompts"] == ["rank"] + # JSON keeps the per-id stream too (unlike human reduce, which suppresses it). + assert any(r.get("type") == "transcript" for r in records) + + +def test_get_batch_reduce_skips_when_nothing_to_reduce(mocker): + config.set_api_key("default", "sk_live") + _dispatch_by_id(mocker, {"t1": ""}) # empty transcript text + reduce = mocker.patch("aai_cli.commands.transcripts.llm.run_chain", autospec=True) + result = runner.invoke(app, ["transcripts", "get", "--llm-reduce", "rank"], input="t1\n") + assert result.exit_code == 0, result.output + reduce.assert_not_called() # never fire a billable call over empty input + assert "Nothing to reduce" in result.output