diff --git a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py index 08b9f5d6dc11..b0ab98e7cc77 100644 --- a/nemo/collections/tts/data/audio_codec_dataset_lhotse.py +++ b/nemo/collections/tts/data/audio_codec_dataset_lhotse.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional +import random +from typing import Dict +import numpy as np import torch from lhotse import CutSet -from lhotse.dataset.collation import collate_audio from nemo.utils import logging @@ -28,33 +29,64 @@ class AudioCodecLhotseDataset(torch.utils.data.Dataset): It is a simple dataset that mostly just loads the audio samples. In addition, it performs the following operations: * Resampling to the target sample rate + * Random truncation of each cut's `target_audio` to a fixed duration * Sanity checks on the audio The operations below are handled directly by Lhotse according to the configuration applied in `AudioCodecModel._get_lhotse_dataloader()`: - * Duration filtering + * Minimum duration filtering * Any additional transformations configured in Lhotse during its construction are - applied to the audio as it is loaded in `collate_audio()`. + applied to the audio as it is loaded in `load_audio()`. """ def __init__( self, sample_rate: int, + segment_duration: float, sanity_check_audio: bool = False, - min_samples_for_sanity: Optional[int] = None, ): """ Args: sample_rate: The sample rate to resample the audio to. + segment_duration: Length of each training segment in seconds. A random + segment of this length is taken from each cut's `target_audio` field + (not from the parent `recording`, which may span a much longer duration). sanity_check_audio: If True, perform sanity checks on the loaded audio. - min_samples_for_sanity: cuts should have at least this many samples or an - error will be raised. Only used when - `sanity_check_audio` is True. """ super().__init__() self.sample_rate = sample_rate + self.segment_duration = segment_duration + self.segment_samples = int(segment_duration * sample_rate) self.sanity_check_audio = sanity_check_audio - self.min_samples_for_sanity = min_samples_for_sanity + # Error out if audio is suspiciously short (leaving some slack for resampling). + self.min_samples_for_sanity = max(1, self.segment_samples - 5) + + def _load_and_truncate_target_audio(self, cut) -> torch.Tensor: + """ + Load `target_audio`, resample, and return a random segment of length `segment_duration`. + """ + if not cut.has_custom("target_audio"): + raise ValueError(f"Cut {cut.id} is missing custom field 'target_audio'") + + target_audio_recording = cut.target_audio.resample(self.sample_rate) + # Load the target audio, resampling and applying and Lhotse transformation in the process + audio = target_audio_recording.load_audio() + if audio.ndim > 1: + audio = audio.squeeze(0) + + num_samples = audio.shape[-1] + if num_samples < self.segment_samples: + raise ValueError( + f"target_audio is shorter than segment_duration: " + f"cut_id={cut.id}, target_audio_id={target_audio_recording.id}, " + f"num_samples={num_samples}, required={self.segment_samples}, " + f"segment_duration={self.segment_duration}s" + ) + + # Randomly select a segment of the audio + start = random.randint(0, num_samples - self.segment_samples) + segment = audio[start : start + self.segment_samples] + return torch.from_numpy(np.ascontiguousarray(segment, dtype=np.float32)) def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: """ @@ -65,19 +97,15 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: Returns: A dictionary with the `audio` and `audio_lens` tensors. """ - # Resample the audio to the target sample rate. We need to do this manually - # because Lhotse only resamples its standard `recording` field automatically, - # not custom fields like `target_audio`. - for cut in cuts: - cut.target_audio = cut.target_audio.resample(self.sample_rate) - - # Load and collate the audio, applying any transformations that were - # configured in Lhotse in the process. - # Note: fault_tolerant=False for now to avoid masking errors until we are more - # confident in the new loader. - batch_audio, batch_audio_len = collate_audio(cuts, recording_field="target_audio", fault_tolerant=False) - - # Sanity checks on the audio and its length + # Load, resample and truncate the audio + audio_list = [self._load_and_truncate_target_audio(cut) for cut in cuts] + batch_audio = torch.stack(audio_list, dim=0) + batch_audio_len = torch.full( + (len(audio_list),), + self.segment_samples, + dtype=torch.int32, + ) + if self.sanity_check_audio: self._sanity_check_audio(batch_audio, batch_audio_len, cuts) @@ -95,7 +123,7 @@ def _sanity_check_audio(self, audio: torch.Tensor, audio_len: torch.Tensor, cuts # --- Error cases --- # Audio length is unexpectedly short - if self.min_samples_for_sanity is not None and audio_len.min() < self.min_samples_for_sanity: + if audio_len.min() < self.min_samples_for_sanity: raise ValueError( f"Audio length is less than {self.min_samples_for_sanity} samples (min: {audio_len.min()})" ) diff --git a/nemo/collections/tts/models/audio_codec.py b/nemo/collections/tts/models/audio_codec.py index c8e6dfec18a7..8b2996e03288 100644 --- a/nemo/collections/tts/models/audio_codec.py +++ b/nemo/collections/tts/models/audio_codec.py @@ -854,23 +854,35 @@ def _get_lhotse_dataloader(self, cfg): # manually in the dataset class. loader_cfg.sample_rate = self.output_sample_rate - # Set up cut truncation, filtering, and random selection: - # `truncate_duration` and `truncate_offset_type` are interpreted by Lhotse. - # Together, they configure Lhotse to choose a random segment of this length - # from each cut. - if loader_cfg.truncate_duration is None: - raise ValueError("`truncate_duration` must be set in the config") - loader_cfg.truncate_offset_type = "random" - # Also filter examples to be at least this long to avoid zero-padding - loader_cfg.min_duration = loader_cfg.truncate_duration + # Random segment selection is done in AudioCodecLhotseDataset on `target_audio`, not via + # Lhotse's `truncate_duration` config (which operates on the parent recording). + if cfg.dataloader_params.get("truncate_duration") is not None: + raise ValueError( + "`truncate_duration` must not be set in `train_ds.dataloader_params`; " + "segment extraction is handled in `AudioCodecLhotseDataset` via `segment_duration`." + ) + segment_duration = dataset_args.get("segment_duration") + if segment_duration is None: + raise ValueError("`segment_duration` must be set in `train_ds.dataset_args` ") + existing_min_duration = cfg.dataloader_params.get("min_duration") + if existing_min_duration is not None and existing_min_duration != -1: + raise ValueError( + "`min_duration` must not be set in `train_ds.dataloader_params`; " + "it is set automatically from `train_ds.dataset_args.segment_duration`." + ) + # Pre-filter to only include cuts whose parent recording is at least as long as + # the training segment duration so the dataset class has enough samples to choose from. + loader_cfg.min_duration = segment_duration + + # Make sure batch_size is set + if loader_cfg.batch_size is None: + raise ValueError("`batch_size` must be set in `train_ds.dataloader_params`.") # --- Create the dataset --- - # Error out if the audio is suspiciously short (half the expected length) - min_samples_for_sanity = loader_cfg.truncate_duration * self.output_sample_rate // 2 - # Create the dataset dataset = AudioCodecLhotseDataset( - sample_rate=self.output_sample_rate, min_samples_for_sanity=min_samples_for_sanity, **dataset_args + sample_rate=self.output_sample_rate, + **dataset_args, ) # Create the dataloader diff --git a/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py b/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py index b8da3cfbfcff..1d61697fabb7 100644 --- a/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py +++ b/tests/collections/tts/data/test_audio_codec_dataset_lhotse.py @@ -88,16 +88,16 @@ def cutset(tmp_path) -> CutSet: def dataset() -> AudioCodecLhotseDataset: return AudioCodecLhotseDataset( sample_rate=TARGET_SAMPLE_RATE, - min_samples_for_sanity=DEFAULT_DURATION * TARGET_SAMPLE_RATE, + segment_duration=DEFAULT_DURATION, ) class TestAudioCodecLhotseDataset: @pytest.mark.unit def test_init(self): - ds = AudioCodecLhotseDataset(sample_rate=22050, min_samples_for_sanity=512) + ds = AudioCodecLhotseDataset(sample_rate=22050, segment_duration=1.0) assert ds.sample_rate == 22050 - assert ds.min_samples_for_sanity == 512 + assert ds.min_samples_for_sanity == 22050 - 5 @pytest.mark.unit def test_getitem_returns_expected_keys_and_shapes(self, dataset, cutset): @@ -137,3 +137,32 @@ def test_getitem_resampling_preserves_frequency(self, dataset, cutset): # FFT bin width is TARGET_SAMPLE_RATE / n; allow ~1 bin of tolerance. bin_width_hz = TARGET_SAMPLE_RATE / n assert abs(peak_freq_hz - cutset[i].target_tone_frequency) <= bin_width_hz + + @pytest.mark.unit + def test_getitem_extracts_subset_of_longer_audio(self, tmp_path, dataset, monkeypatch): + # A cut longer than segment_duration should yield a segment of exactly segment_samples + # that is a contiguous slice taken from inside the longer source signal. + # Use the target sample rate as the source rate so no resampling is involved. + cut = _make_cut(tmp_path, "long", duration=3.0, sample_rate=TARGET_SAMPLE_RATE, tone_frequency=440.0) + cuts = CutSet.from_cuts([cut]) + + # Load the full target audio the same way the dataset does (no resampling needed). + full = cut.target_audio.load_audio().squeeze(0) + segment_samples = int(DEFAULT_DURATION * TARGET_SAMPLE_RATE) + + # Pin the random start so we can compare against the exact source slice. + fixed_start = 7000 + monkeypatch.setattr( + "nemo.collections.tts.data.audio_codec_dataset_lhotse.random.randint", + lambda low, high: fixed_start, + ) + + batch = dataset[cuts] + segment = batch["audio"][0].numpy() + + assert segment.shape == (segment_samples,) + assert batch["audio_lens"][0].item() == segment_samples + # Exact match holds only because source rate == target rate (no resampling) and the + # dataset currently applies no augmentation. Once we add augmentation it makes + # sense to remove this assertion. + assert np.allclose(segment, full[fixed_start : fixed_start + segment_samples])