Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions demos/ARENA_Content.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,9 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"`torch_dtype` is deprecated! Use `dtype` instead!\n",
"The following generation flags are not valid and may be ignored: ['output_attentions']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n"
]
}
],
"outputs": [],
"source": [
"# NBVAL_IGNORE_OUTPUT\n",
"\n",
Expand All @@ -76,10 +67,10 @@
" \"gpt2\",\n",
" device=device,\n",
")\n",
"reference_gpt2.enable_compatibility_mode(disable_warnings=True)",
"\n",
"reference_gpt2.enable_compatibility_mode(disable_warnings=True)\n",
"reference_gpt2.set_use_split_qkv_input(True)\n",
"reference_gpt2.set_use_attn_result(True)"
"reference_gpt2.set_use_attn_result(True)\n",
"reference_gpt2.set_use_hook_mlp_in(True)"
]
},
{
Expand Down
9 changes: 9 additions & 0 deletions docs/source/content/migrating_to_v3.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ cache["blocks.0.hook_in"] # canonical name — preferred for new code

For the full mapping of legacy → canonical names and the expected tensor shape at each hook point, see the [Model Structure](model_structure.md) page.

### Hook semantic notes

Two semantic differences inside `enable_compatibility_mode()` worth knowing if you are porting activation-patching, DLA, or attribution-patching code:

- **`blocks.{i}.hook_mlp_in` fires pre-ln2** (matching legacy `HookedTransformer`). Use `bridge.set_use_hook_mlp_in(True)` to enable it — setting `cfg.use_hook_mlp_in = True` directly is honored when blocks share the bridge's `cfg`, but the setter is the supported entry point. The pre-ln2 placement means cached values from one run can be patched into another and re-flow through `ln2 → mlp` consistently across the bridge and `HookedTransformer`.
- **`hook_q_input` / `hook_k_input` / `hook_v_input` / `hook_attn_in`** also fire pre-ln1 in compat mode. On the per-head LN application that follows, the bridge routes through the raw HF norm rather than the `NormalizationBridge` wrapper, so `ln1`'s sub-hooks (`hook_in`, `hook_normalized`, `hook_scale`) do **not** fire once per head the way legacy `LayerNormPre` would. Q/K/V projections downstream still match legacy numerically; only the intermediate LN sub-hook firing is suppressed.

Post-norm architectures (OLMo 2, BERT-style encoders) and MLA blocks (DeepSeek V2/V3/R1) do not participate in the pre-ln1 capture — `MLABlockBridge` does not expose those aliases, and post-norm models would read the post-attention residual instead of the block input.

## APIs that are unchanged

These work identically on `TransformerBridge` and need no migration:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Bridge vs HookedTransformer parity test for cross-run hook_mlp_in patching (#1317).

Parameterized over Pythia (native autograd LN) and GPT-2 (manual LN), and over
``no_processing`` so both folded and unfolded compat-mode setups are covered.
"""
from __future__ import annotations

import pytest
import torch

from transformer_lens import HookedTransformer
from transformer_lens.model_bridge import TransformerBridge

_MODELS = ("EleutherAI/pythia-14m", "gpt2")
_NO_PROCESSING = (True, False)

_pair_cache: dict[tuple[str, bool], tuple[TransformerBridge, HookedTransformer]] = {}
_baseline_cache: dict[tuple[str, bool, tuple[int, ...]], float] = {}


def _build_pair(model: str, no_processing: bool) -> tuple[TransformerBridge, HookedTransformer]:
key = (model, no_processing)
if key not in _pair_cache:
bridge = TransformerBridge.boot_transformers(model, device="cpu")
bridge.enable_compatibility_mode(no_processing=no_processing)
if no_processing:
ht = HookedTransformer.from_pretrained_no_processing(model, device="cpu")
else:
ht = HookedTransformer.from_pretrained(model, device="cpu")
bridge.set_use_hook_mlp_in(True)
ht.cfg.use_hook_mlp_in = True
_pair_cache[key] = (bridge, ht)
return _pair_cache[key]


def _baseline_logit_diff(model: str, no_processing: bool, prompt: torch.Tensor) -> float:
key = (model, no_processing, tuple(prompt.flatten().tolist()))
if key not in _baseline_cache:
bridge, ht = _build_pair(model, no_processing)
with torch.no_grad():
_baseline_cache[key] = (bridge(prompt) - ht(prompt)).abs().max().item()
return _baseline_cache[key]


@pytest.mark.slow
@pytest.mark.parametrize("no_processing", _NO_PROCESSING)
@pytest.mark.parametrize("model", _MODELS)
@pytest.mark.parametrize("layer", [0, 3])
def test_cross_run_mlp_in_patch_matches_legacy(model: str, layer: int, no_processing: bool) -> None:
"""Splice cached resid_mid from run A into run B's hook_mlp_in; logits should match."""
bridge, ht = _build_pair(model, no_processing)

prompt_a = torch.arange(1, 9).unsqueeze(0)
prompt_b = torch.arange(10, 18).unsqueeze(0)

cache_a_bridge: dict = {}
cache_a_ht: dict = {}
bridge_fire_count = {"n": 0}

def _cap_bridge(tensor: torch.Tensor, hook: object) -> torch.Tensor:
bridge_fire_count["n"] += 1
cache_a_bridge["v"] = tensor.detach().clone()
return tensor

def _cap_ht(tensor: torch.Tensor, hook: object) -> torch.Tensor:
cache_a_ht["v"] = tensor.detach().clone()
return tensor

def _patch(cache: dict) -> "object":
def _inner(tensor: torch.Tensor, hook: object) -> torch.Tensor:
return cache["v"]

return _inner

bridge.run_with_hooks(prompt_a, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _cap_bridge)])
ht.run_with_hooks(prompt_a, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _cap_ht)])

