From 8730c518a2c9f33fe4e88b5d9f6f498501d2a449 Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 14:52:27 +0000 Subject: [PATCH 1/3] fix: handle SSE errors occurred after stream started Currently it'd close the connection. --- src/a2a/client/transports/http_helpers.py | 20 ++- src/a2a/client/transports/jsonrpc.py | 10 +- src/a2a/client/transports/rest.py | 83 +++++++----- src/a2a/compat/v0_3/jsonrpc_adapter.py | 3 +- src/a2a/compat/v0_3/jsonrpc_transport.py | 8 ++ src/a2a/compat/v0_3/rest_adapter.py | 17 ++- src/a2a/compat/v0_3/rest_transport.py | 31 ++++- src/a2a/server/apps/rest/rest_adapter.py | 17 ++- src/a2a/server/routes/jsonrpc_dispatcher.py | 26 +++- src/a2a/utils/error_handlers.py | 83 ++++++------ tests/compat/v0_3/test_jsonrpc_transport.py | 1 + tests/compat/v0_3/test_rest_transport.py | 1 + .../test_client_server_integration.py | 76 +++++++++++ .../server/apps/rest/test_rest_fastapi_app.py | 125 +++++++++++++++++- tests/utils/test_error_handlers.py | 19 +-- 15 files changed, 422 insertions(+), 98 deletions(-) diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index 301782e36..742dab4d2 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -69,11 +69,23 @@ async def send_http_stream_request( httpx_client: httpx.AsyncClient, method: str, url: str, - status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn] - | None = None, + status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn], + sse_error_handler: Callable[[str], NoReturn], **kwargs: Any, ) -> AsyncGenerator[str]: - """Sends a streaming HTTP request, yielding SSE data strings and handling exceptions.""" + """Sends a streaming HTTP request, yielding SSE data strings and handling exceptions. + + Args: + httpx_client: The async HTTP client. + method: The HTTP method (e.g. 'POST', 'GET'). + url: The URL to send the request to. + status_error_handler: Handler for HTTP status errors. Should raise an + appropriate domain-specific exception. + sse_error_handler: Handler for SSE error events. Called with the + raw SSE data string when an ``event: error`` SSE event is received. + Should raise an appropriate domain-specific exception. + **kwargs: Additional keyword arguments forwarded to ``aconnect_sse``. + """ with handle_http_exceptions(status_error_handler): async with aconnect_sse( httpx_client, method, url, **kwargs @@ -97,4 +109,6 @@ async def send_http_stream_request( async for sse in event_source.aiter_sse(): if not sse.data: continue + if sse.event == 'error': + sse_error_handler(sse.data) yield sse.data diff --git a/src/a2a/client/transports/jsonrpc.py b/src/a2a/client/transports/jsonrpc.py index 9854aabb0..cf5933782 100644 --- a/src/a2a/client/transports/jsonrpc.py +++ b/src/a2a/client/transports/jsonrpc.py @@ -1,7 +1,7 @@ import logging from collections.abc import AsyncGenerator -from typing import Any +from typing import Any, NoReturn from uuid import uuid4 import httpx @@ -349,6 +349,7 @@ async def _send_stream_request( 'POST', self.url, None, + self._handle_sse_error, json=rpc_request_payload, **http_kwargs, ): @@ -359,3 +360,10 @@ async def _send_stream_request( json_rpc_response.result, StreamResponse() ) yield response + + def _handle_sse_error(self, sse_data: str) -> NoReturn: + """Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error.""" + json_rpc_response = JSONRPC20Response.from_json(sse_data) + if json_rpc_response.error: + raise self._create_jsonrpc_error(json_rpc_response.error) + raise A2AClientError(f'SSE stream error: {sse_data}') diff --git a/src/a2a/client/transports/rest.py b/src/a2a/client/transports/rest.py index ed40d31c7..3dfe95927 100644 --- a/src/a2a/client/transports/rest.py +++ b/src/a2a/client/transports/rest.py @@ -41,6 +41,47 @@ logger = logging.getLogger(__name__) +def _parse_rest_error( + error_payload: dict[str, Any], + fallback_message: str, +) -> Exception | None: + """Parses a REST error payload and returns the appropriate A2AError. + + Args: + error_payload: The parsed JSON error payload. + fallback_message: Message to use if the payload has no ``message``. + + Returns: + The mapped A2AError if a known reason was found, otherwise ``None``. + """ + error_data = error_payload.get('error', {}) + message = error_data.get('message', fallback_message) + details = error_data.get('details', []) + if not isinstance(details, list): + return None + + # The `details` array can contain multiple different error objects. + # We extract the first `ErrorInfo` object because it contains the + # specific `reason` code needed to map this back to a Python A2AError. + for d in details: + if ( + isinstance(d, dict) + and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo' + ): + reason = d.get('reason') + metadata = d.get('metadata') or {} + if isinstance(reason, str): + exception_cls = A2A_REASON_TO_ERROR.get(reason) + if exception_cls: + exc = exception_cls(message) + if metadata: + exc.data = metadata + return exc + break + + return None + + @trace_class(kind=SpanKind.CLIENT) class RestTransport(ClientTransport): """A REST transport for the A2A client.""" @@ -294,39 +335,12 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: """Handles HTTP status errors and raises the appropriate A2AError.""" try: error_payload = e.response.json() - error_data = error_payload.get('error', {}) - - message = error_data.get('message', str(e)) - details = error_data.get('details', []) - if not isinstance(details, list): - details = [] - - # The `details` array can contain multiple different error objects. - # We extract the first `ErrorInfo` object because it contains the - # specific `reason` code needed to map this back to a Python A2AError. - error_info = {} - for d in details: - if ( - isinstance(d, dict) - and d.get('@type') - == 'type.googleapis.com/google.rpc.ErrorInfo' - ): - error_info = d - break - reason = error_info.get('reason') - metadata = error_info.get('metadata') or {} - - if isinstance(reason, str): - exception_cls = A2A_REASON_TO_ERROR.get(reason) - if exception_cls: - exc = exception_cls(message) - if metadata: - exc.data = metadata - raise exc from e + mapped = _parse_rest_error(error_payload, str(e)) + if mapped: + raise mapped from e except (json.JSONDecodeError, ValueError): pass - # Fallback mappings for status codes if 'type' is missing or unknown status_code = e.response.status_code if status_code == httpx.codes.NOT_FOUND: raise MethodNotFoundError( @@ -335,6 +349,14 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: raise A2AClientError(f'HTTP Error {status_code}: {e}') from e + def _handle_sse_error(self, sse_data: str) -> NoReturn: + """Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError.""" + error_payload = json.loads(sse_data) + mapped = _parse_rest_error(error_payload, sse_data) + if mapped: + raise mapped + raise A2AClientError(sse_data) + async def _send_stream_request( self, method: str, @@ -352,6 +374,7 @@ async def _send_stream_request( method, f'{self.url}{path}', self._handle_http_error, + self._handle_sse_error, json=json, **http_kwargs, ): diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index 073c7854b..fcd9bf81b 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -306,9 +306,10 @@ async def event_generator( ) ) yield { + 'event': 'error', 'data': err_resp.model_dump_json( by_alias=True, exclude_none=True - ) + ), } return EventSourceResponse(event_generator(stream_gen)) diff --git a/src/a2a/compat/v0_3/jsonrpc_transport.py b/src/a2a/compat/v0_3/jsonrpc_transport.py index 6153ccfc0..93661223a 100644 --- a/src/a2a/compat/v0_3/jsonrpc_transport.py +++ b/src/a2a/compat/v0_3/jsonrpc_transport.py @@ -415,6 +415,13 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: """Handles HTTP errors for standard requests.""" raise A2AClientError(f'HTTP Error: {e.response.status_code}') from e + def _handle_sse_error(self, sse_data: str) -> NoReturn: + """Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error.""" + data = json.loads(sse_data) + if 'error' in data: + raise self._create_jsonrpc_error(data['error']) + raise A2AClientError(f'SSE stream error: {sse_data}') + async def _send_stream_request( self, json_data: dict[str, Any], @@ -430,6 +437,7 @@ async def _send_stream_request( 'POST', self.url, self._handle_http_error, + self._handle_sse_error, json=json_data, **http_kwargs, ): diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index 8cae6b630..a21e7a84e 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: + from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -17,6 +18,7 @@ _package_starlette_installed = True else: try: + from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -27,6 +29,7 @@ Request = Any JSONResponse = Any Response = Any + ServerSentEvent = Any _package_starlette_installed = False @@ -37,6 +40,7 @@ from a2a.server.context import ServerCallContext from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder from a2a.utils.error_handlers import ( + build_rest_error_payload, rest_error_handler, rest_stream_error_handler, ) @@ -101,9 +105,16 @@ async def _handle_streaming_request( async def event_generator( stream: AsyncIterable[Any], - ) -> AsyncIterator[str]: - async for item in stream: - yield json.dumps(item) + ) -> AsyncIterator[str | ServerSentEvent]: + try: + async for item in stream: + yield json.dumps(item) + except Exception as e: + logger.exception('Error during v0.3 REST SSE stream') + yield ServerSentEvent( + data=json.dumps(build_rest_error_payload(e)), + event='error', + ) return EventSourceResponse( event_generator(method(request, call_context)) diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index 0ba38538d..ee7e52126 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -44,7 +44,11 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER -from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, MethodNotFoundError +from a2a.utils.errors import ( + A2A_REASON_TO_ERROR, + JSON_RPC_ERROR_CODE_MAP, + MethodNotFoundError, +) from a2a.utils.telemetry import SpanKind, trace_class @@ -369,6 +373,30 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: raise A2AClientError(f'HTTP Error {status_code}: {e}') from e + def _handle_sse_error(self, sse_data: str) -> NoReturn: + """Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError.""" + error_payload = json.loads(sse_data) + error_data = error_payload.get('error', {}) + + message = error_data.get('message', sse_data) + details = error_data.get('details', []) + if not isinstance(details, list): + details = [] + + for d in details: + if ( + isinstance(d, dict) + and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo' + ): + reason = d.get('reason') + if isinstance(reason, str): + exception_cls = A2A_REASON_TO_ERROR.get(reason) + if exception_cls: + raise exception_cls(message) + break + + raise A2AClientError(message) + async def _send_stream_request( self, method: str, @@ -386,6 +414,7 @@ async def _send_stream_request( method, f'{self.url}{path}', self._handle_http_error, + self._handle_sse_error, json=json, **http_kwargs, ): diff --git a/src/a2a/server/apps/rest/rest_adapter.py b/src/a2a/server/apps/rest/rest_adapter.py index 2a1ed95c3..10fc5a415 100644 --- a/src/a2a/server/apps/rest/rest_adapter.py +++ b/src/a2a/server/apps/rest/rest_adapter.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: + from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -20,6 +21,7 @@ else: try: + from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -30,6 +32,7 @@ Request = Any JSONResponse = Any Response = Any + ServerSentEvent = Any _package_starlette_installed = False @@ -42,6 +45,7 @@ from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder from a2a.types.a2a_pb2 import AgentCard from a2a.utils.error_handlers import ( + build_rest_error_payload, rest_error_handler, rest_stream_error_handler, ) @@ -163,10 +167,17 @@ async def _handle_streaming_request( except StopAsyncIteration: return EventSourceResponse(iter([])) - async def event_generator() -> AsyncIterator[str]: + async def event_generator() -> AsyncIterator[str | ServerSentEvent]: yield json.dumps(first_item) - async for item in stream: - yield json.dumps(item) + try: + async for item in stream: + yield json.dumps(item) + except Exception as e: + logger.exception('Error during REST SSE stream') + yield ServerSentEvent( + data=json.dumps(build_rest_error_payload(e)), + event='error', + ) return EventSourceResponse(event_generator()) diff --git a/src/a2a/server/routes/jsonrpc_dispatcher.py b/src/a2a/server/routes/jsonrpc_dispatcher.py index 1ce5f0fe8..87c026a3f 100644 --- a/src/a2a/server/routes/jsonrpc_dispatcher.py +++ b/src/a2a/server/routes/jsonrpc_dispatcher.py @@ -559,8 +559,30 @@ def _create_response( async def event_generator( stream: AsyncGenerator[dict[str, Any]], ) -> AsyncGenerator[dict[str, str]]: - async for item in stream: - yield {'data': json.dumps(item)} + try: + async for item in stream: + event: dict[str, str] = { + 'data': json.dumps(item), + } + if 'error' in item: + event['event'] = 'error' + yield event + except Exception as e: + logger.exception( + 'Unhandled error during JSON-RPC SSE stream' + ) + rpc_error: A2AError | JSONRPCError = ( + e + if isinstance(e, A2AError | JSONRPCError) + else InternalError(message=str(e)) + ) + error_response = build_error_response( + context.state.get('request_id'), rpc_error + ) + yield { + 'event': 'error', + 'data': json.dumps(error_response), + } return EventSourceResponse( event_generator(handler_result), headers=headers diff --git a/src/a2a/utils/error_handlers.py b/src/a2a/utils/error_handlers.py index d21a9e24c..7b22c7dc0 100644 --- a/src/a2a/utils/error_handlers.py +++ b/src/a2a/utils/error_handlers.py @@ -54,16 +54,43 @@ def _build_error_payload( return {'error': payload} -def _create_error_response(error: Exception) -> Response: - """Helper function to create a JSONResponse for an error.""" +def build_rest_error_payload(error: Exception) -> dict[str, Any]: + """Build a REST error payload dict from an exception. + + Returns: + A dict with the error payload in the standard REST error format. + """ if isinstance(error, A2AError): mapping = A2A_REST_ERROR_MAPPING.get( type(error), RestErrorMap(500, 'INTERNAL', 'INTERNAL_ERROR') ) - http_code = mapping.http_code - grpc_status = mapping.grpc_status - reason = mapping.reason + # SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response. + metadata = getattr(error, 'data', None) or {} + return _build_error_payload( + code=mapping.http_code, + status=mapping.grpc_status, + message=getattr(error, 'message', str(error)), + reason=mapping.reason, + metadata=metadata, + ) + if isinstance(error, ParseError): + return _build_error_payload( + code=400, + status='INVALID_ARGUMENT', + message=str(error), + reason='INVALID_REQUEST', + metadata={}, + ) + return _build_error_payload( + code=500, + status='INTERNAL', + message='unknown exception', + ) + +def _create_error_response(error: Exception) -> Response: + """Helper function to create a JSONResponse for an error.""" + if isinstance(error, A2AError): log_level = ( logging.ERROR if isinstance(error, InternalError) @@ -76,42 +103,17 @@ def _create_error_response(error: Exception) -> Response: getattr(error, 'message', str(error)), f', Data={error.data}' if error.data else '', ) - - # SECURITY WARNING: Data attached to A2AError.data is serialized unaltered and exposed publicly to the client in the REST API response. - metadata = getattr(error, 'data', None) or {} - - return JSONResponse( - content=_build_error_payload( - code=http_code, - status=grpc_status, - message=getattr(error, 'message', str(error)), - reason=reason, - metadata=metadata, - ), - status_code=http_code, - media_type='application/json', - ) - if isinstance(error, ParseError): + elif isinstance(error, ParseError): logger.warning('Parse error: %s', str(error)) - return JSONResponse( - content=_build_error_payload( - code=400, - status='INVALID_ARGUMENT', - message=str(error), - reason='INVALID_REQUEST', - metadata={}, - ), - status_code=400, - media_type='application/json', - ) - logger.exception('Unknown error occurred') + else: + logger.exception('Unknown error occurred') + + payload = build_rest_error_payload(error) + # Extract HTTP status code from the payload + http_code = payload.get('error', {}).get('code', 500) return JSONResponse( - content=_build_error_payload( - code=500, - status='INTERNAL', - message='unknown exception', - ), - status_code=500, + content=payload, + status_code=http_code, media_type='application/json', ) @@ -171,9 +173,8 @@ async def error_catching_iterator() -> AsyncGenerator[ try: async for item in original_iterator: yield item - except Exception as stream_error: + except Exception as stream_error: # noqa: BLE001 _log_error(stream_error) - raise stream_error response.body_iterator = error_catching_iterator() diff --git a/tests/compat/v0_3/test_jsonrpc_transport.py b/tests/compat/v0_3/test_jsonrpc_transport.py index 250608014..15b499854 100644 --- a/tests/compat/v0_3/test_jsonrpc_transport.py +++ b/tests/compat/v0_3/test_jsonrpc_transport.py @@ -480,6 +480,7 @@ async def mock_generator(*args, **kwargs): 'POST', 'http://example.com', transport._handle_http_error, + transport._handle_sse_error, json={'some': 'data'}, headers={'a2a-version': '0.3'}, ) diff --git a/tests/compat/v0_3/test_rest_transport.py b/tests/compat/v0_3/test_rest_transport.py index 2bea70f42..a5bf12267 100644 --- a/tests/compat/v0_3/test_rest_transport.py +++ b/tests/compat/v0_3/test_rest_transport.py @@ -634,6 +634,7 @@ async def mock_generator(*args, **kwargs): 'POST', 'http://example.com/test', transport._handle_http_error, + transport._handle_sse_error, json=None, headers={'a2a-version': '0.3'}, ) diff --git a/tests/integration/test_client_server_integration.py b/tests/integration/test_client_server_integration.py index 2df24790b..f2de48567 100644 --- a/tests/integration/test_client_server_integration.py +++ b/tests/integration/test_client_server_integration.py @@ -1129,3 +1129,79 @@ async def test_validate_streaming_disabled( pass await transport.close() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'error_cls', + [ + TaskNotFoundError, + TaskNotCancelableError, + PushNotificationNotSupportedError, + UnsupportedOperationError, + ContentTypeNotSupportedError, + InvalidAgentResponseError, + ExtendedAgentCardNotConfiguredError, + ExtensionSupportRequiredError, + VersionNotSupportedError, + ], +) +@pytest.mark.parametrize( + 'handler_attr, client_method, request_params', + [ + pytest.param( + 'on_message_send_stream', + 'send_message', + SendMessageRequest( + message=Message( + role=Role.ROLE_USER, + message_id='msg-midstream-test', + parts=[Part(text='Hello, mid-stream test!')], + ) + ), + id='stream', + ), + pytest.param( + 'on_subscribe_to_task', + 'subscribe', + SubscribeToTaskRequest(id='some-id'), + id='subscribe', + ), + ], +) +async def test_client_handles_mid_stream_a2a_errors( + transport_setups, + error_cls, + handler_attr, + client_method, + request_params, +) -> None: + """Integration test for mid-stream errors sent as SSE error events. + + The handler yields one event successfully, then raises an A2AError. + The client must receive the first event and then get the error as the + exact error_cls exception. This mirrors test_client_handles_a2a_errors_streaming + but verifies the error occurs *after* the stream has started producing events. + """ + client = transport_setups.client + handler = transport_setups.handler + + async def mock_generator(*args, **kwargs): + yield TASK_FROM_STREAM + raise error_cls('Mid-stream error') + + getattr(handler, handler_attr).side_effect = mock_generator + + received_events = [] + with pytest.raises(error_cls) as exc_info: + async for event in getattr(client, client_method)( + request=request_params + ): + received_events.append(event) # noqa: PERF401 + + assert 'Mid-stream error' in str(exc_info.value) + assert len(received_events) == 1 + + getattr(handler, handler_attr).side_effect = None + + await client.close() diff --git a/tests/server/apps/rest/test_rest_fastapi_app.py b/tests/server/apps/rest/test_rest_fastapi_app.py index 1c976c94b..e29daf422 100644 --- a/tests/server/apps/rest/test_rest_fastapi_app.py +++ b/tests/server/apps/rest/test_rest_fastapi_app.py @@ -1,5 +1,5 @@ -import logging import json +import logging from typing import Any from unittest.mock import MagicMock @@ -27,6 +27,7 @@ TaskState, TaskStatus, ) +from a2a.utils.errors import InternalError logger = logging.getLogger(__name__) @@ -724,5 +725,127 @@ async def test_global_http_exception_handler_returns_rpc_status( assert 'details' not in error_payload +@pytest.mark.anyio +async def test_streaming_mid_stream_error_emits_sse_error_event( + streaming_client: AsyncClient, request_handler: MagicMock +) -> None: + """Test that mid-stream errors are sent as SSE error events instead of crashing the stream.""" + + async def mock_stream_then_error(): + yield Message( + message_id='stream_msg_1', + role=Role.ROLE_AGENT, + parts=[Part(text='First chunk')], + ) + raise InternalError(message='Something went wrong mid-stream') + + request_handler.on_message_send_stream.return_value = ( + mock_stream_then_error() + ) + + request = a2a_pb2.SendMessageRequest( + message=a2a_pb2.Message( + message_id='test_stream_msg', + role=a2a_pb2.ROLE_USER, + parts=[a2a_pb2.Part(text='Test message')], + ), + ) + + response = await streaming_client.post( + '/message:stream', + headers={'Accept': 'text/event-stream'}, + json=json_format.MessageToDict(request), + ) + + response.raise_for_status() + assert 'text/event-stream' in response.headers.get('content-type', '') + + lines = [line async for line in response.aiter_lines()] + + # First event should be a normal data event with the message + data_lines = [ + json.loads(line[6:]) for line in lines if line.startswith('data: ') + ] + assert len(data_lines) >= 1 + assert 'message' in data_lines[0] + assert data_lines[0]['message']['messageId'] == 'stream_msg_1' + + # Last event should be an SSE error event + error_event_lines = [line for line in lines if line.startswith('event: ')] + assert any('error' in line for line in error_event_lines) + + # Find the data line after the error event + error_data = None + for i, line in enumerate(lines): + if line == 'event: error': + # The next non-empty line should be the error data + for j in range(i + 1, len(lines)): + if lines[j].startswith('data: '): + error_data = json.loads(lines[j][6:]) + break + break + + assert error_data is not None + assert 'error' in error_data + assert error_data['error']['code'] == 500 + assert error_data['error']['status'] == 'INTERNAL' + assert 'Something went wrong mid-stream' in error_data['error']['message'] + + +@pytest.mark.anyio +async def test_streaming_mid_stream_unknown_error_emits_sse_error_event( + streaming_client: AsyncClient, request_handler: MagicMock +) -> None: + """Test that non-A2AError mid-stream errors also produce SSE error events.""" + + async def mock_stream_then_error(): + yield Message( + message_id='stream_msg_1', + role=Role.ROLE_AGENT, + parts=[Part(text='First chunk')], + ) + raise RuntimeError('Unexpected failure') + + request_handler.on_message_send_stream.return_value = ( + mock_stream_then_error() + ) + + request = a2a_pb2.SendMessageRequest( + message=a2a_pb2.Message( + message_id='test_stream_msg', + role=a2a_pb2.ROLE_USER, + parts=[a2a_pb2.Part(text='Test message')], + ), + ) + + response = await streaming_client.post( + '/message:stream', + headers={'Accept': 'text/event-stream'}, + json=json_format.MessageToDict(request), + ) + + response.raise_for_status() + + lines = [line async for line in response.aiter_lines()] + + # Should have an error event + error_event_lines = [line for line in lines if line.startswith('event: ')] + assert any('error' in line for line in error_event_lines) + + # Find the error data + error_data = None + for i, line in enumerate(lines): + if line == 'event: error': + for j in range(i + 1, len(lines)): + if lines[j].startswith('data: '): + error_data = json.loads(lines[j][6:]) + break + break + + assert error_data is not None + assert error_data['error']['code'] == 500 + assert error_data['error']['status'] == 'INTERNAL' + + if __name__ == '__main__': pytest.main([__file__]) diff --git a/tests/utils/test_error_handlers.py b/tests/utils/test_error_handlers.py index 93ad6a7c0..29eabecaf 100644 --- a/tests/utils/test_error_handlers.py +++ b/tests/utils/test_error_handlers.py @@ -125,7 +125,7 @@ async def successful_func(): @pytest.mark.asyncio async def test_rest_stream_error_handler_generator_error(caplog): - """Test rest_stream_error_handler catches error during async generation after first success.""" + """Test rest_stream_error_handler logs error during async generation and ends stream gracefully.""" error = InternalError(message='Stream error during generation') async def failing_generator(): @@ -141,21 +141,18 @@ async def successful_prep_failing_stream(): # Assert it returns successfully assert isinstance(response, MockEventSourceResponse) - # Now consume the stream + # Consume the stream - error should be logged but not re-raised chunks = [] - with ( - caplog.at_level(logging.ERROR), - pytest.raises(InternalError) as exc_info, - ): + with caplog.at_level(logging.ERROR): async for chunk in response.body_iterator: chunks.append(chunk) # noqa: PERF401 assert chunks == ['success chunk 1'] - assert exc_info.value == error + assert 'Stream error during generation' in caplog.text @pytest.mark.asyncio async def test_rest_stream_error_handler_generator_unknown_error(caplog): - """Test rest_stream_error_handler catches unknown error during async generation.""" + """Test rest_stream_error_handler logs unknown error during async generation and ends stream gracefully.""" async def failing_generator(): yield 'success chunk 1' @@ -167,11 +164,9 @@ async def successful_prep_failing_stream(): response = await successful_prep_failing_stream() + # Consume the stream - error should be logged but not re-raised chunks = [] - with ( - caplog.at_level(logging.ERROR), - pytest.raises(RuntimeError, match='Unknown stream failure'), - ): + with caplog.at_level(logging.ERROR): async for chunk in response.body_iterator: chunks.append(chunk) # noqa: PERF401 assert chunks == ['success chunk 1'] From 075a7b72c7bfcaf479fff5c81c3faae34a02122f Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 15:57:31 +0000 Subject: [PATCH 2/3] Updates --- src/a2a/client/transports/http_helpers.py | 6 +++- src/a2a/compat/v0_3/jsonrpc_adapter.py | 3 +- src/a2a/compat/v0_3/jsonrpc_transport.py | 8 ------ src/a2a/compat/v0_3/rest_adapter.py | 17 ++--------- src/a2a/compat/v0_3/rest_transport.py | 31 +-------------------- tests/compat/v0_3/test_jsonrpc_transport.py | 1 - tests/compat/v0_3/test_rest_transport.py | 1 - 7 files changed, 10 insertions(+), 57 deletions(-) diff --git a/src/a2a/client/transports/http_helpers.py b/src/a2a/client/transports/http_helpers.py index 742dab4d2..119dce41f 100644 --- a/src/a2a/client/transports/http_helpers.py +++ b/src/a2a/client/transports/http_helpers.py @@ -12,6 +12,10 @@ from a2a.client.errors import A2AClientError, A2AClientTimeoutError +def _default_sse_error_handler(sse_data: str) -> NoReturn: + raise A2AClientError(f'SSE stream error event received: {sse_data}') + + @contextmanager def handle_http_exceptions( status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn] @@ -70,7 +74,7 @@ async def send_http_stream_request( method: str, url: str, status_error_handler: Callable[[httpx.HTTPStatusError], NoReturn], - sse_error_handler: Callable[[str], NoReturn], + sse_error_handler: Callable[[str], NoReturn] = _default_sse_error_handler, **kwargs: Any, ) -> AsyncGenerator[str]: """Sends a streaming HTTP request, yielding SSE data strings and handling exceptions. diff --git a/src/a2a/compat/v0_3/jsonrpc_adapter.py b/src/a2a/compat/v0_3/jsonrpc_adapter.py index fcd9bf81b..073c7854b 100644 --- a/src/a2a/compat/v0_3/jsonrpc_adapter.py +++ b/src/a2a/compat/v0_3/jsonrpc_adapter.py @@ -306,10 +306,9 @@ async def event_generator( ) ) yield { - 'event': 'error', 'data': err_resp.model_dump_json( by_alias=True, exclude_none=True - ), + ) } return EventSourceResponse(event_generator(stream_gen)) diff --git a/src/a2a/compat/v0_3/jsonrpc_transport.py b/src/a2a/compat/v0_3/jsonrpc_transport.py index 93661223a..6153ccfc0 100644 --- a/src/a2a/compat/v0_3/jsonrpc_transport.py +++ b/src/a2a/compat/v0_3/jsonrpc_transport.py @@ -415,13 +415,6 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: """Handles HTTP errors for standard requests.""" raise A2AClientError(f'HTTP Error: {e.response.status_code}') from e - def _handle_sse_error(self, sse_data: str) -> NoReturn: - """Handles SSE error events by parsing JSON-RPC error payload and raising the appropriate domain error.""" - data = json.loads(sse_data) - if 'error' in data: - raise self._create_jsonrpc_error(data['error']) - raise A2AClientError(f'SSE stream error: {sse_data}') - async def _send_stream_request( self, json_data: dict[str, Any], @@ -437,7 +430,6 @@ async def _send_stream_request( 'POST', self.url, self._handle_http_error, - self._handle_sse_error, json=json_data, **http_kwargs, ): diff --git a/src/a2a/compat/v0_3/rest_adapter.py b/src/a2a/compat/v0_3/rest_adapter.py index a21e7a84e..8cae6b630 100644 --- a/src/a2a/compat/v0_3/rest_adapter.py +++ b/src/a2a/compat/v0_3/rest_adapter.py @@ -7,7 +7,6 @@ if TYPE_CHECKING: - from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -18,7 +17,6 @@ _package_starlette_installed = True else: try: - from sse_starlette.event import ServerSentEvent from sse_starlette.sse import EventSourceResponse from starlette.requests import Request from starlette.responses import JSONResponse, Response @@ -29,7 +27,6 @@ Request = Any JSONResponse = Any Response = Any - ServerSentEvent = Any _package_starlette_installed = False @@ -40,7 +37,6 @@ from a2a.server.context import ServerCallContext from a2a.server.routes import CallContextBuilder, DefaultCallContextBuilder from a2a.utils.error_handlers import ( - build_rest_error_payload, rest_error_handler, rest_stream_error_handler, ) @@ -105,16 +101,9 @@ async def _handle_streaming_request( async def event_generator( stream: AsyncIterable[Any], - ) -> AsyncIterator[str | ServerSentEvent]: - try: - async for item in stream: - yield json.dumps(item) - except Exception as e: - logger.exception('Error during v0.3 REST SSE stream') - yield ServerSentEvent( - data=json.dumps(build_rest_error_payload(e)), - event='error', - ) + ) -> AsyncIterator[str]: + async for item in stream: + yield json.dumps(item) return EventSourceResponse( event_generator(method(request, call_context)) diff --git a/src/a2a/compat/v0_3/rest_transport.py b/src/a2a/compat/v0_3/rest_transport.py index ee7e52126..0ba38538d 100644 --- a/src/a2a/compat/v0_3/rest_transport.py +++ b/src/a2a/compat/v0_3/rest_transport.py @@ -44,11 +44,7 @@ TaskPushNotificationConfig, ) from a2a.utils.constants import PROTOCOL_VERSION_0_3, VERSION_HEADER -from a2a.utils.errors import ( - A2A_REASON_TO_ERROR, - JSON_RPC_ERROR_CODE_MAP, - MethodNotFoundError, -) +from a2a.utils.errors import JSON_RPC_ERROR_CODE_MAP, MethodNotFoundError from a2a.utils.telemetry import SpanKind, trace_class @@ -373,30 +369,6 @@ def _handle_http_error(self, e: httpx.HTTPStatusError) -> NoReturn: raise A2AClientError(f'HTTP Error {status_code}: {e}') from e - def _handle_sse_error(self, sse_data: str) -> NoReturn: - """Handles SSE error events by parsing the REST error payload and raising the appropriate A2AError.""" - error_payload = json.loads(sse_data) - error_data = error_payload.get('error', {}) - - message = error_data.get('message', sse_data) - details = error_data.get('details', []) - if not isinstance(details, list): - details = [] - - for d in details: - if ( - isinstance(d, dict) - and d.get('@type') == 'type.googleapis.com/google.rpc.ErrorInfo' - ): - reason = d.get('reason') - if isinstance(reason, str): - exception_cls = A2A_REASON_TO_ERROR.get(reason) - if exception_cls: - raise exception_cls(message) - break - - raise A2AClientError(message) - async def _send_stream_request( self, method: str, @@ -414,7 +386,6 @@ async def _send_stream_request( method, f'{self.url}{path}', self._handle_http_error, - self._handle_sse_error, json=json, **http_kwargs, ): diff --git a/tests/compat/v0_3/test_jsonrpc_transport.py b/tests/compat/v0_3/test_jsonrpc_transport.py index 15b499854..250608014 100644 --- a/tests/compat/v0_3/test_jsonrpc_transport.py +++ b/tests/compat/v0_3/test_jsonrpc_transport.py @@ -480,7 +480,6 @@ async def mock_generator(*args, **kwargs): 'POST', 'http://example.com', transport._handle_http_error, - transport._handle_sse_error, json={'some': 'data'}, headers={'a2a-version': '0.3'}, ) diff --git a/tests/compat/v0_3/test_rest_transport.py b/tests/compat/v0_3/test_rest_transport.py index a5bf12267..2bea70f42 100644 --- a/tests/compat/v0_3/test_rest_transport.py +++ b/tests/compat/v0_3/test_rest_transport.py @@ -634,7 +634,6 @@ async def mock_generator(*args, **kwargs): 'POST', 'http://example.com/test', transport._handle_http_error, - transport._handle_sse_error, json=None, headers={'a2a-version': '0.3'}, ) From 8601f50eafce4e4d6954c1c66f6a3207db36772d Mon Sep 17 00:00:00 2001 From: Ivan Shymko Date: Tue, 24 Mar 2026 16:03:36 +0000 Subject: [PATCH 3/3] Updates --- src/a2a/utils/error_handlers.py | 1 + tests/utils/test_error_handlers.py | 19 ++++++++++++------- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/a2a/utils/error_handlers.py b/src/a2a/utils/error_handlers.py index 7b22c7dc0..2adc7b172 100644 --- a/src/a2a/utils/error_handlers.py +++ b/src/a2a/utils/error_handlers.py @@ -175,6 +175,7 @@ async def error_catching_iterator() -> AsyncGenerator[ yield item except Exception as stream_error: # noqa: BLE001 _log_error(stream_error) + raise stream_error response.body_iterator = error_catching_iterator() diff --git a/tests/utils/test_error_handlers.py b/tests/utils/test_error_handlers.py index 29eabecaf..93ad6a7c0 100644 --- a/tests/utils/test_error_handlers.py +++ b/tests/utils/test_error_handlers.py @@ -125,7 +125,7 @@ async def successful_func(): @pytest.mark.asyncio async def test_rest_stream_error_handler_generator_error(caplog): - """Test rest_stream_error_handler logs error during async generation and ends stream gracefully.""" + """Test rest_stream_error_handler catches error during async generation after first success.""" error = InternalError(message='Stream error during generation') async def failing_generator(): @@ -141,18 +141,21 @@ async def successful_prep_failing_stream(): # Assert it returns successfully assert isinstance(response, MockEventSourceResponse) - # Consume the stream - error should be logged but not re-raised + # Now consume the stream chunks = [] - with caplog.at_level(logging.ERROR): + with ( + caplog.at_level(logging.ERROR), + pytest.raises(InternalError) as exc_info, + ): async for chunk in response.body_iterator: chunks.append(chunk) # noqa: PERF401 assert chunks == ['success chunk 1'] - assert 'Stream error during generation' in caplog.text + assert exc_info.value == error @pytest.mark.asyncio async def test_rest_stream_error_handler_generator_unknown_error(caplog): - """Test rest_stream_error_handler logs unknown error during async generation and ends stream gracefully.""" + """Test rest_stream_error_handler catches unknown error during async generation.""" async def failing_generator(): yield 'success chunk 1' @@ -164,9 +167,11 @@ async def successful_prep_failing_stream(): response = await successful_prep_failing_stream() - # Consume the stream - error should be logged but not re-raised chunks = [] - with caplog.at_level(logging.ERROR): + with ( + caplog.at_level(logging.ERROR), + pytest.raises(RuntimeError, match='Unknown stream failure'), + ): async for chunk in response.body_iterator: chunks.append(chunk) # noqa: PERF401 assert chunks == ['success chunk 1']