Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions tavily/async_tavily.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import httpx

from .utils import get_max_items_from_list
from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError
from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError, _parse_retry_after


class AsyncTavilyClient:
Expand Down Expand Up @@ -152,7 +152,7 @@ async def _search(
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403,432,433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -266,7 +266,7 @@ async def _extract(


if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403,432,433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -375,7 +375,7 @@ async def _crawl(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403,432,433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -485,7 +485,7 @@ async def _map(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403,432,433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -679,7 +679,7 @@ async def stream_generator() -> AsyncGenerator[bytes, None]:
error_text = "Unknown error"

if response.status_code == 429:
raise UsageLimitExceededError(error_text)
raise UsageLimitExceededError(error_text, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403,432,433]:
raise ForbiddenError(error_text)
elif response.status_code == 401:
Expand Down Expand Up @@ -715,7 +715,7 @@ async def _make_request():
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403,432,433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -794,7 +794,7 @@ async def get_research(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403,432,433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down
66 changes: 65 additions & 1 deletion tavily/errors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,70 @@
from datetime import datetime, timezone
from email.utils import parsedate_to_datetime
from typing import Mapping, Optional


def _find_header(headers: Mapping[str, str], name: str) -> Optional[str]:
"""Return the value of ``name`` from ``headers`` with case-insensitive lookup."""
target = name.lower()
for key, value in headers.items():
if key.lower() == target:
return value
return None


def _parse_retry_after(headers: Optional[Mapping[str, str]]) -> Optional[float]:
"""Parse an HTTP ``Retry-After`` header value into seconds.

Handles both forms defined by RFC 7231 §7.1.3:

- a non-negative decimal integer of seconds, e.g. ``"120"``
- an HTTP-date, e.g. ``"Wed, 21 Oct 2015 07:28:00 GMT"``

Semantics follow ``urllib3.util.Retry.parse_retry_after``: integer
seconds first, then HTTP-date. Negative or past values clamp to ``0.0``.
Returns ``None`` when the header is absent or cannot be parsed (including
non-integer numerics, ``NaN``/``inf``, and malformed dates).

Accepts any mapping-like ``headers`` object. Case-insensitive header name
lookup is done explicitly so callers passing a plain ``dict`` (not only
``requests``/``httpx`` header containers) work correctly.
"""
if not headers:
return None
raw = _find_header(headers, "Retry-After")
if raw is None:
return None
raw = raw.strip()
if not raw:
return None
try:
return max(float(int(raw)), 0.0)
except ValueError:
pass
Comment thread
stihahi marked this conversation as resolved.
try:
when = parsedate_to_datetime(raw)
except (TypeError, ValueError):
return None
if when is None:
return None
if when.tzinfo is None:
when = when.replace(tzinfo=timezone.utc)
delta = (when - datetime.now(timezone.utc)).total_seconds()
return max(delta, 0.0)


class UsageLimitExceededError(Exception):
def __init__(self, message: str):
"""Raised on HTTP 429 responses from the Tavily API.

``retry_after`` carries the server-recommended wait (seconds) parsed from
the ``Retry-After`` response header when present, so callers can honor the
server's backoff instead of guessing. ``None`` when the header is absent
or unparseable.
"""

def __init__(self, message: str, retry_after: Optional[float] = None):
super().__init__(message)
self.retry_after = retry_after


class BadRequestError(Exception):
Expand Down
16 changes: 8 additions & 8 deletions tavily/tavily.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings
from typing import Literal, Sequence, Optional, List, Union, Generator
from .utils import get_max_items_from_list
from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError
from .errors import UsageLimitExceededError, InvalidAPIKeyError, MissingAPIKeyError, BadRequestError, ForbiddenError, TimeoutError, _parse_retry_after

class TavilyClient:
"""
Expand Down Expand Up @@ -130,7 +130,7 @@ def _search(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -237,7 +237,7 @@ def _extract(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -340,7 +340,7 @@ def _crawl(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -448,7 +448,7 @@ def _map(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -617,7 +617,7 @@ def _research(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -656,7 +656,7 @@ def stream_generator() -> Generator[bytes, None, None]:
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down Expand Up @@ -728,7 +728,7 @@ def get_research(self,
pass

if response.status_code == 429:
raise UsageLimitExceededError(detail)
raise UsageLimitExceededError(detail, retry_after=_parse_retry_after(response.headers))
elif response.status_code in [403, 432, 433]:
raise ForbiddenError(detail)
elif response.status_code == 401:
Expand Down
126 changes: 126 additions & 0 deletions tests/test_retry_after.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
"""Tests for Retry-After header propagation on 429 responses.

Tavily API returns a ``Retry-After`` header when it rejects a request with
``429 Too Many Requests``. The SDK must expose that value on
``UsageLimitExceededError`` so callers can honor the server's recommended
wait instead of falling back to a fixed backoff.
"""

import asyncio

import pytest

from tavily.errors import UsageLimitExceededError, _parse_retry_after


RATE_LIMIT_BODY = {"detail": {"error": "rate limit exceeded"}}


def test_sync_429_exposes_retry_after_seconds(sync_interceptor, sync_client):
sync_interceptor.set_response(429, headers={"Retry-After": "7"}, json=RATE_LIMIT_BODY)

with pytest.raises(UsageLimitExceededError) as exc_info:
sync_client.search("What is Tavily?")

assert exc_info.value.retry_after == pytest.approx(7.0)


def test_sync_429_retry_after_absent_is_none(sync_interceptor, sync_client):
sync_interceptor.set_response(429, json=RATE_LIMIT_BODY)

with pytest.raises(UsageLimitExceededError) as exc_info:
sync_client.search("What is Tavily?")

assert exc_info.value.retry_after is None


def test_sync_429_retry_after_http_date(sync_interceptor, sync_client):
from email.utils import format_datetime
from datetime import datetime, timezone, timedelta

future = datetime.now(timezone.utc) + timedelta(seconds=30)
sync_interceptor.set_response(
429,
headers={"Retry-After": format_datetime(future, usegmt=True)},
json=RATE_LIMIT_BODY,
)

with pytest.raises(UsageLimitExceededError) as exc_info:
sync_client.search("What is Tavily?")

assert exc_info.value.retry_after is not None
assert 20 <= exc_info.value.retry_after <= 40


def test_sync_429_retry_after_malformed_is_none(sync_interceptor, sync_client):
sync_interceptor.set_response(
429, headers={"Retry-After": "not-a-number"}, json=RATE_LIMIT_BODY
)

with pytest.raises(UsageLimitExceededError) as exc_info:
sync_client.search("What is Tavily?")

assert exc_info.value.retry_after is None


def test_async_429_exposes_retry_after_seconds(async_interceptor, async_client):
async_interceptor.set_response(429, headers={"Retry-After": "12"}, json=RATE_LIMIT_BODY)

with pytest.raises(UsageLimitExceededError) as exc_info:
asyncio.run(async_client.search("What is Tavily?"))

assert exc_info.value.retry_after == pytest.approx(12.0)


def test_async_429_retry_after_absent_is_none(async_interceptor, async_client):
async_interceptor.set_response(429, json=RATE_LIMIT_BODY)

with pytest.raises(UsageLimitExceededError) as exc_info:
asyncio.run(async_client.search("What is Tavily?"))

assert exc_info.value.retry_after is None


def test_usage_limit_error_default_retry_after_is_none():
err = UsageLimitExceededError("boom")
assert err.retry_after is None


def test_usage_limit_error_accepts_retry_after():
err = UsageLimitExceededError("boom", retry_after=3.5)
assert err.retry_after == 3.5


def test_parse_retry_after_rejects_fractional_seconds():
# RFC 7231 §7.1.3 only defines non-negative integer seconds.
assert _parse_retry_after({"Retry-After": "7.5"}) is None


def test_parse_retry_after_clamps_negative_seconds():
assert _parse_retry_after({"Retry-After": "-10"}) == 0.0


def test_parse_retry_after_rejects_nan_and_inf():
assert _parse_retry_after({"Retry-After": "nan"}) is None
assert _parse_retry_after({"Retry-After": "inf"}) is None


def test_parse_retry_after_case_insensitive_lookup():
assert _parse_retry_after({"retry-after": "5"}) == 5.0
assert _parse_retry_after({"RETRY-AFTER": "5"}) == 5.0


def test_parse_retry_after_past_http_date_clamps_to_zero():
from email.utils import format_datetime
from datetime import datetime, timezone, timedelta

past = datetime.now(timezone.utc) - timedelta(seconds=60)
result = _parse_retry_after({"Retry-After": format_datetime(past, usegmt=True)})
assert result == 0.0


def test_parse_retry_after_empty_and_none():
assert _parse_retry_after(None) is None
assert _parse_retry_after({}) is None
assert _parse_retry_after({"Retry-After": ""}) is None
assert _parse_retry_after({"Retry-After": " "}) is None