diff --git a/src/together/error.py b/src/together/error.py index e2883a2..97e33c8 100644 --- a/src/together/error.py +++ b/src/together/error.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +from collections.abc import Mapping from typing import Any, Dict from requests import RequestException @@ -8,6 +9,16 @@ from together.types.error import TogetherErrorResponse +def _json_safe_headers(headers: str | Dict[Any, Any] | None) -> str | Dict[Any, Any]: + if headers is None: + return {} + if isinstance(headers, str): + return headers + if isinstance(headers, Mapping): + return {str(key): value for key, value in headers.items()} + return str(headers) + + class TogetherException(Exception): def __init__( self, @@ -43,8 +54,9 @@ def __repr__(self) -> str: "response": self._message, "status": self.http_status, "request_id": self.request_id, - "headers": self.headers, - } + "headers": _json_safe_headers(self.headers), + }, + default=str, ) return "%s(%r)" % (self.__class__.__name__, repr_message) diff --git a/tests/unit/test_error_repr.py b/tests/unit/test_error_repr.py new file mode 100644 index 0000000..e94dd63 --- /dev/null +++ b/tests/unit/test_error_repr.py @@ -0,0 +1,108 @@ +"""Tests for TogetherException.__repr__ with non-JSON-serializable headers.""" + +from __future__ import annotations + +import json +from unittest.mock import MagicMock + +from multidict import CIMultiDict, CIMultiDictProxy + +from together.error import ( + APIError, + AuthenticationError, + ResponseError, + TogetherException, +) + + +class TestExceptionReprNonSerializable: + """Verify __repr__ doesn't crash on non-JSON-serializable headers (issue #108).""" + + def test_repr_with_dict_headers(self) -> None: + """Normal dict headers should still work fine.""" + exc = TogetherException( + message="test error", + headers={"Content-Type": "application/json"}, + http_status=400, + request_id="req-123", + ) + result = repr(exc) + assert "TogetherException" in result + assert "test error" in result + + def test_repr_with_multidict_proxy_headers(self) -> None: + """Real CIMultiDictProxy headers must not crash repr (issue #108).""" + headers = CIMultiDictProxy( + CIMultiDict( + {"Content-Type": "application/json", "X-Request-Id": "abc"} + ) + ) + exc = TogetherException( + message="server error", + headers=headers, # type: ignore[arg-type] + http_status=500, + request_id="req-456", + ) + # Before fix: TypeError: Object of type CIMultiDictProxy + # is not JSON serializable + result = repr(exc) + assert "TogetherException" in result + assert "server error" in result + parsed = json.loads(result[len("TogetherException(") + 1 : -2]) + assert parsed["headers"] == { + "Content-Type": "application/json", + "X-Request-Id": "abc", + } + + def test_repr_with_none_headers(self) -> None: + """None headers (default) should work.""" + exc = TogetherException(message="no headers") + result = repr(exc) + assert "TogetherException" in result + + def test_repr_with_string_headers(self) -> None: + """String headers should work.""" + exc = TogetherException( + message="string headers", headers="raw-header-string" + ) + result = repr(exc) + assert "TogetherException" in result + + def test_repr_with_nested_non_serializable(self) -> None: + """Dict headers containing non-serializable values should not crash.""" + exc = TogetherException( + message="nested issue", + headers={"key": MagicMock()}, # type: ignore[dict-item] + http_status=502, + ) + result = repr(exc) + assert "TogetherException" in result + + def test_repr_output_is_valid_after_fix(self) -> None: + """repr should produce parseable output (class name + JSON string).""" + exc = TogetherException( + message="validation error", + headers={"Authorization": "Bearer ***"}, + http_status=422, + request_id="req-789", + ) + result = repr(exc) + assert result.startswith("TogetherException(") + # The inner string should be valid JSON + inner = result[len("TogetherException(") + 1 : -2] # strip quotes + parsed = json.loads(inner) + assert parsed["status"] == 422 + assert parsed["request_id"] == "req-789" + + def test_subclass_repr_with_non_serializable_headers(self) -> None: + """Subclasses should also benefit from the fix.""" + headers = CIMultiDictProxy(CIMultiDict({"X-Rate-Limit": "100"})) + + for ExcClass in (AuthenticationError, ResponseError, APIError): + exc = ExcClass( + message="subclass test", + headers=headers, # type: ignore[arg-type] + http_status=429, + ) + result = repr(exc) + assert ExcClass.__name__ in result