diff --git a/tiktoken/load.py b/tiktoken/load.py index 3c76bcb3..ba94630a 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -63,6 +63,12 @@ def read_file_cached(blobpath: str, expected_hash: str | None = None) -> bytes: except OSError: pass + if os.environ.get("TIKTOKEN_OFFLINE"): + raise ValueError( + f"TIKTOKEN_OFFLINE is set but the file {blobpath} is not in the cache at {cache_dir}. " + f"Please download it first or set TIKTOKEN_CACHE_DIR to a directory containing the file." + ) + contents = read_file(blobpath) if expected_hash and not check_hash(contents, expected_hash): raise ValueError( @@ -107,7 +113,9 @@ def data_gym_to_mergeable_bpe_ranks( # vocab_bpe contains the merges along with associated ranks vocab_bpe_contents = read_file_cached(vocab_bpe_file, vocab_bpe_hash).decode() - bpe_merges = [tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1]] + bpe_merges = [ + tuple(merge_str.split()) for merge_str in vocab_bpe_contents.split("\n")[1:-1] + ] def decode_data_gym(value: str) -> bytes: return bytes(data_gym_byte_to_byte[b] for b in value) @@ -156,7 +164,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No f.write(base64.b64encode(token) + b" " + str(rank).encode() + b"\n") -def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) -> dict[bytes, int]: +def load_tiktoken_bpe( + tiktoken_bpe_file: str, expected_hash: str | None = None +) -> dict[bytes, int]: # NB: do not add caching to this function contents = read_file_cached(tiktoken_bpe_file, expected_hash) ret = {} @@ -167,5 +177,7 @@ def load_tiktoken_bpe(tiktoken_bpe_file: str, expected_hash: str | None = None) token, rank = line.split() ret[base64.b64decode(token)] = int(rank) except Exception as e: - raise ValueError(f"Error parsing line {line!r} in {tiktoken_bpe_file}") from e + raise ValueError( + f"Error parsing line {line!r} in {tiktoken_bpe_file}" + ) from e return ret