Skip to content

Commit 584a285

Browse files
committed
test: add unit tests for ASRClient and AsyncASRClient
1 parent 09e9aba commit 584a285

1 file changed

Lines changed: 203 additions & 0 deletions

File tree

tests/unit/test_asr.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""Tests for ASR namespace client."""
2+
3+
from unittest.mock import AsyncMock, Mock
4+
5+
import ormsgpack
6+
import pytest
7+
8+
from fishaudio.core import AsyncClientWrapper, ClientWrapper, RequestOptions
9+
from fishaudio.resources.asr import ASRClient, AsyncASRClient
10+
from fishaudio.types import ASRResponse
11+
12+
13+
@pytest.fixture
14+
def mock_client_wrapper(mock_api_key):
15+
"""Mock client wrapper."""
16+
wrapper = Mock(spec=ClientWrapper)
17+
wrapper.api_key = mock_api_key
18+
return wrapper
19+
20+
21+
@pytest.fixture
22+
def async_mock_client_wrapper(mock_api_key):
23+
"""Mock async client wrapper."""
24+
wrapper = Mock(spec=AsyncClientWrapper)
25+
wrapper.api_key = mock_api_key
26+
return wrapper
27+
28+
29+
@pytest.fixture
30+
def asr_client(mock_client_wrapper):
31+
"""ASRClient instance with mocked wrapper."""
32+
return ASRClient(mock_client_wrapper)
33+
34+
35+
@pytest.fixture
36+
def async_asr_client(async_mock_client_wrapper):
37+
"""AsyncASRClient instance with mocked wrapper."""
38+
return AsyncASRClient(async_mock_client_wrapper)
39+
40+
41+
class TestASRClient:
42+
"""Test synchronous ASRClient."""
43+
44+
def test_transcribe_basic(
45+
self, asr_client, mock_client_wrapper, sample_asr_response
46+
):
47+
"""Test basic transcription."""
48+
mock_response = Mock()
49+
mock_response.json.return_value = sample_asr_response
50+
mock_client_wrapper.request.return_value = mock_response
51+
52+
result = asr_client.transcribe(audio=b"fake_audio")
53+
54+
assert isinstance(result, ASRResponse)
55+
assert result.text == "Hello world"
56+
assert result.duration == 1500.0
57+
assert len(result.segments) == 2
58+
59+
mock_client_wrapper.request.assert_called_once()
60+
call_args = mock_client_wrapper.request.call_args
61+
assert call_args[0][0] == "POST"
62+
assert call_args[0][1] == "/v1/asr"
63+
assert call_args[1]["headers"]["Content-Type"] == "application/msgpack"
64+
65+
payload = ormsgpack.unpackb(call_args[1]["content"])
66+
assert payload["audio"] == b"fake_audio"
67+
assert payload["ignore_timestamps"] is False
68+
assert "language" not in payload
69+
70+
def test_transcribe_with_language(
71+
self, asr_client, mock_client_wrapper, sample_asr_response
72+
):
73+
"""Test transcription with language specified."""
74+
mock_response = Mock()
75+
mock_response.json.return_value = sample_asr_response
76+
mock_client_wrapper.request.return_value = mock_response
77+
78+
asr_client.transcribe(audio=b"fake_audio", language="en")
79+
80+
call_args = mock_client_wrapper.request.call_args
81+
payload = ormsgpack.unpackb(call_args[1]["content"])
82+
assert payload["language"] == "en"
83+
84+
def test_transcribe_without_timestamps(
85+
self, asr_client, mock_client_wrapper, sample_asr_response
86+
):
87+
"""Test transcription with timestamps disabled."""
88+
mock_response = Mock()
89+
mock_response.json.return_value = sample_asr_response
90+
mock_client_wrapper.request.return_value = mock_response
91+
92+
asr_client.transcribe(audio=b"fake_audio", include_timestamps=False)
93+
94+
call_args = mock_client_wrapper.request.call_args
95+
payload = ormsgpack.unpackb(call_args[1]["content"])
96+
assert payload["ignore_timestamps"] is True
97+
98+
def test_transcribe_with_request_options(
99+
self, asr_client, mock_client_wrapper, sample_asr_response
100+
):
101+
"""Test transcription with custom request options."""
102+
mock_response = Mock()
103+
mock_response.json.return_value = sample_asr_response
104+
mock_client_wrapper.request.return_value = mock_response
105+
106+
request_options = RequestOptions(timeout=60.0)
107+
asr_client.transcribe(audio=b"fake_audio", request_options=request_options)
108+
109+
call_args = mock_client_wrapper.request.call_args
110+
assert call_args[1]["request_options"] == request_options
111+
112+
def test_transcribe_language_none(
113+
self, asr_client, mock_client_wrapper, sample_asr_response
114+
):
115+
"""Test transcription with language explicitly set to None."""
116+
mock_response = Mock()
117+
mock_response.json.return_value = sample_asr_response
118+
mock_client_wrapper.request.return_value = mock_response
119+
120+
asr_client.transcribe(audio=b"fake_audio", language=None)
121+
122+
call_args = mock_client_wrapper.request.call_args
123+
payload = ormsgpack.unpackb(call_args[1]["content"])
124+
assert payload["language"] is None
125+
126+
127+
class TestAsyncASRClient:
128+
"""Test asynchronous AsyncASRClient."""
129+
130+
@pytest.mark.asyncio
131+
async def test_transcribe_basic(
132+
self, async_asr_client, async_mock_client_wrapper, sample_asr_response
133+
):
134+
"""Test basic transcription (async)."""
135+
mock_response = Mock()
136+
mock_response.json.return_value = sample_asr_response
137+
async_mock_client_wrapper.request = AsyncMock(return_value=mock_response)
138+
139+
result = await async_asr_client.transcribe(audio=b"fake_audio")
140+
141+
assert isinstance(result, ASRResponse)
142+
assert result.text == "Hello world"
143+
assert result.duration == 1500.0
144+
assert len(result.segments) == 2
145+
146+
async_mock_client_wrapper.request.assert_called_once()
147+
call_args = async_mock_client_wrapper.request.call_args
148+
assert call_args[0][0] == "POST"
149+
assert call_args[0][1] == "/v1/asr"
150+
151+
payload = ormsgpack.unpackb(call_args[1]["content"])
152+
assert payload["audio"] == b"fake_audio"
153+
assert payload["ignore_timestamps"] is False
154+
assert "language" not in payload
155+
156+
@pytest.mark.asyncio
157+
async def test_transcribe_with_language(
158+
self, async_asr_client, async_mock_client_wrapper, sample_asr_response
159+
):
160+
"""Test transcription with language specified (async)."""
161+
mock_response = Mock()
162+
mock_response.json.return_value = sample_asr_response
163+
async_mock_client_wrapper.request = AsyncMock(return_value=mock_response)
164+
165+
await async_asr_client.transcribe(audio=b"fake_audio", language="zh")
166+
167+
call_args = async_mock_client_wrapper.request.call_args
168+
payload = ormsgpack.unpackb(call_args[1]["content"])
169+
assert payload["language"] == "zh"
170+
171+
@pytest.mark.asyncio
172+
async def test_transcribe_without_timestamps(
173+
self, async_asr_client, async_mock_client_wrapper, sample_asr_response
174+
):
175+
"""Test transcription with timestamps disabled (async)."""
176+
mock_response = Mock()
177+
mock_response.json.return_value = sample_asr_response
178+
async_mock_client_wrapper.request = AsyncMock(return_value=mock_response)
179+
180+
await async_asr_client.transcribe(
181+
audio=b"fake_audio", include_timestamps=False
182+
)
183+
184+
call_args = async_mock_client_wrapper.request.call_args
185+
payload = ormsgpack.unpackb(call_args[1]["content"])
186+
assert payload["ignore_timestamps"] is True
187+
188+
@pytest.mark.asyncio
189+
async def test_transcribe_with_request_options(
190+
self, async_asr_client, async_mock_client_wrapper, sample_asr_response
191+
):
192+
"""Test transcription with custom request options (async)."""
193+
mock_response = Mock()
194+
mock_response.json.return_value = sample_asr_response
195+
async_mock_client_wrapper.request = AsyncMock(return_value=mock_response)
196+
197+
request_options = RequestOptions(timeout=60.0)
198+
await async_asr_client.transcribe(
199+
audio=b"fake_audio", request_options=request_options
200+
)
201+
202+
call_args = async_mock_client_wrapper.request.call_args
203+
assert call_args[1]["request_options"] == request_options

0 commit comments

Comments
 (0)