Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 71 additions & 91 deletions nemo/collections/asr/modules/ssl_modules/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.


import math
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -62,16 +63,16 @@ def __init__(
def input_types(self):
"""Returns definitions of module input types"""
return {
"input_feats": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"input_lengths": NeuralType(tuple('B'), LengthsType()),
"input_feats": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()),
"input_lengths": NeuralType(tuple("B"), LengthsType()),
}

@property
def output_types(self):
"""Returns definitions of module output types"""
return {
"maksed_feats": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"masks": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
"maksed_feats": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()),
"masks": NeuralType(("B", "D", "T"), AcousticEncodedRepresentation()),
}

def forward(self, input_feats: torch.Tensor, input_lengths: torch.Tensor):
Expand All @@ -88,95 +89,74 @@ def forward(self, input_feats: torch.Tensor, input_lengths: torch.Tensor):
else:
return self.forward_without_overlap(input_feats, input_lengths)

def forward_without_overlap(self, input_feats: torch.Tensor, input_lengths: torch.Tensor):
"""
Args:
input_feats (Tensor): input sequence features, shape=(batch, features, time)
input_length (Tensor): length of each sequence in the batch, shape=(batch)
Returns:
masked_feats (Tensor): masked features, shape=(batch, features, time)
masks (Tensor): the generated masks, shape=(batch, features, time)
"""
batch_size = input_feats.size(0)
masks = torch.zeros_like(input_feats)
masked_feats = input_feats
indices = []
for i in range(batch_size):
if self.block_size >= input_lengths[i] * self.max_mask_ratio:
# handle case where audio is too short
block_size = 8
num_patches = 1
patch_indices = torch.tensor([0])
offset = 0
else:
num_patches = torch.ceil(input_lengths[i] * self.mask_prob / self.block_size).int()
offset = torch.randint(0, self.block_size, (1,))[0]
block_size = self.block_size
if (num_patches + 1) * self.block_size > input_lengths[i]:
block_size = torch.div(input_lengths[i], (num_patches + 1), rounding_mode='trunc')
max_num_patches = torch.div(input_lengths[i], block_size, rounding_mode='trunc')
patch_indices = torch.randperm(max_num_patches - 1)[:num_patches]

if num_patches:
starts = patch_indices * block_size + offset
ends = starts + block_size
positions = torch.cat([torch.arange(s, e) for s, e in zip(starts, ends)]).reshape(-1, 1)
batch_index = torch.full((positions.shape[0], 1), i, dtype=positions.dtype)
positions = torch.cat([batch_index, positions], dim=1)
indices.append(positions.unique(dim=0))

if indices:
indices = torch.cat(indices, dim=0).unbind(1)
masks = masks.permute(0, 2, 1)
masked_feats = masked_feats.permute(0, 2, 1)

masks = masks.index_put(indices, values=torch.tensor(1.0)).permute(0, 2, 1)
masked_feats = masked_feats.index_put(indices, values=self.mask_embedding).permute(0, 2, 1)

def forward_without_overlap(self, input_feats, input_lengths):
batch_size, _, max_time = input_feats.shape

num_patches = torch.ceil(input_lengths * self.mask_prob / self.block_size).long()
block_sizes = torch.full_like(input_lengths, self.block_size)
needs_shrink = (num_patches + 1) * block_sizes > input_lengths
block_sizes = torch.where(
needs_shrink,
input_lengths // (num_patches + 1),
block_sizes,
).clamp_min(1)

