diff --git a/CHANGELOG.md b/CHANGELOG.md index f745ec1..086302b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ ## [Unreleased] +- feat(dashboard): Model Calling 实时监控扩展至全 vendor / 全 model(仅 CC 场景),其他 vendor 在 monitor 模式下仅计数不限流,Zhipu 保留 limited 模式 + FIFO 排队; +- feat(concurrency): 新增 `peak_pending_recent` 最近 10s 排队峰值追踪,瞬时排队释放后前端仍可见"曾排队 N" 余晖徽章; +- perf(dashboard): Model Calling 轮询间隔由 5000ms 缩短至 1500ms,提升瞬时排队可观测性; +- refactor(vendors): `ModelConcurrencyLimiter` 重构为 `ModelConcurrencyController`,统一 monitor / limited 双模式抽象(保留旧名别名);并发控制由 vendor 内部迁移至 executor 层 `track_in_flight` 包裹,行为对所有 vendor 一致; + ## [v0.5.0](https://github.com/ThreeFish-AI/coding-proxy/releases/tag/v0.5.0) - 2026-05-27 > [!IMPORTANT] diff --git a/src/coding/proxy/routing/executor.py b/src/coding/proxy/routing/executor.py index 4c37f02..df63e74 100644 --- a/src/coding/proxy/routing/executor.py +++ b/src/coding/proxy/routing/executor.py @@ -689,15 +689,17 @@ async def execute_stream( tier.name, failed_tier_name, session_record, body ) body_for_tier = self._prepare_body_for_tier(body, tier, source_vendor) - async for chunk in tier.vendor.send_message_stream( - body_for_tier, headers - ): - parse_usage_from_chunk( - chunk, - usage, - vendor_label=_VENDOR_PROTOCOL_LABEL_MAP.get(tier.name), - ) - yield chunk, tier.name + _mapped_model = tier.vendor.map_model(body.get("model", "")) + async with tier.vendor.track_in_flight(_mapped_model): + async for chunk in tier.vendor.send_message_stream( + body_for_tier, headers + ): + parse_usage_from_chunk( + chunk, + usage, + vendor_label=_VENDOR_PROTOCOL_LABEL_MAP.get(tier.name), + ) + yield chunk, tier.name info = self._recorder.build_usage_info(usage) if has_missing_input_usage_signals(info): @@ -863,7 +865,9 @@ async def execute_message( tier.name, failed_tier_name, session_record, body ) body_for_tier = self._prepare_body_for_tier(body, tier, source_vendor) - resp = await tier.vendor.send_message(body_for_tier, headers) + _mapped_model = tier.vendor.map_model(body.get("model", "")) + async with tier.vendor.track_in_flight(_mapped_model): + resp = await tier.vendor.send_message(body_for_tier, headers) if resp.status_code < 400: duration = int((time.monotonic() - start) * 1000) diff --git a/src/coding/proxy/server/dashboard.py b/src/coding/proxy/server/dashboard.py index 75dd812..156c098 100644 --- a/src/coding/proxy/server/dashboard.py +++ b/src/coding/proxy/server/dashboard.py @@ -629,6 +629,10 @@ def _build_favicon() -> bytes: background: rgba(251,146,60,.15); color: #fb923c; } + .mc-badge-peak { + background: rgba(148,163,184,.12); + color: #94a3b8; + } .mc-badge-active { background: rgba(74,222,128,.12); color: #4ade80; @@ -1282,10 +1286,12 @@ def _build_favicon() -> bytes: models.push({ vendor: tier.name, model: model, - limit: d.limit || 0, + mode: d.mode || 'limited', + limit: d.limit, in_use: d.in_use || 0, - available: d.available || 0, + available: d.available, pending: d.pending || 0, + peak_pending_recent: d.peak_pending_recent || 0, }); } } @@ -1298,18 +1304,33 @@ def _build_favicon() -> bytes: var html = '
'; for (var k = 0; k < models.length; k++) { var m = models[k]; - var pct = m.limit > 0 ? Math.round((m.in_use / m.limit) * 100) : 0; - var barClass = pct <= 50 ? 'mc-low' : (pct <= 80 ? 'mc-mid' : 'mc-high'); - - html += '
' - + '' + escapeHtml(m.vendor + '/' + m.model) + '' - + '
' - + '
' - + '' + m.in_use - + '/' + m.limit + '' - + (m.pending > 0 ? '⏳ ' + m.pending + '' : '') - + '
' - + '
'; + + if (m.mode === 'monitor') { + // monitor 模式:纯计数徽章,无 limit/进度条 + html += '
' + + '' + escapeHtml(m.vendor + '/' + m.model) + '' + + '
' + + '
' + + '' + m.in_use + '' + + '
' + + '
'; + } else { + // limited 模式:保留现有渲染(进度条 + limit 编辑) + var limit = m.limit || 0; + var pct = limit > 0 ? Math.round((m.in_use / limit) * 100) : 0; + var barClass = pct <= 50 ? 'mc-low' : (pct <= 80 ? 'mc-mid' : 'mc-high'); + + html += '
' + + '' + escapeHtml(m.vendor + '/' + m.model) + '' + + '
' + + '
' + + '' + m.in_use + + '/' + limit + '' + + (m.pending > 0 ? '⏳ ' + m.pending + '' : '') + + (m.pending === 0 && m.peak_pending_recent > 0 ? '🕘 曾排队 ' + m.peak_pending_recent + '' : '') + + '
' + + '
'; + } } html += '
'; wrap.innerHTML = html; @@ -1325,7 +1346,7 @@ def _build_favicon() -> bytes: }).catch(function() {}); } tick(); - _mcTimer = setInterval(tick, 5000); + _mcTimer = setInterval(tick, 1500); } function stopModelCallingPoll() { if (_mcTimer) { clearInterval(_mcTimer); _mcTimer = null; } diff --git a/src/coding/proxy/server/routes.py b/src/coding/proxy/server/routes.py index 7c13d2f..5ed90dc 100644 --- a/src/coding/proxy/server/routes.py +++ b/src/coding/proxy/server/routes.py @@ -254,16 +254,15 @@ async def update_concurrency(request: Request) -> Response: for tier in router.tiers: if tier.name == tier_name: vendor = tier.vendor - update_fn = getattr(vendor, "update_concurrency", None) - if update_fn is None: + try: + vendor.update_concurrency(model, limit) + except ValueError as exc: return json_error_response( - 400, + 422, error_type="invalid_request_error", - message=f"vendor '{tier_name}' does not support concurrency", + message=str(exc), ) - try: - update_fn(model, limit) - except (ValueError, AttributeError) as exc: + except AttributeError as exc: return json_error_response( 400, error_type="invalid_request_error", message=str(exc) ) diff --git a/src/coding/proxy/vendors/base.py b/src/coding/proxy/vendors/base.py index 0c08248..d1434bc 100644 --- a/src/coding/proxy/vendors/base.py +++ b/src/coding/proxy/vendors/base.py @@ -44,6 +44,7 @@ ) from ..compat.session_store import CompatSessionRecord from ..config.schema import FailoverConfig +from .concurrency import ModelConcurrencyController logger = logging.getLogger(__name__) @@ -63,6 +64,8 @@ def __init__( self._client: httpx.AsyncClient | None = None self._compat_trace: CompatibilityTrace | None = None self._compat_session_record: CompatSessionRecord | None = None + # 默认 monitor 模式(仅计数不限流);子类可覆盖为 limited 模式 + self._concurrency_controller = ModelConcurrencyController(None) def _get_client(self) -> httpx.AsyncClient: if self._client is None or self._client.is_closed: @@ -246,8 +249,30 @@ def get_diagnostics(self) -> dict[str, Any]: diagnostics: dict[str, Any] = {} if self._compat_trace is not None: diagnostics["compat"] = self._compat_trace.to_dict() + concurrency = self._concurrency_controller.get_diagnostics() + if concurrency: + diagnostics["concurrency"] = concurrency return diagnostics + def track_in_flight(self, mapped_model: str): + """返回用于追踪在途请求的异步上下文管理器. + + 空 model name 时返回 no-op context(防御性处理)。 + """ + if not mapped_model: + from contextlib import nullcontext + + return nullcontext() + return self._concurrency_controller.track(mapped_model) + + def update_concurrency(self, model: str, limit: int) -> None: + """运行时更新指定模型的并发限制. + + 默认实现委托给 ``_concurrency_controller.set_limit``。 + monitor 模式下抛 ``ValueError``。 + """ + self._concurrency_controller.set_limit(model, limit) + def should_trigger_failover( self, status_code: int, body: dict[str, Any] | None ) -> bool: diff --git a/src/coding/proxy/vendors/concurrency.py b/src/coding/proxy/vendors/concurrency.py index 7944bdd..a5e5763 100644 --- a/src/coding/proxy/vendors/concurrency.py +++ b/src/coding/proxy/vendors/concurrency.py @@ -1,78 +1,122 @@ -"""每模型并发限制器 — 支持运行时动态调整的公平排队. +"""统一并发控制器 — 支持监控 (monitor) 与限流 (limited) 双模式. -为每个映射后的模型(如 ``glm-5v-turbo``)独立维护一个 ``_ConcurrencySlot`, -确保同一时间点该模型的并行请求数不超过配置的上限。当所有槽位被占满时, -新请求按 FIFO 顺序排队等待,直到有槽位释放。 +为每个映射后的模型(如 ``glm-5v-turbo``)独立维护一个 ``_ConcurrencySlot``, +根据模式提供不同语义: + + **monitor 模式** (config=None) + - 仅计数 ``in_use``,不做排队与限流 + - ``pending`` 恒为 0,``available`` / ``limit`` 为 None + - 所有 vendor 默认使用此模式 + + **limited 模式** (config 非 None) + - ``in_use`` 不超过 ``limit`` 时立即获取,超限时 FIFO 排队 + - ``pending`` 反映当前排队数,``peak_pending_recent`` 记录最近 10s 峰值 + - 由 ZhipuVendor 等需限流的 vendor 启用 设计要点: - **惰性创建**:仅在首次请求到达时才为该模型创建 Slot,避免冷启动开销 - - **FIFO 公平**:``asyncio.Event`` + while 循环天然满足 FIFO 排队语义 + - **FIFO 公平**:``asyncio.Event`` + while 循环天然满足 FIFO 排队语义(limited 模式) - **动态调整**:支持运行时修改 per-model limit,无需重启进程 - **按映射后模型名键控**:与上游真实承载能力对齐,而非按客户端请求名 + - **峰值余晖**:记录 ``peak_pending_recent`` 使瞬时排队可观测 """ from __future__ import annotations import asyncio import logging +import time +from collections import deque +from contextlib import asynccontextmanager +from typing import Any, Literal from ..config.vendors import ZhipuConcurrencyConfig logger = logging.getLogger(__name__) +# peak_pending_recent 滑窗宽度(秒) +_PEAK_WINDOW_SECONDS = 10.0 + class _ConcurrencySlot: - """支持动态 limit 的并发槽位. + """支持双模式的并发槽位. - 使用 ``asyncio.Event`` 作为等待/通知原语,在 ``acquire`` 中 await 等待, - 在 ``release`` / ``set_limit`` 中唤醒。``set_limit`` 修改上限后立即唤醒 - 所有等待者,由它们重新判断是否可获得槽位。 + ``limit=None`` (monitor) 时 acquire 走 fast path,仅计数。 + ``limit>0`` (limited) 时在满槽位后 FIFO 排队等待。 """ - def __init__(self, limit: int) -> None: + def __init__(self, limit: int | None) -> None: self._limit = limit self._in_use: int = 0 self._pending: int = 0 self._wake = asyncio.Event() self._wake.set() + # peak_pending_recent 追踪:存储 (timestamp, pending_value) 元组 + self._peak_samples: deque[tuple[float, int]] = deque() - async def acquire(self) -> _ConcurrencySlot: - """获取一个并发槽位,必要时阻塞排队. + async def acquire(self) -> None: + """获取一个并发槽位. - 返回 ``self``,调用方在请求完成后调用 ``release()``。 + monitor 模式 (limit=None):仅 in_use++,永不排队。 + limited 模式 (limit>0):满槽时阻塞等待。 """ - # Fast path + # monitor 模式:仅计数 + if self._limit is None: + self._in_use += 1 + return + + # limited — fast path if self._in_use < self._limit: self._in_use += 1 - return self - # Slow path — 等待槽位释放 + return + + # limited — slow path: FIFO 排队 self._pending += 1 + self._observe_peak() try: while True: self._wake.clear() await self._wake.wait() if self._in_use < self._limit: self._in_use += 1 - return self + return finally: self._pending -= 1 def release(self) -> None: """释放一个并发槽位.""" self._in_use = max(0, self._in_use - 1) - self._wake.set() + if self._limit is not None: + self._wake.set() def set_limit(self, new_limit: int) -> None: """动态调整并发上限. - 增大 limit 时立即唤醒等待者;缩小时已持有的槽位不受影响, - 新 limit 在后续 acquire 中自然生效。 + 仅 limited 模式有效;monitor 模式调用抛 ValueError。 """ + if self._limit is None: + msg = "Cannot set limit on monitor-only slot" + raise ValueError(msg) self._limit = new_limit self._wake.set() + def _observe_peak(self) -> None: + """记录当前 pending 值作为峰值采样点.""" + now = time.monotonic() + self._peak_samples.append((now, self._pending)) + + def _get_peak_pending_recent(self) -> int: + """获取最近窗口内的 peak pending 值.""" + cutoff = time.monotonic() - _PEAK_WINDOW_SECONDS + # 剔除过期采样 + while self._peak_samples and self._peak_samples[0][0] < cutoff: + self._peak_samples.popleft() + if not self._peak_samples: + return 0 + return max(v for _, v in self._peak_samples) + @property - def limit(self) -> int: + def limit(self) -> int | None: return self._limit @property @@ -80,59 +124,93 @@ def in_use(self) -> int: return self._in_use @property - def available(self) -> int: + def available(self) -> int | None: + if self._limit is None: + return None return max(0, self._limit - self._in_use) @property def pending(self) -> int: return self._pending + @property + def peak_pending_recent(self) -> int: + return self._get_peak_pending_recent() -class ModelConcurrencyLimiter: - """按模型名提供独立并发槽位的限制器. + +class ModelConcurrencyController: + """按模型名提供独立并发槽位的控制器. 用法:: - limiter = ModelConcurrencyLimiter(config) - slot = await limiter.acquire("glm-5v-turbo") - try: + # monitor 模式(默认) + ctrl = ModelConcurrencyController(None) + async with ctrl.track("model-a"): ... # 执行请求 - finally: - slot.release() + + # limited 模式(Zhipu 等) + ctrl = ModelConcurrencyController(config) + async with ctrl.track("glm-5v-turbo"): + ... # 满槽时排队等待 """ - def __init__(self, config: ZhipuConcurrencyConfig) -> None: + def __init__(self, config: ZhipuConcurrencyConfig | None) -> None: self._config = config self._slots: dict[str, _ConcurrencySlot] = {} + @property + def mode(self) -> Literal["monitor", "limited"]: + """当前控制器模式.""" + return "limited" if self._config is not None else "monitor" + def _get_or_create_slot(self, model: str) -> _ConcurrencySlot: """获取(或惰性创建)指定模型的并发槽位.""" slot = self._slots.get(model) if slot is None: - limit = self._config.get_limit(model) + if self._config is not None: + limit = self._config.get_limit(model) + else: + limit = None slot = _ConcurrencySlot(limit) self._slots[model] = slot - logger.debug( - "ModelConcurrencyLimiter: created slot model=%s limit=%d", - model, - limit, - ) + if self._config is not None: + logger.debug( + "ModelConcurrencyController: created slot mode=limited " + "model=%s limit=%d", + model, + limit, + ) + else: + logger.debug( + "ModelConcurrencyController: created slot mode=monitor model=%s", + model, + ) return slot - async def acquire(self, model: str) -> _ConcurrencySlot: - """获取指定模型的并发槽位,必要时阻塞排队. + @asynccontextmanager + async def track(self, model: str): + """异步上下文管理器:获取 → 执行 → 释放. + + 用法:: - 返回已获取的 Slot 实例,调用方负责在请求完成后调用 ``release()``。 + async with controller.track("glm-5v-turbo"): + await vendor.send_message(...) """ slot = self._get_or_create_slot(model) await slot.acquire() - return slot + try: + yield + finally: + slot.release() def set_limit(self, model: str, new_limit: int) -> None: """运行时修改指定模型的并发上限. - 同时更新 config.models 以确保后续惰性创建使用新值。 + 仅 limited 模式支持;monitor 模式抛 ValueError。 """ + if self._config is None: + msg = f"vendor is monitor-only; cannot update limit for model '{model}'" + raise ValueError(msg) slot = self._slots.get(model) if slot is None: slot = _ConcurrencySlot(new_limit) @@ -141,22 +219,33 @@ def set_limit(self, model: str, new_limit: int) -> None: slot.set_limit(new_limit) self._config.models[model] = new_limit logger.info( - "ModelConcurrencyLimiter: updated limit model=%s new_limit=%d", + "ModelConcurrencyController: updated limit model=%s new_limit=%d", model, new_limit, ) - def get_diagnostics(self) -> dict[str, dict[str, int]]: + def get_diagnostics(self) -> dict[str, dict[str, Any]]: """返回每个模型的并发状态快照(用于可观测性).""" - snapshot: dict[str, dict[str, int]] = {} + snapshot: dict[str, dict[str, Any]] = {} + mode = self.mode for model, slot in self._slots.items(): - snapshot[model] = { - "limit": slot.limit, + entry: dict[str, Any] = { + "mode": mode, "in_use": slot.in_use, - "available": slot.available, "pending": slot.pending, + "peak_pending_recent": slot.peak_pending_recent, } + if mode == "limited": + entry["limit"] = slot.limit + entry["available"] = slot.available + else: + entry["limit"] = None + entry["available"] = None + snapshot[model] = entry return snapshot -__all__ = ["ModelConcurrencyLimiter"] +# 向后兼容别名 +ModelConcurrencyLimiter = ModelConcurrencyController + +__all__ = ["ModelConcurrencyController", "ModelConcurrencyLimiter"] diff --git a/src/coding/proxy/vendors/zhipu.py b/src/coding/proxy/vendors/zhipu.py index 64407ba..ef36f57 100644 --- a/src/coding/proxy/vendors/zhipu.py +++ b/src/coding/proxy/vendors/zhipu.py @@ -17,6 +17,11 @@ - max_attempt = 5(1 初始 + 4 重试) - 指数退避 + Full Jitter(1s → 2s → 4s → 8s) - 优先尊重 server retry-after header + +并发限流由 BaseVendor._concurrency_controller 统一管控 +(limited 模式),在 executor 层通过 ``track_in_flight`` 触发, +slot 跨 429 重试自然持有(executor 的 async with 包裹整个 +send_message/send_message_stream 调用链)。 """ from __future__ import annotations @@ -37,7 +42,7 @@ ) from ..routing.retry import RetryConfig, calculate_delay from .base import VendorResponse -from .concurrency import ModelConcurrencyLimiter +from .concurrency import ModelConcurrencyController from .native_anthropic import NativeAnthropicVendor logger = logging.getLogger(__name__) @@ -59,6 +64,7 @@ class ZhipuVendor(NativeAnthropicVendor): 仅替换模型名和认证头,其余原样透传。 429 Rate Limit 时自动重试(指数退避),降低 failover 频率。 + 并发限流由 BaseVendor._concurrency_controller 统一管控。 """ _vendor_name = "zhipu" @@ -72,12 +78,8 @@ def __init__( ) -> None: super().__init__(config, model_mapper, failover_config) self._rl_retry = _RATE_LIMIT_RETRY - # 每模型并发限制器(config.concurrency 为 None 时禁用) - self._concurrency_limiter: ModelConcurrencyLimiter | None = ( - ModelConcurrencyLimiter(config.concurrency) - if config.concurrency is not None - else None - ) + # 覆盖 BaseVendor 默认的 monitor 模式为 limited 模式 + self._concurrency_controller = ModelConcurrencyController(config.concurrency) # ── 首选 tier 参数兼容转换 ──────────────────────────────── @@ -129,24 +131,7 @@ async def send_message( request_body: dict[str, Any], headers: dict[str, str], ) -> VendorResponse: - """非流式请求,429 时自动重试. - - 在 429 重试循环外层套上每模型并发槽位获取,确保同一时间点同一模型的 - 在途请求数不超过配置上限;超过时新请求 FIFO 排队等待。 - """ - sem = await self._maybe_acquire_concurrency_slot(request_body) - try: - return await self._send_message_with_retry(request_body, headers) - finally: - if sem is not None: - sem.release() - - async def _send_message_with_retry( - self, - request_body: dict[str, Any], - headers: dict[str, str], - ) -> VendorResponse: - """原 send_message 主体逻辑(不含并发控制).""" + """非流式请求,429 时自动重试.""" max_attempts = self._rl_retry.max_attempts for attempt in range(max_attempts): @@ -186,87 +171,42 @@ async def send_message_stream( 安全性:429 在 BaseVendor.send_message_stream 中于 status code 检查阶段即 raise(在任何 chunk yield 之前), 因此重试不会导致已发出数据不一致。 - - 在 429 重试循环外层套上每模型并发槽位获取,确保流式请求与非流式请求 - 共用同一信号量,统一限制同一模型的总在途并发数。 """ - sem = await self._maybe_acquire_concurrency_slot(request_body) max_attempts = self._rl_retry.max_attempts - try: - for attempt in range(max_attempts): - try: - # 429 在 status code 检查阶段即 raise(在任何 chunk 之前), - # 因此 __anext__ 安全:要么拿到首个 chunk,要么抛异常。 - ait = super().send_message_stream(request_body, headers) - head = await ait.__anext__() - except StopAsyncIteration: - return - except httpx.HTTPStatusError as exc: - if exc.response is None or exc.response.status_code != 429: - raise - if attempt == max_attempts - 1: - logger.warning( - "Zhipu 429 stream rate limit exhausted after %d attempts", - max_attempts, - ) - raise - - delay = self._compute_retry_delay_from_response( - exc.response, attempt - ) - logger.info( - "Zhipu 429 stream rate limit, retry %d/%d in %.1fms", - attempt + 1, - max_attempts - 1, - delay, - ) - await asyncio.sleep(delay / 1000.0) - continue - - # yield 在 try/except 之外,避免捕获外部 athrow 的异常 - yield head - async for chunk in ait: - yield chunk + for attempt in range(max_attempts): + try: + # 429 在 status code 检查阶段即 raise(在任何 chunk 之前), + # 因此 __anext__ 安全:要么拿到首个 chunk,要么抛异常。 + ait = super().send_message_stream(request_body, headers) + head = await ait.__anext__() + except StopAsyncIteration: return - finally: - if sem is not None: - sem.release() - - # ── 并发控制 ──────────────────────────────────────────── - - async def _maybe_acquire_concurrency_slot( - self, - request_body: dict[str, Any], - ) -> asyncio.Semaphore | None: - """按映射后模型名获取并发槽位;未配置 concurrency 时返回 None. - - ``map_model()`` 是纯同步字典查找,在 Semaphore 等待前调用是安全的, - 且能确保排队键与上游真实承载模型对齐。 - """ - if self._concurrency_limiter is None: - return None - raw_model = request_body.get("model", "") if request_body else "" - mapped_model = self.map_model(raw_model) if raw_model else "" - if not mapped_model: - return None - return await self._concurrency_limiter.acquire(mapped_model) - - # ── 诊断信息 ───────────────────────────────────────────── - - def get_diagnostics(self) -> dict[str, Any]: - """返回供应商运行时诊断信息,包含每模型并发状态.""" - diagnostics = super().get_diagnostics() - if self._concurrency_limiter is not None: - diagnostics["concurrency"] = self._concurrency_limiter.get_diagnostics() - return diagnostics - - def update_concurrency(self, model: str, limit: int) -> None: - """运行时更新指定模型的并发限制.""" - if self._concurrency_limiter is None: - msg = "Concurrency limiter is not enabled for this vendor" - raise ValueError(msg) - self._concurrency_limiter.set_limit(model, limit) + except httpx.HTTPStatusError as exc: + if exc.response is None or exc.response.status_code != 429: + raise + if attempt == max_attempts - 1: + logger.warning( + "Zhipu 429 stream rate limit exhausted after %d attempts", + max_attempts, + ) + raise + + delay = self._compute_retry_delay_from_response(exc.response, attempt) + logger.info( + "Zhipu 429 stream rate limit, retry %d/%d in %.1fms", + attempt + 1, + max_attempts - 1, + delay, + ) + await asyncio.sleep(delay / 1000.0) + continue + + # yield 在 try/except 之外,避免捕获外部 athrow 的异常 + yield head + async for chunk in ait: + yield chunk + return # ── 延迟计算 ──────────────────────────────────────────── diff --git a/tests/test_concurrency_monitor.py b/tests/test_concurrency_monitor.py new file mode 100644 index 0000000..4fce021 --- /dev/null +++ b/tests/test_concurrency_monitor.py @@ -0,0 +1,158 @@ +"""ModelConcurrencyController monitor 模式专项测试. + +验证 ``config=None`` 时的纯计数行为: + - acquire 不阻塞,无 limit / available + - pending 永远为 0 + - set_limit 抛 ValueError + - 100 并发 in_use 峰值正确 + - get_diagnostics 输出 mode="monitor" + limit/available=None +""" + +from __future__ import annotations + +import asyncio + +import pytest + +from coding.proxy.vendors.concurrency import ModelConcurrencyController + + +class TestMonitorMode: + """monitor 模式(config=None)基础行为.""" + + def test_mode_property(self) -> None: + ctrl = ModelConcurrencyController(None) + assert ctrl.mode == "monitor" + + @pytest.mark.asyncio + async def test_acquire_never_blocks(self) -> None: + """monitor 模式下任意数量并发都立即获取槽位.""" + ctrl = ModelConcurrencyController(None) + slot = ctrl._get_or_create_slot("model-x") + # 即使触发多次 acquire 也不阻塞 + for _ in range(10): + await slot.acquire() + assert slot.in_use == 10 + assert slot.pending == 0 + + @pytest.mark.asyncio + async def test_100_concurrent_acquires_no_queue(self) -> None: + """100 并发请求全部立即拿到槽位,无排队.""" + ctrl = ModelConcurrencyController(None) + + gate = asyncio.Event() + max_in_use = 0 + + async def hold(model: str) -> None: + nonlocal max_in_use + async with ctrl.track(model): + slot = ctrl._get_or_create_slot(model) + max_in_use = max(max_in_use, slot.in_use) + await gate.wait() + + tasks = [asyncio.create_task(hold("test-model")) for _ in range(100)] + # 等所有任务进入 track + await asyncio.sleep(0.1) + slot = ctrl._get_or_create_slot("test-model") + assert slot.in_use == 100, "monitor 模式应允许全部并行" + assert slot.pending == 0, "monitor 模式 pending 应恒为 0" + gate.set() + await asyncio.gather(*tasks) + # 释放后归零 + assert slot.in_use == 0 + + @pytest.mark.asyncio + async def test_release_after_track(self) -> None: + """track 上下文退出后 in_use 归零.""" + ctrl = ModelConcurrencyController(None) + slot = ctrl._get_or_create_slot("m") + async with ctrl.track("m"): + assert slot.in_use == 1 + assert slot.in_use == 0 + + def test_set_limit_raises_in_monitor(self) -> None: + """monitor 模式下 set_limit 抛 ValueError.""" + ctrl = ModelConcurrencyController(None) + with pytest.raises(ValueError, match="monitor-only"): + ctrl.set_limit("m", 5) + + def test_diagnostics_monitor_shape(self) -> None: + """monitor 模式 get_diagnostics 输出 mode + limit/available=None.""" + ctrl = ModelConcurrencyController(None) + ctrl._get_or_create_slot("m") + snap = ctrl.get_diagnostics() + assert snap["m"]["mode"] == "monitor" + assert snap["m"]["limit"] is None + assert snap["m"]["available"] is None + assert snap["m"]["in_use"] == 0 + assert snap["m"]["pending"] == 0 + assert snap["m"]["peak_pending_recent"] == 0 + + @pytest.mark.asyncio + async def test_pending_never_increases_in_monitor(self) -> None: + """monitor 模式即使触发大量并发,pending 永远不增.""" + ctrl = ModelConcurrencyController(None) + gate = asyncio.Event() + + async def hold() -> None: + async with ctrl.track("m"): + await gate.wait() + + tasks = [asyncio.create_task(hold()) for _ in range(20)] + await asyncio.sleep(0.05) + slot = ctrl._get_or_create_slot("m") + assert slot.pending == 0 + assert slot.in_use == 20 + gate.set() + await asyncio.gather(*tasks) + + +class TestBaseVendorTrackInFlight: + """BaseVendor.track_in_flight 默认 monitor 行为.""" + + def test_empty_model_returns_noop(self) -> None: + """空 model name 返回 no-op 上下文,不影响 controller 状态.""" + from coding.proxy.config.schema import AnthropicConfig, FailoverConfig + from coding.proxy.vendors.anthropic import AnthropicVendor + + vendor = AnthropicVendor(AnthropicConfig(), FailoverConfig()) + ctx = vendor.track_in_flight("") + # nullcontext 是同步上下文管理器,应有 __enter__/__exit__ + # 我们不进入它,只验证不抛错 + assert ctx is not None + + @pytest.mark.asyncio + async def test_track_in_flight_increments_in_use(self) -> None: + """非空 model name → controller.track 进入,in_use 自增.""" + from coding.proxy.config.schema import AnthropicConfig, FailoverConfig + from coding.proxy.vendors.anthropic import AnthropicVendor + + vendor = AnthropicVendor(AnthropicConfig(), FailoverConfig()) + async with vendor.track_in_flight("claude-test"): + slot = vendor._concurrency_controller._get_or_create_slot("claude-test") + assert slot.in_use == 1 + # 退出后归零 + slot = vendor._concurrency_controller._get_or_create_slot("claude-test") + assert slot.in_use == 0 + + def test_update_concurrency_default_is_monitor_only(self) -> None: + """BaseVendor 默认 monitor → update_concurrency 抛 ValueError.""" + from coding.proxy.config.schema import AnthropicConfig, FailoverConfig + from coding.proxy.vendors.anthropic import AnthropicVendor + + vendor = AnthropicVendor(AnthropicConfig(), FailoverConfig()) + with pytest.raises(ValueError, match="monitor-only"): + vendor.update_concurrency("m", 5) + + def test_get_diagnostics_includes_concurrency_after_use(self) -> None: + """track_in_flight 用过后 get_diagnostics 输出 concurrency 字段.""" + from coding.proxy.config.schema import AnthropicConfig, FailoverConfig + from coding.proxy.vendors.anthropic import AnthropicVendor + + vendor = AnthropicVendor(AnthropicConfig(), FailoverConfig()) + # 触发 slot 创建 + vendor._concurrency_controller._get_or_create_slot("claude-test") + diag = vendor.get_diagnostics() + assert "concurrency" in diag + assert "claude-test" in diag["concurrency"] + assert diag["concurrency"]["claude-test"]["mode"] == "monitor" diff --git a/tests/test_executor_in_flight_tracking.py b/tests/test_executor_in_flight_tracking.py new file mode 100644 index 0000000..9b57e39 --- /dev/null +++ b/tests/test_executor_in_flight_tracking.py @@ -0,0 +1,233 @@ +"""Executor 层 track_in_flight 包裹行为验证. + +验证 ``_RouteExecutor`` 在调用 ``vendor.send_message[_stream]`` 前 +进入 ``vendor.track_in_flight(mapped_model)`` 上下文,在调用结束(包括异常) +后正确退出(释放槽位)。 +""" + +from __future__ import annotations + +from contextlib import asynccontextmanager +from unittest.mock import AsyncMock, MagicMock + +import httpx +import pytest + +from coding.proxy.compat.canonical import ( + CompatibilityDecision, + CompatibilityStatus, +) +from coding.proxy.routing.executor import _RouteExecutor +from coding.proxy.routing.session_manager import RouteSessionManager +from coding.proxy.routing.tier import VendorTier +from coding.proxy.routing.usage_recorder import UsageRecorder +from coding.proxy.vendors.base import ( + BaseVendor, + RequestCapabilities, + UsageInfo, + VendorCapabilities, + VendorResponse, +) + + +class _TrackingProbe: + """共享状态:记录 track_in_flight enter/exit 时序与 send 调用顺序.""" + + def __init__(self) -> None: + self.events: list[str] = [] + self.in_flight: int = 0 + self.peak_in_flight: int = 0 + + def track_factory(self, vendor_name: str): + @asynccontextmanager + async def _track(mapped_model: str): + self.events.append(f"enter:{vendor_name}:{mapped_model}") + self.in_flight += 1 + self.peak_in_flight = max(self.peak_in_flight, self.in_flight) + try: + yield + finally: + self.in_flight -= 1 + self.events.append(f"exit:{vendor_name}:{mapped_model}") + + def _factory(mapped_model: str): + return _track(mapped_model) + + return _factory + + +def _mock_vendor_with_probe( + probe: _TrackingProbe, name: str = "test", **caps_kwargs +) -> BaseVendor: + """创建带 track_in_flight 探针的 mock vendor.""" + vendor = MagicMock(spec=BaseVendor) + vendor.get_name.return_value = name + vendor.map_model.return_value = f"{name}-mapped" + caps = VendorCapabilities(**caps_kwargs) + vendor.get_capabilities.return_value = caps + vendor.get_compatibility_profile.return_value = MagicMock() + vendor.make_compatibility_decision.return_value = CompatibilityDecision( + status=CompatibilityStatus.NATIVE, + ) + vendor.get_compat_trace.return_value = None + + def _supports_request(_caps: RequestCapabilities): + return True, [] + + vendor.supports_request.side_effect = _supports_request + vendor.check_health = AsyncMock(return_value=True) + vendor.close = AsyncMock() + vendor.set_compat_context = MagicMock() + + # 关键:track_in_flight 委托给 probe + vendor.track_in_flight = MagicMock(side_effect=probe.track_factory(name)) + + # send_message 默认返回成功 + async def _send_message(_body, _headers): + probe.events.append(f"send:{name}") + return VendorResponse( + status_code=200, + raw_body=b"{}", + usage=UsageInfo(input_tokens=1, output_tokens=1), + ) + + vendor.send_message = AsyncMock(side_effect=_send_message) + + # send_message_stream 默认产出空流 + async def _send_stream(_body, _headers): + probe.events.append(f"stream_start:{name}") + yield b'data: {"type":"message_start"}\n\n' + probe.events.append(f"stream_end:{name}") + + vendor.send_message_stream = MagicMock(side_effect=_send_stream) + return vendor + + +def _make_executor(vendor: BaseVendor) -> _RouteExecutor: + tier = VendorTier(vendor=vendor) + return _RouteExecutor( + router=MagicMock(), + tiers=[tier], + usage_recorder=UsageRecorder(), + session_manager=RouteSessionManager(), + ) + + +class TestExecuteMessageInFlightTracking: + """非流式调用的 track_in_flight 包裹行为.""" + + @pytest.mark.asyncio + async def test_track_enter_before_send_exit_after(self): + """success path: enter → send → exit 顺序.""" + probe = _TrackingProbe() + vendor = _mock_vendor_with_probe(probe, name="kimi") + exec_inst = _make_executor(vendor) + + resp = await exec_inst.execute_message({"model": "claude-test"}, {}) + assert resp.status_code == 200 + + assert probe.events == [ + "enter:kimi:kimi-mapped", + "send:kimi", + "exit:kimi:kimi-mapped", + ] + vendor.track_in_flight.assert_called_once_with("kimi-mapped") + + @pytest.mark.asyncio + async def test_track_exits_on_http_error(self): + """异常路径:track_in_flight 仍执行 exit.""" + probe = _TrackingProbe() + vendor = _mock_vendor_with_probe(probe, name="kimi") + vendor.send_message = AsyncMock(side_effect=httpx.ConnectError("down")) + exec_inst = _make_executor(vendor) + + with pytest.raises(httpx.ConnectError): + await exec_inst.execute_message({"model": "claude-test"}, {}) + + # enter 与 exit 都应被记录(finally 触发) + assert "enter:kimi:kimi-mapped" in probe.events + assert "exit:kimi:kimi-mapped" in probe.events + # exit 应在 enter 之后 + assert probe.events.index("exit:kimi:kimi-mapped") > probe.events.index( + "enter:kimi:kimi-mapped" + ) + + @pytest.mark.asyncio + async def test_concurrent_message_calls_track_each(self): + """多并发请求每个都触发 enter/exit;in_flight 峰值正确.""" + import asyncio + + probe = _TrackingProbe() + vendor = _mock_vendor_with_probe(probe, name="copilot") + + async def slow_send(_body, _headers): + probe.events.append("send:copilot") + await asyncio.sleep(0.05) + return VendorResponse( + status_code=200, + raw_body=b"{}", + usage=UsageInfo(input_tokens=1, output_tokens=1), + ) + + vendor.send_message = AsyncMock(side_effect=slow_send) + exec_inst = _make_executor(vendor) + + tasks = [ + asyncio.create_task(exec_inst.execute_message({"model": "claude-test"}, {})) + for _ in range(5) + ] + results = await asyncio.gather(*tasks) + assert all(r.status_code == 200 for r in results) + # 5 个 enter + 5 个 send + 5 个 exit + assert sum(1 for e in probe.events if e.startswith("enter:")) == 5 + assert sum(1 for e in probe.events if e.startswith("exit:")) == 5 + # 并发期间 in_flight 峰值应达到 5(monitor 模式不限流) + assert probe.peak_in_flight == 5 + + +class TestExecuteStreamInFlightTracking: + """流式调用的 track_in_flight 包裹行为.""" + + @pytest.mark.asyncio + async def test_stream_track_enter_exit(self): + """成功流式:enter → stream → exit 顺序.""" + probe = _TrackingProbe() + vendor = _mock_vendor_with_probe(probe, name="doubao") + exec_inst = _make_executor(vendor) + + chunks = [] + async for chunk, name in exec_inst.execute_stream({"model": "claude-test"}, {}): + chunks.append((chunk, name)) + + assert len(chunks) >= 1 + # 应包含 enter:doubao:doubao-mapped 与 exit:doubao:doubao-mapped + assert "enter:doubao:doubao-mapped" in probe.events + assert "exit:doubao:doubao-mapped" in probe.events + # exit 在 stream_end 之后 + assert probe.events.index("exit:doubao:doubao-mapped") > probe.events.index( + "stream_end:doubao" + ) + + @pytest.mark.asyncio + async def test_stream_track_exits_on_error(self): + """流式异常退出时 track exit 仍触发.""" + probe = _TrackingProbe() + vendor = _mock_vendor_with_probe(probe, name="minimax") + + async def error_stream(_body, _headers): + yield b'data: {"type":"message_start"}\n\n' + raise httpx.HTTPStatusError( + "500", + request=httpx.Request("POST", "https://example.com"), + response=httpx.Response(500), + ) + + vendor.send_message_stream = MagicMock(side_effect=error_stream) + exec_inst = _make_executor(vendor) + + with pytest.raises(httpx.HTTPStatusError): + async for _ in exec_inst.execute_stream({"model": "claude-test"}, {}): + pass + + assert "enter:minimax:minimax-mapped" in probe.events + assert "exit:minimax:minimax-mapped" in probe.events diff --git a/tests/test_router_executor.py b/tests/test_router_executor.py index 9506e67..070ca95 100644 --- a/tests/test_router_executor.py +++ b/tests/test_router_executor.py @@ -90,6 +90,11 @@ def _supports_request(request_caps: RequestCapabilities): vendor.check_health = AsyncMock(return_value=True) vendor.close = AsyncMock() vendor.set_compat_context = MagicMock() + + # track_in_flight 返回 nullcontext(不影响执行流,仅满足 async with 协议) + from contextlib import nullcontext + + vendor.track_in_flight = MagicMock(return_value=nullcontext()) return vendor diff --git a/tests/test_zhipu_concurrency.py b/tests/test_zhipu_concurrency.py index 7566b24..d7be1ca 100644 --- a/tests/test_zhipu_concurrency.py +++ b/tests/test_zhipu_concurrency.py @@ -1,13 +1,16 @@ """Zhipu 每模型并发限制专项测试. -验证 ``ModelConcurrencyLimiter`` 与 ``ZhipuVendor`` 集成后的并发控制行为: +验证 ``ModelConcurrencyController`` 与 ``ZhipuVendor`` 集成后的并发控制行为: - 默认 ``concurrency.default=3`` 时同一模型最多 3 个并发 - 超出上限时按 FIFO 排队,槽位释放后才唤醒 - 不同模型彼此独立,互不阻塞 - - 异常路径下 Semaphore 仍能释放,避免泄漏 - - 流式请求与非流式请求共享同一信号量 + - 异常路径下槽位仍能释放,避免泄漏 + - 流式请求与非流式请求共享同一槽位 - 与 429 重试机制兼容(重试期间持续占用槽位) - - ``concurrency=None`` 时禁用限制(向后兼容) + +注意:并发限流由 BaseVendor._concurrency_controller 统一管控, +executor 层通过 ``vendor.track_in_flight(mapped_model)`` 上下文管理器 +包裹 send_message[_stream] 调用。本测试用同样的包裹模拟 executor 语义。 """ from __future__ import annotations @@ -25,7 +28,10 @@ ZhipuConfig, ) from coding.proxy.routing.model_mapper import ModelMapper -from coding.proxy.vendors.concurrency import ModelConcurrencyLimiter +from coding.proxy.vendors.concurrency import ( + ModelConcurrencyController, + ModelConcurrencyLimiter, +) from coding.proxy.vendors.native_anthropic import NativeAnthropicVendor from coding.proxy.vendors.zhipu import ZhipuVendor @@ -69,6 +75,29 @@ def _make_vendor( return ZhipuVendor(ZhipuConfig(**cfg_kwargs), _make_mapper()) +async def _send_with_tracking( + vendor: ZhipuVendor, + body: dict, + headers: dict, +): + """模拟 executor 行为:track_in_flight 包裹 send_message 调用.""" + mapped = vendor.map_model(body.get("model", "")) + async with vendor.track_in_flight(mapped): + return await vendor.send_message(body, headers) + + +async def _stream_with_tracking( + vendor: ZhipuVendor, + body: dict, + headers: dict, +): + """模拟 executor 行为:track_in_flight 包裹 send_message_stream 调用.""" + mapped = vendor.map_model(body.get("model", "")) + async with vendor.track_in_flight(mapped): + async for chunk in vendor.send_message_stream(body, headers): + yield chunk + + def _make_200_response() -> httpx.Response: body = json.dumps( { @@ -132,72 +161,128 @@ def test_zhipu_config_default_concurrency(self) -> None: assert cfg.concurrency.default == 3 -# ─── ModelConcurrencyLimiter 单元测试 ────────────────────── +# ─── ModelConcurrencyController 单元测试(limited 模式)───── + +class TestModelConcurrencyControllerLimited: + """ModelConcurrencyController limited 模式基础行为.""" -class TestModelConcurrencyLimiter: - """ModelConcurrencyLimiter 基础行为.""" + def test_alias_compatibility(self) -> None: + """ModelConcurrencyLimiter 别名指向 ModelConcurrencyController.""" + assert ModelConcurrencyLimiter is ModelConcurrencyController + + def test_mode_is_limited(self) -> None: + ctrl = ModelConcurrencyController(ZhipuConcurrencyConfig(default=2)) + assert ctrl.mode == "limited" @pytest.mark.asyncio - async def test_lazy_semaphore_creation(self) -> None: - limiter = ModelConcurrencyLimiter(ZhipuConcurrencyConfig(default=2)) - slot_a = limiter._get_or_create_slot("model-a") - slot_b = limiter._get_or_create_slot("model-b") + async def test_lazy_slot_creation(self) -> None: + ctrl = ModelConcurrencyController(ZhipuConcurrencyConfig(default=2)) + slot_a = ctrl._get_or_create_slot("model-a") + slot_b = ctrl._get_or_create_slot("model-b") # 不同模型独立 slot assert slot_a is not slot_b # 相同模型复用 slot - assert limiter._get_or_create_slot("model-a") is slot_a + assert ctrl._get_or_create_slot("model-a") is slot_a @pytest.mark.asyncio - async def test_acquire_blocks_when_full(self) -> None: - limiter = ModelConcurrencyLimiter(ZhipuConcurrencyConfig(default=2)) + async def test_track_blocks_when_full(self) -> None: + ctrl = ModelConcurrencyController(ZhipuConcurrencyConfig(default=2)) + + # 通过 acquire 占满 2 个槽位 + slot = ctrl._get_or_create_slot("glm-5.1") + await slot.acquire() + await slot.acquire() + assert slot.in_use == 2 - # 占满 2 个槽位 - sem1 = await limiter.acquire("glm-5.1") - sem2 = await limiter.acquire("glm-5.1") - assert sem1 is sem2 # 同一 semaphore + # 第 3 次 track 必须阻塞 + async def third(): + async with ctrl.track("glm-5.1"): + pass - # 第 3 次 acquire 必须阻塞 - task = asyncio.create_task(limiter.acquire("glm-5.1")) + task = asyncio.create_task(third()) await asyncio.sleep(0.05) assert not task.done(), "第三个请求应在排队等待" + # 排队时 pending 应递增 + assert slot.pending == 1 # 释放一个槽位后,等待者被唤醒 - sem1.release() + slot.release() await asyncio.sleep(0.05) assert task.done() - (await task).release() - sem2.release() + slot.release() @pytest.mark.asyncio async def test_per_model_independent(self) -> None: - limiter = ModelConcurrencyLimiter( + ctrl = ModelConcurrencyController( ZhipuConcurrencyConfig(default=1, models={"glm-5.1": 1}) ) - # 占满 glm-5.1 - sem_51 = await limiter.acquire("glm-5.1") - # glm-5v-turbo 仍可立即获取 - sem_5v = await asyncio.wait_for(limiter.acquire("glm-5v-turbo"), timeout=0.5) - assert sem_51 is not sem_5v - sem_51.release() - sem_5v.release() - def test_diagnostics_snapshot(self) -> None: - limiter = ModelConcurrencyLimiter(ZhipuConcurrencyConfig(default=3)) + async def hold(model: str, event: asyncio.Event) -> None: + async with ctrl.track(model): + await event.wait() + + gate_51 = asyncio.Event() + gate_5v = asyncio.Event() + t51 = asyncio.create_task(hold("glm-5.1", gate_51)) + await asyncio.sleep(0.02) + # glm-5v-turbo 仍可立即获取 + t5v = asyncio.create_task(hold("glm-5v-turbo", gate_5v)) + await asyncio.sleep(0.02) + # 两个任务都尚未结束(都在 await event) + assert not t51.done() + assert not t5v.done() + gate_51.set() + gate_5v.set() + await asyncio.gather(t51, t5v) + + def test_diagnostics_snapshot_limited(self) -> None: + ctrl = ModelConcurrencyController(ZhipuConcurrencyConfig(default=3)) # 触发 slot 创建 - limiter._get_or_create_slot("glm-5.1") - snap = limiter.get_diagnostics() + ctrl._get_or_create_slot("glm-5.1") + snap = ctrl.get_diagnostics() assert "glm-5.1" in snap + assert snap["glm-5.1"]["mode"] == "limited" assert snap["glm-5.1"]["limit"] == 3 assert snap["glm-5.1"]["available"] == 3 assert snap["glm-5.1"]["in_use"] == 0 + assert snap["glm-5.1"]["pending"] == 0 + assert "peak_pending_recent" in snap["glm-5.1"] + + @pytest.mark.asyncio + async def test_peak_pending_recent_tracking(self) -> None: + """触发排队时记录 peak,释放后仍可读到余晖.""" + ctrl = ModelConcurrencyController(ZhipuConcurrencyConfig(default=1)) + slot = ctrl._get_or_create_slot("glm-5.1") + + # 占满 + await slot.acquire() + # 触发两个排队 + t1 = asyncio.create_task(slot.acquire()) + t2 = asyncio.create_task(slot.acquire()) + await asyncio.sleep(0.05) + assert slot.pending == 2 + snap = ctrl.get_diagnostics() + assert snap["glm-5.1"]["peak_pending_recent"] == 2 + + # 释放并完成所有任务 + slot.release() + await t1 + slot.release() + await t2 + slot.release() + + # pending 已归零,但 peak_pending_recent 仍记得最近的峰值 + snap2 = ctrl.get_diagnostics() + assert snap2["glm-5.1"]["pending"] == 0 + assert snap2["glm-5.1"]["peak_pending_recent"] == 2 # ─── ZhipuVendor 集成测试:非流式 ──────────────────────────── class TestZhipuVendorNonStreamConcurrency: - """非流式 send_message 的并发限制行为.""" + """非流式 send_message 的并发限制行为(通过 track_in_flight 包裹).""" @pytest.mark.asyncio async def test_limits_parallel_requests(self) -> None: @@ -223,7 +308,8 @@ async def mock_post(*_, **__) -> httpx.Response: tasks = [ asyncio.create_task( - vendor.send_message( + _send_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {}, ) @@ -265,15 +351,17 @@ async def mock_post(*_args, **kwargs) -> httpx.Response: mock_client.return_value = client # claude-opus → glm-5.1, claude-sonnet → glm-5v-turbo, - # 分属两个独立信号量,应同时执行 + # 分属两个独立槽位,应同时执行 task_opus = asyncio.create_task( - vendor.send_message( + _send_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {}, ) ) task_sonnet = asyncio.create_task( - vendor.send_message( + _send_with_tracking( + vendor, {"model": "claude-sonnet-4-6", "messages": []}, {}, ) @@ -289,8 +377,8 @@ async def mock_post(*_args, **kwargs) -> httpx.Response: await asyncio.gather(task_opus, task_sonnet) @pytest.mark.asyncio - async def test_semaphore_released_on_exception(self) -> None: - """上游抛异常时 Semaphore 仍应释放,后续请求不阻塞.""" + async def test_slot_released_on_exception(self) -> None: + """上游抛异常时槽位仍应释放,后续请求不阻塞.""" vendor = _make_vendor(ZhipuConcurrencyConfig(default=1)) call_count = 0 @@ -307,14 +395,16 @@ async def mock_post(*_, **__) -> httpx.Response: mock_client.return_value = client with pytest.raises(RuntimeError): - await vendor.send_message( + await _send_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {}, ) # 槽位应已释放,第二次请求可正常完成 resp = await asyncio.wait_for( - vendor.send_message( + _send_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {}, ), @@ -343,7 +433,8 @@ async def mock_post(*_, **__) -> httpx.Response: client.post = mock_post mock_client.return_value = client - resp = await vendor.send_message( + resp = await _send_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {}, ) @@ -351,13 +442,13 @@ async def mock_post(*_, **__) -> httpx.Response: assert call_count == 3 # 两次 429 + 一次成功,且共用同一槽位 @pytest.mark.asyncio - async def test_no_concurrency_when_config_is_none(self) -> None: - """concurrency=None 时禁用并发限制,行为与旧版完全一致.""" - # 强制构造一个 concurrency=None 的 ZhipuConfig(绕过默认工厂) - cfg = ZhipuConfig(api_key="key") - cfg = cfg.model_copy(update={"concurrency": None}) - vendor = ZhipuVendor(cfg, _make_mapper()) - assert vendor._concurrency_limiter is None + async def test_monitor_mode_no_throttling(self) -> None: + """BaseVendor 默认 monitor 模式:高并发不限流,仅计数.""" + # 显式构造一个 monitor 模式的 zhipu vendor(用于验证 BaseVendor 默认行为等价) + vendor = _make_vendor(ZhipuConcurrencyConfig(default=20)) + from coding.proxy.vendors.concurrency import ModelConcurrencyController as MCC + + vendor._concurrency_controller = MCC(None) gate = asyncio.Event() active = 0 @@ -378,7 +469,8 @@ async def mock_post(*_, **__) -> httpx.Response: tasks = [ asyncio.create_task( - vendor.send_message( + _send_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {}, ) @@ -390,7 +482,7 @@ async def mock_post(*_, **__) -> httpx.Response: break await asyncio.sleep(0.01) - assert peak == 5, "无并发限制时应全部并行" + assert peak == 5, "monitor 模式应允许全部并行" gate.set() await asyncio.gather(*tasks) @@ -399,7 +491,7 @@ async def mock_post(*_, **__) -> httpx.Response: class TestZhipuVendorStreamConcurrency: - """流式 send_message_stream 的并发限制行为.""" + """流式 send_message_stream 的并发限制行为(通过 track_in_flight 包裹).""" @pytest.mark.asyncio async def test_stream_limits_parallel_requests(self) -> None: @@ -421,8 +513,8 @@ async def fake_stream(self, _body, _headers): # noqa: ARG001 async def consume(model: str) -> int: chunks: list[bytes] = [] - async for chunk in vendor.send_message_stream( - {"model": model, "messages": []}, {} + async for chunk in _stream_with_tracking( + vendor, {"model": model, "messages": []}, {} ): chunks.append(chunk) return len(chunks) @@ -453,15 +545,14 @@ async def fake_stream(self, _body, _headers): # noqa: ARG001 # 连续两次流式请求都能完成(说明槽位被释放) for _ in range(2): chunks = [] - async for chunk in vendor.send_message_stream( - {"model": "claude-opus-4-6", "messages": []}, {} + async for chunk in _stream_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {} ): chunks.append(chunk) assert len(chunks) == 2 # 确认 slot 当前完全可用 - assert vendor._concurrency_limiter is not None - slot = vendor._concurrency_limiter._get_or_create_slot("glm-5.1") + slot = vendor._concurrency_controller._get_or_create_slot("glm-5.1") assert slot.available == 1 @pytest.mark.asyncio @@ -485,22 +576,22 @@ async def fake_stream(self, _body, _headers): # noqa: ARG001 with patch.object(NativeAnthropicVendor, "send_message_stream", fake_stream): with pytest.raises(httpx.HTTPStatusError): - async for _ in vendor.send_message_stream( - {"model": "claude-opus-4-6", "messages": []}, {} + async for _ in _stream_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {} ): pass # 槽位应已释放,第二次请求可正常推进 chunks = [] - async for chunk in vendor.send_message_stream( - {"model": "claude-opus-4-6", "messages": []}, {} + async for chunk in _stream_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {} ): chunks.append(chunk) assert chunks == [b'data: {"type":"message_start"}\n\n'] @pytest.mark.asyncio - async def test_stream_and_nonstream_share_semaphore(self) -> None: - """流式与非流式请求共用同一信号量(按映射后模型分组).""" + async def test_stream_and_nonstream_share_slot(self) -> None: + """流式与非流式请求共用同一槽位(按映射后模型分组).""" vendor = _make_vendor(ZhipuConcurrencyConfig(default=1)) gate = asyncio.Event() active = 0 @@ -530,8 +621,8 @@ async def mock_post(*_, **__) -> httpx.Response: # 启动流式请求并等待它占用槽位 async def consume_stream() -> None: - async for _ in vendor.send_message_stream( - {"model": "claude-opus-4-6", "messages": []}, {} + async for _ in _stream_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {} ): pass @@ -542,9 +633,10 @@ async def consume_stream() -> None: await asyncio.sleep(0.01) assert active == 1 - # 非流式请求应被同一信号量阻塞 + # 非流式请求应被同一槽位阻塞 nonstream_task = asyncio.create_task( - vendor.send_message( + _send_with_tracking( + vendor, {"model": "claude-opus-4-6", "messages": []}, {}, )