Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions astrbot/core/provider/sources/whisper_api_source.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import os
import uuid
from typing import Any

import httpx
from openai import NOT_GIVEN, AsyncOpenAI

from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.network_utils import create_proxy_client
from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav,
tencent_silk_to_wav,
Expand All @@ -30,15 +33,33 @@ def __init__(
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.http_client: httpx.AsyncClient | None = None

self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base"),
timeout=provider_config.get("timeout", NOT_GIVEN),
http_client=self._create_http_client(provider_config),
)

self.set_model(provider_config["model"])

def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient:
proxy = provider_config.get("proxy")
httpx_module: Any = httpx
try:
from openai import _base_client as openai_base_client

httpx_module = getattr(openai_base_client, "httpx", httpx)
except ImportError:
pass
self.http_client = create_proxy_client(
"OpenAI Whisper",
proxy,
httpx_module=httpx_module,
)
Comment on lines +47 to +60

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

When a custom http_client is passed to AsyncOpenAI, calling self.client.close() does not close the custom client (the OpenAI SDK explicitly skips closing custom clients to avoid lifecycle conflicts). This leads to unclosed client/connection leaks.

To fix this, store the created client in self.http_client so that it can be closed in terminate():

    async def terminate(self):
        if self.client:
            await self.client.close()
        if hasattr(self, "http_client") and self.http_client:
            await self.http_client.aclose()
Suggested change
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient:
proxy = provider_config.get("proxy", "")
httpx_module: Any = httpx
try:
from openai import _base_client as openai_base_client
httpx_module = getattr(openai_base_client, "httpx", httpx)
except ImportError:
pass
return create_proxy_client(
"OpenAI Whisper",
proxy,
httpx_module=httpx_module,
)
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient:
proxy = provider_config.get("proxy", "")
httpx_module: Any = httpx
try:
from openai import _base_client as openai_base_client
httpx_module = getattr(openai_base_client, "httpx", httpx)
except ImportError:
pass
self.http_client = create_proxy_client(
"OpenAI Whisper",
proxy,
httpx_module=httpx_module,
)
return self.http_client

return self.http_client

async def _get_audio_format(self, file_path) -> str | None:
# 定义要检测的头部字节
silk_header = b"SILK"
Expand Down Expand Up @@ -133,3 +154,5 @@ async def get_text(self, audio_url: str) -> str:
async def terminate(self):
if self.client:
await self.client.close()
if self.http_client:
await self.http_client.aclose()
139 changes: 138 additions & 1 deletion tests/test_whisper_api_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,16 @@

import pytest

import astrbot.core.provider.sources.whisper_api_source as whisper_api_source
from astrbot.core.provider.sources.whisper_api_source import ProviderOpenAIWhisperAPI


class _FakeAsyncOpenAI:
def __init__(self, **kwargs):
self.kwargs = kwargs
self.close = AsyncMock()


