diff --git a/nemo/collections/asr/parts/utils/transcribe_utils.py b/nemo/collections/asr/parts/utils/transcribe_utils.py index 9c576ce3c093..e7a3bd1e1a0f 100644 --- a/nemo/collections/asr/parts/utils/transcribe_utils.py +++ b/nemo/collections/asr/parts/utils/transcribe_utils.py @@ -18,7 +18,7 @@ from dataclasses import dataclass from pathlib import Path from tempfile import NamedTemporaryFile -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import torch from omegaconf import DictConfig @@ -106,9 +106,9 @@ def get_buffered_pred_feat_rnnt( delay: int, model_stride_in_secs: int, batch_size: int, - manifest: str = None, - filepaths: List[list] = None, - target_lang_id: str = None, + manifest: Optional[str] = None, + filepaths: Optional[List[list]] = None, + target_lang_id: Optional[str] = None, accelerator: Optional[str] = 'cpu', ) -> List[rnnt_utils.Hypothesis]: """ @@ -219,8 +219,8 @@ def get_buffered_pred_feat_multitaskAED( preprocessor_cfg: DictConfig, model_stride_in_secs: int, device: Union[List[int], int], - manifest: str = None, - filepaths: List[list] = None, + manifest: Optional[str] = None, + filepaths: Optional[List[list]] = None, delay: float = 0.0, timestamps: bool = False, ) -> List[rnnt_utils.Hypothesis]: @@ -390,7 +390,7 @@ def read_and_maybe_sort_manifest(path: str, try_sort: bool = False) -> List[dict return items -def restore_transcription_order(manifest_path: str, transcriptions: list) -> list: +def restore_transcription_order(manifest_path: str, transcriptions: List) -> Union[List, Tuple]: with open(manifest_path, encoding='utf-8') as f: items = [(idx, json.loads(l)) for idx, l in enumerate(f) if l.strip() != ""] if not all("duration" in item[1] and item[1]["duration"] is not None for item in items): @@ -422,7 +422,7 @@ def compute_output_filename(cfg: DictConfig, model_name: str) -> DictConfig: return cfg -def normalize_timestamp_output(timestamps: dict): +def normalize_timestamp_output(timestamps: Dict) -> Dict: """ Normalize the dictionary of timestamp values to JSON serializable values. Expects the following keys to exist - @@ -447,7 +447,7 @@ def write_transcription( transcriptions: Union[List[rnnt_utils.Hypothesis], List[List[rnnt_utils.Hypothesis]], List[str]], cfg: DictConfig, model_name: str, - filepaths: List[str] = None, + filepaths: Optional[List[str]] = None, compute_langs: bool = False, timestamps: bool = False, ) -> Tuple[str, str]: @@ -549,7 +549,7 @@ def compute_metrics_per_sample( hypothesis_field: str = "pred_text", metrics: List[str] = ["wer"], punctuation_marks: List[str] = [".", ",", "?"], - output_manifest_path: str = None, + output_manifest_path: Optional[str] = None, ) -> dict: ''' Computes metrics per sample for given manifest @@ -637,7 +637,7 @@ def compute_metrics_per_sample( class PunctuationCapitalization: - def __init__(self, punctuation_marks: str): + def __init__(self, punctuation_marks: str) -> None: """ Class for text processing with punctuation and capitalization. Can be used with class TextProcessingConfig.