Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions nemo/collections/asr/modules/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import torch.nn as nn
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from nemo.core.classes.module import freeze, unfreeze
from nemo.utils.decorators import experimental

flex_attention_compiled = torch.compile(flex_attention, dynamic=True)


Expand Down Expand Up @@ -151,6 +154,7 @@ def forward(self, x, block_mask=None):
return x


@experimental
class TransformerEncoder(nn.Module):
"""Pre-norm Transformer encoder for ASR.

Expand Down Expand Up @@ -239,3 +243,9 @@ def forward(self, audio_signal, length):
x = self.final_norm(x)
x = x.transpose(1, 2) # (B, T, D) -> (B, D, T)
return x, length

def freeze(self) -> None:
freeze(self)

def unfreeze(self, partial: bool = False) -> None:
unfreeze(self, partial=partial)
134 changes: 54 additions & 80 deletions nemo/core/classes/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,50 @@
from nemo.core.classes.common import FileIO, Serialization, Typing
from nemo.utils import logging

__all__ = ['NeuralModule']
__all__ = ['NeuralModule', 'freeze', 'unfreeze']


def freeze(module: Module) -> None:
"""Freeze all parameters of ``module`` and snapshot their prior ``requires_grad`` state.

The snapshot is stored on ``module._frozen_grad_map`` so a later call to ``unfreeze(..., partial=True)``
can restore the pre-freeze state instead of unconditionally enabling gradients.
"""
grad_map = {pname: param.requires_grad for pname, param in module.named_parameters()}
for param in module.parameters():
param.requires_grad = False
if not hasattr(module, '_frozen_grad_map'):
module._frozen_grad_map = grad_map
else:
module._frozen_grad_map.update(grad_map)
module.eval()


def unfreeze(module: Module, partial: bool = False) -> None:
"""Unfreeze parameters of ``module``.

If ``partial=True``, restore each parameter's ``requires_grad`` from the snapshot recorded by
``freeze(module)``; otherwise enable gradients on every parameter. The snapshot is cleared in
both cases and ``module.train()`` is called.
"""
if partial and not hasattr(module, '_frozen_grad_map'):
raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")

for pname, param in module.named_parameters():
if not partial:
param.requires_grad = True
elif pname in module._frozen_grad_map:
param.requires_grad = module._frozen_grad_map[pname]
else:
logging.warning(
f"Parameter {pname} not found in list of previously frozen parameters. Unfreezing this parameter."
)
param.requires_grad = True

if hasattr(module, '_frozen_grad_map'):
delattr(module, '_frozen_grad_map')

module.train()


class NeuralModule(Module, Typing, Serialization, FileIO):
Expand Down Expand Up @@ -53,99 +96,30 @@ def input_example(self, max_batch=None, max_dim=None):
return None

def freeze(self) -> None:
r"""
Freeze all params for inference.

This method sets `requires_grad` to False for all parameters of the module.
It also stores the original `requires_grad` state of each parameter in a dictionary,
so that `unfreeze()` can restore the original state if `partial=True` is set in `unfreeze()`.
"""
grad_map = {}

for pname, param in self.named_parameters():
# Store the original grad state
grad_map[pname] = param.requires_grad
# Freeze the parameter
param.requires_grad = False

# Store the frozen grad map
if not hasattr(self, '_frozen_grad_map'):
self._frozen_grad_map = grad_map
else:
self._frozen_grad_map.update(grad_map)

self.eval()
r"""Freeze all params for inference. See :func:`freeze` for details."""
freeze(self)

def unfreeze(self, partial: bool = False) -> None:
"""
Unfreeze all parameters for training.

Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
previously unfrozen prior `freeze()`.
"""Unfreeze parameters for training. See :func:`unfreeze` for details.

Example:
Consider a model that has an encoder and a decoder module. Assume we want the encoder to be frozen always.

```python
model.encoder.freeze() # Freezes all parameters in the encoder explicitly
```

During inference, all parameters of the model should be frozen - we do this by calling the model's freeze method.
This step records that the encoder module parameters were already frozen, and so if partial unfreeze is called,
we should keep the encoder parameters frozen.

```python
model.freeze() # Freezes all parameters in the model; encoder remains frozen
model.encoder.freeze() # caller freezes encoder
model.freeze() # freezes everything; encoder snapshot preserved
model.unfreeze(partial=True) # decoder unfrozen, encoder stays frozen
```

Now, during fine-tuning, we want to unfreeze the decoder but keep the encoder frozen. We can do this by calling
`unfreeze(partial=True)`.

```python
model.unfreeze(partial=True) # Unfreezes only the decoder; encoder remains frozen
```

Args:
partial: If True, only unfreeze parameters that were previously frozen. If the parameter was already frozen
when calling `freeze()`, it will remain frozen after calling `unfreeze(partial=True)`.
"""
if partial and not hasattr(self, '_frozen_grad_map'):
raise ValueError("Cannot unfreeze partially without first freezing the module with `freeze()`")

for pname, param in self.named_parameters():
if not partial:
# Unfreeze all parameters
param.requires_grad = True
else:
# Unfreeze only parameters that were previously frozen

# Check if the parameter was frozen
if pname in self._frozen_grad_map:
param.requires_grad = self._frozen_grad_map[pname]
else:
# Log a warning if the parameter was not found in the frozen grad map
logging.warning(
f"Parameter {pname} not found in list of previously frozen parameters. "
f"Unfreezing this parameter."
)
param.requires_grad = True

# Clean up the frozen grad map
if hasattr(self, '_frozen_grad_map'):
delattr(self, '_frozen_grad_map')

self.train()
unfreeze(self, partial=partial)

@contextmanager
def as_frozen(self):
"""
Context manager which temporarily freezes a module, yields control and finally unfreezes the module partially
to return to original state.

Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen previously with `freeze()`).
The `partial` argument is used to determine whether to unfreeze all parameters or only the parameters that were
previously unfrozen prior `freeze()`.
Allows for either total unfreeze or partial unfreeze (if the module was explicitly frozen
previously with `freeze()`). The `partial` argument is used to determine whether to unfreeze
all parameters or only the parameters that were previously unfrozen prior `freeze()`.

Example:
with model.as_frozen(): # by default, partial = True
Expand Down
15 changes: 15 additions & 0 deletions tests/collections/asr/test_transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ def test_invalid_attn_mode(self):
with pytest.raises(ValueError, match="not yet supported"):
TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2, attn_mode="causal")

@pytest.mark.unit
def test_freeze_unfreeze_partial_restores_prior_state(self):
model = TransformerEncoder(feat_in=80, d_model=64, n_heads=4, n_layers=2)
for p in model.final_norm.parameters():
p.requires_grad = False
prior = {n: p.requires_grad for n, p in model.named_parameters()}

model.freeze()
assert all(not p.requires_grad for p in model.parameters())
assert not model.training

model.unfreeze(partial=True)
assert {n: p.requires_grad for n, p in model.named_parameters()} == prior
assert model.training

@pytest.mark.unit
def test_forward_cpu(self):
"""Forward pass on CPU uses unfused FlexAttention fallback."""
Expand Down
Loading