1313)
1414from transformer_lens .model_bridge .architecture_adapter import ArchitectureAdapter
1515from transformer_lens .model_bridge .generalized_components import (
16- AttentionBridge ,
1716 BlockBridge ,
1817 EmbeddingBridge ,
1918 GatedMLPBridge ,
19+ JointQKVPositionEmbeddingsAttentionBridge ,
2020 LinearBridge ,
2121 RMSNormalizationBridge ,
2222 UnembeddingBridge ,
@@ -94,24 +94,19 @@ def __init__(self, cfg: Any) -> None:
9494 submodules = {
9595 "ln1" : RMSNormalizationBridge (name = "input_layernorm" , config = self .cfg ),
9696 "ln2" : RMSNormalizationBridge (name = "post_attention_layernorm" , config = self .cfg ),
97- "attn" : AttentionBridge (
97+ "attn" : JointQKVPositionEmbeddingsAttentionBridge (
9898 name = "self_attn" ,
9999 config = self .cfg ,
100- requires_position_embeddings = True ,
101- requires_attention_mask = True ,
100+ split_qkv_matrix = self ._split_phi3_qkv ,
102101 submodules = {
103- # Phi-3 uses combined qkv_proj, but we still need submodules for hooks
104- "q" : LinearBridge (name = "qkv_proj" ),
105- "k" : LinearBridge (name = "qkv_proj" ),
106- "v" : LinearBridge (name = "qkv_proj" ),
102+ "qkv" : LinearBridge (name = "qkv_proj" ),
107103 "o" : LinearBridge (name = "o_proj" ),
108104 },
109105 ),
110106 "mlp" : GatedMLPBridge (
111107 name = "mlp" ,
112108 config = self .cfg ,
113109 submodules = {
114- # Phi-3 uses joint gate_up_proj, but we need submodules for hooks
115110 "gate" : LinearBridge (name = "gate_up_proj" ),
116111 "in" : LinearBridge (name = "gate_up_proj" ),
117112 "out" : LinearBridge (name = "down_proj" ),
@@ -123,6 +118,79 @@ def __init__(self, cfg: Any) -> None:
123118 "unembed" : UnembeddingBridge (name = "lm_head" ),
124119 }
125120
121+ @staticmethod
122+ def _split_gate_up (
123+ original_mlp_component : Any ,
124+ ) -> tuple [torch .nn .Module , torch .nn .Module ]:
125+ """Split Phi-3's fused gate_up_proj into separate gate and up Linear modules."""
126+ fused_weight = original_mlp_component .gate_up_proj .weight
127+ gate_w , up_w = torch .tensor_split (fused_weight , 2 , dim = 0 )
128+ d_model = fused_weight .shape [1 ]
129+ d_mlp = gate_w .shape [0 ]
130+
131+ has_bias = (
132+ hasattr (original_mlp_component .gate_up_proj , "bias" )
133+ and original_mlp_component .gate_up_proj .bias is not None
134+ )
135+ gate_b : torch .Tensor | None
136+ up_b : torch .Tensor | None
137+ if has_bias :
138+ gate_b , up_b = torch .tensor_split (original_mlp_component .gate_up_proj .bias , 2 , dim = 0 )
139+ else :
140+ gate_b = up_b = None
141+
142+ gate_proj = torch .nn .Linear (d_model , d_mlp , bias = has_bias )
143+ gate_proj .weight = torch .nn .Parameter (gate_w )
144+ if gate_b is not None :
145+ gate_proj .bias = torch .nn .Parameter (gate_b )
146+
147+ up_proj = torch .nn .Linear (d_model , d_mlp , bias = has_bias )
148+ up_proj .weight = torch .nn .Parameter (up_w )
149+ if up_b is not None :
150+ up_proj .bias = torch .nn .Parameter (up_b )
151+
152+ return gate_proj , up_proj
153+
154+ def _split_phi3_qkv (
155+ self , original_attention_component : Any
156+ ) -> tuple [torch .nn .Module , torch .nn .Module , torch .nn .Module ]:
157+ """Split Phi-3's fused qkv_proj into separate Q, K, V linear modules."""
158+ qkv_weight = original_attention_component .qkv_proj .weight
159+ d_model = qkv_weight .shape [1 ]
160+ # Phi-3 QKV is [3*n_heads*d_head, d_model], split into equal thirds
161+ q_weight , k_weight , v_weight = torch .tensor_split (qkv_weight , 3 , dim = 0 )
162+
163+ has_bias = (
164+ hasattr (original_attention_component .qkv_proj , "bias" )
165+ and original_attention_component .qkv_proj .bias is not None
166+ )
167+ q_bias : torch .Tensor | None
168+ k_bias : torch .Tensor | None
169+ v_bias : torch .Tensor | None
170+ if has_bias :
171+ q_bias , k_bias , v_bias = torch .tensor_split (
172+ original_attention_component .qkv_proj .bias , 3 , dim = 0
173+ )
174+ else :
175+ q_bias = k_bias = v_bias = None
176+
177+ q_linear = torch .nn .Linear (d_model , q_weight .shape [0 ], bias = has_bias )
178+ q_linear .weight = torch .nn .Parameter (q_weight )
179+ if q_bias is not None :
180+ q_linear .bias = torch .nn .Parameter (q_bias )
181+
182+ k_linear = torch .nn .Linear (d_model , k_weight .shape [0 ], bias = has_bias )
183+ k_linear .weight = torch .nn .Parameter (k_weight )
184+ if k_bias is not None :
185+ k_linear .bias = torch .nn .Parameter (k_bias )
186+
187+ v_linear = torch .nn .Linear (d_model , v_weight .shape [0 ], bias = has_bias )
188+ v_linear .weight = torch .nn .Parameter (v_weight )
189+ if v_bias is not None :
190+ v_linear .bias = torch .nn .Parameter (v_bias )
191+
192+ return q_linear , k_linear , v_linear
193+
126194 def prepare_loading (self , model_name : str , model_kwargs : dict ) -> None :
127195 """Fix compatibility issues for Phi-3 models with trust_remote_code=True.
128196
0 commit comments