Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions tensorizer/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,7 @@ def from_io(
zero_hashes: bool = True,
check_crypt_info: bool = False,
long_shape_tensors: frozenset = frozenset(),
max_header_len: Optional[int] = None,
) -> Optional["_TensorHeaderDeserializer"]:
# We read the entire header into memory rather than reading
# it piecewise to avoid the overhead of many small reads,
Expand All @@ -643,6 +644,11 @@ def from_io(
header_len: int = cls.header_len_segment.unpack(header_len_bytes)[0]
if header_len == 0:
return None
if max_header_len is not None and header_len > max_header_len:
raise ValueError(
"Tensor header length exceeds metadata bounds:"
f" {header_len} > {max_header_len}"
)
buffer = bytearray(header_len)
buffer[:offset] = header_len_bytes
with memoryview(buffer) as mv:
Expand Down Expand Up @@ -3070,18 +3076,31 @@ def _copy_thread(
tensor_sizes_by_name: Dict[_TensorPath, int] = {
t.name: t.deserialized_length for t in tensor_items
}
metadata_by_offset: Dict[int, TensorEntry] = {
t.offset: t for t in unsafe_self._metadata.values()
}

# then for each tensor in tensor_items
tensors_read = 0
while tensors_read < len(tensor_items):
if halt:
break

header_offset = file_.tell()
metadata_entry = metadata_by_offset.get(header_offset)
if metadata_entry is None:
raise ValueError(
"Unexpected tensor header offset:"
f" {header_offset}"
)

header = _TensorHeaderDeserializer.from_io(
file_,
zero_hashes=True,
check_crypt_info=unsafe_self._has_crypt_info,
long_shape_tensors=unsafe_self._long_shape_tensors,
max_header_len=metadata_entry.data_offset
- metadata_entry.offset,
)

if header is None:
Expand Down
35 changes: 35 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import os
import re
import secrets
import struct
import sys
import tempfile
import time
Expand Down Expand Up @@ -345,6 +346,40 @@ def test_serialization(self):
finally:
os.unlink(serialized_model)

def test_oversized_tensor_header_is_rejected(self):
with temporary_file("wb+") as tensorized_file:
serializer = TensorSerializer(tensorized_file)
serializer.write_state_dict({"tensor": torch.zeros(1)})
serializer.close()

with TensorDeserializer(
tensorized_file.name,
device="cpu",
lazy_load=True,
num_readers=1,
) as deserializer:
entry = deserializer._metadata[("tensor",)]

with open(tensorized_file.name, "rb") as file:
data = bytearray(file.read())

legal_header_len = entry.data_offset - entry.offset
struct.pack_into(
"<Q", data, entry.offset, legal_header_len + 1
)
with open(tensorized_file.name, "wb") as file:
file.write(data)

with TensorDeserializer(
tensorized_file.name,
device="cpu",
lazy_load=True,
num_readers=1,
) as deserializer, self.assertRaisesRegex(
ValueError, "Tensor header length exceeds metadata bounds"
):
deserializer["tensor"]

def test_large_unbuffered_tensor(self):
shape = (36000, 36000) # 4.828 GiB
dtype = torch.float32
Expand Down