Skip to content
74 changes: 51 additions & 23 deletions nemo/collections/tts/data/audio_codec_dataset_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Optional
import random
from typing import Dict

import numpy as np
import torch
from lhotse import CutSet
from lhotse.dataset.collation import collate_audio

from nemo.utils import logging

Expand All @@ -28,33 +29,64 @@ class AudioCodecLhotseDataset(torch.utils.data.Dataset):
It is a simple dataset that mostly just loads the audio samples.
In addition, it performs the following operations:
* Resampling to the target sample rate
* Random truncation of each cut's `target_audio` to a fixed duration
* Sanity checks on the audio

The operations below are handled directly by Lhotse according to the configuration
applied in `AudioCodecModel._get_lhotse_dataloader()`:
* Duration filtering
* Minimum duration filtering
* Any additional transformations configured in Lhotse during its construction are
applied to the audio as it is loaded in `collate_audio()`.
applied to the audio as it is loaded in `load_audio()`.
"""

def __init__(
self,
sample_rate: int,
segment_duration: float,
sanity_check_audio: bool = False,
min_samples_for_sanity: Optional[int] = None,
):
"""
Args:
sample_rate: The sample rate to resample the audio to.
segment_duration: Length of each training segment in seconds. A random
segment of this length is taken from each cut's `target_audio` field
(not from the parent `recording`, which may span a much longer duration).
Comment on lines +52 to +53
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are they shars that have a non-empty recording field?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From examining the cuts, they all seem to have a recording field, at least in two datasets I checked. What varies is whether the recording is the target audio itself or rather a longer recording of which the target audio is a subset (which is only true for certain datasets, but it's something we need to support).

sanity_check_audio: If True, perform sanity checks on the loaded audio.
min_samples_for_sanity: cuts should have at least this many samples or an
error will be raised. Only used when
`sanity_check_audio` is True.
"""
super().__init__()
self.sample_rate = sample_rate
self.segment_duration = segment_duration
self.segment_samples = int(segment_duration * sample_rate)
self.sanity_check_audio = sanity_check_audio
self.min_samples_for_sanity = min_samples_for_sanity
# Error out if audio is suspiciously short (leaving some slack for resampling).
self.min_samples_for_sanity = max(1, self.segment_samples - 5)

def _load_and_truncate_target_audio(self, cut) -> torch.Tensor:
"""
Load `target_audio`, resample, and return a random segment of length `segment_duration`.
"""
if not cut.has_custom("target_audio"):
raise ValueError(f"Cut {cut.id} is missing custom field 'target_audio'")

target_audio_recording = cut.target_audio.resample(self.sample_rate)
# Load the target audio, resampling and applying and Lhotse transformation in the process
audio = target_audio_recording.load_audio()
if audio.ndim > 1:
audio = audio.squeeze(0)

num_samples = audio.shape[-1]
if num_samples < self.segment_samples:
raise ValueError(
f"target_audio is shorter than segment_duration: "
f"cut_id={cut.id}, target_audio_id={target_audio_recording.id}, "
f"num_samples={num_samples}, required={self.segment_samples}, "
f"segment_duration={self.segment_duration}s"
)

# Randomly select a segment of the audio
start = random.randint(0, num_samples - self.segment_samples)
segment = audio[start : start + self.segment_samples]
return torch.from_numpy(np.ascontiguousarray(segment, dtype=np.float32))

def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
"""
Expand All @@ -65,19 +97,15 @@ def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]:
Returns:
A dictionary with the `audio` and `audio_lens` tensors.
"""
# Resample the audio to the target sample rate. We need to do this manually
# because Lhotse only resamples its standard `recording` field automatically,
# not custom fields like `target_audio`.
for cut in cuts:
cut.target_audio = cut.target_audio.resample(self.sample_rate)

# Load and collate the audio, applying any transformations that were
# configured in Lhotse in the process.
# Note: fault_tolerant=False for now to avoid masking errors until we are more
# confident in the new loader.
batch_audio, batch_audio_len = collate_audio(cuts, recording_field="target_audio", fault_tolerant=False)

# Sanity checks on the audio and its length
# Load, resample and truncate the audio
audio_list = [self._load_and_truncate_target_audio(cut) for cut in cuts]
batch_audio = torch.stack(audio_list, dim=0)
batch_audio_len = torch.full(
(len(audio_list),),
self.segment_samples,
dtype=torch.int32,
)

if self.sanity_check_audio:
self._sanity_check_audio(batch_audio, batch_audio_len, cuts)

Expand All @@ -95,7 +123,7 @@ def _sanity_check_audio(self, audio: torch.Tensor, audio_len: torch.Tensor, cuts
# --- Error cases ---

# Audio length is unexpectedly short
if self.min_samples_for_sanity is not None and audio_len.min() < self.min_samples_for_sanity:
if audio_len.min() < self.min_samples_for_sanity:
raise ValueError(
f"Audio length is less than {self.min_samples_for_sanity} samples (min: {audio_len.min()})"
)
Expand Down
38 changes: 25 additions & 13 deletions nemo/collections/tts/models/audio_codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,23 +854,35 @@ def _get_lhotse_dataloader(self, cfg):
# manually in the dataset class.
loader_cfg.sample_rate = self.output_sample_rate

# Set up cut truncation, filtering, and random selection:
# `truncate_duration` and `truncate_offset_type` are interpreted by Lhotse.
# Together, they configure Lhotse to choose a random segment of this length
# from each cut.
if loader_cfg.truncate_duration is None:
raise ValueError("`truncate_duration` must be set in the config")
loader_cfg.truncate_offset_type = "random"
# Also filter examples to be at least this long to avoid zero-padding
loader_cfg.min_duration = loader_cfg.truncate_duration
# Random segment selection is done in AudioCodecLhotseDataset on `target_audio`, not via
# Lhotse's `truncate_duration` config (which operates on the parent recording).
if cfg.dataloader_params.get("truncate_duration") is not None:
raise ValueError(
"`truncate_duration` must not be set in `train_ds.dataloader_params`; "
"segment extraction is handled in `AudioCodecLhotseDataset` via `segment_duration`."
)
segment_duration = dataset_args.get("segment_duration")
if segment_duration is None:
raise ValueError("`segment_duration` must be set in `train_ds.dataset_args` ")
existing_min_duration = cfg.dataloader_params.get("min_duration")
if existing_min_duration is not None and existing_min_duration != -1:
raise ValueError(
"`min_duration` must not be set in `train_ds.dataloader_params`; "
"it is set automatically from `train_ds.dataset_args.segment_duration`."
)
# Pre-filter to only include cuts whose parent recording is at least as long as
# the training segment duration so the dataset class has enough samples to choose from.
loader_cfg.min_duration = segment_duration

# Make sure batch_size is set
if loader_cfg.batch_size is None:
raise ValueError("`batch_size` must be set in `train_ds.dataloader_params`.")

# --- Create the dataset ---

# Error out if the audio is suspiciously short (half the expected length)
min_samples_for_sanity = loader_cfg.truncate_duration * self.output_sample_rate // 2
# Create the dataset
dataset = AudioCodecLhotseDataset(
sample_rate=self.output_sample_rate, min_samples_for_sanity=min_samples_for_sanity, **dataset_args
sample_rate=self.output_sample_rate,
**dataset_args,
)

# Create the dataloader
Expand Down
35 changes: 32 additions & 3 deletions tests/collections/tts/data/test_audio_codec_dataset_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ def cutset(tmp_path) -> CutSet:
def dataset() -> AudioCodecLhotseDataset:
return AudioCodecLhotseDataset(
sample_rate=TARGET_SAMPLE_RATE,
min_samples_for_sanity=DEFAULT_DURATION * TARGET_SAMPLE_RATE,
segment_duration=DEFAULT_DURATION,
)


class TestAudioCodecLhotseDataset:
@pytest.mark.unit
def test_init(self):
ds = AudioCodecLhotseDataset(sample_rate=22050, min_samples_for_sanity=512)
ds = AudioCodecLhotseDataset(sample_rate=22050, segment_duration=1.0)
assert ds.sample_rate == 22050
assert ds.min_samples_for_sanity == 512
assert ds.min_samples_for_sanity == 22050 - 5

@pytest.mark.unit
def test_getitem_returns_expected_keys_and_shapes(self, dataset, cutset):
Expand Down Expand Up @@ -137,3 +137,32 @@ def test_getitem_resampling_preserves_frequency(self, dataset, cutset):
# FFT bin width is TARGET_SAMPLE_RATE / n; allow ~1 bin of tolerance.
bin_width_hz = TARGET_SAMPLE_RATE / n
assert abs(peak_freq_hz - cutset[i].target_tone_frequency) <= bin_width_hz

@pytest.mark.unit
def test_getitem_extracts_subset_of_longer_audio(self, tmp_path, dataset, monkeypatch):
# A cut longer than segment_duration should yield a segment of exactly segment_samples
# that is a contiguous slice taken from inside the longer source signal.
# Use the target sample rate as the source rate so no resampling is involved.
cut = _make_cut(tmp_path, "long", duration=3.0, sample_rate=TARGET_SAMPLE_RATE, tone_frequency=440.0)
cuts = CutSet.from_cuts([cut])

# Load the full target audio the same way the dataset does (no resampling needed).
full = cut.target_audio.load_audio().squeeze(0)
segment_samples = int(DEFAULT_DURATION * TARGET_SAMPLE_RATE)

# Pin the random start so we can compare against the exact source slice.
fixed_start = 7000
monkeypatch.setattr(
"nemo.collections.tts.data.audio_codec_dataset_lhotse.random.randint",
lambda low, high: fixed_start,
)

batch = dataset[cuts]
segment = batch["audio"][0].numpy()

assert segment.shape == (segment_samples,)
assert batch["audio_lens"][0].item() == segment_samples
# Exact match holds only because source rate == target rate (no resampling) and the
# dataset currently applies no augmentation. Once we add augmentation it makes
# sense to remove this assertion.
assert np.allclose(segment, full[fixed_start : fixed_start + segment_samples])
Loading