Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions aai_cli/agent/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Typed Voice Agent NDJSON events.

The ``assembly agent --json`` stream is a public contract: every line carries a
``type`` discriminator (see docs/cli-reference.md) and a small, fixed payload.
Modelling each event as a Pydantic class — rather than hand-building
``{"type": …}`` dicts the type checker only sees as ``dict[str, str]`` — pins the
``type`` literal and the payload fields at type-check time, so a renamed key or a
mistyped ``type`` value fails before it can drift onto the wire. ``model_dump()``
is the serialized form the renderer emits.
"""

from __future__ import annotations

from typing import Literal

from pydantic import BaseModel, ConfigDict


class _Event(BaseModel):
"""Base for Voice Agent events: a closed, frozen wire model.

``extra="forbid"`` keeps a stray field from silently riding along on the
stream, and ``frozen=True`` makes an emitted event immutable.
"""

model_config = ConfigDict(extra="forbid", frozen=True)


class SessionReady(_Event):
"""The agent session connected and is ready for audio."""

type: Literal["session.ready"] = "session.ready"


class UserDelta(_Event):
"""An interim (partial) user transcript."""

type: Literal["transcript.user.delta"] = "transcript.user.delta"
text: str


class UserFinal(_Event):
"""A finalized user transcript turn."""

type: Literal["transcript.user"] = "transcript.user"
text: str


class ReplyStarted(_Event):
"""The agent began generating a reply."""

type: Literal["reply.started"] = "reply.started"


class AgentTranscript(_Event):
"""The agent's reply transcript (``interrupted`` when the user barged in)."""

type: Literal["transcript.agent"] = "transcript.agent"
text: str
interrupted: bool


class ReplyDone(_Event):
"""The agent finished, or was interrupted out of, a reply."""

type: Literal["reply.done"] = "reply.done"
interrupted: bool


Event = SessionReady | UserDelta | UserFinal | ReplyStarted | AgentTranscript | ReplyDone
17 changes: 11 additions & 6 deletions aai_cli/agent/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from rich.text import Text

from aai_cli.agent import events
from aai_cli.ui.render import BaseRenderer


Expand All @@ -28,10 +29,14 @@ def __init__(self, *, mic_input: bool = True, **kwargs: Any) -> None:
# File-driven runs have no mic, so they skip the "start talking" prompt.
self.mic_input = mic_input

def _emit_event(self, event: events.Event) -> None:
"""Serialize one typed agent event to the NDJSON stream."""
self._emit(event.model_dump())

# --- lifecycle ---------------------------------------------------------
def connected(self) -> None:
if self.json_mode:
self._emit({"type": "session.ready"})
self._emit_event(events.SessionReady())
elif not self.mic_input:
return
elif self.text_mode:
Expand All @@ -53,13 +58,13 @@ def notice(self, text: str) -> None:
# --- user --------------------------------------------------------------
def user_partial(self, text: str) -> None:
if self.json_mode:
self._emit({"type": "transcript.user.delta", "text": text})
self._emit_event(events.UserDelta(text=text))
elif not self.text_mode: # partials are noise for piped text
self._update_line(_labeled("you: ", text, style="aai.you"))

def user_final(self, text: str) -> None:
if self.json_mode:
self._emit({"type": "transcript.user", "text": text})
self._emit_event(events.UserFinal(text=text))
elif self.text_mode:
self._write(f"you: {text}\n")
else:
Expand All @@ -68,11 +73,11 @@ def user_final(self, text: str) -> None:
# --- agent -------------------------------------------------------------
def reply_started(self) -> None:
if self.json_mode:
self._emit({"type": "reply.started"})
self._emit_event(events.ReplyStarted())

def agent_transcript(self, text: str, *, interrupted: bool) -> None:
if self.json_mode:
self._emit({"type": "transcript.agent", "text": text, "interrupted": interrupted})
self._emit_event(events.AgentTranscript(text=text, interrupted=interrupted))
elif self.text_mode:
self._write(f"agent: {text}\n")
else:
Expand All @@ -81,4 +86,4 @@ def agent_transcript(self, text: str, *, interrupted: bool) -> None:

def reply_done(self, *, interrupted: bool) -> None:
if self.json_mode:
self._emit({"type": "reply.done", "interrupted": interrupted})
self._emit_event(events.ReplyDone(interrupted=interrupted))
36 changes: 28 additions & 8 deletions aai_cli/commands/evaluate/_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from http import HTTPStatus

import httpx2 as httpx
from pydantic import BaseModel, ConfigDict, TypeAdapter, ValidationError