# Pins down a silent-miss in the ln2 pre-hook (the #1317 bug class).
assert bridge_fire_count["n"] == 1, (
f"[{model} no_processing={no_processing}] bridge hook_mlp_in fired "
f"{bridge_fire_count['n']} times, expected exactly 1 (pre-ln2 capture closure)"
)

assert cache_a_bridge["v"].shape == cache_a_ht["v"].shape
captured_diff = (cache_a_bridge["v"] - cache_a_ht["v"]).abs().max().item()
assert captured_diff < 1e-3, (
f"[{model} no_processing={no_processing}] Bridge hook_mlp_in captures "
f"different values than HT: {captured_diff:.3e}"
)

bridge_logits = bridge.run_with_hooks(
prompt_b, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _patch(cache_a_bridge))]
)
ht_logits = ht.run_with_hooks(
prompt_b, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _patch(cache_a_ht))]
)

baseline_diff = _baseline_logit_diff(model, no_processing, prompt_b)
patched_diff = (bridge_logits - ht_logits).abs().max().item()
assert patched_diff < 10 * max(baseline_diff, 1e-5), (
f"[{model} no_processing={no_processing}] Bridge vs HT cross-run mlp_in patch "
f"logits diverge {patched_diff:.3e}, >10x the unhooked baseline {baseline_diff:.3e}"
)


