diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 43ebc8e25..688dbeccd 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -5,7 +5,6 @@ from a2a.server.request_handlers.default_request_handler import ( DefaultRequestHandler, ) -from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, @@ -40,7 +39,6 @@ def __init__(self, *args, **kwargs): __all__ = [ 'DefaultRequestHandler', 'GrpcHandler', - 'JSONRPCHandler', 'RESTHandler', 'RequestHandler', 'build_error_response', diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py deleted file mode 100644 index 06188e412..000000000 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ /dev/null @@ -1,488 +0,0 @@ -"""JSON-RPC handler for A2A server requests.""" - -import logging - -from collections.abc import AsyncIterable, Awaitable, Callable -from typing import Any - -from google.protobuf.json_format import MessageToDict -from jsonrpc.jsonrpc2 import JSONRPC20Response - -from a2a.server.context import ServerCallContext -from a2a.server.jsonrpc_models import ( - InternalError as JSONRPCInternalError, -) -from a2a.server.jsonrpc_models import ( - JSONRPCError, -) -from a2a.server.request_handlers.request_handler import RequestHandler -from a2a.types.a2a_pb2 import ( - AgentCard, - CancelTaskRequest, - DeleteTaskPushNotificationConfigRequest, - GetExtendedAgentCardRequest, - GetTaskPushNotificationConfigRequest, - GetTaskRequest, - ListTaskPushNotificationConfigsRequest, - ListTasksRequest, - SendMessageRequest, - SendMessageResponse, - SubscribeToTaskRequest, - Task, - TaskPushNotificationConfig, -) -from a2a.utils import constants, proto_utils -from a2a.utils.errors import ( - JSON_RPC_ERROR_CODE_MAP, - A2AError, - ContentTypeNotSupportedError, - ExtendedAgentCardNotConfiguredError, - ExtensionSupportRequiredError, - InternalError, - InvalidAgentResponseError, - InvalidParamsError, - InvalidRequestError, - MethodNotFoundError, - PushNotificationNotSupportedError, - TaskNotCancelableError, - TaskNotFoundError, - UnsupportedOperationError, - VersionNotSupportedError, -) -from a2a.utils.helpers import ( - maybe_await, - validate, - validate_version, -) -from a2a.utils.telemetry import SpanKind, trace_class - - -logger = logging.getLogger(__name__) - - -EXCEPTION_MAP: dict[type[A2AError], type[JSONRPCError]] = { - TaskNotFoundError: JSONRPCError, - TaskNotCancelableError: JSONRPCError, - PushNotificationNotSupportedError: JSONRPCError, - UnsupportedOperationError: JSONRPCError, - ContentTypeNotSupportedError: JSONRPCError, - InvalidAgentResponseError: JSONRPCError, - ExtendedAgentCardNotConfiguredError: JSONRPCError, - InternalError: JSONRPCInternalError, - InvalidParamsError: JSONRPCError, - InvalidRequestError: JSONRPCError, - MethodNotFoundError: JSONRPCError, - ExtensionSupportRequiredError: JSONRPCError, - VersionNotSupportedError: JSONRPCError, -} - - -def _build_success_response( - request_id: str | int | None, result: Any -) -> dict[str, Any]: - """Build a JSON-RPC success response dict.""" - return JSONRPC20Response(result=result, _id=request_id).data - - -def _build_error_response( - request_id: str | int | None, error: Exception -) -> dict[str, Any]: - """Build a JSON-RPC error response dict.""" - jsonrpc_error: JSONRPCError - if isinstance(error, A2AError): - error_type = type(error) - model_class = EXCEPTION_MAP.get(error_type, JSONRPCInternalError) - code = JSON_RPC_ERROR_CODE_MAP.get(error_type, -32603) - jsonrpc_error = model_class( - code=code, - message=str(error), - ) - else: - jsonrpc_error = JSONRPCInternalError(message=str(error)) - - error_dict = jsonrpc_error.model_dump(exclude_none=True) - return JSONRPC20Response(error=error_dict, _id=request_id).data - - -@trace_class(kind=SpanKind.SERVER) -class JSONRPCHandler: - """Maps incoming JSON-RPC requests to the appropriate request handler method and formats responses.""" - - def __init__( - self, - agent_card: AgentCard, - request_handler: RequestHandler, - extended_agent_card: AgentCard | None = None, - extended_card_modifier: Callable[ - [AgentCard, ServerCallContext], Awaitable[AgentCard] | AgentCard - ] - | None = None, - card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] - | None = None, - ): - """Initializes the JSONRPCHandler. - - Args: - agent_card: The AgentCard describing the agent's capabilities. - request_handler: The underlying `RequestHandler` instance to delegate requests to. - extended_agent_card: An optional, distinct Extended AgentCard to be served - extended_card_modifier: An optional callback to dynamically modify - the extended agent card before it is served. It receives the - call context. - card_modifier: An optional callback to dynamically modify the public - agent card before it is served. - """ - self.agent_card = agent_card - self.request_handler = request_handler - self.extended_agent_card = extended_agent_card - self.extended_card_modifier = extended_card_modifier - self.card_modifier = card_modifier - - def _get_request_id( - self, context: ServerCallContext | None - ) -> str | int | None: - """Get the JSON-RPC request ID from the context.""" - if context is None: - return None - return context.state.get('request_id') - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_message_send( - self, - request: SendMessageRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'message/send' JSON-RPC method. - - Args: - request: The incoming `SendMessageRequest` proto message. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task_or_message = await self.request_handler.on_message_send( - request, context - ) - if isinstance(task_or_message, Task): - response = SendMessageResponse(task=task_or_message) - else: - response = SendMessageResponse(message=task_or_message) - - result = MessageToDict(response) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_message_send_stream( - self, - request: SendMessageRequest, - context: ServerCallContext, - ) -> AsyncIterable[dict[str, Any]]: - """Handles the 'message/stream' JSON-RPC method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `SendMessageRequest` object (for streaming). - context: Context provided by the server. - - Yields: - Dict representations of JSON-RPC responses containing streaming events. - """ - try: - async for event in self.request_handler.on_message_send_stream( - request, context - ): - # Wrap the event in StreamResponse for consistent client parsing - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield _build_success_response( - self._get_request_id(context), result - ) - except A2AError as e: - yield _build_error_response( - self._get_request_id(context), - e, - ) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_cancel_task( - self, - request: CancelTaskRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/cancel' JSON-RPC method. - - Args: - request: The incoming `CancelTaskRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task = await self.request_handler.on_cancel_task(request, context) - except A2AError as e: - return _build_error_response(request_id, e) - - if task: - result = MessageToDict(task, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - - return _build_error_response(request_id, TaskNotFoundError()) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.streaming, - 'Streaming is not supported by the agent', - ) - async def on_subscribe_to_task( - self, - request: SubscribeToTaskRequest, - context: ServerCallContext, - ) -> AsyncIterable[dict[str, Any]]: - """Handles the 'SubscribeToTask' JSON-RPC method. - - Yields response objects as they are produced by the underlying handler's stream. - - Args: - request: The incoming `SubscribeToTaskRequest` object. - context: Context provided by the server. - - Yields: - Dict representations of JSON-RPC responses containing streaming events. - """ - try: - async for event in self.request_handler.on_subscribe_to_task( - request, context - ): - # Wrap the event in StreamResponse for consistent client parsing - stream_response = proto_utils.to_stream_response(event) - result = MessageToDict( - stream_response, preserving_proto_field_name=False - ) - yield _build_success_response( - self._get_request_id(context), result - ) - except A2AError as e: - yield _build_error_response( - self._get_request_id(context), - e, - ) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def get_push_notification_config( - self, - request: GetTaskPushNotificationConfigRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/get' JSON-RPC method. - - Args: - request: The incoming `GetTaskPushNotificationConfigRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - config = ( - await self.request_handler.on_get_task_push_notification_config( - request, context - ) - ) - result = MessageToDict(config, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - @validate( - lambda self: self.agent_card.capabilities.push_notifications, - 'Push notifications are not supported by the agent', - ) - async def set_push_notification_config( - self, - request: TaskPushNotificationConfig, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/set' JSON-RPC method. - - Requires the agent to support push notifications. - - Args: - request: The incoming `TaskPushNotificationConfig` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - - Raises: - UnsupportedOperationError: If push notifications are not supported by the agent - (due to the `@validate` decorator). - """ - request_id = self._get_request_id(context) - try: - # Pass the full request to the handler - result_config = await self.request_handler.on_create_task_push_notification_config( - request, context - ) - result = MessageToDict( - result_config, preserving_proto_field_name=False - ) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def on_get_task( - self, - request: GetTaskRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/get' JSON-RPC method. - - Args: - request: The incoming `GetTaskRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - task = await self.request_handler.on_get_task(request, context) - except A2AError as e: - return _build_error_response(request_id, e) - - if task: - result = MessageToDict(task, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - - return _build_error_response(request_id, TaskNotFoundError()) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def list_tasks( - self, - request: ListTasksRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/list' JSON-RPC method. - - Args: - request: The incoming `ListTasksRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - response = await self.request_handler.on_list_tasks( - request, context - ) - result = MessageToDict( - response, - preserving_proto_field_name=False, - always_print_fields_with_no_presence=True, - ) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def list_push_notification_configs( - self, - request: ListTaskPushNotificationConfigsRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'ListTaskPushNotificationConfigs' JSON-RPC method. - - Args: - request: The incoming `ListTaskPushNotificationConfigsRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - response = await self.request_handler.on_list_task_push_notification_configs( - request, context - ) - # response is a ListTaskPushNotificationConfigsResponse proto - result = MessageToDict(response, preserving_proto_field_name=False) - return _build_success_response(request_id, result) - except A2AError as e: - return _build_error_response(request_id, e) - - @validate_version(constants.PROTOCOL_VERSION_1_0) - async def delete_push_notification_config( - self, - request: DeleteTaskPushNotificationConfigRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'tasks/pushNotificationConfig/delete' JSON-RPC method. - - Args: - request: The incoming `DeleteTaskPushNotificationConfigRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - try: - await self.request_handler.on_delete_task_push_notification_config( - request, context - ) - return _build_success_response(request_id, None) - except A2AError as e: - return _build_error_response(request_id, e) - - async def get_authenticated_extended_card( - self, - request: GetExtendedAgentCardRequest, - context: ServerCallContext, - ) -> dict[str, Any]: - """Handles the 'agent/authenticatedExtendedCard' JSON-RPC method. - - Args: - request: The incoming `GetExtendedAgentCardRequest` object. - context: Context provided by the server. - - Returns: - A dict representing the JSON-RPC response. - """ - request_id = self._get_request_id(context) - if not self.agent_card.capabilities.extended_agent_card: - raise ExtendedAgentCardNotConfiguredError( - message='The agent does not have an extended agent card configured' - ) - - base_card = self.extended_agent_card - if base_card is None: - base_card = self.agent_card - - card_to_serve = base_card - if self.extended_card_modifier and context: - card_to_serve = await maybe_await( - self.extended_card_modifier(base_card, context) - ) - elif self.card_modifier: - card_to_serve = await maybe_await(self.card_modifier(base_card)) - - result = MessageToDict(card_to_serve, preserving_proto_field_name=False) - return _build_success_response(request_id, result) diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 1ce5f0fe8..fd7b226bb 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -9,8 +9,8 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from typing import TYPE_CHECKING, Any -from google.protobuf.json_format import ParseDict -from jsonrpc.jsonrpc2 import JSONRPC20Request +from google.protobuf.json_format import MessageToDict, ParseDict +from jsonrpc.jsonrpc2 import JSONRPC20Request, JSONRPC20Response from a2a.auth.user import UnauthenticatedUser from a2a.auth.user import User as A2AUser @@ -28,7 +28,6 @@ JSONRPCError, MethodNotFoundError, ) -from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import ( build_error_response, @@ -44,13 +43,20 @@ ListTaskPushNotificationConfigsRequest, ListTasksRequest, SendMessageRequest, + SendMessageResponse, SubscribeToTaskRequest, + Task, TaskPushNotificationConfig, ) +from a2a.utils import constants, proto_utils from a2a.utils.errors import ( A2AError, + ExtendedAgentCardNotConfiguredError, + TaskNotFoundError, UnsupportedOperationError, ) +from a2a.utils.helpers import maybe_await, validate, validate_version +from a2a.utils.telemetry import SpanKind, trace_class INTERNAL_ERROR_CODE = -32603 @@ -161,6 +167,7 @@ def build(self, request: Request) -> ServerCallContext: ) +@trace_class(kind=SpanKind.SERVER) class JsonRpcDispatcher: """Base class for A2A JSONRPC applications. @@ -189,7 +196,7 @@ class JsonRpcDispatcher: def __init__( # noqa: PLR0913 self, agent_card: AgentCard, - http_handler: RequestHandler, + request_handler: RequestHandler, extended_agent_card: AgentCard | None = None, context_builder: CallContextBuilder | None = None, card_modifier: Callable[[AgentCard], Awaitable[AgentCard] | AgentCard] @@ -204,12 +211,12 @@ def __init__( # noqa: PLR0913 Args: agent_card: The AgentCard describing the agent's capabilities. - http_handler: The handler instance responsible for processing A2A + request_handler: The handler instance responsible for processing A2A requests via http. extended_agent_card: An optional, distinct AgentCard to be served at the authenticated extended card endpoint. context_builder: The CallContextBuilder used to construct the - ServerCallContext passed to the http_handler. If None, no + ServerCallContext passed to the request_handler. If None, no ServerCallContext is passed. card_modifier: An optional callback to dynamically modify the public agent card before it is served. @@ -226,15 +233,10 @@ def __init__( # noqa: PLR0913 ) self.agent_card = agent_card + self.request_handler = request_handler self.extended_agent_card = extended_agent_card self.card_modifier = card_modifier self.extended_card_modifier = extended_card_modifier - self.handler = JSONRPCHandler( - agent_card=agent_card, - request_handler=http_handler, - extended_agent_card=extended_agent_card, - extended_card_modifier=extended_card_modifier, - ) self._context_builder = context_builder or DefaultCallContextBuilder() self.enable_v0_3_compat = enable_v0_3_compat self._v03_adapter: JSONRPC03Adapter | None = None @@ -242,7 +244,7 @@ def __init__( # noqa: PLR0913 if self.enable_v0_3_compat: self._v03_adapter = JSONRPC03Adapter( agent_card=agent_card, - http_handler=http_handler, + http_handler=request_handler, extended_agent_card=extended_agent_card, context_builder=self._context_builder, card_modifier=card_modifier, @@ -393,13 +395,20 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, # Route streaming requests by method name if method in ('SendStreamingMessage', 'SubscribeToTask'): - return await self._process_streaming_request( + handler_result = await self._process_streaming_request( request_id, specific_request, call_context ) - - return await self._process_non_streaming_request( - request_id, specific_request, call_context - ) + else: + try: + raw_result = await self._process_non_streaming_request( + request_id, specific_request, call_context + ) + handler_result = JSONRPC20Response( + result=raw_result, _id=request_id + ).data + except A2AError as e: + handler_result = build_error_response(request_id, e) + return self._create_response(call_context, handler_result) except json.decoder.JSONDecodeError as e: traceback.print_exc() return self._generate_error_response( @@ -420,12 +429,17 @@ async def handle_requests(self, request: Request) -> Response: # noqa: PLR0911, request_id, InternalError(message=str(e)) ) + @validate_version(constants.PROTOCOL_VERSION_1_0) + @validate( + lambda self: self.agent_card.capabilities.streaming, + 'Streaming is not supported by the agent', + ) async def _process_streaming_request( self, request_id: str | int | None, request_obj: A2ARequest, context: ServerCallContext, - ) -> Response: + ) -> AsyncGenerator[dict[str, Any], None]: """Processes streaming requests (SendStreamingMessage or SubscribeToTask). Args: @@ -434,30 +448,152 @@ async def _process_streaming_request( context: The ServerCallContext for the request. Returns: - An `EventSourceResponse` object to stream results to the client. + An `AsyncGenerator` object to stream results to the client. """ - handler_result: Any = None - # Check for streaming message request (same type as SendMessage, but handled differently) - if isinstance( - request_obj, - SendMessageRequest, - ): - handler_result = self.handler.on_message_send_stream( + stream: AsyncGenerator | None = None + if isinstance(request_obj, SendMessageRequest): + stream = self.request_handler.on_message_send_stream( request_obj, context ) elif isinstance(request_obj, SubscribeToTaskRequest): - handler_result = self.handler.on_subscribe_to_task( + stream = self.request_handler.on_subscribe_to_task( + request_obj, context + ) + + if stream is None: + raise UnsupportedOperationError(message='Stream not supported') + + async def _wrap_stream( + st: AsyncGenerator, + ) -> AsyncGenerator[dict[str, Any], None]: + try: + async for event in st: + stream_response = proto_utils.to_stream_response(event) + result = MessageToDict( + stream_response, preserving_proto_field_name=False + ) + yield JSONRPC20Response(result=result, _id=request_id).data + except A2AError as e: + yield build_error_response(request_id, e) + + return _wrap_stream(stream) + + async def _handle_send_message( + self, request_obj: SendMessageRequest, context: ServerCallContext + ) -> dict[str, Any]: + task_or_message = await self.request_handler.on_message_send( + request_obj, context + ) + if isinstance(task_or_message, Task): + return MessageToDict(SendMessageResponse(task=task_or_message)) + return MessageToDict(SendMessageResponse(message=task_or_message)) + + async def _handle_cancel_task( + self, request_obj: CancelTaskRequest, context: ServerCallContext + ) -> dict[str, Any]: + task = await self.request_handler.on_cancel_task(request_obj, context) + if task: + return MessageToDict(task, preserving_proto_field_name=False) + raise TaskNotFoundError + + async def _handle_get_task( + self, request_obj: GetTaskRequest, context: ServerCallContext + ) -> dict[str, Any]: + task = await self.request_handler.on_get_task(request_obj, context) + if task: + return MessageToDict(task, preserving_proto_field_name=False) + raise TaskNotFoundError + + async def _handle_list_tasks( + self, request_obj: ListTasksRequest, context: ServerCallContext + ) -> dict[str, Any]: + tasks_response = await self.request_handler.on_list_tasks( + request_obj, context + ) + return MessageToDict( + tasks_response, + preserving_proto_field_name=False, + always_print_fields_with_no_presence=True, + ) + + @validate( + lambda self: self.agent_card.capabilities.push_notifications, + 'Push notifications are not supported by the agent', + ) + async def _handle_create_task_push_notification_config( + self, + request_obj: TaskPushNotificationConfig, + context: ServerCallContext, + ) -> dict[str, Any]: + result_config = ( + await self.request_handler.on_create_task_push_notification_config( + request_obj, context + ) + ) + return MessageToDict(result_config, preserving_proto_field_name=False) + + async def _handle_get_task_push_notification_config( + self, + request_obj: GetTaskPushNotificationConfigRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + config = ( + await self.request_handler.on_get_task_push_notification_config( request_obj, context ) + ) + return MessageToDict(config, preserving_proto_field_name=False) - return self._create_response(context, handler_result) + async def _handle_list_task_push_notification_configs( + self, + request_obj: ListTaskPushNotificationConfigsRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + configs_response = ( + await self.request_handler.on_list_task_push_notification_configs( + request_obj, context + ) + ) + return MessageToDict( + configs_response, preserving_proto_field_name=False + ) - async def _process_non_streaming_request( + async def _handle_delete_task_push_notification_config( + self, + request_obj: DeleteTaskPushNotificationConfigRequest, + context: ServerCallContext, + ) -> None: + await self.request_handler.on_delete_task_push_notification_config( + request_obj, context + ) + + async def _handle_get_extended_agent_card( + self, + request_obj: GetExtendedAgentCardRequest, + context: ServerCallContext, + ) -> dict[str, Any]: + if not self.agent_card.capabilities.extended_agent_card: + raise ExtendedAgentCardNotConfiguredError( + message='The agent does not have an extended agent card configured' + ) + base_card = self.extended_agent_card or self.agent_card + card_to_serve = base_card + if self.extended_card_modifier and context: + card_to_serve = await maybe_await( + self.extended_card_modifier(base_card, context) + ) + elif self.card_modifier: + card_to_serve = await maybe_await(self.card_modifier(base_card)) + + return MessageToDict(card_to_serve, preserving_proto_field_name=False) + + @validate_version(constants.PROTOCOL_VERSION_1_0) + async def _process_non_streaming_request( # noqa: PLR0911 self, request_id: str | int | None, request_obj: A2ARequest, context: ServerCallContext, - ) -> Response: + ) -> dict[str, Any] | None: """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*). Args: @@ -466,71 +602,44 @@ async def _process_non_streaming_request( context: The ServerCallContext for the request. Returns: - A `JSONResponse` object containing the result or error. + A dict containing the result or error. """ - handler_result: Any = None match request_obj: case SendMessageRequest(): - handler_result = await self.handler.on_message_send( - request_obj, context - ) + return await self._handle_send_message(request_obj, context) case CancelTaskRequest(): - handler_result = await self.handler.on_cancel_task( - request_obj, context - ) + return await self._handle_cancel_task(request_obj, context) case GetTaskRequest(): - handler_result = await self.handler.on_get_task( - request_obj, context - ) + return await self._handle_get_task(request_obj, context) case ListTasksRequest(): - handler_result = await self.handler.list_tasks( - request_obj, context - ) + return await self._handle_list_tasks(request_obj, context) case TaskPushNotificationConfig(): - handler_result = ( - await self.handler.set_push_notification_config( - request_obj, - context, - ) + return await self._handle_create_task_push_notification_config( + request_obj, context ) case GetTaskPushNotificationConfigRequest(): - handler_result = ( - await self.handler.get_push_notification_config( - request_obj, - context, - ) + return await self._handle_get_task_push_notification_config( + request_obj, context ) case ListTaskPushNotificationConfigsRequest(): - handler_result = ( - await self.handler.list_push_notification_configs( - request_obj, - context, - ) + return await self._handle_list_task_push_notification_configs( + request_obj, context ) case DeleteTaskPushNotificationConfigRequest(): - handler_result = ( - await self.handler.delete_push_notification_config( - request_obj, - context, - ) + return await self._handle_delete_task_push_notification_config( + request_obj, context ) case GetExtendedAgentCardRequest(): - handler_result = ( - await self.handler.get_authenticated_extended_card( - request_obj, - context, - ) + return await self._handle_get_extended_agent_card( + request_obj, context ) case _: logger.error( 'Unhandled validated request type: %s', type(request_obj) ) - error = UnsupportedOperationError( + raise UnsupportedOperationError( message=f'Request type {type(request_obj).__name__} is unknown.' ) - return self._generate_error_response(request_id, error) - - return self._create_response(context, handler_result) def _create_response( self, diff --git a/src/a2a/server/routes/jsonrpc_routes.py b/src/a2a/server/routes/jsonrpc_routes.py index 9138ed8ea..8d1a67bbd 100644 --- a/src/a2a/server/routes/jsonrpc_routes.py +++ b/src/a2a/server/routes/jsonrpc_routes.py @@ -72,7 +72,7 @@ def create_jsonrpc_routes( # noqa: PLR0913 dispatcher = JsonRpcDispatcher( agent_card=agent_card, - http_handler=request_handler, + request_handler=request_handler, extended_agent_card=extended_agent_card, context_builder=context_builder, card_modifier=card_modifier, diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index a1198878a..8884a5dd8 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -57,7 +57,11 @@ TaskStatus, TaskStatusUpdateEvent, ) -from a2a.utils.constants import TransportProtocol +from a2a.utils.constants import ( + PROTOCOL_VERSION_CURRENT, + VERSION_HEADER, + TransportProtocol, +) from a2a.utils.errors import ( ContentTypeNotSupportedError, ExtendedAgentCardNotConfiguredError, @@ -705,7 +709,10 @@ async def test_json_transport_get_signed_base_card( rpc_url='/', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) - httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + headers={VERSION_HEADER: PROTOCOL_VERSION_CURRENT}, + ) agent_url = agent_card.supported_interfaces[0].url signature_verifier = create_signature_verifier( @@ -776,7 +783,10 @@ async def test_client_get_signed_extended_card( rpc_url='/', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) - httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + headers={VERSION_HEADER: PROTOCOL_VERSION_CURRENT}, + ) transport = JsonRpcTransport( httpx_client=httpx_client, @@ -847,7 +857,10 @@ async def test_client_get_signed_base_and_extended_cards( rpc_url='/', ) app = Starlette(routes=[*agent_card_routes, *jsonrpc_routes]) - httpx_client = httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) + httpx_client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + headers={VERSION_HEADER: PROTOCOL_VERSION_CURRENT}, + ) agent_url = agent_card.supported_interfaces[0].url signature_verifier = create_signature_verifier( diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py deleted file mode 100644 index 81b23126c..000000000 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ /dev/null @@ -1,1505 +0,0 @@ -import asyncio -import unittest -import unittest.async_case - -from collections.abc import AsyncGenerator -from typing import Any, NoReturn -from unittest.mock import ANY, AsyncMock, MagicMock, call, patch - -import httpx -import pytest - -from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.agent_execution.request_context_builder import ( - RequestContextBuilder, -) -from a2a.server.context import ServerCallContext -from a2a.server.events import QueueManager -from a2a.server.events.event_queue import EventQueue -from a2a.server.request_handlers import DefaultRequestHandler, JSONRPCHandler -from a2a.server.tasks import ( - BasePushNotificationSender, - InMemoryPushNotificationConfigStore, - PushNotificationConfigStore, - PushNotificationSender, - TaskStore, -) -from a2a.types import ( - InternalError, - TaskNotFoundError, - UnsupportedOperationError, -) -from a2a.types.a2a_pb2 import ( - AgentCapabilities, - AgentCard, - AgentInterface, - Artifact, - CancelTaskRequest, - DeleteTaskPushNotificationConfigRequest, - GetExtendedAgentCardRequest, - GetTaskPushNotificationConfigRequest, - GetTaskRequest, - ListTaskPushNotificationConfigsRequest, - ListTaskPushNotificationConfigsResponse, - ListTasksResponse, - Message, - Part, - TaskPushNotificationConfig, - Role, - SendMessageConfiguration, - SendMessageRequest, - TaskPushNotificationConfig, - SubscribeToTaskRequest, - Task, - TaskArtifactUpdateEvent, - TaskPushNotificationConfig, - TaskState, - TaskStatus, - TaskStatusUpdateEvent, -) - - -# Helper function to create a minimal Task proto -def create_task( - task_id: str = 'task_123', context_id: str = 'session-xyz' -) -> Task: - return Task( - id=task_id, - context_id=context_id, - status=TaskStatus(state=TaskState.TASK_STATE_SUBMITTED), - ) - - -# Helper function to create a Message proto -def create_message( - message_id: str = '111', - role: Role = Role.ROLE_AGENT, - text: str = 'test message', - task_id: str | None = None, - context_id: str | None = None, -) -> Message: - msg = Message( - message_id=message_id, - role=role, - parts=[Part(text=text)], - ) - if task_id: - msg.task_id = task_id - if context_id: - msg.context_id = context_id - return msg - - -# Helper functions for checking JSON-RPC response structure -def is_success_response(response: dict[str, Any]) -> bool: - """Check if response is a successful JSON-RPC response.""" - return 'result' in response and 'error' not in response - - -def is_error_response(response: dict[str, Any]) -> bool: - """Check if response is an error JSON-RPC response.""" - return 'error' in response - - -def get_error_code(response: dict[str, Any]) -> int | None: - """Get error code from JSON-RPC error response.""" - if 'error' in response: - return response['error'].get('code') - return None - - -def get_error_message(response: dict[str, Any]) -> str | None: - """Get error message from JSON-RPC error response.""" - if 'error' in response: - return response['error'].get('message') - return None - - -class TestJSONRPCtHandler(unittest.async_case.IsolatedAsyncioTestCase): - @pytest.fixture(autouse=True) - def init_fixtures(self) -> None: - self.mock_agent_card = MagicMock( - spec=AgentCard, - ) - self.mock_agent_card.capabilities = MagicMock(spec=AgentCapabilities) - self.mock_agent_card.capabilities.extended_agent_card = True - self.mock_agent_card.capabilities.streaming = True - self.mock_agent_card.capabilities.push_notifications = True - - # Mock supported_interfaces list - interface = MagicMock(spec=AgentInterface) - interface.url = 'http://agent.example.com/api' - self.mock_agent_card.supported_interfaces = [interface] - - def _ctx(self, state: dict[str, Any] | None = None) -> ServerCallContext: - full_state = {'headers': {'A2A-Version': '1.0'}} - if state: - full_state.update(state) - return ServerCallContext(state=full_state) - - async def test_on_get_task_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - call_context = ServerCallContext( - state={ - 'foo': 'bar', - 'request_id': '1', - 'headers': {'A2A-Version': '1.0'}, - } - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - request = GetTaskRequest(id=f'{task_id}') - response = await handler.on_get_task(request, call_context) - # Response is now a dict with 'result' key for success - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - assert response['result']['id'] == task_id - mock_task_store.get.assert_called_once_with(f'{task_id}', ANY) - - async def test_on_get_task_not_found(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = GetTaskRequest(id='nonexistent_id') - call_context = ServerCallContext( - state={ - 'foo': 'bar', - 'request_id': '1', - 'headers': {'A2A-Version': '1.0'}, - } - ) - response = await handler.on_get_task(request, call_context) - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32001 - - async def test_on_list_tasks_success(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task1 = create_task() - task2 = create_task() - task2.id = 'task_456' - mock_result = ListTasksResponse( - next_page_token='123', - tasks=[task1, task2], - ) - request_handler.on_list_tasks.return_value = mock_result - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest( - page_size=10, - page_token='token', - ) - call_context = self._ctx({'foo': 'bar'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertIn('tasks', response['result']) - self.assertEqual(len(response['result']['tasks']), 2) - self.assertEqual(response['result']['nextPageToken'], '123') - - async def test_on_list_tasks_error(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - request_handler.on_list_tasks.side_effect = InternalError( - message='DB down' - ) - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest(page_size=10) - call_context = self._ctx({'request_id': '2'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['message'], 'DB down') - - async def test_on_list_tasks_empty(self) -> None: - request_handler = AsyncMock(spec=DefaultRequestHandler) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_result = ListTasksResponse(page_size=10) - request_handler.on_list_tasks.return_value = mock_result - from a2a.types.a2a_pb2 import ListTasksRequest - - request = ListTasksRequest(page_size=10) - call_context = self._ctx({'foo': 'bar'}) - - response = await handler.list_tasks(request, call_context) - - request_handler.on_list_tasks.assert_awaited_once() - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertIn('tasks', response['result']) - self.assertEqual(len(response['result']['tasks']), 0) - self.assertIn('nextPageToken', response['result']) - self.assertEqual(response['result']['nextPageToken'], '') - self.assertIn('pageSize', response['result']) - self.assertEqual(response['result']['pageSize'], 10) - self.assertIn('totalSize', response['result']) - self.assertEqual(response['result']['totalSize'], 0) - - async def test_on_cancel_task_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - mock_agent_executor.cancel.return_value = None - call_context = ServerCallContext( - state={ - 'foo': 'bar', - 'request_id': '1', - 'headers': {'A2A-Version': '1.0'}, - } - ) - - async def streaming_coro(): - mock_task.status.state = TaskState.TASK_STATE_CANCELED - yield mock_task - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = CancelTaskRequest(id=f'{task_id}') - response = await handler.on_cancel_task(request, call_context) - assert mock_agent_executor.cancel.call_count == 1 - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result is converted to dict for JSON serialization - assert response['result']['id'] == task_id # type: ignore - assert ( - response['result']['status']['state'] == 'TASK_STATE_CANCELED' - ) # type: ignore - mock_agent_executor.cancel.assert_called_once() - - async def test_on_cancel_task_not_supported(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - task_id = 'test_task_id' - mock_task = create_task(task_id=task_id) - mock_task_store.get.return_value = mock_task - mock_agent_executor.cancel.return_value = None - call_context = ServerCallContext( - state={ - 'foo': 'bar', - 'request_id': '1', - 'headers': {'A2A-Version': '1.0'}, - } - ) - - async def streaming_coro(): - raise UnsupportedOperationError() - yield - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = CancelTaskRequest(id=f'{task_id}') - response = await handler.on_cancel_task(request, call_context) - assert mock_agent_executor.cancel.call_count == 1 - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32004 - mock_agent_executor.cancel.assert_called_once() - - async def test_on_cancel_task_not_found(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = CancelTaskRequest(id='nonexistent_id') - call_context = self._ctx({'request_id': '1'}) - response = await handler.on_cancel_task(request, call_context) - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32001 - mock_task_store.get.assert_called_once_with('nonexistent_id', ANY) - mock_agent_executor.cancel.assert_not_called() - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_new_message_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message( - task_id='task_123', context_id='session-xyz' - ), - ) - response = await handler.on_message_send( - request, - self._ctx(), - ) - # execute is called asynchronously in background task - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - - async def test_on_message_new_message_with_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - response = await handler.on_message_send( - request, - self._ctx(), - ) - # execute is called asynchronously in background task - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - - async def test_on_message_error(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - - async def streaming_coro(): - raise UnsupportedOperationError() - yield - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, context_id=mock_task.context_id - ), - ) - response = await handler.on_message_send( - request, - self._ctx(), - ) - - # Allow the background event loop to start the execution_task - import asyncio - - await asyncio.sleep(0) - - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - assert response['error']['code'] == -32004 - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_stream_new_message_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - # Latch to ensure background execute is scheduled before asserting - execute_called = asyncio.Event() - - async def exec_side_effect(*args, **kwargs): - execute_called.set() - - mock_agent_executor.execute.side_effect = exec_side_effect - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message( - task_id='task_123', context_id='session-xyz' - ), - ) - response = handler.on_message_send_stream( - request, - self._ctx(), - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == len(events) - await asyncio.wait_for(execute_called.wait(), timeout=0.1) - mock_agent_executor.execute.assert_called_once() - - async def test_on_message_stream_new_message_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_WORKING), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - # Latch to ensure background execute is scheduled before asserting - execute_called = asyncio.Event() - - async def exec_side_effect(*args, **kwargs): - execute_called.set() - - mock_agent_executor.execute.side_effect = exec_side_effect - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - response = handler.on_message_send_stream( - request, - self._ctx(), - ) - assert isinstance(response, AsyncGenerator) - collected_events = [item async for item in response] - assert len(collected_events) == len(events) - await asyncio.wait_for(execute_called.wait(), timeout=0.1) - mock_agent_executor.execute.assert_called_once() - assert mock_task.history is not None and len(mock_task.history) == 1 - - async def test_set_push_notification_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_push_notification_store = AsyncMock( - spec=PushNotificationConfigStore - ) - - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=mock_push_notification_store, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - ) - context = self._ctx() - response = await handler.set_push_notification_config(request, context) - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - mock_push_notification_store.set_info.assert_called_once_with( - mock_task.id, request, context - ) - - async def test_get_push_notification_success(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - push_notification_store = InMemoryPushNotificationConfigStore() - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=push_notification_store, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - mock_task_store.get.return_value = mock_task - push_config = TaskPushNotificationConfig( - id='default', url='http://example.com' - ) - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - id='default', - ) - await handler.set_push_notification_config( - request, - self._ctx(), - ) - - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='default', - ) - get_response = await handler.get_push_notification_config( - get_request, - self._ctx(), - ) - self.assertIsInstance(get_response, dict) - self.assertTrue(is_success_response(get_response)) - - @patch( - 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' - ) - async def test_on_message_stream_new_message_send_push_notification_success( - self, _mock_builder_build: AsyncMock - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) - push_notification_store = InMemoryPushNotificationConfigStore() - push_notification_sender = BasePushNotificationSender( - mock_httpx_client, - push_notification_store, - self._ctx(), - ) - request_handler = DefaultRequestHandler( - mock_agent_executor, - mock_task_store, - push_config_store=push_notification_store, - push_sender=push_notification_sender, - ) - self.mock_agent_card.capabilities = AgentCapabilities( - streaming=True, push_notifications=True - ) - _mock_builder_build.return_value = RequestContext( - request=MagicMock(), - task_id='task_123', - context_id='session-xyz', - task=None, - related_tasks=None, - ) - - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - mock_task, - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = None - mock_agent_executor.execute.return_value = None - mock_httpx_client.post.return_value = httpx.Response(200) - request = SendMessageRequest( - message=create_message(), - configuration=SendMessageConfiguration( - accepted_output_modes=['text'], - task_push_notification_config=TaskPushNotificationConfig( - url='http://example.com' - ), - ), - ) - response = handler.on_message_send_stream( - request, - self._ctx(), - ) - assert isinstance(response, AsyncGenerator) - - collected_events = [item async for item in response] - assert len(collected_events) == len(events) - - async def test_on_resubscribe_existing_task_success( - self, - ) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_queue_manager = AsyncMock(spec=QueueManager) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store, mock_queue_manager - ) - self.mock_agent_card = MagicMock(spec=AgentCard) - self.mock_agent_card.capabilities = MagicMock(spec=AgentCapabilities) - self.mock_agent_card.capabilities.streaming = True - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - events: list[Any] = [ - TaskArtifactUpdateEvent( - task_id='task_123', - context_id='session-xyz', - artifact=Artifact(artifact_id='11', parts=[Part(text='text')]), - ), - TaskStatusUpdateEvent( - task_id='task_123', - context_id='session-xyz', - status=TaskStatus(state=TaskState.TASK_STATE_COMPLETED), - ), - ] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = mock_task - mock_queue_manager.tap.return_value = EventQueue() - request = SubscribeToTaskRequest(id=f'{mock_task.id}') - response = handler.on_subscribe_to_task( - request, - self._ctx(), - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert ( - len(collected_events) == len(events) + 1 - ) # First event is task itself - assert mock_task.history is not None and len(mock_task.history) == 0 - - async def test_on_subscribe_no_existing_task_error(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task_store.get.return_value = None - request = SubscribeToTaskRequest(id='nonexistent_id') - response = handler.on_subscribe_to_task( - request, - self._ctx(), - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == 1 - self.assertIsInstance(collected_events[0], dict) - self.assertTrue(is_error_response(collected_events[0])) - assert collected_events[0]['error']['code'] == -32001 - - async def test_streaming_not_supported_error( - self, - ) -> None: - """Test that on_message_send_stream raises an error when streaming not supported.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - # Create agent card with streaming capability disabled - self.mock_agent_card.capabilities = AgentCapabilities(streaming=False) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Act & Assert - request = SendMessageRequest( - message=create_message(), - ) - - # Should raise UnsupportedOperationError about streaming not supported - with self.assertRaises(UnsupportedOperationError) as context: - async for _ in handler.on_message_send_stream( - request, - self._ctx(), - ): - pass - - self.assertEqual( - str(context.exception.message), - 'Streaming is not supported by the agent', - ) - - async def test_push_notifications_not_supported_error(self) -> None: - """Test that set_push_notification raises an error when push notifications not supported.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - # Create agent card with push notifications capability disabled - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=False, streaming=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Act & Assert - request = TaskPushNotificationConfig( - task_id='task_123', - url='http://example.com', - ) - - # Should raise UnsupportedOperationError about push notifications not supported - with self.assertRaises(UnsupportedOperationError) as context: - await handler.set_push_notification_config( - request, - self._ctx(), - ) - - self.assertEqual( - str(context.exception.message), - 'Push notifications are not supported by the agent', - ) - - async def test_on_get_push_notification_no_push_config_store(self) -> None: - """Test get_push_notification with no push notifier configured.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - # Create request handler without a push notifier - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Act - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='default', - ) - response = await handler.get_push_notification_config( - get_request, - self._ctx(), - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_set_push_notification_no_push_config_store(self) -> None: - """Test set_push_notification with no push notifier configured.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - # Create request handler without a push notifier - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Act - request = TaskPushNotificationConfig( - task_id=mock_task.id, - url='http://example.com', - ) - response = await handler.set_push_notification_config( - request, - self._ctx(), - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_message_send_internal_error(self) -> None: - """Test on_message_send with an internal error.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Make the request handler raise an Internal error without specifying an error type - async def raise_server_error(*args, **kwargs) -> NoReturn: - raise InternalError(message='Internal Error') - - # Patch the method to raise an error - with patch.object( - request_handler, 'on_message_send', side_effect=raise_server_error - ): - # Act - request = SendMessageRequest( - message=create_message(), - ) - response = await handler.on_message_send( - request, - self._ctx(), - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_message_stream_internal_error(self) -> None: - """Test on_message_send_stream with an internal error.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Make the request handler raise an Internal error without specifying an error type - async def raise_server_error(*args, **kwargs): - raise InternalError(message='Internal Error') - yield # Need this to make it an async generator - - # Patch the method to raise an error - with patch.object( - request_handler, - 'on_message_send_stream', - return_value=raise_server_error(), - ): - # Act - request = SendMessageRequest( - message=create_message(), - ) - - # Get the single error response - responses = [] - async for response in handler.on_message_send_stream( - request, - self._ctx(), - ): - responses.append(response) - - # Assert - self.assertEqual(len(responses), 1) - self.assertIsInstance(responses[0], dict) - self.assertTrue(is_error_response(responses[0])) - self.assertEqual(responses[0]['error']['code'], -32603) - - async def test_default_request_handler_with_custom_components(self) -> None: - """Test DefaultRequestHandler initialization with custom components.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - mock_queue_manager = AsyncMock(spec=QueueManager) - mock_push_config_store = AsyncMock(spec=PushNotificationConfigStore) - mock_push_sender = AsyncMock(spec=PushNotificationSender) - mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) - - # Act - handler = DefaultRequestHandler( - agent_executor=mock_agent_executor, - task_store=mock_task_store, - queue_manager=mock_queue_manager, - push_config_store=mock_push_config_store, - push_sender=mock_push_sender, - request_context_builder=mock_request_context_builder, - ) - - # Assert - self.assertEqual(handler.agent_executor, mock_agent_executor) - self.assertEqual(handler.task_store, mock_task_store) - self.assertEqual(handler._queue_manager, mock_queue_manager) - self.assertEqual(handler._push_config_store, mock_push_config_store) - self.assertEqual(handler._push_sender, mock_push_sender) - self.assertEqual( - handler._request_context_builder, mock_request_context_builder - ) - - async def test_on_message_send_error_handling(self) -> None: - """Test error handling in on_message_send when consuming raises A2AError.""" - # Arrange - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - - # Let task exist - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Set up consume_and_break_on_interrupt to raise UnsupportedOperationError - async def consume_raises_error(*args, **kwargs) -> NoReturn: - raise UnsupportedOperationError() - - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - side_effect=consume_raises_error, - ): - # Act - request = SendMessageRequest( - message=create_message( - task_id=mock_task.id, - context_id=mock_task.context_id, - ), - ) - - response = await handler.on_message_send( - request, - self._ctx(), - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_on_message_send_task_id_mismatch(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - mock_task = create_task() - # Mock returns task with different ID than what will be generated - mock_task_store.get.return_value = None # No existing task - mock_agent_executor.execute.return_value = None - - # Task returned has task_id='task_123' but request_context will have generated UUID - with patch( - 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', - return_value=(mock_task, False, None), - ): - request = SendMessageRequest( - message=create_message(), # No task_id, so UUID is generated - ) - response = await handler.on_message_send( - request, - self._ctx(), - ) - # The task ID mismatch should cause an error - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_message_stream_task_id_mismatch(self) -> None: - mock_agent_executor = AsyncMock(spec=AgentExecutor) - mock_task_store = AsyncMock(spec=TaskStore) - request_handler = DefaultRequestHandler( - mock_agent_executor, mock_task_store - ) - - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - events: list[Any] = [create_task()] - - async def streaming_coro(): - for event in events: - yield event - - with patch( - 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', - return_value=streaming_coro(), - ): - mock_task_store.get.return_value = None - mock_agent_executor.execute.return_value = None - request = SendMessageRequest( - message=create_message(), - ) - response = handler.on_message_send_stream( - request, - self._ctx(), - ) - assert isinstance(response, AsyncGenerator) - collected_events: list[Any] = [] - async for event in response: - collected_events.append(event) - assert len(collected_events) == 1 - self.assertIsInstance(collected_events[0], dict) - self.assertTrue(is_error_response(collected_events[0])) - self.assertEqual(collected_events[0]['error']['code'], -32603) - - async def test_on_get_push_notification(self) -> None: - """Test get_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - task_push_config = TaskPushNotificationConfig( - task_id=mock_task.id, id='config1', url='http://example.com' - ) - request_handler.on_get_task_push_notification_config.return_value = ( - task_push_config - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - get_request = GetTaskPushNotificationConfigRequest( - task_id=mock_task.id, - id='config1', - ) - response = await handler.get_push_notification_config( - get_request, - self._ctx(), - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result is converted to dict for JSON serialization - self.assertEqual( - response['result']['id'], - 'config1', - ) - self.assertEqual( - response['result']['taskId'], - mock_task.id, - ) - - async def test_on_list_push_notification(self) -> None: - """Test list_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - task_push_config = TaskPushNotificationConfig( - task_id=mock_task.id, id='default', url='http://example.com' - ) - request_handler.on_list_task_push_notification_configs.return_value = ( - ListTaskPushNotificationConfigsResponse(configs=[task_push_config]) - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - list_request = ListTaskPushNotificationConfigsRequest( - task_id=mock_task.id, - ) - response = await handler.list_push_notification_configs( - list_request, - self._ctx(), - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - # Result contains the response dict with configs field - self.assertIsInstance(response['result'], dict) - - async def test_on_list_push_notification_error(self) -> None: - """Test list_push_notification_config handling""" - mock_task_store = AsyncMock(spec=TaskStore) - - mock_task = create_task() - mock_task_store.get.return_value = mock_task - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - # throw server error - request_handler.on_list_task_push_notification_configs.side_effect = ( - InternalError() - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - list_request = ListTaskPushNotificationConfigsRequest( - task_id=mock_task.id, - ) - response = await handler.list_push_notification_configs( - list_request, - self._ctx(), - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32603) - - async def test_on_delete_push_notification(self) -> None: - """Test delete_push_notification_config handling""" - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - request_handler.on_delete_task_push_notification_config.return_value = ( - None - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - delete_request = DeleteTaskPushNotificationConfigRequest( - task_id='task1', - id='config1', - ) - response = await handler.delete_push_notification_config( - delete_request, - self._ctx(), - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['result'], None) - - async def test_on_delete_push_notification_error(self) -> None: - """Test delete_push_notification_config error handling""" - - # Create request handler without a push notifier - request_handler = AsyncMock(spec=DefaultRequestHandler) - # throw server error - request_handler.on_delete_task_push_notification_config.side_effect = ( - UnsupportedOperationError() - ) - - self.mock_agent_card.capabilities = AgentCapabilities( - push_notifications=True - ) - handler = JSONRPCHandler(self.mock_agent_card, request_handler) - delete_request = DeleteTaskPushNotificationConfigRequest( - task_id='task1', - id='config1', - ) - response = await handler.delete_push_notification_config( - delete_request, - self._ctx(), - ) - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_error_response(response)) - self.assertEqual(response['error']['code'], -32004) - - async def test_get_authenticated_extended_card_success(self) -> None: - """Test successful retrieval of the authenticated extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_extended_card = AgentCard( - name='Extended Card', - description='More details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.1', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_extended_card, - extended_card_modifier=None, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-1'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-1') - # Result is the agent card proto - - async def test_get_authenticated_extended_card_not_configured(self) -> None: - """Test error when authenticated extended agent card is not configured.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - # We need a proper card here because agent_card_to_dict accesses multiple fields - card = AgentCard( - name='TestAgent', - version='1.0.0', - supported_interfaces=[ - AgentInterface( - url='http://localhost', - protocol_binding='JSONRPC', - protocol_version='1.0.0', - ) - ], - capabilities=AgentCapabilities(extended_agent_card=True), - ) - - handler = JSONRPCHandler( - card, - mock_request_handler, - extended_agent_card=None, - extended_card_modifier=None, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-2'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - # Authenticated Extended Card flag is set with no extended card, - # returns base card in this case. - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-2') - - async def test_get_authenticated_extended_card_with_modifier(self) -> None: - """Test successful retrieval of a dynamically modified extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_base_card = AgentCard( - name='Base Card', - description='Base details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - - async def modifier( - card: AgentCard, context: ServerCallContext - ) -> AgentCard: - modified_card = AgentCard() - modified_card.CopyFrom(card) - modified_card.name = 'Modified Card' - modified_card.description = ( - f'Modified for context: {context.state.get("foo")}' - ) - return modified_card - - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_base_card, - extended_card_modifier=modifier, - ) - request = GetExtendedAgentCardRequest() - call_context = self._ctx({'foo': 'bar'}) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertFalse(is_error_response(response)) - from google.protobuf.json_format import ParseDict - - modified_card = ParseDict( - response['result'], AgentCard(), ignore_unknown_fields=True - ) - self.assertEqual(modified_card.name, 'Modified Card') - self.assertEqual(modified_card.description, 'Modified for context: bar') - self.assertEqual(modified_card.version, '1.0') - - async def test_get_authenticated_extended_card_with_modifier_sync( - self, - ) -> None: - """Test successful retrieval of a synchronously dynamically modified extended agent card.""" - # Arrange - mock_request_handler = AsyncMock(spec=DefaultRequestHandler) - mock_base_card = AgentCard( - name='Base Card', - description='Base details', - supported_interfaces=[ - AgentInterface( - protocol_binding='HTTP+JSON', - url='http://agent.example.com/api', - ) - ], - version='1.0', - capabilities=AgentCapabilities(), - default_input_modes=['text/plain'], - default_output_modes=['application/json'], - skills=[], - ) - - def modifier(card: AgentCard, context: ServerCallContext) -> AgentCard: - # Copy the card by creating a new one with the same fields - from copy import deepcopy - - modified_card = AgentCard() - modified_card.CopyFrom(card) - modified_card.name = 'Modified Card' - modified_card.description = ( - f'Modified for context: {context.state.get("foo")}' - ) - return modified_card - - handler = JSONRPCHandler( - self.mock_agent_card, - mock_request_handler, - extended_agent_card=mock_base_card, - extended_card_modifier=modifier, - ) - request = GetExtendedAgentCardRequest() - call_context = ServerCallContext( - state={'foo': 'bar', 'request_id': 'ext-card-req-mod'} - ) - - # Act - response = await handler.get_authenticated_extended_card( - request, call_context - ) - - # Assert - self.assertIsInstance(response, dict) - self.assertTrue(is_success_response(response)) - self.assertEqual(response['id'], 'ext-card-req-mod') - # Result is converted to dict for JSON serialization - modified_card_dict = response['result'] - self.assertEqual(modified_card_dict['name'], 'Modified Card') - self.assertEqual( - modified_card_dict['description'], 'Modified for context: bar' - ) - self.assertEqual(modified_card_dict['version'], '1.0') diff --git a/tests/server/routes/test_jsonrpc_dispatcher.py b/tests/server/routes/test_jsonrpc_dispatcher.py index 586486b01..1242bee23 100644 --- a/tests/server/routes/test_jsonrpc_dispatcher.py +++ b/tests/server/routes/test_jsonrpc_dispatcher.py @@ -126,7 +126,7 @@ def mock_app_params(self) -> dict: mock_handler = MagicMock(spec=RequestHandler) mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = 'http://example.com' - return {'agent_card': mock_agent_card, 'http_handler': mock_handler} + return {'agent_card': mock_agent_card, 'request_handler': mock_handler} @pytest.fixture(scope='class') def mark_pkg_starlette_not_installed(self):