diff --git a/tavily/async_tavily.py b/tavily/async_tavily.py index dc896b8..17ca2e9 100644 --- a/tavily/async_tavily.py +++ b/tavily/async_tavily.py @@ -6,7 +6,7 @@ import httpx from .utils import get_max_items_from_list -from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError +from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError, _parse_retry_after class AsyncTavilyClient: @@ -152,7 +152,7 @@ async def _search( pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -266,7 +266,7 @@ async def _extract( if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -375,7 +375,7 @@ async def _crawl(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -485,7 +485,7 @@ async def _map(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -679,7 +679,7 @@ async def stream_generator() -> AsyncGenerator[bytes, None]: error_text = "Unknown error" if response.status_code == 429: - raise UsageLimitExceededError(error_text) + raise UsageLimitExceededError(error_text, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403,432,433]: raise ForbiddenError(error_text) elif response.status_code == 401: @@ -715,7 +715,7 @@ async def _make_request(): pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -794,7 +794,7 @@ async def get_research(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403,432,433]: raise ForbiddenError(detail) elif response.status_code == 401: diff --git a/tavily/errors.py b/tavily/errors.py index 45fbbb4..0263576 100644 --- a/tavily/errors.py +++ b/tavily/errors.py @@ -1,6 +1,70 @@ +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime +from typing import Mapping, Optional + + +def _find_header(headers: Mapping[str, str], name: str) -> Optional[str]: + """Return the value of ``name`` from ``headers`` with case-insensitive lookup.""" + target = name.lower() + for key, value in headers.items(): + if key.lower() == target: + return value + return None + + +def _parse_retry_after(headers: Optional[Mapping[str, str]]) -> Optional[float]: + """Parse an HTTP ``Retry-After`` header value into seconds. + + Handles both forms defined by RFC 7231 §7.1.3: + + - a non-negative decimal integer of seconds, e.g. ``"120"`` + - an HTTP-date, e.g. ``"Wed, 21 Oct 2015 07:28:00 GMT"`` + + Semantics follow ``urllib3.util.Retry.parse_retry_after``: integer + seconds first, then HTTP-date. Negative or past values clamp to ``0.0``. + Returns ``None`` when the header is absent or cannot be parsed (including + non-integer numerics, ``NaN``/``inf``, and malformed dates). + + Accepts any mapping-like ``headers`` object. Case-insensitive header name + lookup is done explicitly so callers passing a plain ``dict`` (not only + ``requests``/``httpx`` header containers) work correctly. + """ + if not headers: + return None + raw = _find_header(headers, "Retry-After") + if raw is None: + return None + raw = raw.strip() + if not raw: + return None + try: + return max(float(int(raw)), 0.0) + except ValueError: + pass + try: + when = parsedate_to_datetime(raw) + except (TypeError, ValueError): + return None + if when is None: + return None + if when.tzinfo is None: + when = when.replace(tzinfo=timezone.utc) + delta = (when - datetime.now(timezone.utc)).total_seconds() + return max(delta, 0.0) + + class UsageLimitExceededError(Exception): - def __init__(self, message: str): + """Raised on HTTP 429 responses from the Tavily API. + + ``retry_after`` carries the server-recommended wait (seconds) parsed from + the ``Retry-After`` response header when present, so callers can honor the + server's backoff instead of guessing. ``None`` when the header is absent + or unparseable. + """ + + def __init__(self, message: str, retry_after: Optional[float] = None): super().__init__(message) + self.retry_after = retry_after class BadRequestError(Exception): diff --git a/tavily/tavily.py b/tavily/tavily.py index b059734..306d078 100644 --- a/tavily/tavily.py +++ b/tavily/tavily.py @@ -4,7 +4,7 @@ import warnings from typing import Literal, Sequence, Optional, List, Union, Generator from .utils import get_max_items_from_list -from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError +from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError, _parse_retry_after class TavilyClient: """ @@ -130,7 +130,7 @@ def _search(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -237,7 +237,7 @@ def _extract(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -340,7 +340,7 @@ def _crawl(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -448,7 +448,7 @@ def _map(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -617,7 +617,7 @@ def _research(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -656,7 +656,7 @@ def stream_generator() -> Generator[bytes, None, None]: pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: @@ -728,7 +728,7 @@ def get_research(self, pass if response.status_code == 429: - raise UsageLimitExceededError(detail) + raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers)) elif response.status_code in [403, 432, 433]: raise ForbiddenError(detail) elif response.status_code == 401: diff --git a/tests/test_retry_after.py b/tests/test_retry_after.py new file mode 100644 index 0000000..0e9e4cd --- /dev/null +++ b/tests/test_retry_after.py @@ -0,0 +1,126 @@ +"""Tests for Retry-After header propagation on 429 responses. + +Tavily API returns a ``Retry-After`` header when it rejects a request with +``429 Too Many Requests``. The SDK must expose that value on +``UsageLimitExceededError`` so callers can honor the server's recommended +wait instead of falling back to a fixed backoff. +""" + +import asyncio + +import pytest + +from tavily.errors import UsageLimitExceededError, _parse_retry_after + + +RATE_LIMIT_BODY = {"detail": {"error": "rate limit exceeded"}} + + +def test_sync_429_exposes_retry_after_seconds(sync_interceptor, sync_client): + sync_interceptor.set_response(429, headers={"Retry-After": "7"}, json=RATE_LIMIT_BODY) + + with pytest.raises(UsageLimitExceededError) as exc_info: + sync_client.search("What is Tavily?") + + assert exc_info.value.retry_after == pytest.approx(7.0) + + +def test_sync_429_retry_after_absent_is_none(sync_interceptor, sync_client): + sync_interceptor.set_response(429, json=RATE_LIMIT_BODY) + + with pytest.raises(UsageLimitExceededError) as exc_info: + sync_client.search("What is Tavily?") + + assert exc_info.value.retry_after is None + + +def test_sync_429_retry_after_http_date(sync_interceptor, sync_client): + from email.utils import format_datetime + from datetime import datetime, timezone, timedelta + + future = datetime.now(timezone.utc) + timedelta(seconds=30) + sync_interceptor.set_response( + 429, + headers={"Retry-After": format_datetime(future, usegmt=True)}, + json=RATE_LIMIT_BODY, + ) + + with pytest.raises(UsageLimitExceededError) as exc_info: + sync_client.search("What is Tavily?") + + assert exc_info.value.retry_after is not None + assert 20 <= exc_info.value.retry_after <= 40 + + +def test_sync_429_retry_after_malformed_is_none(sync_interceptor, sync_client): + sync_interceptor.set_response( + 429, headers={"Retry-After": "not-a-number"}, json=RATE_LIMIT_BODY + ) + + with pytest.raises(UsageLimitExceededError) as exc_info: + sync_client.search("What is Tavily?") + + assert exc_info.value.retry_after is None + + +def test_async_429_exposes_retry_after_seconds(async_interceptor, async_client): + async_interceptor.set_response(429, headers={"Retry-After": "12"}, json=RATE_LIMIT_BODY) + + with pytest.raises(UsageLimitExceededError) as exc_info: + asyncio.run(async_client.search("What is Tavily?")) + + assert exc_info.value.retry_after == pytest.approx(12.0) + + +def test_async_429_retry_after_absent_is_none(async_interceptor, async_client): + async_interceptor.set_response(429, json=RATE_LIMIT_BODY) + + with pytest.raises(UsageLimitExceededError) as exc_info: + asyncio.run(async_client.search("What is Tavily?")) + + assert exc_info.value.retry_after is None + + +def test_usage_limit_error_default_retry_after_is_none(): + err = UsageLimitExceededError("boom") + assert err.retry_after is None + + +def test_usage_limit_error_accepts_retry_after(): + err = UsageLimitExceededError("boom", retry_after=3.5) + assert err.retry_after == 3.5 + + +def test_parse_retry_after_rejects_fractional_seconds(): + # RFC 7231 §7.1.3 only defines non-negative integer seconds. + assert _parse_retry_after({"Retry-After": "7.5"}) is None + + +def test_parse_retry_after_clamps_negative_seconds(): + assert _parse_retry_after({"Retry-After": "-10"}) == 0.0 + + +def test_parse_retry_after_rejects_nan_and_inf(): + assert _parse_retry_after({"Retry-After": "nan"}) is None + assert _parse_retry_after({"Retry-After": "inf"}) is None + + +def test_parse_retry_after_case_insensitive_lookup(): + assert _parse_retry_after({"retry-after": "5"}) == 5.0 + assert _parse_retry_after({"RETRY-AFTER": "5"}) == 5.0 + + +def test_parse_retry_after_past_http_date_clamps_to_zero(): + from email.utils import format_datetime + from datetime import datetime, timezone, timedelta + + past = datetime.now(timezone.utc) - timedelta(seconds=60) + result = _parse_retry_after({"Retry-After": format_datetime(past, usegmt=True)}) + assert result == 0.0 + + +def test_parse_retry_after_empty_and_none(): + assert _parse_retry_after(None) is None + assert _parse_retry_after({}) is None + assert _parse_retry_after({"Retry-After": ""}) is None + assert _parse_retry_after({"Retry-After": " "}) is None