From 5e5a503b40bd56698857398141f85beeea2eae6d Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 13 Jun 2026 20:30:47 +0000 Subject: [PATCH 1/2] Type the agent NDJSON events and HF /splits parsing with Pydantic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two stringly-typed boundaries now carry typed models, following the auth/flow.py + account.py precedent: - agent/events.py: the `assembly agent --json` event stream was hand-built `{"type": …}` dicts the type checker only saw as `dict[str, str]`. Each event is now a closed, frozen Pydantic model whose `type` literal and payload are pinned at type-check time, so a renamed key or mistyped `type` can't drift onto the wire. AgentRenderer emits via `model_dump()`; the wire shapes are byte-identical (existing render goldens unchanged). - evaluate/_hf_api.py: `split_entries()` returned bare dicts and subset/split selection read `str(entry.get("config"))`. A `_SplitEntry` model (validated via a module-level TypeAdapter) gives `pick_subset`/`pick_split` typed fields and turns a malformed /splits payload into a clean APIError instead of a stringified "None". https://claude.ai/code/session_01AzkXsmPQSoUJjPgJY6qvGB --- aai_cli/agent/events.py | 70 ++++++++++++++++++++++++++++ aai_cli/agent/render.py | 17 ++++--- aai_cli/commands/evaluate/_hf_api.py | 36 ++++++++++---- tests/test_agent_events.py | 42 +++++++++++++++++ tests/test_eval_data_hf.py | 8 ++++ 5 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 aai_cli/agent/events.py create mode 100644 tests/test_agent_events.py diff --git a/aai_cli/agent/events.py b/aai_cli/agent/events.py new file mode 100644 index 00000000..0f917e15 --- /dev/null +++ b/aai_cli/agent/events.py @@ -0,0 +1,70 @@ +"""Typed Voice Agent NDJSON events. + +The ``assembly agent --json`` stream is a public contract: every line carries a +``type`` discriminator (see docs/cli-reference.md) and a small, fixed payload. +Modelling each event as a Pydantic class — rather than hand-building +``{"type": …}`` dicts the type checker only sees as ``dict[str, str]`` — pins the +``type`` literal and the payload fields at type-check time, so a renamed key or a +mistyped ``type`` value fails before it can drift onto the wire. ``model_dump()`` +is the serialized form the renderer emits. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, ConfigDict + + +class _Event(BaseModel): + """Base for Voice Agent events: a closed, frozen wire model. + + ``extra="forbid"`` keeps a stray field from silently riding along on the + stream, and ``frozen=True`` makes an emitted event immutable. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + +class SessionReady(_Event): + """The agent session connected and is ready for audio.""" + + type: Literal["session.ready"] = "session.ready" + + +class UserDelta(_Event): + """An interim (partial) user transcript.""" + + type: Literal["transcript.user.delta"] = "transcript.user.delta" + text: str + + +class UserFinal(_Event): + """A finalized user transcript turn.""" + + type: Literal["transcript.user"] = "transcript.user" + text: str + + +class ReplyStarted(_Event): + """The agent began generating a reply.""" + + type: Literal["reply.started"] = "reply.started" + + +class AgentTranscript(_Event): + """The agent's reply transcript (``interrupted`` when the user barged in).""" + + type: Literal["transcript.agent"] = "transcript.agent" + text: str + interrupted: bool + + +class ReplyDone(_Event): + """The agent finished, or was interrupted out of, a reply.""" + + type: Literal["reply.done"] = "reply.done" + interrupted: bool + + +Event = SessionReady | UserDelta | UserFinal | ReplyStarted | AgentTranscript | ReplyDone diff --git a/aai_cli/agent/render.py b/aai_cli/agent/render.py index 70a86ecc..e64df83f 100644 --- a/aai_cli/agent/render.py +++ b/aai_cli/agent/render.py @@ -4,6 +4,7 @@ from rich.text import Text +from aai_cli.agent import events from aai_cli.ui.render import BaseRenderer @@ -28,10 +29,14 @@ def __init__(self, *, mic_input: bool = True, **kwargs: Any) -> None: # File-driven runs have no mic, so they skip the "start talking" prompt. self.mic_input = mic_input + def _emit_event(self, event: events.Event) -> None: + """Serialize one typed agent event to the NDJSON stream.""" + self._emit(event.model_dump()) + # --- lifecycle --------------------------------------------------------- def connected(self) -> None: if self.json_mode: - self._emit({"type": "session.ready"}) + self._emit_event(events.SessionReady()) elif not self.mic_input: return elif self.text_mode: @@ -53,13 +58,13 @@ def notice(self, text: str) -> None: # --- user -------------------------------------------------------------- def user_partial(self, text: str) -> None: if self.json_mode: - self._emit({"type": "transcript.user.delta", "text": text}) + self._emit_event(events.UserDelta(text=text)) elif not self.text_mode: # partials are noise for piped text self._update_line(_labeled("you: ", text, style="aai.you")) def user_final(self, text: str) -> None: if self.json_mode: - self._emit({"type": "transcript.user", "text": text}) + self._emit_event(events.UserFinal(text=text)) elif self.text_mode: self._write(f"you: {text}\n") else: @@ -68,11 +73,11 @@ def user_final(self, text: str) -> None: # --- agent ------------------------------------------------------------- def reply_started(self) -> None: if self.json_mode: - self._emit({"type": "reply.started"}) + self._emit_event(events.ReplyStarted()) def agent_transcript(self, text: str, *, interrupted: bool) -> None: if self.json_mode: - self._emit({"type": "transcript.agent", "text": text, "interrupted": interrupted}) + self._emit_event(events.AgentTranscript(text=text, interrupted=interrupted)) elif self.text_mode: self._write(f"agent: {text}\n") else: @@ -81,4 +86,4 @@ def agent_transcript(self, text: str, *, interrupted: bool) -> None: def reply_done(self, *, interrupted: bool) -> None: if self.json_mode: - self._emit({"type": "reply.done", "interrupted": interrupted}) + self._emit_event(events.ReplyDone(interrupted=interrupted)) diff --git a/aai_cli/commands/evaluate/_hf_api.py b/aai_cli/commands/evaluate/_hf_api.py index c2285886..5ad73780 100644 --- a/aai_cli/commands/evaluate/_hf_api.py +++ b/aai_cli/commands/evaluate/_hf_api.py @@ -11,6 +11,7 @@ from http import HTTPStatus import httpx2 as httpx +from pydantic import BaseModel, ConfigDict, TypeAdapter, ValidationError from aai_cli.core import env, jsonshape from aai_cli.core.errors import APIError, UsageError @@ -91,16 +92,37 @@ def fetch_json(endpoint: str, params: dict[str, str | int], *, dataset: str) -> return _checked_payload(resp, dataset=dataset) -def split_entries(dataset: str) -> list[dict[str, object]]: +class _SplitEntry(BaseModel): + """One ``/splits`` entry: the (config, split) pair naming a dataset slice. + + ``extra="allow"`` keeps the sibling fields the server returns (``dataset``, + ``num_examples``) without modelling them, so subset/split selection reads + typed fields instead of ``str(entry.get("config"))`` off a bare dict. + """ + + model_config = ConfigDict(extra="allow") + config: str + split: str + + +_SPLIT_ENTRIES = TypeAdapter(list[_SplitEntry]) + + +def split_entries(dataset: str) -> list[_SplitEntry]: payload = fetch_json("/splits", {"dataset": dataset}, dataset=dataset) - entries = jsonshape.mapping_list(payload.get("splits")) + try: + entries = _SPLIT_ENTRIES.validate_python(payload.get("splits") or []) + except ValidationError as exc: + raise APIError( + f"Hugging Face returned an unexpected /splits payload for '{dataset}'." + ) from exc if not entries: raise APIError(f"Hugging Face reports no splits for '{dataset}'.") return entries -def pick_subset(entries: list[dict[str, object]], subset: str | None, dataset: str) -> str: - configs = list(dict.fromkeys(str(entry.get("config")) for entry in entries)) +def pick_subset(entries: list[_SplitEntry], subset: str | None, dataset: str) -> str: + configs = list(dict.fromkeys(entry.config for entry in entries)) if subset is not None: if subset in configs: return subset @@ -115,10 +137,8 @@ def pick_subset(entries: list[dict[str, object]], subset: str | None, dataset: s ) -def pick_split( - entries: list[dict[str, object]], config: str, split: str | None, dataset: str -) -> str: - splits = [str(entry.get("split")) for entry in entries if str(entry.get("config")) == config] +def pick_split(entries: list[_SplitEntry], config: str, split: str | None, dataset: str) -> str: + splits = [entry.split for entry in entries if entry.config == config] if split is not None: if split in splits: return split diff --git a/tests/test_agent_events.py b/tests/test_agent_events.py new file mode 100644 index 00000000..676d8efe --- /dev/null +++ b/tests/test_agent_events.py @@ -0,0 +1,42 @@ +"""The typed Voice Agent NDJSON events (`aai_cli.agent.events`). + +These pin the wire contract `assembly agent --json` emits: one canonical place +that asserts each event's `type` discriminator and payload, plus the closed/ +frozen model guarantees the renderer relies on. +""" + +import pytest +from pydantic import ValidationError + +from aai_cli.agent import events + + +@pytest.mark.parametrize( + ("event", "expected"), + [ + (events.SessionReady(), {"type": "session.ready"}), + (events.UserDelta(text="typing…"), {"type": "transcript.user.delta", "text": "typing…"}), + (events.UserFinal(text="hello"), {"type": "transcript.user", "text": "hello"}), + (events.ReplyStarted(), {"type": "reply.started"}), + ( + events.AgentTranscript(text="hi back", interrupted=False), + {"type": "transcript.agent", "text": "hi back", "interrupted": False}, + ), + (events.ReplyDone(interrupted=True), {"type": "reply.done", "interrupted": True}), + ], +) +def test_event_wire_shape(event: events.Event, expected: dict[str, object]): + assert event.model_dump() == expected + + +def test_events_are_frozen(): + # An emitted event is immutable: the stream record can't be mutated after the fact. + event = events.UserFinal(text="hello") + with pytest.raises(ValidationError): + event.text = "tampered" + + +def test_events_reject_unknown_fields(): + # extra="forbid": a stray key is a programming error, not a silent passenger. + with pytest.raises(ValidationError): + events.UserFinal.model_validate({"text": "hello", "bogus": 1}) diff --git a/tests/test_eval_data_hf.py b/tests/test_eval_data_hf.py index 47357d6b..43487fc8 100644 --- a/tests/test_eval_data_hf.py +++ b/tests/test_eval_data_hf.py @@ -200,6 +200,14 @@ def test_hf_no_rows(monkeypatch): eval_data.load("org/ds", limit=1) +def test_hf_split_entry_missing_required_field_is_an_api_error(monkeypatch): + # A /splits entry without the required config/split pair is a schema surprise, + # not a usable slice — surface it as a clean AMS-style error, not a KeyError. + _hf_handler(monkeypatch, splits=[{"split": "test"}], rows=[_hf_row()]) + with pytest.raises(APIError, match="unexpected /splits payload"): + eval_data.load("org/ds", limit=1) + + @pytest.mark.parametrize("status", [401, 403]) def test_hf_auth_failure_suggests_hf_token(monkeypatch, status): _patch_transport(monkeypatch, lambda request: httpx.Response(status, json={"error": "gated"})) From db17c37d4b1f11b9045449c0aa3823370e54cb80 Mon Sep 17 00:00:00 2001 From: Claude Date: Sat, 13 Jun 2026 20:43:25 +0000 Subject: [PATCH 2/2] Type the streaming NDJSON events with Pydantic Extends the agent-events pattern to `assembly stream --json`. The begin/turn/ termination records were hand-built dicts assembled with `jsonshape.compact` + `_with_source`; each is now a closed, frozen Pydantic model whose `type` literal and payload are pinned at type-check time. The two presence rules are preserved in a shared `wire()`: the optional annotations `source` (parallel system/you streams) and `speaker` (--speaker-labels diarization) drop out of the record when absent, while the core payload (`id`, `audio_duration_seconds`) stays present even when null. The `id` field is `session_id` in Python with a serialization_alias so the model stays self-contained (no flake8-builtins A003 carve-out). Wire output is byte-identical; the existing render goldens are unchanged. https://claude.ai/code/session_01AzkXsmPQSoUJjPgJY6qvGB --- aai_cli/streaming/events.py | 69 ++++++++++++++++++++++++++++++++++ aai_cli/streaming/render.py | 35 +++++------------ tests/test_streaming_events.py | 64 +++++++++++++++++++++++++++++++ 3 files changed, 142 insertions(+), 26 deletions(-) create mode 100644 aai_cli/streaming/events.py create mode 100644 tests/test_streaming_events.py diff --git a/aai_cli/streaming/events.py b/aai_cli/streaming/events.py new file mode 100644 index 00000000..39501a3a --- /dev/null +++ b/aai_cli/streaming/events.py @@ -0,0 +1,69 @@ +"""Typed streaming NDJSON events. + +The ``assembly stream --json`` stream is a public contract: every line carries a +``type`` discriminator (see docs/cli-reference.md). Modelling each event as a +Pydantic class — rather than hand-building ``{"type": …}`` dicts the type checker +only sees as ``dict[str, object]`` — pins each event's ``type`` literal and +payload at type-check time. + +Two presence rules the renderer relied on are preserved in ``wire()``: the +optional *annotations* ``source`` (parallel system/you streams) and ``speaker`` +(``--speaker-labels`` diarization) drop out of the record when absent, while the +core payload (``id``, ``audio_duration_seconds``) stays present even when null. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field + +# Optional annotations omitted from the record when None; the core payload is kept. +_OMIT_WHEN_NONE = ("speaker", "source") + + +class _StreamEvent(BaseModel): + """Base for streaming events: a closed, frozen wire model. + + ``extra="forbid"`` keeps a stray field off the stream and ``frozen=True`` + makes an emitted event immutable. + """ + + model_config = ConfigDict(extra="forbid", frozen=True) + + def wire(self) -> dict[str, object]: + """The NDJSON record: optional annotations drop out when absent.""" + data: dict[str, object] = self.model_dump(by_alias=True) + for key in _OMIT_WHEN_NONE: + if data.get(key) is None: + data.pop(key, None) + return data + + +class Begin(_StreamEvent): + """The session opened; ``id`` is the streaming session id.""" + + type: Literal["begin"] = "begin" + session_id: str | None = Field(serialization_alias="id") + source: str | None = None + + +class Turn(_StreamEvent): + """A turn transcript: interim while ``end_of_turn`` is False, finalized when True.""" + + type: Literal["turn"] = "turn" + transcript: str + end_of_turn: bool + speaker: str | None = None + source: str | None = None + + +class Termination(_StreamEvent): + """The session closed; ``audio_duration_seconds`` is the total audio processed.""" + + type: Literal["termination"] = "termination" + audio_duration_seconds: float | None + source: str | None = None + + +Event = Begin | Turn | Termination diff --git a/aai_cli/streaming/render.py b/aai_cli/streaming/render.py index 9907374a..e2a720a5 100644 --- a/aai_cli/streaming/render.py +++ b/aai_cli/streaming/render.py @@ -6,7 +6,7 @@ from rich.console import Console from rich.text import Text -from aai_cli.core import jsonshape +from aai_cli.streaming import events from aai_cli.ui import theme from aai_cli.ui.render import BaseRenderer @@ -71,12 +71,6 @@ def __init__( ) self._lock = threading.RLock() - @staticmethod - def _with_source(payload: dict[str, object], source: str | None) -> dict[str, object]: - if source is not None: - payload["source"] = source - return payload - @staticmethod def _label(text: str, source: str | None, speaker: str | None = None) -> str: prefix = speaker_prefix(source, speaker) @@ -99,7 +93,7 @@ def begin(self, event: object, *, source: str | None = None) -> None: with self._lock: if self.json_mode: self._emit( - self._with_source({"type": "begin", "id": getattr(event, "id", None)}, source) + events.Begin(session_id=getattr(event, "id", None), source=source).wire() ) def listening(self) -> None: @@ -116,16 +110,12 @@ def turn(self, event: object, *, source: str | None = None) -> None: speaker = getattr(event, "speaker_label", None) # set when --speaker-labels diarizes with self._lock: if self.json_mode: - # speaker is omitted entirely when undiarized (not null). - payload = jsonshape.compact( - { - "type": "turn", - "transcript": text, - "end_of_turn": end, - "speaker": speaker, - } + # speaker/source are omitted entirely when absent (not null) — see wire(). + self._emit( + events.Turn( + transcript=text, end_of_turn=end, speaker=speaker, source=source + ).wire() ) - self._emit(self._with_source(payload, source)) elif self.text_mode: if end and text: self._write(self._label(text, source, speaker) + "\n") # plain finalized line @@ -137,16 +127,9 @@ def turn(self, event: object, *, source: str | None = None) -> None: def termination(self, event: object, *, source: str | None = None) -> None: with self._lock: if self.json_mode: + duration = getattr(event, "audio_duration_seconds", None) self._emit( - self._with_source( - { - "type": "termination", - "audio_duration_seconds": getattr( - event, "audio_duration_seconds", None - ), - }, - source, - ) + events.Termination(audio_duration_seconds=duration, source=source).wire() ) def stopped(self) -> None: diff --git a/tests/test_streaming_events.py b/tests/test_streaming_events.py new file mode 100644 index 00000000..28f1bd68 --- /dev/null +++ b/tests/test_streaming_events.py @@ -0,0 +1,64 @@ +"""The typed streaming NDJSON events (`aai_cli.streaming.events`). + +One canonical place that asserts each event's `type` discriminator and payload, +the omit-when-absent rule for `source`/`speaker`, and the closed/frozen model +guarantees the renderer relies on. +""" + +import pytest +from pydantic import ValidationError + +from aai_cli.streaming import events + + +@pytest.mark.parametrize( + ("event", "expected"), + [ + # source is omitted when absent, but id stays present even when null. + (events.Begin(session_id="sess_1"), {"type": "begin", "id": "sess_1"}), + (events.Begin(session_id=None), {"type": "begin", "id": None}), + ( + events.Begin(session_id="sess_1", source="system"), + {"type": "begin", "id": "sess_1", "source": "system"}, + ), + # speaker and source both drop out when undiarized / single-stream. + ( + events.Turn(transcript="hi", end_of_turn=True), + {"type": "turn", "transcript": "hi", "end_of_turn": True}, + ), + ( + events.Turn(transcript="hi", end_of_turn=True, speaker="A", source="system"), + { + "type": "turn", + "transcript": "hi", + "end_of_turn": True, + "speaker": "A", + "source": "system", + }, + ), + # audio_duration_seconds stays present even when null. + ( + events.Termination(audio_duration_seconds=12.5), + {"type": "termination", "audio_duration_seconds": 12.5}, + ), + ( + events.Termination(audio_duration_seconds=None), + {"type": "termination", "audio_duration_seconds": None}, + ), + ], +) +def test_wire_record(event: events.Event, expected: dict[str, object]): + assert event.wire() == expected + + +def test_events_are_frozen(): + # An emitted event is immutable: the stream record can't be mutated after the fact. + event = events.Turn(transcript="hi", end_of_turn=True) + with pytest.raises(ValidationError): + event.transcript = "tampered" + + +def test_events_reject_unknown_fields(): + # extra="forbid": a stray key is a programming error, not a silent passenger. + with pytest.raises(ValidationError): + events.Turn.model_validate({"transcript": "hi", "end_of_turn": True, "bogus": 1})