from aai_cli.core import env, jsonshape
from aai_cli.core.errors import APIError, UsageError
Expand Down Expand Up @@ -91,16 +92,37 @@ def fetch_json(endpoint: str, params: dict[str, str | int], *, dataset: str) ->
return _checked_payload(resp, dataset=dataset)


def split_entries(dataset: str) -> list[dict[str, object]]:
class _SplitEntry(BaseModel):
"""One ``/splits`` entry: the (config, split) pair naming a dataset slice.

``extra="allow"`` keeps the sibling fields the server returns (``dataset``,
``num_examples``) without modelling them, so subset/split selection reads
typed fields instead of ``str(entry.get("config"))`` off a bare dict.
"""

model_config = ConfigDict(extra="allow")
config: str
split: str


_SPLIT_ENTRIES = TypeAdapter(list[_SplitEntry])


def split_entries(dataset: str) -> list[_SplitEntry]:
payload = fetch_json("/splits", {"dataset": dataset}, dataset=dataset)
entries = jsonshape.mapping_list(payload.get("splits"))
try:
entries = _SPLIT_ENTRIES.validate_python(payload.get("splits") or [])
except ValidationError as exc:
raise APIError(
f"Hugging Face returned an unexpected /splits payload for '{dataset}'."
) from exc
if not entries:
raise APIError(f"Hugging Face reports no splits for '{dataset}'.")
return entries


def pick_subset(entries: list[dict[str, object]], subset: str | None, dataset: str) -> str:
configs = list(dict.fromkeys(str(entry.get("config")) for entry in entries))
def pick_subset(entries: list[_SplitEntry], subset: str | None, dataset: str) -> str:
configs = list(dict.fromkeys(entry.config for entry in entries))
if subset is not None:
if subset in configs:
return subset
Expand All @@ -115,10 +137,8 @@ def pick_subset(entries: list[dict[str, object]], subset: str | None, dataset: s
)


def pick_split(
entries: list[dict[str, object]], config: str, split: str | None, dataset: str
) -> str:
splits = [str(entry.get("split")) for entry in entries if str(entry.get("config")) == config]
def pick_split(entries: list[_SplitEntry], config: str, split: str | None, dataset: str) -> str:
splits = [entry.split for entry in entries if entry.config == config]
if split is not None:
if split in splits:
return split
Expand Down
69 changes: 69 additions & 0 deletions aai_cli/streaming/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""Typed streaming NDJSON events.

The ``assembly stream --json`` stream is a public contract: every line carries a
``type`` discriminator (see docs/cli-reference.md). Modelling each event as a
Pydantic class — rather than hand-building ``{"type": …}`` dicts the type checker
only sees as ``dict[str, object]`` — pins each event's ``type`` literal and
payload at type-check time.

Two presence rules the renderer relied on are preserved in ``wire()``: the
optional *annotations* ``source`` (parallel system/you streams) and ``speaker``
(``--speaker-labels`` diarization) drop out of the record when absent, while the
core payload (``id``, ``audio_duration_seconds``) stays present even when null.
"""

from __future__ import annotations

from typing import Literal

from pydantic import BaseModel, ConfigDict, Field

# Optional annotations omitted from the record when None; the core payload is kept.
_OMIT_WHEN_NONE = ("speaker", "source")


class _StreamEvent(BaseModel):
"""Base for streaming events: a closed, frozen wire model.

