From ca4ab071cfba181b360edf933294acc45d9068e7 Mon Sep 17 00:00:00 2001 From: AssemblyAI Date: Wed, 10 Jun 2026 11:20:00 -0600 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 824af47f45adfc7515b95266cb4e33e23f29af5c --- assemblyai/__version__.py | 2 +- assemblyai/streaming/v3/models.py | 1 + assemblyai/types.py | 54 ++++++++++++++++++++++++++++--- tests/unit/test_streaming.py | 31 ++++++++++++++++++ tests/unit/test_sync.py | 48 +++++++++++++++++++++++++++ 5 files changed, 131 insertions(+), 5 deletions(-) diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index 1540970..656604c 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.64.9" +__version__ = "0.64.11" diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index db76abe..bb58884 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -245,6 +245,7 @@ class StreamingParameters(StreamingSessionParameters): sample_rate: int encoding: Optional[Encoding] = None speech_model: Optional[SpeechModel] = None + language_code: Optional[str] = None language_detection: Optional[bool] = None domain: Optional[StreamingDomain] = None inactivity_timeout: Optional[int] = None diff --git a/assemblyai/types.py b/assemblyai/types.py index 1ce4c7f..9bada53 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -2985,6 +2985,32 @@ class LemurPurgeResponse(BaseModel): # locally with a clear message instead of a 400 round trip. _SYNC_MAX_PROMPT_LEN = 4096 _SYNC_MAX_WORD_BOOST_LEN = 2048 +_SYNC_MAX_CONVERSATION_CONTEXT_TURNS = 100 +_SYNC_MAX_CONVERSATION_CONTEXT_LEN = 4096 + + +def _normalize_conversation_context(v): + """Coerce a single string to a one-turn list, strip + drop empties, cap. + + Shared by the pydantic v1 and v2 validators on ``SyncTranscriptionConfig``. + """ + if v is None: + return None + if isinstance(v, str): + v = [v] + turns = [t.strip() for t in v if t and t.strip()] + if len(turns) > _SYNC_MAX_CONVERSATION_CONTEXT_TURNS: + raise ValueError( + f"conversation_context exceeds {_SYNC_MAX_CONVERSATION_CONTEXT_TURNS} " + f"turns (got {len(turns)})" + ) + total = sum(len(t) for t in turns) + if total > _SYNC_MAX_CONVERSATION_CONTEXT_LEN: + raise ValueError( + f"conversation_context exceeds {_SYNC_MAX_CONVERSATION_CONTEXT_LEN} " + f"characters (got {total})" + ) + return turns or None class SyncSpeechModel(str, Enum): @@ -2997,10 +3023,11 @@ class SyncTranscriptionConfig(BaseModel): """ Options for a synchronous transcription request. - `prompt`, `word_boost`, and `language_code` shape the transcript; - `sample_rate` and `channels` are required only for raw PCM audio (WAV - carries them in its header). `model` selects the sync speech model and is - sent as the `X-AAI-Model` routing header, not in the request body. + `prompt`, `word_boost`, `conversation_context`, and `language_code` shape + the transcript; `sample_rate` and `channels` are required only for raw PCM + audio (WAV carries them in its header). `model` selects the sync speech + model and is sent as the `X-AAI-Model` routing header, not in the request + body. """ model: str = SyncSpeechModel.u3_sync_pro.value @@ -3012,6 +3039,16 @@ class SyncTranscriptionConfig(BaseModel): word_boost: Optional[List[str]] = None "Keyterms biasing the decoder. Whitespace is stripped and empty terms dropped. Max 2048 characters total." + conversation_context: Optional[Union[str, List[str]]] = None + """Prior turns from the same conversation, in chronological order (oldest + first, most recent last). Gives the model the dialogue that preceded this + audio so it transcribes the clip with better continuity and proper-noun + consistency. Include turns from either side of the conversation (e.g. a + voice agent's replies) as separate entries; entries carry no speaker labels. + A single string is accepted and treated as one turn. Max 100 turns / 4096 + characters total; when the prompt exceeds the model token budget the oldest + turns are dropped first, so put the most recent turn last.""" + language_code: Optional[Union[str, List[str]]] = None """ISO 639-1 language code, or a list of codes for multilingual audio (e.g. `"es"` or `["en", "es"]`). Steers the default transcription prompt toward @@ -3040,6 +3077,11 @@ def _normalize_word_boost(cls, v): ) return terms or None + @field_validator("conversation_context") + @classmethod + def _normalize_conversation_context(cls, v): + return _normalize_conversation_context(v) + else: @validator("word_boost") @@ -3054,6 +3096,10 @@ def _normalize_word_boost(cls, v): ) return terms or None + @validator("conversation_context") + def _normalize_conversation_context(cls, v): + return _normalize_conversation_context(v) + class SyncWord(BaseModel): """A single word from a synchronous transcript, with timing and confidence.""" diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index 4a6451d..30960cc 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -277,6 +277,37 @@ def mocked_websocket_connect( assert "mode=max_accuracy" in actual_url +def test_client_connect_with_language_code(mocker: MockFixture): + # Given: client + language_code parameter + actual_url = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_url + actual_url = url + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + params = StreamingParameters( + sample_rate=16000, + speech_model=SpeechModel.u3_rt_pro, + language_code="es", + ) + + # When: connect + client.connect(params) + + # Then: the language_code wire param is forwarded + assert "language_code=es" in actual_url + + def test_noise_suppression_deprecated_alias_migrates_to_voice_focus( mocker: MockFixture, caplog: pytest.LogCaptureFixture ): diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index 0ee739c..ac34928 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -83,6 +83,54 @@ def test_transcribe_sends_prompt_and_word_boost(httpx_mock: HTTPXMock): assert b'"model"' not in body +def test_transcribe_sends_conversation_context_list(httpx_mock: HTTPXMock): + # Given a mocked sync endpoint + _mock_ok(httpx_mock) + + # When transcribing with prior conversation turns (oldest first) + config = aai.SyncTranscriptionConfig( + conversation_context=[ + "I'd like to book a flight to Denver.", + " Sure, what date were you thinking? ", + "", + ], + ) + aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes", config=config) + + # Then the config JSON part carries the turns, stripped with empties dropped + body = httpx_mock.get_requests()[0].read() + assert b'name="config"' in body + assert b'"conversation_context"' in body + assert b"I'd like to book a flight to Denver." in body + assert b"Sure, what date were you thinking?" in body + + +def test_transcribe_coerces_conversation_context_string(httpx_mock: HTTPXMock): + # Given a mocked sync endpoint + _mock_ok(httpx_mock) + + # When conversation_context is a bare string (single prior turn) + config = aai.SyncTranscriptionConfig( + conversation_context="Sure, what date were you thinking?" + ) + + # Then it is normalized to a one-turn list + assert config.conversation_context == ["Sure, what date were you thinking?"] + + # And it ships as a JSON array in the config part + aai.SyncTranscriber().transcribe(b"RIFFfake-wav-bytes", config=config) + body = httpx_mock.get_requests()[0].read() + assert b'"conversation_context"' in body + assert b'"Sure, what date were you thinking?"' in body + + +def test_conversation_context_rejects_too_many_chars(): + # Given conversation_context whose total length exceeds the cap, + # When/Then constructing the config raises a validation error + with pytest.raises(Exception): + aai.SyncTranscriptionConfig(conversation_context=["a" * 5000]) + + def test_transcribe_sends_single_language_code(httpx_mock: HTTPXMock): # Given a mocked sync endpoint _mock_ok(httpx_mock)