From 378af45b200cecfcd3db27751d6df0ca7c29393d Mon Sep 17 00:00:00 2001 From: Brendan Long Date: Sat, 28 Mar 2026 18:28:55 -0700 Subject: [PATCH] Fix QK norm reshape to match RMSNorm's expected 3D input shape MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit RMSNorm.forward() has a jaxtyping hint expecting a 3D tensor (batch, pos, length), but _apply_qk_norm was reshaping to 2D (batch*pos*heads, d_head), causing a BeartypeCallHintParamViolation on Gemma 3 models. Reshape to 3D (batch*pos, heads, d_head) instead — normalization is identical since RMSNorm only operates on the last dimension. Co-Authored-By: Claude Opus 4.6 (1M context) --- transformer_lens/components/abstract_attention.py | 9 +++++---- transformer_lens/components/grouped_query_attention.py | 5 +++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/transformer_lens/components/abstract_attention.py b/transformer_lens/components/abstract_attention.py index f9af85637..6a21c64c4 100644 --- a/transformer_lens/components/abstract_attention.py +++ b/transformer_lens/components/abstract_attention.py @@ -357,10 +357,11 @@ def _apply_qk_norm( Returns: Normalized tensor with same shape as input """ - # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head] - d_head = x.shape[-1] - x_normed = norm_module(x.reshape(-1, d_head)) - return x_normed.reshape(x.shape) + # Reshape from [batch, pos, head_index, d_head] to [batch * pos, head_index, d_head] + # so it matches RMSNorm's expected 3D input shape (batch, pos, length) + batch, pos, n_heads, d_head = x.shape + x_normed = norm_module(x.reshape(batch * pos, n_heads, d_head)) + return x_normed.reshape(batch, pos, n_heads, d_head) def calculate_qkv_matrices( self, diff --git a/transformer_lens/components/grouped_query_attention.py b/transformer_lens/components/grouped_query_attention.py index f305d5329..c6a2bbeff 100644 --- a/transformer_lens/components/grouped_query_attention.py +++ b/transformer_lens/components/grouped_query_attention.py @@ -197,8 +197,9 @@ def _apply_qk_norm( Returns: Normalized tensor with same shape as input """ - # Reshape from [batch, pos, head_index, d_head] to [batch * pos * head_index, d_head] + # Reshape from [batch, pos, head_index, d_head] to [batch * pos, head_index, d_head] + # so it matches RMSNorm's expected 3D input shape (batch, pos, length) batch, pos, n_heads, d_head = x.shape - x_reshaped = x.reshape(-1, d_head) + x_reshaped = x.reshape(batch * pos, n_heads, d_head) x_normed = norm_module(x_reshaped) return x_normed.reshape(batch, pos, n_heads, d_head)