``extra="forbid"`` keeps a stray field off the stream and ``frozen=True``
makes an emitted event immutable.
"""

model_config = ConfigDict(extra="forbid", frozen=True)

def wire(self) -> dict[str, object]:
"""The NDJSON record: optional annotations drop out when absent."""
data: dict[str, object] = self.model_dump(by_alias=True)
for key in _OMIT_WHEN_NONE:
if data.get(key) is None:
data.pop(key, None)
return data


class Begin(_StreamEvent):
"""The session opened; ``id`` is the streaming session id."""

type: Literal["begin"] = "begin"
session_id: str | None = Field(serialization_alias="id")
source: str | None = None


class Turn(_StreamEvent):
"""A turn transcript: interim while ``end_of_turn`` is False, finalized when True."""

type: Literal["turn"] = "turn"
transcript: str
end_of_turn: bool
speaker: str | None = None
source: str | None = None


class Termination(_StreamEvent):
"""The session closed; ``audio_duration_seconds`` is the total audio processed."""

type: Literal["termination"] = "termination"
audio_duration_seconds: float | None
source: str | None = None


Event = Begin | Turn | Termination
35 changes: 9 additions & 26 deletions aai_cli/streaming/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from rich.console import Console
from rich.text import Text

from aai_cli.core import jsonshape
from aai_cli.streaming import events
from aai_cli.ui import theme
from aai_cli.ui.render import BaseRenderer

Expand Down Expand Up @@ -71,12 +71,6 @@ def __init__(
)
self._lock = threading.RLock()

@staticmethod
def _with_source(payload: dict[str, object], source: str | None) -> dict[str, object]:
if source is not None:
payload["source"] = source
return payload

@staticmethod
def _label(text: str, source: str | None, speaker: str | None = None) -> str:
prefix = speaker_prefix(source, speaker)
Expand All @@ -99,7 +93,7 @@ def begin(self, event: object, *, source: str | None = None) -> None:
with self._lock:
if self.json_mode:
self._emit(
self._with_source({"type": "begin", "id": getattr(event, "id", None)}, source)
events.Begin(session_id=getattr(event, "id", None), source=source).wire()
)

def listening(self) -> None:
Expand All @@ -116,16 +110,12 @@ def turn(self, event: object, *, source: str | None = None) -> None:
speaker = getattr(event, "speaker_label", None) # set when --speaker-labels diarizes
with self._lock:
if self.json_mode:
# speaker is omitted entirely when undiarized (not null).
payload = jsonshape.compact(
{
"type": "turn",
"transcript": text,
"end_of_turn": end,
"speaker": speaker,
}
# speaker/source are omitted entirely when absent (not null) — see wire().
self._emit(
events.Turn(
transcript=text, end_of_turn=end, speaker=speaker, source=source
).wire()
)
self._emit(self._with_source(payload, source))
elif self.text_mode:
if end and text:
self._write(self._label(text, source, speaker) + "\n") # plain finalized line
Expand All @@ -137,16 +127,9 @@ def turn(self, event: object, *, source: str | None = None) -> None:
def termination(self, event: object, *, source: str | None = None) -> None:
with self._lock:
if self.json_mode:
duration = getattr(event, "audio_duration_seconds", None)
self._emit(
self._with_source(
{
"type": "termination",
"audio_duration_seconds": getattr(
event, "audio_duration_seconds", None
),
},
source,
)
events.Termination(audio_duration_seconds=duration, source=source).wire()
)

def stopped(self) -> None:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_agent_events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""The typed Voice Agent NDJSON events (`aai_cli.agent.events`).

These pin the wire contract `assembly agent --json` emits: one canonical place
that asserts each event's `type` discriminator and payload, plus the closed/
frozen model guarantees the renderer relies on.
"""

import pytest
from pydantic import ValidationError

from aai_cli.agent import events


@pytest.mark.parametrize(
("event", "expected"),
[
(events.SessionReady(), {"type": "session.ready"}),
(events.UserDelta(text="typing…"), {"type": "transcript.user.delta", "text": "typing…"}),
(events.UserFinal(text="hello"), {"type": "transcript.user", "text": "hello"}),
(events.ReplyStarted(), {"type": "reply.started"}),
(
events.AgentTranscript(text="hi back", interrupted=False),
{"type": "transcript.agent", "text": "hi back", "interrupted": False},
),
(events.ReplyDone(interrupted=True), {"type": "reply.done", "interrupted": True}),
],
)
def test_event_wire_shape(event: events.Event, expected: dict[str, object]):
assert event.model_dump() == expected


def test_events_are_frozen():
# An emitted event is immutable: the stream record can't be mutated after the fact.
event = events.UserFinal(text="hello")
with pytest.raises(ValidationError):
event.text = "tampered"


def test_events_reject_unknown_fields():
# extra="forbid": a stray key is a programming error, not a silent passenger.
with pytest.raises(ValidationError):
events.UserFinal.model_validate({"text": "hello", "bogus": 1})
8 changes: 8 additions & 0 deletions tests/test_eval_data_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,14 @@ def test_hf_no_rows(monkeypatch):
eval_data.load("org/ds", limit=1)


def test_hf_split_entry_missing_required_field_is_an_api_error(monkeypatch):
# A /splits entry without the required config/split pair is a schema surprise,
# not a usable slice — surface it as a clean AMS-style error, not a KeyError.
_hf_handler(monkeypatch, splits=[{"split": "test"}], rows=[_hf_row()])
with pytest.raises(APIError, match="unexpected /splits payload"):
eval_data.load("org/ds", limit=1)


@pytest.mark.parametrize("status", [401, 403])
def test_hf_auth_failure_suggests_hf_token(monkeypatch, status):
_patch_transport(monkeypatch, lambda request: httpx.Response(status, json={"error": "gated"}))
Expand Down
Loading
Loading