|
| 1 | +"""Request task-augmented execution, then drive the task lifecycle via `tasks/*`.""" |
| 2 | + |
| 3 | +from typing import cast |
| 4 | + |
| 5 | +import mcp_types as types |
| 6 | + |
| 7 | +from mcp.client import Client |
| 8 | +from mcp.server.tasks import EXTENSION_ID, RELATED_TASK_META_KEY |
| 9 | +from stories._harness import Target, run_client |
| 10 | + |
| 11 | + |
| 12 | +async def main(target: Target, *, mode: str = "auto") -> None: |
| 13 | + async with Client(target, mode=mode) as client: |
| 14 | + # The extensions capability map rides `server/discover` (modern only); a legacy |
| 15 | + # connection (today's stdio) omits it, so assert it only when present. |
| 16 | + if client.server_capabilities.extensions is not None: |
| 17 | + assert client.server_capabilities.extensions == {EXTENSION_ID: {"list": {}, "cancel": {}}} |
| 18 | + |
| 19 | + # `Client` exposes only spec verbs, so task-augmented calls and the |
| 20 | + # `tasks/*` methods drop to `client.session` (see custom_methods/). The |
| 21 | + # casts satisfy the closed `ClientRequest` union; at runtime the body |
| 22 | + # only calls `.model_dump()`. |
| 23 | + session = client.session |
| 24 | + call = types.CallToolRequest( |
| 25 | + params=types.CallToolRequestParams( |
| 26 | + name="echo", arguments={"text": "async"}, task=types.TaskMetadata(ttl=60) |
| 27 | + ) |
| 28 | + ) |
| 29 | + result = await session.send_request(cast("types.ClientRequest", call), types.CallToolResult) |
| 30 | + assert result.meta is not None, result |
| 31 | + task_id = result.meta[RELATED_TASK_META_KEY]["taskId"] |
| 32 | + assert isinstance(result.content[0], types.TextContent) |
| 33 | + assert result.content[0].text == "async", result |
| 34 | + |
| 35 | + get = types.GetTaskRequest(params=types.GetTaskRequestParams(task_id=task_id)) |
| 36 | + status = await session.send_request(cast("types.ClientRequest", get), types.GetTaskResult) |
| 37 | + assert status.status == "completed", status |
| 38 | + |
| 39 | + payload_req = types.GetTaskPayloadRequest(params=types.GetTaskPayloadRequestParams(task_id=task_id)) |
| 40 | + payload = await session.send_request(cast("types.ClientRequest", payload_req), types.CallToolResult) |
| 41 | + assert isinstance(payload.content[0], types.TextContent) |
| 42 | + assert payload.content[0].text == "async", payload |
| 43 | + |
| 44 | + |
| 45 | +if __name__ == "__main__": |
| 46 | + run_client(main) |
0 commit comments