From 9562522c40a4bae6a326fdf2f9ba52722493ad07 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 19 Mar 2026 11:38:19 +0000 Subject: [PATCH 01/16] feat(server): validate fields presence according to `google.api.field_behavior` annotations --- .../default_request_handler.py | 15 ++- .../request_handlers/request_handler.py | 43 ++++++- src/a2a/utils/errors.py | 3 +- src/a2a/utils/proto_utils.py | 109 +++++++++++++++++- .../test_default_request_handler.py | 50 +++++--- 5 files changed, 200 insertions(+), 20 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index c641b0f12..99bb81fc2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -18,7 +18,10 @@ InMemoryQueueManager, QueueManager, ) -from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.request_handler import ( + RequestHandler, + validate_request_params, +) from a2a.server.tasks import ( PushNotificationConfigStore, PushNotificationEvent, @@ -118,6 +121,7 @@ def __init__( # noqa: PLR0913 # asyncio tasks and to surface unexpected exceptions. self._background_tasks = set() + @validate_request_params async def on_get_task( self, params: GetTaskRequest, @@ -133,6 +137,7 @@ async def on_get_task( return apply_history_length(task, params) + @validate_request_params async def on_list_tasks( self, params: ListTasksRequest, @@ -154,6 +159,7 @@ async def on_list_tasks( return page + @validate_request_params async def on_cancel_task( self, params: CancelTaskRequest, @@ -317,6 +323,7 @@ async def _send_push_notification_if_needed( ): await self._push_sender.send_notification(task_id, event) + @validate_request_params async def on_message_send( self, params: SendMessageRequest, @@ -386,6 +393,7 @@ async def push_notification_callback(event: Event) -> None: return result + @validate_request_params async def on_message_send_stream( self, params: SendMessageRequest, @@ -474,6 +482,7 @@ async def _cleanup_producer( async with self._running_agents_lock: self._running_agents.pop(task_id, None) + @validate_request_params async def on_create_task_push_notification_config( self, params: TaskPushNotificationConfig, @@ -499,6 +508,7 @@ async def on_create_task_push_notification_config( return params + @validate_request_params async def on_get_task_push_notification_config( self, params: GetTaskPushNotificationConfigRequest, @@ -530,6 +540,7 @@ async def on_get_task_push_notification_config( raise InternalError(message='Push notification config not found') + @validate_request_params async def on_subscribe_to_task( self, params: SubscribeToTaskRequest, @@ -572,6 +583,7 @@ async def on_subscribe_to_task( async for event in result_aggregator.consume_and_emit(consumer): yield event + @validate_request_params async def on_list_task_push_notification_configs( self, params: ListTaskPushNotificationConfigsRequest, @@ -597,6 +609,7 @@ async def on_list_task_push_notification_configs( configs=push_notification_config_list ) + @validate_request_params async def on_delete_task_push_notification_config( self, params: DeleteTaskPushNotificationConfigRequest, diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 120a71e37..6fa68b084 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -1,5 +1,11 @@ +import functools +import inspect + from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable +from typing import Any + +from google.protobuf.message import Message as ProtoMessage from a2a.server.context import ServerCallContext from a2a.server.events.event_queue import Event @@ -19,6 +25,7 @@ TaskPushNotificationConfig, ) from a2a.utils.errors import UnsupportedOperationError +from a2a.utils.proto_utils import validate_proto_required_fields class RequestHandler(ABC): @@ -218,3 +225,37 @@ async def on_delete_task_push_notification_config( Returns: None """ + + +def validate_request_params(method: Callable) -> Callable: + """Decorator for RequestHandler methods to validate required fields on incoming requests.""" + if inspect.isasyncgenfunction(method): + + @functools.wraps(method) + async def async_generator_wrapper( + self: RequestHandler, + params: ProtoMessage, + context: ServerCallContext, + *args: Any, + **kwargs: Any, + ) -> AsyncGenerator: + if params is not None: + validate_proto_required_fields(params) + async for item in method(self, params, context, *args, **kwargs): + yield item + + return async_generator_wrapper + + @functools.wraps(method) + async def async_wrapper( + self: RequestHandler, + params: ProtoMessage, + context: ServerCallContext, + *args: Any, + **kwargs: Any, + ) -> Any: + if params is not None: + validate_proto_required_fields(params) + return await method(self, params, context, *args, **kwargs) + + return async_wrapper diff --git a/src/a2a/utils/errors.py b/src/a2a/utils/errors.py index a16542d97..c87fa7372 100644 --- a/src/a2a/utils/errors.py +++ b/src/a2a/utils/errors.py @@ -21,9 +21,10 @@ class A2AError(Exception): message: str = 'A2A Error' data: dict | None = None - def __init__(self, message: str | None = None): + def __init__(self, message: str | None = None, data: dict | None = None): if message: self.message = message + self.data = data super().__init__(self.message) diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index cdfc306f4..34de6e47a 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -17,11 +17,15 @@ This module provides helper functions for common proto type operations. """ -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypedDict +from google.api.field_behavior_pb2 import FieldBehavior, field_behavior +from google.protobuf.descriptor import FieldDescriptor from google.protobuf.json_format import ParseDict from google.protobuf.message import Message as ProtobufMessage +from a2a.utils.errors import InvalidParamsError + if TYPE_CHECKING: from starlette.datastructures import QueryParams @@ -189,3 +193,106 @@ def parse_params(params: QueryParams, message: ProtobufMessage) -> None: processed[k] = parsed_val ParseDict(processed, message, ignore_unknown_fields=True) + + +class ValidationDetail(TypedDict): + """Structured validation error detail.""" + + field: str + message: str + + +def _check_required_field_violation( + msg: ProtobufMessage, field: FieldDescriptor +) -> ValidationDetail | None: + """Check if a required field is missing or invalid.""" + val = getattr(msg, field.name) + if field.is_repeated: + if not val: + return ValidationDetail( + field=field.name, + message='Field must contain at least one element.', + ) + elif field.has_presence: + if not msg.HasField(field.name): + return ValidationDetail( + field=field.name, message='Field is required.' + ) + elif val == field.default_value: + return ValidationDetail(field=field.name, message='Field is required.') + return None + + +def _append_nested_errors( + errors: list[ValidationDetail], + prefix: str, + sub_errs: list[ValidationDetail], +) -> None: + """Format nested validation errors and append to errors list.""" + for sub in sub_errs: + sub_field = sub['field'] + errors.append( + ValidationDetail( + field=f'{prefix}.{sub_field}' if sub_field else prefix, + message=sub['message'], + ) + ) + + +def _recurse_validation( + msg: ProtobufMessage, field: FieldDescriptor +) -> list[ValidationDetail]: + """Recurse validation for nested messages and map fields.""" + errors: list[ValidationDetail] = [] + if field.type != FieldDescriptor.TYPE_MESSAGE: + return errors + + val = getattr(msg, field.name) + if not field.is_repeated: + if msg.HasField(field.name): + sub_errs = _validate_proto_required_fields_internal(val) + _append_nested_errors(errors, field.name, sub_errs) + elif field.message_type.GetOptions().map_entry: + for k, v in val.items(): + if isinstance(v, ProtobufMessage): + sub_errs = _validate_proto_required_fields_internal(v) + _append_nested_errors(errors, f'{field.name}[{k}]', sub_errs) + else: + for i, item in enumerate(val): + sub_errs = _validate_proto_required_fields_internal(item) + _append_nested_errors(errors, f'{field.name}[{i}]', sub_errs) + return errors + + +def _validate_proto_required_fields_internal( + msg: ProtobufMessage, +) -> list[ValidationDetail]: + """Internal validation that returns a list of error dictionaries.""" + desc = msg.DESCRIPTOR + errors: list[ValidationDetail] = [] + + for field in desc.fields: + options = field.GetOptions() + if FieldBehavior.REQUIRED in options.Extensions[field_behavior]: + violation = _check_required_field_violation(msg, field) + if violation: + errors.append(violation) + errors.extend(_recurse_validation(msg, field)) + return errors + + +def validate_proto_required_fields(msg: ProtobufMessage) -> None: + """Validate that all fields marked as REQUIRED are present on the proto message. + + Args: + msg: The Protobuf message to validate. + + Raises: + InvalidParamsError: If a required field is missing or empty. + """ + errors = _validate_proto_required_fields_internal(msg) + + if errors: + raise InvalidParamsError( + message='Validation failed', data={'errors': errors} + ) diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index ba2627e38..3d22813c6 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -451,7 +451,9 @@ async def test_on_cancel_task_invalid_result_type(): # Mock ResultAggregator to return a Message mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) mock_result_aggregator_instance.consume_all.return_value = Message( - message_id='unexpected_msg', role=Role.ROLE_AGENT, parts=[] + message_id='unexpected_msg', + role=Role.ROLE_AGENT, + parts=[Part(text='Test')], ) request_handler = DefaultRequestHandler( @@ -524,7 +526,7 @@ async def test_on_message_send_with_push_notification(): message=Message( role=Role.ROLE_USER, message_id='msg_push', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -630,7 +632,7 @@ async def test_on_message_send_with_push_notification_in_non_blocking_request(): message=Message( role=Role.ROLE_USER, message_id='msg_non_blocking', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -750,7 +752,11 @@ async def test_on_message_send_with_push_notification_no_existing_Task(): accepted_output_modes=['text/plain'], # Added required field ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg_push', parts=[]), + message=Message( + role=Role.ROLE_USER, + message_id='msg_push', + parts=[Part(text='Test')], + ), configuration=message_config, ) @@ -815,7 +821,11 @@ async def test_on_message_send_no_result_from_aggregator(): request_context_builder=mock_request_context_builder, ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg_no_res', parts=[]) + message=Message( + role=Role.ROLE_USER, + message_id='msg_no_res', + parts=[Part(text='Test')], + ) ) mock_result_aggregator_instance = AsyncMock(spec=ResultAggregator) @@ -863,7 +873,9 @@ async def test_on_message_send_task_id_mismatch(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_id_mismatch', parts=[] + role=Role.ROLE_USER, + message_id='msg_id_mismatch', + parts=[Part(text='Test')], ) ) @@ -1067,7 +1079,9 @@ async def test_on_message_send_interrupted_flow(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_interrupt', parts=[] + role=Role.ROLE_USER, + message_id='msg_interrupt', + parts=[Part(text='Test')], ) ) @@ -1178,7 +1192,7 @@ async def test_on_message_send_stream_with_push_notification(): message=Message( role=Role.ROLE_USER, message_id='msg_stream_push', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ), @@ -1460,7 +1474,7 @@ async def test_stream_disconnect_then_resubscribe_receives_future_events(): message=Message( role=Role.ROLE_USER, message_id='msg_reconn', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1558,7 +1572,7 @@ async def test_on_message_send_stream_client_disconnect_triggers_background_clea message=Message( role=Role.ROLE_USER, message_id='mid', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1698,7 +1712,7 @@ async def cancel( message=Message( role=Role.ROLE_USER, message_id='msg_persist', - parts=[], + parts=[Part(text='Test')], ) ) @@ -1785,7 +1799,7 @@ async def test_background_cleanup_task_is_tracked_and_cleared(): message=Message( role=Role.ROLE_USER, message_id='mid_track', - parts=[], + parts=[Part(text='Test')], task_id=task_id, context_id=context_id, ) @@ -1890,7 +1904,9 @@ async def test_on_message_send_stream_task_id_mismatch(): ) params = SendMessageRequest( message=Message( - role=Role.ROLE_USER, message_id='msg_stream_mismatch', parts=[] + role=Role.ROLE_USER, + message_id='msg_stream_mismatch', + parts=[Part(text='Test')], ) ) @@ -2586,7 +2602,7 @@ async def test_on_message_send_task_in_terminal_state(terminal_state): message=Message( role=Role.ROLE_USER, message_id='msg_terminal', - parts=[], + parts=[Part(text='Test')], task_id=task_id, ) ) @@ -2627,7 +2643,7 @@ async def test_on_message_send_stream_task_in_terminal_state(terminal_state): message=Message( role=Role.ROLE_USER, message_id='msg_terminal_stream', - parts=[], + parts=[Part(text='Test')], task_id=task_id, ) ) @@ -2869,7 +2885,9 @@ async def test_on_message_send_negative_history_length_error(): accepted_output_modes=['text/plain'], ) params = SendMessageRequest( - message=Message(role=Role.ROLE_USER, message_id='msg1', parts=[]), + message=Message( + role=Role.ROLE_USER, message_id='msg1', parts=[Part(text='Test')] + ), configuration=message_config, ) context = create_server_call_context() From 1c648a21dcd7088f0e08e87c3f68016838226fb7 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 19 Mar 2026 14:14:43 +0000 Subject: [PATCH 02/16] WIP --- src/a2a/client/transports/grpc.py | 20 +++++++-- src/a2a/client/transports/jsonrpc.py | 3 +- .../server/request_handlers/grpc_handler.py | 21 ++++++++-- .../request_handlers/jsonrpc_handler.py | 1 + tests/integration/test_end_to_end.py | 34 ++++++++++++++- tests/utils/test_proto_utils.py | 41 +++++++++++++++++++ 6 files changed, 110 insertions(+), 10 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 02c418eb3..0945f3bca 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -61,17 +61,29 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: # Use grpc_status to cleanly extract the rich Status from the call status = rpc_status.from_call(cast('grpc.Call', e)) + data = None if status is not None: + exception_cls = None for detail in status.details: - if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): + if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR): + bad_request = error_details_pb2.BadRequest() + detail.Unpack(bad_request) + errors = [ + {'field': v.field, 'message': v.description} + for v in bad_request.field_violations + ] + data = {'errors': errors} + # Infer InvalidParamsError from BadRequest details + exception_cls = A2A_REASON_TO_ERROR.get('INVALID_PARAMS') + elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) - if error_info.domain == 'a2a-protocol.org': exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason) - if exception_cls: - raise exception_cls(status.message) from e + + if exception_cls: + raise exception_cls(status.message, data=data) from e raise A2AClientError(f'gRPC Error {e.code().name}: {e.details()}') from e diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 9854aabb0..eca6c4897 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -318,9 +318,10 @@ def _create_jsonrpc_error(self, error_dict: dict[str, Any]) -> Exception: """Creates the appropriate A2AError from a JSON-RPC error dictionary.""" code = error_dict.get('code') message = error_dict.get('message', str(error_dict)) + data = error_dict.get('data') if isinstance(code, int) and code in _JSON_RPC_ERROR_CODE_TO_A2A_ERROR: - return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message) + return _JSON_RPC_ERROR_CODE_TO_A2A_ERROR[code](message, data=data) # Fallback to general A2AClientError return A2AClientError(f'JSON-RPC Error {code}: {message}') diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 326dea236..05277426d 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -438,16 +438,29 @@ async def abort_context( error.message if hasattr(error, 'message') else str(error) ) - # Create standard Status and pack the ErrorInfo + # Create standard Status status = status_pb2.Status(code=status_code, message=error_msg) - detail = any_pb2.Any() - detail.Pack(error_info) - status.details.append(detail) + + # Exclusive details based on error type: + if error.data and error.data.get('errors'): + bad_request = error_details_pb2.BadRequest() + for err_dict in error.data['errors']: + violation = bad_request.field_violations.add() + violation.field = err_dict.get('field', '') + violation.description = err_dict.get('message', '') + any_bad_request = any_pb2.Any() + any_bad_request.Pack(bad_request) + status.details.append(any_bad_request) + else: + detail = any_pb2.Any() + detail.Pack(error_info) + status.details.append(detail) # Use grpc_status to safely generate standard trailing metadata rich_status = rpc_status.to_status(status) new_metadata: list[tuple[str, str | bytes]] = [] + trailing = context.trailing_metadata() if trailing: for k, v in trailing: diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index e7d5b75ad..0fe6c56bd 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -92,6 +92,7 @@ def _build_error_response( jsonrpc_error = model_class( code=code, message=str(error), + data=error.data, ) else: jsonrpc_error = JSONRPCInternalError(message=str(error)) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index ddf9edbf3..11d1b4562 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -1,5 +1,5 @@ from collections.abc import AsyncGenerator -from typing import NamedTuple +from typing import Any, NamedTuple import grpc import httpx @@ -31,6 +31,7 @@ a2a_pb2_grpc, ) from a2a.utils import TransportProtocol +from a2a.utils.errors import InvalidParamsError def assert_message_matches(message, expected_role, expected_text): @@ -546,3 +547,34 @@ async def test_end_to_end_input_required(transport_setups): ], ) assert_message_matches(task.status.message, Role.ROLE_AGENT, 'done') + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'empty_request, expected_fields', + [ + ( + SendMessageRequest(), + ['message'], + ), + ( + SendMessageRequest(message=Message()), + ['message.message_id', 'message.role', 'message.parts'], + ), + ], +) +async def test_end_to_end_validation_errors( + transport_setups, + empty_request: SendMessageRequest, + expected_fields: list[str], +) -> None: + client = transport_setups.client + + with pytest.raises(InvalidParamsError) as exc_info: + async for _ in client.send_message(request=empty_request): + pass + + errors = exc_info.value.data.get('errors', []) + assert {e['field'] for e in errors} == set(expected_fields) + + await client.close() diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index 6a53541f3..e2c760bae 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -239,3 +239,44 @@ def _message_to_rest_params(self, message: ProtobufMessage) -> QueryParams: return httpx.Request( 'GET', 'http://api.example.com', params=rest_dict ).url.params + + +class TestValidateProtoRequiredFields: + """Tests for validate_proto_required_fields function.""" + + def test_valid_required_fields(self): + """Test with all required fields present.""" + msg = Message( + message_id='msg-1', + role=Role.ROLE_USER, + parts=[Part(text='hello')], + ) + proto_utils.validate_proto_required_fields(msg) + + def test_missing_required_fields(self): + """Test with empty message raising InvalidParamsError containing all errors.""" + from a2a.utils.errors import InvalidParamsError + + msg = Message() + with pytest.raises(InvalidParamsError) as exc_info: + proto_utils.validate_proto_required_fields(msg) + + err = exc_info.value + errors = err.data.get('errors', []) if err.data else [] + + assert {e['field'] for e in errors} == {'message_id', 'role', 'parts'} + + def test_nested_required_fields(self): + """Test nested required fields inside TaskStatus.""" + from a2a.utils.errors import InvalidParamsError + + # Task Status requires 'state' + task = Task(id='task-1', status=TaskStatus()) + with pytest.raises(InvalidParamsError) as exc_info: + proto_utils.validate_proto_required_fields(task) + + err = exc_info.value + errors = err.data.get('errors', []) if err.data else [] + + fields = [e['field'] for e in errors] + assert 'status.state' in fields From 73e54452d8d573da87eeeff467fc405e2d0cfb31 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 19 Mar 2026 14:46:11 +0000 Subject: [PATCH 03/16] WIP --- src/a2a/client/transports/grpc.py | 5 ++--- .../server/request_handlers/grpc_handler.py | 3 +-- tests/integration/test_end_to_end.py | 19 +++++++++++++++---- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 0945f3bca..c91b78a0c 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -47,7 +47,7 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER -from a2a.utils.errors import A2A_REASON_TO_ERROR +from a2a.utils.errors import InvalidParamsError from a2a.utils.telemetry import SpanKind, trace_class @@ -74,8 +74,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: for v in bad_request.field_violations ] data = {'errors': errors} - # Infer InvalidParamsError from BadRequest details - exception_cls = A2A_REASON_TO_ERROR.get('INVALID_PARAMS') + exception_cls = InvalidParamsError elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 05277426d..503faadbb 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -441,8 +441,7 @@ async def abort_context( # Create standard Status status = status_pb2.Status(code=status_code, message=error_msg) - # Exclusive details based on error type: - if error.data and error.data.get('errors'): + if isinstance(error, types.InvalidParamsError) and error.data and error.data.get('errors'): bad_request = error_details_pb2.BadRequest() for err_dict in error.data['errors']: violation = bad_request.field_violations.add() diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 11d1b4562..af9be2e83 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -570,11 +570,22 @@ async def test_end_to_end_validation_errors( ) -> None: client = transport_setups.client - with pytest.raises(InvalidParamsError) as exc_info: + try: async for _ in client.send_message(request=empty_request): pass - - errors = exc_info.value.data.get('errors', []) - assert {e['field'] for e in errors} == set(expected_fields) + except Exception as e: + # ASGITransport propagates server-side generator crashes as ExceptionGroups + exc = e + if hasattr(e, 'exceptions') and len(e.exceptions) == 1: + exc = e.exceptions[0] + + if not isinstance(exc, InvalidParamsError): + raise e + + errors = exc.data.get('errors', []) if exc.data else [] + assert {e['field'] for e in errors} == set(expected_fields) + return + + pytest.fail('InvalidParamsError was not raised') await client.close() From b891f344f4f098d8d2b771fae67a1be076422829 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Thu, 19 Mar 2026 14:50:08 +0000 Subject: [PATCH 04/16] Cosmetics --- src/a2a/client/transports/grpc.py | 6 ++++-- src/a2a/client/transports/http_helpers.py | 5 +++++ tests/integration/test_end_to_end.py | 25 +++++++---------------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index c91b78a0c..d4cf35e31 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -47,7 +47,7 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER -from a2a.utils.errors import InvalidParamsError +from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError, InvalidParamsError from a2a.utils.telemetry import SpanKind, trace_class @@ -64,7 +64,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: data = None if status is not None: - exception_cls = None + exception_cls: type[A2AError] | None = None for detail in status.details: if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR): bad_request = error_details_pb2.BadRequest() @@ -75,11 +75,13 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: ] data = {'errors': errors} exception_cls = InvalidParamsError + break elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) if error_info.domain == 'a2a-protocol.org': exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason) + break if exception_cls: raise exception_cls(status.message, data=data) from e diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index 0a5721b50..43accadd2 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -40,6 +40,11 @@ def handle_http_exceptions( raise A2AClientError(f'Network communication error: {e}') from e except json.JSONDecodeError as e: raise A2AClientError(f'JSON Decode Error: {e}') from e + except Exception as e: + # ASGITransport propagates local server-side generator crashes as ExceptionGroups + if hasattr(e, 'exceptions') and len(e.exceptions) == 1: + raise e.exceptions[0] from e + raise e def get_http_args(context: ClientCallContext | None) -> dict[str, Any]: diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index af9be2e83..efeea9ad6 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -555,37 +555,26 @@ async def test_end_to_end_input_required(transport_setups): [ ( SendMessageRequest(), - ['message'], + {'message'}, ), ( SendMessageRequest(message=Message()), - ['message.message_id', 'message.role', 'message.parts'], + {'message.message_id', 'message.role', 'message.parts'}, ), ], ) async def test_end_to_end_validation_errors( transport_setups, empty_request: SendMessageRequest, - expected_fields: list[str], + expected_fields: set[str], ) -> None: client = transport_setups.client - try: + with pytest.raises(InvalidParamsError) as exc_info: async for _ in client.send_message(request=empty_request): pass - except Exception as e: - # ASGITransport propagates server-side generator crashes as ExceptionGroups - exc = e - if hasattr(e, 'exceptions') and len(e.exceptions) == 1: - exc = e.exceptions[0] - - if not isinstance(exc, InvalidParamsError): - raise e - - errors = exc.data.get('errors', []) if exc.data else [] - assert {e['field'] for e in errors} == set(expected_fields) - return - - pytest.fail('InvalidParamsError was not raised') + + errors = exc_info.value.data.get('errors', []) + assert {e['field'] for e in errors} == expected_fields await client.close() From 02593433dd1ba09625a0e57c9e409792708e5011 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Fri, 20 Mar 2026 15:13:55 +0000 Subject: [PATCH 05/16] wip --- src/a2a/client/transports/grpc.py | 2 +- src/a2a/server/apps/rest/rest_adapter.py | 2 +- .../server/request_handlers/grpc_handler.py | 6 ++++- .../request_handlers/request_handler.py | 17 +++++++------ .../server/request_handlers/rest_handler.py | 24 ++++++++++++------- 5 files changed, 31 insertions(+), 20 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index d4cf35e31..e8614bbb6 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -76,7 +76,7 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: data = {'errors': errors} exception_cls = InvalidParamsError break - elif detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): + if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) if error_info.domain == 'a2a-protocol.org': diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 6b8abb99e..e44120f3d 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -159,7 +159,7 @@ async def event_generator( yield json.dumps(item) return EventSourceResponse( - event_generator(method(request, call_context)) + event_generator(await method(request, call_context)) ) async def handle_get_agent_card( diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index 4a66e95b7..e9f3f2fe8 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -406,7 +406,11 @@ async def abort_context( # Create standard Status status = status_pb2.Status(code=status_code, message=error_msg) - if isinstance(error, types.InvalidParamsError) and error.data and error.data.get('errors'): + if ( + isinstance(error, types.InvalidParamsError) + and error.data + and error.data.get('errors') + ): bad_request = error_details_pb2.BadRequest() for err_dict in error.data['errors']: violation = bad_request.field_violations.add() diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 6fa68b084..ed1ba6e52 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -229,25 +229,24 @@ async def on_delete_task_push_notification_config( def validate_request_params(method: Callable) -> Callable: """Decorator for RequestHandler methods to validate required fields on incoming requests.""" - if inspect.isasyncgenfunction(method): + if inspect.iscoroutinefunction(method): @functools.wraps(method) - async def async_generator_wrapper( + async def async_wrapper( self: RequestHandler, params: ProtoMessage, context: ServerCallContext, *args: Any, **kwargs: Any, - ) -> AsyncGenerator: + ) -> Any: if params is not None: validate_proto_required_fields(params) - async for item in method(self, params, context, *args, **kwargs): - yield item + return await method(self, params, context, *args, **kwargs) - return async_generator_wrapper + return async_wrapper @functools.wraps(method) - async def async_wrapper( + def sync_wrapper( self: RequestHandler, params: ProtoMessage, context: ServerCallContext, @@ -256,6 +255,6 @@ async def async_wrapper( ) -> Any: if params is not None: validate_proto_required_fields(params) - return await method(self, params, context, *args, **kwargs) + return method(self, params, context, *args, **kwargs) - return async_wrapper + return sync_wrapper diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index af889d9df..50a9f2ac6 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -116,11 +116,14 @@ async def on_message_send_stream( body = await request.body() params = a2a_pb2.SendMessageRequest() Parse(body, params) - async for event in self.request_handler.on_message_send_stream( - params, context - ): - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) + stream = self.request_handler.on_message_send_stream(params, context) + + async def _generator() -> AsyncIterator[dict[str, Any]]: + async for event in stream: + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) + + return _generator() @validate_version(constants.PROTOCOL_VERSION_1_0) async def on_cancel_task( @@ -167,10 +170,15 @@ async def on_subscribe_to_task( JSON serialized objects containing streaming events """ task_id = request.path_params['id'] - async for event in self.request_handler.on_subscribe_to_task( + stream = self.request_handler.on_subscribe_to_task( SubscribeToTaskRequest(id=task_id), context - ): - yield MessageToDict(proto_utils.to_stream_response(event)) + ) + + async def _generator() -> AsyncIterator[dict[str, Any]]: + async for event in stream: + yield MessageToDict(proto_utils.to_stream_response(event)) + + return _generator() @validate_version(constants.PROTOCOL_VERSION_1_0) async def get_push_notification( From 184f8d5802c3009088c0939b1e6a43e7198ea572 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 23 Mar 2026 08:27:15 +0000 Subject: [PATCH 06/16] TMP --- src/a2a/utils/helpers.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index e5b37e5f4..109c85fc4 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -389,8 +389,27 @@ def async_gen_wrapper( return cast('F', async_gen_wrapper) + if inspect.iscoroutinefunction(inspect.unwrap(func)): + + @functools.wraps(func) + async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + actual_version = _get_actual_version(args, kwargs) + if not _is_version_compatible(actual_version): + logger.warning( + "Version mismatch: actual='%s', expected='%s'", + actual_version, + expected_version, + ) + raise VersionNotSupportedError( + message=f"A2A version '{actual_version}' is not supported by this handler. " + f"Expected version '{expected_version}'." + ) + return await func(self, *args, **kwargs) + + return cast('F', async_wrapper) + @functools.wraps(func) - async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: actual_version = _get_actual_version(args, kwargs) if not _is_version_compatible(actual_version): logger.warning( @@ -402,8 +421,8 @@ async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: message=f"A2A version '{actual_version}' is not supported by this handler. " f"Expected version '{expected_version}'." ) - return await func(self, *args, **kwargs) + return func(self, *args, **kwargs) - return cast('F', async_wrapper) + return cast('F', sync_wrapper) return decorator From 5d765941695e1a232562e2040680441b94616d9a Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 23 Mar 2026 14:14:14 +0000 Subject: [PATCH 07/16] Revert a few things --- src/a2a/server/apps/rest/rest_adapter.py | 2 +- .../server/request_handlers/rest_handler.py | 24 ++++++------------ src/a2a/utils/helpers.py | 25 +++---------------- 3 files changed, 12 insertions(+), 39 deletions(-) diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index e44120f3d..6b8abb99e 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -159,7 +159,7 @@ async def event_generator( yield json.dumps(item) return EventSourceResponse( - event_generator(await method(request, call_context)) + event_generator(method(request, call_context)) ) async def handle_get_agent_card( diff --git a/src/a2a/server/request_handlers/rest_handler.py b/src/a2a/server/request_handlers/rest_handler.py index 50a9f2ac6..af889d9df 100644 --- a/src/a2a/server/request_handlers/rest_handler.py +++ b/src/a2a/server/request_handlers/rest_handler.py @@ -116,14 +116,11 @@ async def on_message_send_stream( body = await request.body() params = a2a_pb2.SendMessageRequest() Parse(body, params) - stream = self.request_handler.on_message_send_stream(params, context) - - async def _generator() -> AsyncIterator[dict[str, Any]]: - async for event in stream: - response = proto_utils.to_stream_response(event) - yield MessageToDict(response) - - return _generator() + async for event in self.request_handler.on_message_send_stream( + params, context + ): + response = proto_utils.to_stream_response(event) + yield MessageToDict(response) @validate_version(constants.PROTOCOL_VERSION_1_0) async def on_cancel_task( @@ -170,15 +167,10 @@ async def on_subscribe_to_task( JSON serialized objects containing streaming events """ task_id = request.path_params['id'] - stream = self.request_handler.on_subscribe_to_task( + async for event in self.request_handler.on_subscribe_to_task( SubscribeToTaskRequest(id=task_id), context - ) - - async def _generator() -> AsyncIterator[dict[str, Any]]: - async for event in stream: - yield MessageToDict(proto_utils.to_stream_response(event)) - - return _generator() + ): + yield MessageToDict(proto_utils.to_stream_response(event)) @validate_version(constants.PROTOCOL_VERSION_1_0) async def get_push_notification( diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index 109c85fc4..e5b37e5f4 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -389,27 +389,8 @@ def async_gen_wrapper( return cast('F', async_gen_wrapper) - if inspect.iscoroutinefunction(inspect.unwrap(func)): - - @functools.wraps(func) - async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: - actual_version = _get_actual_version(args, kwargs) - if not _is_version_compatible(actual_version): - logger.warning( - "Version mismatch: actual='%s', expected='%s'", - actual_version, - expected_version, - ) - raise VersionNotSupportedError( - message=f"A2A version '{actual_version}' is not supported by this handler. " - f"Expected version '{expected_version}'." - ) - return await func(self, *args, **kwargs) - - return cast('F', async_wrapper) - @functools.wraps(func) - def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: actual_version = _get_actual_version(args, kwargs) if not _is_version_compatible(actual_version): logger.warning( @@ -421,8 +402,8 @@ def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: message=f"A2A version '{actual_version}' is not supported by this handler. " f"Expected version '{expected_version}'." ) - return func(self, *args, **kwargs) + return await func(self, *args, **kwargs) - return cast('F', sync_wrapper) + return cast('F', async_wrapper) return decorator From 2075a2474a1363611c4046c1887df18be065b716 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Mon, 23 Mar 2026 15:30:28 +0000 Subject: [PATCH 08/16] TMP --- .../default_request_handler.py | 22 +++++++++---------- .../request_handlers/request_handler.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 99bb81fc2..d33edddc5 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -20,7 +20,6 @@ ) from a2a.server.request_handlers.request_handler import ( RequestHandler, - validate_request_params, ) from a2a.server.tasks import ( PushNotificationConfigStore, @@ -59,6 +58,7 @@ validate_page_size, ) from a2a.utils.telemetry import SpanKind, trace_class +from a2a.utils.proto_utils import validate_proto_required_fields logger = logging.getLogger(__name__) @@ -121,13 +121,13 @@ def __init__( # noqa: PLR0913 # asyncio tasks and to surface unexpected exceptions. self._background_tasks = set() - @validate_request_params async def on_get_task( self, params: GetTaskRequest, context: ServerCallContext, ) -> Task | None: """Default handler for 'tasks/get'.""" + validate_proto_required_fields(params) validate_history_length(params) task_id = params.id @@ -137,13 +137,13 @@ async def on_get_task( return apply_history_length(task, params) - @validate_request_params async def on_list_tasks( self, params: ListTasksRequest, context: ServerCallContext, ) -> ListTasksResponse: """Default handler for 'tasks/list'.""" + validate_proto_required_fields(params) validate_history_length(params) if params.HasField('page_size'): validate_page_size(params.page_size) @@ -159,7 +159,6 @@ async def on_list_tasks( return page - @validate_request_params async def on_cancel_task( self, params: CancelTaskRequest, @@ -169,6 +168,7 @@ async def on_cancel_task( Attempts to cancel the task managed by the `AgentExecutor`. """ + validate_proto_required_fields(params) task_id = params.id task: Task | None = await self.task_store.get(task_id, context) if not task: @@ -323,7 +323,6 @@ async def _send_push_notification_if_needed( ): await self._push_sender.send_notification(task_id, event) - @validate_request_params async def on_message_send( self, params: SendMessageRequest, @@ -334,6 +333,7 @@ async def on_message_send( Starts the agent execution for the message and waits for the final result (Task or Message). """ + validate_proto_required_fields(params) validate_history_length(params.configuration) ( @@ -393,7 +393,6 @@ async def push_notification_callback(event: Event) -> None: return result - @validate_request_params async def on_message_send_stream( self, params: SendMessageRequest, @@ -404,6 +403,7 @@ async def on_message_send_stream( Starts the agent execution and yields events as they are produced by the agent. """ + validate_proto_required_fields(params) ( _task_manager, task_id, @@ -482,7 +482,6 @@ async def _cleanup_producer( async with self._running_agents_lock: self._running_agents.pop(task_id, None) - @validate_request_params async def on_create_task_push_notification_config( self, params: TaskPushNotificationConfig, @@ -492,6 +491,7 @@ async def on_create_task_push_notification_config( Requires a `PushNotifier` to be configured. """ + validate_proto_required_fields(params) if not self._push_config_store: raise UnsupportedOperationError @@ -508,7 +508,6 @@ async def on_create_task_push_notification_config( return params - @validate_request_params async def on_get_task_push_notification_config( self, params: GetTaskPushNotificationConfigRequest, @@ -518,6 +517,7 @@ async def on_get_task_push_notification_config( Requires a `PushConfigStore` to be configured. """ + validate_proto_required_fields(params) if not self._push_config_store: raise UnsupportedOperationError @@ -540,7 +540,6 @@ async def on_get_task_push_notification_config( raise InternalError(message='Push notification config not found') - @validate_request_params async def on_subscribe_to_task( self, params: SubscribeToTaskRequest, @@ -551,6 +550,7 @@ async def on_subscribe_to_task( Allows a client to re-attach to a running streaming task's event stream. Requires the task and its queue to still be active. """ + validate_proto_required_fields(params) task_id = params.id task: Task | None = await self.task_store.get(task_id, context) if not task: @@ -583,7 +583,6 @@ async def on_subscribe_to_task( async for event in result_aggregator.consume_and_emit(consumer): yield event - @validate_request_params async def on_list_task_push_notification_configs( self, params: ListTaskPushNotificationConfigsRequest, @@ -593,6 +592,7 @@ async def on_list_task_push_notification_configs( Requires a `PushConfigStore` to be configured. """ + validate_proto_required_fields(params) if not self._push_config_store: raise UnsupportedOperationError @@ -609,7 +609,6 @@ async def on_list_task_push_notification_configs( configs=push_notification_config_list ) - @validate_request_params async def on_delete_task_push_notification_config( self, params: DeleteTaskPushNotificationConfigRequest, @@ -619,6 +618,7 @@ async def on_delete_task_push_notification_config( Requires a `PushConfigStore` to be configured. """ + validate_proto_required_fields(params) if not self._push_config_store: raise UnsupportedOperationError diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index ed1ba6e52..34608fa4e 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -227,7 +227,7 @@ async def on_delete_task_push_notification_config( """ -def validate_request_params(method: Callable) -> Callable: +def _validate_request_params(method: Callable) -> Callable: """Decorator for RequestHandler methods to validate required fields on incoming requests.""" if inspect.iscoroutinefunction(method): From ba6deafe04a158d7a024880d2b9be413dfc6d792 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 11:25:17 +0000 Subject: [PATCH 09/16] WIP --- src/a2a/server/request_handlers/__init__.py | 6 ++++- .../default_request_handler.py | 22 +++++++++---------- .../request_handlers/request_handler.py | 19 +++++++++++++++- 3 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/a2a/server/request_handlers/__init__.py b/src/a2a/server/request_handlers/__init__.py index 43ebc8e25..e8fcc2141 100644 --- a/src/a2a/server/request_handlers/__init__.py +++ b/src/a2a/server/request_handlers/__init__.py @@ -6,7 +6,10 @@ DefaultRequestHandler, ) from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler -from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.request_handlers.request_handler import ( + RequestHandler, + validate_request_params, +) from a2a.server.request_handlers.response_helpers import ( build_error_response, prepare_response_object, @@ -45,4 +48,5 @@ def __init__(self, *args, **kwargs): 'RequestHandler', 'build_error_response', 'prepare_response_object', + 'validate_request_params', ] diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index d33edddc5..99bb81fc2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -20,6 +20,7 @@ ) from a2a.server.request_handlers.request_handler import ( RequestHandler, + validate_request_params, ) from a2a.server.tasks import ( PushNotificationConfigStore, @@ -58,7 +59,6 @@ validate_page_size, ) from a2a.utils.telemetry import SpanKind, trace_class -from a2a.utils.proto_utils import validate_proto_required_fields logger = logging.getLogger(__name__) @@ -121,13 +121,13 @@ def __init__( # noqa: PLR0913 # asyncio tasks and to surface unexpected exceptions. self._background_tasks = set() + @validate_request_params async def on_get_task( self, params: GetTaskRequest, context: ServerCallContext, ) -> Task | None: """Default handler for 'tasks/get'.""" - validate_proto_required_fields(params) validate_history_length(params) task_id = params.id @@ -137,13 +137,13 @@ async def on_get_task( return apply_history_length(task, params) + @validate_request_params async def on_list_tasks( self, params: ListTasksRequest, context: ServerCallContext, ) -> ListTasksResponse: """Default handler for 'tasks/list'.""" - validate_proto_required_fields(params) validate_history_length(params) if params.HasField('page_size'): validate_page_size(params.page_size) @@ -159,6 +159,7 @@ async def on_list_tasks( return page + @validate_request_params async def on_cancel_task( self, params: CancelTaskRequest, @@ -168,7 +169,6 @@ async def on_cancel_task( Attempts to cancel the task managed by the `AgentExecutor`. """ - validate_proto_required_fields(params) task_id = params.id task: Task | None = await self.task_store.get(task_id, context) if not task: @@ -323,6 +323,7 @@ async def _send_push_notification_if_needed( ): await self._push_sender.send_notification(task_id, event) + @validate_request_params async def on_message_send( self, params: SendMessageRequest, @@ -333,7 +334,6 @@ async def on_message_send( Starts the agent execution for the message and waits for the final result (Task or Message). """ - validate_proto_required_fields(params) validate_history_length(params.configuration) ( @@ -393,6 +393,7 @@ async def push_notification_callback(event: Event) -> None: return result + @validate_request_params async def on_message_send_stream( self, params: SendMessageRequest, @@ -403,7 +404,6 @@ async def on_message_send_stream( Starts the agent execution and yields events as they are produced by the agent. """ - validate_proto_required_fields(params) ( _task_manager, task_id, @@ -482,6 +482,7 @@ async def _cleanup_producer( async with self._running_agents_lock: self._running_agents.pop(task_id, None) + @validate_request_params async def on_create_task_push_notification_config( self, params: TaskPushNotificationConfig, @@ -491,7 +492,6 @@ async def on_create_task_push_notification_config( Requires a `PushNotifier` to be configured. """ - validate_proto_required_fields(params) if not self._push_config_store: raise UnsupportedOperationError @@ -508,6 +508,7 @@ async def on_create_task_push_notification_config( return params + @validate_request_params async def on_get_task_push_notification_config( self, params: GetTaskPushNotificationConfigRequest, @@ -517,7 +518,6 @@ async def on_get_task_push_notification_config( Requires a `PushConfigStore` to be configured. """ - validate_proto_required_fields(params) if not self._push_config_store: raise UnsupportedOperationError @@ -540,6 +540,7 @@ async def on_get_task_push_notification_config( raise InternalError(message='Push notification config not found') + @validate_request_params async def on_subscribe_to_task( self, params: SubscribeToTaskRequest, @@ -550,7 +551,6 @@ async def on_subscribe_to_task( Allows a client to re-attach to a running streaming task's event stream. Requires the task and its queue to still be active. """ - validate_proto_required_fields(params) task_id = params.id task: Task | None = await self.task_store.get(task_id, context) if not task: @@ -583,6 +583,7 @@ async def on_subscribe_to_task( async for event in result_aggregator.consume_and_emit(consumer): yield event + @validate_request_params async def on_list_task_push_notification_configs( self, params: ListTaskPushNotificationConfigsRequest, @@ -592,7 +593,6 @@ async def on_list_task_push_notification_configs( Requires a `PushConfigStore` to be configured. """ - validate_proto_required_fields(params) if not self._push_config_store: raise UnsupportedOperationError @@ -609,6 +609,7 @@ async def on_list_task_push_notification_configs( configs=push_notification_config_list ) + @validate_request_params async def on_delete_task_push_notification_config( self, params: DeleteTaskPushNotificationConfigRequest, @@ -618,7 +619,6 @@ async def on_delete_task_push_notification_config( Requires a `PushConfigStore` to be configured. """ - validate_proto_required_fields(params) if not self._push_config_store: raise UnsupportedOperationError diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 34608fa4e..e23f3926e 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -227,8 +227,25 @@ async def on_delete_task_push_notification_config( """ -def _validate_request_params(method: Callable) -> Callable: +def validate_request_params(method: Callable) -> Callable: """Decorator for RequestHandler methods to validate required fields on incoming requests.""" + if inspect.isasyncgenfunction(method): + + @functools.wraps(method) + async def async_gen_wrapper( + self: RequestHandler, + params: ProtoMessage, + context: ServerCallContext, + *args: Any, + **kwargs: Any, + ) -> Any: + if params is not None: + validate_proto_required_fields(params) + async for item in method(self, params, context, *args, **kwargs): + yield item + + return async_gen_wrapper + if inspect.iscoroutinefunction(method): @functools.wraps(method) From c87dac205b18e21e7244139af888ee6158f765c8 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 11:58:49 +0000 Subject: [PATCH 10/16] Updates --- src/a2a/client/transports/grpc.py | 18 +++++--------- .../server/request_handlers/grpc_handler.py | 24 +++++++++---------- .../request_handlers/request_handler.py | 22 +++-------------- src/a2a/utils/proto_utils.py | 23 ++++++++++++++++++ tests/integration/test_end_to_end.py | 2 +- tests/utils/test_proto_utils.py | 9 +++---- 6 files changed, 47 insertions(+), 51 deletions(-) diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index e8614bbb6..24c4b5385 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -47,7 +47,8 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_CURRENT, VERSION_HEADER -from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError, InvalidParamsError +from a2a.utils.errors import A2A_REASON_TO_ERROR, A2AError +from a2a.utils.proto_utils import bad_request_to_validation_errors from a2a.utils.telemetry import SpanKind, trace_class @@ -66,22 +67,15 @@ def _map_grpc_error(e: grpc.aio.AioRpcError) -> NoReturn: if status is not None: exception_cls: type[A2AError] | None = None for detail in status.details: - if detail.Is(error_details_pb2.BadRequest.DESCRIPTOR): - bad_request = error_details_pb2.BadRequest() - detail.Unpack(bad_request) - errors = [ - {'field': v.field, 'message': v.description} - for v in bad_request.field_violations - ] - data = {'errors': errors} - exception_cls = InvalidParamsError - break if detail.Is(error_details_pb2.ErrorInfo.DESCRIPTOR): error_info = error_details_pb2.ErrorInfo() detail.Unpack(error_info) if error_info.domain == 'a2a-protocol.org': exception_cls = A2A_REASON_TO_ERROR.get(error_info.reason) - break + elif detail.Is(error_details_pb2.BadRequest.DESCRIPTOR): + bad_request = error_details_pb2.BadRequest() + detail.Unpack(bad_request) + data = {'errors': bad_request_to_validation_errors(bad_request)} if exception_cls: raise exception_cls(status.message, data=data) from e diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index e9f3f2fe8..e4cb39492 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -41,6 +41,7 @@ TaskNotFoundError, ) from a2a.utils.helpers import maybe_await, validate +from a2a.utils.proto_utils import validation_errors_to_bad_request logger = logging.getLogger(__name__) @@ -403,26 +404,23 @@ async def abort_context( error.message if hasattr(error, 'message') else str(error) ) - # Create standard Status + # Create standard Status with ErrorInfo for all A2A errors status = status_pb2.Status(code=status_code, message=error_msg) + error_info_detail = any_pb2.Any() + error_info_detail.Pack(error_info) + status.details.append(error_info_detail) + # Append structured field violations for validation errors if ( isinstance(error, types.InvalidParamsError) and error.data and error.data.get('errors') ): - bad_request = error_details_pb2.BadRequest() - for err_dict in error.data['errors']: - violation = bad_request.field_violations.add() - violation.field = err_dict.get('field', '') - violation.description = err_dict.get('message', '') - any_bad_request = any_pb2.Any() - any_bad_request.Pack(bad_request) - status.details.append(any_bad_request) - else: - detail = any_pb2.Any() - detail.Pack(error_info) - status.details.append(detail) + bad_request_detail = any_pb2.Any() + bad_request_detail.Pack( + validation_errors_to_bad_request(error.data['errors']) + ) + status.details.append(bad_request_detail) # Use grpc_status to safely generate standard trailing metadata rich_status = rpc_status.to_status(status) diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index e23f3926e..8c955cbbe 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -246,24 +246,8 @@ async def async_gen_wrapper( return async_gen_wrapper - if inspect.iscoroutinefunction(method): - - @functools.wraps(method) - async def async_wrapper( - self: RequestHandler, - params: ProtoMessage, - context: ServerCallContext, - *args: Any, - **kwargs: Any, - ) -> Any: - if params is not None: - validate_proto_required_fields(params) - return await method(self, params, context, *args, **kwargs) - - return async_wrapper - @functools.wraps(method) - def sync_wrapper( + async def async_wrapper( self: RequestHandler, params: ProtoMessage, context: ServerCallContext, @@ -272,6 +256,6 @@ def sync_wrapper( ) -> Any: if params is not None: validate_proto_required_fields(params) - return method(self, params, context, *args, **kwargs) + return await method(self, params, context, *args, **kwargs) - return sync_wrapper + return async_wrapper diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 34de6e47a..f77593297 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -23,6 +23,7 @@ from google.protobuf.descriptor import FieldDescriptor from google.protobuf.json_format import ParseDict from google.protobuf.message import Message as ProtobufMessage +from google.rpc import error_details_pb2 from a2a.utils.errors import InvalidParamsError @@ -296,3 +297,25 @@ def validate_proto_required_fields(msg: ProtobufMessage) -> None: raise InvalidParamsError( message='Validation failed', data={'errors': errors} ) + + +def validation_errors_to_bad_request( + errors: list[ValidationDetail], +) -> error_details_pb2.BadRequest: + """Convert validation error details to a gRPC BadRequest proto.""" + bad_request = error_details_pb2.BadRequest() + for err in errors: + violation = bad_request.field_violations.add() + violation.field = err['field'] + violation.description = err['message'] + return bad_request + + +def bad_request_to_validation_errors( + bad_request: error_details_pb2.BadRequest, +) -> list[ValidationDetail]: + """Convert a gRPC BadRequest proto to validation error details.""" + return [ + ValidationDetail(field=v.field, message=v.description) + for v in bad_request.field_violations + ] diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index 9322469d8..cd1a78f6f 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -1,5 +1,5 @@ from collections.abc import AsyncGenerator -from typing import Any, NamedTuple +from typing import NamedTuple import grpc import httpx diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index e2c760bae..6d251660b 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -5,12 +5,13 @@ import httpx import pytest + from google.protobuf.json_format import MessageToDict, Parse from google.protobuf.message import Message as ProtobufMessage from google.protobuf.timestamp_pb2 import Timestamp +from starlette.datastructures import QueryParams from a2a.types.a2a_pb2 import ( - AgentCard, AgentSkill, ListTasksRequest, Message, @@ -23,8 +24,8 @@ TaskStatus, TaskStatusUpdateEvent, ) -from starlette.datastructures import QueryParams from a2a.utils import proto_utils +from a2a.utils.errors import InvalidParamsError class TestToStreamResponse: @@ -255,8 +256,6 @@ def test_valid_required_fields(self): def test_missing_required_fields(self): """Test with empty message raising InvalidParamsError containing all errors.""" - from a2a.utils.errors import InvalidParamsError - msg = Message() with pytest.raises(InvalidParamsError) as exc_info: proto_utils.validate_proto_required_fields(msg) @@ -268,8 +267,6 @@ def test_missing_required_fields(self): def test_nested_required_fields(self): """Test nested required fields inside TaskStatus.""" - from a2a.utils.errors import InvalidParamsError - # Task Status requires 'state' task = Task(id='task-1', status=TaskStatus()) with pytest.raises(InvalidParamsError) as exc_info: From 5e02bcb0a5e6563f399147947a3542fb67b0beed Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 12:16:57 +0000 Subject: [PATCH 11/16] More tests --- tests/integration/test_end_to_end.py | 93 +++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/tests/integration/test_end_to_end.py b/tests/integration/test_end_to_end.py index cd1a78f6f..675856a65 100644 --- a/tests/integration/test_end_to_end.py +++ b/tests/integration/test_end_to_end.py @@ -22,13 +22,18 @@ AgentCapabilities, AgentCard, AgentInterface, + CancelTaskRequest, + DeleteTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, GetTaskRequest, + ListTaskPushNotificationConfigsRequest, ListTasksRequest, Message, Part, Role, SendMessageConfiguration, SendMessageRequest, + SubscribeToTaskRequest, TaskState, a2a_pb2_grpc, ) @@ -273,6 +278,22 @@ def transport_setups(request) -> ClientSetup: return request.getfixturevalue(request.param) +@pytest.fixture( + params=[ + pytest.param('jsonrpc_setup', id='JSON-RPC'), + pytest.param('grpc_setup', id='gRPC'), + ] +) +def rpc_transport_setups(request) -> ClientSetup: + """Parametrized fixture for RPC transports only (excludes REST). + + REST encodes some required fields in URL paths, so empty-field validation + tests hit routing errors before reaching the handler. JSON-RPC and gRPC + send the full request message, allowing server-side validation to work. + """ + return request.getfixturevalue(request.param) + + @pytest.mark.asyncio async def test_end_to_end_send_message_blocking(transport_setups): client = transport_setups.client @@ -569,9 +590,15 @@ async def test_end_to_end_input_required(transport_setups): SendMessageRequest(message=Message()), {'message.message_id', 'message.role', 'message.parts'}, ), + ( + SendMessageRequest( + message=Message(message_id='m1', role=Role.ROLE_USER) + ), + {'message.parts'}, + ), ], ) -async def test_end_to_end_validation_errors( +async def test_end_to_end_send_message_validation_errors( transport_setups, empty_request: SendMessageRequest, expected_fields: set[str], @@ -586,3 +613,67 @@ async def test_end_to_end_validation_errors( assert {e['field'] for e in errors} == expected_fields await client.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'method, invalid_request, expected_fields', + [ + ( + 'get_task', + GetTaskRequest(), + {'id'}, + ), + ( + 'cancel_task', + CancelTaskRequest(), + {'id'}, + ), + ( + 'get_task_push_notification_config', + GetTaskPushNotificationConfigRequest(), + {'task_id', 'id'}, + ), + ( + 'list_task_push_notification_configs', + ListTaskPushNotificationConfigsRequest(), + {'task_id'}, + ), + ( + 'delete_task_push_notification_config', + DeleteTaskPushNotificationConfigRequest(), + {'task_id', 'id'}, + ), + ], +) +async def test_end_to_end_unary_validation_errors( + rpc_transport_setups, + method: str, + invalid_request, + expected_fields: set[str], +) -> None: + client = rpc_transport_setups.client + + with pytest.raises(InvalidParamsError) as exc_info: + await getattr(client, method)(request=invalid_request) + + errors = exc_info.value.data.get('errors', []) + assert {e['field'] for e in errors} == expected_fields + + await client.close() + + +@pytest.mark.asyncio +async def test_end_to_end_subscribe_validation_error( + rpc_transport_setups, +) -> None: + client = rpc_transport_setups.client + + with pytest.raises(InvalidParamsError) as exc_info: + async for _ in client.subscribe(request=SubscribeToTaskRequest()): + pass + + errors = exc_info.value.data.get('errors', []) + assert {e['field'] for e in errors} == {'id'} + + await client.close() From aecfeadb08dc1e007ca68092ec5f21cec5664eab Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 13:05:29 +0000 Subject: [PATCH 12/16] Attempt to fix flakiness --- src/a2a/server/request_handlers/request_handler.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 8c955cbbe..b75e3ca11 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -241,8 +241,17 @@ async def async_gen_wrapper( ) -> Any: if params is not None: validate_proto_required_fields(params) - async for item in method(self, params, context, *args, **kwargs): - yield item + # Explicitly close the inner async generator in a finally block + # so that its cleanup (except/finally) runs deterministically + # during aclose(). On Python < 3.13, async-for does NOT call + # aclose() on the iterator when an exception (e.g. GeneratorExit) + # propagates through the loop body, leaving cleanup to the GC. + inner = method(self, params, context, *args, **kwargs) + try: + async for item in inner: + yield item + finally: + await inner.aclose() return async_gen_wrapper From 3286e36af2a60251010b7fbbe9ff8fdd38f20e42 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 13:38:26 +0000 Subject: [PATCH 13/16] Revert fix, let's see --- src/a2a/server/request_handlers/request_handler.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index b75e3ca11..8c955cbbe 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -241,17 +241,8 @@ async def async_gen_wrapper( ) -> Any: if params is not None: validate_proto_required_fields(params) - # Explicitly close the inner async generator in a finally block - # so that its cleanup (except/finally) runs deterministically - # during aclose(). On Python < 3.13, async-for does NOT call - # aclose() on the iterator when an exception (e.g. GeneratorExit) - # propagates through the loop body, leaving cleanup to the GC. - inner = method(self, params, context, *args, **kwargs) - try: - async for item in inner: - yield item - finally: - await inner.aclose() + async for item in method(self, params, context, *args, **kwargs): + yield item return async_gen_wrapper From 8f66f2a196cd174c5cadb30bc9260c22d1b9bb0f Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 14:40:42 +0000 Subject: [PATCH 14/16] Revert "Revert fix, let's see" This reverts commit 3286e36af2a60251010b7fbbe9ff8fdd38f20e42. --- src/a2a/server/request_handlers/request_handler.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index 8c955cbbe..b75e3ca11 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -241,8 +241,17 @@ async def async_gen_wrapper( ) -> Any: if params is not None: validate_proto_required_fields(params) - async for item in method(self, params, context, *args, **kwargs): - yield item + # Explicitly close the inner async generator in a finally block + # so that its cleanup (except/finally) runs deterministically + # during aclose(). On Python < 3.13, async-for does NOT call + # aclose() on the iterator when an exception (e.g. GeneratorExit) + # propagates through the loop body, leaving cleanup to the GC. + inner = method(self, params, context, *args, **kwargs) + try: + async for item in inner: + yield item + finally: + await inner.aclose() return async_gen_wrapper From 7db2adf0f858c5bc59ae0e300a0be21b51b0d808 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 16:20:43 +0000 Subject: [PATCH 15/16] Change comment --- src/a2a/server/request_handlers/request_handler.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index b75e3ca11..23b0f2b95 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -241,11 +241,11 @@ async def async_gen_wrapper( ) -> Any: if params is not None: validate_proto_required_fields(params) - # Explicitly close the inner async generator in a finally block - # so that its cleanup (except/finally) runs deterministically - # during aclose(). On Python < 3.13, async-for does NOT call - # aclose() on the iterator when an exception (e.g. GeneratorExit) - # propagates through the loop body, leaving cleanup to the GC. + # Ensure the inner async generator is closed explicitly; + # bare async-for does not call aclose() on GeneratorExit, + # which on Python 3.12+ prevents the except/finally blocks + # in on_message_send_stream from running on client disconnect + # (background_consume and cleanup_producer tasks are never created). inner = method(self, params, context, *args, **kwargs) try: async for item in inner: From 8877612c4223e2203f9dc34b6f0b8857b2c201e4 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 16:31:51 +0000 Subject: [PATCH 16/16] Updates --- src/a2a/server/request_handlers/grpc_handler.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/a2a/server/request_handlers/grpc_handler.py b/src/a2a/server/request_handlers/grpc_handler.py index e4cb39492..2ea110e2b 100644 --- a/src/a2a/server/request_handlers/grpc_handler.py +++ b/src/a2a/server/request_handlers/grpc_handler.py @@ -35,11 +35,7 @@ from a2a.types import a2a_pb2 from a2a.types.a2a_pb2 import AgentCard from a2a.utils import proto_utils -from a2a.utils.errors import ( - A2A_ERROR_REASONS, - A2AError, - TaskNotFoundError, -) +from a2a.utils.errors import A2A_ERROR_REASONS, A2AError, TaskNotFoundError from a2a.utils.helpers import maybe_await, validate from a2a.utils.proto_utils import validation_errors_to_bad_request @@ -426,7 +422,6 @@ async def abort_context( rich_status = rpc_status.to_status(status) new_metadata: list[tuple[str, str | bytes]] = [] - trailing = context.trailing_metadata() if trailing: for k, v in trailing: