diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index ecdd685e2a..c99b5b81cf 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -1,12 +1,15 @@ import os import uuid +from typing import Any +import httpx from openai import NOT_GIVEN, AsyncOpenAI from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.io import download_file from astrbot.core.utils.media_utils import convert_audio_to_wav +from astrbot.core.utils.network_utils import create_proxy_client from astrbot.core.utils.tencent_record_helper import ( convert_to_pcm_wav, tencent_silk_to_wav, @@ -30,15 +33,33 @@ def __init__( ) -> None: super().__init__(provider_config, provider_settings) self.chosen_api_key = provider_config.get("api_key", "") + self.http_client: httpx.AsyncClient | None = None self.client = AsyncOpenAI( api_key=self.chosen_api_key, base_url=provider_config.get("api_base"), timeout=provider_config.get("timeout", NOT_GIVEN), + http_client=self._create_http_client(provider_config), ) self.set_model(provider_config["model"]) + def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient: + proxy = provider_config.get("proxy") + httpx_module: Any = httpx + try: + from openai import _base_client as openai_base_client + + httpx_module = getattr(openai_base_client, "httpx", httpx) + except ImportError: + pass + self.http_client = create_proxy_client( + "OpenAI Whisper", + proxy, + httpx_module=httpx_module, + ) + return self.http_client + async def _get_audio_format(self, file_path) -> str | None: # 定义要检测的头部字节 silk_header = b"SILK" @@ -133,3 +154,5 @@ async def get_text(self, audio_url: str) -> str: async def terminate(self): if self.client: await self.client.close() + if self.http_client: + await self.http_client.aclose() diff --git a/tests/test_whisper_api_source.py b/tests/test_whisper_api_source.py index dd4a456a3f..202125659a 100644 --- a/tests/test_whisper_api_source.py +++ b/tests/test_whisper_api_source.py @@ -4,9 +4,16 @@ import pytest +import astrbot.core.provider.sources.whisper_api_source as whisper_api_source from astrbot.core.provider.sources.whisper_api_source import ProviderOpenAIWhisperAPI +class _FakeAsyncOpenAI: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.close = AsyncMock() + + def _make_provider() -> ProviderOpenAIWhisperAPI: provider = ProviderOpenAIWhisperAPI( provider_config={ @@ -28,6 +35,134 @@ def _make_provider() -> ProviderOpenAIWhisperAPI: return provider +def test_provider_passes_configured_proxy_to_openai_http_client(monkeypatch): + captured: dict[str, object] = {} + fake_http_client = SimpleNamespace(aclose=AsyncMock()) + + def fake_create_proxy_client( + provider_label: str, + proxy: str | None = None, + headers: dict[str, str] | None = None, + verify=None, + httpx_module=None, + ): + captured["provider_label"] = provider_label + captured["proxy"] = proxy + captured["headers"] = headers + captured["httpx_module"] = httpx_module + return fake_http_client + + monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) + monkeypatch.setattr( + whisper_api_source, + "create_proxy_client", + fake_create_proxy_client, + ) + + provider = ProviderOpenAIWhisperAPI( + provider_config={ + "id": "test-whisper-api", + "type": "openai_whisper_api", + "model": "whisper-1", + "api_key": "test-key", + "api_base": "https://api.example.com/v1", + "proxy": "http://127.0.0.1:7890", + "timeout": 30, + }, + provider_settings={}, + ) + + assert provider.client.kwargs["api_key"] == "test-key" + assert provider.client.kwargs["base_url"] == "https://api.example.com/v1" + assert provider.client.kwargs["timeout"] == 30 + assert provider.client.kwargs["http_client"] is fake_http_client + assert set(provider.client.kwargs) == { + "api_key", + "base_url", + "timeout", + "http_client", + } + assert provider.http_client is fake_http_client + assert captured["provider_label"] == "OpenAI Whisper" + assert captured["proxy"] == "http://127.0.0.1:7890" + assert captured["headers"] is None + assert captured["httpx_module"] is not None + + +def test_provider_uses_default_http_client_when_proxy_missing(monkeypatch): + captured: dict[str, object] = {} + fake_http_client = SimpleNamespace(aclose=AsyncMock()) + + def fake_create_proxy_client( + provider_label: str, + proxy: str | None = None, + headers: dict[str, str] | None = None, + verify=None, + httpx_module=None, + ): + captured["provider_label"] = provider_label + captured["proxy"] = proxy + captured["headers"] = headers + captured["httpx_module"] = httpx_module + return fake_http_client + + monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) + monkeypatch.setattr( + whisper_api_source, + "create_proxy_client", + fake_create_proxy_client, + ) + + provider = ProviderOpenAIWhisperAPI( + provider_config={ + "id": "test-whisper-api", + "type": "openai_whisper_api", + "model": "whisper-1", + "api_key": "test-key", + }, + provider_settings={}, + ) + + assert provider.client.kwargs["http_client"] is fake_http_client + assert set(provider.client.kwargs) == { + "api_key", + "base_url", + "timeout", + "http_client", + } + assert provider.http_client is fake_http_client + assert captured["provider_label"] == "OpenAI Whisper" + assert captured["proxy"] is None + assert captured["headers"] is None + + +@pytest.mark.asyncio +async def test_terminate_closes_openai_client_and_custom_http_client(monkeypatch): + fake_http_client = SimpleNamespace(aclose=AsyncMock()) + + monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) + monkeypatch.setattr( + whisper_api_source, + "create_proxy_client", + lambda *args, **kwargs: fake_http_client, + ) + + provider = ProviderOpenAIWhisperAPI( + provider_config={ + "id": "test-whisper-api", + "type": "openai_whisper_api", + "model": "whisper-1", + "api_key": "test-key", + }, + provider_settings={}, + ) + + await provider.terminate() + + provider.client.close.assert_awaited_once() + fake_http_client.aclose.assert_awaited_once() + + @pytest.mark.asyncio async def test_get_text_converts_opus_files_to_wav_before_transcription( tmp_path: Path, monkeypatch: pytest.MonkeyPatch @@ -38,7 +173,9 @@ async def test_get_text_converts_opus_files_to_wav_before_transcription( conversions: list[tuple[str, str]] = [] - async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None): + async def fake_convert_audio_to_wav( + audio_path: str, output_path: str | None = None + ): assert output_path is not None conversions.append((audio_path, output_path)) Path(output_path).write_bytes(b"fake wav data")