1+ """Tests for ASR namespace client."""
2+
3+ from unittest .mock import AsyncMock , Mock
4+
5+ import ormsgpack
6+ import pytest
7+
8+ from fishaudio .core import AsyncClientWrapper , ClientWrapper , RequestOptions
9+ from fishaudio .resources .asr import ASRClient , AsyncASRClient
10+ from fishaudio .types import ASRResponse
11+
12+
13+ @pytest .fixture
14+ def mock_client_wrapper (mock_api_key ):
15+ """Mock client wrapper."""
16+ wrapper = Mock (spec = ClientWrapper )
17+ wrapper .api_key = mock_api_key
18+ return wrapper
19+
20+
21+ @pytest .fixture
22+ def async_mock_client_wrapper (mock_api_key ):
23+ """Mock async client wrapper."""
24+ wrapper = Mock (spec = AsyncClientWrapper )
25+ wrapper .api_key = mock_api_key
26+ return wrapper
27+
28+
29+ @pytest .fixture
30+ def asr_client (mock_client_wrapper ):
31+ """ASRClient instance with mocked wrapper."""
32+ return ASRClient (mock_client_wrapper )
33+
34+
35+ @pytest .fixture
36+ def async_asr_client (async_mock_client_wrapper ):
37+ """AsyncASRClient instance with mocked wrapper."""
38+ return AsyncASRClient (async_mock_client_wrapper )
39+
40+
41+ class TestASRClient :
42+ """Test synchronous ASRClient."""
43+
44+ def test_transcribe_basic (
45+ self , asr_client , mock_client_wrapper , sample_asr_response
46+ ):
47+ """Test basic transcription."""
48+ mock_response = Mock ()
49+ mock_response .json .return_value = sample_asr_response
50+ mock_client_wrapper .request .return_value = mock_response
51+
52+ result = asr_client .transcribe (audio = b"fake_audio" )
53+
54+ assert isinstance (result , ASRResponse )
55+ assert result .text == "Hello world"
56+ assert result .duration == 1500.0
57+ assert len (result .segments ) == 2
58+
59+ mock_client_wrapper .request .assert_called_once ()
60+ call_args = mock_client_wrapper .request .call_args
61+ assert call_args [0 ][0 ] == "POST"
62+ assert call_args [0 ][1 ] == "/v1/asr"
63+ assert call_args [1 ]["headers" ]["Content-Type" ] == "application/msgpack"
64+
65+ payload = ormsgpack .unpackb (call_args [1 ]["content" ])
66+ assert payload ["audio" ] == b"fake_audio"
67+ assert payload ["ignore_timestamps" ] is False
68+ assert "language" not in payload
69+
70+ def test_transcribe_with_language (
71+ self , asr_client , mock_client_wrapper , sample_asr_response
72+ ):
73+ """Test transcription with language specified."""
74+ mock_response = Mock ()
75+ mock_response .json .return_value = sample_asr_response
76+ mock_client_wrapper .request .return_value = mock_response
77+
78+ asr_client .transcribe (audio = b"fake_audio" , language = "en" )
79+
80+ call_args = mock_client_wrapper .request .call_args
81+ payload = ormsgpack .unpackb (call_args [1 ]["content" ])
82+ assert payload ["language" ] == "en"
83+
84+ def test_transcribe_without_timestamps (
85+ self , asr_client , mock_client_wrapper , sample_asr_response
86+ ):
87+ """Test transcription with timestamps disabled."""
88+ mock_response = Mock ()
89+ mock_response .json .return_value = sample_asr_response
90+ mock_client_wrapper .request .return_value = mock_response
91+
92+ asr_client .transcribe (audio = b"fake_audio" , include_timestamps = False )
93+
94+ call_args = mock_client_wrapper .request .call_args
95+ payload = ormsgpack .unpackb (call_args [1 ]["content" ])
96+ assert payload ["ignore_timestamps" ] is True
97+
98+ def test_transcribe_with_request_options (
99+ self , asr_client , mock_client_wrapper , sample_asr_response
100+ ):
101+ """Test transcription with custom request options."""
102+ mock_response = Mock ()
103+ mock_response .json .return_value = sample_asr_response
104+ mock_client_wrapper .request .return_value = mock_response
105+
106+ request_options = RequestOptions (timeout = 60.0 )
107+ asr_client .transcribe (audio = b"fake_audio" , request_options = request_options )
108+
109+ call_args = mock_client_wrapper .request .call_args
110+ assert call_args [1 ]["request_options" ] == request_options
111+
112+ def test_transcribe_language_none (
113+ self , asr_client , mock_client_wrapper , sample_asr_response
114+ ):
115+ """Test transcription with language explicitly set to None."""
116+ mock_response = Mock ()
117+ mock_response .json .return_value = sample_asr_response
118+ mock_client_wrapper .request .return_value = mock_response
119+
120+ asr_client .transcribe (audio = b"fake_audio" , language = None )
121+
122+ call_args = mock_client_wrapper .request .call_args
123+ payload = ormsgpack .unpackb (call_args [1 ]["content" ])
124+ assert payload ["language" ] is None
125+
126+
127+ class TestAsyncASRClient :
128+ """Test asynchronous AsyncASRClient."""
129+
130+ @pytest .mark .asyncio
131+ async def test_transcribe_basic (
132+ self , async_asr_client , async_mock_client_wrapper , sample_asr_response
133+ ):
134+ """Test basic transcription (async)."""
135+ mock_response = Mock ()
136+ mock_response .json .return_value = sample_asr_response
137+ async_mock_client_wrapper .request = AsyncMock (return_value = mock_response )
138+
139+ result = await async_asr_client .transcribe (audio = b"fake_audio" )
140+
141+ assert isinstance (result , ASRResponse )
142+ assert result .text == "Hello world"
143+ assert result .duration == 1500.0
144+ assert len (result .segments ) == 2
145+
146+ async_mock_client_wrapper .request .assert_called_once ()
147+ call_args = async_mock_client_wrapper .request .call_args
148+ assert call_args [0 ][0 ] == "POST"
149+ assert call_args [0 ][1 ] == "/v1/asr"
150+
151+ payload = ormsgpack .unpackb (call_args [1 ]["content" ])
152+ assert payload ["audio" ] == b"fake_audio"
153+ assert payload ["ignore_timestamps" ] is False
154+ assert "language" not in payload
155+
156+ @pytest .mark .asyncio
157+ async def test_transcribe_with_language (
158+ self , async_asr_client , async_mock_client_wrapper , sample_asr_response
159+ ):
160+ """Test transcription with language specified (async)."""
161+ mock_response = Mock ()
162+ mock_response .json .return_value = sample_asr_response
163+ async_mock_client_wrapper .request = AsyncMock (return_value = mock_response )
164+
165+ await async_asr_client .transcribe (audio = b"fake_audio" , language = "zh" )
166+
167+ call_args = async_mock_client_wrapper .request .call_args
168+ payload = ormsgpack .unpackb (call_args [1 ]["content" ])
169+ assert payload ["language" ] == "zh"
170+
171+ @pytest .mark .asyncio
172+ async def test_transcribe_without_timestamps (
173+ self , async_asr_client , async_mock_client_wrapper , sample_asr_response
174+ ):
175+ """Test transcription with timestamps disabled (async)."""
176+ mock_response = Mock ()
177+ mock_response .json .return_value = sample_asr_response
178+ async_mock_client_wrapper .request = AsyncMock (return_value = mock_response )
179+
180+ await async_asr_client .transcribe (
181+ audio = b"fake_audio" , include_timestamps = False
182+ )
183+
184+ call_args = async_mock_client_wrapper .request .call_args
185+ payload = ormsgpack .unpackb (call_args [1 ]["content" ])
186+ assert payload ["ignore_timestamps" ] is True
187+
188+ @pytest .mark .asyncio
189+ async def test_transcribe_with_request_options (
190+ self , async_asr_client , async_mock_client_wrapper , sample_asr_response
191+ ):
192+ """Test transcription with custom request options (async)."""
193+ mock_response = Mock ()
194+ mock_response .json .return_value = sample_asr_response
195+ async_mock_client_wrapper .request = AsyncMock (return_value = mock_response )
196+
197+ request_options = RequestOptions (timeout = 60.0 )
198+ await async_asr_client .transcribe (
199+ audio = b"fake_audio" , request_options = request_options
200+ )
201+
202+ call_args = async_mock_client_wrapper .request .call_args
203+ assert call_args [1 ]["request_options" ] == request_options
0 commit comments