From b7de51208198a25b35f0ede5d5c0428ccc446e39 Mon Sep 17 00:00:00 2001 From: ThreeFish Date: Mon, 25 May 2026 23:33:13 +0800 Subject: [PATCH 1/2] =?UTF-8?q?feat(zhipu):=20=E6=96=B0=E5=A2=9E=E6=AF=8F?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E5=B9=B6=E5=8F=91=E9=99=90=E5=88=B6=EF=BC=8C?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=203=20=E4=B8=AA=E5=B9=B6=E8=A1=8C=E8=AF=B7?= =?UTF-8?q?=E6=B1=82=20FIFO=20=E6=8E=92=E9=98=9F;?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 ZhipuConcurrencyConfig 与 ModelConcurrencyLimiter,按映射后模型名(如 glm-5v-turbo / glm-5.1 / glm-4.5-air)维护独立 asyncio.Semaphore,槽位满时新请求 FIFO 排队等待; - ZhipuVendor 流式与非流式入口共用同一信号量,并与既有 429 重试机制兼容(重试期间持续占用槽位); - VendorConfig 新增 concurrency 字段,由工厂转发至 ZhipuConfig,未配置时回退默认 default=3,concurrency=None 完全禁用限流; - 同步更新 docs/arch/config-reference.md 与 CHANGELOG.md,新增 18 项专项测试(含配置层、限制器单元、流式/非流式集成与异常释放)。 🤖 Generated with [Claude Code](https://github.com/claude), [CodeX](https://openai.com), [Gemini](https://github.com/apps/gemini-code-assist) Co-Authored-By: Aurelius Huang --- CHANGELOG.md | 2 + docs/arch/config-reference.md | 40 +- src/coding/proxy/config/config.default.yaml | 8 + src/coding/proxy/config/routing.py | 7 + src/coding/proxy/config/schema.py | 2 + src/coding/proxy/config/vendors.py | 18 +- src/coding/proxy/server/factory.py | 16 +- src/coding/proxy/vendors/concurrency.py | 78 +++ src/coding/proxy/vendors/zhipu.py | 119 +++-- tests/test_zhipu_concurrency.py | 557 ++++++++++++++++++++ 10 files changed, 801 insertions(+), 46 deletions(-) create mode 100644 src/coding/proxy/vendors/concurrency.py create mode 100644 tests/test_zhipu_concurrency.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 0fb0f1d..8eb7a1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,8 @@ ## [Unreleased] +- feat(zhipu): 新增每模型并发限制(默认 3,可通过 `vendors[zhipu].concurrency` 配置),基于 `asyncio.Semaphore` 实现 FIFO 公平排队,流式与非流式共用同一槽位,与 429 重试机制兼容。 + ## [v0.4.0](https://github.com/ThreeFish-AI/coding-proxy/releases/tag/v0.4.0) — 2026-05-01 > [!IMPORTANT] diff --git a/docs/arch/config-reference.md b/docs/arch/config-reference.md index 24e11e5..1f4460f 100644 --- a/docs/arch/config-reference.md +++ b/docs/arch/config-reference.md @@ -89,12 +89,13 @@ flowchart TD ## 5. VendorConfig 弹性字段 -| 字段 | 类型 | 默认值 | 说明 | -| -------------------- | -------------- | -------------------- | --------------------------- | -| `circuit_breaker` | config \| None | `None` | 熔断器配置(None = 终端层) | -| `retry` | config | `RetryConfig()` | 重试策略配置 | -| `quota_guard` | config | `QuotaGuardConfig()` | 日度配额守卫配置 | -| `weekly_quota_guard` | config | `QuotaGuardConfig()` | 周度配额守卫配置 | +| 字段 | 类型 | 默认值 | 说明 | +| -------------------- | -------------- | -------------------- | ----------------------------------- | +| `circuit_breaker` | config \| None | `None` | 熔断器配置(None = 终端层) | +| `retry` | config | `RetryConfig()` | 重试策略配置 | +| `quota_guard` | config | `QuotaGuardConfig()` | 日度配额守卫配置 | +| `weekly_quota_guard` | config | `QuotaGuardConfig()` | 周度配额守卫配置 | +| `concurrency` | config \| None | `None` | `[zhipu]` 每模型并发限制(详见 5.5) | @@ -143,6 +144,33 @@ flowchart TD | `error_types` | list[str] | `["rate_limit_error", "overloaded_error", "api_error"]` | | `error_message_patterns` | list[str] | `["quota", "limit exceeded", "usage cap", "capacity", "internal network failure"]` | +### 5.5 ZhipuConcurrencyConfig — Zhipu 每模型并发参数 + +仅对 `vendor: zhipu` 生效,基于 `asyncio.Semaphore` 实现 FIFO 公平排队。 + +| 字段 | 类型 | 默认值 | 说明 | +| --------- | -------------- | ------ | -------------------------------------------------------------------------------- | +| `default` | int | `3` | 全局默认并行度(适用于所有未在 `models` 中显式覆盖的模型);取值范围 `[1, 20]` | +| `models` | map[str → int] | `{}` | 按映射后模型名(如 `glm-5v-turbo` / `glm-5.1` / `glm-4.5-air`)自定义并行度上限 | + +YAML 示例: + +```yaml +- vendor: zhipu + concurrency: + default: 3 + models: + glm-5v-turbo: 5 + glm-5.1: 2 +``` + +行为语义: + +- 信号量按**映射后模型名**键控,与上游真实承载模型对齐;流式与非流式请求共用同一槽位。 +- 槽位满时新请求按 FIFO 顺序排队,直到任一在途请求释放槽位才被唤醒。 +- 429 重试期间持续占用槽位(重试视为同一请求的延续)。 +- 顶层 `concurrency` 字段缺省为 `None` → 转发至 `ZhipuConfig` 时回退默认值 `default=3`;如需完全关闭限流,可在 `ZhipuConfig` 构造层显式置 `null`(一般无需操作)。 + --- ## 6. 供应商专属字段 diff --git a/src/coding/proxy/config/config.default.yaml b/src/coding/proxy/config/config.default.yaml index b6987fa..d945125 100644 --- a/src/coding/proxy/config/config.default.yaml +++ b/src/coding/proxy/config/config.default.yaml @@ -119,6 +119,14 @@ vendors: window_hours: 24.0 threshold_percent: 95.0 probe_interval_seconds: 300 + # 每模型并发限制:默认 3 个并行请求;超出则按 FIFO 排队等待 + # 可通过 models 字段覆盖单个模型的限制(如 glm-5.1: 5) + concurrency: + default: 3 + # models: + # glm-5v-turbo: 3 + # glm-5.1: 3 + # glm-4.5-air: 3 # Vendor 4: MiniMax(默认禁用,需手动启用并添加到 tiers) - vendor: minimax diff --git a/src/coding/proxy/config/routing.py b/src/coding/proxy/config/routing.py index 3326a0b..d0b2d48 100644 --- a/src/coding/proxy/config/routing.py +++ b/src/coding/proxy/config/routing.py @@ -9,6 +9,7 @@ from pydantic import BaseModel, BeforeValidator, Field, PrivateAttr, model_validator from .resiliency import CircuitBreakerConfig, QuotaGuardConfig, RetryConfig +from .vendors import ZhipuConcurrencyConfig # ── 价格字段解析($ / ¥ 前缀支持) ────────────────────────── @@ -285,6 +286,12 @@ class VendorConfig(BaseModel): quota_guard: QuotaGuardConfig = Field(default_factory=QuotaGuardConfig) weekly_quota_guard: QuotaGuardConfig = Field(default_factory=QuotaGuardConfig) + # ── Zhipu 专属:每模型并发限制 ─────────────────────────── + concurrency: ZhipuConcurrencyConfig | None = Field( + default=None, + description="[zhipu] 每模型并发限制;None 表示不限并发", + ) + @model_validator(mode="after") def _warn_irrelevant_fields(self) -> VendorConfig: """对非当前 vendor 类型的非空专属字段发出 warning.""" diff --git a/src/coding/proxy/config/schema.py b/src/coding/proxy/config/schema.py index ee21ee7..40e5428 100644 --- a/src/coding/proxy/config/schema.py +++ b/src/coding/proxy/config/schema.py @@ -54,6 +54,7 @@ KimiConfig, MinimaxConfig, XiaomiConfig, + ZhipuConcurrencyConfig, ZhipuConfig, ) @@ -318,6 +319,7 @@ def compat_state_path(self) -> Path: "CopilotConfig", "AntigravityConfig", "ZhipuConfig", + "ZhipuConcurrencyConfig", # resiliency "CircuitBreakerConfig", "RetryConfig", diff --git a/src/coding/proxy/config/vendors.py b/src/coding/proxy/config/vendors.py index 4f15531..a1c0280 100644 --- a/src/coding/proxy/config/vendors.py +++ b/src/coding/proxy/config/vendors.py @@ -2,7 +2,21 @@ from __future__ import annotations -from pydantic import BaseModel +from pydantic import BaseModel, Field + + +class ZhipuConcurrencyConfig(BaseModel): + """Zhipu 每模型并发限制配置.""" + + default: int = Field(default=3, ge=1, le=20, description="全局默认并行度") + models: dict[str, int] = Field( + default_factory=dict, + description="按映射后模型名自定义并行度(覆盖 default)", + ) + + def get_limit(self, model: str) -> int: + """获取指定模型的并行度限制.""" + return self.models.get(model, self.default) class AnthropicConfig(BaseModel): @@ -48,6 +62,7 @@ class ZhipuConfig(BaseModel): base_url: str = "https://open.bigmodel.cn/api/anthropic" api_key: str = "" timeout_ms: int = 3000000 + concurrency: ZhipuConcurrencyConfig = Field(default_factory=ZhipuConcurrencyConfig) class MinimaxConfig(BaseModel): @@ -100,6 +115,7 @@ class AlibabaConfig(BaseModel): "CopilotConfig", "AntigravityConfig", "ZhipuConfig", + "ZhipuConcurrencyConfig", "MinimaxConfig", "KimiConfig", "DoubaoConfig", diff --git a/src/coding/proxy/server/factory.py b/src/coding/proxy/server/factory.py index a1f64a3..4e7632d 100644 --- a/src/coding/proxy/server/factory.py +++ b/src/coding/proxy/server/factory.py @@ -156,13 +156,17 @@ def _create_vendor_from_config( cfg = _resolve_antigravity_credentials(cfg, token_store) return AntigravityVendor(cfg, failover_cfg, mapper) case "zhipu": - cfg = ZhipuConfig( - enabled=vendor_cfg.enabled, - base_url=vendor_cfg.base_url + zhipu_kwargs: dict[str, Any] = { + "enabled": vendor_cfg.enabled, + "base_url": vendor_cfg.base_url or "https://open.bigmodel.cn/api/anthropic", - api_key=vendor_cfg.api_key, - timeout_ms=vendor_cfg.timeout_ms, - ) + "api_key": vendor_cfg.api_key, + "timeout_ms": vendor_cfg.timeout_ms, + } + # 仅当显式配置了 concurrency 时转发,否则使用 ZhipuConfig 默认值 + if vendor_cfg.concurrency is not None: + zhipu_kwargs["concurrency"] = vendor_cfg.concurrency + cfg = ZhipuConfig(**zhipu_kwargs) return ZhipuVendor(cfg, mapper, failover_cfg) case "minimax": cfg = MinimaxConfig( diff --git a/src/coding/proxy/vendors/concurrency.py b/src/coding/proxy/vendors/concurrency.py new file mode 100644 index 0000000..b4f4df7 --- /dev/null +++ b/src/coding/proxy/vendors/concurrency.py @@ -0,0 +1,78 @@ +"""每模型并发限制器 — 基于 asyncio.Semaphore 的公平排队. + +为每个映射后的模型(如 ``glm-5v-turbo``)独立维护一个 ``asyncio.Semaphore``, +确保同一时间点该模型的并行请求数不超过配置的上限。当所有槽位被占满时, +新请求按 FIFO 顺序排队等待,直到有槽位释放。 + +设计要点: + - **惰性创建**:仅在首次请求到达时才为该模型创建 Semaphore,避免冷启动开销 + - **FIFO 公平**:``asyncio.Semaphore`` 内部使用 FIFO 队列,天然满足排队语义 + - **按映射后模型名键控**:与上游真实承载能力对齐,而非按客户端请求名(如 ``claude-sonnet-*``) +""" + +from __future__ import annotations + +import asyncio +import logging + +from ..config.vendors import ZhipuConcurrencyConfig + +logger = logging.getLogger(__name__) + + +class ModelConcurrencyLimiter: + """按模型名提供独立并发槽位的限制器. + + 用法:: + + limiter = ModelConcurrencyLimiter(config) + sem = await limiter.acquire("glm-5v-turbo") + try: + ... # 执行请求 + finally: + sem.release() + """ + + def __init__(self, config: ZhipuConcurrencyConfig) -> None: + self._config = config + self._semaphores: dict[str, asyncio.Semaphore] = {} + + def _get_semaphore(self, model: str) -> asyncio.Semaphore: + """获取(或惰性创建)指定模型的信号量.""" + sem = self._semaphores.get(model) + if sem is None: + limit = self._config.get_limit(model) + sem = asyncio.Semaphore(limit) + self._semaphores[model] = sem + logger.debug( + "ModelConcurrencyLimiter: created semaphore model=%s limit=%d", + model, + limit, + ) + return sem + + async def acquire(self, model: str) -> asyncio.Semaphore: + """获取指定模型的并发槽位,必要时阻塞排队. + + 返回已获取的 Semaphore 实例,调用方负责在请求完成后调用 ``release()``。 + """ + sem = self._get_semaphore(model) + await sem.acquire() + return sem + + def get_diagnostics(self) -> dict[str, dict[str, int]]: + """返回每个模型的并发状态快照(用于可观测性).""" + snapshot: dict[str, dict[str, int]] = {} + for model, sem in self._semaphores.items(): + limit = self._config.get_limit(model) + # asyncio.Semaphore 内部 _value 表示剩余可用槽位 + available = sem._value # noqa: SLF001 — 公开 API 未暴露 + snapshot[model] = { + "limit": limit, + "in_use": max(limit - available, 0), + "available": max(available, 0), + } + return snapshot + + +__all__ = ["ModelConcurrencyLimiter"] diff --git a/src/coding/proxy/vendors/zhipu.py b/src/coding/proxy/vendors/zhipu.py index e7ed8c7..ff186cd 100644 --- a/src/coding/proxy/vendors/zhipu.py +++ b/src/coding/proxy/vendors/zhipu.py @@ -34,6 +34,7 @@ ) from ..routing.retry import RetryConfig, calculate_delay from .base import VendorResponse +from .concurrency import ModelConcurrencyLimiter from .native_anthropic import NativeAnthropicVendor logger = logging.getLogger(__name__) @@ -68,6 +69,12 @@ def __init__( ) -> None: super().__init__(config, model_mapper, failover_config) self._rl_retry = _RATE_LIMIT_RETRY + # 每模型并发限制器(config.concurrency 为 None 时禁用) + self._concurrency_limiter: ModelConcurrencyLimiter | None = ( + ModelConcurrencyLimiter(config.concurrency) + if config.concurrency is not None + else None + ) # ── 非流式:429 重试 ──────────────────────────────────── @@ -76,7 +83,24 @@ async def send_message( request_body: dict[str, Any], headers: dict[str, str], ) -> VendorResponse: - """非流式请求,429 时自动重试.""" + """非流式请求,429 时自动重试. + + 在 429 重试循环外层套上每模型并发槽位获取,确保同一时间点同一模型的 + 在途请求数不超过配置上限;超过时新请求 FIFO 排队等待。 + """ + sem = await self._maybe_acquire_concurrency_slot(request_body) + try: + return await self._send_message_with_retry(request_body, headers) + finally: + if sem is not None: + sem.release() + + async def _send_message_with_retry( + self, + request_body: dict[str, Any], + headers: dict[str, str], + ) -> VendorResponse: + """原 send_message 主体逻辑(不含并发控制).""" max_attempts = self._rl_retry.max_attempts for attempt in range(max_attempts): @@ -116,42 +140,71 @@ async def send_message_stream( 安全性:429 在 BaseVendor.send_message_stream 中于 status code 检查阶段即 raise(在任何 chunk yield 之前), 因此重试不会导致已发出数据不一致。 + + 在 429 重试循环外层套上每模型并发槽位获取,确保流式请求与非流式请求 + 共用同一信号量,统一限制同一模型的总在途并发数。 """ + sem = await self._maybe_acquire_concurrency_slot(request_body) max_attempts = self._rl_retry.max_attempts - for attempt in range(max_attempts): - try: - # 429 在 status code 检查阶段即 raise(在任何 chunk 之前), - # 因此 __anext__ 安全:要么拿到首个 chunk,要么抛异常。 - ait = super().send_message_stream(request_body, headers) - head = await ait.__anext__() - except StopAsyncIteration: - return - except httpx.HTTPStatusError as exc: - if exc.response is None or exc.response.status_code != 429: - raise - if attempt == max_attempts - 1: - logger.warning( - "Zhipu 429 stream rate limit exhausted after %d attempts", - max_attempts, + try: + for attempt in range(max_attempts): + try: + # 429 在 status code 检查阶段即 raise(在任何 chunk 之前), + # 因此 __anext__ 安全:要么拿到首个 chunk,要么抛异常。 + ait = super().send_message_stream(request_body, headers) + head = await ait.__anext__() + except StopAsyncIteration: + return + except httpx.HTTPStatusError as exc: + if exc.response is None or exc.response.status_code != 429: + raise + if attempt == max_attempts - 1: + logger.warning( + "Zhipu 429 stream rate limit exhausted after %d attempts", + max_attempts, + ) + raise + + delay = self._compute_retry_delay_from_response( + exc.response, attempt ) - raise - - delay = self._compute_retry_delay_from_response(exc.response, attempt) - logger.info( - "Zhipu 429 stream rate limit, retry %d/%d in %.1fms", - attempt + 1, - max_attempts - 1, - delay, - ) - await asyncio.sleep(delay / 1000.0) - continue - - # yield 在 try/except 之外,避免捕获外部 athrow 的异常 - yield head - async for chunk in ait: - yield chunk - return + logger.info( + "Zhipu 429 stream rate limit, retry %d/%d in %.1fms", + attempt + 1, + max_attempts - 1, + delay, + ) + await asyncio.sleep(delay / 1000.0) + continue + + # yield 在 try/except 之外,避免捕获外部 athrow 的异常 + yield head + async for chunk in ait: + yield chunk + return + finally: + if sem is not None: + sem.release() + + # ── 并发控制 ──────────────────────────────────────────── + + async def _maybe_acquire_concurrency_slot( + self, + request_body: dict[str, Any], + ) -> asyncio.Semaphore | None: + """按映射后模型名获取并发槽位;未配置 concurrency 时返回 None. + + ``map_model()`` 是纯同步字典查找,在 Semaphore 等待前调用是安全的, + 且能确保排队键与上游真实承载模型对齐。 + """ + if self._concurrency_limiter is None: + return None + raw_model = request_body.get("model", "") if request_body else "" + mapped_model = self.map_model(raw_model) if raw_model else "" + if not mapped_model: + return None + return await self._concurrency_limiter.acquire(mapped_model) # ── 延迟计算 ──────────────────────────────────────────── diff --git a/tests/test_zhipu_concurrency.py b/tests/test_zhipu_concurrency.py new file mode 100644 index 0000000..3c8a97d --- /dev/null +++ b/tests/test_zhipu_concurrency.py @@ -0,0 +1,557 @@ +"""Zhipu 每模型并发限制专项测试. + +验证 ``ModelConcurrencyLimiter`` 与 ``ZhipuVendor`` 集成后的并发控制行为: + - 默认 ``concurrency.default=3`` 时同一模型最多 3 个并发 + - 超出上限时按 FIFO 排队,槽位释放后才唤醒 + - 不同模型彼此独立,互不阻塞 + - 异常路径下 Semaphore 仍能释放,避免泄漏 + - 流式请求与非流式请求共享同一信号量 + - 与 429 重试机制兼容(重试期间持续占用槽位) + - ``concurrency=None`` 时禁用限制(向后兼容) +""" + +from __future__ import annotations + +import asyncio +import json +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from coding.proxy.config.schema import ( + ModelMappingRule, + ZhipuConcurrencyConfig, + ZhipuConfig, +) +from coding.proxy.routing.model_mapper import ModelMapper +from coding.proxy.vendors.concurrency import ModelConcurrencyLimiter +from coding.proxy.vendors.native_anthropic import NativeAnthropicVendor +from coding.proxy.vendors.zhipu import ZhipuVendor + +# ─── 测试工具 ─────────────────────────────────────────────── + + +def _make_mapper() -> ModelMapper: + """构造标准三模型映射的 ModelMapper.""" + return ModelMapper( + [ + ModelMappingRule( + pattern="claude-sonnet-.*", + target="glm-5v-turbo", + is_regex=True, + vendors=["zhipu"], + ), + ModelMappingRule( + pattern="claude-opus-.*", + target="glm-5.1", + is_regex=True, + vendors=["zhipu"], + ), + ModelMappingRule( + pattern="claude-haiku-.*", + target="glm-4.5-air", + is_regex=True, + vendors=["zhipu"], + ), + ] + ) + + +def _make_vendor( + concurrency: ZhipuConcurrencyConfig | None = None, + api_key: str = "test-zhipu-key", +) -> ZhipuVendor: + """构造一个 ZhipuVendor,默认启用并发限制(default=3).""" + cfg_kwargs: dict = {"api_key": api_key} + if concurrency is not None: + cfg_kwargs["concurrency"] = concurrency + return ZhipuVendor(ZhipuConfig(**cfg_kwargs), _make_mapper()) + + +def _make_200_response() -> httpx.Response: + body = json.dumps( + { + "id": "msg_test", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "ok"}], + "model": "glm-5.1", + "usage": {"input_tokens": 1, "output_tokens": 1}, + } + ).encode() + return httpx.Response( + status_code=200, + content=body, + headers={"content-type": "application/json"}, + request=httpx.Request( + "POST", "https://open.bigmodel.cn/api/anthropic/v1/messages" + ), + ) + + +def _make_429_response() -> httpx.Response: + return httpx.Response( + status_code=429, + content=b'{"error":{"type":"rate_limit_error","message":"slow down"}}', + headers={}, + request=httpx.Request( + "POST", "https://open.bigmodel.cn/api/anthropic/v1/messages" + ), + ) + + +# ─── 配置层测试 ───────────────────────────────────────────── + + +class TestZhipuConcurrencyConfig: + """ZhipuConcurrencyConfig 配置模型行为.""" + + def test_defaults(self) -> None: + cfg = ZhipuConcurrencyConfig() + assert cfg.default == 3 + assert cfg.models == {} + + def test_get_limit_falls_back_to_default(self) -> None: + cfg = ZhipuConcurrencyConfig(default=5) + assert cfg.get_limit("glm-5.1") == 5 + assert cfg.get_limit("any-unknown-model") == 5 + + def test_get_limit_uses_per_model_override(self) -> None: + cfg = ZhipuConcurrencyConfig(default=3, models={"glm-5v-turbo": 1}) + assert cfg.get_limit("glm-5v-turbo") == 1 + assert cfg.get_limit("glm-5.1") == 3 # 未覆盖时回退 default + + def test_default_must_be_positive(self) -> None: + with pytest.raises(ValueError): + ZhipuConcurrencyConfig(default=0) + + def test_zhipu_config_default_concurrency(self) -> None: + cfg = ZhipuConfig() + assert cfg.concurrency is not None + assert cfg.concurrency.default == 3 + + +# ─── ModelConcurrencyLimiter 单元测试 ────────────────────── + + +class TestModelConcurrencyLimiter: + """ModelConcurrencyLimiter 基础行为.""" + + @pytest.mark.asyncio + async def test_lazy_semaphore_creation(self) -> None: + limiter = ModelConcurrencyLimiter(ZhipuConcurrencyConfig(default=2)) + sem_a = limiter._get_semaphore("model-a") + sem_b = limiter._get_semaphore("model-b") + # 不同模型独立 semaphore + assert sem_a is not sem_b + # 相同模型复用 semaphore + assert limiter._get_semaphore("model-a") is sem_a + + @pytest.mark.asyncio + async def test_acquire_blocks_when_full(self) -> None: + limiter = ModelConcurrencyLimiter(ZhipuConcurrencyConfig(default=2)) + + # 占满 2 个槽位 + sem1 = await limiter.acquire("glm-5.1") + sem2 = await limiter.acquire("glm-5.1") + assert sem1 is sem2 # 同一 semaphore + + # 第 3 次 acquire 必须阻塞 + task = asyncio.create_task(limiter.acquire("glm-5.1")) + await asyncio.sleep(0.05) + assert not task.done(), "第三个请求应在排队等待" + + # 释放一个槽位后,等待者被唤醒 + sem1.release() + await asyncio.sleep(0.05) + assert task.done() + (await task).release() + sem2.release() + + @pytest.mark.asyncio + async def test_per_model_independent(self) -> None: + limiter = ModelConcurrencyLimiter( + ZhipuConcurrencyConfig(default=1, models={"glm-5.1": 1}) + ) + # 占满 glm-5.1 + sem_51 = await limiter.acquire("glm-5.1") + # glm-5v-turbo 仍可立即获取 + sem_5v = await asyncio.wait_for(limiter.acquire("glm-5v-turbo"), timeout=0.5) + assert sem_51 is not sem_5v + sem_51.release() + sem_5v.release() + + def test_diagnostics_snapshot(self) -> None: + limiter = ModelConcurrencyLimiter(ZhipuConcurrencyConfig(default=3)) + # 触发 semaphore 创建 + limiter._get_semaphore("glm-5.1") + snap = limiter.get_diagnostics() + assert "glm-5.1" in snap + assert snap["glm-5.1"]["limit"] == 3 + assert snap["glm-5.1"]["available"] == 3 + assert snap["glm-5.1"]["in_use"] == 0 + + +# ─── ZhipuVendor 集成测试:非流式 ──────────────────────────── + + +class TestZhipuVendorNonStreamConcurrency: + """非流式 send_message 的并发限制行为.""" + + @pytest.mark.asyncio + async def test_limits_parallel_requests(self) -> None: + """concurrency.default=2 时,3 个并发请求中只有 2 个同时执行.""" + vendor = _make_vendor(ZhipuConcurrencyConfig(default=2)) + active = 0 + peak = 0 + gate = asyncio.Event() + + async def mock_post(*_, **__) -> httpx.Response: + nonlocal active, peak + active += 1 + peak = max(peak, active) + # 等待外部释放,保证并发观测窗口 + await gate.wait() + active -= 1 + return _make_200_response() + + with patch.object(vendor, "_get_client") as mock_client: + client = AsyncMock() + client.post = mock_post + mock_client.return_value = client + + tasks = [ + asyncio.create_task( + vendor.send_message( + {"model": "claude-opus-4-6", "messages": []}, + {}, + ) + ) + for _ in range(3) + ] + # 等待两个请求进入 active 状态 + for _ in range(40): + if active >= 2: + break + await asyncio.sleep(0.01) + + assert active == 2, "应有恰好 2 个请求在执行(第 3 个排队)" + gate.set() + results = await asyncio.gather(*tasks) + assert all(r.status_code == 200 for r in results) + assert peak == 2, "并发峰值不应超过 2" + + @pytest.mark.asyncio + async def test_per_model_independent(self) -> None: + """不同模型的槽位互不影响.""" + cfg = ZhipuConcurrencyConfig( + default=3, + models={"glm-5v-turbo": 1, "glm-5.1": 1}, + ) + vendor = _make_vendor(cfg) + gate = asyncio.Event() + seen_models: list[str] = [] + + async def mock_post(*_args, **kwargs) -> httpx.Response: + body = kwargs.get("json", {}) + seen_models.append(body.get("model", "")) + await gate.wait() + return _make_200_response() + + with patch.object(vendor, "_get_client") as mock_client: + client = AsyncMock() + client.post = mock_post + mock_client.return_value = client + + # claude-opus → glm-5.1, claude-sonnet → glm-5v-turbo, + # 分属两个独立信号量,应同时执行 + task_opus = asyncio.create_task( + vendor.send_message( + {"model": "claude-opus-4-6", "messages": []}, + {}, + ) + ) + task_sonnet = asyncio.create_task( + vendor.send_message( + {"model": "claude-sonnet-4-6", "messages": []}, + {}, + ) + ) + for _ in range(40): + if len(seen_models) >= 2: + break + await asyncio.sleep(0.01) + + assert len(seen_models) == 2, "两个不同模型应并发执行" + assert set(seen_models) == {"glm-5.1", "glm-5v-turbo"} + gate.set() + await asyncio.gather(task_opus, task_sonnet) + + @pytest.mark.asyncio + async def test_semaphore_released_on_exception(self) -> None: + """上游抛异常时 Semaphore 仍应释放,后续请求不阻塞.""" + vendor = _make_vendor(ZhipuConcurrencyConfig(default=1)) + call_count = 0 + + async def mock_post(*_, **__) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("upstream boom") + return _make_200_response() + + with patch.object(vendor, "_get_client") as mock_client: + client = AsyncMock() + client.post = mock_post + mock_client.return_value = client + + with pytest.raises(RuntimeError): + await vendor.send_message( + {"model": "claude-opus-4-6", "messages": []}, + {}, + ) + + # 槽位应已释放,第二次请求可正常完成 + resp = await asyncio.wait_for( + vendor.send_message( + {"model": "claude-opus-4-6", "messages": []}, + {}, + ), + timeout=1.0, + ) + assert resp.status_code == 200 + + @pytest.mark.asyncio + async def test_429_retry_holds_slot(self) -> None: + """429 重试期间持续占用槽位,重试结束后释放.""" + vendor = _make_vendor(ZhipuConcurrencyConfig(default=1)) + call_count = 0 + + async def mock_post(*_, **__) -> httpx.Response: + nonlocal call_count + call_count += 1 + if call_count <= 2: + return _make_429_response() + return _make_200_response() + + with ( + patch.object(vendor, "_get_client") as mock_client, + patch("asyncio.sleep", new_callable=AsyncMock), + ): + client = AsyncMock() + client.post = mock_post + mock_client.return_value = client + + resp = await vendor.send_message( + {"model": "claude-opus-4-6", "messages": []}, + {}, + ) + assert resp.status_code == 200 + assert call_count == 3 # 两次 429 + 一次成功,且共用同一槽位 + + @pytest.mark.asyncio + async def test_no_concurrency_when_config_is_none(self) -> None: + """concurrency=None 时禁用并发限制,行为与旧版完全一致.""" + # 强制构造一个 concurrency=None 的 ZhipuConfig(绕过默认工厂) + cfg = ZhipuConfig(api_key="key") + cfg = cfg.model_copy(update={"concurrency": None}) + vendor = ZhipuVendor(cfg, _make_mapper()) + assert vendor._concurrency_limiter is None + + gate = asyncio.Event() + active = 0 + peak = 0 + + async def mock_post(*_, **__) -> httpx.Response: + nonlocal active, peak + active += 1 + peak = max(peak, active) + await gate.wait() + active -= 1 + return _make_200_response() + + with patch.object(vendor, "_get_client") as mock_client: + client = AsyncMock() + client.post = mock_post + mock_client.return_value = client + + tasks = [ + asyncio.create_task( + vendor.send_message( + {"model": "claude-opus-4-6", "messages": []}, + {}, + ) + ) + for _ in range(5) + ] + for _ in range(40): + if active >= 5: + break + await asyncio.sleep(0.01) + + assert peak == 5, "无并发限制时应全部并行" + gate.set() + await asyncio.gather(*tasks) + + +# ─── ZhipuVendor 集成测试:流式 ────────────────────────────── + + +class TestZhipuVendorStreamConcurrency: + """流式 send_message_stream 的并发限制行为.""" + + @pytest.mark.asyncio + async def test_stream_limits_parallel_requests(self) -> None: + """流式请求遵循并发限制,超出排队等待.""" + vendor = _make_vendor(ZhipuConcurrencyConfig(default=1)) + active = 0 + peak = 0 + gate = asyncio.Event() + + async def fake_stream(self, _body, _headers): # noqa: ARG001 + nonlocal active, peak + active += 1 + peak = max(peak, active) + try: + await gate.wait() + yield b'data: {"type":"message_start"}\n\n' + finally: + active -= 1 + + async def consume(model: str) -> int: + chunks: list[bytes] = [] + async for chunk in vendor.send_message_stream( + {"model": model, "messages": []}, {} + ): + chunks.append(chunk) + return len(chunks) + + with patch.object(NativeAnthropicVendor, "send_message_stream", fake_stream): + tasks = [asyncio.create_task(consume("claude-opus-4-6")) for _ in range(3)] + for _ in range(40): + if active >= 1: + break + await asyncio.sleep(0.01) + + assert active == 1, "concurrency=1 时只允许 1 个流式请求并发" + gate.set() + results = await asyncio.gather(*tasks) + assert all(c >= 1 for c in results) + assert peak == 1 + + @pytest.mark.asyncio + async def test_stream_releases_slot_on_completion(self) -> None: + """流式生成器正常耗尽后槽位释放.""" + vendor = _make_vendor(ZhipuConcurrencyConfig(default=1)) + + async def fake_stream(self, _body, _headers): # noqa: ARG001 + yield b'data: {"type":"message_start"}\n\n' + yield b'data: {"type":"message_stop"}\n\n' + + with patch.object(NativeAnthropicVendor, "send_message_stream", fake_stream): + # 连续两次流式请求都能完成(说明槽位被释放) + for _ in range(2): + chunks = [] + async for chunk in vendor.send_message_stream( + {"model": "claude-opus-4-6", "messages": []}, {} + ): + chunks.append(chunk) + assert len(chunks) == 2 + + # 确认 semaphore 当前完全可用 + assert vendor._concurrency_limiter is not None + sem = vendor._concurrency_limiter._get_semaphore("glm-5.1") + assert sem._value == 1 # noqa: SLF001 + + @pytest.mark.asyncio + async def test_stream_releases_slot_on_error(self) -> None: + """流式请求异常退出时槽位仍释放,后续请求不被阻塞.""" + vendor = _make_vendor(ZhipuConcurrencyConfig(default=1)) + call_count = 0 + + async def fake_stream(self, _body, _headers): # noqa: ARG001 + nonlocal call_count + call_count += 1 + if call_count == 1: + resp = httpx.Response( + status_code=500, + content=b'{"error":{"type":"api_error"}}', + request=httpx.Request("POST", "https://example.com"), + ) + raise httpx.HTTPStatusError("500", request=resp.request, response=resp) + yield b"" # 让函数成为 async generator(不可达) + yield b'data: {"type":"message_start"}\n\n' + + with patch.object(NativeAnthropicVendor, "send_message_stream", fake_stream): + with pytest.raises(httpx.HTTPStatusError): + async for _ in vendor.send_message_stream( + {"model": "claude-opus-4-6", "messages": []}, {} + ): + pass + + # 槽位应已释放,第二次请求可正常推进 + chunks = [] + async for chunk in vendor.send_message_stream( + {"model": "claude-opus-4-6", "messages": []}, {} + ): + chunks.append(chunk) + assert chunks == [b'data: {"type":"message_start"}\n\n'] + + @pytest.mark.asyncio + async def test_stream_and_nonstream_share_semaphore(self) -> None: + """流式与非流式请求共用同一信号量(按映射后模型分组).""" + vendor = _make_vendor(ZhipuConcurrencyConfig(default=1)) + gate = asyncio.Event() + active = 0 + + async def fake_stream(self, _body, _headers): # noqa: ARG001 + nonlocal active + active += 1 + try: + await gate.wait() + yield b'data: {"type":"message_start"}\n\n' + finally: + active -= 1 + + async def mock_post(*_, **__) -> httpx.Response: + nonlocal active + active += 1 + active -= 1 + return _make_200_response() + + with ( + patch.object(NativeAnthropicVendor, "send_message_stream", fake_stream), + patch.object(vendor, "_get_client") as mock_client, + ): + client = AsyncMock() + client.post = mock_post + mock_client.return_value = client + + # 启动流式请求并等待它占用槽位 + async def consume_stream() -> None: + async for _ in vendor.send_message_stream( + {"model": "claude-opus-4-6", "messages": []}, {} + ): + pass + + stream_task = asyncio.create_task(consume_stream()) + for _ in range(40): + if active >= 1: + break + await asyncio.sleep(0.01) + assert active == 1 + + # 非流式请求应被同一信号量阻塞 + nonstream_task = asyncio.create_task( + vendor.send_message( + {"model": "claude-opus-4-6", "messages": []}, + {}, + ) + ) + await asyncio.sleep(0.05) + assert not nonstream_task.done(), "非流式请求应等待流式释放槽位" + + # 释放后两者都能完成 + gate.set() + await asyncio.gather(stream_task, nonstream_task) From bcc2a68c9b64228f9dc5b79ce753d672048dbdcc Mon Sep 17 00:00:00 2001 From: ThreeFish Date: Tue, 26 May 2026 00:16:53 +0800 Subject: [PATCH 2/2] =?UTF-8?q?fix(zhipu):=20=E5=B0=86=20concurrency=20?= =?UTF-8?q?=E6=B3=A8=E5=86=8C=E5=88=B0=20=5FZHIPU=5FFIELDS=20=E7=8B=AC?= =?UTF-8?q?=E5=8D=A0=E5=AD=97=E6=AE=B5=E9=9B=86=E5=90=88=EF=BC=8C=E7=A1=AE?= =?UTF-8?q?=E4=BF=9D=E9=9D=9E=20zhipu=20=E4=BE=9B=E5=BA=94=E5=95=86?= =?UTF-8?q?=E8=AF=AF=E9=85=8D=E6=97=B6=E8=A7=A6=E5=8F=91=20warning;?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 🤖 Generated with [Claude Code](https://github.com/claude), [CodeX](https://openai.com), [Gemini](https://github.com/apps/gemini-code-assist) Co-Authored-By: Aurelius Huang --- src/coding/proxy/config/routing.py | 6 +++--- tests/test_schema.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/coding/proxy/config/routing.py b/src/coding/proxy/config/routing.py index d0b2d48..2c29363 100644 --- a/src/coding/proxy/config/routing.py +++ b/src/coding/proxy/config/routing.py @@ -65,13 +65,13 @@ def _detect_currency(v: Any) -> str | None: "api_key", } ) -# 向后兼容别名 -_ZHIPU_FIELDS = _NATIVE_ANTHROPIC_FIELDS +# Zhipu 独占字段:在通用 api_key 基础上增加每模型并发限制 +_ZHIPU_FIELDS: frozenset[str] = _NATIVE_ANTHROPIC_FIELDS | frozenset({"concurrency"}) _VENDOR_EXCLUSIVE_FIELDS: dict[str, frozenset[str]] = { "copilot": _COPILOT_FIELDS, "antigravity": _ANTIGRAVITY_FIELDS, - "zhipu": _NATIVE_ANTHROPIC_FIELDS, + "zhipu": _ZHIPU_FIELDS, "minimax": _NATIVE_ANTHROPIC_FIELDS, "kimi": _NATIVE_ANTHROPIC_FIELDS, "doubao": _NATIVE_ANTHROPIC_FIELDS, diff --git a/tests/test_schema.py b/tests/test_schema.py index ae7120e..30d691c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -31,7 +31,8 @@ def test_antigravity_fields_set(): def test_zhipu_fields_set(): assert "api_key" in _ZHIPU_FIELDS - assert len(_ZHIPU_FIELDS) == 1 + assert "concurrency" in _ZHIPU_FIELDS + assert len(_ZHIPU_FIELDS) == 2 def test_vendor_exclusive_fields_mapping_complete():