Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
685a2cf
fix: use abstopk in feature activation function
dest1n1s Apr 6, 2026
b76ffac
fix: abstopk
dest1n1s Apr 6, 2026
9ec7f47
perf: improve tp attribution
dest1n1s Apr 6, 2026
f1568ee
feat: implement batch_index to replace the torch DTensor indexing
dest1n1s Apr 9, 2026
beaa04a
perf(circuits): use batch index & move source values out of loop
dest1n1s Apr 7, 2026
66145b0
perf(circuits): use multi_batch_index
dest1n1s Apr 8, 2026
f2e0892
fix(circuits): only encode once in apply_saes
dest1n1s Apr 8, 2026
16f52e6
chore(dependencies): add numba to dev dependencies
dest1n1s Apr 8, 2026
c4683a3
chore: remove some timers in lorsa and sae
dest1n1s Apr 8, 2026
623f3b1
fix(circuits): detach ref tensor
dest1n1s Apr 9, 2026
ed5b07d
chore(deps): bump tornado from 6.5.4 to 6.5.5
dependabot[bot] Apr 9, 2026
c55ef4d
chore(deps): bump aiohttp from 3.13.3 to 3.13.4
dependabot[bot] Apr 9, 2026
0abc680
chore(deps): bump nltk from 3.9.3 to 3.9.4
dependabot[bot] Apr 9, 2026
811ded2
chore(deps): bump requests from 2.32.5 to 2.33.0
dependabot[bot] Apr 9, 2026
76237b1
chore(deps): bump pygments from 2.19.2 to 2.20.0
dependabot[bot] Apr 9, 2026
d007472
feat(circuits): integrate QK tracing into attribute() pipeline
Frankstein73 Apr 9, 2026
8ff2b08
feat(examples): add script for QK tracing in attribute() with model l…
Frankstein73 Apr 9, 2026
c4484dc
feat(circuits): support full_tensor and to device in AttributionResult
dest1n1s Apr 9, 2026
b6a9c7c
fix(backend): relax input constraint to allow non-tensor input in dis…
dest1n1s Apr 9, 2026
f53a098
feat(server): support distributed circuit tracing
dest1n1s Apr 9, 2026
3ca1298
fix(circuits): fix full_tensor and to device to transfer all fields p…
dest1n1s Apr 9, 2026
5ac6280
fix(server): preload models on workers & fix distributed function reg…
dest1n1s Apr 9, 2026
ee355c0
refactor(circuits): serialize AttributionResult using torch.save
dest1n1s Apr 10, 2026
15e2150
chore: ignore claude artifacts and drop BaseSAEConfig alias
dest1n1s Apr 10, 2026
d2e54e3
chore: enforce commitizen pre-commit rules
dest1n1s Apr 10, 2026
c3ab301
refactor(attribution): unify QK tracing via generic D container and H…
Frankstein73 Apr 10, 2026
58e3324
refactor(attribution): replace generic D container with Dimensioned f…
Frankstein73 Apr 10, 2026
7890dbe
feat(circuits): support nodes_to_offsets with unregistered nodes (ret…
dest1n1s Apr 10, 2026
4ab4238
feat(server): support host mode
dest1n1s Apr 10, 2026
7632121
refactor(circuits): move Dimensioned to indexed_tensor
dest1n1s Apr 10, 2026
3ba8b8f
feat(server): support qk tracing
dest1n1s Apr 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,7 @@ pyrightconfig.json

.codex
.agents

# Claude Code
CLAUDE.md
.claude/
62 changes: 62 additions & 0 deletions examples/attribute_with_qk_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import os

import torch
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh

from lm_saes.backend.language_model import LanguageModelConfig, TransformerLensLanguageModel
from lm_saes.models.sparse_dictionary import SparseDictionary
from lm_saes.resource_loaders import load_model


def load_replacement_modules(
layers: list[int], exp_factor: int, topk: int, include_lorsa: bool = True, device_mesh: DeviceMesh | None = None
):
replacement_modules = []
sae_types = ["lorsa", "transcoder"] if include_lorsa else ["transcoder"]
for layer in layers:
for sae_type in sae_types:
local_sae_path = f"OpenMOSS-Team/Llama-Scope-2-Qwen3-1.7B:{sae_type}/{exp_factor}x/k{topk}/layer{layer}_{sae_type}_{exp_factor}x_k{topk}"
replacement_modules.append(
SparseDictionary.from_pretrained(
local_sae_path,
device="cuda",
dtype="torch.float32",
fold_activation_scale=False,
device_mesh=device_mesh,
)
)
return replacement_modules


def load_language_model(device_mesh: DeviceMesh | None = None):
model_cfg = LanguageModelConfig(
model_name="Qwen/Qwen3-1.7B",
device="cuda",
dtype="torch.float32",
prepend_bos=False,
)
return load_model(model_cfg, device_mesh)


if __name__ == "__main__":
torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", "0")))
device_mesh = init_device_mesh(
"cuda", mesh_shape=(int(os.environ.get("WORLD_SIZE", "1")),), mesh_dim_names=("model",)
)
model: TransformerLensLanguageModel = load_language_model(device_mesh)
replacement_modules = load_replacement_modules(
layers=list(range(model.cfg.n_layers)), exp_factor=8, topk=64, include_lorsa=True, device_mesh=device_mesh
)

attribute_result = model.attribute(
"The National Digital ",
replacement_modules=replacement_modules,
max_n_logits=10,
desired_logit_prob=0.95,
batch_size=64,
max_features=4096,
enable_qk_tracing=True,
qk_top_fraction=0.6,
)

print(attribute_result)
16 changes: 15 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ dev = [
"sqlalchemy>=2.0.44",
"apscheduler>=3.11.1",
"commitizen>=4.11.0",
"numba>=0.65.0",
]
docs = [
"griffe-pydantic>=1.3.1",
Expand Down Expand Up @@ -195,7 +196,7 @@ version = "2.7.4.post1"
requires-dist = ["torch", "einops"]

[tool.commitizen]
name = "cz_conventional_commits"
name = "cz_customize"
tag_format = "v$version"
version_scheme = "pep440"
version_provider = "uv"
Expand All @@ -206,3 +207,16 @@ version_files = [
"docs/index.md:uv add lm-saes==",
"docs/index.md:pip install lm-saes==",
]

[tool.commitizen.customize]
# Enforce restricted conventional commits.
# Scope is optional; if present, it must be one of the pre-defined values listed below.
schema = "<type>(<scope>)?: <subject>"
schema_pattern = "^(feat|fix|docs|style|refactor|perf|test|build|ci|chore|revert)(\\((backend|activation|circuits|analyzer|runners|database|trainer|evaluator|initializer|kernels|utils|cli|server|ui|ui/circuits|ui/dictionaries|ui/features|ui/bookmarks|ui/admin|ui/embed|deps|docs|examples|tests|ci|release)\\))?!?: .+"
change_type_map = {"feat" = "Feat", "fix" = "Fix", "refactor" = "Refactor", "perf" = "Perf"}
info = """Commit messages must follow restricted conventional commits.
Allowed types: feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert.
Scope is optional; if used, must be one of the pre-defined scopes listed in
pyproject.toml under [tool.commitizen.customize] (e.g. backend, circuits, server,
ui, ui/circuits, deps, ...).
"""
35 changes: 30 additions & 5 deletions server/app.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,47 @@
import asyncio
import os
from contextlib import asynccontextmanager

import torch
from fastapi import FastAPI, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
from torch.distributed.device_mesh import DeviceMesh

from server.logic.loaders import get_dataset, get_model, get_sae
from server.logic.workers import DistributedWorkerRegistry, distributed
from server.routers import admin, bookmarks, circuits, dictionaries
from server.routers.circuits import load_circuit_graph


@distributed
def workers_on_mount(preload_models: list[str], preload_saes: list[str], device_mesh: DeviceMesh | None = None):
for model in preload_models:
get_model(name=model, device_mesh=device_mesh)
for sae in preload_saes:
get_sae(name=sae, device_mesh=device_mesh)


@distributed
def workers_on_unmount(device_mesh: DeviceMesh | None = None):
get_model.cache_clear()
get_dataset.cache_clear()
get_sae.cache_clear()
load_circuit_graph.cache_clear()


@asynccontextmanager
async def lifespan(app: FastAPI):
torch.multiprocessing.set_start_method("spawn", force=True)
DistributedWorkerRegistry.initialize(num_workers=int(os.environ["NUM_WORKERS"]))
preload_models = os.environ["PRELOAD_MODELS"].strip().split(",") if os.environ.get("PRELOAD_MODELS") else []
preload_saes = os.environ["PRELOAD_SAES"].strip().split(",") if os.environ.get("PRELOAD_SAES") else []

task = asyncio.create_task(workers_on_mount(preload_models, preload_saes))

for model in preload_models:
get_model(name=model)

preload_saes = os.environ["PRELOAD_SAES"].strip().split(",") if os.environ.get("PRELOAD_SAES") else []
for sae in preload_saes:
get_sae(name=sae)

# Format: "circuit_id:node_threshold:edge_threshold,..."
# Thresholds default to 0.6 and 0.8 if omitted.
preload_circuits = os.environ["PRELOAD_CIRCUITS"].strip().split(",") if os.environ.get("PRELOAD_CIRCUITS") else []
Expand All @@ -31,12 +52,16 @@ async def lifespan(app: FastAPI):
edge_threshold = float(parts[2]) if len(parts) > 2 else 0.8
load_circuit_graph(circuit_id=circuit_id, node_threshold=node_threshold, edge_threshold=edge_threshold)

await task

yield

task = asyncio.create_task(workers_on_unmount())
get_model.cache_clear()
get_dataset.cache_clear()
get_sae.cache_clear()
load_circuit_graph.cache_clear()
await task
DistributedWorkerRegistry.shutdown()


app = FastAPI(lifespan=lifespan)
Expand Down
2 changes: 1 addition & 1 deletion server/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
device = "cuda" if torch.cuda.is_available() else "cpu"
client = MongoClient(MongoDBConfig())
sae_series = os.environ.get("SAE_SERIES", "default")
tokenizer_only = os.environ.get("TOKENIZER_ONLY", "false").lower() == "true"
tokenizer_only = os.environ.get("IS_WORKER") is None and os.environ.get("NUM_WORKERS") != "0"

# LRU cache sizes (configurable via environment variables)
LRU_CACHE_SIZE_SAMPLES = int(os.environ.get("LRU_CACHE_SIZE_SAMPLES", "128"))
Expand Down
9 changes: 5 additions & 4 deletions server/logic/loaders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import lru_cache

from datasets import Dataset
from torch.distributed.device_mesh import DeviceMesh

from lm_saes import SparseDictionaryConfig
from lm_saes.backend import LanguageModel
Expand All @@ -22,14 +23,14 @@
@synchronized
@lru_cache(maxsize=LRU_CACHE_SIZE_MODELS)
@timer.time("get_model")
def get_model(*, name: str) -> LanguageModel:
def get_model(*, name: str, device_mesh: DeviceMesh | None = None) -> LanguageModel:
"""Load and cache a language model."""
cfg = client.get_model_cfg(name)
if cfg is None:
raise ValueError(f"Model {name} not found")
cfg.tokenizer_only = tokenizer_only
cfg.device = device
return load_model(cfg)
return load_model(cfg, device_mesh=device_mesh)


@synchronized
Expand All @@ -45,11 +46,11 @@ def get_dataset(*, name: str, shard_idx: int = 0, n_shards: int = 1) -> Dataset:
@synchronized
@lru_cache(maxsize=LRU_CACHE_SIZE_SAES)
@timer.time("get_sae")
def get_sae(*, name: str) -> SparseDictionary:
def get_sae(*, name: str, device_mesh: DeviceMesh | None = None) -> SparseDictionary:
"""Load and cache a sparse autoencoder."""
path = client.get_sae_path(name, sae_series)
assert path is not None, f"SAE {name} not found"
sae = SparseDictionary.from_pretrained(path, device=device)
sae = SparseDictionary.from_pretrained(path, device=device, device_mesh=device_mesh)
sae.eval()
return sae

Expand Down
Loading
Loading