From 64df657d385571529e183f9046b6b78e33761e93 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Fri, 1 May 2026 17:45:16 -0700 Subject: [PATCH 01/10] Audio data debugging Signed-off-by: Fejgin, Roy --- .../tts/data/audio_codec_dataset_lhotse.py | 99 ++++++++++++++++++- nemo/collections/tts/models/audio_codec.py | 3 +- 2 files changed, 100 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py index 08b9f5d6dc11..bcab53f19a13 100644 --- a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py +++ b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +import re +from pathlib import Path +from typing import Dict, Optional, Union +import soundfile as sf import torch from lhotse import CutSet from lhotse.dataset.collation import collate_audio from nemo.utils import logging +_SAFE_FILENAME_CHARS = re.compile(r"[^A-Za-z0-9._-]+") +_DATASET_IN_SPEAKER = re.compile(r"(?:^|[\s|])Dataset:([^\s|]+)") + class AudioCodecLhotseDataset(torch.utils.data.Dataset): """ @@ -42,6 +48,9 @@ def __init__( sample_rate: int, sanity_check_audio: bool = False, min_samples_for_sanity: Optional[int] = None, + log_audio: bool = False, + log_audio_dir: Union[str, Path] = "logged_audio", + log_audio_num_batches: int = 3, ): """ Args: @@ -50,11 +59,98 @@ def __init__( min_samples_for_sanity: cuts should have at least this many samples or an error will be raised. Only used when `sanity_check_audio` is True. + log_audio: If True, save the original `target_audio` waveforms from + the first few batches before any dataset resampling. + log_audio_dir: Directory where debug wav files will be written. + log_audio_num_batches: Number of initial batches to log per dataset + instance or dataloader worker. """ super().__init__() self.sample_rate = sample_rate self.sanity_check_audio = sanity_check_audio self.min_samples_for_sanity = min_samples_for_sanity + self.log_audio = log_audio + self.log_audio_dir = Path(log_audio_dir) + self.log_audio_num_batches = log_audio_num_batches + self._logged_audio_batches = 0 + + def _maybe_log_audio(self, cuts: CutSet): + """ + Save original target_audio waveforms for the first few batches. + """ + if not self.log_audio or self._logged_audio_batches >= self.log_audio_num_batches: + return + + self._log_target_audio_without_resampling(cuts) + self._logged_audio_batches += 1 + + def _log_target_audio_without_resampling(self, cuts: CutSet): + """ + Save each cut's `target_audio` before `target_audio.resample()` is applied. + + This intentionally uses the custom `target_audio` recording and its own + sampling rate, not `self.sample_rate`. To keep the debug files trustworthy, + fail fast if the recording already has Lhotse audio transforms attached. + """ + self.log_audio_dir.mkdir(parents=True, exist_ok=True) + + for cut in cuts: + recording = cut.target_audio + transform_names = self._recording_transform_names(recording) + if transform_names: + raise RuntimeError( + "Cannot log untransformed target_audio because the recording " + f"already has Lhotse audio transforms attached: {transform_names}. " + f"cut_id={cut.id}, recording_id={recording.id}" + ) + + audio = cut.load_custom("target_audio") + speaker = getattr(cut.supervisions[0], "speaker", None) if cut.supervisions else None + filename = self._recording_id_to_wav_name(recording.id, recording.sampling_rate, speaker=speaker) + path = self.log_audio_dir / filename + sf.write(str(path), self._audio_for_soundfile(audio), samplerate=recording.sampling_rate) + logging.info( + f"Saved original target_audio for cut_id={cut.id}, recording_id={recording.id}, " + f"sampling_rate={recording.sampling_rate}, shape={audio.shape}, path={path}" + ) + + @staticmethod + def _recording_transform_names(recording) -> list[str]: + transform_names = [] + for transform in recording.transforms or []: + if isinstance(transform, dict): + transform_names.append(transform.get("name", str(transform))) + else: + transform_names.append(type(transform).__name__) + return transform_names + + @staticmethod + def _recording_id_to_wav_name(recording_id: str, sampling_rate: int, speaker: Optional[str] = None) -> str: + safe_id = _SAFE_FILENAME_CHARS.sub("_", str(recording_id)).strip("._") + if not safe_id: + safe_id = "recording" + dataset = AudioCodecLhotseDataset._dataset_from_speaker(speaker) + if dataset is not None: + safe_id = f"{dataset}_{safe_id}" + return f"{safe_id}_{sampling_rate}Hz.wav" + + @staticmethod + def _dataset_from_speaker(speaker: Optional[str]) -> Optional[str]: + if speaker is None: + return None + match = _DATASET_IN_SPEAKER.search(speaker) + if match is None: + return None + safe_dataset = _SAFE_FILENAME_CHARS.sub("_", match.group(1)).strip("._") + return safe_dataset or None + + @staticmethod + def _audio_for_soundfile(audio): + if audio.ndim == 1: + return audio + if audio.shape[0] == 1: + return audio[0] + return audio.T def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: """ @@ -65,6 +161,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: Returns: A dictionary with the `audio` and `audio_lens` tensors. """ + self._maybe_log_audio(cuts) # Resample the audio to the target sample rate. We need to do this manually # because Lhotse only resamples its standard `recording` field automatically, # not custom fields like `target_audio`. diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index c8e6dfec18a7..728e6489fb19 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -862,7 +862,8 @@ def _get_lhotse_dataloader(self, cfg): raise ValueError("`truncate_duration` must be set in the config") loader_cfg.truncate_offset_type = "random" # Also filter examples to be at least this long to avoid zero-padding - loader_cfg.min_duration = loader_cfg.truncate_duration + if loader_cfg.min_duration is None: + loader_cfg.min_duration = loader_cfg.truncate_duration # --- Create the dataset --- From 1d7423e082179e4141c70b200b54362d206ea42b Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Sat, 16 May 2026 18:55:20 -0700 Subject: [PATCH 02/10] Fix Lhotse handling when recording and target_audio are of different length Signed-off-by: Fejgin, Roy --- .../tts/data/audio_codec_dataset_lhotse.py | 57 ++++++++++++++----- nemo/collections/tts/models/audio_codec.py | 29 ++++++---- 2 files changed, 60 insertions(+), 26 deletions(-) diff --git a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py index bcab53f19a13..f0de5a2b5158 100644 --- a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py +++ b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random import re from pathlib import Path from typing import Dict, Optional, Union +import numpy as np import soundfile as sf import torch from lhotse import CutSet -from lhotse.dataset.collation import collate_audio from nemo.utils import logging @@ -34,18 +35,20 @@ class AudioCodecLhotseDataset(torch.utils.data.Dataset): It is a simple dataset that mostly just loads the audio samples. In addition, it performs the following operations: * Resampling to the target sample rate + * Random truncation of each cut's `target_audio` to a fixed duration * Sanity checks on the audio The operations below are handled directly by Lhotse according to the configuration applied in `AudioCodecModel._get_lhotse_dataloader()`: * Duration filtering * Any additional transformations configured in Lhotse during its construction are - applied to the audio as it is loaded in `collate_audio()`. + applied to the audio as it is loaded in `load_audio()`. """ def __init__( self, sample_rate: int, + truncate_duration: float, sanity_check_audio: bool = False, min_samples_for_sanity: Optional[int] = None, log_audio: bool = False, @@ -55,6 +58,9 @@ def __init__( """ Args: sample_rate: The sample rate to resample the audio to. + truncate_duration: Length of each training window in seconds. A random + window of this length is taken from each cut's `target_audio` field + (not from the parent `recording`, which may span a much longer file). sanity_check_audio: If True, perform sanity checks on the loaded audio. min_samples_for_sanity: cuts should have at least this many samples or an error will be raised. Only used when @@ -67,6 +73,8 @@ def __init__( """ super().__init__() self.sample_rate = sample_rate + self.truncate_duration = truncate_duration + self.truncate_samples = int(truncate_duration * sample_rate) self.sanity_check_audio = sanity_check_audio self.min_samples_for_sanity = min_samples_for_sanity self.log_audio = log_audio @@ -152,6 +160,31 @@ def _audio_for_soundfile(audio): return audio[0] return audio.T + def _load_and_truncate_target_audio(self, cut) -> torch.Tensor: + """ + Load `target_audio`, resample, and return a random segmentof length `truncate_duration`. + """ + if not cut.has_custom("target_audio"): + raise ValueError(f"Cut {cut.id} is missing custom field 'target_audio'") + + target_audio_recording = cut.target_audio.resample(self.sample_rate) + audio = target_audio_recording.load_audio() + if audio.ndim > 1: + audio = audio.squeeze(0) + + num_samples = audio.shape[-1] + if num_samples < self.truncate_samples: + raise ValueError( + f"target_audio is shorter than truncate_duration: " + f"cut_id={cut.id}, target_audio_id={target_audio_recording.id}, " + f"num_samples={num_samples}, required={self.truncate_samples}, " + f"truncate_duration={self.truncate_duration}s" + ) + + start = random.randint(0, num_samples - self.truncate_samples) + window = audio[start : start + self.truncate_samples] + return torch.from_numpy(np.ascontiguousarray(window, dtype=np.float32)) + def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: """ Loads the specified cuts and performs the operations listed above. @@ -162,19 +195,15 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: A dictionary with the `audio` and `audio_lens` tensors. """ self._maybe_log_audio(cuts) - # Resample the audio to the target sample rate. We need to do this manually - # because Lhotse only resamples its standard `recording` field automatically, - # not custom fields like `target_audio`. - for cut in cuts: - cut.target_audio = cut.target_audio.resample(self.sample_rate) - - # Load and collate the audio, applying any transformations that were - # configured in Lhotse in the process. - # Note: fault_tolerant=False for now to avoid masking errors until we are more - # confident in the new loader. - batch_audio, batch_audio_len = collate_audio(cuts, recording_field="target_audio", fault_tolerant=False) + # Load, resample and truncate the audio + audio_list = [self._load_and_truncate_target_audio(cut) for cut in cuts] + batch_audio = torch.stack(audio_list, dim=0) + batch_audio_len = torch.full( + (len(audio_list),), + self.truncate_samples, + dtype=torch.int32, + ) - # Sanity checks on the audio and its length if self.sanity_check_audio: self._sanity_check_audio(batch_audio, batch_audio_len, cuts) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 728e6489fb19..639ab410f9a1 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -854,24 +854,29 @@ def _get_lhotse_dataloader(self, cfg): # manually in the dataset class. loader_cfg.sample_rate = self.output_sample_rate - # Set up cut truncation, filtering, and random selection: - # `truncate_duration` and `truncate_offset_type` are interpreted by Lhotse. - # Together, they configure Lhotse to choose a random segment of this length - # from each cut. - if loader_cfg.truncate_duration is None: - raise ValueError("`truncate_duration` must be set in the config") - loader_cfg.truncate_offset_type = "random" - # Also filter examples to be at least this long to avoid zero-padding + # Random windowing is done in AudioCodecLhotseDataset on `target_audio`, not via + # Lhotse's cuts.truncate (which operates on the parent recording coordinates). + truncate_duration = dataset_args.get("truncate_duration") + if truncate_duration is None: + raise ValueError("`truncate_duration` must be set in `train_ds.dataset_args` ") + if cfg.dataloader_params.get("truncate_duration") is not None: + raise ValueError( + "`truncate_duration` must not be set in `train_ds.dataloader_params`; " + "set it in `train_ds.dataset_args` instead." + ) + loader_cfg.truncate_duration = None + # Pre-filter cuts whose parent recording is shorter than the training window. if loader_cfg.min_duration is None: - loader_cfg.min_duration = loader_cfg.truncate_duration + loader_cfg.min_duration = truncate_duration # --- Create the dataset --- # Error out if the audio is suspiciously short (half the expected length) - min_samples_for_sanity = loader_cfg.truncate_duration * self.output_sample_rate // 2 - # Create the dataset + min_samples_for_sanity = truncate_duration * self.output_sample_rate // 2 dataset = AudioCodecLhotseDataset( - sample_rate=self.output_sample_rate, min_samples_for_sanity=min_samples_for_sanity, **dataset_args + sample_rate=self.output_sample_rate, + min_samples_for_sanity=min_samples_for_sanity, + **dataset_args, ) # Create the dataloader From 84334c0ba3d20941aeeea9b620a0872a83c030c2 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Sat, 16 May 2026 19:28:04 -0700 Subject: [PATCH 03/10] Fix duration filtering Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/audio_codec.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 639ab410f9a1..c492d4c83e14 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -866,8 +866,13 @@ def _get_lhotse_dataloader(self, cfg): ) loader_cfg.truncate_duration = None # Pre-filter cuts whose parent recording is shorter than the training window. - if loader_cfg.min_duration is None: - loader_cfg.min_duration = truncate_duration + existing_min_duration = cfg.dataloader_params.get("min_duration") + if existing_min_duration is not None and existing_min_duration != -1: + raise ValueError( + "`min_duration` must not be set in `train_ds.dataloader_params` " + "it is set automatically from `train_ds.dataset_args.truncate_duration`." + ) + loader_cfg.min_duration = truncate_duration + 0.01 # add a bit to allow for resampling length mismatch # --- Create the dataset --- From 9e6b6f31d6f528a6540659c2ec42776be9cc5ae7 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Wed, 20 May 2026 20:51:59 -0700 Subject: [PATCH 04/10] Fix batch duration Need to set truncate duration in both the dataset and the loader otherwise Lhotse's tracking of batch duration will be incorrect. Signed-off-by: Fejgin, Roy --- .../tts/data/audio_codec_dataset_lhotse.py | 2 +- nemo/collections/tts/models/audio_codec.py | 15 +++++++++------ 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py index f0de5a2b5158..9b9d39b5e576 100644 --- a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py +++ b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py @@ -162,7 +162,7 @@ def _audio_for_soundfile(audio): def _load_and_truncate_target_audio(self, cut) -> torch.Tensor: """ - Load `target_audio`, resample, and return a random segmentof length `truncate_duration`. + Load `target_audio`, resample, and return a random segment of length `truncate_duration`. """ if not cut.has_custom("target_audio"): raise ValueError(f"Cut {cut.id} is missing custom field 'target_audio'") diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index c492d4c83e14..04d74013ebd0 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -855,16 +855,18 @@ def _get_lhotse_dataloader(self, cfg): loader_cfg.sample_rate = self.output_sample_rate # Random windowing is done in AudioCodecLhotseDataset on `target_audio`, not via - # Lhotse's cuts.truncate (which operates on the parent recording coordinates). + # Lhotse's cuts.truncate (which operates on the parent recording). + # However, to ensure the total batch duration is correct, we need to also set + # the truncate duration in the loader configuration. truncate_duration = dataset_args.get("truncate_duration") if truncate_duration is None: raise ValueError("`truncate_duration` must be set in `train_ds.dataset_args` ") - if cfg.dataloader_params.get("truncate_duration") is not None: + dataloader_truncate_duration = cfg.dataloader_params.get("truncate_duration") + if dataloader_truncate_duration is not None and dataloader_truncate_duration != truncate_duration: raise ValueError( - "`truncate_duration` must not be set in `train_ds.dataloader_params`; " - "set it in `train_ds.dataset_args` instead." + "`truncate_duration` in `train_ds.dataloader_params` must be set to the same value as `train_ds.dataset_args.truncate_duration`." ) - loader_cfg.truncate_duration = None + loader_cfg.truncate_duration = truncate_duration # Pre-filter cuts whose parent recording is shorter than the training window. existing_min_duration = cfg.dataloader_params.get("min_duration") if existing_min_duration is not None and existing_min_duration != -1: @@ -872,7 +874,8 @@ def _get_lhotse_dataloader(self, cfg): "`min_duration` must not be set in `train_ds.dataloader_params` " "it is set automatically from `train_ds.dataset_args.truncate_duration`." ) - loader_cfg.min_duration = truncate_duration + 0.01 # add a bit to allow for resampling length mismatch + # random truncation of the audio + loader_cfg.min_duration = truncate_duration # + 0.01 # add a bit to allow for resampling length mismatch # --- Create the dataset --- From 79cd1533298963fb07ba0e4a85f5dcb2979da361 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Mon, 1 Jun 2026 16:42:12 -0700 Subject: [PATCH 05/10] Switch from `truncate_duration` to `batch_size` Signed-off-by: Fejgin, Roy --- .../tts/data/audio_codec_dataset_lhotse.py | 43 +++++++++---------- nemo/collections/tts/models/audio_codec.py | 28 +++++------- .../data/test_audio_codec_dataset_lhotse.py | 6 +-- 3 files changed, 34 insertions(+), 43 deletions(-) diff --git a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py index 9b9d39b5e576..8c280f10daca 100644 --- a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py +++ b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py @@ -40,7 +40,7 @@ class AudioCodecLhotseDataset(torch.utils.data.Dataset): The operations below are handled directly by Lhotse according to the configuration applied in `AudioCodecModel._get_lhotse_dataloader()`: - * Duration filtering + * Minimum duration filtering * Any additional transformations configured in Lhotse during its construction are applied to the audio as it is loaded in `load_audio()`. """ @@ -48,9 +48,8 @@ class AudioCodecLhotseDataset(torch.utils.data.Dataset): def __init__( self, sample_rate: int, - truncate_duration: float, + segment_duration: float, sanity_check_audio: bool = False, - min_samples_for_sanity: Optional[int] = None, log_audio: bool = False, log_audio_dir: Union[str, Path] = "logged_audio", log_audio_num_batches: int = 3, @@ -58,13 +57,10 @@ def __init__( """ Args: sample_rate: The sample rate to resample the audio to. - truncate_duration: Length of each training window in seconds. A random - window of this length is taken from each cut's `target_audio` field - (not from the parent `recording`, which may span a much longer file). + segment_duration: Length of each training segment in seconds. A random + segment of this length is taken from each cut's `target_audio` field + (not from the parent `recording`, which may span a much longer duration). sanity_check_audio: If True, perform sanity checks on the loaded audio. - min_samples_for_sanity: cuts should have at least this many samples or an - error will be raised. Only used when - `sanity_check_audio` is True. log_audio: If True, save the original `target_audio` waveforms from the first few batches before any dataset resampling. log_audio_dir: Directory where debug wav files will be written. @@ -73,10 +69,11 @@ def __init__( """ super().__init__() self.sample_rate = sample_rate - self.truncate_duration = truncate_duration - self.truncate_samples = int(truncate_duration * sample_rate) + self.segment_duration = segment_duration + self.segment_samples = int(segment_duration * sample_rate) self.sanity_check_audio = sanity_check_audio - self.min_samples_for_sanity = min_samples_for_sanity + # Error out if audio is suspiciously short (leaving some slack for resampling). + self.min_samples_for_sanity = max(1, self.segment_samples - 5) self.log_audio = log_audio self.log_audio_dir = Path(log_audio_dir) self.log_audio_num_batches = log_audio_num_batches @@ -162,7 +159,7 @@ def _audio_for_soundfile(audio): def _load_and_truncate_target_audio(self, cut) -> torch.Tensor: """ - Load `target_audio`, resample, and return a random segment of length `truncate_duration`. + Load `target_audio`, resample, and return a random segment of length `segment_duration`. """ if not cut.has_custom("target_audio"): raise ValueError(f"Cut {cut.id} is missing custom field 'target_audio'") @@ -173,17 +170,17 @@ def _load_and_truncate_target_audio(self, cut) -> torch.Tensor: audio = audio.squeeze(0) num_samples = audio.shape[-1] - if num_samples < self.truncate_samples: + if num_samples < self.segment_samples: raise ValueError( - f"target_audio is shorter than truncate_duration: " + f"target_audio is shorter than segment_duration: " f"cut_id={cut.id}, target_audio_id={target_audio_recording.id}, " - f"num_samples={num_samples}, required={self.truncate_samples}, " - f"truncate_duration={self.truncate_duration}s" + f"num_samples={num_samples}, required={self.segment_samples}, " + f"segment_duration={self.segment_duration}s" ) - - start = random.randint(0, num_samples - self.truncate_samples) - window = audio[start : start + self.truncate_samples] - return torch.from_numpy(np.ascontiguousarray(window, dtype=np.float32)) + # Randomly select a segment of the audio of length `segment_duration`. + start = random.randint(0, num_samples - self.segment_samples) + segment = audio[start : start + self.segment_samples] + return torch.from_numpy(np.ascontiguousarray(segment, dtype=np.float32)) def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: """ @@ -200,7 +197,7 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: batch_audio = torch.stack(audio_list, dim=0) batch_audio_len = torch.full( (len(audio_list),), - self.truncate_samples, + self.segment_samples, dtype=torch.int32, ) @@ -221,7 +218,7 @@ def _sanity_check_audio(self, audio: torch.Tensor, audio_len: torch.Tensor, cuts # --- Error cases --- # Audio length is unexpectedly short - if self.min_samples_for_sanity is not None and audio_len.min() < self.min_samples_for_sanity: + if audio_len.min() < self.min_samples_for_sanity: raise ValueError( f"Audio length is less than {self.min_samples_for_sanity} samples (min: {audio_len.min()})" ) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 04d74013ebd0..4a1c06764796 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -856,34 +856,28 @@ def _get_lhotse_dataloader(self, cfg): # Random windowing is done in AudioCodecLhotseDataset on `target_audio`, not via # Lhotse's cuts.truncate (which operates on the parent recording). - # However, to ensure the total batch duration is correct, we need to also set - # the truncate duration in the loader configuration. - truncate_duration = dataset_args.get("truncate_duration") - if truncate_duration is None: - raise ValueError("`truncate_duration` must be set in `train_ds.dataset_args` ") - dataloader_truncate_duration = cfg.dataloader_params.get("truncate_duration") - if dataloader_truncate_duration is not None and dataloader_truncate_duration != truncate_duration: + segment_duration = dataset_args.get("segment_duration") + if segment_duration is None: + raise ValueError("`segment_duration` must be set in `train_ds.dataset_args` ") + if cfg.dataloader_params.get("truncate_duration") is not None: raise ValueError( - "`truncate_duration` in `train_ds.dataloader_params` must be set to the same value as `train_ds.dataset_args.truncate_duration`." + "`truncate_duration` must not be set in `train_ds.dataloader_params`; " + "segment extraction is handled in `AudioCodecLhotseDataset` via `segment_duration`." ) - loader_cfg.truncate_duration = truncate_duration - # Pre-filter cuts whose parent recording is shorter than the training window. existing_min_duration = cfg.dataloader_params.get("min_duration") if existing_min_duration is not None and existing_min_duration != -1: raise ValueError( - "`min_duration` must not be set in `train_ds.dataloader_params` " - "it is set automatically from `train_ds.dataset_args.truncate_duration`." + "`min_duration` must not be set in `train_ds.dataloader_params`; " + "it is set automatically from `train_ds.dataset_args.segment_duration`." ) - # random truncation of the audio - loader_cfg.min_duration = truncate_duration # + 0.01 # add a bit to allow for resampling length mismatch + # Pre-filter to only include cuts whose parent recording is at least as long as + # the training segment duration so the dataset class has enough samples to choose from. + loader_cfg.min_duration = segment_duration # --- Create the dataset --- - # Error out if the audio is suspiciously short (half the expected length) - min_samples_for_sanity = truncate_duration * self.output_sample_rate // 2 dataset = AudioCodecLhotseDataset( sample_rate=self.output_sample_rate, - min_samples_for_sanity=min_samples_for_sanity, **dataset_args, ) diff --git a/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py b/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py index b8da3cfbfcff..cd4e8850ae4a 100644 --- a/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py +++ b/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py @@ -88,16 +88,16 @@ def cutset(tmp_path) -> CutSet: def dataset() -> AudioCodecLhotseDataset: return AudioCodecLhotseDataset( sample_rate=TARGET_SAMPLE_RATE, - min_samples_for_sanity=DEFAULT_DURATION * TARGET_SAMPLE_RATE, + segment_duration=DEFAULT_DURATION, ) class TestAudioCodecLhotseDataset: @pytest.mark.unit def test_init(self): - ds = AudioCodecLhotseDataset(sample_rate=22050, min_samples_for_sanity=512) + ds = AudioCodecLhotseDataset(sample_rate=22050, segment_duration=1.0) assert ds.sample_rate == 22050 - assert ds.min_samples_for_sanity == 512 + assert ds.min_samples_for_sanity == 22050 - 5 @pytest.mark.unit def test_getitem_returns_expected_keys_and_shapes(self, dataset, cutset): From c9ac8c943cd81a8dba2e2b991f5d3dd8455d498f Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Mon, 1 Jun 2026 16:50:28 -0700 Subject: [PATCH 06/10] Cleanup Signed-off-by: Fejgin, Roy --- .../tts/data/audio_codec_dataset_lhotse.py | 99 +------------------ 1 file changed, 1 insertion(+), 98 deletions(-) diff --git a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py index 8c280f10daca..96ff54bf2bb5 100644 --- a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py +++ b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py @@ -13,20 +13,14 @@ # limitations under the License. import random -import re -from pathlib import Path -from typing import Dict, Optional, Union +from typing import Dict import numpy as np -import soundfile as sf import torch from lhotse import CutSet from nemo.utils import logging -_SAFE_FILENAME_CHARS = re.compile(r"[^A-Za-z0-9._-]+") -_DATASET_IN_SPEAKER = re.compile(r"(?:^|[\s|])Dataset:([^\s|]+)") - class AudioCodecLhotseDataset(torch.utils.data.Dataset): """ @@ -50,9 +44,6 @@ def __init__( sample_rate: int, segment_duration: float, sanity_check_audio: bool = False, - log_audio: bool = False, - log_audio_dir: Union[str, Path] = "logged_audio", - log_audio_num_batches: int = 3, ): """ Args: @@ -61,11 +52,6 @@ def __init__( segment of this length is taken from each cut's `target_audio` field (not from the parent `recording`, which may span a much longer duration). sanity_check_audio: If True, perform sanity checks on the loaded audio. - log_audio: If True, save the original `target_audio` waveforms from - the first few batches before any dataset resampling. - log_audio_dir: Directory where debug wav files will be written. - log_audio_num_batches: Number of initial batches to log per dataset - instance or dataloader worker. """ super().__init__() self.sample_rate = sample_rate @@ -74,88 +60,6 @@ def __init__( self.sanity_check_audio = sanity_check_audio # Error out if audio is suspiciously short (leaving some slack for resampling). self.min_samples_for_sanity = max(1, self.segment_samples - 5) - self.log_audio = log_audio - self.log_audio_dir = Path(log_audio_dir) - self.log_audio_num_batches = log_audio_num_batches - self._logged_audio_batches = 0 - - def _maybe_log_audio(self, cuts: CutSet): - """ - Save original target_audio waveforms for the first few batches. - """ - if not self.log_audio or self._logged_audio_batches >= self.log_audio_num_batches: - return - - self._log_target_audio_without_resampling(cuts) - self._logged_audio_batches += 1 - - def _log_target_audio_without_resampling(self, cuts: CutSet): - """ - Save each cut's `target_audio` before `target_audio.resample()` is applied. - - This intentionally uses the custom `target_audio` recording and its own - sampling rate, not `self.sample_rate`. To keep the debug files trustworthy, - fail fast if the recording already has Lhotse audio transforms attached. - """ - self.log_audio_dir.mkdir(parents=True, exist_ok=True) - - for cut in cuts: - recording = cut.target_audio - transform_names = self._recording_transform_names(recording) - if transform_names: - raise RuntimeError( - "Cannot log untransformed target_audio because the recording " - f"already has Lhotse audio transforms attached: {transform_names}. " - f"cut_id={cut.id}, recording_id={recording.id}" - ) - - audio = cut.load_custom("target_audio") - speaker = getattr(cut.supervisions[0], "speaker", None) if cut.supervisions else None - filename = self._recording_id_to_wav_name(recording.id, recording.sampling_rate, speaker=speaker) - path = self.log_audio_dir / filename - sf.write(str(path), self._audio_for_soundfile(audio), samplerate=recording.sampling_rate) - logging.info( - f"Saved original target_audio for cut_id={cut.id}, recording_id={recording.id}, " - f"sampling_rate={recording.sampling_rate}, shape={audio.shape}, path={path}" - ) - - @staticmethod - def _recording_transform_names(recording) -> list[str]: - transform_names = [] - for transform in recording.transforms or []: - if isinstance(transform, dict): - transform_names.append(transform.get("name", str(transform))) - else: - transform_names.append(type(transform).__name__) - return transform_names - - @staticmethod - def _recording_id_to_wav_name(recording_id: str, sampling_rate: int, speaker: Optional[str] = None) -> str: - safe_id = _SAFE_FILENAME_CHARS.sub("_", str(recording_id)).strip("._") - if not safe_id: - safe_id = "recording" - dataset = AudioCodecLhotseDataset._dataset_from_speaker(speaker) - if dataset is not None: - safe_id = f"{dataset}_{safe_id}" - return f"{safe_id}_{sampling_rate}Hz.wav" - - @staticmethod - def _dataset_from_speaker(speaker: Optional[str]) -> Optional[str]: - if speaker is None: - return None - match = _DATASET_IN_SPEAKER.search(speaker) - if match is None: - return None - safe_dataset = _SAFE_FILENAME_CHARS.sub("_", match.group(1)).strip("._") - return safe_dataset or None - - @staticmethod - def _audio_for_soundfile(audio): - if audio.ndim == 1: - return audio - if audio.shape[0] == 1: - return audio[0] - return audio.T def _load_and_truncate_target_audio(self, cut) -> torch.Tensor: """ @@ -191,7 +95,6 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: Returns: A dictionary with the `audio` and `audio_lens` tensors. """ - self._maybe_log_audio(cuts) # Load, resample and truncate the audio audio_list = [self._load_and_truncate_target_audio(cut) for cut in cuts] batch_audio = torch.stack(audio_list, dim=0) From 7032ce63a8f56e7b9df4c24429924437e789bbe1 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Mon, 1 Jun 2026 17:43:44 -0700 Subject: [PATCH 07/10] Add test for segment extraction Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/audio_codec.py | 2 +- .../data/test_audio_codec_dataset_lhotse.py | 29 +++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 4a1c06764796..9e4290a80d0e 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -854,7 +854,7 @@ def _get_lhotse_dataloader(self, cfg): # manually in the dataset class. loader_cfg.sample_rate = self.output_sample_rate - # Random windowing is done in AudioCodecLhotseDataset on `target_audio`, not via + # Random segment selection is done in AudioCodecLhotseDataset on `target_audio`, not via # Lhotse's cuts.truncate (which operates on the parent recording). segment_duration = dataset_args.get("segment_duration") if segment_duration is None: diff --git a/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py b/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py index cd4e8850ae4a..1d61697fabb7 100644 --- a/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py +++ b/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py @@ -137,3 +137,32 @@ def test_getitem_resampling_preserves_frequency(self, dataset, cutset): # FFT bin width is TARGET_SAMPLE_RATE / n; allow ~1 bin of tolerance. bin_width_hz = TARGET_SAMPLE_RATE / n assert abs(peak_freq_hz - cutset[i].target_tone_frequency) <= bin_width_hz + + @pytest.mark.unit + def test_getitem_extracts_subset_of_longer_audio(self, tmp_path, dataset, monkeypatch): + # A cut longer than segment_duration should yield a segment of exactly segment_samples + # that is a contiguous slice taken from inside the longer source signal. + # Use the target sample rate as the source rate so no resampling is involved. + cut = _make_cut(tmp_path, "long", duration=3.0, sample_rate=TARGET_SAMPLE_RATE, tone_frequency=440.0) + cuts = CutSet.from_cuts([cut]) + + # Load the full target audio the same way the dataset does (no resampling needed). + full = cut.target_audio.load_audio().squeeze(0) + segment_samples = int(DEFAULT_DURATION * TARGET_SAMPLE_RATE) + + # Pin the random start so we can compare against the exact source slice. + fixed_start = 7000 + monkeypatch.setattr( + "nemo.collections.tts.data.audio_codec_dataset_lhotse.random.randint", + lambda low, high: fixed_start, + ) + + batch = dataset[cuts] + segment = batch["audio"][0].numpy() + + assert segment.shape == (segment_samples,) + assert batch["audio_lens"][0].item() == segment_samples + # Exact match holds only because source rate == target rate (no resampling) and the + # dataset currently applies no augmentation. Once we add augmentation it makes + # sense to remove this assertion. + assert np.allclose(segment, full[fixed_start : fixed_start + segment_samples]) From fb6d33a0be9b808b2ca03d4e7693863f8a093846 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Mon, 1 Jun 2026 22:44:25 -0700 Subject: [PATCH 08/10] Lhotse config validation Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/audio_codec.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 9e4290a80d0e..1d3743708950 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -855,7 +855,7 @@ def _get_lhotse_dataloader(self, cfg): loader_cfg.sample_rate = self.output_sample_rate # Random segment selection is done in AudioCodecLhotseDataset on `target_audio`, not via - # Lhotse's cuts.truncate (which operates on the parent recording). + # Lhotse's `truncate_duration` config (which operates on the parent recording). segment_duration = dataset_args.get("segment_duration") if segment_duration is None: raise ValueError("`segment_duration` must be set in `train_ds.dataset_args` ") @@ -874,6 +874,10 @@ def _get_lhotse_dataloader(self, cfg): # the training segment duration so the dataset class has enough samples to choose from. loader_cfg.min_duration = segment_duration + # Make sure batch_size is set + if loader_cfg.batch_size is None: + raise ValueError("`batch_size` must be set in `train_ds.dataloader_params`.") + # --- Create the dataset --- dataset = AudioCodecLhotseDataset( From 5279ed79492a319c1d1f3cdaf9d35fa738209030 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Tue, 2 Jun 2026 12:31:28 -0700 Subject: [PATCH 09/10] Comments Signed-off-by: Fejgin, Roy --- nemo/collections/tts/data/audio_codec_dataset_lhotse.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py index 96ff54bf2bb5..b0ab98e7cc77 100644 --- a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py +++ b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py @@ -69,6 +69,7 @@ def _load_and_truncate_target_audio(self, cut) -> torch.Tensor: raise ValueError(f"Cut {cut.id} is missing custom field 'target_audio'") target_audio_recording = cut.target_audio.resample(self.sample_rate) + # Load the target audio, resampling and applying and Lhotse transformation in the process audio = target_audio_recording.load_audio() if audio.ndim > 1: audio = audio.squeeze(0) @@ -81,7 +82,8 @@ def _load_and_truncate_target_audio(self, cut) -> torch.Tensor: f"num_samples={num_samples}, required={self.segment_samples}, " f"segment_duration={self.segment_duration}s" ) - # Randomly select a segment of the audio of length `segment_duration`. + + # Randomly select a segment of the audio start = random.randint(0, num_samples - self.segment_samples) segment = audio[start : start + self.segment_samples] return torch.from_numpy(np.ascontiguousarray(segment, dtype=np.float32)) From c0e99df2541efe83fbe804aae49afa8d34b503a2 Mon Sep 17 00:00:00 2001 From: "Fejgin, Roy" Date: Tue, 2 Jun 2026 12:33:44 -0700 Subject: [PATCH 10/10] Move code around for clarity Signed-off-by: Fejgin, Roy --- nemo/collections/tts/models/audio_codec.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index 1d3743708950..8b2996e03288 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -856,14 +856,14 @@ def _get_lhotse_dataloader(self, cfg): # Random segment selection is done in AudioCodecLhotseDataset on `target_audio`, not via # Lhotse's `truncate_duration` config (which operates on the parent recording). - segment_duration = dataset_args.get("segment_duration") - if segment_duration is None: - raise ValueError("`segment_duration` must be set in `train_ds.dataset_args` ") if cfg.dataloader_params.get("truncate_duration") is not None: raise ValueError( "`truncate_duration` must not be set in `train_ds.dataloader_params`; " "segment extraction is handled in `AudioCodecLhotseDataset` via `segment_duration`." ) + segment_duration = dataset_args.get("segment_duration") + if segment_duration is None: + raise ValueError("`segment_duration` must be set in `train_ds.dataset_args` ") existing_min_duration = cfg.dataloader_params.get("min_duration") if existing_min_duration is not None and existing_min_duration != -1: raise ValueError(