num_slots = (input_lengths // block_sizes - 1).clamp_min(0)
num_patches = torch.minimum(num_patches, num_slots)

slots = torch.arange(max_time, device=input_feats.device)
valid_slots = slots.unsqueeze(0) < num_slots.unsqueeze(1)
scores = torch.rand(batch_size, max_time, device=input_feats.device)
scores.masked_fill_(~valid_slots, float("-inf"))

max_patches = math.ceil(max_time * self.mask_prob / self.block_size)
selected_slots = scores.topk(max_patches, dim=1).indices
selected = torch.arange(max_patches, device=input_feats.device).unsqueeze(0) < num_patches.unsqueeze(1)

offsets = (torch.rand(batch_size, device=input_feats.device) * block_sizes).long()
starts = selected_slots * block_sizes.unsqueeze(1) + offsets.unsqueeze(1)
ends = starts + block_sizes.unsqueeze(1)
starts = torch.where(selected, starts, 0)
ends = torch.where(selected, ends, 0)
deltas = selected.int()

coverage_diff = torch.zeros(
batch_size,
max_time + 1,
dtype=torch.int32,
device=input_feats.device,
)
coverage_diff.scatter_add_(1, starts, deltas)
coverage_diff.scatter_add_(1, ends, -deltas)
time_mask = coverage_diff[:, :max_time].cumsum(dim=1) > 0
time_mask = time_mask.unsqueeze(1)

masked_feats = torch.where(
time_mask,
self.mask_embedding.view(1, -1, 1),
input_feats,
)
masks = time_mask.to(input_feats.dtype).expand_as(input_feats)
return masked_feats, masks

def forward_with_overlap(self, input_feats: torch.Tensor, input_lengths: torch.Tensor):
"""
Args:
input_feats (Tensor): input sequence features, shape=(batch, features, time)
input_length (Tensor): length of each sequence in the batch, shape=(batch)
Returns:
masked_feats (Tensor): masked features, shape=(batch, features, time)
masks (Tensor): the generated masks, shape=(batch, features, time)
"""
batch_size = input_feats.size(0)
masks = torch.zeros_like(input_feats)
masked_feats = input_feats
mask_prob = torch.tensor(self.mask_prob)
indices = []
for i in range(batch_size):
input_length = input_lengths[i].item()
if self.block_size >= input_length * self.max_mask_ratio:
# handle case where audio is too short
block_size = 8
num_patches = 1
patch_indices = torch.tensor([0])
else:
block_size = self.block_size
count = max(0, input_length - self.block_size)
num_patches = torch.binomial(torch.tensor(count).float(), mask_prob).long()
patch_indices = torch.randperm(count)
patch_indices = patch_indices[:num_patches]
if num_patches:
ends = torch.clamp(patch_indices + block_size, max=input_length)
positions = torch.cat([torch.arange(s, e) for s, e in zip(patch_indices, ends)]).reshape(-1, 1)
batch_index = torch.full((positions.shape[0], 1), i, dtype=positions.dtype)
positions = torch.cat([batch_index, positions], dim=1)
indices.append(positions.unique(dim=0))

if indices:
indices = torch.cat(indices, dim=0).unbind(1)
masks = masks.permute(0, 2, 1)
masked_feats = masked_feats.permute(0, 2, 1)

masks = masks.index_put(indices, values=torch.tensor(1.0)).permute(0, 2, 1)
masked_feats = masked_feats.index_put(indices, values=self.mask_embedding).permute(0, 2, 1)

def forward_with_overlap(self, input_feats, input_lengths):
batch_size, _, max_time = input_feats.shape
positions = torch.arange(max_time, device=input_feats.device)

valid_starts = positions + self.block_size <= input_lengths.unsqueeze(1)
start_mask = (torch.rand(batch_size, max_time, device=input_feats.device) < self.mask_prob) & valid_starts
time_mask = torch.nn.functional.max_pool1d(
torch.nn.functional.pad(start_mask.unsqueeze(1).float(), (self.block_size - 1, 0)),
kernel_size=self.block_size,
stride=1,
).bool()

masked_feats = torch.where(
time_mask,
self.mask_embedding.view(1, -1, 1),
input_feats,
)
masks = time_mask.to(input_feats.dtype).expand_as(input_feats)
return masked_feats, masks


Expand Down
Loading