Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/checks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,15 @@ permissions:
actions: write
contents: write

# Cancel in-progress PR runs on new push; non-PR events (release, tags) are exempt.
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.event_name == 'pull_request' }}

# Retry HF 429s in non-pytest invocations; pytest enables via tests/conftest.py.
env:
TRANSFORMERLENS_HF_RETRY: "1"

jobs:
compatibility-checks:
name: Compatibility Checks
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ CD](https://github.com/TransformerLensOrg/TransformerLens/actions/workflows/chec
[![Docs
CD](https://github.com/TransformerLensOrg/TransformerLens/actions/workflows/pages/pages-build-deployment/badge.svg)](https://github.com/TransformerLensOrg/TransformerLens/actions/workflows/pages/pages-build-deployment)

A Library for Mechanistic Interpretability of Generative Language Models. Maintained by [Bryce Meyer](https://github.com/bryce13950) and created by [Neel Nanda](https://neelnanda.io/about)
A Library for Mechanistic Interpretability of Generative Language Models. Maintained by [Bryce Meyer](https://github.com/bryce13950) and [Jonah Larson](https://github.com/jlarson4); created by [Neel Nanda](https://neelnanda.io/about)

[![Read the Docs
Here](https://img.shields.io/badge/-Read%20the%20Docs%20Here-blue?style=for-the-badge&logo=Read-the-Docs&logoColor=white&link=https://TransformerLensOrg.github.io/TransformerLens/)](https://TransformerLensOrg.github.io/TransformerLens/)
Expand Down Expand Up @@ -50,6 +50,8 @@ bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
logits, activations = bridge.run_with_cache("Hello World")
```

> Gated models (Llama, Mistral, Gemma, ...) require `HF_TOKEN` in your environment. See [Environment Variables](https://TransformerLensOrg.github.io/TransformerLens/content/getting_started.html#environment-variables) for the full list.

`TransformerBridge` is the recommended 3.0 path and supports 50+ architectures. By default it preserves raw HuggingFace weights – logits and activations match HF, *not* legacy `HookedTransformer` (which folds LayerNorm and centers weights by default). Call `bridge.enable_compatibility_mode()` after booting for HookedTransformer-equivalent numerics. The legacy `HookedTransformer.from_pretrained` API is still available but deprecated — see the [Migrating to TransformerLens 3](https://TransformerLensOrg.github.io/TransformerLens/content/migrating_to_v3.html) guide.

## Key Tutorials
Expand Down
34 changes: 31 additions & 3 deletions docs/source/content/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,36 @@ The bridge currently covers 50+ architectures spanning Llama, Mistral, Qwen, Gem

The bridge is organized around a small set of generalized components wired together by an architecture adapter, which keeps the model code much easier to navigate than the older unified implementation. For a tour of the bridge's canonical hook names, the component layout, and the expected tensor shapes at each hook point, see the [Model Structure](model_structure.md) page. A small alias layer preserves the older TransformerLens hook names (e.g. `blocks.{i}.hook_resid_pre`) so legacy notebooks keep working — but new code should prefer the canonical names.

## Huggingface Gated Access
## Environment Variables

TransformerLens reads a handful of environment variables. None are required for basic use; each enables a specific opt-in behavior.

### `HF_TOKEN`

Your [HuggingFace access token](https://huggingface.co/settings/tokens). Required for gated models (Llama, Mistral/Mixtral, Gemma families, and others) and used to authenticate any HuggingFace API call TransformerLens makes on your behalf. You will need to accept any model-specific agreements on the HuggingFace Hub before TransformerLens can load a gated model; if you skip this step, the error message will link you directly to the agreement page.

```bash
export HF_TOKEN="hf_..."
```

### `TRANSFORMERLENS_HF_RETRY`

Set to `"1"` to wrap `transformers.AutoConfig.from_pretrained`, `AutoModel.from_pretrained`, `AutoTokenizer.from_pretrained`, `AutoProcessor.from_pretrained`, and `AutoFeatureExtractor.from_pretrained` with a retry-on-429 helper. When HuggingFace returns HTTP 429 (rate-limited), the call is retried up to three times with exponential backoff, honoring the `Retry-After` response header when present.

Some of the models available in TransformerLens require gated access to be used. Luckily TransformerLens provides a way to access those models via the configuration of an environmental variable. Simply configure your [HuggingFace access token](https://huggingface.co/settings/tokens) as `HF_TOKEN` in your environment.
Intended primarily for CI environments where parallel workflow runs can trip HF's rate limits. Off by default so production callers see unmodified `transformers` behavior. The wrapping is idempotent and applied globally to the class methods; see [`enable_hf_retry`](https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/utilities/hf_utils.py) for the implementation. The TransformerLens test suite enables this automatically via `tests/conftest.py`.

```bash
export TRANSFORMERLENS_HF_RETRY=1
```

### `TRANSFORMERLENS_ALLOW_MPS`

Set to `"1"` to opt in to Apple Silicon (MPS) as a target device for model inference. Off by default because not all PyTorch operations used by TransformerLens have stable MPS implementations across PyTorch versions; if you enable this and hit a backend error, the most reliable fallback is to leave the variable unset and let TransformerLens select CPU instead.

```bash
export TRANSFORMERLENS_ALLOW_MPS=1
```

## Huggingface Gated Access

You will need to make sure you accept the agreements for any gated models, but once you do, the models will work with TransformerLens without issue. If you attempt to use one of these models before you have accepted any related agreements, the console output will be very helpful and point you to the URL where you need to accept an agreement. The most popular gated families supported by TransformerLens are the Llama, Mistral/Mixtral, and Gemma models.
For convenience, gated-model access depends only on `HF_TOKEN` above. Once you have set the token and accepted any model-specific agreements on the HuggingFace Hub, gated models load through TransformerLens with no additional configuration. The most popular gated families supported by TransformerLens are the Llama, Mistral/Mixtral, and Gemma models.
29 changes: 26 additions & 3 deletions tests/acceptance/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Shared fixtures for acceptance tests.
"""Session fixtures for acceptance tests.

Session-scoped fixtures avoid redundant model loads across test files.
All models used here must be in the CI cache (see .github/workflows/checks.yml).
transformer_lens imports stay inside fixture bodies — jaxtyping's pytest_configure
hook must install before the package is first imported.
"""

import pytest
Expand All @@ -13,3 +13,26 @@ def gpt2_model():
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained("gpt2", device="cpu")


@pytest.fixture(scope="session")
def bloom_560m_hooked():
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained(
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
)


@pytest.fixture(scope="session")
def bloom_560m_hf_model():
from transformers import AutoModelForCausalLM

return AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")


@pytest.fixture(scope="session")
def bloom_560m_hf_tokenizer():
from transformers import AutoTokenizer

return AutoTokenizer.from_pretrained("bigscience/bloom-560m")
19 changes: 13 additions & 6 deletions tests/acceptance/model_bridge/conftest.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
"""Shared fixtures for model_bridge acceptance tests.
"""Session fixtures for model_bridge acceptance tests.

Session-scoped fixtures avoid redundant model loads across test files.
All models used here must be in the CI cache (see .github/workflows/checks.yml).
transformer_lens imports stay inside fixture bodies — jaxtyping's pytest_configure
hook must install before the package is first imported.
"""

import pytest

from transformer_lens import HookedTransformer
from transformer_lens.model_bridge import TransformerBridge


@pytest.fixture(scope="session")
def gpt2_bridge():
"""TransformerBridge wrapping gpt2 (no compatibility mode)."""
from transformer_lens.model_bridge import TransformerBridge

return TransformerBridge.boot_transformers("gpt2", device="cpu")


@pytest.fixture(scope="session")
def gpt2_bridge_compat():
"""TransformerBridge wrapping gpt2 with compatibility mode enabled."""
from transformer_lens.model_bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
bridge.enable_compatibility_mode()
return bridge
Expand All @@ -27,6 +28,8 @@ def gpt2_bridge_compat():
@pytest.fixture(scope="session")
def gpt2_bridge_compat_no_processing():
"""TransformerBridge wrapping gpt2 with compatibility mode but no weight processing."""
from transformer_lens.model_bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
return bridge
Expand All @@ -35,10 +38,14 @@ def gpt2_bridge_compat_no_processing():
@pytest.fixture(scope="session")
def gpt2_hooked_processed():
"""HookedTransformer gpt2 with default weight processing."""
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained("gpt2", device="cpu")


@pytest.fixture(scope="session")
def gpt2_hooked_unprocessed():
"""HookedTransformer gpt2 without weight processing."""
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu")
25 changes: 12 additions & 13 deletions tests/acceptance/test_hooked_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ def test_from_pretrained_revision():
raise AssertionError("Should have raised an error")


def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
tf_model = HookedTransformer.from_pretrained(
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
)
hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
def test_bloom_similarity_with_hf_model_with_kv_cache_activated(
bloom_560m_hooked, bloom_560m_hf_model, bloom_560m_hf_tokenizer
):
tf_model = bloom_560m_hooked
hf_model = bloom_560m_hf_model
hf_tokenizer = bloom_560m_hf_tokenizer

output_tf = tf_model.generate(
text, do_sample=False, use_past_kv_cache=True, verbose=False, max_new_tokens=10
Expand All @@ -236,13 +236,12 @@ def test_bloom_similarity_with_hf_model_with_kv_cache_activated():
assert output_tf == output_hf_str


def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream():
tf_model = HookedTransformer.from_pretrained(
"bigscience/bloom-560m", default_prepend_bos=False, device="cpu"
)

hf_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m")
hf_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
def test_bloom_similarity_with_hf_model_with_kv_cache_activated_stream(
bloom_560m_hooked, bloom_560m_hf_model, bloom_560m_hf_tokenizer
):
tf_model = bloom_560m_hooked
hf_model = bloom_560m_hf_model
hf_tokenizer = bloom_560m_hf_tokenizer

final_output = ""
for result in tf_model.generate_stream(
Expand Down
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ def pytest_configure(config):
torch.cuda.manual_seed_all(42)


@pytest.fixture(autouse=True, scope="session")
def _enable_hf_retry_for_tests():
"""Deferred to fixture (not pytest_configure) so jaxtyping installs first."""
from transformer_lens.utilities.hf_utils import enable_hf_retry

enable_hf_retry()
yield


@pytest.fixture(scope="session")
def gpt2_tokenizer():
from transformers import AutoTokenizer

return AutoTokenizer.from_pretrained("gpt2")


@pytest.fixture(scope="session")
def gpt2_hooked_processed():
"""Read-only use only — mutations leak across the session."""
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained("gpt2", device="cpu")


def pytest_sessionfinish(session, exitstatus):
"""Clean up at the end of test session."""
if torch.cuda.is_available():
Expand Down
29 changes: 23 additions & 6 deletions tests/integration/model_bridge/conftest.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
"""Shared fixtures for model_bridge integration tests.
"""Session fixtures for model_bridge integration tests.

Session-scoped fixtures avoid redundant model loads across test files.
All models used here must be in the CI cache (see .github/workflows/checks.yml).
transformer_lens imports stay inside fixture bodies — jaxtyping's pytest_configure
hook must install before the package is first imported.
"""

import pytest

from transformer_lens import HookedTransformer
from transformer_lens.model_bridge.bridge import TransformerBridge


@pytest.fixture(scope="session")
def distilgpt2_bridge():
"""TransformerBridge wrapping distilgpt2 (no compatibility mode)."""
from transformer_lens.model_bridge.bridge import TransformerBridge

return TransformerBridge.boot_transformers("distilgpt2", device="cpu")


@pytest.fixture(scope="session")
def distilgpt2_bridge_compat():
"""TransformerBridge wrapping distilgpt2 with compatibility mode enabled."""
from transformer_lens.model_bridge.bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
bridge.enable_compatibility_mode()
return bridge
Expand All @@ -27,12 +28,16 @@ def distilgpt2_bridge_compat():
@pytest.fixture(scope="session")
def gpt2_bridge():
"""TransformerBridge wrapping gpt2 (no compatibility mode)."""
from transformer_lens.model_bridge.bridge import TransformerBridge

return TransformerBridge.boot_transformers("gpt2", device="cpu")


@pytest.fixture(scope="session")
def gpt2_bridge_compat():
"""TransformerBridge wrapping gpt2 with compatibility mode enabled."""
from transformer_lens.model_bridge.bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
bridge.enable_compatibility_mode()
return bridge
Expand All @@ -41,30 +46,40 @@ def gpt2_bridge_compat():
@pytest.fixture(scope="session")
def gpt2_hooked_processed():
"""HookedTransformer gpt2 with default weight processing."""
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained("gpt2", device="cpu")


@pytest.fixture(scope="session")
def gpt2_hooked_unprocessed():
"""HookedTransformer gpt2 without weight processing."""
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained_no_processing("gpt2", device="cpu")


@pytest.fixture(scope="session")
def distilgpt2_hooked_processed():
"""HookedTransformer distilgpt2 with default weight processing."""
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained("distilgpt2", device="cpu")


@pytest.fixture(scope="session")
def distilgpt2_hooked_unprocessed():
"""HookedTransformer distilgpt2 without weight processing."""
from transformer_lens import HookedTransformer

return HookedTransformer.from_pretrained_no_processing("distilgpt2", device="cpu")


@pytest.fixture(scope="session")
def gpt2_bridge_compat_no_processing():
"""TransformerBridge wrapping gpt2 with compat mode, no weight processing."""
from transformer_lens.model_bridge.bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
return bridge
Expand All @@ -73,6 +88,8 @@ def gpt2_bridge_compat_no_processing():
@pytest.fixture(scope="session")
def distilgpt2_bridge_compat_no_processing():
"""TransformerBridge wrapping distilgpt2 with compat mode, no weight processing."""
from transformer_lens.model_bridge.bridge import TransformerBridge

bridge = TransformerBridge.boot_transformers("distilgpt2", device="cpu")
bridge.enable_compatibility_mode(no_processing=True, disable_warnings=True)
return bridge
Loading
Loading