From 1fc31e0370695240ca29e5fdefe6ee3af21d96b1 Mon Sep 17 00:00:00 2001 From: nithinraok Date: Mon, 25 May 2026 12:43:49 -0700 Subject: [PATCH 1/4] add freeze, unfreeze methods with experimental tag Signed-off-by: nithinraok --- .../collections/asr/modules/transformer_encoder.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index f2af64cb8974..8d2375de0698 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -18,6 +18,8 @@ import torch.nn as nn from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from nemo.utils.decorators import experimental + flex_attention_compiled = torch.compile(flex_attention, dynamic=True) @@ -151,6 +153,7 @@ def forward(self, x, block_mask=None): return x +@experimental class TransformerEncoder(nn.Module): """Pre-norm Transformer encoder for ASR. @@ -239,3 +242,14 @@ def forward(self, audio_signal, length): x = self.final_norm(x) x = x.transpose(1, 2) # (B, T, D) -> (B, D, T) return x, length + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.eval() + + def unfreeze(self, partial: bool = False): + for p in self.parameters(): + p.requires_grad = True + if not partial: + self.train() From e1f0bc437dcf6b2c32c7441ecefb9bd7518ea446 Mon Sep 17 00:00:00 2001 From: nithinraok Date: Mon, 25 May 2026 13:46:15 -0700 Subject: [PATCH 2/4] Add test case Signed-off-by: nithinraok --- .../asr/modules/transformer_encoder.py | 21 +++++++++++++++---- .../asr/test_transformer_encoder.py | 15 +++++++++++++ 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index 8d2375de0698..f85a1623e632 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -244,12 +244,25 @@ def forward(self, audio_signal, length): return x, length def freeze(self): + """Freeze all parameters, recording prior ``requires_grad`` so ``unfreeze(partial=True)`` can restore it.""" + grad_map = {name: p.requires_grad for name, p in self.named_parameters()} for p in self.parameters(): p.requires_grad = False + if not hasattr(self, "_frozen_grad_map"): + self._frozen_grad_map = grad_map + else: + self._frozen_grad_map.update(grad_map) self.eval() def unfreeze(self, partial: bool = False): - for p in self.parameters(): - p.requires_grad = True - if not partial: - self.train() + """Unfreeze parameters. ``partial=True`` restores the pre-``freeze()`` state from ``_frozen_grad_map``.""" + if partial and not hasattr(self, "_frozen_grad_map"): + raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") + for name, p in self.named_parameters(): + if partial: + p.requires_grad = self._frozen_grad_map.get(name, True) + else: + p.requires_grad = True + if hasattr(self, "_frozen_grad_map"): + del self._frozen_grad_map + self.train() diff --git a/tests/collections/asr/test_transformer_encoder.py b/tests/collections/asr/test_transformer_encoder.py index 0cc2f174a1e5..b548ef96edf3 100644 --- a/tests/collections/asr/test_transformer_encoder.py +++ b/tests/collections/asr/test_transformer_encoder.py @@ -126,6 +126,21 @@ def test_invalid_attn_mode(self): with pytest.raises(ValueError, match="not yet supported"): TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, attn_mode="causal") + @pytest.mark.unit + def test_freeze_unfreeze_partial_restores_prior_state(self): + model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2) + for p in model.final_norm.parameters(): + p.requires_grad = False + prior = {n: p.requires_grad for n, p in model.named_parameters()} + + model.freeze() + assert all(not p.requires_grad for p in model.parameters()) + assert not model.training + + model.unfreeze(partial=True) + assert {n: p.requires_grad for n, p in model.named_parameters()} == prior + assert model.training + @pytest.mark.unit def test_forward_cpu(self): """Forward pass on CPU uses unfused FlexAttention fallback.""" From 2c81b2bcf843dba247e2150786729d77b0c0afd3 Mon Sep 17 00:00:00 2001 From: nithinraok Date: Thu, 28 May 2026 08:48:11 -0700 Subject: [PATCH 3/4] refactor to module level functions Signed-off-by: nithinraok --- .../asr/modules/transformer_encoder.py | 29 +--- nemo/core/classes/module.py | 128 +++++++----------- 2 files changed, 57 insertions(+), 100 deletions(-) diff --git a/nemo/collections/asr/modules/transformer_encoder.py b/nemo/collections/asr/modules/transformer_encoder.py index f85a1623e632..75951f9ae690 100644 --- a/nemo/collections/asr/modules/transformer_encoder.py +++ b/nemo/collections/asr/modules/transformer_encoder.py @@ -18,6 +18,7 @@ import torch.nn as nn from torch.nn.attention.flex_attention import create_block_mask, flex_attention +from nemo.core.classes.module import freeze, unfreeze from nemo.utils.decorators import experimental flex_attention_compiled = torch.compile(flex_attention, dynamic=True) @@ -243,26 +244,8 @@ def forward(self, audio_signal, length): x = x.transpose(1, 2) # (B, T, D) -> (B, D, T) return x, length - def freeze(self): - """Freeze all parameters, recording prior ``requires_grad`` so ``unfreeze(partial=True)`` can restore it.""" - grad_map = {name: p.requires_grad for name, p in self.named_parameters()} - for p in self.parameters(): - p.requires_grad = False - if not hasattr(self, "_frozen_grad_map"): - self._frozen_grad_map = grad_map - else: - self._frozen_grad_map.update(grad_map) - self.eval() - - def unfreeze(self, partial: bool = False): - """Unfreeze parameters. ``partial=True`` restores the pre-``freeze()`` state from ``_frozen_grad_map``.""" - if partial and not hasattr(self, "_frozen_grad_map"): - raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") - for name, p in self.named_parameters(): - if partial: - p.requires_grad = self._frozen_grad_map.get(name, True) - else: - p.requires_grad = True - if hasattr(self, "_frozen_grad_map"): - del self._frozen_grad_map - self.train() + def freeze(self) -> None: + freeze(self) + + def unfreeze(self, partial: bool = False) -> None: + unfreeze(self, partial=partial) diff --git a/nemo/core/classes/module.py b/nemo/core/classes/module.py index ef80467c8c7a..74beeb7a1ec3 100644 --- a/nemo/core/classes/module.py +++ b/nemo/core/classes/module.py @@ -20,7 +20,50 @@ from nemo.core.classes.common import FileIO, Serialization, Typing from nemo.utils import logging -__all__ = ['NeuralModule'] +__all__ = ['NeuralModule', 'freeze', 'unfreeze'] + + +def freeze(module: Module) -> None: + """Freeze all parameters of ``module`` and snapshot their prior ``requires_grad`` state. + + The snapshot is stored on ``module._frozen_grad_map`` so a later call to ``unfreeze(..., partial=True)`` + can restore the pre-freeze state instead of unconditionally enabling gradients. + """ + grad_map = {pname: param.requires_grad for pname, param in module.named_parameters()} + for param in module.parameters(): + param.requires_grad = False + if not hasattr(module, '_frozen_grad_map'): + module._frozen_grad_map = grad_map + else: + module._frozen_grad_map.update(grad_map) + module.eval() + + +def unfreeze(module: Module, partial: bool = False) -> None: + """Unfreeze parameters of ``module``. + + If ``partial=True``, restore each parameter's ``requires_grad`` from the snapshot recorded by + ``freeze(module)``; otherwise enable gradients on every parameter. The snapshot is cleared in + both cases and ``module.train()`` is called. + """ + if partial and not hasattr(module, '_frozen_grad_map'): + raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") + + for pname, param in module.named_parameters(): + if not partial: + param.requires_grad = True + elif pname in module._frozen_grad_map: + param.requires_grad = module._frozen_grad_map[pname] + else: + logging.warning( + f"Parameter {pname} not found in list of previously frozen parameters. Unfreezing this parameter." + ) + param.requires_grad = True + + if hasattr(module, '_frozen_grad_map'): + delattr(module, '_frozen_grad_map') + + module.train() class NeuralModule(Module, Typing, Serialization, FileIO): @@ -53,89 +96,20 @@ def input_example(self, max_batch=None, max_dim=None): return None def freeze(self) -> None: - r""" - Freeze all params for inference. - - This method sets `requires_grad` to False for all parameters of the module. - It also stores the original `requires_grad` state of each parameter in a dictionary, - so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`. - """ - grad_map = {} - - for pname, param in self.named_parameters(): - # Store the original grad state - grad_map[pname] = param.requires_grad - # Freeze the parameter - param.requires_grad = False - - # Store the frozen grad map - if not hasattr(self, '_frozen_grad_map'): - self._frozen_grad_map = grad_map - else: - self._frozen_grad_map.update(grad_map) - - self.eval() + r"""Freeze all params for inference. See :func:`freeze` for details.""" + freeze(self) def unfreeze(self, partial: bool = False) -> None: - """ - Unfreeze all parameters for training. - - Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). - The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were - previously unfrozen prior `freeze()`. + """Unfreeze parameters for training. See :func:`unfreeze` for details. Example: - Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always. - - ```python - model.encoder.freeze() # Freezes all parameters in the encoder explicitly - ``` - - During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method. - This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called, - we should keep the encoder parameters frozen. - ```python - model.freeze() # Freezes all parameters in the model; encoder remains frozen + model.encoder.freeze() # caller freezes encoder + model.freeze() # freezes everything; encoder snapshot preserved + model.unfreeze(partial=True) # decoder unfrozen, encoder stays frozen ``` - - Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling - `unfreeze(partial=True)`. - - ```python - model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen - ``` - - Args: - partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen - when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`. """ - if partial and not hasattr(self, '_frozen_grad_map'): - raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`") - - for pname, param in self.named_parameters(): - if not partial: - # Unfreeze all parameters - param.requires_grad = True - else: - # Unfreeze only parameters that were previously frozen - - # Check if the parameter was frozen - if pname in self._frozen_grad_map: - param.requires_grad = self._frozen_grad_map[pname] - else: - # Log a warning if the parameter was not found in the frozen grad map - logging.warning( - f"Parameter {pname} not found in list of previously frozen parameters. " - f"Unfreezing this parameter." - ) - param.requires_grad = True - - # Clean up the frozen grad map - if hasattr(self, '_frozen_grad_map'): - delattr(self, '_frozen_grad_map') - - self.train() + unfreeze(self, partial=partial) @contextmanager def as_frozen(self): From a6db2622e9f3584d48906e8923126500adaf944b Mon Sep 17 00:00:00 2001 From: nithinraok Date: Thu, 28 May 2026 08:50:13 -0700 Subject: [PATCH 4/4] linting fixes Signed-off-by: nithinraok --- nemo/core/classes/module.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/core/classes/module.py b/nemo/core/classes/module.py index 74beeb7a1ec3..39d20b2f39ae 100644 --- a/nemo/core/classes/module.py +++ b/nemo/core/classes/module.py @@ -117,9 +117,9 @@ def as_frozen(self): Context manager which temporarily freezes a module, yields control and finally unfreezes the module partially to return to original state. - Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`). - The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were - previously unfrozen prior `freeze()`. + Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen + previously with `freeze()`). The `partial` argument is used to determine whether to unfreeze + all parameters or only the parameters that were previously unfrozen prior `freeze()`. Example: with model.as_frozen(): # by default, partial = True