Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ dependencies = [
"typing-extensions>=4.13.2",
"more-itertools>=10.7.0",
"json-repair>=0.44.1",
"cattrs>=26.1.0",
]
requires-python = ">=3.11,<3.13"
readme = "README.md"
Expand Down
37 changes: 0 additions & 37 deletions src/lm_saes/backend/attribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,43 +109,6 @@ def full_tensor(self) -> AttributionResult:
qk_trace_results=self.qk_trace_results,
)

_STATE_DICT_VERSION = 1

def state_dict(self) -> dict:
return {
"_version": self._STATE_DICT_VERSION,
"activations": self.activations.state_dict(),
"attribution": self.attribution.state_dict(),
"logits": self.logits,
"probs": self.probs,
"prompt_token_ids": self.prompt_token_ids,
"prompt_tokens": self.prompt_tokens,
"logit_token_ids": self.logit_token_ids,
"logit_tokens": self.logit_tokens,
"qk_trace_results": self.qk_trace_results.state_dict() if self.qk_trace_results is not None else None,
}

@classmethod
def from_state_dict(cls, state: dict, device: torch.device | str = "cpu") -> "AttributionResult":
# version = state.get("_version", 1) # reserved for future migrations
qk_trace_results: Dimensioned[list[Dimensioned[torch.Tensor]]] | None = (
Dimensioned.from_state_dict(state["qk_trace_results"], device=device)
if state["qk_trace_results"] is not None
else None
)
result = cls(
activations=NodeIndexedVector.from_state_dict(state["activations"], device=device),
attribution=NodeIndexedMatrix.from_state_dict(state["attribution"], device=device),
logits=state["logits"].to(device),
probs=state["probs"].to(device),
prompt_token_ids=state["prompt_token_ids"],
prompt_tokens=state["prompt_tokens"],
logit_token_ids=state["logit_token_ids"],
logit_tokens=state["logit_tokens"],
qk_trace_results=qk_trace_results,
)
return result