def _make_provider() -> ProviderOpenAIWhisperAPI:
provider = ProviderOpenAIWhisperAPI(
provider_config={
Expand All @@ -28,6 +35,134 @@ def _make_provider() -> ProviderOpenAIWhisperAPI:
return provider


def test_provider_passes_configured_proxy_to_openai_http_client(monkeypatch):
captured: dict[str, object] = {}
fake_http_client = SimpleNamespace(aclose=AsyncMock())

def fake_create_proxy_client(
provider_label: str,
proxy: str | None = None,
headers: dict[str, str] | None = None,
verify=None,
httpx_module=None,
Comment on lines +38 to +47

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Consider adding a complementary test case for when no proxy is configured

Please also add a test where ProviderOpenAIWhisperAPI is constructed without a proxy key in provider_config (e.g. test_provider_uses_default_http_client_when_proxy_missing). That test should verify create_proxy_client is called with the correct provider label and the default proxy value (empty string or None, per the intended contract), so the no-proxy behavior remains covered during refactors.

Suggested implementation:

def test_provider_passes_configured_proxy_to_openai_http_client(monkeypatch):
    captured: dict[str, object] = {}
    fake_http_client = object()

    def fake_create_proxy_client(
        provider_label: str,
        proxy: str | None = None,
        headers: dict[str, str] | None = None,
        verify=None,
        httpx_module=None,
    ):
        captured["provider_label"] = provider_label
        captured["proxy"] = proxy
        captured["headers"] = headers
        captured["verify"] = verify
        captured["httpx_module"] = httpx_module

        return fake_http_client

    # Adjust the target string here to match where `create_proxy_client` is imported/used
    monkeypatch.setattr(
        "whisper_api_source.create_proxy_client",
        fake_create_proxy_client,
    )

    provider = ProviderOpenAIWhisperAPI(
        provider_config={
            # include whatever keys are normally required by your provider_config
            "provider_label": "openai_whisper_api",
            "proxy": "http://configured-proxy.example.com",
        }
    )

    # Trigger the code path that builds the HTTP client. Adjust this call as needed.
    http_client = provider._get_http_client()

    assert http_client is fake_http_client
    assert captured["provider_label"] == "openai_whisper_api"
    assert captured["proxy"] == "http://configured-proxy.example.com"


def test_provider_uses_default_http_client_when_proxy_missing(monkeypatch):
    captured: dict[str, object] = {}
    fake_http_client = object()

    def fake_create_proxy_client(
        provider_label: str,
        proxy: str | None = None,
        headers: dict[str, str] | None = None,
        verify=None,
        httpx_module=None,
    ):
        captured["provider_label"] = provider_label
        captured["proxy"] = proxy
        captured["headers"] = headers
        captured["verify"] = verify
        captured["httpx_module"] = httpx_module

        return fake_http_client

    # Adjust the target string here to match where `create_proxy_client` is imported/used
    monkeypatch.setattr(
        "whisper_api_source.create_proxy_client",
        fake_create_proxy_client,
    )

    # Construct the provider WITHOUT a `proxy` key in provider_config
    provider = ProviderOpenAIWhisperAPI(
        provider_config={
            # include whatever keys are normally required by your provider_config
            "provider_label": "openai_whisper_api",
            # NOTE: no "proxy" key here on purpose
        }
    )

    # Trigger the code path that builds the HTTP client. Adjust this call as needed.
    http_client = provider._get_http_client()

    assert http_client is fake_http_client
    assert captured["provider_label"] == "openai_whisper_api"
    # Depending on your intended contract, assert the default value here:
    # - if default is "", use `== ""`
    # - if default is None, use `is None`
    assert captured["proxy"] in ("", None)

To integrate this cleanly with your existing codebase, you will likely need to:

  1. Adjust the monkeypatch target: Replace "whisper_api_source.create_proxy_client" with the actual import path used in your production code (e.g. "app.providers.whisper_api_source.create_proxy_client").
  2. Align provider_config:
    • Replace "provider_label": "openai_whisper_api" with the real key/value(s) your ProviderOpenAIWhisperAPI expects (e.g. {"label": "whisper_api"} or whatever your config schema is).
    • Ensure all required config keys (API key, base URL, etc.) are included so the provider initializes correctly.
  3. Use the proper provider label assertion:
    • If create_proxy_client is called with some constant like ProviderOpenAIWhisperAPI.PROVIDER_LABEL, assert against that instead of the hardcoded "openai_whisper_api".
  4. Match the HTTP client creation API:
    • If your provider does not expose _get_http_client(), replace that call with whatever actually triggers the creation of the OpenAI HTTP client (e.g. accessing provider._openai_http_client, calling provider._ensure_client(), etc.).
  5. Proxy default contract:
    • If the intended default proxy is always None, change assert captured["proxy"] in ("", None) to assert captured["proxy"] is None.
    • If it is always an empty string, change it to assert captured["proxy"] == "".

):
captured["provider_label"] = provider_label
captured["proxy"] = proxy
captured["headers"] = headers
captured["httpx_module"] = httpx_module
return fake_http_client

monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI)
monkeypatch.setattr(
whisper_api_source,
"create_proxy_client",
fake_create_proxy_client,
)

provider = ProviderOpenAIWhisperAPI(
provider_config={
"id": "test-whisper-api",
"type": "openai_whisper_api",
"model": "whisper-1",
"api_key": "test-key",
"api_base": "https://api.example.com/v1",
"proxy": "http://127.0.0.1:7890",
"timeout": 30,
},
provider_settings={},
)

assert provider.client.kwargs["api_key"] == "test-key"
assert provider.client.kwargs["base_url"] == "https://api.example.com/v1"
assert provider.client.kwargs["timeout"] == 30
assert provider.client.kwargs["http_client"] is fake_http_client
assert set(provider.client.kwargs) == {
"api_key",
"base_url",
"timeout",
"http_client",
}
assert provider.http_client is fake_http_client
assert captured["provider_label"] == "OpenAI Whisper"
assert captured["proxy"] == "http://127.0.0.1:7890"
assert captured["headers"] is None
assert captured["httpx_module"] is not None


def test_provider_uses_default_http_client_when_proxy_missing(monkeypatch):
captured: dict[str, object] = {}
fake_http_client = SimpleNamespace(aclose=AsyncMock())

def fake_create_proxy_client(
provider_label: str,
proxy: str | None = None,
headers: dict[str, str] | None = None,
verify=None,
httpx_module=None,
):
captured["provider_label"] = provider_label
captured["proxy"] = proxy
captured["headers"] = headers
captured["httpx_module"] = httpx_module
return fake_http_client

monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI)
monkeypatch.setattr(
whisper_api_source,
"create_proxy_client",
fake_create_proxy_client,
)

provider = ProviderOpenAIWhisperAPI(
provider_config={
"id": "test-whisper-api",
"type": "openai_whisper_api",
"model": "whisper-1",
"api_key": "test-key",
},
provider_settings={},
)

assert provider.client.kwargs["http_client"] is fake_http_client
assert set(provider.client.kwargs) == {
"api_key",
"base_url",
"timeout",
"http_client",
}
assert provider.http_client is fake_http_client
assert captured["provider_label"] == "OpenAI Whisper"
assert captured["proxy"] is None
assert captured["headers"] is None


@pytest.mark.asyncio
async def test_terminate_closes_openai_client_and_custom_http_client(monkeypatch):
fake_http_client = SimpleNamespace(aclose=AsyncMock())

monkeypatch.setattr(whisper_api_source, "AsyncOpenAI", _FakeAsyncOpenAI)
monkeypatch.setattr(
whisper_api_source,
"create_proxy_client",
lambda *args, **kwargs: fake_http_client,
)

provider = ProviderOpenAIWhisperAPI(
provider_config={
"id": "test-whisper-api",
"type": "openai_whisper_api",
"model": "whisper-1",
"api_key": "test-key",
},
provider_settings={},
)

await provider.terminate()

provider.client.close.assert_awaited_once()
fake_http_client.aclose.assert_awaited_once()


@pytest.mark.asyncio
async def test_get_text_converts_opus_files_to_wav_before_transcription(
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
Expand All @@ -38,7 +173,9 @@ async def test_get_text_converts_opus_files_to_wav_before_transcription(

conversions: list[tuple[str, str]] = []

async def fake_convert_audio_to_wav(audio_path: str, output_path: str | None = None):
async def fake_convert_audio_to_wav(
audio_path: str, output_path: str | None = None
):
assert output_path is not None
conversions.append((audio_path, output_path))
Path(output_path).write_bytes(b"fake wav data")
Expand Down