From 44348a9f5f3313ba43f59e92bda6c14a6ae0da6a Mon Sep 17 00:00:00 2001 From: MahmoudAshraf97 Date: Wed, 20 May 2026 10:52:45 -0400 Subject: [PATCH] initial commit Co-authored-by: Copilot Signed-off-by: MahmoudAshraf97 --- .../asr/modules/conformer_encoder.py | 70 ++- nemo/collections/asr/parts/mixins/mixins.py | 28 +- .../asr/parts/submodules/conformer_modules.py | 22 +- .../parts/submodules/multi_head_attention.py | 149 ++++++ .../asr/parts/utils/transcribe_utils.py | 2 + tests/collections/asr/test_asr_rope.py | 477 ++++++++++++++++++ 6 files changed, 740 insertions(+), 8 deletions(-) create mode 100644 tests/collections/asr/test_asr_rope.py diff --git a/nemo/collections/asr/modules/conformer_encoder.py b/nemo/collections/asr/modules/conformer_encoder.py index a4ab84c75cb5..d70026915d33 100644 --- a/nemo/collections/asr/modules/conformer_encoder.py +++ b/nemo/collections/asr/modules/conformer_encoder.py @@ -34,6 +34,8 @@ RelPositionalEncoding, RelPositionMultiHeadAttention, RelPositionMultiHeadAttentionLongformer, + RoPEMultiHeadAttention, + RotaryPositionalEncoding, ) from nemo.collections.asr.parts.submodules.subsampling import ( ConvSubsampling, @@ -98,10 +100,16 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin): overlapping chunks. Attention context is determined by att_context_size parameter. 'abs_pos': absolute positional embedding and Transformer + 'rope': + rotary position embedding Default is rel_pos. pos_emb_max_len (int): the maximum length of positional embeddings Defaults to 5000 + rope_base (float): theta base for the rotary position embedding. + Defaults to 10000. + rotary_fraction (float): fraction of the per-head dim to rotate. + Defaults to 1.0. n_heads (int): number of heads in multi-headed attention layers Defaults to 4. att_context_size (List[Union[List[int],int]]): specifies the context sizes on each side. @@ -335,6 +343,8 @@ def __init__( use_pytorch_sdpa: bool = False, use_pytorch_sdpa_backends=None, sync_max_audio_length: bool = True, + rope_base: float = 10000.0, + rotary_fraction: float = 1.0, ): super().__init__() d_ff = d_model * ff_expansion_factor @@ -469,9 +479,17 @@ def __init__( self.pos_enc = PositionalEncoding( d_model=d_model, dropout_rate=dropout_pre_encoder, max_len=pos_emb_max_len, xscale=self.xscale ) + elif self_attention_model == "rope": + self.pos_enc = RotaryPositionalEncoding( + d_k=d_model // n_heads, + rotary_fraction=rotary_fraction, + rope_base=rope_base, + max_len=pos_emb_max_len, + ) else: raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") + layer_pos_enc = self.pos_enc if self_attention_model == 'rope' else None self.layers = nn.ModuleList() for i in range(n_layers): layer = ConformerLayer( @@ -493,6 +511,7 @@ def __init__( use_bias=use_bias, use_pytorch_sdpa=self.use_pytorch_sdpa, use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + pos_enc=layer_pos_enc, ) self.layers.append(layer) @@ -674,7 +693,10 @@ def forward_internal( cache_len = 0 offset = None - audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) + if self.self_attention_model == 'rope': + pos_emb = None + else: + audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) # Create the self-attention and padding masks pad_mask, att_mask = self._create_masks( @@ -733,7 +755,8 @@ def forward_internal( max_audio_length = audio_signal.size(1) # Don't update the audio_signal here because then it will again scale the audio_signal # and cause an increase in the WER - _, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) + if self.self_attention_model != 'rope': + _, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len) pad_mask, att_mask = self._create_masks( att_context_size=cur_att_context_size, padding_length=length, @@ -1130,6 +1153,8 @@ def change_attention_model( att_context_size: List[int] = None, update_config: bool = True, device: torch.device = None, + rope_base: float = None, + rotary_fraction: float = None, ): """ Update the self_attention_model which changes the positional encoding and attention layers. @@ -1147,6 +1172,9 @@ def change_attention_model( 'abs_pos': absolute positional embedding and Transformer + 'rope': + rotary position embedding + If None is provided, the self_attention_model isn't changed. Defaults to None. att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes, or None to keep as it is. Defaults to None. @@ -1154,6 +1182,10 @@ def change_attention_model( Defaults to True. device (torch.device): If provided, new layers will be moved to the device. Defaults to None. + rope_base (float): Theta base for rotary position embedding. Only used when + ``self_attention_model='rope'``. If None, the stored config value is kept. + rotary_fraction (float): Fraction of the per-head dim to rotate. Only used when + ``self_attention_model='rope'``. If None, the stored config value is kept. """ if att_context_size: @@ -1164,6 +1196,11 @@ def change_attention_model( if self_attention_model is None: self_attention_model = self.self_attention_model + if rope_base is None: + rope_base = getattr(self._cfg, 'rope_base', 10000.0) + if rotary_fraction is None: + rotary_fraction = getattr(self._cfg, 'rotary_fraction', 1.0) + if self_attention_model == 'rel_pos_local_attn' and max(att_context_size) <= 0: raise ValueError("When using local attention, context size must be set > 0") @@ -1191,6 +1228,13 @@ def change_attention_model( max_len=self._cfg.pos_emb_max_len, xscale=self.xscale, ) + elif self_attention_model == "rope": + new_pos_enc = RotaryPositionalEncoding( + d_k=self._cfg.d_model // self._cfg.n_heads, + rotary_fraction=rotary_fraction, + rope_base=rope_base, + max_len=self._cfg.pos_emb_max_len, + ) else: raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!") @@ -1236,10 +1280,20 @@ def change_attention_model( use_pytorch_sdpa=self.use_pytorch_sdpa, use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, ) + elif self_attention_model == 'rope': + new_attn = RoPEMultiHeadAttention( + n_head=self._cfg.n_heads, + n_feat=self._cfg.d_model, + dropout_rate=self._cfg.dropout_att, + pos_enc=new_pos_enc, + max_cache_len=att_context_size[0], + use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + ) else: raise ValueError( f"'{self_attention_model}' is not not a valid value for 'self_attention_model', " - f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']" + f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos', 'rope']" ) if device is not None: new_attn = new_attn.to(device=device) @@ -1252,6 +1306,9 @@ def change_attention_model( with open_dict(self._cfg): self._cfg.self_attention_model = self_attention_model self._cfg.att_context_size = att_context_size + if self_attention_model == 'rope': + self._cfg.rope_base = rope_base + self._cfg.rotary_fraction = rotary_fraction def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int): """ @@ -1470,6 +1527,7 @@ class ConformerChangeConfig: 'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using overlapping chunks. Attention context is determined by att_context_size parameter. 'abs_pos': absolute positional embedding and Transformer + 'rope': rotary position embedding """ # If None is provided, self_attention_model is not changed. @@ -1479,3 +1537,9 @@ class ConformerChangeConfig: # corresponding to left and right context, or -1 for full context. # If None is provided, the attention context size isn't changed. att_context_size: Optional[List[int]] = None + + # Rotary position embedding parameters; only used when self_attention_model is + # being set to (or already is) 'rope'. If None, the values from the stored + # config are kept. + rope_base: Optional[float] = None + rotary_fraction: Optional[float] = None diff --git a/nemo/collections/asr/parts/mixins/mixins.py b/nemo/collections/asr/parts/mixins/mixins.py index af973be3cc4c..aa0c418d4c6f 100644 --- a/nemo/collections/asr/parts/mixins/mixins.py +++ b/nemo/collections/asr/parts/mixins/mixins.py @@ -517,7 +517,12 @@ def change_conv_asr_se_context_window(self, context_window: int, update_config: ) def change_attention_model( - self, self_attention_model: str = None, att_context_size: List[int] = None, update_config: bool = True + self, + self_attention_model: str = None, + att_context_size: List[int] = None, + update_config: bool = True, + rope_base: float = None, + rotary_fraction: float = None, ): """ Update the self_attention_model if function is available in encoder. @@ -535,13 +540,20 @@ def change_attention_model( 'abs_pos': absolute positional embedding and Transformer + 'rope': + rotary position embedding + If None is provided, the self_attention_model isn't changed. Defauts to None. att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes, or None to keep as it is. Defauts to None. update_config (bool): Whether to update the config or not with the new attention model. Defaults to True. + rope_base (float): Theta base for rotary position embedding. Only used when + ``self_attention_model='rope'``. If None, the existing value is kept. + rotary_fraction (float): Fraction of the per-head dim to rotate. Only used when + ``self_attention_model='rope'``. If None, the existing value is kept. """ - if self_attention_model is None and att_context_size is None: + if self_attention_model is None and att_context_size is None and rope_base is None and rotary_fraction is None: return if not hasattr(self, 'encoder'): @@ -555,11 +567,21 @@ def change_attention_model( logging.info("Model encoder doesn't have a change_attention_model method ") return - self.encoder.change_attention_model(self_attention_model, att_context_size, update_config, self.device) + self.encoder.change_attention_model( + self_attention_model=self_attention_model, + att_context_size=att_context_size, + update_config=update_config, + device=self.device, + rope_base=rope_base, + rotary_fraction=rotary_fraction, + ) if update_config: with open_dict(self.cfg): self.cfg.encoder.self_attention_model = self_attention_model self.cfg.encoder.att_context_size = att_context_size + if self_attention_model == 'rope': + self.cfg.encoder.rope_base = self.encoder._cfg.rope_base + self.cfg.encoder.rotary_fraction = self.encoder._cfg.rotary_fraction def change_subsampling_conv_chunking_factor( self, subsampling_conv_chunking_factor: int, update_config: bool = True diff --git a/nemo/collections/asr/parts/submodules/conformer_modules.py b/nemo/collections/asr/parts/submodules/conformer_modules.py index b3098ad89ffe..70c7c585c762 100644 --- a/nemo/collections/asr/parts/submodules/conformer_modules.py +++ b/nemo/collections/asr/parts/submodules/conformer_modules.py @@ -24,6 +24,7 @@ MultiHeadAttention, RelPositionMultiHeadAttention, RelPositionMultiHeadAttentionLongformer, + RoPEMultiHeadAttention, ) from nemo.collections.asr.parts.utils.activations import Swish from nemo.collections.common.parts.utils import activation_registry @@ -43,6 +44,7 @@ class ConformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin): 'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using overlapping chunks. Attention context is determined by att_context_size parameter. 'abs_pos': absolute positional embedding and Transformer + 'rope': rotary position embedding Default is rel_pos. global_tokens (int): number of tokens to be used for global attention. Only relevant if self_attention_model is 'rel_pos_local_attn'. @@ -79,6 +81,7 @@ def __init__( use_bias=True, use_pytorch_sdpa=False, use_pytorch_sdpa_backends=None, + pos_enc=None, ): super(ConformerLayer, self).__init__() @@ -144,10 +147,23 @@ def __init__( use_pytorch_sdpa=self.use_pytorch_sdpa, use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, ) + elif self_attention_model == 'rope': + if pos_enc is None: + raise ValueError("'rope' attention requires a RotaryPositionalEncoding via pos_enc.") + self.self_attn = RoPEMultiHeadAttention( + n_head=n_heads, + n_feat=d_model, + dropout_rate=dropout_att, + pos_enc=pos_enc, + max_cache_len=MHA_max_cache_len, + use_bias=use_bias, + use_pytorch_sdpa=self.use_pytorch_sdpa, + use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends, + ) else: raise ValueError( f"'{self_attention_model}' is not not a valid value for 'self_attention_model', " - f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']" + f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos', 'rope']" ) # second feed forward module @@ -181,7 +197,9 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel) elif self.self_attention_model == 'rel_pos_local_attn': x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel) - elif self.self_attention_model == 'abs_pos': + elif self.self_attention_model in ('abs_pos', 'rope'): + # 'rope' rotates Q/K inside MultiHeadAttention via the attached pos_enc, + # so the call site is identical to 'abs_pos' (no additive pos_emb). x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel) else: x = None diff --git a/nemo/collections/asr/parts/submodules/multi_head_attention.py b/nemo/collections/asr/parts/submodules/multi_head_attention.py index 3e6c056bd7b5..3808edaa24f0 100644 --- a/nemo/collections/asr/parts/submodules/multi_head_attention.py +++ b/nemo/collections/asr/parts/submodules/multi_head_attention.py @@ -47,6 +47,8 @@ 'RelPositionMultiHeadAttention', 'RelPositionalEncoding', 'PositionalEncoding', + 'RoPEMultiHeadAttention', + 'RotaryPositionalEncoding', ] INF_VAL = 10000.0 @@ -166,6 +168,7 @@ def forward(self, query, key, value, mask, pos_emb=None, cache=None): # temporary until we solve this more gracefully with avoid_float16_autocast_context(): q, k, v = self.forward_qkv(query, key, value) + q, k = self._apply_pos_emb(q, k) if self.use_pytorch_sdpa: n_batch = value.size(0) @@ -208,6 +211,10 @@ def update_cache(self, key, value, query, cache): cache = torch.cat([cache[:, q_keep_size:, :], query[:, :q_keep_size, :]], dim=1) return key, value, query, cache + def _apply_pos_emb(self, q, k): + """Hook for subclasses to apply a positional transformation to Q/K. No-op by default.""" + return q, k + class RelPositionMultiHeadAttention(MultiHeadAttention): """Multi-Head Attention layer of Transformer-XL with support of relative positional encoding. @@ -990,6 +997,52 @@ def sliding_chunks_matmul_pv(self, prob: torch.Tensor, v: torch.Tensor, w: int): return context.view(bsz, num_heads, seqlen, head_dim).transpose(1, 2) +class RoPEMultiHeadAttention(MultiHeadAttention): + """Multi-Head Attention with rotary position embedding applied to Q and K. + + Args: + n_head (int): number of heads + n_feat (int): size of the features + dropout_rate (float): dropout rate + pos_enc (RotaryPositionalEncoding): rotary position encoding shared across layers + use_bias (bool): whether to remove bias in linear and conv layers + use_pytorch_sdpa (bool): use torch sdpa instead of manual attention + use_pytorch_sdpa_backends list[str]: list of backend names to use in sdpa. + """ + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + pos_enc, + max_cache_len=0, + use_bias=True, + use_pytorch_sdpa=False, + use_pytorch_sdpa_backends=None, + ): + """Construct a RoPEMultiHeadAttention object.""" + super(RoPEMultiHeadAttention, self).__init__( + n_head=n_head, + n_feat=n_feat, + dropout_rate=dropout_rate, + max_cache_len=max_cache_len, + use_bias=use_bias, + use_pytorch_sdpa=use_pytorch_sdpa, + use_pytorch_sdpa_backends=use_pytorch_sdpa_backends, + ) + if pos_enc.d_k != self.d_k: + raise ValueError( + f"RotaryPositionalEncoding d_k ({pos_enc.d_k}) does not match attention " + f"head dim n_feat/n_head ({self.d_k})" + ) + self.pos_enc = pos_enc + + def _apply_pos_emb(self, q, k): + """Rotate Q and K via the attached RotaryPositionalEncoding.""" + return self.pos_enc(q, k) + + class PositionalEncoding(torch.nn.Module): """Fixed sinusoidal positional encoding. Args: @@ -1145,3 +1198,99 @@ def forward(self, x, cache_len=0): if self.dropout_emb: pos_emb = self.dropout_emb(pos_emb) return self.dropout(x), pos_emb + + +class RotaryPositionalEncoding(torch.nn.Module): + """Rotary position embedding. + + Args: + d_k (int): per-head feature dim + rotary_fraction (float): fraction of d_k to rotate + rope_base (float): theta base + max_len (int): maximum input length + """ + + def __init__(self, d_k, rotary_fraction=1.0, rope_base=10000.0, max_len=5000): + """Construct a RotaryPositionalEncoding object.""" + super(RotaryPositionalEncoding, self).__init__() + if not 0 < rotary_fraction <= 1.0: + raise ValueError(f"rotary_fraction must be in (0, 1], got {rotary_fraction}") + d_k_rot = int(d_k * rotary_fraction) + if d_k_rot < 2 or d_k_rot % 2 != 0: + raise ValueError( + f"Effective rotary dim (d_k * rotary_fraction) must be a positive even number, " + f"got {d_k_rot} from d_k={d_k} and rotary_fraction={rotary_fraction}" + ) + self.d_k = d_k + self.d_k_rot = d_k_rot + self.rope_base = rope_base + self.max_len = max_len + + inv_freq = 1.0 / (rope_base ** (torch.arange(0, d_k_rot, 2, dtype=torch.float32) / d_k_rot)) + self.register_buffer('inv_freq', inv_freq, persistent=False) + + def _rotate_half(self, x): + """Split the last dim of x in half and rotate: (x1, x2) -> (-x2, x1).""" + half = x.shape[-1] // 2 + return torch.cat((-x[..., half:], x[..., :half]), dim=-1) + + def _apply_rotary(self, t, cos, sin): + """Apply RoPE to the first d_k_rot dims of the last axis of t.""" + if self.d_k_rot == t.shape[-1]: + return t * cos + self._rotate_half(t) * sin + t_rot = t[..., : self.d_k_rot] + t_pass = t[..., self.d_k_rot :] + t_rot = t_rot * cos + self._rotate_half(t_rot) * sin + return torch.cat((t_rot, t_pass), dim=-1) + + def create_pe(self, positions, dtype): + """Build cos/sin buffers covering ``positions`` (1D fp32 tensor). + + Frequencies are computed in fp32 for numerical stability and cast to + ``dtype`` for storage. The final cast to Q/K runtime dtype happens in + ``forward``. + """ + freqs = torch.outer(positions, self.inv_freq.to(device=positions.device, dtype=torch.float32)) + # Duplicate to align with `_rotate_half`: tail half mirrors the head half. + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype) + sin = emb.sin().to(dtype) + if hasattr(self, 'cos'): + self.cos = cos + self.sin = sin + else: + self.register_buffer('cos', cos, persistent=False) + self.register_buffer('sin', sin, persistent=False) + + def extend_pe(self, length, device, dtype): + """Reset and extend the cos/sin buffers if needed.""" + if hasattr(self, 'cos') and self.cos.size(0) >= length: + return + positions = torch.arange(0, length, dtype=torch.float32, device=device) + self.create_pe(positions=positions, dtype=dtype) + + def forward(self, q, k): + """Rotate Q and K. + + Args: + q (torch.Tensor): (batch, head, time1, d_k) + k (torch.Tensor): (batch, head, time2, d_k); time2 >= time1. + + When time2 > time1 (streaming with KV cache), Q is rotated starting at + ``offset = time2 - time1`` and K from offset 0, so the position difference + seen inside attention scores remains correct. + """ + t_q = q.size(2) + t_k = k.size(2) + cache_len = t_k - t_q + + cos_k = self.cos[:t_k].view(1, 1, t_k, self.d_k_rot) + sin_k = self.sin[:t_k].view(1, 1, t_k, self.d_k_rot) + cos_q = self.cos[cache_len:t_k].view(1, 1, t_q, self.d_k_rot) + sin_q = self.sin[cache_len:t_k].view(1, 1, t_q, self.d_k_rot) + + # Buffers are stored in model dtype; cast to Q/K runtime dtype (which may have + # been upgraded to fp32 by autocast handling in MultiHeadAttention.forward). + q = self._apply_rotary(q, cos_q.to(q.dtype), sin_q.to(q.dtype)) + k = self._apply_rotary(k, cos_k.to(k.dtype), sin_k.to(k.dtype)) + return q, k diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 9c576ce3c093..0999d377dd1e 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -319,6 +319,8 @@ def setup_model(cfg: DictConfig, map_location: torch.device) -> Tuple[ASRModel, asr_model.change_attention_model( self_attention_model=cfg.model_change.conformer.get("self_attention_model", None), att_context_size=cfg.model_change.conformer.get("att_context_size", None), + rope_base=cfg.model_change.conformer.get("rope_base", None), + rotary_fraction=cfg.model_change.conformer.get("rotary_fraction", None), ) return asr_model, model_name diff --git a/tests/collections/asr/test_asr_rope.py b/tests/collections/asr/test_asr_rope.py new file mode 100644 index 000000000000..067362a75599 --- /dev/null +++ b/tests/collections/asr/test_asr_rope.py @@ -0,0 +1,477 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch +from omegaconf import OmegaConf + +from nemo.collections.asr.modules.conformer_encoder import ConformerEncoder +from nemo.collections.asr.parts.submodules.multi_head_attention import RoPEMultiHeadAttention, RotaryPositionalEncoding + + +def _build_encoder( + self_attention_model='rope', + n_layers=2, + d_model=64, + n_heads=4, + use_pytorch_sdpa=False, + use_pytorch_sdpa_backends=None, + rotary_fraction=1.0, + rope_base=10000.0, + pos_emb_max_len=256, +): + return ConformerEncoder( + feat_in=80, + n_layers=n_layers, + d_model=d_model, + n_heads=n_heads, + self_attention_model=self_attention_model, + subsampling_factor=4, + subsampling_conv_channels=32, + pos_emb_max_len=pos_emb_max_len, + rotary_fraction=rotary_fraction, + rope_base=rope_base, + use_pytorch_sdpa=use_pytorch_sdpa, + use_pytorch_sdpa_backends=use_pytorch_sdpa_backends, + dropout=0.0, + dropout_att=0.0, + dropout_emb=0.0, + dropout_pre_encoder=0.0, + ).eval() + + +class TestRotaryPositionalEncoding: + @pytest.mark.unit + def test_rejects_invalid_rotary_fraction(self): + with pytest.raises(ValueError): + RotaryPositionalEncoding(d_k=16, rotary_fraction=0.0) + with pytest.raises(ValueError): + RotaryPositionalEncoding(d_k=16, rotary_fraction=1.5) + + @pytest.mark.unit + def test_rejects_odd_effective_dim(self): + # d_k * rotary_fraction = 16 * 0.1875 = 3, which is odd + with pytest.raises(ValueError): + RotaryPositionalEncoding(d_k=16, rotary_fraction=0.1875) + + @pytest.mark.unit + def test_extend_pe_grows_buffers(self): + pe = RotaryPositionalEncoding(d_k=16, max_len=128) + pe.extend_pe(64, device=torch.device('cpu'), dtype=torch.float32) + assert pe.cos.shape == (64, 16) + pe.extend_pe(128, device=torch.device('cpu'), dtype=torch.float32) + assert pe.cos.shape == (128, 16) + # No-op when buffer is already large enough. + prev = pe.cos.data_ptr() + pe.extend_pe(64, device=torch.device('cpu'), dtype=torch.float32) + assert pe.cos.data_ptr() == prev + + @pytest.mark.unit + def test_forward_first_token_is_identity(self): + # Position 0 has zero phase, so cos=1, sin=0 -> rotation is identity. + pe = RotaryPositionalEncoding(d_k=16, rotary_fraction=1.0) + pe.extend_pe(32, device=torch.device('cpu'), dtype=torch.float32) + q = torch.randn(2, 4, 8, 16) + k = torch.randn(2, 4, 8, 16) + q_rot, k_rot = pe(q, k) + assert q_rot.shape == q.shape + assert k_rot.shape == k.shape + assert torch.allclose(q_rot[:, :, 0, :], q[:, :, 0, :], atol=1e-6) + assert torch.allclose(k_rot[:, :, 0, :], k[:, :, 0, :], atol=1e-6) + + @pytest.mark.unit + def test_partial_rotation_leaves_tail_unchanged(self): + pe = RotaryPositionalEncoding(d_k=16, rotary_fraction=0.5) + pe.extend_pe(32, device=torch.device('cpu'), dtype=torch.float32) + q = torch.randn(2, 4, 8, 16) + k = torch.randn(2, 4, 8, 16) + q_rot, k_rot = pe(q, k) + # The last (d_k - d_k_rot) = 8 dims of each head must pass through untouched. + assert torch.allclose(q_rot[..., pe.d_k_rot :], q[..., pe.d_k_rot :]) + assert torch.allclose(k_rot[..., pe.d_k_rot :], k[..., pe.d_k_rot :]) + + @pytest.mark.unit + def test_dot_product_translation_invariance(self): + # The defining property of RoPE: for the same q and k content, + # depends only on the position difference (m - n). Pick two (m, n) pairs + # that share the same difference and assert the dot products agree. + pe = RotaryPositionalEncoding(d_k=16, rotary_fraction=1.0) + pe.extend_pe(64, device=torch.device('cpu'), dtype=torch.float32) + + torch.manual_seed(0) + q_content = torch.randn(1, 1, 1, 16) + k_content = torch.randn(1, 1, 1, 16) + + def dot_at(m, n): + cos_q = pe.cos[m : m + 1].view(1, 1, 1, 16) + sin_q = pe.sin[m : m + 1].view(1, 1, 1, 16) + cos_k = pe.cos[n : n + 1].view(1, 1, 1, 16) + sin_k = pe.sin[n : n + 1].view(1, 1, 1, 16) + q_r = pe._apply_rotary(q_content, cos_q, sin_q) + k_r = pe._apply_rotary(k_content, cos_k, sin_k) + return (q_r * k_r).sum() + + # Three (m, n) pairs with the same difference n - m = 3. + d_a = dot_at(2, 5) + d_b = dot_at(10, 13) + d_c = dot_at(40, 43) + assert torch.allclose(d_a, d_b, atol=1e-5) + assert torch.allclose(d_a, d_c, atol=1e-5) + + # Sanity: a different position difference must yield a different dot product + # (otherwise the rotation is a no-op or degenerate). + d_diff = dot_at(2, 7) # difference 5 + assert not torch.allclose(d_a, d_diff, atol=1e-3) + + @pytest.mark.unit + def test_rotation_is_not_identity(self): + # Confirm RoPE actually mutates Q/K at non-zero positions. + pe = RotaryPositionalEncoding(d_k=16, rotary_fraction=1.0) + pe.extend_pe(32, device=torch.device('cpu'), dtype=torch.float32) + q = torch.randn(1, 1, 8, 16) + k = torch.randn(1, 1, 8, 16) + q_rot, k_rot = pe(q, k) + # Tokens after position 0 must change. + assert not torch.allclose(q_rot[:, :, 1:, :], q[:, :, 1:, :], atol=1e-3) + assert not torch.allclose(k_rot[:, :, 1:, :], k[:, :, 1:, :], atol=1e-3) + + @pytest.mark.unit + def test_norm_preservation(self): + # Rotation is unitary: ||q_rot[..., t, :]||_2 == ||q[..., t, :]||_2 per (batch, head, t). + # Catches scaling bugs in _apply_rotary. + pe = RotaryPositionalEncoding(d_k=16, rotary_fraction=1.0) + pe.extend_pe(64, device=torch.device('cpu'), dtype=torch.float32) + q = torch.randn(2, 4, 16, 16) + k = torch.randn(2, 4, 16, 16) + q_rot, k_rot = pe(q, k) + q_norm_in = torch.linalg.norm(q, dim=-1) + q_norm_out = torch.linalg.norm(q_rot, dim=-1) + k_norm_in = torch.linalg.norm(k, dim=-1) + k_norm_out = torch.linalg.norm(k_rot, dim=-1) + assert torch.allclose(q_norm_in, q_norm_out, atol=1e-5) + assert torch.allclose(k_norm_in, k_norm_out, atol=1e-5) + + @pytest.mark.unit + def test_reference_equivalence(self): + # Slow split-half RoPE reference written in explicit-2D-rotation form + # (no _rotate_half trick, no cat-duplicated cos/sin). Same math as the + # production code expressed via a disjoint code path, so a bug in either + # _rotate_half or the cos/sin layout would surface here. + d_k = 16 + pe = RotaryPositionalEncoding(d_k=d_k, rotary_fraction=1.0) + pe.extend_pe(32, device=torch.device('cpu'), dtype=torch.float32) + + torch.manual_seed(0) + q = torch.randn(1, 1, 8, d_k) + k = torch.randn(1, 1, 8, d_k) + q_rot, k_rot = pe(q, k) + + d_half = d_k // 2 + positions = torch.arange(8, dtype=torch.float32) + theta = positions[:, None] * pe.inv_freq[None, :] # (T, d_half) + c = theta.cos() + s = theta.sin() + + def rope_ref(x): + # Rotate each (x[..., i], x[..., i + d_half]) pair by angle theta[t, i]. + x_a = x[..., :d_half] + x_b = x[..., d_half:] + y_a = x_a * c - x_b * s + y_b = x_a * s + x_b * c + return torch.cat((y_a, y_b), dim=-1) + + assert torch.allclose(q_rot, rope_ref(q), atol=1e-6) + assert torch.allclose(k_rot, rope_ref(k), atol=1e-6) + + @pytest.mark.unit + def test_extend_preserves_existing_positions(self): + # Extending the cos/sin buffers must not change the values at previously + # covered positions, otherwise streaming forward calls would silently + # produce different rotations across the extension boundary. + pe = RotaryPositionalEncoding(d_k=16, max_len=64) + pe.extend_pe(64, device=torch.device('cpu'), dtype=torch.float32) + cos_before = pe.cos[:64].clone() + sin_before = pe.sin[:64].clone() + pe.extend_pe(256, device=torch.device('cpu'), dtype=torch.float32) + assert torch.equal(pe.cos[:64], cos_before) + assert torch.equal(pe.sin[:64], sin_before) + + @pytest.mark.unit + def test_non_contiguous_inputs(self): + # Real-world callers may pass non-contiguous Q/K (e.g. from .transpose()). + # The rotation must produce the same result as on the contiguous version. + pe = RotaryPositionalEncoding(d_k=16, rotary_fraction=1.0) + pe.extend_pe(32, device=torch.device('cpu'), dtype=torch.float32) + + # Build (B, T, H, D) and transpose to (B, H, T, D) -> non-contiguous. + q_btnd = torch.randn(2, 8, 4, 16) + k_btnd = torch.randn(2, 8, 4, 16) + q_nc = q_btnd.transpose(1, 2) + k_nc = k_btnd.transpose(1, 2) + assert not q_nc.is_contiguous() and not k_nc.is_contiguous() + + q_rot_nc, k_rot_nc = pe(q_nc, k_nc) + q_rot_c, k_rot_c = pe(q_nc.contiguous(), k_nc.contiguous()) + assert torch.allclose(q_rot_nc, q_rot_c, atol=1e-6) + assert torch.allclose(k_rot_nc, k_rot_c, atol=1e-6) + + +class TestRoPEMultiHeadAttention: + @pytest.mark.unit + def test_rejects_pos_enc_with_wrong_d_k(self): + # n_feat / n_head = 64 / 4 = 16, but pos_enc was built with d_k=32. + bad_pe = RotaryPositionalEncoding(d_k=32, max_len=64) + with pytest.raises(ValueError): + RoPEMultiHeadAttention(n_head=4, n_feat=64, dropout_rate=0.0, pos_enc=bad_pe) + + @pytest.mark.unit + def test_v_unchanged_by_rotation(self): + # Confirm the rotation hook is called only with (q, k); V must never reach + # the positional encoder. Catches a future regression where someone adds + # V to the rotation hook signature. + pe = RotaryPositionalEncoding(d_k=16, max_len=32) + pe.extend_pe(32, device=torch.device('cpu'), dtype=torch.float32) + attn = RoPEMultiHeadAttention(n_head=4, n_feat=64, dropout_rate=0.0, pos_enc=pe).eval() + + call_args = [] + original_forward = pe.forward + + def spy(q, k): + call_args.append((q.shape, k.shape)) + return original_forward(q, k) + + attn.pos_enc.forward = spy + x = torch.randn(2, 16, 64) + with torch.no_grad(): + _ = attn(query=x, key=x, value=x, mask=None) + assert len(call_args) == 1 + q_shape, k_shape = call_args[0] + # Both tensors have the same length (16); the layout is (B, H, T, d_k). + assert q_shape == (2, 4, 16, 16) + assert k_shape == (2, 4, 16, 16) + + @pytest.mark.unit + def test_backward_smoke(self): + # Forward → loss → backward → every learnable param has a non-NaN, non-zero + # gradient. Mirrors test_transformer_encoder.py::test_backward_pass. + pe = RotaryPositionalEncoding(d_k=16, max_len=32) + pe.extend_pe(32, device=torch.device('cpu'), dtype=torch.float32) + attn = RoPEMultiHeadAttention(n_head=4, n_feat=64, dropout_rate=0.0, pos_enc=pe).train() + x = torch.randn(2, 8, 64, requires_grad=True) + out = attn(query=x, key=x, value=x, mask=None) + loss = out.sum() + loss.backward() + for name, param in attn.named_parameters(): + assert param.grad is not None, f"No gradient for {name}" + assert not torch.isnan(param.grad).any(), f"NaN gradient for {name}" + assert (param.grad != 0).any(), f"All-zero gradient for {name}" + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + @pytest.mark.parametrize( + "backend", ['MATH', 'FLASH_ATTENTION', 'EFFICIENT_ATTENTION', 'CUDNN_ATTENTION'] + ) + def test_sdpa_backend_smoke_gpu(self, backend): + # Each SDPA backend must run with RoPE pre-rotation under bf16 autocast + # (the production training path) without falling back or crashing on + # shape/dtype constraints. FLASH/EFFICIENT/CUDNN require fp16/bf16; + # bf16 satisfies all four. + pe = RotaryPositionalEncoding(d_k=16, max_len=32).to("cuda") + pe.extend_pe(32, device=torch.device('cuda'), dtype=torch.float32) + attn = ( + RoPEMultiHeadAttention( + n_head=4, + n_feat=64, + dropout_rate=0.0, + pos_enc=pe, + use_pytorch_sdpa=True, + use_pytorch_sdpa_backends=[backend], + ) + .to("cuda") + .eval() + ) + x = torch.randn(2, 16, 64, device='cuda') + with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): + out = attn(query=x, key=x, value=x, mask=None) + assert out.shape == (2, 16, 64) + assert torch.isfinite(out).all() + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + def test_autocast_gpu(self): + # Mixed-precision forward (CUDA autocast in bf16) must produce finite output. + # Exercises the interaction between RoPE's .to(q.dtype) cast and the + # avoid_float16_autocast_context wrapper in the base MHA. + pe = RotaryPositionalEncoding(d_k=16, max_len=32).to("cuda") + pe.extend_pe(32, device=torch.device('cuda'), dtype=torch.float32) + attn = RoPEMultiHeadAttention(n_head=4, n_feat=64, dropout_rate=0.0, pos_enc=pe).to("cuda").eval() + x = torch.randn(2, 16, 64, device='cuda') + with torch.no_grad(), torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): + out = attn(query=x, key=x, value=x, mask=None) + assert torch.isfinite(out).all() + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + @pytest.mark.parametrize( + "dtype,atol", + [(torch.float32, 1e-5), (torch.bfloat16, 5e-2), (torch.float16, 1e-2)], + ) + def test_dtype_stability_gpu(self, dtype, atol): + # Forward in low precision must stay close to the fp32 reference. + pe = RotaryPositionalEncoding(d_k=16, max_len=32).to("cuda") + pe.extend_pe(32, device=torch.device('cuda'), dtype=torch.float32) + attn = RoPEMultiHeadAttention(n_head=4, n_feat=64, dropout_rate=0.0, pos_enc=pe).to("cuda").eval() + torch.manual_seed(0) + x = torch.randn(2, 16, 64, device='cuda') + with torch.no_grad(): + out_ref = attn(query=x, key=x, value=x, mask=None) + + # `attn.to(dtype=...)` converts every buffer including pos_enc.cos/sin to `dtype`. + attn_dt = attn.to(dtype=dtype) + x_dt = x.to(dtype=dtype) + with torch.no_grad(): + out_dt = attn_dt(query=x_dt, key=x_dt, value=x_dt, mask=None) + assert torch.isfinite(out_dt).all() + assert torch.allclose(out_dt.float(), out_ref, atol=atol, rtol=atol) + + @pytest.mark.unit + def test_streaming_matches_offline(self): + # The load-bearing test for the cache_len offset logic. Feeding the last + # `new_len` tokens with the first `cache_len` tokens as KV cache must + # reproduce the corresponding slice of the offline forward, because RoPE + # depends only on the (m - n) position difference and the cache layout + # preserves that. + pe = RotaryPositionalEncoding(d_k=16, max_len=64) + pe.extend_pe(32, device=torch.device('cpu'), dtype=torch.float32) + attn = RoPEMultiHeadAttention(n_head=4, n_feat=64, dropout_rate=0.0, pos_enc=pe).eval() + attn.cache_drop_size = 0 # required by update_cache + + torch.manual_seed(7) + full_seq = torch.randn(1, 12, 64) + cache_len = 8 + + with torch.no_grad(): + offline_out = attn(query=full_seq, key=full_seq, value=full_seq, mask=None) + new_query = full_seq[:, cache_len:] + cache = full_seq[:, :cache_len] + streaming_out, _ = attn(query=new_query, key=new_query, value=new_query, mask=None, cache=cache) + + assert torch.allclose(streaming_out, offline_out[:, cache_len:], atol=1e-5) + + +class TestConformerEncoderRoPE: + @pytest.mark.unit + def test_pos_enc_shared_across_layers(self): + # Critical: every layer must hold the same pos_enc instance so that the + # encoder's set_max_audio_length / extend_pe grows the buffers used by + # every layer (not just the first). + enc = _build_encoder() + assert all(layer.self_attn.pos_enc is enc.pos_enc for layer in enc.layers) + # And exercising the shared-extend path: growing the buffer once must be + # visible from every layer. + enc.pos_enc.extend_pe(512, device=torch.device('cpu'), dtype=torch.float32) + assert all(layer.self_attn.pos_enc.cos.size(0) >= 512 for layer in enc.layers) + + @pytest.mark.unit + def test_sdpa_matches_manual(self): + # CPU fp32: SDPA falls back to MATH; verify it matches the manual matmul + # path so RoPE pre-rotation is applied consistently across both code paths. + enc_manual = _build_encoder(use_pytorch_sdpa=False) + enc_sdpa = _build_encoder(use_pytorch_sdpa=True) + enc_sdpa.load_state_dict(enc_manual.state_dict(), strict=False) + x = torch.randn(2, 80, 200) + lens = torch.tensor([200, 150]) + with torch.no_grad(): + o_manual, _ = enc_manual(audio_signal=x, length=lens) + o_sdpa, _ = enc_sdpa(audio_signal=x, length=lens) + assert torch.allclose(o_manual, o_sdpa, atol=1e-4, rtol=1e-4) + + @pytest.mark.run_only_on('GPU') + @pytest.mark.unit + @pytest.mark.parametrize("backend", ['MATH', 'EFFICIENT_ATTENTION', 'CUDNN_ATTENTION']) + def test_sdpa_backend_matches_manual_gpu(self, backend): + # Forward + backward parity vs the manual path under bf16 autocast. + # RoPE applies to Q/K before the SDPA call, so each backend sees the + # same rotated tensors and must agree on outputs and gradients within + # bf16 tolerance. FLASH_ATTENTION is excluded because PyTorch rejects + # any non-null `attn_mask` on the Flash kernel and the encoder always + # emits a padding mask; CUDNN and EFFICIENT both accept bool masks. + # The MHA-level smoke test covers FLASH with mask=None. + enc_manual = _build_encoder(use_pytorch_sdpa=False).to("cuda") + enc_sdpa = _build_encoder(use_pytorch_sdpa=True, use_pytorch_sdpa_backends=[backend]).to("cuda") + enc_sdpa.load_state_dict(enc_manual.state_dict(), strict=False) + + torch.manual_seed(0) + x_base = torch.randn(2, 80, 200, device='cuda') + x_manual = x_base.clone().requires_grad_(True) + x_sdpa = x_base.clone().requires_grad_(True) + lens = torch.tensor([200, 150], device='cuda') + + with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16): + o_manual, _ = enc_manual(audio_signal=x_manual, length=lens) + o_sdpa, _ = enc_sdpa(audio_signal=x_sdpa, length=lens) + + # Forward parity. + assert torch.allclose(o_manual.float(), o_sdpa.float(), atol=5e-2, rtol=5e-2) + + # Backward parity: same loss, compare input grads and weight grads. + o_manual.sum().backward() + o_sdpa.sum().backward() + assert torch.allclose(x_manual.grad.float(), x_sdpa.grad.float(), atol=5e-2, rtol=5e-2) + for (n1, p1), (n2, p2) in zip(enc_manual.named_parameters(), enc_sdpa.named_parameters()): + assert n1 == n2 + assert p1.grad is not None and p2.grad is not None, f"missing grad for {n1}" + assert torch.allclose(p1.grad.float(), p2.grad.float(), atol=5e-2, rtol=5e-2), f"grad mismatch for {n1}" + + @pytest.mark.unit + def test_padding_does_not_leak(self): + # Output for the valid prefix must be invariant to the values in the + # padded suffix. + enc = _build_encoder() + x = torch.randn(1, 80, 200) + valid_len = 120 + x1 = x.clone() + x1[0, :, valid_len:] = torch.randn(80, 200 - valid_len) + x2 = x.clone() + x2[0, :, valid_len:] = torch.randn(80, 200 - valid_len) + lens = torch.tensor([valid_len]) + with torch.no_grad(): + o1, _ = enc(audio_signal=x1, length=lens) + o2, _ = enc(audio_signal=x2, length=lens) + valid_out_len = valid_len // 4 + assert torch.allclose(o1[..., :valid_out_len], o2[..., :valid_out_len], atol=1e-5) + + @pytest.mark.unit + def test_change_attention_model_to_rope(self): + # Build a rel_pos encoder, swap to rope, run forward. + enc = _build_encoder(self_attention_model='rel_pos') + enc._cfg = OmegaConf.create( + { + 'd_model': 64, + 'n_heads': 4, + 'dropout': 0.0, + 'dropout_att': 0.0, + 'dropout_emb': 0.0, + 'pos_emb_max_len': 256, + 'rope_base': 10000.0, + 'rotary_fraction': 1.0, + } + ) + enc.change_attention_model('rope') + assert isinstance(enc.pos_enc, RotaryPositionalEncoding) + assert all(layer.self_attn.pos_enc is enc.pos_enc for layer in enc.layers) + x = torch.randn(2, 80, 200) + lens = torch.tensor([200, 150]) + out, _ = enc(audio_signal=x, length=lens) + assert torch.isfinite(out).all()