@pytest.mark.slow
def test_mlp_in_gated_off_does_not_fire() -> None:
"""When ``use_hook_mlp_in`` is False, the bridge pre-ln2 closure must skip firing."""
bridge = TransformerBridge.boot_transformers("gpt2", device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
bridge.set_use_hook_mlp_in(False)

fire_count = {"n": 0}

def _counter(tensor: torch.Tensor, hook: object) -> torch.Tensor:
fire_count["n"] += 1
return tensor

prompt = torch.arange(1, 9).unsqueeze(0)
bridge.run_with_hooks(prompt, fwd_hooks=[("blocks.0.hook_mlp_in", _counter)])
assert fire_count["n"] == 0, (
f"hook_mlp_in fired {fire_count['n']} times with use_hook_mlp_in=False; "
"should not fire when the flag is off"
)
188 changes: 134 additions & 54 deletions tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,158 @@
"""Bridge vs HookedTransformer parity test for cross-run Q/K/V patching.

Deliberate strict-xfail: the bridge forks Q/K/V inputs post-ln1; legacy TL
forks pre-ln1. Pure ablations (zero, mean) are unaffected by the placement,
but a cross-run patch — copy a cached residual from run A into run B's
`hook_q_input` — lands in Q's projection already normed for run A's
distribution on the bridge, and pre-norm-then-re-normed on legacy. The logits
diverge.

This test makes the divergence a load-bearing CI signal. When someone ships
pre-ln1 placement (see docs/rfcs/FOLLOWUP-pre-ln-split-qkv.md), the strict
xfail forces them to flip this test to passing in the same PR.
"""Bridge vs HookedTransformer parity tests for cross-run Q/K/V/attn_in patching (#1317).

Parameterized over Pythia (native autograd LN) and GPT-2 (manual LN), and over
``no_processing`` so both folded and unfolded compat-mode setups are covered.
"""
from __future__ import annotations

import os

import pytest
import torch

from transformer_lens import HookedTransformer
from transformer_lens.model_bridge import TransformerBridge

_MODEL = "EleutherAI/pythia-14m"


@pytest.mark.slow
@pytest.mark.xfail(
strict=True,
reason=(
"Bridge forks Q/K/V inputs post-ln1; legacy HookedTransformer forks "
"pre-ln1. Cross-run residual patches land in different coordinate "
"systems, so logits diverge. Tracked in "
"docs/rfcs/FOLLOWUP-pre-ln-split-qkv.md — flip to passing in the same "
"PR that ships pre-ln1 placement."
),
)
def test_cross_run_q_input_patch_matches_legacy() -> None:
"""Copy a cached residual from prompt A into hook_q_input on prompt B; bridge and HT logits should match."""
bridge = TransformerBridge.boot_transformers(_MODEL, device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
ht = HookedTransformer.from_pretrained_no_processing(_MODEL, device="cpu")
bridge.set_use_split_qkv_input(True)
ht.set_use_split_qkv_input(True)

_MODELS = ("EleutherAI/pythia-14m", "gpt2")
_NO_PROCESSING = (True, False)

_pair_cache: dict[tuple[str, bool], tuple[TransformerBridge, HookedTransformer]] = {}
_baseline_cache: dict[tuple[str, bool, tuple[int, ...]], float] = {}


def _build_pair(model: str, no_processing: bool) -> tuple[TransformerBridge, HookedTransformer]:
key = (model, no_processing)
if key not in _pair_cache:
bridge = TransformerBridge.boot_transformers(model, device="cpu")
bridge.enable_compatibility_mode(no_processing=no_processing)
if no_processing:
ht = HookedTransformer.from_pretrained_no_processing(model, device="cpu")
else:
ht = HookedTransformer.from_pretrained(model, device="cpu")
_pair_cache[key] = (bridge, ht)
return _pair_cache[key]


def _baseline_logit_diff(model: str, no_processing: bool, prompt: torch.Tensor) -> float:
key = (model, no_processing, tuple(prompt.flatten().tolist()))
if key not in _baseline_cache:
bridge, ht = _build_pair(model, no_processing)
with torch.no_grad():
_baseline_cache[key] = (bridge(prompt) - ht(prompt)).abs().max().item()
return _baseline_cache[key]


def _cross_run_patch_parity(
model: str,
no_processing: bool,
bridge_hook_path: str,
ht_hook_path: str,
capture_tol: float = 1e-3,
) -> None:
"""Cache hook tensor from prompt A, patch into prompt B on both runtimes; logits should match."""
bridge, ht = _build_pair(model, no_processing)
prompt_a = torch.arange(1, 9).unsqueeze(0)
prompt_b = torch.arange(10, 18).unsqueeze(0)

# Cache a residual from run A (pre-ln on HT; the bridge has no pre-ln hook,
# so we cache the same conceptual slot — hook_q_input post-ln) and splice
# it into run B's hook_q_input at layer 0.
cache_a_bridge: dict = {}
cache_a_ht: dict = {}

def cap_bridge(tensor, hook):
cache_a_bridge["q_in"] = tensor.detach().clone()
return tensor
def _cap(cache: dict) -> "object":
def _inner(tensor: torch.Tensor, hook: object) -> torch.Tensor:
cache["v"] = tensor.detach().clone()
return tensor

def cap_ht(tensor, hook):
cache_a_ht["q_in"] = tensor.detach().clone()
return tensor
return _inner

bridge.run_with_hooks(prompt_a, fwd_hooks=[("blocks.0.attn.hook_q_input", cap_bridge)])
ht.run_with_hooks(prompt_a, fwd_hooks=[("blocks.0.hook_q_input", cap_ht)])
def _patch(cache: dict) -> "object":
def _inner(tensor: torch.Tensor, hook: object) -> torch.Tensor:
return cache["v"]

def patch_bridge(tensor, hook):
return cache_a_bridge["q_in"]
return _inner

def patch_ht(tensor, hook):
return cache_a_ht["q_in"]
bridge.run_with_hooks(prompt_a, fwd_hooks=[(bridge_hook_path, _cap(cache_a_bridge))])
ht.run_with_hooks(prompt_a, fwd_hooks=[(ht_hook_path, _cap(cache_a_ht))])

captured_diff = (cache_a_bridge["v"] - cache_a_ht["v"]).abs().max().item()
assert captured_diff < capture_tol, (
f"[{model} no_processing={no_processing}] Bridge {bridge_hook_path} captures "
f"different values than HT {ht_hook_path}: max diff {captured_diff:.3e} "
f"(tol {capture_tol:.0e})"
)

bridge_logits = bridge.run_with_hooks(
prompt_b, fwd_hooks=[("blocks.0.attn.hook_q_input", patch_bridge)]
prompt_b, fwd_hooks=[(bridge_hook_path, _patch(cache_a_bridge))]
)
ht_logits = ht.run_with_hooks(prompt_b, fwd_hooks=[("blocks.0.hook_q_input", patch_ht)])
ht_logits = ht.run_with_hooks(prompt_b, fwd_hooks=[(ht_hook_path, _patch(cache_a_ht))])

assert torch.allclose(bridge_logits, ht_logits, atol=1e-4), (
f"Bridge vs HT cross-run patch logits diverge: max "
f"{(bridge_logits - ht_logits).abs().max().item():.3e}"
baseline = _baseline_logit_diff(model, no_processing, prompt_b)
patched = (bridge_logits - ht_logits).abs().max().item()
assert patched < 10 * max(baseline, 1e-5), (
f"[{model} no_processing={no_processing}] Bridge vs HT cross-run patch logits "
f"diverge {patched:.3e}, >10x the unhooked baseline {baseline:.3e}"
)


@pytest.mark.slow
@pytest.mark.parametrize("no_processing", _NO_PROCESSING)
@pytest.mark.parametrize("model", _MODELS)
@pytest.mark.parametrize("hook_slot", ["q_input", "k_input", "v_input"])
@pytest.mark.parametrize("layer", [0, 3])
def test_split_qkv_cross_run_patch_matches_legacy(
model: str, hook_slot: str, layer: int, no_processing: bool
) -> None:
"""Each of Q, K, V at multiple layers, on both native/manual LN paths and folded/unfolded compat."""
bridge, ht = _build_pair(model, no_processing)
bridge.set_use_split_qkv_input(True)
ht.set_use_split_qkv_input(True)
_cross_run_patch_parity(
model,
no_processing,
bridge_hook_path=f"blocks.{layer}.attn.hook_{hook_slot}",
ht_hook_path=f"blocks.{layer}.hook_{hook_slot}",
)


@pytest.mark.slow
@pytest.mark.parametrize("no_processing", _NO_PROCESSING)
@pytest.mark.parametrize("model", _MODELS)
@pytest.mark.parametrize("layer", [0, 3])
def test_attn_in_cross_run_patch_matches_legacy(
model: str, layer: int, no_processing: bool
) -> None:
"""The shared attn_in fork uses the same captured pre-LN value, separate from split-QKV."""
bridge, ht = _build_pair(model, no_processing)
bridge.set_use_split_qkv_input(False)
bridge.set_use_attn_in(True)
ht.set_use_split_qkv_input(False)
ht.set_use_attn_in(True)
_cross_run_patch_parity(
model,
no_processing,
bridge_hook_path=f"blocks.{layer}.attn.hook_attn_in",
ht_hook_path=f"blocks.{layer}.hook_attn_in",
)


@pytest.mark.slow
@pytest.mark.skipif(
os.getenv("RUN_OLMO2_GAP_TEST", "") != "1",
reason="Set RUN_OLMO2_GAP_TEST=1 to exercise the OLMo-2 post-norm gap (1B-param download).",
)
@pytest.mark.xfail(
strict=True,
reason="OLMo 2 post-norm: ln1 maps to post_attention_layernorm, so pre-ln1 capture "
"reads post-attention residual. Flip to passing when the carve-out is fixed.",
)
def test_olmo2_pre_ln_capture_known_gap() -> None:
bridge = TransformerBridge.boot_transformers("allenai/OLMo-2-0425-1B", device="cpu")
bridge.enable_compatibility_mode(no_processing=True)
ht = HookedTransformer.from_pretrained_no_processing("allenai/OLMo-2-0425-1B", device="cpu")
bridge.set_use_split_qkv_input(True)
ht.set_use_split_qkv_input(True)
_cross_run_patch_parity(
"allenai/OLMo-2-0425-1B",
True,
bridge_hook_path="blocks.0.attn.hook_q_input",
ht_hook_path="blocks.0.hook_q_input",
)
Loading
Loading