Skip to content

Commit d0fa6cc

Browse files
authored
Weight Processing Bugs for larger models (GPT OSS, Gemma3 Multimodal, Phi 3) (#1224)
* Fix small flaws in the updated position embeddings attention, and bugs introduced into Hookedtransformer by transformers v5 * Gemma3 & base Gemma weight processing bug fixes * Additional weight processing bug fixes for Mistral and Phi3
1 parent 2c41b6c commit d0fa6cc

4 files changed

Lines changed: 93 additions & 27 deletions

File tree

transformer_lens/benchmarks/granular_weight_processing.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -370,20 +370,6 @@ def run_granular_weight_processing_benchmarks(
370370
for key, value in forward_hooks_result.details.items():
371371
print(f" {key}: {value}")
372372

373-
# Clean up
374-
del bridge
375-
del ht_ref
376-
# Force garbage collection (multiple passes to break circular references)
377-
import gc
378-
379-
for _ in range(3):
380-
gc.collect()
381-
if torch.cuda.is_available():
382-
torch.cuda.empty_cache()
383-
if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
384-
torch.mps.synchronize()
385-
torch.mps.empty_cache()
386-
387373
except Exception as e:
388374
# Record failure
389375
results.append(
@@ -395,6 +381,20 @@ def run_granular_weight_processing_benchmarks(
395381
details={"error": str(e), "config": str(config)},
396382
)
397383
)
384+
finally:
385+
# Always clean up models after each config (success or failure)
386+
# to prevent memory leaks on large models
387+
import gc
388+
389+
bridge = None # type: ignore[assignment]
390+
ht_ref = None # type: ignore[assignment]
391+
for _ in range(3):
392+
gc.collect()
393+
if torch.cuda.is_available():
394+
torch.cuda.empty_cache()
395+
if hasattr(torch, "mps") and hasattr(torch.mps, "empty_cache"):
396+
torch.mps.synchronize()
397+
torch.mps.empty_cache()
398398

399399
# Store results
400400
all_results[config.name] = results

transformer_lens/loading_from_pretrained.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@
7171
"Qwen/Qwen-",
7272
"Qwen/Qwen3-",
7373
"microsoft/phi-2",
74-
"microsoft/Phi-3-mini-4k-instruct",
7574
"microsoft/phi-4",
7675
"apple/OpenELM",
7776
"openai/gpt-oss-",
@@ -861,7 +860,6 @@ def convert_hf_model_config(model_name: str, **kwargs: Any) -> dict[str, Any]:
861860
"initializer_range": hf_config.initializer_range,
862861
"normalization_type": "RMS",
863862
"positional_embedding_type": "rotary",
864-
"trust_remote_code": True,
865863
"rotary_base": _get_rope_theta(hf_config),
866864
"use_attn_scale": True,
867865
"gated_mlp": True,

transformer_lens/model_bridge/supported_architectures/mistral.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
AttentionBridge,
1212
BlockBridge,
1313
EmbeddingBridge,
14+
GatedMLPBridge,
1415
LinearBridge,
15-
MLPBridge,
1616
RMSNormalizationBridge,
1717
RotaryEmbeddingBridge,
1818
UnembeddingBridge,
@@ -83,7 +83,7 @@ def __init__(self, cfg: Any) -> None:
8383
"o": LinearBridge(name="o_proj"),
8484
},
8585
),
86-
"mlp": MLPBridge(
86+
"mlp": GatedMLPBridge(
8787
name="mlp",
8888
submodules={
8989
"gate": LinearBridge(name="gate_proj"),

transformer_lens/model_bridge/supported_architectures/phi3.py

Lines changed: 77 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
)
1414
from transformer_lens.model_bridge.architecture_adapter import ArchitectureAdapter
1515
from transformer_lens.model_bridge.generalized_components import (
16-
AttentionBridge,
1716
BlockBridge,
1817
EmbeddingBridge,
1918
GatedMLPBridge,
19+
JointQKVPositionEmbeddingsAttentionBridge,
2020
LinearBridge,
2121
RMSNormalizationBridge,
2222
UnembeddingBridge,
@@ -94,24 +94,19 @@ def __init__(self, cfg: Any) -> None:
9494
submodules={
9595
"ln1": RMSNormalizationBridge(name="input_layernorm", config=self.cfg),
9696
"ln2": RMSNormalizationBridge(name="post_attention_layernorm", config=self.cfg),
97-
"attn": AttentionBridge(
97+
"attn": JointQKVPositionEmbeddingsAttentionBridge(
9898
name="self_attn",
9999
config=self.cfg,
100-
requires_position_embeddings=True,
101-
requires_attention_mask=True,
100+
split_qkv_matrix=self._split_phi3_qkv,
102101
submodules={
103-
# Phi-3 uses combined qkv_proj, but we still need submodules for hooks
104-
"q": LinearBridge(name="qkv_proj"),
105-
"k": LinearBridge(name="qkv_proj"),
106-
"v": LinearBridge(name="qkv_proj"),
102+
"qkv": LinearBridge(name="qkv_proj"),
107103
"o": LinearBridge(name="o_proj"),
108104
},
109105
),
110106
"mlp": GatedMLPBridge(
111107
name="mlp",
112108
config=self.cfg,
113109
submodules={
114-
# Phi-3 uses joint gate_up_proj, but we need submodules for hooks
115110
"gate": LinearBridge(name="gate_up_proj"),
116111
"in": LinearBridge(name="gate_up_proj"),
117112
"out": LinearBridge(name="down_proj"),
@@ -123,6 +118,79 @@ def __init__(self, cfg: Any) -> None:
123118
"unembed": UnembeddingBridge(name="lm_head"),
124119
}
125120

121+
@staticmethod
122+
def _split_gate_up(
123+
original_mlp_component: Any,
124+
) -> tuple[torch.nn.Module, torch.nn.Module]:
125+
"""Split Phi-3's fused gate_up_proj into separate gate and up Linear modules."""
126+
fused_weight = original_mlp_component.gate_up_proj.weight
127+
gate_w, up_w = torch.tensor_split(fused_weight, 2, dim=0)
128+
d_model = fused_weight.shape[1]
129+
d_mlp = gate_w.shape[0]
130+
131+
has_bias = (
132+
hasattr(original_mlp_component.gate_up_proj, "bias")
133+
and original_mlp_component.gate_up_proj.bias is not None
134+
)
135+
gate_b: torch.Tensor | None
136+
up_b: torch.Tensor | None
137+
if has_bias:
138+
gate_b, up_b = torch.tensor_split(original_mlp_component.gate_up_proj.bias, 2, dim=0)
139+
else:
140+
gate_b = up_b = None
141+
142+
gate_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias)
143+
gate_proj.weight = torch.nn.Parameter(gate_w)
144+
if gate_b is not None:
145+
gate_proj.bias = torch.nn.Parameter(gate_b)
146+
147+
up_proj = torch.nn.Linear(d_model, d_mlp, bias=has_bias)
148+
up_proj.weight = torch.nn.Parameter(up_w)
149+
if up_b is not None:
150+
up_proj.bias = torch.nn.Parameter(up_b)
151+
152+
return gate_proj, up_proj
153+
154+
def _split_phi3_qkv(
155+
self, original_attention_component: Any
156+
) -> tuple[torch.nn.Module, torch.nn.Module, torch.nn.Module]:
157+
"""Split Phi-3's fused qkv_proj into separate Q, K, V linear modules."""
158+
qkv_weight = original_attention_component.qkv_proj.weight
159+
d_model = qkv_weight.shape[1]
160+
# Phi-3 QKV is [3*n_heads*d_head, d_model], split into equal thirds
161+
q_weight, k_weight, v_weight = torch.tensor_split(qkv_weight, 3, dim=0)
162+
163+
has_bias = (
164+
hasattr(original_attention_component.qkv_proj, "bias")
165+
and original_attention_component.qkv_proj.bias is not None
166+
)
167+
q_bias: torch.Tensor | None
168+
k_bias: torch.Tensor | None
169+
v_bias: torch.Tensor | None
170+
if has_bias:
171+
q_bias, k_bias, v_bias = torch.tensor_split(
172+
original_attention_component.qkv_proj.bias, 3, dim=0
173+
)
174+
else:
175+
q_bias = k_bias = v_bias = None
176+
177+
q_linear = torch.nn.Linear(d_model, q_weight.shape[0], bias=has_bias)
178+
q_linear.weight = torch.nn.Parameter(q_weight)
179+
if q_bias is not None:
180+
q_linear.bias = torch.nn.Parameter(q_bias)
181+
182+
k_linear = torch.nn.Linear(d_model, k_weight.shape[0], bias=has_bias)
183+
k_linear.weight = torch.nn.Parameter(k_weight)
184+
if k_bias is not None:
185+
k_linear.bias = torch.nn.Parameter(k_bias)
186+
187+
v_linear = torch.nn.Linear(d_model, v_weight.shape[0], bias=has_bias)
188+
v_linear.weight = torch.nn.Parameter(v_weight)
189+
if v_bias is not None:
190+
v_linear.bias = torch.nn.Parameter(v_bias)
191+
192+
return q_linear, k_linear, v_linear
193+
126194
def prepare_loading(self, model_name: str, model_kwargs: dict) -> None:
127195
"""Fix compatibility issues for Phi-3 models with trust_remote_code=True.
128196

0 commit comments

Comments
 (0)