Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions nemo/collections/asr/parts/utils/multispk_transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1416,6 +1416,8 @@ def __init__(
self._uppercase_first_letter = uppercase_first_letter
self._speaker_wise_sentences = {}
self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)]
self._prev_token_counts = [0 for _ in range(self.max_num_of_spks)]
self._prev_decoded_lengths = [0 for _ in range(self.max_num_of_spks)]

self.seglsts = []

Expand All @@ -1425,6 +1427,8 @@ def _reset_speaker_wise_sentences(self):
"""
self._speaker_wise_sentences = {}
self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)]
self._prev_token_counts = [0 for _ in range(self.max_num_of_spks)]
self._prev_decoded_lengths = [0 for _ in range(self.max_num_of_spks)]

def reset(self, asr_cache_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
"""
Expand All @@ -1443,6 +1447,8 @@ def reset(self, asr_cache_state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
self.seglsts = []
self._speaker_wise_sentences = {}
self._prev_history_speaker_texts = ["" for _ in range(self.max_num_of_spks)]
self._prev_token_counts = [0 for _ in range(self.max_num_of_spks)]
self._prev_decoded_lengths = [0 for _ in range(self.max_num_of_spks)]

def update_asr_state(
self,
Expand Down Expand Up @@ -1515,7 +1521,10 @@ def _update_last_sentence(self, spk_idx: int, end_time: float, diff_text: str):
diff_text (str): The difference text.
"""
if end_time is not None:
self._speaker_wise_sentences[spk_idx][-1]['end_time'] = end_time
current_start = self._speaker_wise_sentences[spk_idx][-1]['start_time']
self._speaker_wise_sentences[spk_idx][-1]['end_time'] = max(
end_time, current_start + self._frame_len_sec
)
new_words = self._speaker_wise_sentences[spk_idx][-1]['words'] + diff_text
self._speaker_wise_sentences[spk_idx][-1]['words'] = new_words.strip()

Expand All @@ -1536,18 +1545,27 @@ def _is_new_text(self, spk_idx: int, text: str):
else:
return text.strip()

def _compute_hypothesis_timestamps(self, hypothesis: Hypothesis, offset: float) -> Tuple[float, float, bool]:
def _compute_hypothesis_timestamps(
self,
hypothesis: Hypothesis,
offset: float,
prev_token_count: int = 0,
decoded_length_before: int = None,
) -> Tuple[float, float, bool]:
"""
Compute start and end timestamps for a hypothesis based on available timing information.

This method calculates the temporal boundaries of a speech hypothesis, prioritizing
frame-level timestamps when available. When timestamps are not available, it falls
back to computing timing based on the hypothesis length.
frame-level timestamps and decoder state when available. When timestamps are not available,
it falls back to computing timing based on the hypothesis length.

Args:
hypothesis (Hypothesis): The ASR hypothesis object containing either frame-level
hypothesis (Hypothesis): The ASR hypothesis object containing frame-level
timestamps and decoder state.
offset (float): The time offset (in seconds) to add to the computed timestamps,
typically representing the start time of the current audio chunk.
prev_token_count (int): The number of timestamp entries already processed for this speaker.
decoded_length_before (int): The decoded length before the current chunk.

Returns:
Tuple[float, float, bool]: A tuple containing:
Expand All @@ -1561,12 +1579,15 @@ def _compute_hypothesis_timestamps(self, hypothesis: Hypothesis, offset: float)
for the full duration of the final frame.
"""
sep_flag = False
if len(hypothesis.timestamp) > 0:
start_time = offset + (hypothesis.timestamp[0]) * self._frame_len_sec
end_time = offset + (hypothesis.timestamp[-1] + 1) * self._frame_len_sec
new_timestamp_count = len(hypothesis.timestamp) - prev_token_count
if hypothesis.dec_state is not None and new_timestamp_count > 0 and decoded_length_before is not None:
start_local = hypothesis.timestamp[prev_token_count].item() - decoded_length_before
end_local = hypothesis.timestamp[-1].item() - decoded_length_before
start_time = offset + start_local * self._frame_len_sec
end_time = offset + (end_local + 1) * self._frame_len_sec
else:
start_time = offset
end_time = offset + hypothesis.length.item() * self._frame_len_sec
end_time = offset + max(0, hypothesis.length.item() - prev_token_count) * self._frame_len_sec
sep_flag = True

return start_time, end_time, sep_flag
Expand Down Expand Up @@ -1594,9 +1615,16 @@ def update_sessionwise_seglsts_for_parallel(self, offset: float):
if diff_text is not None:

start_time, end_time, sep_flag = self._compute_hypothesis_timestamps(
hypothesis=hypothesis, offset=offset
hypothesis=hypothesis,
offset=offset,
prev_token_count=self._prev_token_counts[spk_idx],
decoded_length_before=self._prev_decoded_lengths[spk_idx],
)

# Update the stored decoded_length for this speaker
if hypothesis.dec_state is not None:
self._prev_decoded_lengths[spk_idx] = hypothesis.dec_state.decoded_length.item()

# Get the last end time of the previous sentence or None if no sentences are present
if len(self._speaker_wise_sentences[spk_idx]) > 0:
last_end_time = self._speaker_wise_sentences[spk_idx][-1]['end_time']
Expand Down Expand Up @@ -1628,6 +1656,7 @@ def update_sessionwise_seglsts_for_parallel(self, offset: float):
# Update the previous history of the speaker text
if hypothesis.text is not None:
self._prev_history_speaker_texts[spk_idx] = hypothesis.text
self._prev_token_counts[spk_idx] = len(hypothesis.timestamp)

self.seglsts = []

Expand Down
Loading