diff --git a/astrbot/core/agent/mcp_client.py b/astrbot/core/agent/mcp_client.py index b75999ea65..3898c04935 100644 --- a/astrbot/core/agent/mcp_client.py +++ b/astrbot/core/agent/mcp_client.py @@ -12,7 +12,7 @@ from tenacity import ( before_sleep_log, retry, - retry_if_exception_type, + retry_if_exception, stop_after_attempt, wait_exponential, ) @@ -92,6 +92,10 @@ } ) _STDIO_ALLOWLIST_ENV = "ASTRBOT_MCP_STDIO_ALLOWED_COMMANDS" +_MCP_RECONNECT_ERROR_MESSAGES = ( + "session terminated", + "session was terminated", +) try: import anyio @@ -110,6 +114,22 @@ ) +def _is_mcp_reconnect_error(exc: BaseException) -> bool: + try: + anyio_module = anyio + except NameError: + anyio_module = None + + closed_resource_error = getattr(anyio_module, "ClosedResourceError", None) + if isinstance(closed_resource_error, type) and isinstance( + exc, closed_resource_error + ): + return True + + message = str(exc).lower() + return any(marker in message for marker in _MCP_RECONNECT_ERROR_MESSAGES) + + def _prepare_config(config: dict) -> dict: """Prepare configuration, handle nested format""" if config.get("mcpServers"): @@ -605,7 +625,7 @@ async def call_tool_with_reconnect( """ @retry( - retry=retry_if_exception_type(anyio.ClosedResourceError), + retry=retry_if_exception(_is_mcp_reconnect_error), stop=stop_after_attempt(2), wait=wait_exponential(multiplier=1, min=1, max=3), before_sleep=before_sleep_log(logger, logging.WARNING), @@ -621,9 +641,15 @@ async def _call_with_retry(): arguments=arguments, read_timeout_seconds=read_timeout_seconds, ) - except anyio.ClosedResourceError: + except Exception as exc: + if not _is_mcp_reconnect_error(exc): + raise + logger.warning( - f"MCP tool {tool_name} call failed (ClosedResourceError), attempting to reconnect..." + "MCP tool %s call failed (%s: %s), attempting to reconnect...", + tool_name, + type(exc).__name__, + exc, ) # Attempt to reconnect await self._reconnect() diff --git a/tests/unit/test_mcp_client_reconnect.py b/tests/unit/test_mcp_client_reconnect.py new file mode 100644 index 0000000000..4ae1dcea44 --- /dev/null +++ b/tests/unit/test_mcp_client_reconnect.py @@ -0,0 +1,103 @@ +from datetime import timedelta + +import anyio +import pytest +from tenacity import wait_none + +from astrbot.core.agent import mcp_client + + +class FlakyMcpSession: + def __init__(self, first_error: Exception | None = None) -> None: + self.calls = 0 + self.first_error = first_error or RuntimeError("Session terminated") + + async def call_tool( + self, + *, + name: str, + arguments: dict, + read_timeout_seconds: timedelta, + ) -> dict[str, object]: + self.calls += 1 + if self.calls == 1: + raise self.first_error + return { + "name": name, + "arguments": arguments, + "timeout": read_timeout_seconds.total_seconds(), + } + + +@pytest.mark.parametrize( + ("error", "expected"), + [ + (RuntimeError("Session terminated"), True), + (RuntimeError("SESSION TERMINATED"), True), + (RuntimeError("session was terminated"), True), + (anyio.ClosedResourceError(), True), + (RuntimeError("business flow terminated normally"), False), + (RuntimeError("terminated"), False), + ], +) +def test_mcp_reconnect_error_detection_is_narrow( + error: BaseException, expected: bool +) -> None: + assert mcp_client._is_mcp_reconnect_error(error) is expected + + +@pytest.mark.asyncio +async def test_call_tool_reconnects_on_session_terminated(monkeypatch) -> None: + monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none()) + + client = mcp_client.MCPClient() + session = FlakyMcpSession() + reconnects = 0 + + async def reconnect() -> None: + nonlocal reconnects + reconnects += 1 + client.session = session + + client.session = session + client._reconnect = reconnect + + result = await client.call_tool_with_reconnect( + tool_name="lookup", + arguments={"url": "https://example.com"}, + read_timeout_seconds=timedelta(seconds=5), + ) + + assert result == { + "name": "lookup", + "arguments": {"url": "https://example.com"}, + "timeout": 5.0, + } + assert session.calls == 2 + assert reconnects == 1 + + +@pytest.mark.asyncio +async def test_call_tool_does_not_reconnect_on_business_error(monkeypatch) -> None: + monkeypatch.setattr(mcp_client, "wait_exponential", lambda **_: wait_none()) + + client = mcp_client.MCPClient() + session = FlakyMcpSession(first_error=ValueError("business logic failed")) + reconnects = 0 + + async def reconnect() -> None: + nonlocal reconnects + reconnects += 1 + + client.session = session + client._reconnect = reconnect + + with pytest.raises(ValueError, match="business logic failed"): + await client.call_tool_with_reconnect( + tool_name="lookup", + arguments={"url": "https://example.com"}, + read_timeout_seconds=timedelta(seconds=5), + ) + + assert session.calls == 1 + assert reconnects == 0