def get_normalized_matrix(matrix: NodeIndexedMatrix) -> NodeIndexedMatrix:
return NodeIndexedMatrix.from_data(
Expand Down
85 changes: 16 additions & 69 deletions src/lm_saes/backend/indexed_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@
from torch.distributed.tensor import DTensor, Replicate
from torch.types import Number

from lm_saes.core.serialize import (
override,
register_overrides,
structure,
unstructure,
)
from lm_saes.utils.discrete import DiscreteMapper
from lm_saes.utils.distributed import DimMap, full_tensor
from lm_saes.utils.misc import tensor_id
from lm_saes.utils.timer import timer


@register_overrides(_inv_indices=override(omit=True))
@dataclass(frozen=True)
class Node:
key: Any
Expand Down Expand Up @@ -73,17 +80,6 @@ def to(self, device: torch.device | str | None = None, *, device_mesh: DeviceMes
else:
return self

def state_dict(self) -> dict:
return {
"key": self.key,
"indices": self.indices,
"offsets": self.offsets,
}

@classmethod
def from_state_dict(cls, state: dict) -> "Node":
return cls(key=state["key"], indices=state["indices"], offsets=state["offsets"])


@dataclass
class NodeInfo:
Expand Down Expand Up @@ -124,13 +120,6 @@ def to(self, device: torch.device | str) -> Self:
def full_tensor(self) -> Self:
return replace(self, indices=full_tensor(self.indices))

def state_dict(self) -> dict:
return {"key": self.key, "indices": self.indices}

@classmethod
def from_state_dict(cls, state: dict) -> "NodeInfo":
return cls(key=state["key"], indices=state["indices"])


def compute_inv_indices(indices: torch.Tensor) -> torch.Tensor:
"""Build the inverse-index lookup table for a node's ``indices``.
Expand Down Expand Up @@ -470,17 +459,12 @@ def to(self, device: torch.device | str) -> Self:
def full_tensor(self) -> Self:
return replace(self, device_mesh=None)

def state_dict(self) -> dict:
"""Serialize to a minimal dict — only keys, indices, and offsets per node."""
return {
"nodes": [node.state_dict() for node in self.node_mappings.values()],
}
def __unstructure__(self) -> dict[str, Any]:
return unstructure(self.node_mappings)

@classmethod
def from_state_dict(cls, state: dict, device: torch.device | str = "cpu") -> "Dimension":
"""Reconstruct from state_dict. mapper, caches, and device_mesh are rebuilt."""
node_mappings = {node_state["key"]: Node.from_state_dict(node_state) for node_state in state["nodes"]}
return cls._from_node_mappings(node_mappings=node_mappings, device=device)
def __structure__(cls, data: dict[str, Any]) -> Self:
return cls._from_node_mappings(node_mappings=structure(data, dict[Any, Node]), device=data.get("device", "cpu"))


class NodeIndexedTensor:
Expand Down Expand Up @@ -686,17 +670,13 @@ def full_tensor(self) -> Self:
tuple(dim.full_tensor() for dim in self.dimensions),
)

def state_dict(self) -> dict:
return {
"data": self.data,
"dimensions": [dim.state_dict() for dim in self.dimensions],
}
def __unstructure__(self) -> dict[str, Any]:
return {"data": self.data, "dimensions": [unstructure(d) for d in self.dimensions]}

@classmethod
def from_state_dict(cls, state: dict, device: torch.device | str = "cpu") -> Self:
dimensions = tuple(Dimension.from_state_dict(dim_state, device=device) for dim_state in state["dimensions"])
data = state["data"].to(device)
return cls.from_data(data=data, dimensions=dimensions)
def __structure__(cls, data: dict[str, Any]) -> Self:
dims = tuple(structure(d, Dimension) for d in data["dimensions"])
return cls.from_data(data=data["data"], dimensions=dims)


class NodeIndexedVector(NodeIndexedTensor):
Expand Down Expand Up @@ -874,36 +854,3 @@ def to(self, device: torch.device | str) -> Self:
value=new_value,
dimensions=tuple(d.to(device) for d in self.dimensions),
)

def state_dict(self) -> dict:
return {
"value": _encode_value(self.value),
"dimensions": [d.state_dict() for d in self.dimensions],
}

@classmethod
def from_state_dict(cls, state: dict, device: torch.device | str = "cpu") -> "Dimensioned":
value = _decode_value(state["value"], device=device)
dimensions = tuple(Dimension.from_state_dict(d, device=device) for d in state["dimensions"])
return cls(value=value, dimensions=dimensions)


def _encode_value(v: Any) -> dict:
if isinstance(v, torch.Tensor):
return {"kind": "tensor", "data": v}
if isinstance(v, Dimensioned):
return {"kind": "dimensioned", "data": v.state_dict()}
if isinstance(v, list):
return {"kind": "list", "data": [_encode_value(x) for x in v]}
raise TypeError(f"Dimensioned does not know how to serialize value of type {type(v).__name__}")


def _decode_value(state: dict, device: torch.device | str) -> Any:
kind = state["kind"]
if kind == "tensor":
return cast(torch.Tensor, state["data"]).to(device)
if kind == "dimensioned":
return Dimensioned.from_state_dict(state["data"], device=device)
if kind == "list":
return [_decode_value(x, device=device) for x in state["data"]]
raise ValueError(f"Unknown Dimensioned value kind: {kind!r}")
Empty file added src/lm_saes/core/__init__.py
Empty file.
62 changes: 62 additions & 0 deletions src/lm_saes/core/serialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from __future__ import annotations

from typing import Any, Callable, TypeVar

import cattrs
import torch
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn, override
from cattrs.strategies import use_class_methods

FORMAT_VERSION = "1"

T = TypeVar("T")
C = TypeVar("C", bound=type)

converter = cattrs.Converter()

converter.register_unstructure_hook(torch.Tensor, lambda v: v)
converter.register_structure_hook(torch.Tensor, lambda v, _t: v)

use_class_methods(converter, "__structure__", "__unstructure__")


def register_overrides(**overrides: Any) -> Callable[[C], C]:
def decorator(cls: C) -> C:
converter.register_unstructure_hook(cls, make_dict_unstructure_fn(cls, converter, **overrides))
converter.register_structure_hook(cls, make_dict_structure_fn(cls, converter, **overrides))
return cls

return decorator


structure = converter.structure

unstructure = converter.unstructure


def dump(obj: Any) -> dict[str, Any]:
"""Serialize *obj* to a version-tagged, ``torch.save``-friendly dict."""
return {"_version": FORMAT_VERSION, "data": converter.unstructure(obj)}


def load(blob: Any, cls: type[T]) -> T:
"""Rehydrate a value of type *cls* from a :func:`dump` blob."""
if not is_current_format(blob):
raise ValueError(f"Expected serialized blob with _version={FORMAT_VERSION!r}, got {blob!r}")
return converter.structure(blob["data"], cls)


def is_current_format(blob: Any) -> bool:
return isinstance(blob, dict) and blob.get("_version") == FORMAT_VERSION


__all__ = [
"converter",
"dump",
"is_current_format",
"load",
"override",
"register_overrides",
"structure",
"unstructure",
]
17 changes: 9 additions & 8 deletions src/lm_saes/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@
from lm_saes.backend.attribution import AttributionResult
from lm_saes.backend.language_model import LanguageModelConfig
from lm_saes.config import DatasetConfig
from lm_saes.core.serialize import dump, load
from lm_saes.models.sparse_dictionary import SAE_TYPE_TO_CONFIG_CLASS, SparseDictionaryConfig
from lm_saes.utils.bytes import bytes_to_np, np_to_bytes
from lm_saes.utils.logging import get_distributed_logger
from lm_saes.utils.timer import timer

logger = get_distributed_logger(__name__)


class MongoDBConfig(BaseModel):
mongo_uri: str = Field(default_factory=lambda: os.environ.get("MONGO_URI", "mongodb://localhost:27017/"))
Expand Down Expand Up @@ -1019,13 +1023,12 @@ def update_circuit_status(
return result.modified_count > 0

def store_attribution(self, circuit_id: str, attribution: AttributionResult) -> bool:
"""Store attribution data to GridFS using torch.save with state_dict."""
"""Store attribution data to GridFS via ``lm_saes.core.serialize``."""
assert self.fs is not None

buf = io.BytesIO()
torch.save(attribution.state_dict(), buf)
attribution_bytes = buf.getvalue()
attribution_id = self.fs.put(attribution_bytes, filename=f"circuit_{circuit_id}_attribution")
torch.save(dump(attribution), buf)
attribution_id = self.fs.put(buf.getvalue(), filename=f"circuit_{circuit_id}_attribution")

result = self.circuit_collection.update_one(
{"_id": ObjectId(circuit_id)},
Expand All @@ -1042,10 +1045,8 @@ def load_attribution(self, circuit_id: str, device: torch.device | str = "cpu")
if circuit is None or circuit.get("attribution_id") is None:
return None
attribution_bytes = self.fs.get(circuit["attribution_id"]).read()

buf = io.BytesIO(attribution_bytes)
state = torch.load(buf, map_location=device, weights_only=True)
return AttributionResult.from_state_dict(state, device=device)
state = torch.load(io.BytesIO(attribution_bytes), map_location=device, weights_only=True)
return load(state, AttributionResult)

def get_circuit_status(self, circuit_id: str) -> Optional[dict[str, Any]]:
"""Get just the status information for a circuit."""
Expand Down
Loading
Loading