Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions astrbot/core/provider/sources/openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import httpx
from openai import AsyncOpenAI

Expand All @@ -8,6 +10,13 @@
from ..register import register_provider_adapter


def _normalize_api_base(api_base: str) -> str:
api_base = api_base.strip().removesuffix("/").removesuffix("/embeddings")
if api_base and not re.search(r"/v\d+$", api_base):
api_base = api_base + "/v1"
return api_base


@register_provider_adapter(
"openai_embedding",
"OpenAI API Embedding 提供商适配器",
Expand All @@ -24,15 +33,9 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
if proxy:
logger.info(f"[OpenAI Embedding] {provider_id} Using proxy: {proxy}")
http_client = httpx.AsyncClient(proxy=proxy)
api_base = (
api_base = _normalize_api_base(
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
.strip()
.removesuffix("/")
.removesuffix("/embeddings")
)
Comment on lines +36 to 38

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

If embedding_api_base is explicitly configured as None or an empty string in the provider configuration, provider_config.get("embedding_api_base", "https://api.openai.com/v1") will return None or "". This will cause _normalize_api_base to raise an AttributeError (since None has no strip method) or fail to apply the default base URL.

Using or instead of the get method's default argument ensures that we fall back to the default OpenAI API base URL if the configured value is falsy (such as None or "").

Suggested change
api_base = _normalize_api_base(
provider_config.get("embedding_api_base", "https://api.openai.com/v1")
.strip()
.removesuffix("/")
.removesuffix("/embeddings")
)
api_base = _normalize_api_base(
provider_config.get("embedding_api_base") or "https://api.openai.com/v1"
)

if api_base and not api_base.endswith("/v1") and not api_base.endswith("/v4"):
# /v4 see #5699
api_base = api_base + "/v1"
logger.info(f"[OpenAI Embedding] {provider_id} Using API Base: {api_base}")
self.client = AsyncOpenAI(
api_key=provider_config.get("embedding_api_key"),
Expand Down
18 changes: 18 additions & 0 deletions tests/test_openai_embedding_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from astrbot.core.provider.sources.openai_embedding_source import _normalize_api_base


def test_openai_embedding_api_base_keeps_version_suffixes():
assert (
_normalize_api_base("https://ark.cn-beijing.volces.com/api/plan/v3")
== "https://ark.cn-beijing.volces.com/api/plan/v3"
)
assert _normalize_api_base("https://example.test/v4") == "https://example.test/v4"


def test_openai_embedding_api_base_adds_default_version():
assert _normalize_api_base("https://example.test/openai") == (
"https://example.test/openai/v1"
)
assert _normalize_api_base("https://example.test/v1/embeddings") == (
Comment on lines +12 to +16

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Cover behavior for empty or whitespace-only bases and non-version /embeddings bases

The default-version behavior is only partially covered here. Please also add cases for:

  • "https://example.test/embeddings""https://example.test/v1", as described in the PR but not asserted.
  • Empty / whitespace-only input (e.g., "", " "), which currently returns an empty string.

These will lock in the /embeddings trimming and fallback behavior and help prevent regressions.

"https://example.test/v1"
)
Loading