diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index f2af64cb8974..75951f9ae690 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -18,6 +18,9 @@ import torch.nn as nn from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from nemo.core.classes.module import freeze, unfreeze +from nemo.utils.decorators import experimental + flex_attention_compiled = torch.compile(flex_attention, dynamic=True) @@ -151,6 +154,7 @@ def forward(self, x, block_mask=None): return x +@experimental class TransformerEncoder(nn.Module): """Pre-norm Transformer encoder for ASR. @@ -239,3 +243,9 @@ def forward(self, audio_signal, length): x = self.final_norm(x) x = x.transpose(1, 2) # (B, T, D) -> (B, D, T) return x, length + + def freeze(self) -> None: + freeze(self) + + def unfreeze(self, partial: bool = False) -> None: + unfreeze(self, partial=partial) diff --git a/nemo/core/classes/module.py b/nemo/core/classes/module.py index ef80467c8c7a..39d20b2f39ae 100644 --- a/nemo/core/classes/module.py +++ b/nemo/core/classes/module.py @@ -20,7 +20,50 @@ from nemo.core.classes.common import FileIO, Serialization, Typing from nemo.utils import logging -__all__ = ['NeuralModule'] +__all__ = ['NeuralModule', 'freeze', 'unfreeze'] + + +def freeze(module: Module) -> None: + """Freeze all parameters of ``module`` and snapshot their prior ``requires_grad`` state. + + The snapshot is stored on ``module._frozen_grad_map`` so a later call to ``unfreeze(..., partial=True)`` + can restore the pre-freeze state instead of unconditionally enabling gradients. + """ + grad_map = {pname: param.requires_grad for pname, param in module.named_parameters()} + for param in module.parameters(): + param.requires_grad = False + if not hasattr(module, '_frozen_grad_map'): + module._frozen_grad_map = grad_map + else: + module._frozen_grad_map.update(grad_map) + module.eval() + + +def unfreeze(module: Module, partial: bool = False) -> None: + """Unfreeze parameters of ``module``. + + If ``partial=True``, restore each parameter's ``requires_grad`` from the snapshot recorded by + ``freeze(module)``; otherwise enable gradients on every parameter. The snapshot is cleared in + both cases and ``module.train()`` is called. + """ + if partial and not hasattr(module, '_frozen_grad_map'): + raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") + + for pname, param in module.named_parameters(): + if not partial: + param.requires_grad = True + elif pname in module._frozen_grad_map: + param.requires_grad = module._frozen_grad_map[pname] + else: + logging.warning( + f"Parameter {pname} not found in list of previously frozen parameters. Unfreezing this parameter." + ) + param.requires_grad = True + + if hasattr(module, '_frozen_grad_map'): + delattr(module, '_frozen_grad_map') + + module.train() class NeuralModule(Module, Typing, Serialization, FileIO): @@ -53,89 +96,20 @@ def input_example(self, max_batch=None, max_dim=None): return None def freeze(self) -> None: - r""" - Freeze all params for inference. - - This method sets `requires_grad` to False for all parameters of the module. - It also stores the original `requires_grad` state of each parameter in a dictionary, - so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`. - """ - grad_map = {} - - for pname, param in self.named_parameters(): - # Store the original grad state - grad_map[pname] = param.requires_grad - # Freeze the parameter - param.requires_grad = False - - # Store the frozen grad map - if not hasattr(self, '_frozen_grad_map'): - self._frozen_grad_map = grad_map - else: - self._frozen_grad_map.update(grad_map) - - self.eval() + r"""Freeze all params for inference. See :func:`freeze` for details.""" + freeze(self) def unfreeze(self, partial: bool = False) -> None: - """ - Unfreeze all parameters for training. - - Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). - The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were - previously unfrozen prior `freeze()`. + """Unfreeze parameters for training. See :func:`unfreeze` for details. Example: - Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always. - - ```python - model.encoder.freeze() # Freezes all parameters in the encoder explicitly - ``` - - During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method. - This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called, - we should keep the encoder parameters frozen. - ```python - model.freeze() # Freezes all parameters in the model; encoder remains frozen + model.encoder.freeze() # caller freezes encoder + model.freeze() # freezes everything; encoder snapshot preserved + model.unfreeze(partial=True) # decoder unfrozen, encoder stays frozen ``` - - Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling - `unfreeze(partial=True)`. - - ```python - model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen - ``` - - Args: - partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen - when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`. """ - if partial and not hasattr(self, '_frozen_grad_map'): - raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") - - for pname, param in self.named_parameters(): - if not partial: - # Unfreeze all parameters - param.requires_grad = True - else: - # Unfreeze only parameters that were previously frozen - - # Check if the parameter was frozen - if pname in self._frozen_grad_map: - param.requires_grad = self._frozen_grad_map[pname] - else: - # Log a warning if the parameter was not found in the frozen grad map - logging.warning( - f"Parameter {pname} not found in list of previously frozen parameters. " - f"Unfreezing this parameter." - ) - param.requires_grad = True - - # Clean up the frozen grad map - if hasattr(self, '_frozen_grad_map'): - delattr(self, '_frozen_grad_map') - - self.train() + unfreeze(self, partial=partial) @contextmanager def as_frozen(self): @@ -143,9 +117,9 @@ def as_frozen(self): Context manager which temporarily freezes a module, yields control and finally unfreezes the module partially to return to original state. - Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). - The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were - previously unfrozen prior `freeze()`. + Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen + previously with `freeze()`). The `partial` argument is used to determine whether to unfreeze + all parameters or only the parameters that were previously unfrozen prior `freeze()`. Example: with model.as_frozen(): # by default, partial = True diff --git a/tests/collections/asr/test_transformer_encoder.py b/tests/collections/asr/test_transformer_encoder.py index 0cc2f174a1e5..b548ef96edf3 100644 --- a/tests/collections/asr/test_transformer_encoder.py +++ b/tests/collections/asr/test_transformer_encoder.py @@ -126,6 +126,21 @@ def test_invalid_attn_mode(self): with pytest.raises(ValueError, match="not yet supported"): TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, attn_mode="causal") + @pytest.mark.unit + def test_freeze_unfreeze_partial_restores_prior_state(self): + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2) + for p in model.final_norm.parameters(): + p.requires_grad = False + prior = {n: p.requires_grad for n, p in model.named_parameters()} + + model.freeze() + assert all(not p.requires_grad for p in model.parameters()) + assert not model.training + + model.unfreeze(partial=True) + assert {n: p.requires_grad for n, p in model.named_parameters()} == prior + assert model.training + @pytest.mark.unit def test_forward_cpu(self): """Forward pass on CPU uses unfused FlexAttention fallback."""