22import multiprocessing
33import socket
44from collections .abc import AsyncGenerator , Generator
5- from typing import Any
5+ from typing import Any , cast
66from unittest .mock import AsyncMock , MagicMock , Mock , patch
77from urllib .parse import urlparse
88
2020import mcp .client .sse
2121from mcp import types
2222from mcp .client .session import ClientSession
23- from mcp .client .sse import _extract_session_id_from_endpoint , sse_client
23+ from mcp .client .sse import _extract_session_id_from_endpoint , _resolve_endpoint_url , sse_client
2424from mcp .server import Server , ServerRequestContext
2525from mcp .server .sse import SseServerTransport
2626from mcp .server .transport_security import TransportSecuritySettings
2727from mcp .shared .exceptions import MCPError
28+ from mcp .shared .message import SessionMessage
2829from mcp .types import (
2930 CallToolRequestParams ,
3031 CallToolResult ,
@@ -229,6 +230,50 @@ def test_extract_session_id_from_endpoint(endpoint_url: str, expected: str | Non
229230 assert _extract_session_id_from_endpoint (endpoint_url ) == expected
230231
231232
233+ @pytest .mark .parametrize (
234+ ("sse_url" , "endpoint_data" , "messages_url" , "expected" ),
235+ [
236+ (
237+ "https://example.com/api/v1/sse" ,
238+ "/v1/messages/?session_id=abc123" ,
239+ None ,
240+ "https://example.com/v1/messages/?session_id=abc123" ,
241+ ),
242+ (
243+ "https://example.com/api/v1/sse" ,
244+ "/v1/messages/?session_id=abc123" ,
245+ "https://example.com/api/v1/messages/" ,
246+ "https://example.com/api/v1/messages/?session_id=abc123" ,
247+ ),
248+ (
249+ "https://example.com/api/v1/sse" ,
250+ "/v1/messages/?session_id=abc123" ,
251+ "/api/v1/messages/" ,
252+ "https://example.com/api/v1/messages/?session_id=abc123" ,
253+ ),
254+ (
255+ "https://example.com/api/v1/sse" ,
256+ "/v1/messages/?session_id=abc123" ,
257+ "https://example.com/api/v1/messages/?tenant=blue" ,
258+ "https://example.com/api/v1/messages/?tenant=blue&session_id=abc123" ,
259+ ),
260+ (
261+ "https://example.com/api/v1/sse" ,
262+ "/v1/messages/" ,
263+ "https://example.com/api/v1/messages/" ,
264+ "https://example.com/api/v1/messages/" ,
265+ ),
266+ ],
267+ )
268+ def test_resolve_endpoint_url_with_messages_url_override (
269+ sse_url : str ,
270+ endpoint_data : str ,
271+ messages_url : str | None ,
272+ expected : str ,
273+ ) -> None :
274+ assert _resolve_endpoint_url (sse_url , endpoint_data , messages_url ) == expected
275+
276+
232277@pytest .mark .anyio
233278async def test_sse_client_on_session_created_not_called_when_no_session_id (
234279 server : None , server_url : str , monkeypatch : pytest .MonkeyPatch
@@ -249,6 +294,50 @@ def mock_extract(url: str) -> None:
249294 callback_mock .assert_not_called ()
250295
251296
297+ @pytest .mark .anyio
298+ async def test_sse_client_uses_messages_url_override () -> None :
299+ async def mock_aiter_sse () -> AsyncGenerator [ServerSentEvent , None ]:
300+ yield ServerSentEvent (event = "endpoint" , data = "/v1/messages/?session_id=abc123" )
301+ await anyio .sleep_forever ()
302+
303+ mock_event_source = MagicMock ()
304+ mock_event_source .aiter_sse .return_value = mock_aiter_sse ()
305+ mock_event_source .response = MagicMock ()
306+ mock_event_source .response .raise_for_status = MagicMock ()
307+
308+ mock_aconnect_sse = MagicMock ()
309+ mock_aconnect_sse .__aenter__ = AsyncMock (return_value = mock_event_source )
310+ mock_aconnect_sse .__aexit__ = AsyncMock (return_value = None )
311+
312+ mock_client = MagicMock ()
313+ mock_client .__aenter__ = AsyncMock (return_value = mock_client )
314+ mock_client .__aexit__ = AsyncMock (return_value = None )
315+ mock_client .post = AsyncMock (return_value = MagicMock (status_code = 200 , raise_for_status = MagicMock ()))
316+
317+ def mock_httpx_client_factory (
318+ headers : dict [str , str ] | None = None ,
319+ timeout : httpx .Timeout | None = None ,
320+ auth : httpx .Auth | None = None ,
321+ ) -> httpx .AsyncClient :
322+ _ = (headers , timeout , auth )
323+ return cast (httpx .AsyncClient , mock_client )
324+
325+ with patch ("mcp.client.sse.aconnect_sse" , return_value = mock_aconnect_sse ):
326+ async with sse_client (
327+ "https://example.com/api/v1/sse" ,
328+ httpx_client_factory = mock_httpx_client_factory ,
329+ messages_url = "https://example.com/api/v1/messages/" ,
330+ ) as (_ , write_stream ):
331+ message = types .JSONRPCRequest (jsonrpc = "2.0" , id = 1 , method = "ping" )
332+ await write_stream .send (SessionMessage (message ))
333+ with anyio .fail_after (1 ): # pragma: no branch
334+ while not mock_client .post .await_count :
335+ await anyio .sleep (0.01 )
336+
337+ mock_client .post .assert_awaited ()
338+ assert mock_client .post .await_args .args [0 ] == "https://example.com/api/v1/messages/?session_id=abc123"
339+
340+
252341@pytest .fixture
253342async def initialized_sse_client_session (server : None , server_url : str ) -> AsyncGenerator [ClientSession , None ]:
254343 async with sse_client (server_url + "/sse" , sse_read_timeout = 0.5 ) as streams :
0 commit comments