Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions tests/test_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""Tests for ``tiktoken.load``."""

from __future__ import annotations

from typing import Any
from unittest import mock

import pytest

from tiktoken import load


class _FakeResponse:
def __init__(self, content: bytes = b"") -> None:
self.content = content

def raise_for_status(self) -> None:
return None


def test_read_file_https_passes_default_timeout(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Regression: ``read_file`` must pass a non-None timeout to ``requests.get``.

Without it, ``requests.get`` blocks indefinitely on DNS/SYN/TCP-reset
failures and silently hangs ``encoding_for_model`` on first use.
"""
monkeypatch.delenv("TIKTOKEN_HTTP_TIMEOUT", raising=False)
captured: dict[str, Any] = {}

def fake_get(url: str, *, timeout: Any = None, **_: Any) -> _FakeResponse:
captured["url"] = url
captured["timeout"] = timeout
return _FakeResponse(b"data")

fake_requests = mock.Mock()
fake_requests.get = fake_get

with mock.patch.dict("sys.modules", {"requests": fake_requests}):
result = load.read_file("https://example.invalid/path")

assert result == b"data"
assert captured["url"] == "https://example.invalid/path"
assert captured["timeout"] == 60.0


def test_read_file_https_respects_env_override(monkeypatch: pytest.MonkeyPatch) -> None:
"""``TIKTOKEN_HTTP_TIMEOUT`` overrides the default timeout."""
monkeypatch.setenv("TIKTOKEN_HTTP_TIMEOUT", "5.5")
captured: dict[str, Any] = {}

def fake_get(url: str, *, timeout: Any = None, **_: Any) -> _FakeResponse:
captured["timeout"] = timeout
return _FakeResponse(b"")

fake_requests = mock.Mock()
fake_requests.get = fake_get

with mock.patch.dict("sys.modules", {"requests": fake_requests}):
load.read_file("http://example.invalid/path")

assert captured["timeout"] == 5.5


def test_read_file_https_falls_back_on_invalid_env(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""An unparseable ``TIKTOKEN_HTTP_TIMEOUT`` falls back to the default rather than crashing."""
monkeypatch.setenv("TIKTOKEN_HTTP_TIMEOUT", "not-a-number")
captured: dict[str, Any] = {}

def fake_get(url: str, *, timeout: Any = None, **_: Any) -> _FakeResponse:
captured["timeout"] = timeout
return _FakeResponse(b"")

fake_requests = mock.Mock()
fake_requests.get = fake_get

with mock.patch.dict("sys.modules", {"requests": fake_requests}):
load.read_file("http://example.invalid/path")

assert captured["timeout"] == 60.0
21 changes: 17 additions & 4 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ def read_file(blobpath: str) -> bytes:
# avoiding blobfile for public files helps avoid auth issues, like MFA prompts.
import requests

resp = requests.get(blobpath)
# Without an explicit timeout, requests.get blocks indefinitely on
# DNS/SYN/TCP-reset failures, which silently hangs `encoding_for_model`
# on first use. Override via the TIKTOKEN_HTTP_TIMEOUT env var (seconds).
try:
timeout: float | None = float(os.environ.get("TIKTOKEN_HTTP_TIMEOUT", "60"))
except ValueError:
timeout = 60.0
resp = requests.get(blobpath, timeout=timeout)
resp.raise_for_status()
return resp.content

Expand Down Expand Up @@ -107,7 +114,9 @@ def data_gym_to_mergeable_bpe_ranks(

# vocab_bpe contains the merges along with associated ranks
vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode()
bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]]
bpe_merges = [
tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]
]

def decode_data_gym(value: str) -> bytes:
return bytes(data_gym_byte_to_byte[b] for b in value)
Expand Down Expand Up @@ -156,7 +165,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n")


def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]:
def load_tiktoken_bpe(
tiktoken_bpe_file: str, expected_hash: str | None = None
) -> dict[bytes, int]:
# NB: do not add caching to this function
contents = read_file_cached(tiktoken_bpe_file, expected_hash)
ret = {}
Expand All @@ -167,5 +178,7 @@ def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None)
token, rank = line.split()
ret[base64.b64decode(token)] = int(rank)
except Exception as e:
raise ValueError(f"Error parsing line {line!r} in {tiktoken_bpe_file}") from e
raise ValueError(
f"Error parsing line {line!r} in {tiktoken_bpe_file}"
) from e
return ret