Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions nemo/collections/asr/modules/conformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
RelPositionalEncoding,
RelPositionMultiHeadAttention,
RelPositionMultiHeadAttentionLongformer,
RoPEMultiHeadAttention,
RotaryPositionalEncoding,
)
from nemo.collections.asr.parts.submodules.subsampling import (
ConvSubsampling,
Expand Down Expand Up @@ -98,10 +100,16 @@ class ConformerEncoder(NeuralModule, StreamingEncoder, Exportable, AccessMixin):
overlapping chunks. Attention context is determined by att_context_size parameter.
'abs_pos':
absolute positional embedding and Transformer
'rope':
rotary position embedding

Default is rel_pos.
pos_emb_max_len (int): the maximum length of positional embeddings
Defaults to 5000
rope_base (float): theta base for the rotary position embedding.
Defaults to 10000.
rotary_fraction (float): fraction of the per-head dim to rotate.
Defaults to 1.0.
n_heads (int): number of heads in multi-headed attention layers
Defaults to 4.
att_context_size (List[Union[List[int],int]]): specifies the context sizes on each side.
Expand Down Expand Up @@ -335,6 +343,8 @@ def __init__(
use_pytorch_sdpa: bool = False,
use_pytorch_sdpa_backends=None,
sync_max_audio_length: bool = True,
rope_base: float = 10000.0,
rotary_fraction: float = 1.0,
):
super().__init__()
d_ff = d_model * ff_expansion_factor
Expand Down Expand Up @@ -469,9 +479,17 @@ def __init__(
self.pos_enc = PositionalEncoding(
d_model=d_model, dropout_rate=dropout_pre_encoder, max_len=pos_emb_max_len, xscale=self.xscale
)
elif self_attention_model == "rope":
self.pos_enc = RotaryPositionalEncoding(
d_k=d_model // n_heads,
rotary_fraction=rotary_fraction,
rope_base=rope_base,
max_len=pos_emb_max_len,
)
else:
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")

layer_pos_enc = self.pos_enc if self_attention_model == 'rope' else None
self.layers = nn.ModuleList()
for i in range(n_layers):
layer = ConformerLayer(
Expand All @@ -493,6 +511,7 @@ def __init__(
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
pos_enc=layer_pos_enc,
)
self.layers.append(layer)

Expand Down Expand Up @@ -674,7 +693,10 @@ def forward_internal(
cache_len = 0
offset = None

audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
if self.self_attention_model == 'rope':
pos_emb = None
Comment on lines +696 to +697
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: For other attention models (rel_pos, abs_pos), self.pos_enc(x=audio_signal, ...) applies dropout_pre_encoder to the audio signal (via the self.dropout(x) inside RelPositionalEncoding.forward / PositionalEncoding.forward). When self_attention_model == 'rope', this call is skipped entirely, so any configured dropout_pre_encoder > 0 is silently ignored.

This means switching from rel_pos to rope (or using change_attention_model) with a non-zero dropout_pre_encoder will silently change regularization behavior.

Consider applying the dropout separately here, e.g.:

if self.self_attention_model == 'rope':
    pos_emb = None
    if hasattr(self.pos_enc, 'dropout'):
        audio_signal = self.pos_enc.dropout(audio_signal)

Or add an explicit nn.Dropout to RotaryPositionalEncoding / the encoder for this purpose.

else:
audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)

# Create the self-attention and padding masks
pad_mask, att_mask = self._create_masks(
Expand Down Expand Up @@ -733,7 +755,8 @@ def forward_internal(
max_audio_length = audio_signal.size(1)
# Don't update the audio_signal here because then it will again scale the audio_signal
# and cause an increase in the WER
_, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
if self.self_attention_model != 'rope':
_, pos_emb = self.pos_enc(x=audio_signal, cache_len=cache_len)
pad_mask, att_mask = self._create_masks(
att_context_size=cur_att_context_size,
padding_length=length,
Expand Down Expand Up @@ -1130,6 +1153,8 @@ def change_attention_model(
att_context_size: List[int] = None,
update_config: bool = True,
device: torch.device = None,
rope_base: float = None,
rotary_fraction: float = None,
):
"""
Update the self_attention_model which changes the positional encoding and attention layers.
Expand All @@ -1147,13 +1172,20 @@ def change_attention_model(
'abs_pos':
absolute positional embedding and Transformer

'rope':
rotary position embedding

If None is provided, the self_attention_model isn't changed. Defaults to None.
att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes,
or None to keep as it is. Defaults to None.
update_config (bool): Whether to update the config or not with the new attention model.
Defaults to True.
device (torch.device): If provided, new layers will be moved to the device.
Defaults to None.
rope_base (float): Theta base for rotary position embedding. Only used when
``self_attention_model='rope'``. If None, the stored config value is kept.
rotary_fraction (float): Fraction of the per-head dim to rotate. Only used when
``self_attention_model='rope'``. If None, the stored config value is kept.
"""

if att_context_size:
Expand All @@ -1164,6 +1196,11 @@ def change_attention_model(
if self_attention_model is None:
self_attention_model = self.self_attention_model

if rope_base is None:
rope_base = getattr(self._cfg, 'rope_base', 10000.0)
if rotary_fraction is None:
rotary_fraction = getattr(self._cfg, 'rotary_fraction', 1.0)

if self_attention_model == 'rel_pos_local_attn' and max(att_context_size) <= 0:
raise ValueError("When using local attention, context size must be set > 0")

Expand Down Expand Up @@ -1191,6 +1228,13 @@ def change_attention_model(
max_len=self._cfg.pos_emb_max_len,
xscale=self.xscale,
)
elif self_attention_model == "rope":
new_pos_enc = RotaryPositionalEncoding(
d_k=self._cfg.d_model // self._cfg.n_heads,
rotary_fraction=rotary_fraction,
rope_base=rope_base,
max_len=self._cfg.pos_emb_max_len,
)
else:
raise ValueError(f"Not valid self_attention_model: '{self_attention_model}'!")

Expand Down Expand Up @@ -1236,10 +1280,20 @@ def change_attention_model(
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
elif self_attention_model == 'rope':
new_attn = RoPEMultiHeadAttention(
n_head=self._cfg.n_heads,
n_feat=self._cfg.d_model,
dropout_rate=self._cfg.dropout_att,
pos_enc=new_pos_enc,
max_cache_len=att_context_size[0],
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
else:
raise ValueError(
f"'{self_attention_model}' is not not a valid value for 'self_attention_model', "
f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']"
f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos', 'rope']"
)
if device is not None:
new_attn = new_attn.to(device=device)
Expand All @@ -1252,6 +1306,9 @@ def change_attention_model(
with open_dict(self._cfg):
self._cfg.self_attention_model = self_attention_model
self._cfg.att_context_size = att_context_size
if self_attention_model == 'rope':
self._cfg.rope_base = rope_base
self._cfg.rotary_fraction = rotary_fraction

def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int):
"""
Expand Down Expand Up @@ -1470,6 +1527,7 @@ class ConformerChangeConfig:
'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using
overlapping chunks. Attention context is determined by att_context_size parameter.
'abs_pos': absolute positional embedding and Transformer
'rope': rotary position embedding
"""

# If None is provided, self_attention_model is not changed.
Expand All @@ -1479,3 +1537,9 @@ class ConformerChangeConfig:
# corresponding to left and right context, or -1 for full context.
# If None is provided, the attention context size isn't changed.
att_context_size: Optional[List[int]] = None

# Rotary position embedding parameters; only used when self_attention_model is
# being set to (or already is) 'rope'. If None, the values from the stored
# config are kept.
rope_base: Optional[float] = None
rotary_fraction: Optional[float] = None
28 changes: 25 additions & 3 deletions nemo/collections/asr/parts/mixins/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,12 @@ def change_conv_asr_se_context_window(self, context_window: int, update_config:
)

def change_attention_model(
self, self_attention_model: str = None, att_context_size: List[int] = None, update_config: bool = True
self,
self_attention_model: str = None,
att_context_size: List[int] = None,
update_config: bool = True,
rope_base: float = None,
rotary_fraction: float = None,
):
"""
Update the self_attention_model if function is available in encoder.
Expand All @@ -535,13 +540,20 @@ def change_attention_model(
'abs_pos':
absolute positional embedding and Transformer

'rope':
rotary position embedding

If None is provided, the self_attention_model isn't changed. Defauts to None.
att_context_size (List[int]): List of 2 ints corresponding to left and right attention context sizes,
or None to keep as it is. Defauts to None.
update_config (bool): Whether to update the config or not with the new attention model.
Defaults to True.
rope_base (float): Theta base for rotary position embedding. Only used when
``self_attention_model='rope'``. If None, the existing value is kept.
rotary_fraction (float): Fraction of the per-head dim to rotate. Only used when
``self_attention_model='rope'``. If None, the existing value is kept.
"""
if self_attention_model is None and att_context_size is None:
if self_attention_model is None and att_context_size is None and rope_base is None and rotary_fraction is None:
return

if not hasattr(self, 'encoder'):
Expand All @@ -555,11 +567,21 @@ def change_attention_model(
logging.info("Model encoder doesn't have a change_attention_model method ")
return

self.encoder.change_attention_model(self_attention_model, att_context_size, update_config, self.device)
self.encoder.change_attention_model(
self_attention_model=self_attention_model,
att_context_size=att_context_size,
update_config=update_config,
device=self.device,
rope_base=rope_base,
rotary_fraction=rotary_fraction,
)
if update_config:
with open_dict(self.cfg):
self.cfg.encoder.self_attention_model = self_attention_model
self.cfg.encoder.att_context_size = att_context_size
if self_attention_model == 'rope':
self.cfg.encoder.rope_base = self.encoder._cfg.rope_base
self.cfg.encoder.rotary_fraction = self.encoder._cfg.rotary_fraction

def change_subsampling_conv_chunking_factor(
self, subsampling_conv_chunking_factor: int, update_config: bool = True
Expand Down
22 changes: 20 additions & 2 deletions nemo/collections/asr/parts/submodules/conformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MultiHeadAttention,
RelPositionMultiHeadAttention,
RelPositionMultiHeadAttentionLongformer,
RoPEMultiHeadAttention,
)
from nemo.collections.asr.parts.utils.activations import Swish
from nemo.collections.common.parts.utils import activation_registry
Expand All @@ -43,6 +44,7 @@ class ConformerLayer(torch.nn.Module, AttentionAdapterModuleMixin, AccessMixin):
'rel_pos_local_attn': relative positional embedding and Transformer-XL with local attention using
overlapping chunks. Attention context is determined by att_context_size parameter.
'abs_pos': absolute positional embedding and Transformer
'rope': rotary position embedding
Default is rel_pos.
global_tokens (int): number of tokens to be used for global attention.
Only relevant if self_attention_model is 'rel_pos_local_attn'.
Expand Down Expand Up @@ -79,6 +81,7 @@ def __init__(
use_bias=True,
use_pytorch_sdpa=False,
use_pytorch_sdpa_backends=None,
pos_enc=None,
):
super(ConformerLayer, self).__init__()

Expand Down Expand Up @@ -144,10 +147,23 @@ def __init__(
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
elif self_attention_model == 'rope':
if pos_enc is None:
raise ValueError("'rope' attention requires a RotaryPositionalEncoding via pos_enc.")
self.self_attn = RoPEMultiHeadAttention(
n_head=n_heads,
n_feat=d_model,
dropout_rate=dropout_att,
pos_enc=pos_enc,
max_cache_len=MHA_max_cache_len,
use_bias=use_bias,
use_pytorch_sdpa=self.use_pytorch_sdpa,
use_pytorch_sdpa_backends=self.use_pytorch_sdpa_backends,
)
else:
raise ValueError(
f"'{self_attention_model}' is not not a valid value for 'self_attention_model', "
f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos']"
f"valid values can be from ['rel_pos', 'rel_pos_local_attn', 'abs_pos', 'rope']"
)

# second feed forward module
Expand Down Expand Up @@ -181,7 +197,9 @@ def forward(self, x, att_mask=None, pos_emb=None, pad_mask=None, cache_last_chan
x = self.self_attn(query=x, key=x, value=x, mask=att_mask, pos_emb=pos_emb, cache=cache_last_channel)
elif self.self_attention_model == 'rel_pos_local_attn':
x = self.self_attn(query=x, key=x, value=x, pad_mask=pad_mask, pos_emb=pos_emb, cache=cache_last_channel)
elif self.self_attention_model == 'abs_pos':
elif self.self_attention_model in ('abs_pos', 'rope'):
# 'rope' rotates Q/K inside MultiHeadAttention via the attached pos_enc,
# so the call site is identical to 'abs_pos' (no additive pos_emb).
x = self.self_attn(query=x, key=x, value=x, mask=att_mask, cache=cache_last_channel)
else:
x = None
Expand Down
Loading
Loading