diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index f9af85637..6a21c64c4 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -357,10 +357,11 @@ def _apply_qk_norm( Returns: Normalized tensor with same shape as input """ - # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head] - d_head = x.shape[-1] - x_normed = norm_module(x.reshape(-1, d_head)) - return x_normed.reshape(x.shape) + # Reshape from [batch, pos, head_index, d_head] to [batch * pos, head_index, d_head] + # so it matches RMSNorm's expected 3D input shape (batch, pos, length) + batch, pos, n_heads, d_head = x.shape + x_normed = norm_module(x.reshape(batch * pos, n_heads, d_head)) + return x_normed.reshape(batch, pos, n_heads, d_head) def calculate_qkv_matrices( self, diff --git a/transformer_lens/components/grouped_query_attention.py b/transformer_lens/components/grouped_query_attention.py index f305d5329..c6a2bbeff 100644 --- a/transformer_lens/components/grouped_query_attention.py +++ b/transformer_lens/components/grouped_query_attention.py @@ -197,8 +197,9 @@ def _apply_qk_norm( Returns: Normalized tensor with same shape as input """ - # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head] + # Reshape from [batch, pos, head_index, d_head] to [batch * pos, head_index, d_head] + # so it matches RMSNorm's expected 3D input shape (batch, pos, length) batch, pos, n_heads, d_head = x.shape - x_reshaped = x.reshape(-1, d_head) + x_reshaped = x.reshape(batch * pos, n_heads, d_head) x_normed = norm_module(x_reshaped) return x_normed.reshape(batch, pos, n_heads, d_head)