Skip to content

Commit b7b8967

Browse files
committed
Make ElicitationResult subscriptable so the documented Resolve union form works
Codex review caught that the documented Annotated[ElicitationResult[Login], Resolve(login)] form silently dropped the resolver: ElicitationResult was a collapsed union alias (not subscriptable), so under 'from __future__ import annotations' get_type_hints raised, _type_hints swallowed it, and the parameter stayed client-supplied with the resolver never running. Redefine ElicitationResult via TypeAliasType so ElicitationResult[T] is genuinely subscriptable, and teach _wants_union to unwrap the alias. Update the migration doc to use the clean ElicitationResult[T] form. Add a regression test exercising the postponed-annotations path.
1 parent 58238b1 commit b7b8967

4 files changed

Lines changed: 41 additions & 9 deletions

File tree

docs/migration.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1476,13 +1476,14 @@ async def star_repo(
14761476
return f"starred {repo} as {login.username}" if confirm.ok else "cancelled"
14771477
```
14781478

1479-
The injected type follows the consumer's annotation. Annotating the unwrapped model (`Annotated[Login, Resolve(login)]`) injects the model on accept and aborts the call with an error result on decline or cancel. To branch on the outcome instead, annotate the elicitation result union:
1479+
The injected type follows the consumer's annotation. Annotating the unwrapped model (`Annotated[Login, Resolve(login)]`) injects the model on accept and aborts the call with an error result on decline or cancel. To branch on the outcome instead, annotate `ElicitationResult[Login]` (or an explicit `AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation` union):
14801480

14811481
```python
1482+
from mcp.server.mcpserver import ElicitationResult
1483+
1484+
14821485
@mcp.tool()
1483-
async def whoami(
1484-
login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)],
1485-
) -> str:
1486+
async def whoami(login: Annotated[ElicitationResult[Login], Resolve(login)]) -> str:
14861487
match login:
14871488
case AcceptedElicitation(data=data):
14881489
return f"hi {data.username}"

src/mcp/server/elicitation.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pydantic import BaseModel, ValidationError
88
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
99
from pydantic_core import core_schema
10+
from typing_extensions import TypeAliasType
1011

1112
from mcp.server.session import ServerSession
1213
from mcp.types import RequestId
@@ -36,7 +37,11 @@ class CancelledElicitation(BaseModel):
3637
action: Literal["cancel"] = "cancel"
3738

3839

39-
ElicitationResult = AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation
40+
ElicitationResult = TypeAliasType(
41+
"ElicitationResult",
42+
AcceptedElicitation[ElicitSchemaModelT] | DeclinedElicitation | CancelledElicitation,
43+
type_params=(ElicitSchemaModelT,),
44+
)
4045

4146

4247
class AcceptedUrlElicitation(BaseModel):

src/mcp/server/mcpserver/resolve.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,16 @@ def find_resolved_parameters(fn: Callable[..., Any]) -> dict[str, tuple[Resolve,
126126

127127

128128
def _wants_union(type_arg: Any) -> bool:
129-
"""True when `type_arg` is an `ElicitationResult` member (or a union of them)."""
129+
"""True when `type_arg` is an `ElicitationResult` member (or a union of them).
130+
131+
Handles the bare `ElicitationResult[T]` alias (a `TypeAliasType` carrying the
132+
union on `__value__`), an explicit `AcceptedElicitation[T] | ... ` union, and a
133+
single member.
134+
"""
135+
origin = get_origin(type_arg)
136+
value = getattr(origin, "__value__", None)
137+
if value is not None:
138+
type_arg = value
130139
members = get_args(type_arg) if get_origin(type_arg) is not None else (type_arg,)
131140
return any(isinstance(m, type) and issubclass(m, _ELICITATION_RESULT_MEMBERS) for m in members)
132141

tests/server/mcpserver/test_resolve.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Context,
1414
DeclinedElicitation,
1515
Elicit,
16+
ElicitationResult,
1617
MCPServer,
1718
Resolve,
1819
)
@@ -30,6 +31,10 @@ class Confirm(BaseModel):
3031
ok: bool
3132

3233

34+
async def _alias_login(ctx: Context) -> Login:
35+
return Login(username="x") # pragma: no cover - only the signature is inspected
36+
37+
3338
def _accept(content: dict[str, str | int | float | bool | list[str] | None]):
3439
async def callback(context: ClientRequestContext, params: ElicitRequestParams) -> ElicitResult:
3540
return ElicitResult(action="accept", content=content)
@@ -92,9 +97,7 @@ async def login(ctx: Context) -> Login | Elicit[Login]:
9297
return Elicit("GitHub username?", Login)
9398

9499
@mcp.tool()
95-
async def whoami(
96-
login: Annotated[AcceptedElicitation[Login] | DeclinedElicitation | CancelledElicitation, Resolve(login)],
97-
) -> str:
100+
async def whoami(login: Annotated[ElicitationResult[Login], Resolve(login)]) -> str:
98101
match login:
99102
case AcceptedElicitation(data=data):
100103
return f"hi {data.username}"
@@ -263,6 +266,20 @@ def fn(x: int) -> int:
263266
assert find_resolved_parameters(fn) == {}
264267

265268

269+
def test_elicitation_result_alias_resolves_under_postponed_annotations():
270+
# Reproduces the case where `from __future__ import annotations` stringifies
271+
# `Annotated[ElicitationResult[Login], Resolve(_alias_login)]`: the alias must be
272+
# subscriptable so the resolver is detected (not silently dropped) and the
273+
# consumer is recognized as wanting the result union.
274+
def tool(login: str) -> str:
275+
return login # pragma: no cover
276+
277+
tool.__annotations__["login"] = "Annotated[ElicitationResult[Login], Resolve(_alias_login)]"
278+
resolved = find_resolved_parameters(tool)
279+
assert "login" in resolved
280+
assert resolved["login"][1] is True # wants_union
281+
282+
266283
def test_unresolvable_resolver_param_raises_at_registration():
267284
async def login(mystery: int) -> Login:
268285
return Login(username="x") # pragma: no cover

0 commit comments

Comments
 (0)