Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dataclasses import dataclass
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import torch
from omegaconf import DictConfig
Expand Down Expand Up @@ -106,9 +106,9 @@ def get_buffered_pred_feat_rnnt(
delay: int,
model_stride_in_secs: int,
batch_size: int,
manifest: str = None,
filepaths: List[list] = None,
target_lang_id: str = None,
manifest: Optional[str] = None,
filepaths: Optional[List[list]] = None,
target_lang_id: Optional[str] = None,
accelerator: Optional[str] = 'cpu',
) -> List[rnnt_utils.Hypothesis]:
"""
Expand Down Expand Up @@ -219,8 +219,8 @@ def get_buffered_pred_feat_multitaskAED(
preprocessor_cfg: DictConfig,
model_stride_in_secs: int,
device: Union[List[int], int],
manifest: str = None,
filepaths: List[list] = None,
manifest: Optional[str] = None,
filepaths: Optional[List[list]] = None,
delay: float = 0.0,
timestamps: bool = False,
) -> List[rnnt_utils.Hypothesis]:
Expand Down Expand Up @@ -390,7 +390,7 @@ def read_and_maybe_sort_manifest(path: str, try_sort: bool = False) -> List[dict
return items


def restore_transcription_order(manifest_path: str, transcriptions: list) -> list:
def restore_transcription_order(manifest_path: str, transcriptions: List) -> Union[List, Tuple]:
with open(manifest_path, encoding='utf-8') as f:
items = [(idx, json.loads(l)) for idx, l in enumerate(f) if l.strip() != ""]
if not all("duration" in item[1] and item[1]["duration"] is not None for item in items):
Expand Down Expand Up @@ -422,7 +422,7 @@ def compute_output_filename(cfg: DictConfig, model_name: str) -> DictConfig:
return cfg


def normalize_timestamp_output(timestamps: dict):
def normalize_timestamp_output(timestamps: Dict) -> Dict:
"""
Normalize the dictionary of timestamp values to JSON serializable values.
Expects the following keys to exist -
Expand All @@ -447,7 +447,7 @@ def write_transcription(
transcriptions: Union[List[rnnt_utils.Hypothesis], List[List[rnnt_utils.Hypothesis]], List[str]],
cfg: DictConfig,
model_name: str,
filepaths: List[str] = None,
filepaths: Optional[List[str]] = None,
compute_langs: bool = False,
timestamps: bool = False,
) -> Tuple[str, str]:
Expand Down Expand Up @@ -549,7 +549,7 @@ def compute_metrics_per_sample(
hypothesis_field: str = "pred_text",
metrics: List[str] = ["wer"],
punctuation_marks: List[str] = [".", ",", "?"],
output_manifest_path: str = None,
output_manifest_path: Optional[str] = None,
) -> dict:
'''
Computes metrics per sample for given manifest
Expand Down Expand Up @@ -637,7 +637,7 @@ def compute_metrics_per_sample(


class PunctuationCapitalization:
def __init__(self, punctuation_marks: str):
def __init__(self, punctuation_marks: str) -> None:
"""
Class for text processing with punctuation and capitalization. Can be used with class TextProcessingConfig.

Expand Down
Loading