diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 2be0165bb3..126a41fdaf 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -1,3 +1,5 @@ +import re + import httpx from openai import AsyncOpenAI @@ -8,6 +10,13 @@ from ..register import register_provider_adapter +def _normalize_api_base(api_base: str) -> str: + api_base = api_base.strip().removesuffix("/").removesuffix("/embeddings") + if api_base and not re.search(r"/v\d+$", api_base): + api_base = api_base + "/v1" + return api_base + + @register_provider_adapter( "openai_embedding", "OpenAI API Embedding 提供商适配器", @@ -24,15 +33,9 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None: if proxy: logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}") http_client = httpx.AsyncClient(proxy=proxy) - api_base = ( + api_base = _normalize_api_base( provider_config.get("embedding_api_base", "https://api.openai.com/v1") - .strip() - .removesuffix("/") - .removesuffix("/embeddings") ) - if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"): - # /v4 see #5699 - api_base = api_base + "/v1" logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}") self.client = AsyncOpenAI( api_key=provider_config.get("embedding_api_key"), diff --git a/tests/test_openai_embedding_source.py b/tests/test_openai_embedding_source.py new file mode 100644 index 0000000000..39a38d9455 --- /dev/null +++ b/tests/test_openai_embedding_source.py @@ -0,0 +1,18 @@ +from astrbot.core.provider.sources.openai_embedding_source import _normalize_api_base + + +def test_openai_embedding_api_base_keeps_version_suffixes(): + assert ( + _normalize_api_base("https://ark.cn-beijing.volces.com/api/plan/v3") + == "https://ark.cn-beijing.volces.com/api/plan/v3" + ) + assert _normalize_api_base("https://example.test/v4") == "https://example.test/v4" + + +def test_openai_embedding_api_base_adds_default_version(): + assert _normalize_api_base("https://example.test/openai") == ( + "https://example.test/openai/v1" + ) + assert _normalize_api_base("https://example.test/v1/embeddings") == ( + "https://example.test/v1" + )