diff --git a/demos/ARENA_Content.ipynb b/demos/ARENA_Content.ipynb index 78d6d8dc9..47ae2f5d8 100644 --- a/demos/ARENA_Content.ipynb +++ b/demos/ARENA_Content.ipynb @@ -56,18 +56,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "`torch_dtype` is deprecated! Use `dtype` instead!\n", - "The following generation flags are not valid and may be ignored: ['output_attentions']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" - ] - } - ], + "outputs": [], "source": [ "# NBVAL_IGNORE_OUTPUT\n", "\n", @@ -76,10 +67,10 @@ " \"gpt2\",\n", " device=device,\n", ")\n", - "reference_gpt2.enable_compatibility_mode(disable_warnings=True)", - "\n", + "reference_gpt2.enable_compatibility_mode(disable_warnings=True)\n", "reference_gpt2.set_use_split_qkv_input(True)\n", - "reference_gpt2.set_use_attn_result(True)" + "reference_gpt2.set_use_attn_result(True)\n", + "reference_gpt2.set_use_hook_mlp_in(True)" ] }, { diff --git a/docs/source/content/migrating_to_v3.md b/docs/source/content/migrating_to_v3.md index fc47fceef..29749fff3 100644 --- a/docs/source/content/migrating_to_v3.md +++ b/docs/source/content/migrating_to_v3.md @@ -99,6 +99,15 @@ cache["blocks.0.hook_in"] # canonical name — preferred for new code For the full mapping of legacy → canonical names and the expected tensor shape at each hook point, see the [Model Structure](model_structure.md) page. +### Hook semantic notes + +Two semantic differences inside `enable_compatibility_mode()` worth knowing if you are porting activation-patching, DLA, or attribution-patching code: + +- **`blocks.{i}.hook_mlp_in` fires pre-ln2** (matching legacy `HookedTransformer`). Use `bridge.set_use_hook_mlp_in(True)` to enable it — setting `cfg.use_hook_mlp_in = True` directly is honored when blocks share the bridge's `cfg`, but the setter is the supported entry point. The pre-ln2 placement means cached values from one run can be patched into another and re-flow through `ln2 → mlp` consistently across the bridge and `HookedTransformer`. +- **`hook_q_input` / `hook_k_input` / `hook_v_input` / `hook_attn_in`** also fire pre-ln1 in compat mode. On the per-head LN application that follows, the bridge routes through the raw HF norm rather than the `NormalizationBridge` wrapper, so `ln1`'s sub-hooks (`hook_in`, `hook_normalized`, `hook_scale`) do **not** fire once per head the way legacy `LayerNormPre` would. Q/K/V projections downstream still match legacy numerically; only the intermediate LN sub-hook firing is suppressed. + +Post-norm architectures (OLMo 2, BERT-style encoders) and MLA blocks (DeepSeek V2/V3/R1) do not participate in the pre-ln1 capture — `MLABlockBridge` does not expose those aliases, and post-norm models would read the post-attention residual instead of the block input. + ## APIs that are unchanged These work identically on `TransformerBridge` and need no migration: diff --git a/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_mlp_in_patching.py b/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_mlp_in_patching.py new file mode 100644 index 000000000..e643905a1 --- /dev/null +++ b/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_mlp_in_patching.py @@ -0,0 +1,124 @@ +"""Bridge vs HookedTransformer parity test for cross-run hook_mlp_in patching (#1317). + +Parameterized over Pythia (native autograd LN) and GPT-2 (manual LN), and over +``no_processing`` so both folded and unfolded compat-mode setups are covered. +""" +from __future__ import annotations + +import pytest +import torch + +from transformer_lens import HookedTransformer +from transformer_lens.model_bridge import TransformerBridge + +_MODELS = ("EleutherAI/pythia-14m", "gpt2") +_NO_PROCESSING = (True, False) + +_pair_cache: dict[tuple[str, bool], tuple[TransformerBridge, HookedTransformer]] = {} +_baseline_cache: dict[tuple[str, bool, tuple[int, ...]], float] = {} + + +def _build_pair(model: str, no_processing: bool) -> tuple[TransformerBridge, HookedTransformer]: + key = (model, no_processing) + if key not in _pair_cache: + bridge = TransformerBridge.boot_transformers(model, device="cpu") + bridge.enable_compatibility_mode(no_processing=no_processing) + if no_processing: + ht = HookedTransformer.from_pretrained_no_processing(model, device="cpu") + else: + ht = HookedTransformer.from_pretrained(model, device="cpu") + bridge.set_use_hook_mlp_in(True) + ht.cfg.use_hook_mlp_in = True + _pair_cache[key] = (bridge, ht) + return _pair_cache[key] + + +def _baseline_logit_diff(model: str, no_processing: bool, prompt: torch.Tensor) -> float: + key = (model, no_processing, tuple(prompt.flatten().tolist())) + if key not in _baseline_cache: + bridge, ht = _build_pair(model, no_processing) + with torch.no_grad(): + _baseline_cache[key] = (bridge(prompt) - ht(prompt)).abs().max().item() + return _baseline_cache[key] + + +@pytest.mark.slow +@pytest.mark.parametrize("no_processing", _NO_PROCESSING) +@pytest.mark.parametrize("model", _MODELS) +@pytest.mark.parametrize("layer", [0, 3]) +def test_cross_run_mlp_in_patch_matches_legacy(model: str, layer: int, no_processing: bool) -> None: + """Splice cached resid_mid from run A into run B's hook_mlp_in; logits should match.""" + bridge, ht = _build_pair(model, no_processing) + + prompt_a = torch.arange(1, 9).unsqueeze(0) + prompt_b = torch.arange(10, 18).unsqueeze(0) + + cache_a_bridge: dict = {} + cache_a_ht: dict = {} + bridge_fire_count = {"n": 0} + + def _cap_bridge(tensor: torch.Tensor, hook: object) -> torch.Tensor: + bridge_fire_count["n"] += 1 + cache_a_bridge["v"] = tensor.detach().clone() + return tensor + + def _cap_ht(tensor: torch.Tensor, hook: object) -> torch.Tensor: + cache_a_ht["v"] = tensor.detach().clone() + return tensor + + def _patch(cache: dict) -> "object": + def _inner(tensor: torch.Tensor, hook: object) -> torch.Tensor: + return cache["v"] + + return _inner + + bridge.run_with_hooks(prompt_a, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _cap_bridge)]) + ht.run_with_hooks(prompt_a, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _cap_ht)]) + + # Pins down a silent-miss in the ln2 pre-hook (the #1317 bug class). + assert bridge_fire_count["n"] == 1, ( + f"[{model} no_processing={no_processing}] bridge hook_mlp_in fired " + f"{bridge_fire_count['n']} times, expected exactly 1 (pre-ln2 capture closure)" + ) + + assert cache_a_bridge["v"].shape == cache_a_ht["v"].shape + captured_diff = (cache_a_bridge["v"] - cache_a_ht["v"]).abs().max().item() + assert captured_diff < 1e-3, ( + f"[{model} no_processing={no_processing}] Bridge hook_mlp_in captures " + f"different values than HT: {captured_diff:.3e}" + ) + + bridge_logits = bridge.run_with_hooks( + prompt_b, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _patch(cache_a_bridge))] + ) + ht_logits = ht.run_with_hooks( + prompt_b, fwd_hooks=[(f"blocks.{layer}.hook_mlp_in", _patch(cache_a_ht))] + ) + + baseline_diff = _baseline_logit_diff(model, no_processing, prompt_b) + patched_diff = (bridge_logits - ht_logits).abs().max().item() + assert patched_diff < 10 * max(baseline_diff, 1e-5), ( + f"[{model} no_processing={no_processing}] Bridge vs HT cross-run mlp_in patch " + f"logits diverge {patched_diff:.3e}, >10x the unhooked baseline {baseline_diff:.3e}" + ) + + +@pytest.mark.slow +def test_mlp_in_gated_off_does_not_fire() -> None: + """When ``use_hook_mlp_in`` is False, the bridge pre-ln2 closure must skip firing.""" + bridge = TransformerBridge.boot_transformers("gpt2", device="cpu") + bridge.enable_compatibility_mode(no_processing=True) + bridge.set_use_hook_mlp_in(False) + + fire_count = {"n": 0} + + def _counter(tensor: torch.Tensor, hook: object) -> torch.Tensor: + fire_count["n"] += 1 + return tensor + + prompt = torch.arange(1, 9).unsqueeze(0) + bridge.run_with_hooks(prompt, fwd_hooks=[("blocks.0.hook_mlp_in", _counter)]) + assert fire_count["n"] == 0, ( + f"hook_mlp_in fired {fire_count['n']} times with use_hook_mlp_in=False; " + "should not fire when the flag is off" + ) diff --git a/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py b/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py index 67b42c73f..66a37f13b 100644 --- a/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py +++ b/tests/unit/model_bridge/test_bridge_vs_hooked_transformer_patching.py @@ -1,78 +1,158 @@ -"""Bridge vs HookedTransformer parity test for cross-run Q/K/V patching. - -Deliberate strict-xfail: the bridge forks Q/K/V inputs post-ln1; legacy TL -forks pre-ln1. Pure ablations (zero, mean) are unaffected by the placement, -but a cross-run patch — copy a cached residual from run A into run B's -`hook_q_input` — lands in Q's projection already normed for run A's -distribution on the bridge, and pre-norm-then-re-normed on legacy. The logits -diverge. - -This test makes the divergence a load-bearing CI signal. When someone ships -pre-ln1 placement (see docs/rfcs/FOLLOWUP-pre-ln-split-qkv.md), the strict -xfail forces them to flip this test to passing in the same PR. +"""Bridge vs HookedTransformer parity tests for cross-run Q/K/V/attn_in patching (#1317). + +Parameterized over Pythia (native autograd LN) and GPT-2 (manual LN), and over +``no_processing`` so both folded and unfolded compat-mode setups are covered. """ from __future__ import annotations +import os + import pytest import torch from transformer_lens import HookedTransformer from transformer_lens.model_bridge import TransformerBridge -_MODEL = "EleutherAI/pythia-14m" - - -@pytest.mark.slow -@pytest.mark.xfail( - strict=True, - reason=( - "Bridge forks Q/K/V inputs post-ln1; legacy HookedTransformer forks " - "pre-ln1. Cross-run residual patches land in different coordinate " - "systems, so logits diverge. Tracked in " - "docs/rfcs/FOLLOWUP-pre-ln-split-qkv.md — flip to passing in the same " - "PR that ships pre-ln1 placement." - ), -) -def test_cross_run_q_input_patch_matches_legacy() -> None: - """Copy a cached residual from prompt A into hook_q_input on prompt B; bridge and HT logits should match.""" - bridge = TransformerBridge.boot_transformers(_MODEL, device="cpu") - bridge.enable_compatibility_mode(no_processing=True) - ht = HookedTransformer.from_pretrained_no_processing(_MODEL, device="cpu") - bridge.set_use_split_qkv_input(True) - ht.set_use_split_qkv_input(True) - +_MODELS = ("EleutherAI/pythia-14m", "gpt2") +_NO_PROCESSING = (True, False) + +_pair_cache: dict[tuple[str, bool], tuple[TransformerBridge, HookedTransformer]] = {} +_baseline_cache: dict[tuple[str, bool, tuple[int, ...]], float] = {} + + +def _build_pair(model: str, no_processing: bool) -> tuple[TransformerBridge, HookedTransformer]: + key = (model, no_processing) + if key not in _pair_cache: + bridge = TransformerBridge.boot_transformers(model, device="cpu") + bridge.enable_compatibility_mode(no_processing=no_processing) + if no_processing: + ht = HookedTransformer.from_pretrained_no_processing(model, device="cpu") + else: + ht = HookedTransformer.from_pretrained(model, device="cpu") + _pair_cache[key] = (bridge, ht) + return _pair_cache[key] + + +def _baseline_logit_diff(model: str, no_processing: bool, prompt: torch.Tensor) -> float: + key = (model, no_processing, tuple(prompt.flatten().tolist())) + if key not in _baseline_cache: + bridge, ht = _build_pair(model, no_processing) + with torch.no_grad(): + _baseline_cache[key] = (bridge(prompt) - ht(prompt)).abs().max().item() + return _baseline_cache[key] + + +def _cross_run_patch_parity( + model: str, + no_processing: bool, + bridge_hook_path: str, + ht_hook_path: str, + capture_tol: float = 1e-3, +) -> None: + """Cache hook tensor from prompt A, patch into prompt B on both runtimes; logits should match.""" + bridge, ht = _build_pair(model, no_processing) prompt_a = torch.arange(1, 9).unsqueeze(0) prompt_b = torch.arange(10, 18).unsqueeze(0) - # Cache a residual from run A (pre-ln on HT; the bridge has no pre-ln hook, - # so we cache the same conceptual slot — hook_q_input post-ln) and splice - # it into run B's hook_q_input at layer 0. cache_a_bridge: dict = {} cache_a_ht: dict = {} - def cap_bridge(tensor, hook): - cache_a_bridge["q_in"] = tensor.detach().clone() - return tensor + def _cap(cache: dict) -> "object": + def _inner(tensor: torch.Tensor, hook: object) -> torch.Tensor: + cache["v"] = tensor.detach().clone() + return tensor - def cap_ht(tensor, hook): - cache_a_ht["q_in"] = tensor.detach().clone() - return tensor + return _inner - bridge.run_with_hooks(prompt_a, fwd_hooks=[("blocks.0.attn.hook_q_input", cap_bridge)]) - ht.run_with_hooks(prompt_a, fwd_hooks=[("blocks.0.hook_q_input", cap_ht)]) + def _patch(cache: dict) -> "object": + def _inner(tensor: torch.Tensor, hook: object) -> torch.Tensor: + return cache["v"] - def patch_bridge(tensor, hook): - return cache_a_bridge["q_in"] + return _inner - def patch_ht(tensor, hook): - return cache_a_ht["q_in"] + bridge.run_with_hooks(prompt_a, fwd_hooks=[(bridge_hook_path, _cap(cache_a_bridge))]) + ht.run_with_hooks(prompt_a, fwd_hooks=[(ht_hook_path, _cap(cache_a_ht))]) + + captured_diff = (cache_a_bridge["v"] - cache_a_ht["v"]).abs().max().item() + assert captured_diff < capture_tol, ( + f"[{model} no_processing={no_processing}] Bridge {bridge_hook_path} captures " + f"different values than HT {ht_hook_path}: max diff {captured_diff:.3e} " + f"(tol {capture_tol:.0e})" + ) bridge_logits = bridge.run_with_hooks( - prompt_b, fwd_hooks=[("blocks.0.attn.hook_q_input", patch_bridge)] + prompt_b, fwd_hooks=[(bridge_hook_path, _patch(cache_a_bridge))] ) - ht_logits = ht.run_with_hooks(prompt_b, fwd_hooks=[("blocks.0.hook_q_input", patch_ht)]) + ht_logits = ht.run_with_hooks(prompt_b, fwd_hooks=[(ht_hook_path, _patch(cache_a_ht))]) - assert torch.allclose(bridge_logits, ht_logits, atol=1e-4), ( - f"Bridge vs HT cross-run patch logits diverge: max " - f"{(bridge_logits - ht_logits).abs().max().item():.3e}" + baseline = _baseline_logit_diff(model, no_processing, prompt_b) + patched = (bridge_logits - ht_logits).abs().max().item() + assert patched < 10 * max(baseline, 1e-5), ( + f"[{model} no_processing={no_processing}] Bridge vs HT cross-run patch logits " + f"diverge {patched:.3e}, >10x the unhooked baseline {baseline:.3e}" + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("no_processing", _NO_PROCESSING) +@pytest.mark.parametrize("model", _MODELS) +@pytest.mark.parametrize("hook_slot", ["q_input", "k_input", "v_input"]) +@pytest.mark.parametrize("layer", [0, 3]) +def test_split_qkv_cross_run_patch_matches_legacy( + model: str, hook_slot: str, layer: int, no_processing: bool +) -> None: + """Each of Q, K, V at multiple layers, on both native/manual LN paths and folded/unfolded compat.""" + bridge, ht = _build_pair(model, no_processing) + bridge.set_use_split_qkv_input(True) + ht.set_use_split_qkv_input(True) + _cross_run_patch_parity( + model, + no_processing, + bridge_hook_path=f"blocks.{layer}.attn.hook_{hook_slot}", + ht_hook_path=f"blocks.{layer}.hook_{hook_slot}", + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("no_processing", _NO_PROCESSING) +@pytest.mark.parametrize("model", _MODELS) +@pytest.mark.parametrize("layer", [0, 3]) +def test_attn_in_cross_run_patch_matches_legacy( + model: str, layer: int, no_processing: bool +) -> None: + """The shared attn_in fork uses the same captured pre-LN value, separate from split-QKV.""" + bridge, ht = _build_pair(model, no_processing) + bridge.set_use_split_qkv_input(False) + bridge.set_use_attn_in(True) + ht.set_use_split_qkv_input(False) + ht.set_use_attn_in(True) + _cross_run_patch_parity( + model, + no_processing, + bridge_hook_path=f"blocks.{layer}.attn.hook_attn_in", + ht_hook_path=f"blocks.{layer}.hook_attn_in", + ) + + +@pytest.mark.slow +@pytest.mark.skipif( + os.getenv("RUN_OLMO2_GAP_TEST", "") != "1", + reason="Set RUN_OLMO2_GAP_TEST=1 to exercise the OLMo-2 post-norm gap (1B-param download).", +) +@pytest.mark.xfail( + strict=True, + reason="OLMo 2 post-norm: ln1 maps to post_attention_layernorm, so pre-ln1 capture " + "reads post-attention residual. Flip to passing when the carve-out is fixed.", +) +def test_olmo2_pre_ln_capture_known_gap() -> None: + bridge = TransformerBridge.boot_transformers("allenai/OLMo-2-0425-1B", device="cpu") + bridge.enable_compatibility_mode(no_processing=True) + ht = HookedTransformer.from_pretrained_no_processing("allenai/OLMo-2-0425-1B", device="cpu") + bridge.set_use_split_qkv_input(True) + ht.set_use_split_qkv_input(True) + _cross_run_patch_parity( + "allenai/OLMo-2-0425-1B", + True, + bridge_hook_path="blocks.0.attn.hook_q_input", + ht_hook_path="blocks.0.hook_q_input", ) diff --git a/transformer_lens/model_bridge/bridge.py b/transformer_lens/model_bridge/bridge.py index 91e685204..3e178b17e 100644 --- a/transformer_lens/model_bridge/bridge.py +++ b/transformer_lens/model_bridge/bridge.py @@ -703,6 +703,14 @@ def enable_compatibility_mode( post-processed coordinate system: logit lens, direct logit attribution, residual-stream norms. Also enables legacy hook/component name aliases. + Hook semantic parity (issue #1317): ``hook_q_input``, ``hook_k_input``, + ``hook_v_input``, ``hook_attn_in``, and ``hook_mlp_in`` fire on the + pre-norm residual. Carve-outs: post-norm architectures (OLMo 2, + BERT-style) read the post-attention residual instead, and MLA blocks + (DeepSeek V2/V3/R1) do not expose the split-qkv aliases. ``hook_mlp_in`` + is gated on ``cfg.use_hook_mlp_in``; toggle it via + :py:meth:`set_use_hook_mlp_in`. + Args: disable_warnings: Whether to disable warnings about legacy components/hooks no_processing: Whether to disable ALL pre-processing steps of the model. @@ -731,6 +739,11 @@ def set_compatibility_mode(component: Any) -> None: apply_fn_to_all_components(self, set_compatibility_mode) self.clear_hook_registry() + # Drop pre-ln capture handles from any prior call so they don't accumulate. + if hasattr(self, "blocks"): + for block in self.blocks: + if hasattr(block, "_teardown_pre_ln_capture"): + block._teardown_pre_ln_capture() try: if not no_processing: self.process_weights( @@ -3433,6 +3446,23 @@ def set_use_attn_in(self, use_attn_in: bool): self.cfg.use_attn_in = use_attn_in self._propagate_attention_flag("use_attn_in", use_attn_in) + def set_use_hook_mlp_in(self, use_hook_mlp_in: bool) -> None: + """Toggle the pre-ln2 ``hook_mlp_in`` HookPoint, matching legacy semantics. + + See :py:meth:`HookedTransformer.set_use_hook_mlp_in`. + """ + self.cfg.use_hook_mlp_in = use_hook_mlp_in + if not hasattr(self, "blocks"): + return + for block in self.blocks: + block_cfg = getattr(block, "config", None) + if block_cfg is not None and block_cfg is not self.cfg: + try: + block_cfg.use_hook_mlp_in = use_hook_mlp_in + except Exception: + pass + block._use_hook_mlp_in = use_hook_mlp_in + def _propagate_attention_flag(self, flag_name: str, value: bool) -> None: """Mirror `bridge.cfg.` onto every block's attention config. diff --git a/transformer_lens/model_bridge/generalized_components/attention.py b/transformer_lens/model_bridge/generalized_components/attention.py index 89d6203ab..d7b9d657a 100644 --- a/transformer_lens/model_bridge/generalized_components/attention.py +++ b/transformer_lens/model_bridge/generalized_components/attention.py @@ -36,6 +36,10 @@ class AttentionBridge(GeneralizedComponent): "hook_v": "v.hook_out", "hook_z": "o.hook_in", } + + # Override to False on variants without a pre-LN fork (e.g. MLA); skips + # the split-qkv HookPoints and the BlockBridge pre-ln1 capture. + supports_split_qkv_fork: bool = True property_aliases = { "W_Q": "q.weight", "W_K": "k.weight", @@ -101,16 +105,15 @@ def __init__( # by cfg.use_attn_result; the HookPoint exists unconditionally so # run_with_cache key lookups never miss. self.hook_result = HookPoint() - # Independent residual copies feeding Q / K / V (and the shared - # `use_attn_in` fork). Fire at [batch, pos, H, d_model] only when - # cfg.use_split_qkv_input or cfg.use_attn_in is set. Placement is - # post-ln1 — see test_bridge_vs_hooked_transformer_patching.py - # (strict xfail) for the semantic divergence from legacy TL's pre-LN - # fork and the follow-up work it tracks. - self.hook_attn_in = HookPoint() - self.hook_q_input = HookPoint() - self.hook_k_input = HookPoint() - self.hook_v_input = HookPoint() + # Pre-ln1 fork hooks ([B, S, H, D]) gated by use_split_qkv_input / + # use_attn_in; fall back to post-ln1 if BlockBridge can't wire ln1. See #1317. + if self.supports_split_qkv_fork: + self.hook_attn_in = HookPoint() + self.hook_q_input = HookPoint() + self.hook_k_input = HookPoint() + self.hook_v_input = HookPoint() + self._captured_pre_ln_residual: Optional[torch.Tensor] = None + self._ln1_module: Optional[torch.nn.Module] = None if ( hasattr(config, "positional_embedding_type") and config.positional_embedding_type == "rotary" @@ -138,6 +141,27 @@ def set_original_component(self, original_component: torch.nn.Module) -> None: if layer_idx_raw is not None: self._layer_idx = int(layer_idx_raw) + def _apply_ln1_per_head(self, x: torch.Tensor) -> torch.Tensor: + """Apply ln1 to [B, S, H, D] with H folded into the batch. Identity if ln1 unwired. + + Routes through the raw HF norm to avoid refiring ln1's internal hooks + per-head — deliberate divergence from legacy's *Pre sub-hook firing. + """ + if self._ln1_module is None: + return x + b, s, h, d = x.shape + return self._ln1_module(x.reshape(b * s * h, d)).reshape(b, s, h, d) + + def _fork_and_norm_per_head( + self, source: torch.Tensor, hook: HookPoint, n_heads: int + ) -> torch.Tensor: + """Repeat residual to [B, S, H, D], fire ``hook``, re-LN iff source is pre-LN.""" + forked = einops.repeat(source, "b s d -> b s h d", h=n_heads).contiguous() + forked = hook(forked) + if self._captured_pre_ln_residual is not None: + forked = self._apply_ln1_per_head(forked) + return forked + def setup_hook_compatibility(self) -> None: """Setup hook compatibility transformations to match HookedTransformer behavior. @@ -676,7 +700,12 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: ): hooked = hooked.to(dtype=target_dtype) args = (hooked,) + args[1:] - output = self.original_component(*args, **kwargs) + # try/finally so the captured tensor (and its autograd graph) is + # released even if original_component raises. + try: + output = self.original_component(*args, **kwargs) + finally: + self._captured_pre_ln_residual = None if isinstance(output, tuple) and len(output) >= 2: # output[0] is attention output # output[1] may be attention weights (pattern) or position_bias (T5) diff --git a/transformer_lens/model_bridge/generalized_components/block.py b/transformer_lens/model_bridge/generalized_components/block.py index 02845998e..506107781 100644 --- a/transformer_lens/model_bridge/generalized_components/block.py +++ b/transformer_lens/model_bridge/generalized_components/block.py @@ -6,10 +6,12 @@ import inspect import re -from typing import Any, Callable, Dict, Optional +import weakref +from typing import Any, Callable, Dict, Optional, cast import torch +from transformer_lens.hook_points import HookPoint from transformer_lens.model_bridge.exceptions import StopAtLayerException from transformer_lens.model_bridge.generalized_components.base import ( GeneralizedComponent, @@ -35,11 +37,8 @@ class BlockBridge(GeneralizedComponent): """ is_list_item: bool = True - # Block-level aliases matching HookedTransformer's hook path. hook_attn_in / - # hook_q_input / hook_k_input / hook_v_input forward to four *independent* - # HookPoints on the attention bridge (they used to collapse onto the same - # upstream tensor; that bug is gone — each hook now backs a distinct - # residual fork gated by cfg.use_split_qkv_input / cfg.use_attn_in). + # hook_mlp_in is a direct HookPoint on this class (not aliased) so it can + # fire pre-ln2; see __init__. The post-ln2 mlp input stays at block.mlp.hook_in. hook_aliases = { "hook_resid_pre": "hook_in", "hook_resid_mid": "ln2.hook_in", @@ -49,7 +48,6 @@ class BlockBridge(GeneralizedComponent): "hook_q_input": "attn.hook_q_input", "hook_k_input": "attn.hook_k_input", "hook_v_input": "attn.hook_v_input", - "hook_mlp_in": "mlp.hook_in", "hook_mlp_out": "mlp.hook_out", } @@ -105,6 +103,76 @@ def __init__( ) self._original_block_forward: Optional[Callable[..., Any]] = None + self._pre_ln_capture_wired: bool = False + self._pre_ln_capture_handles: list[torch.utils.hooks.RemovableHandle] = [] + # Fallback for _read_use_hook_mlp_in when block.config is None. + self._use_hook_mlp_in: bool = False + # Fires pre-ln2 when use_hook_mlp_in is set. See #1317. + self.hook_mlp_in = HookPoint() + + def _maybe_wire_pre_ln_capture(self) -> None: + """Install ln1/ln2 forward_pre_hooks that feed the bridge's pre-LN hooks (#1317). + + Hooks register on the NormalizationBridge instance, not on + ``original_component`` — the manual (non-native-autograd) bridge + forward never calls the raw module, so a hook there would silently miss + on most adapters. Idempotent. + """ + if self._pre_ln_capture_wired: + return + from transformer_lens.model_bridge.generalized_components.attention import ( + AttentionBridge, + ) + + ln1 = self.submodules.get("ln1") if self.submodules else None + attn = self.submodules.get("attn") if self.submodules else None + if ( + ln1 is not None + and isinstance(attn, AttentionBridge) + and getattr(attn, "supports_split_qkv_fork", False) + and getattr(ln1, "original_component", None) is not None + ): + attn_ref = cast(AttentionBridge, weakref.proxy(attn)) + + def _capture_pre_ln1(_module: torch.nn.Module, args: tuple) -> None: + if args and isinstance(args[0], torch.Tensor): + attn_ref._captured_pre_ln_residual = args[0] + + handle = ln1.register_forward_pre_hook(_capture_pre_ln1) + self._pre_ln_capture_handles.append(handle) + attn._ln1_module = ln1.original_component + + ln2 = self.submodules.get("ln2") if self.submodules else None + if ln2 is not None and getattr(ln2, "original_component", None) is not None: + hook_mlp_in = self.hook_mlp_in + block_ref = weakref.proxy(self) + + def _capture_pre_ln2(_module: torch.nn.Module, args: tuple) -> Any: + if not block_ref._read_use_hook_mlp_in(): + return None + if args and isinstance(args[0], torch.Tensor): + hooked = hook_mlp_in(args[0]) + return (hooked,) + args[1:] + return None + + handle = ln2.register_forward_pre_hook(_capture_pre_ln2) + self._pre_ln_capture_handles.append(handle) + + self._pre_ln_capture_wired = True + + def _teardown_pre_ln_capture(self) -> None: + """Remove the ln1/ln2 forward_pre_hooks installed by _maybe_wire_pre_ln_capture.""" + for handle in self._pre_ln_capture_handles: + handle.remove() + self._pre_ln_capture_handles.clear() + self._pre_ln_capture_wired = False + + def _read_use_hook_mlp_in(self) -> bool: + """Prefer ``block.config.use_hook_mlp_in``; fall back to the block-local flag.""" + cfg = self.config + if cfg is not None and hasattr(cfg, "use_hook_mlp_in"): + return bool(cfg.use_hook_mlp_in) + return self._use_hook_mlp_in def forward(self, *args: Any, **kwargs: Any) -> Any: """Forward pass through the block bridge. @@ -124,6 +192,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: f"Original component not set for {self.name}. Call set_original_component() first." ) + self._maybe_wire_pre_ln_capture() self._check_stop_at_layer(*args, **kwargs) args, kwargs = self._hook_input_hidden_states(args, kwargs) @@ -284,8 +353,9 @@ class MLABlockBridge(BlockBridge): q_a_proj→q_a_layernorm→q_b_proj, and K/V share a joint kv_a_proj_with_mqa entry point. There is no single HookPoint that represents "input that becomes Q/K/V", so the block-level ``hook_q_input``/``hook_k_input``/ - ``hook_v_input`` aliases do not apply. Type-level distinction means a reader - of the adapter sees ``MLABlockBridge`` and knows those hooks are absent. + ``hook_v_input``/``hook_attn_in`` aliases do not apply. Type-level + distinction means a reader of the adapter sees ``MLABlockBridge`` and + knows those hooks are absent. """ def __init__( @@ -303,7 +373,7 @@ def __init__( ) if self.hook_aliases is BlockBridge.hook_aliases: self.hook_aliases = dict(self.hook_aliases) - for alias in ("hook_q_input", "hook_k_input", "hook_v_input"): + for alias in ("hook_q_input", "hook_k_input", "hook_v_input", "hook_attn_in"): self.hook_aliases.pop(alias, None) diff --git a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py index 41a2952eb..6b5487150 100644 --- a/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py +++ b/transformer_lens/model_bridge/generalized_components/joint_qkv_attention.py @@ -340,16 +340,15 @@ def _split_forward_qkv( n_kv_heads = int(getattr(cfg, "n_key_value_heads", None) or n_heads) d_head = int(getattr(cfg, "d_head", 0) or (int(cfg.d_model) // n_heads)) use_split = bool(getattr(cfg, "use_split_qkv_input", False)) + # #1317: fork pre-LN when available so hook patches match legacy. + captured = self._captured_pre_ln_residual + source = captured if captured is not None else hidden_states if use_split: - q_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous() - k_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous() - v_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous() - q_in = self.hook_q_input(q_in) - k_in = self.hook_k_input(k_in) - v_in = self.hook_v_input(v_in) + q_in = self._fork_and_norm_per_head(source, self.hook_q_input, n_heads) + k_in = self._fork_and_norm_per_head(source, self.hook_k_input, n_kv_heads) + v_in = self._fork_and_norm_per_head(source, self.hook_v_input, n_kv_heads) else: - attn_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous() - attn_in = self.hook_attn_in(attn_in) + attn_in = self._fork_and_norm_per_head(source, self.hook_attn_in, n_heads) q_in = attn_in if n_kv_heads != n_heads: k_in = attn_in[..., :n_kv_heads, :].contiguous() diff --git a/transformer_lens/model_bridge/generalized_components/mla_attention.py b/transformer_lens/model_bridge/generalized_components/mla_attention.py index c394f85a3..18a770480 100644 --- a/transformer_lens/model_bridge/generalized_components/mla_attention.py +++ b/transformer_lens/model_bridge/generalized_components/mla_attention.py @@ -63,6 +63,9 @@ class MLAAttentionBridge(PositionEmbeddingHooksMixin, AttentionBridge): "hook_z": "o.hook_in", } + # MLA's forward never forks the residual pre-LN; suppress dead HookPoints. + supports_split_qkv_fork: bool = False + def __init__( self, name: str, diff --git a/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py b/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py index 97213402e..21b0e6df9 100644 --- a/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py +++ b/transformer_lens/model_bridge/generalized_components/position_embeddings_attention.py @@ -12,7 +12,6 @@ import weakref from typing import Any, Callable, Dict, Optional -import einops import torch import transformers.models.gemma2.modeling_gemma2 as gemma2_module @@ -298,16 +297,15 @@ def forward(self, *args: Any, **kwargs: Any) -> Any: assert self.config is not None # narrowed by `has_head_count` n_heads = int(self.config.n_heads) n_kv_heads = int(getattr(self.config, "n_key_value_heads", None) or n_heads) + # #1317: fork pre-LN when available so hook patches match legacy. + captured = self._captured_pre_ln_residual + source = captured if captured is not None else hidden_states if use_split_qkv: - q_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous() - k_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous() - v_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_kv_heads).contiguous() - q_in = self.hook_q_input(q_in) - k_in = self.hook_k_input(k_in) - v_in = self.hook_v_input(v_in) + q_in = self._fork_and_norm_per_head(source, self.hook_q_input, n_heads) + k_in = self._fork_and_norm_per_head(source, self.hook_k_input, n_kv_heads) + v_in = self._fork_and_norm_per_head(source, self.hook_v_input, n_kv_heads) else: - attn_in = einops.repeat(hidden_states, "b s d -> b s h d", h=n_heads).contiguous() - attn_in = self.hook_attn_in(attn_in) + attn_in = self._fork_and_norm_per_head(source, self.hook_attn_in, n_heads) q_in = attn_in if n_kv_heads != n_heads: k_in = attn_in[..., :n_kv_heads, :].contiguous()