diff --git a/nemo/collections/asr/modules/ssl_modules/masking.py b/nemo/collections/asr/modules/ssl_modules/masking.py index c491c56fa829..b89eca9b3974 100644 --- a/nemo/collections/asr/modules/ssl_modules/masking.py +++ b/nemo/collections/asr/modules/ssl_modules/masking.py @@ -13,6 +13,7 @@ # limitations under the License. +import math from typing import Optional, Union import torch @@ -62,16 +63,16 @@ def __init__( def input_types(self): """Returns definitions of module input types""" return { - "input_feats": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), - "input_lengths": NeuralType(tuple('B'), LengthsType()), + "input_feats": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()), + "input_lengths": NeuralType(tuple("B"), LengthsType()), } @property def output_types(self): """Returns definitions of module output types""" return { - "maksed_feats": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), - "masks": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()), + "maksed_feats": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()), + "masks": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()), } def forward(self, input_feats: torch.Tensor, input_lengths: torch.Tensor): @@ -88,95 +89,74 @@ def forward(self, input_feats: torch.Tensor, input_lengths: torch.Tensor): else: return self.forward_without_overlap(input_feats, input_lengths) - def forward_without_overlap(self, input_feats: torch.Tensor, input_lengths: torch.Tensor): - """ - Args: - input_feats (Tensor): input sequence features, shape=(batch, features, time) - input_length (Tensor): length of each sequence in the batch, shape=(batch) - Returns: - masked_feats (Tensor): masked features, shape=(batch, features, time) - masks (Tensor): the generated masks, shape=(batch, features, time) - """ - batch_size = input_feats.size(0) - masks = torch.zeros_like(input_feats) - masked_feats = input_feats - indices = [] - for i in range(batch_size): - if self.block_size >= input_lengths[i] * self.max_mask_ratio: - # handle case where audio is too short - block_size = 8 - num_patches = 1 - patch_indices = torch.tensor([0]) - offset = 0 - else: - num_patches = torch.ceil(input_lengths[i] * self.mask_prob / self.block_size).int() - offset = torch.randint(0, self.block_size, (1,))[0] - block_size = self.block_size - if (num_patches + 1) * self.block_size > input_lengths[i]: - block_size = torch.div(input_lengths[i], (num_patches + 1), rounding_mode='trunc') - max_num_patches = torch.div(input_lengths[i], block_size, rounding_mode='trunc') - patch_indices = torch.randperm(max_num_patches - 1)[:num_patches] - - if num_patches: - starts = patch_indices * block_size + offset - ends = starts + block_size - positions = torch.cat([torch.arange(s, e) for s, e in zip(starts, ends)]).reshape(-1, 1) - batch_index = torch.full((positions.shape[0], 1), i, dtype=positions.dtype) - positions = torch.cat([batch_index, positions], dim=1) - indices.append(positions.unique(dim=0)) - - if indices: - indices = torch.cat(indices, dim=0).unbind(1) - masks = masks.permute(0, 2, 1) - masked_feats = masked_feats.permute(0, 2, 1) - - masks = masks.index_put(indices, values=torch.tensor(1.0)).permute(0, 2, 1) - masked_feats = masked_feats.index_put(indices, values=self.mask_embedding).permute(0, 2, 1) - + def forward_without_overlap(self, input_feats, input_lengths): + batch_size, _, max_time = input_feats.shape + + num_patches = torch.ceil(input_lengths * self.mask_prob / self.block_size).long() + block_sizes = torch.full_like(input_lengths, self.block_size) + needs_shrink = (num_patches + 1) * block_sizes > input_lengths + block_sizes = torch.where( + needs_shrink, + input_lengths // (num_patches + 1), + block_sizes, + ).clamp_min(1) + + num_slots = (input_lengths // block_sizes - 1).clamp_min(0) + num_patches = torch.minimum(num_patches, num_slots) + + slots = torch.arange(max_time, device=input_feats.device) + valid_slots = slots.unsqueeze(0) < num_slots.unsqueeze(1) + scores = torch.rand(batch_size, max_time, device=input_feats.device) + scores.masked_fill_(~valid_slots, float("-inf")) + + max_patches = math.ceil(max_time * self.mask_prob / self.block_size) + selected_slots = scores.topk(max_patches, dim=1).indices + selected = torch.arange(max_patches, device=input_feats.device).unsqueeze(0) < num_patches.unsqueeze(1) + + offsets = (torch.rand(batch_size, device=input_feats.device) * block_sizes).long() + starts = selected_slots * block_sizes.unsqueeze(1) + offsets.unsqueeze(1) + ends = starts + block_sizes.unsqueeze(1) + starts = torch.where(selected, starts, 0) + ends = torch.where(selected, ends, 0) + deltas = selected.int() + + coverage_diff = torch.zeros( + batch_size, + max_time + 1, + dtype=torch.int32, + device=input_feats.device, + ) + coverage_diff.scatter_add_(1, starts, deltas) + coverage_diff.scatter_add_(1, ends, -deltas) + time_mask = coverage_diff[:, :max_time].cumsum(dim=1) > 0 + time_mask = time_mask.unsqueeze(1) + + masked_feats = torch.where( + time_mask, + self.mask_embedding.view(1, -1, 1), + input_feats, + ) + masks = time_mask.to(input_feats.dtype).expand_as(input_feats) return masked_feats, masks - def forward_with_overlap(self, input_feats: torch.Tensor, input_lengths: torch.Tensor): - """ - Args: - input_feats (Tensor): input sequence features, shape=(batch, features, time) - input_length (Tensor): length of each sequence in the batch, shape=(batch) - Returns: - masked_feats (Tensor): masked features, shape=(batch, features, time) - masks (Tensor): the generated masks, shape=(batch, features, time) - """ - batch_size = input_feats.size(0) - masks = torch.zeros_like(input_feats) - masked_feats = input_feats - mask_prob = torch.tensor(self.mask_prob) - indices = [] - for i in range(batch_size): - input_length = input_lengths[i].item() - if self.block_size >= input_length * self.max_mask_ratio: - # handle case where audio is too short - block_size = 8 - num_patches = 1 - patch_indices = torch.tensor([0]) - else: - block_size = self.block_size - count = max(0, input_length - self.block_size) - num_patches = torch.binomial(torch.tensor(count).float(), mask_prob).long() - patch_indices = torch.randperm(count) - patch_indices = patch_indices[:num_patches] - if num_patches: - ends = torch.clamp(patch_indices + block_size, max=input_length) - positions = torch.cat([torch.arange(s, e) for s, e in zip(patch_indices, ends)]).reshape(-1, 1) - batch_index = torch.full((positions.shape[0], 1), i, dtype=positions.dtype) - positions = torch.cat([batch_index, positions], dim=1) - indices.append(positions.unique(dim=0)) - - if indices: - indices = torch.cat(indices, dim=0).unbind(1) - masks = masks.permute(0, 2, 1) - masked_feats = masked_feats.permute(0, 2, 1) - - masks = masks.index_put(indices, values=torch.tensor(1.0)).permute(0, 2, 1) - masked_feats = masked_feats.index_put(indices, values=self.mask_embedding).permute(0, 2, 1) - + def forward_with_overlap(self, input_feats, input_lengths): + batch_size, _, max_time = input_feats.shape + positions = torch.arange(max_time, device=input_feats.device) + + valid_starts = positions + self.block_size <= input_lengths.unsqueeze(1) + start_mask = (torch.rand(batch_size, max_time, device=input_feats.device) < self.mask_prob) & valid_starts + time_mask = torch.nn.functional.max_pool1d( + torch.nn.functional.pad(start_mask.unsqueeze(1).float(), (self.block_size - 1, 0)), + kernel_size=self.block_size, + stride=1, + ).bool() + + masked_feats = torch.where( + time_mask, + self.mask_embedding.view(1, -1, 1), + input_feats, + ) + masks = time_mask.to(input_feats.dtype).expand_as(input_feats) return masked_feats, masks