Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,14 @@ assembly init voice-agent && assembly deploy --prod
assembly eval librispeech --speech-model universal-3-pro --limit 50
```

Add `--llm` to run an LLM-Gateway chain over each transcript (the WER score still
uses the raw transcript), and `--llm-reduce` to run one prompt over every item's
result and summarize the errors across the whole run:

```sh
assembly eval tedlium --limit 50 --llm-reduce "Summarize the common error patterns"
```

## 📦 Installation

Requires Python 3.12+ (Homebrew brings its own; for pipx/uv see the `--python` hint below).
Expand Down
6 changes: 6 additions & 0 deletions REFERENCE.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,9 @@ output printed to stdout (the progress table is routed to stderr so stdout stays
clean for piping). `--llm-reduce` is repeatable, each prompt running on the
previous one's output; for a single source it extends the `--llm` chain over
that transcript.

`assembly eval` takes the same `--llm`/`--llm-reduce` flags but emits a single
JSON object (not NDJSON): `--llm` runs a chain over each transcript and attaches
`{"model","steps"}` under the row's `llm` key (the WER score still uses the raw
transcript), and `--llm-reduce` runs one prompt over every item's result and
adds a top-level `reduce` (`{"model","prompts","output"}`) to the object.
41 changes: 41 additions & 0 deletions aai_cli/commands/evaluate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from aai_cli.app.context import run_with_options
from aai_cli.commands.evaluate import _exec as evaluate_exec
from aai_cli.commands.evaluate._exec import EvalSpeechModel
from aai_cli.core import llm
from aai_cli.ui.help_text import examples_epilog

app = typer.Typer()
Expand Down Expand Up @@ -45,6 +46,10 @@
"Evaluate non-English audio",
"assembly eval commonvoice --subset fr --language-code fr",
),
(
"Summarize error patterns across the set",
'assembly eval tedlium --llm-reduce "Summarize the common error patterns"',
),
]
),
)
Expand Down Expand Up @@ -79,6 +84,34 @@ def evaluate(
min=1,
help="How many items to transcribe at once (sequential by default)",
),
llm_prompt: list[str] | None = typer.Option(
None,
"--llm",
help="Transform each transcript through LLM Gateway before reporting (the WER "
"score still uses the raw transcript). Repeatable: each prompt runs on the "
"previous one's response, the first on the transcript.",
rich_help_panel=help_panels.OPT_LLM,
),
llm_reduce: list[str] | None = typer.Option(
None,
"--llm-reduce",
help="Run one LLM-Gateway prompt over every item's result (a reduce). "
"Repeatable: each runs on the previous one's output.",
rich_help_panel=help_panels.OPT_LLM,
),
model: str = typer.Option(
llm.DEFAULT_MODEL,
"--model",
help="LLM Gateway model",
rich_help_panel=help_panels.OPT_LLM,
autocompletion=llm.complete_model,
),
max_tokens: int = typer.Option(
llm.DEFAULT_MAX_TOKENS,
"--max-tokens",
help="Max tokens",
rich_help_panel=help_panels.OPT_LLM,
),
json_out: bool = options.json_option("Output the rows and summary as one JSON object"),
) -> None:
"""Transcribe a dataset and score WER against its reference texts
Expand All @@ -99,6 +132,10 @@ def evaluate(
(English; --subset fr etc. for its 98 other locales), voxpopuli
(parliament speech), switchboard (phone calls), expresso (expressive
speech), loquacious, and callhome (phone calls).

--llm runs an LLM-Gateway chain over each transcript (the WER score still
uses the raw transcript); --llm-reduce then runs one prompt over every
item's result to summarize patterns across the run.
"""
opts = evaluate_exec.EvalOptions(
dataset=dataset,
Expand All @@ -110,5 +147,9 @@ def evaluate(
speech_model=speech_model,
language_code=language_code,
concurrency=concurrency,
llm_prompt=llm_prompt,
llm_reduce=llm_reduce,
model=model,
max_tokens=max_tokens,
)
run_with_options(ctx, evaluate_exec.run_evaluate, opts, json=json_out)
163 changes: 159 additions & 4 deletions aai_cli/commands/evaluate/_exec.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@

import assemblyai as aai
from rich.console import RenderableType
from rich.markup import escape

from aai_cli.app.context import AppState
from aai_cli.commands.evaluate import _data as eval_data
from aai_cli.core import client, jsonshape, wer
from aai_cli.core import llm as gateway
from aai_cli.core.errors import CLIError, NotAuthenticated
from aai_cli.ui import output

Expand All @@ -50,6 +52,31 @@ class EvalOptions:
speech_model: EvalSpeechModel | None
language_code: str | None
concurrency: int
llm_prompt: list[str] | None
llm_reduce: list[str] | None
model: str
max_tokens: int

def llm_options(self) -> _LlmOptions:
"""The ``--llm`` / ``--llm-reduce`` chain settings as plain data."""
return _LlmOptions(
prompts=list(self.llm_prompt or []),
reduce_prompts=list(self.llm_reduce or []),
model=self.model,
max_tokens=self.max_tokens,
)


@dataclass(frozen=True)
class _LlmOptions:
"""The post-transcription LLM-Gateway transform: the per-item ``--llm`` chain
(a *map*) and the across-items ``--llm-reduce`` chain (a *reduce*), plus the
gateway model + token budget both run under."""

prompts: list[str]
reduce_prompts: list[str]
model: str
max_tokens: int


def _pct(value: object) -> str:
Expand All @@ -75,11 +102,16 @@ def _percentile(values: list[float], q: float) -> float:

@dataclass(frozen=True)
class _ItemResult:
"""One scored row: the emitted dict plus the score and latency kept for pooling."""
"""One scored row: the emitted dict plus the score and latency kept for pooling.

``hypothesis`` is the transcript text (``None`` for a failed row) — kept so the
optional ``--llm`` map / ``--llm-reduce`` reduce can run over it after scoring.
"""

row: dict[str, object]
words: wer.Score | None
latency: float
hypothesis: str | None = None


def _failed_result(item: eval_data.EvalItem, err: CLIError, latency: float) -> _ItemResult:
Expand All @@ -94,15 +126,16 @@ def _failed_result(item: eval_data.EvalItem, err: CLIError, latency: float) -> _
def _score_item(
item: eval_data.EvalItem, transcript: aai.Transcript, latency: float
) -> _ItemResult:
words = wer.score(item.reference, str(transcript.text or ""))
hypothesis = str(transcript.text or "")
words = wer.score(item.reference, hypothesis)
row: dict[str, object] = {
"item": item.item_id,
"words": words.words,
"errors": words.errors,
"wer": words.wer,
"latency": latency,
}
return _ItemResult(row=row, words=words, latency=latency)
return _ItemResult(row=row, words=words, latency=latency, hypothesis=hypothesis)


def _pooled_metrics(results: list[_ItemResult]) -> dict[str, object]:
Expand Down Expand Up @@ -204,6 +237,87 @@ def _transcripts(
)


def _run_llm_map(
api_key: str,
results: list[_ItemResult],
llm_opts: _LlmOptions,
*,
json_mode: bool,
quiet: bool,
) -> None:
"""Run the ``--llm`` chain over each transcribed row and attach it under ``llm``.

A *map*: the chain runs over the row's transcript text (inline, like
``stream --llm``) and lands as ``{"model", "steps"}`` on the row — the WER score
is untouched. Failed rows have no transcript, so they're skipped.
"""
scored = [result for result in results if result.hypothesis is not None]
with output.status(
f"Running --llm over {len(scored)} transcripts…", json_mode=json_mode, quiet=quiet
):
for result in scored:
steps = gateway.run_chain_steps(
api_key,
llm_opts.prompts,
transcript_text=result.hypothesis,
model=llm_opts.model,
max_tokens=llm_opts.max_tokens,
)
result.row["llm"] = {"model": llm_opts.model, "steps": steps}


def _reduce_input(result: _ItemResult) -> str:
"""A row's contribution to the reduce: its last ``--llm`` output, else its transcript."""
llm_data = jsonshape.as_mapping(result.row.get("llm"))
if llm_data is not None:
steps = jsonshape.mapping_list(llm_data.get("steps"))
if steps:
return str(steps[-1].get("output", "") or "")
return result.hypothesis or ""


def _gather_reduce_inputs(results: list[_ItemResult]) -> str:
"""Concatenate every transcribed row's reduce input under an item header."""
blocks: list[str] = []
for result in results:
if result.hypothesis is None:
continue
text = _reduce_input(result)
if text:
blocks.append(f"### Item: {result.row.get('item')}\n{text}")
return "\n\n".join(blocks)


def _run_reduce(
api_key: str,
results: list[_ItemResult],
llm_opts: _LlmOptions,
*,
json_mode: bool,
quiet: bool,
) -> dict[str, object] | None:
"""Run the ``--llm-reduce`` chain once over every row's result; the payload entry.

``None`` when there's nothing to aggregate (every row failed or transcribed to
empty text) so the caller skips the (billable) gateway call and the payload key.
"""
combined = _gather_reduce_inputs(results)
if not combined:
output.emit_warning(
"Nothing to reduce: no transcript text across items.", json_mode=json_mode
)
return None
with output.status("Running --llm-reduce over all items…", json_mode=json_mode, quiet=quiet):
result = gateway.run_chain(
api_key,
llm_opts.reduce_prompts,
transcript_text=combined,
model=llm_opts.model,
max_tokens=llm_opts.max_tokens,
)
return {"model": llm_opts.model, "prompts": llm_opts.reduce_prompts, "output": result}


def _payload(
label: str, speech_model: EvalSpeechModel | None, results: list[_ItemResult]
) -> dict[str, object]:
Expand Down Expand Up @@ -249,6 +363,36 @@ def _secs_cell(row: dict[str, object], key: str) -> str:
return _secs(row[key]) if key in row else ""


def _final_llm_output(row: dict[str, object]) -> str | None:
"""A row's last ``--llm`` step output, or ``None`` when no chain ran on it."""
llm_data = jsonshape.as_mapping(row.get("llm"))
if llm_data is None:
return None
steps = jsonshape.mapping_list(llm_data.get("steps"))
return str(steps[-1].get("output", "") or "") if steps else ""


def _llm_block(payload: dict[str, object]) -> str | None:
"""The per-item ``--llm`` outputs as a heading + one ``item: output`` line each,
or ``None`` when no ``--llm`` chain ran."""
lines: list[str] = []
for row in jsonshape.mapping_list(payload.get("rows")):
final = _final_llm_output(row)
if final is not None:
lines.append(f"{escape(str(row.get('item')))}: {escape(final)}")
if not lines:
return None
return "\n".join([output.heading("--llm"), *lines])


def _reduce_block(payload: dict[str, object]) -> str | None:
"""The ``--llm-reduce`` aggregate as a heading + the output, or ``None`` when unset."""
reduce = jsonshape.as_mapping(payload.get("reduce"))
if reduce is None:
return None
return f"{output.heading('--llm-reduce')}\n{escape(str(reduce.get('output', '')))}"


def _render(payload: dict[str, object]) -> RenderableType:
has_wer = "wer" in payload
has_failed = "failed" in payload
Expand All @@ -271,7 +415,11 @@ def _render(payload: dict[str, object]) -> RenderableType:
table.add_row(*cells)
model = payload.get("speech_model") or "default model"
return output.stack(
output.muted(f"{payload.get('dataset')} · {model}"), table, _summary(payload)
output.muted(f"{payload.get('dataset')} · {model}"),
table,
_summary(payload),
_llm_block(payload),
_reduce_block(payload),
)


Expand Down Expand Up @@ -310,7 +458,14 @@ def run_evaluate(opts: EvalOptions, state: AppState, *, json_mode: bool) -> None
strict=True, # pragma: no mutate (defensive invariant; _transcripts returns one outcome per item)
)
]
llm_opts = opts.llm_options()
if llm_opts.prompts:
_run_llm_map(api_key, results, llm_opts, json_mode=json_mode, quiet=state.quiet)
payload = _payload(data.label, opts.speech_model, results)
if llm_opts.reduce_prompts:
reduce = _run_reduce(api_key, results, llm_opts, json_mode=json_mode, quiet=state.quiet)
if reduce is not None:
payload["reduce"] = reduce
output.emit(payload, _render, json_mode=json_mode)
failed = jsonshape.as_int(payload.get("failed"))
if failed:
Expand Down
19 changes: 19 additions & 0 deletions tests/__snapshots__/test_snapshots_help_run.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@
(parliament speech), switchboard (phone calls), expresso (expressive
speech), loquacious, and callhome (phone calls).

--llm runs an LLM-Gateway chain over each transcript (the WER score still
uses the raw transcript); --llm-reduce then runs one prompt over every
item's result to summarize patterns across the run.

╭─ Arguments ──────────────────────────────────────────────────────────────────╮
│ * dataset TEXT Hugging Face dataset id, or a local .csv/.jsonl │
│ manifest with audio + text columns │
Expand Down Expand Up @@ -527,6 +531,19 @@
│ object │
│ --help Show this message and │
│ exit. │
╰──────────────────────────────────────────────────────────────────────────────╯
╭─ LLM Transform ──────────────────────────────────────────────────────────────╮
│ --llm TEXT Transform each transcript through LLM Gateway │
│ before reporting (the WER score still uses the │
│ raw transcript). Repeatable: each prompt runs │
│ on the previous one's response, the first on │
│ the transcript. │
│ --llm-reduce TEXT Run one LLM-Gateway prompt over every item's │
│ result (a reduce). Repeatable: each runs on the │
│ previous one's output. │
│ --model TEXT LLM Gateway model │
│ [default: claude-haiku-4-5-20251001] │
│ --max-tokens INTEGER Max tokens [default: 1000] │
╰──────────────────────────────────────────────────────────────────────────────╯

Examples
Expand All @@ -538,6 +555,8 @@
$ assembly eval librispeech --limit 50 --concurrency 4
Evaluate non-English audio
$ assembly eval commonvoice --subset fr --language-code fr
Summarize error patterns across the set
$ assembly eval tedlium --llm-reduce "Summarize the common error patterns"



Expand Down
Loading
Loading