From 98879955f59063f0255f4b6ddbaa3341079b0d2f Mon Sep 17 00:00:00 2001 From: AssemblyAI Date: Thu, 11 Jun 2026 12:48:55 -0600 Subject: [PATCH] Project import generated by Copybara. GitOrigin-RevId: 8bde83cabc9c27cc0f66eb3860b0293fa93800cf --- .github/workflows/test.yml | 6 + README.md | 7 - assemblyai/__init__.py | 2 - assemblyai/__version__.py | 2 +- assemblyai/streaming/v3/__init__.py | 2 + assemblyai/streaming/v3/async_client.py | 87 +++++--- assemblyai/streaming/v3/client.py | 87 +++++--- assemblyai/streaming/v3/models.py | 21 ++ assemblyai/types.py | 16 +- tests/unit/test_streaming.py | 270 +++++++++++++++++++++++- tests/unit/test_streaming_async.py | 166 ++++++++++++++- tests/unit/test_sync.py | 8 +- tox.ini | 31 ++- 13 files changed, 599 insertions(+), 106 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 48f696f..b4d2469 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -12,10 +12,16 @@ jobs: test: name: Python ${{ matrix.py }} on ${{ matrix.os }} runs-on: ${{ matrix.os }} + env: + # pip>=24.1 rejects legacy dep metadata (e.g. pytest-httpx's `pytest (<8.*,>=6.*)`) + # and backtracks through dozens of versions, stalling the matrix 30+ min. + # Seed each tox venv with the last pip that tolerates it. + VIRTUALENV_PIP: "24.0" strategy: fail-fast: false matrix: py: + - "3.12" - "3.11" - "3.10" - "3.9" diff --git a/README.md b/README.md index 3816076..b148540 100644 --- a/README.md +++ b/README.md @@ -18,13 +18,6 @@ With a single API call, get access to AI models built on the latest AI breakthroughs to transcribe and understand audio and speech data securely at large scale. -> **⚠️ WARNING** -> This SDK is intended for **testing and light usage only**. It is not -> recommended for use at scale or with production traffic. For best -> results, we recommend calling the AssemblyAI API directly via HTTP -> request. See our [official documentation](https://www.assemblyai.com/docs) -> for more information, including HTTP code examples. - ## Using with AI coding agents If you're integrating this SDK with Claude Code, Cursor, Copilot, or another AI coding assistant, give your agent current API context so it doesn't generate code against outdated model names or parameters. diff --git a/assemblyai/__init__.py b/assemblyai/__init__.py index 7d92a77..6e66b5b 100644 --- a/assemblyai/__init__.py +++ b/assemblyai/__init__.py @@ -68,7 +68,6 @@ SyncTranscriptError, SyncTranscriptionConfig, SyncTranscriptResponse, - SyncWord, Timestamp, TranscriptError, TranscriptionConfig, @@ -149,7 +148,6 @@ "SyncTranscriptError", "SyncTranscriptionConfig", "SyncTranscriptResponse", - "SyncWord", "Timestamp", "Transcriber", "TranscriptionConfig", diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index 656604c..5c34a51 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.64.11" +__version__ = "0.64.16" diff --git a/assemblyai/streaming/v3/__init__.py b/assemblyai/streaming/v3/__init__.py index 2bcec35..bd9ca87 100644 --- a/assemblyai/streaming/v3/__init__.py +++ b/assemblyai/streaming/v3/__init__.py @@ -6,6 +6,7 @@ EventMessage, LLMGatewayResponseEvent, NoiseSuppressionModel, + SessionConfiguration, SpeakerRevisionEvent, SpeakerRevisionItem, SpeechModel, @@ -35,6 +36,7 @@ "SpeakerRevisionEvent", "SpeakerRevisionItem", "SpeechModel", + "SessionConfiguration", "SpeechStartedEvent", "StreamingClient", "StreamingClientOptions", diff --git a/assemblyai/streaming/v3/async_client.py b/assemblyai/streaming/v3/async_client.py index 6804982..f04c9b6 100644 --- a/assemblyai/streaming/v3/async_client.py +++ b/assemblyai/streaming/v3/async_client.py @@ -132,35 +132,52 @@ async def connect(self, params: StreamingParameters) -> None: uri = _build_uri(self._options.api_host, params) headers = _build_headers(self._options) + options = self._options - try: - self._websocket = await asyncio.wait_for( - websocket_connect_async(uri, additional_headers=headers), - timeout=15, - ) - except websockets.exceptions.InvalidStatus as exc: - status_code = getattr(getattr(exc, "response", None), "status_code", None) - await self._report_connection_closed( - StreamingError( - message=f"WebSocket handshake rejected (HTTP {status_code})", - code=status_code, + for attempt in range(options.max_connection_retries + 1): + try: + self._websocket = await asyncio.wait_for( + websocket_connect_async(uri, additional_headers=headers), + timeout=options.connect_timeout, ) - ) - # Single-use design: a failed handshake terminates the client. Close - # the HTTP client now so users who treat ``on_error`` as the - # terminal signal don't leak the httpx pool. - await self._client.aclose() - return - except ( - websockets.exceptions.InvalidHandshake, - websockets.exceptions.ConnectionClosed, - OSError, - asyncio.TimeoutError, - TimeoutError, - ) as exc: - await self._report_connection_closed(exc) - await self._client.aclose() - return + break + except websockets.exceptions.InvalidStatus as exc: + # HTTP-level rejection (auth, quota, bad request): a retry would + # hit the same response, so fail fast. + status_code = getattr( + getattr(exc, "response", None), "status_code", None + ) + await self._report_connection_closed( + StreamingError( + message=f"WebSocket handshake rejected (HTTP {status_code})", + code=status_code, + ) + ) + # Single-use design: a failed handshake terminates the client. + # Close the HTTP client now so users who treat ``on_error`` as + # the terminal signal don't leak the httpx pool. + await self._client.aclose() + return + except ( + websockets.exceptions.InvalidHandshake, + websockets.exceptions.ConnectionClosed, + OSError, + asyncio.TimeoutError, + TimeoutError, + ) as exc: + if attempt < options.max_connection_retries: + logger.debug( + "WebSocket connect attempt %d/%d failed (%s); retrying", + attempt + 1, + options.max_connection_retries + 1, + exc, + ) + if options.connection_retry_delay > 0: + await asyncio.sleep(options.connection_retry_delay) + continue + await self._report_connection_closed(exc) + await self._client.aclose() + return self._read_task = asyncio.create_task( self._read_loop(), name="AsyncStreamingClient._read_loop" @@ -188,6 +205,22 @@ async def disconnect(self, terminate: bool = False) -> None: # cancel the awaited task on timeout, unlike ``wait_for``. if self._write_task is not None and not self._write_task.done(): await asyncio.wait({self._write_task}, timeout=2.0) + # Don't stop the read task yet — the server sends the final Turn + # and TerminationEvent after receiving Terminate. Every terminal + # path sets ``_stop_event``, so waiting on it here lets those + # messages dispatch before teardown. + if ( + self._read_task is not None + and not self._read_task.done() + and asyncio.current_task() is not self._read_task + ): + try: + await asyncio.wait_for( + self._stop_event.wait(), + timeout=self._options.terminate_timeout, + ) + except asyncio.TimeoutError: + pass self._stop_event.set() diff --git a/assemblyai/streaming/v3/client.py b/assemblyai/streaming/v3/client.py index 65ace7a..8dc5084 100644 --- a/assemblyai/streaming/v3/client.py +++ b/assemblyai/streaming/v3/client.py @@ -2,6 +2,7 @@ import logging import queue import threading +import time from typing import Any, Dict, Generator, Iterable, Optional, Union import httpx @@ -57,8 +58,11 @@ def __init__(self, options: StreamingClientOptions): def connect(self, params: StreamingParameters) -> None: """Open the WebSocket session and start the read/write threads. - Blocks until the handshake completes. If the server rejects the - handshake (auth error, etc.) ``Error`` is dispatched to any + Blocks until the handshake completes. A transient handshake failure + (timeout, network drop) is retried up to + ``options.max_connection_retries`` times before the failure is + reported. If the server rejects the handshake at the HTTP layer (auth + error, etc.) ``Error`` is dispatched to any ``on(StreamingEvents.Error, ...)`` handler rather than raised, so registration order matters: call ``on()`` before ``connect()``. """ @@ -66,30 +70,47 @@ def connect(self, params: StreamingParameters) -> None: uri = _build_uri(self._options.api_host, params) headers = _build_headers(self._options) + options = self._options - try: - self._websocket = websocket_connect( - uri, - additional_headers=headers, - open_timeout=15, - ) - except websockets.exceptions.InvalidStatus as exc: - status_code = getattr(getattr(exc, "response", None), "status_code", None) - self._report_connection_closed( - StreamingError( - message=f"WebSocket handshake rejected (HTTP {status_code})", - code=status_code, + for attempt in range(options.max_connection_retries + 1): + try: + self._websocket = websocket_connect( + uri, + additional_headers=headers, + open_timeout=options.connect_timeout, ) - ) - return - except ( - websockets.exceptions.InvalidHandshake, - websockets.exceptions.ConnectionClosed, - OSError, - TimeoutError, - ) as exc: - self._report_connection_closed(exc) - return + break + except websockets.exceptions.InvalidStatus as exc: + # HTTP-level rejection (auth, quota, bad request): a retry + # would hit the same response, so fail fast. + status_code = getattr( + getattr(exc, "response", None), "status_code", None + ) + self._report_connection_closed( + StreamingError( + message=f"WebSocket handshake rejected (HTTP {status_code})", + code=status_code, + ) + ) + return + except ( + websockets.exceptions.InvalidHandshake, + websockets.exceptions.ConnectionClosed, + OSError, + TimeoutError, + ) as exc: + if attempt < options.max_connection_retries: + logger.debug( + "WebSocket connect attempt %d/%d failed (%s); retrying", + attempt + 1, + options.max_connection_retries + 1, + exc, + ) + if options.connection_retry_delay > 0: + time.sleep(options.connection_retry_delay) + continue + self._report_connection_closed(exc) + return self._write_thread.start() self._read_thread.start() @@ -100,16 +121,26 @@ def disconnect(self, terminate: bool = False) -> None: """Stop the read/write threads and close the WebSocket. Pass ``terminate=True`` for a graceful close — the client sends a - ``TerminateSession`` frame and waits for the server's - ``TerminationEvent`` (which reports total audio duration). Without - ``terminate=True`` the WebSocket is closed without notifying the - server. + ``TerminateSession`` frame and waits up to ``options.terminate_timeout`` + seconds for the server's ``TerminationEvent`` (which reports total + audio duration). Without ``terminate=True`` the WebSocket is closed + without notifying the server. """ # Enqueue Terminate even when stop is already set: `_write_message` # bypasses the stop gate for TerminateSession so the frame still # reaches the server when the write thread is alive. if terminate: self._write_queue.put(TerminateSession()) + # Don't stop the read thread yet — the server sends the final Turn + # and TerminationEvent after receiving Terminate. Every terminal + # path sets `_stop_event` (TerminationEvent via `_handle_message`, + # server close, server error), so waiting on it here lets those + # messages dispatch before teardown. + if ( + self._read_thread.is_alive() + and threading.current_thread() is not self._read_thread + ): + self._stop_event.wait(timeout=self._options.terminate_timeout) self._stop_event.set() diff --git a/assemblyai/streaming/v3/models.py b/assemblyai/streaming/v3/models.py index bb58884..c45a34d 100644 --- a/assemblyai/streaming/v3/models.py +++ b/assemblyai/streaming/v3/models.py @@ -38,10 +38,19 @@ class TurnEvent(BaseModel): speaker_label: Optional[str] = None +class SessionConfiguration(BaseModel): + # `mode` stays a plain str so a new server-side mode value can't fail + # validation and drop the whole BeginEvent. + model: Optional[str] = None + mode: Optional[str] = None + api_version: Optional[str] = None + + class BeginEvent(BaseModel): type: Literal["Begin"] = "Begin" id: str expires_at: datetime + configuration: Optional[SessionConfiguration] = None class TerminationEvent(BaseModel): @@ -285,6 +294,18 @@ class StreamingClientOptions(BaseModel): api_host: str = "streaming.assemblyai.com" api_key: Optional[str] = None token: Optional[str] = None + # Seconds to wait for the WebSocket handshake to complete before treating + # the attempt as failed. + connect_timeout: float = 1.0 + # Additional handshake attempts after the first one fails on a transient + # error (timeout, network drop). 0 disables retries. HTTP-level rejections + # (auth, quota, bad request) are never retried. + max_connection_retries: int = 2 + # Seconds to wait between handshake attempts. + connection_retry_delay: float = 0.5 + # Seconds disconnect(terminate=True) waits for the server's + # TerminationEvent (and any final Turn) before tearing down. + terminate_timeout: float = 5.0 class StreamingError(Exception): diff --git a/assemblyai/types.py b/assemblyai/types.py index 9bada53..bafa3c7 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -3101,24 +3101,13 @@ def _normalize_conversation_context(cls, v): return _normalize_conversation_context(v) -class SyncWord(BaseModel): - """A single word from a synchronous transcript, with timing and confidence.""" - - text: str - start_ms: int - "Word start time in milliseconds." - end_ms: int - "Word end time in milliseconds." - confidence: float - - class SyncTranscriptResponse(BaseModel): """The result of a synchronous transcription request.""" text: str "The full transcript text." - words: List[SyncWord] = Field(default_factory=list) + words: List[Word] = Field(default_factory=list) "Per-word timing and confidence." confidence: float @@ -3127,8 +3116,5 @@ class SyncTranscriptResponse(BaseModel): audio_duration_ms: int "Total audio duration in milliseconds." - inference_time_ms: float - "Model inference time in milliseconds. Excludes auth, decode, and queue wait." - session_id: str "Server-generated UUID for this request. Record it to correlate with support." diff --git a/tests/unit/test_streaming.py b/tests/unit/test_streaming.py index 30960cc..c9cbbe2 100644 --- a/tests/unit/test_streaming.py +++ b/tests/unit/test_streaming.py @@ -11,6 +11,7 @@ from websockets.frames import Close from assemblyai.streaming.v3 import ( + BeginEvent, NoiseSuppressionModel, SpeakerRevisionEvent, SpeechModel, @@ -76,7 +77,7 @@ def mocked_websocket_connect( assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12" assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] - assert actual_open_timeout == 15 + assert actual_open_timeout == 1.0 def test_client_connect_with_token(mocker: MockFixture): @@ -118,7 +119,7 @@ def mocked_websocket_connect( assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12" assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] - assert actual_open_timeout == 15 + assert actual_open_timeout == 1.0 def test_client_connect_all_parameters(mocker: MockFixture): @@ -168,7 +169,7 @@ def mocked_websocket_connect( assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12" assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] - assert actual_open_timeout == 15 + assert actual_open_timeout == 1.0 def test_client_connect_with_redact_pii(mocker: MockFixture): @@ -600,7 +601,7 @@ def mocked_websocket_connect( assert actual_url == f"wss://api.example.com/v3/ws?{urlencode(expected_params)}" assert actual_additional_headers["Authorization"] == "test" - assert actual_open_timeout == 15 + assert actual_open_timeout == 1.0 def test_client_connect_with_u3_pro_and_prompt(mocker: MockFixture): @@ -650,7 +651,7 @@ def mocked_websocket_connect( assert actual_additional_headers["AssemblyAI-Version"] == "2025-05-12" assert "AssemblyAI/1.0" in actual_additional_headers["User-Agent"] - assert actual_open_timeout == 15 + assert actual_open_timeout == 1.0 def test_client_connect_with_speaker_labels(mocker: MockFixture): @@ -1100,6 +1101,64 @@ def test_speech_started_event(): assert event.timestamp == 1280 +def test_begin_event_parses_configuration(): + # Given: a Begin message with the server's configuration object + data = { + "type": "Begin", + "id": "abc", + "expires_at": 1781207829, + "configuration": { + "model": "u3-rt-pro", + "mode": "balanced", + "api_version": "2025-05-12", + }, + } + + # When: parsing the event model + event = BeginEvent(**data) + + # Then: configuration fields are exposed instead of silently dropped + assert event.configuration is not None + assert event.configuration.model == "u3-rt-pro" + assert event.configuration.mode == "balanced" + assert event.configuration.api_version == "2025-05-12" + + +def test_begin_event_without_configuration(): + # Given: a Begin message without configuration (older server) + event = BeginEvent(id="abc", expires_at=1781207829) + + # Then: configuration defaults to None + assert event.configuration is None + + +def test_begin_event_tolerates_unknown_mode_and_null_fields(): + # Given: a Begin message with a mode value the SDK doesn't know yet and a + # null mode variant (non-u3-pro models send mode: null) + event = BeginEvent( + id="abc", + expires_at=1781207829, + configuration={ + "model": "u9-rt", + "mode": "warpspeed", + "api_version": "2099-01-01", + }, + ) + event_null = BeginEvent( + id="abc", + expires_at=1781207829, + configuration={ + "model": "universal-streaming-english", + "mode": None, + "api_version": "2025-05-12", + }, + ) + + # Then: neither fails validation (a new server mode must not drop BeginEvent) + assert event.configuration.mode == "warpspeed" + assert event_null.configuration.mode is None + + class _FakeWebSocket: """Programmable sync websocket stand-in for driving StreamingClient in tests.""" @@ -1466,6 +1525,75 @@ def test_server_error_without_trailing_close_exits_read_loop(mocker: MockFixture client.disconnect(terminate=True) +def test_disconnect_terminate_waits_for_termination_event(mocker: MockFixture): + # Given: a server that delivers the final TerminationEvent only after the + # client's Terminate frame arrives (as in production, where Termination + # lands ~0.5-1.3s after Terminate). + termination_json = json.dumps( + { + "type": "Termination", + "audio_duration_seconds": 12, + "session_duration_seconds": 13, + } + ) + + class _TerminateAwareWS(_FakeWebSocket): + def recv(self, timeout=None): + if not any(isinstance(s, str) and '"Terminate"' in s for s in self.sent): + raise TimeoutError() + return super().recv(timeout) + + fake_ws = _TerminateAwareWS(recv_script=[termination_json]) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + received = [] + client = StreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Termination, lambda c, e: received.append(e)) + client.connect(_default_params()) + + # When: disconnecting gracefully + client.disconnect(terminate=True) + + # Then: the TerminationEvent was dispatched before disconnect returned, + # and both worker threads exited. + assert len(received) == 1 + assert received[0].audio_duration_seconds == 12 + assert not client._read_thread.is_alive() + assert not client._write_thread.is_alive() + + +def test_disconnect_terminate_times_out_when_no_termination(mocker: MockFixture): + # Given: a server that never sends a TerminationEvent and a short + # terminate_timeout. + fake_ws = _FakeWebSocket(recv_script=[]) + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + return_value=fake_ws, + ) + client = StreamingClient( + StreamingClientOptions( + api_key="test", api_host="api.example.com", terminate_timeout=0.3 + ) + ) + client.connect(_default_params()) + + # When: disconnecting gracefully + start = time.monotonic() + client.disconnect(terminate=True) + elapsed = time.monotonic() - start + + # Then: disconnect honored the bounded wait instead of hanging, still sent + # the Terminate frame, and tore down both threads. + assert 0.3 <= elapsed < 2.0 + assert any(isinstance(s, str) and '"Terminate"' in s for s in fake_ws.sent) + assert not client._read_thread.is_alive() + assert not client._write_thread.is_alive() + + def test_disconnect_terminate_enqueues_when_stop_already_set(mocker: MockFixture): # Given: a client whose _stop_event is already set (e.g. after a server # error invoked _report_server_error). Threads were never started, so the @@ -1580,3 +1708,135 @@ def bad_warning_handler(self_, w): assert not client._read_thread.is_alive() client.disconnect() + + +def test_client_connect_uses_configured_timeout(mocker: MockFixture): + # Given: a client configured with a custom connect_timeout. + actual_open_timeout = None + + def mocked_websocket_connect( + url: str, additional_headers: dict, open_timeout: float + ): + nonlocal actual_open_timeout + actual_open_timeout = open_timeout + + mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + new=mocked_websocket_connect, + ) + _disable_rw_threads(mocker) + + client = StreamingClient( + StreamingClientOptions( + api_key="test", api_host="api.example.com", connect_timeout=5.0 + ) + ) + + # When: connect() opens the handshake. + client.connect(_default_params()) + + # Then: the configured timeout is forwarded to the websocket handshake. + assert actual_open_timeout == 5.0 + + +def test_client_connect_retries_transient_failure_then_succeeds(mocker: MockFixture): + # Given: the handshake fails twice transiently, then succeeds (default 2 + # retries → 3 attempts). + fake_ws = object() # rw threads disabled, so the websocket is never driven + connect_mock = mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + side_effect=[TimeoutError(), TimeoutError(), fake_ws], + ) + _disable_rw_threads(mocker) + errors = [] + client = StreamingClient( + StreamingClientOptions( + api_key="test", api_host="api.example.com", connection_retry_delay=0 + ) + ) + client.on(StreamingEvents.Error, lambda s, e: errors.append(e)) + + # When: connect() retries through the transient failures. + client.connect(_default_params()) + + # Then: it makes three attempts, binds the websocket, and reports no error. + assert connect_mock.call_count == 3 + assert client._websocket is fake_ws + assert errors == [] + + +def test_client_connect_exhausts_retries_then_reports_error(mocker: MockFixture): + # Given: the handshake fails transiently on every attempt. + connect_mock = mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + side_effect=TimeoutError(), + ) + _disable_rw_threads(mocker) + errors = [] + client = StreamingClient( + StreamingClientOptions( + api_key="test", + api_host="api.example.com", + max_connection_retries=2, + connection_retry_delay=0, + ) + ) + client.on(StreamingEvents.Error, lambda s, e: errors.append(e)) + + # When: connect() exhausts all attempts. + client.connect(_default_params()) + + # Then: it makes 1 + max_connection_retries attempts and dispatches one error. + assert connect_mock.call_count == 3 + assert len(errors) == 1 + + +def test_client_connect_does_not_retry_invalid_status(mocker: MockFixture): + # Given: the server rejects the handshake at the HTTP layer (401). + response = SimpleNamespace(status_code=401) + connect_mock = mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + side_effect=InvalidStatus(response=response), + ) + _disable_rw_threads(mocker) + errors = [] + client = StreamingClient( + StreamingClientOptions( + api_key="test", + api_host="api.example.com", + max_connection_retries=5, + connection_retry_delay=0, + ) + ) + client.on(StreamingEvents.Error, lambda s, e: errors.append(e)) + + # When: connect() is called. + client.connect(_default_params()) + + # Then: the HTTP rejection is not retried — a single attempt, single error. + assert connect_mock.call_count == 1 + assert len(errors) == 1 + assert errors[0].code == 401 + + +def test_client_connect_retries_disabled(mocker: MockFixture): + # Given: retries disabled (max_connection_retries=0). + connect_mock = mocker.patch( + "assemblyai.streaming.v3.client.websocket_connect", + side_effect=TimeoutError(), + ) + _disable_rw_threads(mocker) + errors = [] + client = StreamingClient( + StreamingClientOptions( + api_key="test", api_host="api.example.com", max_connection_retries=0 + ) + ) + client.on(StreamingEvents.Error, lambda s, e: errors.append(e)) + + # When: connect() fails transiently. + client.connect(_default_params()) + + # Then: exactly one attempt is made and the error is reported. + assert connect_mock.call_count == 1 + assert len(errors) == 1 diff --git a/tests/unit/test_streaming_async.py b/tests/unit/test_streaming_async.py index dc1d002..c06986e 100644 --- a/tests/unit/test_streaming_async.py +++ b/tests/unit/test_streaming_async.py @@ -210,8 +210,12 @@ async def test_disconnect_terminate_sends_terminate_then_closes(mocker: MockFixt fake_ws = _FakeAsyncWebSocket() _patch_connect(mocker, fake_ws) + # Short terminate_timeout: this fake server never replies with a + # TerminationEvent, so disconnect's graceful wait should time out fast. client = AsyncStreamingClient( - StreamingClientOptions(api_key="test", api_host="api.example.com") + StreamingClientOptions( + api_key="test", api_host="api.example.com", terminate_timeout=0.2 + ) ) await client.connect(_default_params()) @@ -224,6 +228,41 @@ async def test_disconnect_terminate_sends_terminate_then_closes(mocker: MockFixt assert fake_ws.close_call_count >= 1 +async def test_disconnect_terminate_waits_for_termination_event(mocker: MockFixture): + # Given: a server that replies to the Terminate frame with a + # TerminationEvent (as in production). + termination_json = json.dumps( + { + "type": "Termination", + "audio_duration_seconds": 12, + "session_duration_seconds": 13, + } + ) + + class _TerminateAwareAsyncWS(_FakeAsyncWebSocket): + async def send(self, data) -> None: + await super().send(data) + if isinstance(data, str) and '"Terminate"' in data: + self.push_message(termination_json) + + fake_ws = _TerminateAwareAsyncWS() + _patch_connect(mocker, fake_ws) + received = [] + + client = AsyncStreamingClient( + StreamingClientOptions(api_key="test", api_host="api.example.com") + ) + client.on(StreamingEvents.Termination, lambda c, e: received.append(e)) + await client.connect(_default_params()) + + # When: disconnecting gracefully + await client.disconnect(terminate=True) + + # Then: the TerminationEvent was dispatched before disconnect returned. + assert len(received) == 1 + assert received[0].audio_duration_seconds == 12 + + async def test_begin_event_dispatched_to_handler(mocker: MockFixture): fake_ws = _FakeAsyncWebSocket() _patch_connect(mocker, fake_ws) @@ -1190,3 +1229,128 @@ def on_error(_client, err): assert received[0].code == 1011 await client.disconnect() + + +_ASYNC_CONNECT_PATH = ( + "assemblyai.streaming.v3.async_client." + "websocket_connect_async" +) + + +async def test_connect_retries_transient_failure_then_succeeds(mocker: MockFixture): + # Given: the handshake fails twice transiently, then succeeds (default 2 + # retries → 3 attempts). + fake_ws = _FakeAsyncWebSocket() + attempts = {"n": 0} + + async def flaky_connect(uri, additional_headers=None, **_kwargs): + attempts["n"] += 1 + if attempts["n"] < 3: + raise TimeoutError() + return fake_ws + + mocker.patch(_ASYNC_CONNECT_PATH, new=flaky_connect) + errors = [] + client = AsyncStreamingClient( + StreamingClientOptions( + api_key="test", api_host="api.example.com", connection_retry_delay=0 + ) + ) + client.on(StreamingEvents.Error, lambda c, e: errors.append(e)) + + # When: connect() retries through the transient failures. + await client.connect(_default_params()) + + # Then: it makes three attempts, binds the websocket, and reports no error. + assert attempts["n"] == 3 + assert client._websocket is fake_ws + assert errors == [] + + await client.disconnect() + + +async def test_connect_exhausts_retries_then_reports_error(mocker: MockFixture): + # Given: the handshake fails transiently on every attempt. + attempts = {"n": 0} + + async def failing_connect(*_args, **_kwargs): + attempts["n"] += 1 + raise TimeoutError() + + mocker.patch(_ASYNC_CONNECT_PATH, new=failing_connect) + errors = [] + client = AsyncStreamingClient( + StreamingClientOptions( + api_key="test", + api_host="api.example.com", + max_connection_retries=2, + connection_retry_delay=0, + ) + ) + client.on(StreamingEvents.Error, lambda c, e: errors.append(e)) + + # When: connect() exhausts all attempts. + await client.connect(_default_params()) + + # Then: it makes 1 + max_connection_retries attempts and dispatches one error. + assert attempts["n"] == 3 + assert len(errors) == 1 + + +async def test_connect_timeout_is_honored_then_retried(mocker: MockFixture): + # Given: a handshake that hangs past the (tiny) configured connect_timeout. + attempts = {"n": 0} + + async def hanging_connect(*_args, **_kwargs): + attempts["n"] += 1 + await asyncio.sleep(10) + + mocker.patch(_ASYNC_CONNECT_PATH, new=hanging_connect) + errors = [] + client = AsyncStreamingClient( + StreamingClientOptions( + api_key="test", + api_host="api.example.com", + connect_timeout=0.01, + max_connection_retries=1, + connection_retry_delay=0, + ) + ) + client.on(StreamingEvents.Error, lambda c, e: errors.append(e)) + + # When: connect() aborts each hung attempt at the timeout and retries once. + await client.connect(_default_params()) + + # Then: both attempts time out and a single error is reported. + assert attempts["n"] == 2 + assert len(errors) == 1 + + +async def test_connect_does_not_retry_invalid_status(mocker: MockFixture): + # Given: the server rejects the handshake at the HTTP layer (401). + attempts = {"n": 0} + response = type("R", (), {"status_code": 401})() + + async def failing_connect(*_args, **_kwargs): + attempts["n"] += 1 + raise InvalidStatus(response=response) + + mocker.patch(_ASYNC_CONNECT_PATH, new=failing_connect) + errors = [] + client = AsyncStreamingClient( + StreamingClientOptions( + api_key="test", + api_host="api.example.com", + max_connection_retries=5, + connection_retry_delay=0, + ) + ) + client.on(StreamingEvents.Error, lambda c, e: errors.append(e)) + + # When: connect() is called. + await client.connect(_default_params()) + + # Then: the HTTP rejection is not retried — a single attempt, single error. + assert attempts["n"] == 1 + assert len(errors) == 1 + assert errors[0].code == 401 diff --git a/tests/unit/test_sync.py b/tests/unit/test_sync.py index ac34928..33e3c5e 100644 --- a/tests/unit/test_sync.py +++ b/tests/unit/test_sync.py @@ -11,12 +11,11 @@ _OK_RESPONSE = { "text": "hello world", "words": [ - {"text": "hello", "start_ms": 0, "end_ms": 200, "confidence": 0.9}, - {"text": "world", "start_ms": 220, "end_ms": 400, "confidence": 0.95}, + {"text": "hello", "start": 0, "end": 200, "confidence": 0.9}, + {"text": "world", "start": 220, "end": 400, "confidence": 0.95}, ], "confidence": 0.92, "audio_duration_ms": 400, - "inference_time_ms": 12.5, "session_id": "eb92c4ff-4bbb-429f-9b99-7279d7fe738f", } @@ -41,7 +40,8 @@ def test_transcribe_bytes_parses_response(httpx_mock: HTTPXMock): assert isinstance(result, aai.SyncTranscriptResponse) assert result.text == "hello world" assert result.session_id == _OK_RESPONSE["session_id"] - assert result.words[0].start_ms == 0 + assert result.words[0].start == 0 + assert result.words[0].end == 200 assert result.words[1].text == "world" diff --git a/tox.ini b/tox.ini index 23daedf..b0adc6f 100644 --- a/tox.ini +++ b/tox.ini @@ -1,25 +1,24 @@ +# Per-axis matrix, not a cartesian product: `latest` envs resolve the newest +# deps via setup.py install_requires; legacy floors are pinned on one Python +# (3.11) so each dependency's lower bound is still exercised without a 480-env +# blow-up. Floors track setup.py (pydantic>=1.10.17, websockets>=11, httpx>=0.19). [tox] -envlist = py{38,39,310,311}-websockets{latest,11.0}-pyaudio{latest,0.2}-httpx{latest,0.24,0.23,0.22,0.21}-pydantic{latest,2,1.10,1.9,1.8,1.7}-typing-extensions +envlist = + py{39,310,311,312}-latest + py311-httpx{0.22,0.24} + py311-pydantic1.10 + py311-websockets11.0 + py311-pyaudio0.2 [testenv] deps = - # library dependencies - websocketslatest: websockets + # back-compat floors (per factor); `latest` envs omit these websockets11.0: websockets>=11.0.0,<12.0.0 - httpxlatest: httpx - httpx0.24: httpx>=0.24.0,<0.25.0 - httpx0.23: httpx>=0.23.0,<0.24.0 httpx0.22: httpx>=0.22.0,<0.23.0 - httpx0.21: httpx>=0.21.0,<0.22.0 - pydanticlatest: pydantic - pydantic2: pydantic>=2 - pydantic1.10: pydantic>=1.10.0,<1.11.0,!=1.10.7 - pydantic1.9: pydantic>=1.9.0,<1.10.0 - pydantic1.8: pydantic>=1.8.0,<1.9.0 - pydantic1.7: pydantic>=1.7.0,<1.8.0 - typing-extensions: typing-extensions>=3.7 - # extra dependencies - pyaudiolatest: pyaudio + httpx0.24: httpx>=0.24.0,<0.25.0 + pydantic1.10: pydantic>=1.10.17,<1.11.0 + # pyaudio is an extra, not in install_requires — install for all envs + pyaudio>=0.2.13 pyaudio0.2: pyaudio>=0.2.13,<0.3.0 # test dependencies pytest