From e3bb4355fa6548ce50c9d12c6a618a14f62a91af Mon Sep 17 00:00:00 2001 From: AnegasakiNene <990126341@qq.com> Date: Mon, 8 Jun 2026 13:04:22 +0800 Subject: [PATCH 1/2] feat(stt): honor proxy in OpenAI Whisper provider --- .../provider/sources/whisper_api_source.py | 19 ++++++ tests/test_whisper_api_source.py | 63 ++++++++++++++++++- 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index ecdd685e2a..2e1fa940db 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -1,12 +1,15 @@ import os import uuid +from typing import Any +import httpx from openai import NOT_GIVEN, AsyncOpenAI from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.io import download_file from astrbot.core.utils.media_utils import convert_audio_to_wav +from astrbot.core.utils.network_utils import create_proxy_client from astrbot.core.utils.tencent_record_helper import ( convert_to_pcm_wav, tencent_silk_to_wav, @@ -35,10 +38,26 @@ def __init__( api_key=self.chosen_api_key, base_url=provider_config.get("api_base"), timeout=provider_config.get("timeout", NOT_GIVEN), + http_client=self._create_http_client(provider_config), ) self.set_model(provider_config["model"]) + def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient: + proxy = provider_config.get("proxy", "") + httpx_module: Any = httpx + try: + from openai import _base_client as openai_base_client + + httpx_module = getattr(openai_base_client, "httpx", httpx) + except ImportError: + pass + return create_proxy_client( + "OpenAI Whisper", + proxy, + httpx_module=httpx_module, + ) + async def _get_audio_format(self, file_path) -> str | None: # 定义要检测的头部字节 silk_header = b"SILK" diff --git a/tests/test_whisper_api_source.py b/tests/test_whisper_api_source.py index dd4a456a3f..f8f8ca4f8e 100644 --- a/tests/test_whisper_api_source.py +++ b/tests/test_whisper_api_source.py @@ -4,9 +4,18 @@ import pytest +import astrbot.core.provider.sources.whisper_api_source as whisper_api_source from astrbot.core.provider.sources.whisper_api_source import ProviderOpenAIWhisperAPI +class _FakeAsyncOpenAI: + def __init__(self, **kwargs): + self.kwargs = kwargs + + async def close(self): + return None + + def _make_provider() -> ProviderOpenAIWhisperAPI: provider = ProviderOpenAIWhisperAPI( provider_config={ @@ -28,6 +37,56 @@ def _make_provider() -> ProviderOpenAIWhisperAPI: return provider +def test_provider_passes_configured_proxy_to_openai_http_client(monkeypatch): + captured: dict[str, object] = {} + fake_http_client = object() + + def fake_create_proxy_client( + provider_label: str, + proxy: str | None = None, + headers: dict[str, str] | None = None, + verify=None, + httpx_module=None, + ): + captured["provider_label"] = provider_label + captured["proxy"] = proxy + captured["headers"] = headers + captured["httpx_module"] = httpx_module + return fake_http_client + + monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) + monkeypatch.setattr( + whisper_api_source, + "create_proxy_client", + fake_create_proxy_client, + ) + + provider = ProviderOpenAIWhisperAPI( + provider_config={ + "id": "test-whisper-api", + "type": "openai_whisper_api", + "model": "whisper-1", + "api_key": "test-key", + "api_base": "https://api.example.com/v1", + "proxy": "http://127.0.0.1:7890", + "timeout": 30, + }, + provider_settings={}, + ) + + assert provider.client.kwargs["api_key"] == "test-key" + assert provider.client.kwargs["base_url"] == "https://api.example.com/v1" + assert provider.client.kwargs["timeout"] == 30 + assert provider.client.kwargs["http_client"] is fake_http_client + assert captured["provider_label"] == "OpenAI Whisper" + assert captured["proxy"] == "http://127.0.0.1:7890" + assert captured["headers"] is None + + from openai import _base_client as openai_base_client + + assert captured["httpx_module"] is openai_base_client.httpx + + @pytest.mark.asyncio async def test_get_text_converts_opus_files_to_wav_before_transcription( tmp_path: Path, monkeypatch: pytest.MonkeyPatch @@ -38,7 +97,9 @@ async def test_get_text_converts_opus_files_to_wav_before_transcription( conversions: list[tuple[str, str]] = [] - async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None): + async def fake_convert_audio_to_wav( + audio_path: str, output_path: str | None = None + ): assert output_path is not None conversions.append((audio_path, output_path)) Path(output_path).write_bytes(b"fake wav data") From f42e27fb54e6e765181c9bbd3b6a7e41cf169859 Mon Sep 17 00:00:00 2001 From: AnegasakiNene <990126341@qq.com> Date: Mon, 8 Jun 2026 13:19:05 +0800 Subject: [PATCH 2/2] fix(stt): close Whisper HTTP client --- .../provider/sources/whisper_api_source.py | 8 +- tests/test_whisper_api_source.py | 88 +++++++++++++++++-- 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 2e1fa940db..c99b5b81cf 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -33,6 +33,7 @@ def __init__( ) -> None: super().__init__(provider_config, provider_settings) self.chosen_api_key = provider_config.get("api_key", "") + self.http_client: httpx.AsyncClient | None = None self.client = AsyncOpenAI( api_key=self.chosen_api_key, @@ -44,7 +45,7 @@ def __init__( self.set_model(provider_config["model"]) def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient: - proxy = provider_config.get("proxy", "") + proxy = provider_config.get("proxy") httpx_module: Any = httpx try: from openai import _base_client as openai_base_client @@ -52,11 +53,12 @@ def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient: httpx_module = getattr(openai_base_client, "httpx", httpx) except ImportError: pass - return create_proxy_client( + self.http_client = create_proxy_client( "OpenAI Whisper", proxy, httpx_module=httpx_module, ) + return self.http_client async def _get_audio_format(self, file_path) -> str | None: # 定义要检测的头部字节 @@ -152,3 +154,5 @@ async def get_text(self, audio_url: str) -> str: async def terminate(self): if self.client: await self.client.close() + if self.http_client: + await self.http_client.aclose() diff --git a/tests/test_whisper_api_source.py b/tests/test_whisper_api_source.py index f8f8ca4f8e..202125659a 100644 --- a/tests/test_whisper_api_source.py +++ b/tests/test_whisper_api_source.py @@ -11,9 +11,7 @@ class _FakeAsyncOpenAI: def __init__(self, **kwargs): self.kwargs = kwargs - - async def close(self): - return None + self.close = AsyncMock() def _make_provider() -> ProviderOpenAIWhisperAPI: @@ -39,7 +37,7 @@ def _make_provider() -> ProviderOpenAIWhisperAPI: def test_provider_passes_configured_proxy_to_openai_http_client(monkeypatch): captured: dict[str, object] = {} - fake_http_client = object() + fake_http_client = SimpleNamespace(aclose=AsyncMock()) def fake_create_proxy_client( provider_label: str, @@ -78,13 +76,91 @@ def fake_create_proxy_client( assert provider.client.kwargs["base_url"] == "https://api.example.com/v1" assert provider.client.kwargs["timeout"] == 30 assert provider.client.kwargs["http_client"] is fake_http_client + assert set(provider.client.kwargs) == { + "api_key", + "base_url", + "timeout", + "http_client", + } + assert provider.http_client is fake_http_client assert captured["provider_label"] == "OpenAI Whisper" assert captured["proxy"] == "http://127.0.0.1:7890" assert captured["headers"] is None + assert captured["httpx_module"] is not None + + +def test_provider_uses_default_http_client_when_proxy_missing(monkeypatch): + captured: dict[str, object] = {} + fake_http_client = SimpleNamespace(aclose=AsyncMock()) + + def fake_create_proxy_client( + provider_label: str, + proxy: str | None = None, + headers: dict[str, str] | None = None, + verify=None, + httpx_module=None, + ): + captured["provider_label"] = provider_label + captured["proxy"] = proxy + captured["headers"] = headers + captured["httpx_module"] = httpx_module + return fake_http_client + + monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) + monkeypatch.setattr( + whisper_api_source, + "create_proxy_client", + fake_create_proxy_client, + ) + + provider = ProviderOpenAIWhisperAPI( + provider_config={ + "id": "test-whisper-api", + "type": "openai_whisper_api", + "model": "whisper-1", + "api_key": "test-key", + }, + provider_settings={}, + ) + + assert provider.client.kwargs["http_client"] is fake_http_client + assert set(provider.client.kwargs) == { + "api_key", + "base_url", + "timeout", + "http_client", + } + assert provider.http_client is fake_http_client + assert captured["provider_label"] == "OpenAI Whisper" + assert captured["proxy"] is None + assert captured["headers"] is None + + +@pytest.mark.asyncio +async def test_terminate_closes_openai_client_and_custom_http_client(monkeypatch): + fake_http_client = SimpleNamespace(aclose=AsyncMock()) + + monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI) + monkeypatch.setattr( + whisper_api_source, + "create_proxy_client", + lambda *args, **kwargs: fake_http_client, + ) + + provider = ProviderOpenAIWhisperAPI( + provider_config={ + "id": "test-whisper-api", + "type": "openai_whisper_api", + "model": "whisper-1", + "api_key": "test-key", + }, + provider_settings={}, + ) - from openai import _base_client as openai_base_client + await provider.terminate() - assert captured["httpx_module"] is openai_base_client.httpx + provider.client.close.assert_awaited_once() + fake_http_client.aclose.assert_awaited_once() @pytest.mark.asyncio