Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion assemblyai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.64.9"
__version__ = "0.64.11"
1 change: 1 addition & 0 deletions assemblyai/streaming/v3/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class StreamingParameters(StreamingSessionParameters):
sample_rate: int
encoding: Optional[Encoding] = None
speech_model: Optional[SpeechModel] = None
language_code: Optional[str] = None
language_detection: Optional[bool] = None
domain: Optional[StreamingDomain] = None
inactivity_timeout: Optional[int] = None
Expand Down
54 changes: 50 additions & 4 deletions assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2985,6 +2985,32 @@ class LemurPurgeResponse(BaseModel):
# locally with a clear message instead of a 400 round trip.
_SYNC_MAX_PROMPT_LEN = 4096
_SYNC_MAX_WORD_BOOST_LEN = 2048
_SYNC_MAX_CONVERSATION_CONTEXT_TURNS = 100
_SYNC_MAX_CONVERSATION_CONTEXT_LEN = 4096


def _normalize_conversation_context(v):
"""Coerce a single string to a one-turn list, strip + drop empties, cap.

Shared by the pydantic v1 and v2 validators on ``SyncTranscriptionConfig``.
"""
if v is None:
return None
if isinstance(v, str):
v = [v]
turns = [t.strip() for t in v if t and t.strip()]
if len(turns) > _SYNC_MAX_CONVERSATION_CONTEXT_TURNS:
raise ValueError(
f"conversation_context exceeds {_SYNC_MAX_CONVERSATION_CONTEXT_TURNS} "
f"turns (got {len(turns)})"
)
total = sum(len(t) for t in turns)
if total > _SYNC_MAX_CONVERSATION_CONTEXT_LEN:
raise ValueError(
f"conversation_context exceeds {_SYNC_MAX_CONVERSATION_CONTEXT_LEN} "
f"characters (got {total})"
)
return turns or None


class SyncSpeechModel(str, Enum):
Expand All @@ -2997,10 +3023,11 @@ class SyncTranscriptionConfig(BaseModel):
"""
Options for a synchronous transcription request.

`prompt`, `word_boost`, and `language_code` shape the transcript;
`sample_rate` and `channels` are required only for raw PCM audio (WAV
carries them in its header). `model` selects the sync speech model and is
sent as the `X-AAI-Model` routing header, not in the request body.
`prompt`, `word_boost`, `conversation_context`, and `language_code` shape
the transcript; `sample_rate` and `channels` are required only for raw PCM
audio (WAV carries them in its header). `model` selects the sync speech
model and is sent as the `X-AAI-Model` routing header, not in the request
body.
"""

model: str = SyncSpeechModel.u3_sync_pro.value
Expand All @@ -3012,6 +3039,16 @@ class SyncTranscriptionConfig(BaseModel):
word_boost: Optional[List[str]] = None
"Keyterms biasing the decoder. Whitespace is stripped and empty terms dropped. Max 2048 characters total."

conversation_context: Optional[Union[str, List[str]]] = None
"""Prior turns from the same conversation, in chronological order (oldest
first, most recent last). Gives the model the dialogue that preceded this
audio so it transcribes the clip with better continuity and proper-noun
consistency. Include turns from either side of the conversation (e.g. a
voice agent's replies) as separate entries; entries carry no speaker labels.
A single string is accepted and treated as one turn. Max 100 turns / 4096
characters total; when the prompt exceeds the model token budget the oldest
turns are dropped first, so put the most recent turn last."""

language_code: Optional[Union[str, List[str]]] = None
"""ISO 639-1 language code, or a list of codes for multilingual audio (e.g.
`"es"` or `["en", "es"]`). Steers the default transcription prompt toward
Expand Down Expand Up @@ -3040,6 +3077,11 @@ def _normalize_word_boost(cls, v):
)
return terms or None

@field_validator("conversation_context")
@classmethod
def _normalize_conversation_context(cls, v):
return _normalize_conversation_context(v)

else:

@validator("word_boost")
Expand All @@ -3054,6 +3096,10 @@ def _normalize_word_boost(cls, v):
)
return terms or None

@validator("conversation_context")
def _normalize_conversation_context(cls, v):
return _normalize_conversation_context(v)


class SyncWord(BaseModel):
"""A single word from a synchronous transcript, with timing and confidence."""
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,37 @@ def mocked_websocket_connect(
assert "mode=max_accuracy" in actual_url


def test_client_connect_with_language_code(mocker: MockFixture):
# Given: client + language_code parameter
actual_url = None

def mocked_websocket_connect(
url: str, additional_headers: dict, open_timeout: float
):
nonlocal actual_url
actual_url = url

mocker.patch(
"assemblyai.streaming.v3.client.websocket_connect",
new=mocked_websocket_connect,
)
_disable_rw_threads(mocker)
client = StreamingClient(
StreamingClientOptions(api_key="test", api_host="api.example.com")
)
params = StreamingParameters(
sample_rate=16000,
speech_model=SpeechModel.u3_rt_pro,
language_code="es",
)

# When: connect
client.connect(params)

# Then: the language_code wire param is forwarded
assert "language_code=es" in actual_url


def test_noise_suppression_deprecated_alias_migrates_to_voice_focus(
mocker: MockFixture, caplog: pytest.LogCaptureFixture
):
Expand Down
48 changes: 48 additions & 0 deletions tests/unit/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,54 @@ def test_transcribe_sends_prompt_and_word_boost(httpx_mock: HTTPXMock):
assert b'"model"' not in body


def test_transcribe_sends_conversation_context_list(httpx_mock: HTTPXMock):
# Given a mocked sync endpoint
_mock_ok(httpx_mock)

# When transcribing with prior conversation turns (oldest first)
config = aai.SyncTranscriptionConfig(
conversation_context=[
"I'd like to book a flight to Denver.",
" Sure, what date were you thinking? ",
"",
],
)
aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes", config=config)

# Then the config JSON part carries the turns, stripped with empties dropped
body = httpx_mock.get_requests()[0].read()
assert b'name="config"' in body
assert b'"conversation_context"' in body
assert b"I'd like to book a flight to Denver." in body
assert b"Sure, what date were you thinking?" in body


def test_transcribe_coerces_conversation_context_string(httpx_mock: HTTPXMock):
# Given a mocked sync endpoint
_mock_ok(httpx_mock)

# When conversation_context is a bare string (single prior turn)
config = aai.SyncTranscriptionConfig(
conversation_context="Sure, what date were you thinking?"
)

# Then it is normalized to a one-turn list
assert config.conversation_context == ["Sure, what date were you thinking?"]

# And it ships as a JSON array in the config part
aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes", config=config)
body = httpx_mock.get_requests()[0].read()
assert b'"conversation_context"' in body
assert b'"Sure, what date were you thinking?"' in body


def test_conversation_context_rejects_too_many_chars():
# Given conversation_context whose total length exceeds the cap,
# When/Then constructing the config raises a validation error
with pytest.raises(Exception):
aai.SyncTranscriptionConfig(conversation_context=["a" * 5000])


def test_transcribe_sends_single_language_code(httpx_mock: HTTPXMock):
# Given a mocked sync endpoint
_mock_ok(httpx_mock)
Expand Down
Loading