diff --git a/src/coding/proxy/config/config.default.yaml b/src/coding/proxy/config/config.default.yaml index 23b9dc0..f91b73e 100644 --- a/src/coding/proxy/config/config.default.yaml +++ b/src/coding/proxy/config/config.default.yaml @@ -685,4 +685,11 @@ native_api: # tiers: ["copilot", "anthropic", "zhipu"] # # 未配置时(默认),所有 Session 使用全局 tiers 顺序。 -session_policies: [] +session_policies: + policies: [] + # 标题前缀 → 供应商自动绑定。 + # 当 Session 标题以指定前缀开头时,自动将该 Session 绑定到对应供应商。 + # 匹配规则按列表顺序求值,首次匹配生效。 + title_vendor_bindings: + - prefix: "# 目标" + vendor: "zhipu" diff --git a/src/coding/proxy/config/schema.py b/src/coding/proxy/config/schema.py index 40e5428..ec8ac5f 100644 --- a/src/coding/proxy/config/schema.py +++ b/src/coding/proxy/config/schema.py @@ -44,7 +44,7 @@ # ── 子模块 re-export ──────────────────────────────────────────── from .server import DatabaseConfig, LoggingConfig, ServerConfig # noqa: F401 -from .session_policy import SessionPoliciesConfig # noqa: F401 +from .session_policy import SessionPoliciesConfig, TitleVendorBinding # noqa: F401 from .vendors import ( # noqa: F401 AlibabaConfig, AnthropicConfig, @@ -350,4 +350,5 @@ def compat_state_path(self) -> Path: "NativeApiConfig", # session policy "SessionPoliciesConfig", + "TitleVendorBinding", ] diff --git a/src/coding/proxy/config/session_policy.py b/src/coding/proxy/config/session_policy.py index cb2c512..c25da7a 100644 --- a/src/coding/proxy/config/session_policy.py +++ b/src/coding/proxy/config/session_policy.py @@ -50,6 +50,22 @@ class SessionPolicy(BaseModel): ) +class TitleVendorBinding(BaseModel): + """标题前缀 → 供应商自动绑定规则.""" + + prefix: str = Field( + min_length=1, + description=( + "标题前缀匹配模式(大小写敏感的 startswith 匹配)。" + "禁止空字符串——空前缀会匹配所有标题,导致全量误绑定。" + ), + ) + vendor: str = Field( + min_length=1, + description="匹配后绑定的目标供应商名称", + ) + + class SessionPoliciesConfig(BaseModel): """顶层 Session 策略配置容器.""" @@ -57,3 +73,11 @@ class SessionPoliciesConfig(BaseModel): default_factory=list, description="Session 路由策略列表,按定义顺序求值,首次匹配生效", ) + title_vendor_bindings: list[TitleVendorBinding] = Field( + default_factory=list, + description=( + "标题前缀 → 供应商自动绑定规则。" + "当 Session 标题以指定前缀开头时,自动绑定到对应供应商。" + "匹配规则按列表顺序求值,首次匹配生效。" + ), + ) diff --git a/src/coding/proxy/routing/executor.py b/src/coding/proxy/routing/executor.py index 20d9c51..9ade9dd 100644 --- a/src/coding/proxy/routing/executor.py +++ b/src/coding/proxy/routing/executor.py @@ -11,10 +11,13 @@ import re import time from collections.abc import AsyncIterator -from typing import Any +from typing import TYPE_CHECKING, Any import httpx +if TYPE_CHECKING: + from ..config.session_policy import TitleVendorBinding + from ..vendors.base import ( NoCompatibleVendorError, RequestCapabilities, @@ -610,6 +613,7 @@ def __init__( session_manager: RouteSessionManager, reauth_coordinator: Any | None = None, session_policy_resolver: SessionPolicyResolver | None = None, + title_vendor_bindings: list[TitleVendorBinding] | None = None, ) -> None: self._router = router self._tiers = tiers @@ -617,6 +621,8 @@ def __init__( self._session_mgr = session_manager self._reauth_coordinator = reauth_coordinator self._policy_resolver = session_policy_resolver or SessionPolicyResolver() + self._title_vendor_bindings = title_vendor_bindings or [] + self._validate_title_vendor_bindings() # Tier 名称 → OAuth provider 名称的映射 self._tier_provider_map: dict[str, str] = { @@ -624,6 +630,26 @@ def __init__( "antigravity": "google", } + def _validate_title_vendor_bindings(self) -> None: + """启动期校验标题绑定引用的 vendor 均存在,缺失则告警. + + 与手动绑定 API(拒绝未知 vendor)的语义对齐:此处不硬失败, + 仅记录警告——避免单条误配置阻断整个代理启动;运行时 + `_resolve_effective_tiers` 会静默跳过未知 vendor 回退默认顺序。 + """ + if not self._title_vendor_bindings: + return + valid = {t.name for t in self._tiers} + for binding in self._title_vendor_bindings: + if binding.vendor not in valid: + logger.warning( + "title_vendor_bindings 引用了未知 vendor %r(前缀 %r);" + "可用 vendor: %s。该绑定将在运行时被静默跳过。", + binding.vendor, + binding.prefix, + sorted(valid), + ) + # ── 公开执行入口 ────────────────────────────────────── def _resolve_effective_tiers(self, session_key: str) -> list[VendorTier]: @@ -650,6 +676,27 @@ def _resolve_effective_tiers(self, session_key: str) -> list[VendorTier]: seen.add(tier.name) return ordered + def _apply_title_based_policy(self, session_key: str, title: str) -> None: + """根据 Session 标题前缀自动绑定供应商. + + 当标题以预配置的前缀开头时,通过 SessionPolicyResolver.upsert() + 将该 Session 绑定到指定供应商,后续请求无需再走默认路由。 + + 仅在新 Session 首次提取标题时调用,避免覆盖手动绑定的策略。 + """ + if not title or not self._title_vendor_bindings: + return + for binding in self._title_vendor_bindings: + if title.startswith(binding.prefix): + self._policy_resolver.upsert(session_key, [binding.vendor]) + logger.info( + "Session title prefix %r matched → auto-bind to %s (session=%s)", + binding.prefix, + binding.vendor, + session_key[:12], + ) + return + def _prepare_body_for_tier( self, body: dict[str, Any], @@ -748,6 +795,7 @@ async def execute_stream( await self._recorder.set_session_title( canonical_request.session_key, title ) + self._apply_title_based_policy(canonical_request.session_key, title) else: # 延迟标题补写: 若 session 尚无标题,尝试从当前请求中提取并回写。 title = _extract_session_title(canonical_request) @@ -934,6 +982,7 @@ async def execute_message( await self._recorder.set_session_title( canonical_request.session_key, title ) + self._apply_title_based_policy(canonical_request.session_key, title) else: # 延迟标题补写: 若 session 尚无标题,尝试从当前请求中提取并回写。 title = _extract_session_title(canonical_request) diff --git a/src/coding/proxy/routing/router.py b/src/coding/proxy/routing/router.py index 32757a8..a9b5379 100644 --- a/src/coding/proxy/routing/router.py +++ b/src/coding/proxy/routing/router.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: + from ..config.session_policy import TitleVendorBinding from ..pricing import PricingTable from .executor import _RouteExecutor @@ -38,6 +39,7 @@ def __init__( reauth_coordinator: Any | None = None, compat_session_store: CompatSessionStore | None = None, session_policy_resolver: SessionPolicyResolver | None = None, + title_vendor_bindings: list[TitleVendorBinding] | None = None, ) -> None: if not tiers: raise ValueError("至少需要一个供应商层级") @@ -56,6 +58,7 @@ def __init__( session_manager=self._session_mgr, reauth_coordinator=reauth_coordinator, session_policy_resolver=session_policy_resolver, + title_vendor_bindings=title_vendor_bindings, ) def set_pricing_table(self, table: PricingTable) -> None: diff --git a/src/coding/proxy/server/app.py b/src/coding/proxy/server/app.py index 5ce8011..396b89f 100644 --- a/src/coding/proxy/server/app.py +++ b/src/coding/proxy/server/app.py @@ -161,6 +161,7 @@ def create_app(config: ProxyConfig | None = None) -> FastAPI: reauth_coordinator, compat_session_store, session_policy_resolver=SessionPolicyResolver(config.session_policies.policies), + title_vendor_bindings=config.session_policies.title_vendor_bindings, ) app = FastAPI(title="coding-proxy", version=__version__, lifespan=lifespan) diff --git a/tests/test_router_executor.py b/tests/test_router_executor.py index d8455cd..2e5c13e 100644 --- a/tests/test_router_executor.py +++ b/tests/test_router_executor.py @@ -19,6 +19,7 @@ CompatibilityStatus, build_canonical_request, ) +from coding.proxy.config.session_policy import TitleVendorBinding from coding.proxy.routing.executor import ( _FALLBACK_TITLE_MAX_LEN, _SESSION_TITLE_MAX_LEN, @@ -36,6 +37,7 @@ _sanitize_user_text, ) from coding.proxy.routing.session_manager import RouteSessionManager +from coding.proxy.routing.session_policy import SessionPolicyResolver from coding.proxy.routing.tier import VendorTier from coding.proxy.routing.usage_recorder import UsageRecorder from coding.proxy.vendors.base import ( @@ -133,6 +135,19 @@ def _executor(tiers: list[VendorTier] | None = None, **kwargs) -> _RouteExecutor ) +def _stub_session_manager(is_new: bool = True) -> MagicMock: + """构造返回指定 is_new 的 session manager stub. + + 默认 RouteSessionManager(无 store) 的 get_or_create_record 恒返回 + is_new=False;测试新 session 路径需显式 stub 返回 is_new=True。 + """ + mgr = MagicMock(spec=RouteSessionManager) + mgr.get_or_create_record = AsyncMock(return_value=(None, is_new)) + mgr.apply_compat_context = MagicMock() + mgr.persist_session = AsyncMock() + return mgr + + # ── _VENDOR_PROTOCOL_LABEL_MAP ─────────────────────────── @@ -2656,3 +2671,278 @@ def test_fallback_cascade_full(self): ] req = self._build_request(messages) assert _extract_session_title(req) == "[Session] test-model" + + +class TestApplyTitleBasedPolicy: + """``_apply_title_based_policy`` 标题前缀自动绑定测试.""" + + def test_prefix_match_triggers_upsert(self): + """标题以配置前缀开头 → 触发 upsert 绑定到目标 vendor.""" + resolver = SessionPolicyResolver() + executor = _executor( + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + executor._apply_title_based_policy("sess-1", "# 目标 (Goal) 实现功能 X") + policy = resolver.resolve("sess-1") + assert policy is not None + assert policy.tiers == ["zhipu"] + assert policy.name == "runtime:sess-1" + + def test_prefix_match_without_parenthesis(self): + """前缀匹配不要求括号后缀,纯 '# 目标' 开头即命中.""" + resolver = SessionPolicyResolver() + executor = _executor( + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + executor._apply_title_based_policy("sess-2", "# 目标 详细计划") + policy = resolver.resolve("sess-2") + assert policy is not None + assert policy.tiers == ["zhipu"] + + def test_non_matching_title_no_binding(self): + """非匹配标题 → 不创建绑定.""" + resolver = SessionPolicyResolver() + executor = _executor( + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + executor._apply_title_based_policy("sess-3", "普通会话标题") + assert resolver.resolve("sess-3") is None + + def test_empty_title_no_binding(self): + """空标题 → 提前返回,不创建绑定.""" + resolver = SessionPolicyResolver() + executor = _executor( + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + executor._apply_title_based_policy("sess-4", "") + assert resolver.resolve("sess-4") is None + + def test_no_bindings_configured_no_binding(self): + """未配置任何绑定规则 → 提前返回,等效禁用.""" + resolver = SessionPolicyResolver() + executor = _executor( + session_policy_resolver=resolver, + title_vendor_bindings=[], + ) + executor._apply_title_based_policy("sess-5", "# 目标 任意标题") + assert resolver.resolve("sess-5") is None + + def test_prefix_in_middle_no_match(self): + """前缀出现在标题中间 → startswith 不匹配,不绑定.""" + resolver = SessionPolicyResolver() + executor = _executor( + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + executor._apply_title_based_policy("sess-6", "前缀 # 目标 在中间") + assert resolver.resolve("sess-6") is None + + def test_multiple_bindings_first_match_wins(self): + """多条规则按顺序匹配,首次命中生效.""" + resolver = SessionPolicyResolver() + executor = _executor( + session_policy_resolver=resolver, + title_vendor_bindings=[ + TitleVendorBinding(prefix="# 目标", vendor="zhipu"), + TitleVendorBinding(prefix="# Review", vendor="anthropic"), + ], + ) + executor._apply_title_based_policy("sess-7", "# Review 代码审查") + policy = resolver.resolve("sess-7") + assert policy is not None + assert policy.tiers == ["anthropic"] + + def test_bound_tier_promoted_to_front(self): + """绑定后 _resolve_effective_tiers 将目标 vendor 提升至首位.""" + resolver = SessionPolicyResolver() + tiers = [ + _make_tier(_mock_vendor("anthropic")), + _make_tier(_mock_vendor("zhipu")), + ] + executor = _executor( + tiers=tiers, + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + executor._apply_title_based_policy("sess-8", "# 目标 实现 X") + effective = executor._resolve_effective_tiers("sess-8") + assert effective[0].name == "zhipu" + # 未提及的 vendor 仍保留在末尾 + assert {t.name for t in effective} == {"zhipu", "anthropic"} + + def test_non_matching_session_uses_default_order(self): + """非匹配 session 的 tier 顺序保持全局默认.""" + resolver = SessionPolicyResolver() + tiers = [ + _make_tier(_mock_vendor("anthropic")), + _make_tier(_mock_vendor("zhipu")), + ] + executor = _executor( + tiers=tiers, + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + executor._apply_title_based_policy("sess-9", "普通标题") + effective = executor._resolve_effective_tiers("sess-9") + assert [t.name for t in effective] == ["anthropic", "zhipu"] + + def test_nonexistent_vendor_skipped_in_resolution(self): + """绑定不存在的 vendor → upsert 成功但 tier 解析跳过该 vendor.""" + resolver = SessionPolicyResolver() + tiers = [ + _make_tier(_mock_vendor("anthropic")), + _make_tier(_mock_vendor("zhipu")), + ] + executor = _executor( + tiers=tiers, + session_policy_resolver=resolver, + title_vendor_bindings=[ + TitleVendorBinding(prefix="# 目标", vendor="nonexistent") + ], + ) + executor._apply_title_based_policy("sess-10", "# 目标 X") + effective = executor._resolve_effective_tiers("sess-10") + # 不存在的 vendor 被跳过,回退到全局默认顺序 + assert [t.name for t in effective] == ["anthropic", "zhipu"] + + @pytest.mark.asyncio + async def test_execute_message_end_to_end_binding(self): + """端到端: 新 session 首请求标题命中前缀 → 创建绑定并路由到 zhipu.""" + resolver = SessionPolicyResolver() + tiers = [ + _make_tier(_mock_vendor("anthropic")), + _make_tier(_mock_vendor("zhipu")), + ] + executor = _executor( + tiers=tiers, + session_mgr=_stub_session_manager(is_new=True), + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + body = { + "model": "test", + "metadata": {"user_id": "session-abc"}, + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "# 目标 实现 X"}]} + ], + } + resp = await executor.execute_message(body, {}) + assert resp.status_code == 200 + # 从 body 解析出的 session_key 应已建立运行时绑定 + canonical = build_canonical_request(body, {}) + policy = resolver.resolve(canonical.session_key) + assert policy is not None + assert policy.tiers == ["zhipu"] + + @pytest.mark.asyncio + async def test_execute_message_existing_session_no_binding(self): + """端到端: 已存在 session(is_new=False) 不触发标题绑定.""" + resolver = SessionPolicyResolver() + tiers = [ + _make_tier(_mock_vendor("anthropic")), + _make_tier(_mock_vendor("zhipu")), + ] + executor = _executor( + tiers=tiers, + session_mgr=_stub_session_manager(is_new=False), + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + body = { + "model": "test", + "metadata": {"user_id": "session-existing"}, + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "# 目标 实现 X"}]} + ], + } + await executor.execute_message(body, {}) + canonical = build_canonical_request(body, {}) + # is_new=False → 不调用 _apply_title_based_policy,无运行时绑定 + assert resolver.resolve(canonical.session_key) is None + + @pytest.mark.asyncio + async def test_execute_stream_end_to_end_binding(self): + """端到端(流式): 新 session 首请求标题命中前缀 → 创建绑定.""" + resolver = SessionPolicyResolver() + zhipu_vendor = _mock_vendor("zhipu") + zhipu_vendor.send_message_stream = MagicMock( + return_value=_async_chunks([b'{"type":"message_stop"}']) + ) + tiers = [ + _make_tier(_mock_vendor("anthropic")), + _make_tier(zhipu_vendor), + ] + executor = _executor( + tiers=tiers, + session_mgr=_stub_session_manager(is_new=True), + session_policy_resolver=resolver, + title_vendor_bindings=[TitleVendorBinding(prefix="# 目标", vendor="zhipu")], + ) + body = { + "model": "test", + "metadata": {"user_id": "session-stream"}, + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "# 目标 流式任务"}], + } + ], + } + chunks = [chunk async for chunk, _ in executor.execute_stream(body, {})] + assert chunks # 有数据返回 + canonical = build_canonical_request(body, {}) + policy = resolver.resolve(canonical.session_key) + assert policy is not None + assert policy.tiers == ["zhipu"] + + def test_empty_prefix_rejected_by_validation(self): + """空 prefix 在模型校验阶段即被拒绝,杜绝全量误绑定.""" + import pydantic + + with pytest.raises(pydantic.ValidationError): + TitleVendorBinding(prefix="", vendor="zhipu") + + def test_empty_vendor_rejected_by_validation(self): + """空 vendor 在模型校验阶段即被拒绝.""" + import pydantic + + with pytest.raises(pydantic.ValidationError): + TitleVendorBinding(prefix="# 目标", vendor="") + + def test_unknown_vendor_warns_at_startup(self, caplog): + """构造时引用未知 vendor → 记录启动告警.""" + import logging as _logging + + tiers = [_make_tier(_mock_vendor("anthropic"))] + with caplog.at_level(_logging.WARNING, logger="coding.proxy.routing.executor"): + _executor( + tiers=tiers, + session_policy_resolver=SessionPolicyResolver(), + title_vendor_bindings=[ + TitleVendorBinding(prefix="# 目标", vendor="nonexistent") + ], + ) + warnings = [r for r in caplog.records if r.levelno == _logging.WARNING] + assert any("nonexistent" in r.message for r in warnings) + + def test_known_vendor_no_startup_warning(self, caplog): + """构造时引用合法 vendor → 不产生告警.""" + import logging as _logging + + tiers = [_make_tier(_mock_vendor("zhipu"))] + with caplog.at_level(_logging.WARNING, logger="coding.proxy.routing.executor"): + _executor( + tiers=tiers, + session_policy_resolver=SessionPolicyResolver(), + title_vendor_bindings=[ + TitleVendorBinding(prefix="# 目标", vendor="zhipu") + ], + ) + binding_warnings = [ + r for r in caplog.records if "title_vendor_bindings" in r.message + ] + assert not binding_warnings