From 20b19ff9724328e359af0743e0b2343608764910 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 21 May 2026 08:21:46 -0700 Subject: [PATCH 1/4] Support encoder input chunking for SALM vLLM inference MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/speechlm2/vllm/salm/audio.py | 61 +++++++++++++++++-- .../collections/speechlm2/vllm/salm/config.py | 5 ++ nemo/collections/speechlm2/vllm/salm/model.py | 21 +++++-- .../test_vllm_audio_token_estimator.py | 52 ++++++++++++++++ .../collections/speechlm2/test_vllm_plugin.py | 26 +++++++- 5 files changed, 153 insertions(+), 12 deletions(-) diff --git a/nemo/collections/speechlm2/vllm/salm/audio.py b/nemo/collections/speechlm2/vllm/salm/audio.py index 3bc45283287f..2da564e56fc1 100644 --- a/nemo/collections/speechlm2/vllm/salm/audio.py +++ b/nemo/collections/speechlm2/vllm/salm/audio.py @@ -64,6 +64,15 @@ _SAMPLING_RATE = 16000 _AUDIO_CHANNELS = 1 _DUMMY_AUDIO_DURATION_S = 40.0 +# FastConformer preprocessor hop length, used to derive the smallest +# chunk that produces ≥ 2 feature frames (per-feature normalization +# breaks on a single frame). Mirrors +# ``encoder_chunking._get_min_chunk_size_samples`` for the canonical +# preprocessor we ship; the chunking helper probes the live featurizer +# at training time, but the prompt processor here runs before the +# perception module is loaded, so we use the same constant the helper +# would derive. +_MIN_CHUNK_SIZE_SAMPLES = 320 # ── Helpers ───────────────────────────────────────────────────────── @@ -125,9 +134,18 @@ def get_data_parser(self) -> MultiModalDataParser: def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"audio": 1} + def _get_encoder_chunk_size_seconds(self) -> float | None: + """Return the per-encoder-call chunk size baked into the checkpoint. + + Mirrors the training-time ``model.encoder_chunk_size_seconds`` field + (see ``encode_audio_with_optional_chunking``). ``None`` means the + encoder runs once over the full audio, matching legacy checkpoints. + """ + return getattr(self.get_hf_config(), "encoder_chunk_size_seconds", None) + @staticmethod - def _estimate_audio_tokens(audio_length_samples: int) -> int: - """Predict the encoder's output frame count for an audio of N samples. + def _estimate_audio_tokens_single_pass(audio_length_samples: int) -> int: + """Predict the encoder's output frame count for one perception forward. Mirrors the FastConformer preprocessing chain used by ``AudioPerceptionModule``: STFT (n_fft=512, hop_length=160) followed @@ -151,6 +169,39 @@ def _estimate_audio_tokens(audio_length_samples: int) -> int: length = (length + add_pad) / stride + 1.0 return max(1, int(length)) + @classmethod + def _estimate_audio_tokens( + cls, + audio_length_samples: int, + chunk_size_seconds: float | None = None, + ) -> int: + """Predict the encoder's total output frame count for an audio of N samples. + + When ``chunk_size_seconds`` is ``None`` or the audio fits in a single + chunk, returns the single-pass estimate. Otherwise mirrors + ``encode_audio_with_optional_chunking``'s split (with the same + tail-folding rule) and sums the per-chunk frame counts so the + placeholder count matches what the model emits at forward time. + """ + if chunk_size_seconds is None or audio_length_samples <= 0: + return cls._estimate_audio_tokens_single_pass(audio_length_samples) + if chunk_size_seconds <= 0.0: + raise ValueError("encoder_chunk_size_seconds must be positive when set.") + chunk_size_samples = max(1, int(round(chunk_size_seconds * _SAMPLING_RATE))) + chunk_size_samples = max(chunk_size_samples, _MIN_CHUNK_SIZE_SAMPLES) + if audio_length_samples <= chunk_size_samples: + return cls._estimate_audio_tokens_single_pass(audio_length_samples) + + spans: list[tuple[int, int]] = [] + for begin in range(0, audio_length_samples, chunk_size_samples): + end = min(begin + chunk_size_samples, audio_length_samples) + spans.append((begin, end)) + if len(spans) > 1 and spans[-1][1] - spans[-1][0] < _MIN_CHUNK_SIZE_SAMPLES: + spans[-2] = (spans[-2][0], spans[-1][1]) + spans.pop() + + return sum(cls._estimate_audio_tokens_single_pass(end - begin) for begin, end in spans) + class NeMoSpeechLMMultiModalProcessor( BaseMultiModalProcessor[NeMoSpeechLMProcessingInfo], @@ -182,10 +233,11 @@ def _get_prompt_updates( out_mm_kwargs: MultiModalKwargsItems, ) -> list[PromptUpdate]: audios = mm_items.get_items("audio", AudioProcessorItems) + chunk_size_seconds = self.info._get_encoder_chunk_size_seconds() def get_replacement(item_idx: int): audio = audios.get(item_idx) - n_tokens = self.info._estimate_audio_tokens(audio.shape[-1]) + n_tokens = self.info._estimate_audio_tokens(audio.shape[-1], chunk_size_seconds) repl_full = _AUDIO_PLACEHOLDER * n_tokens return PromptUpdateDetails.select_text(repl_full, _AUDIO_PLACEHOLDER) @@ -210,6 +262,7 @@ def _call_hf_processor( audios = mm_data.pop("audios", []) if audios: + chunk_size_seconds = self.info._get_encoder_chunk_size_seconds() audio_list: list[torch.Tensor] = [] audio_lengths: list[int] = [] parts = re.split(f"({re.escape(_AUDIO_PLACEHOLDER)})", prompt) @@ -229,7 +282,7 @@ def _call_hf_processor( ) if audio_tensor.dim() > 1: audio_tensor = audio_tensor.squeeze() - n_tokens = self.info._estimate_audio_tokens(audio_tensor.shape[-1]) + n_tokens = self.info._estimate_audio_tokens(audio_tensor.shape[-1], chunk_size_seconds) parts[i] = _AUDIO_PLACEHOLDER * n_tokens audio_list.append(audio_tensor) audio_lengths.append(audio_tensor.shape[-1]) diff --git a/nemo/collections/speechlm2/vllm/salm/config.py b/nemo/collections/speechlm2/vllm/salm/config.py index 24da38329282..6d9f55d1b3fc 100644 --- a/nemo/collections/speechlm2/vllm/salm/config.py +++ b/nemo/collections/speechlm2/vllm/salm/config.py @@ -76,6 +76,7 @@ def __init__( prompt_format: str | None = None, pretrained_weights: bool | None = None, lora: dict | None = None, + encoder_chunk_size_seconds: float | None = None, **kwargs, ): required_fields = { @@ -88,6 +89,7 @@ def __init__( is_default_init = ( perception is None and lora is None + and encoder_chunk_size_seconds is None and not kwargs and all(value is None for value in required_fields.values()) ) @@ -112,6 +114,7 @@ def __init__( self.prompt_format = None self.pretrained_weights = None self.lora = None + self.encoder_chunk_size_seconds = None return for name, value in required_fields.items(): @@ -137,6 +140,7 @@ def __init__( self.prompt_format = prompt_format self.pretrained_weights = pretrained_weights self.lora = lora + self.encoder_chunk_size_seconds = encoder_chunk_size_seconds self.text_config = AutoConfig.from_pretrained(pretrained_llm, trust_remote_code=True) @@ -214,6 +218,7 @@ def __getattr__(self, name): "text_config", "lora", "is_hybrid", + "encoder_chunk_size_seconds", ): raise AttributeError(name) alias = self._ATTR_ALIASES.get(name, name) if self.is_hybrid else name diff --git a/nemo/collections/speechlm2/vllm/salm/model.py b/nemo/collections/speechlm2/vllm/salm/model.py index cf509e3489cc..5665a2cd7e4d 100644 --- a/nemo/collections/speechlm2/vllm/salm/model.py +++ b/nemo/collections/speechlm2/vllm/salm/model.py @@ -49,7 +49,9 @@ from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors +from nemo.collections.speechlm2.parts.encoder_chunking import encode_audio_with_optional_chunking from nemo.collections.speechlm2.vllm.salm.audio import ( + _SAMPLING_RATE, NeMoSpeechLMAudioInputs, NeMoSpeechLMDummyInputsBuilder, NeMoSpeechLMMultiModalProcessor, @@ -84,6 +86,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config + self.encoder_chunk_size_seconds = getattr(config, "encoder_chunk_size_seconds", None) backend = make_backend(config) self._backend = backend @@ -140,15 +143,21 @@ def _process_audio(self, audio_input: NeMoSpeechLMAudioInputs) -> tuple[torch.Te audio_signal = audio_signal.to(device=device, dtype=torch.float32) audio_lengths = audio_input.audio_signal_length.to(device=device) + # Mirrors training (``encode_audio_with_optional_chunking``): when the + # checkpoint was trained with a chunked encoder (e.g. SALMAutomodel + # default 30 s), long audios are split into chunks before the perception + # forward and the per-chunk embeddings are concatenated. ``None`` + # disables chunking and runs a single forward over the full batch. with torch.no_grad(): - audio_embeds, audio_embed_lens = self.perception( - input_signal=audio_signal, - input_signal_length=audio_lengths, + audio_embeds = encode_audio_with_optional_chunking( + self.perception, + audio_signal, + audio_lengths, + chunk_size_seconds=self.encoder_chunk_size_seconds, + sampling_rate=_SAMPLING_RATE, ) - audio_embeds = audio_embeds.to(torch.bfloat16) - - return tuple(audio_embeds[i, : audio_embed_lens[i]] for i in range(audio_embeds.shape[0])) + return tuple(emb.to(torch.bfloat16) for emb in audio_embeds) def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: audio_input = self._parse_audio_input(**kwargs) diff --git a/tests/collections/speechlm2/test_vllm_audio_token_estimator.py b/tests/collections/speechlm2/test_vllm_audio_token_estimator.py index c691c2cbc482..86588204b2ca 100644 --- a/tests/collections/speechlm2/test_vllm_audio_token_estimator.py +++ b/tests/collections/speechlm2/test_vllm_audio_token_estimator.py @@ -80,3 +80,55 @@ def test_estimator_matches_calc_length(samples: int) -> None: def test_estimator_min_one() -> None: """Even for very short audio the estimator must return at least 1.""" assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens(1) >= 1 + + +def test_estimator_chunking_disabled_matches_single_pass() -> None: + """``chunk_size_seconds=None`` must match the legacy single-pass estimate.""" + samples = 30 * 16_000 + assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens( + samples, chunk_size_seconds=None + ) == NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples) + + +def test_estimator_short_audio_falls_back_to_single_pass() -> None: + """Audio shorter than the chunk size collapses to a single forward.""" + samples = 5 * 16_000 + assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens( + samples, chunk_size_seconds=30.0 + ) == NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples) + + +def test_estimator_chunked_sums_per_chunk_frames() -> None: + """Long audio is split into chunks and per-chunk frame counts are summed, + matching ``encode_audio_with_optional_chunking``'s concat behavior.""" + samples = 90 * 16_000 + chunk_size_seconds = 30.0 + chunk_samples = int(round(chunk_size_seconds * 16_000)) + expected = sum( + NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(min(chunk_samples, samples - i)) + for i in range(0, samples, chunk_samples) + ) + assert ( + NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) + == expected + ) + + +def test_estimator_chunked_tail_folded_into_previous_chunk() -> None: + """A tiny tail (< min chunk size) is folded into the previous chunk so + the total token count matches the runtime helper instead of producing a + spurious single-frame chunk that the audio preprocessor would reject.""" + chunk_size_seconds = 30.0 + chunk_samples = int(round(chunk_size_seconds * 16_000)) + samples = chunk_samples + 100 # 100 sample tail < min_chunk_size_samples (320) + # Folded: one chunk of `samples` samples (no split). + expected = NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples) + assert ( + NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) + == expected + ) + + +def test_estimator_negative_chunk_size_raises() -> None: + with pytest.raises(ValueError, match="encoder_chunk_size_seconds"): + NeMoSpeechLMProcessingInfo._estimate_audio_tokens(16_000, chunk_size_seconds=-1.0) diff --git a/tests/collections/speechlm2/test_vllm_plugin.py b/tests/collections/speechlm2/test_vllm_plugin.py index 1f89a1268924..8f6d5c817fe8 100644 --- a/tests/collections/speechlm2/test_vllm_plugin.py +++ b/tests/collections/speechlm2/test_vllm_plugin.py @@ -209,6 +209,26 @@ def test_unknown_attr_raises(self): with pytest.raises(AttributeError): _ = cfg.nonexistent_attribute_xyz + def test_encoder_chunk_size_seconds_default_none(self): + """Legacy checkpoints without a chunk size keep the single-pass encoder path.""" + cfg = NeMoSpeechLMConfig(**_DEFAULT_CONFIG_KWARGS) + assert cfg.encoder_chunk_size_seconds is None + + def test_encoder_chunk_size_seconds_round_trips(self): + """Chunk size set in config.json (e.g. SALMAutomodel default 30 s) survives load.""" + cfg = NeMoSpeechLMConfig( + **{ + **_DEFAULT_CONFIG_KWARGS, + "encoder_chunk_size_seconds": 30.0, + } + ) + assert cfg.encoder_chunk_size_seconds == 30.0 + + def test_encoder_chunk_size_seconds_default_init_inert(self): + """No-arg default init must still expose ``encoder_chunk_size_seconds=None``.""" + cfg = NeMoSpeechLMConfig() + assert cfg.encoder_chunk_size_seconds is None + @pytest.mark.skipif(not (_HAS_CONFIG and _HAS_VLLM), reason="NeMoSpeechLMConfig or vLLM not available") class TestBackendSelection: @@ -332,7 +352,8 @@ def test_call_hf_processor_requires_matching_placeholder_count(self): processor = object.__new__(NeMoSpeechLMMultiModalProcessor) processor.info = SimpleNamespace( get_tokenizer=_FakeTokenizer, - _estimate_audio_tokens=lambda samples: 2, + _estimate_audio_tokens=lambda samples, chunk_size_seconds=None: 2, + _get_encoder_chunk_size_seconds=lambda: None, ) with pytest.raises(ValueError, match="placeholders"): @@ -351,7 +372,8 @@ def test_call_hf_processor_emits_true_audio_lengths(self): processor = object.__new__(NeMoSpeechLMMultiModalProcessor) processor.info = SimpleNamespace( get_tokenizer=_FakeTokenizer, - _estimate_audio_tokens=lambda samples: 2, + _estimate_audio_tokens=lambda samples, chunk_size_seconds=None: 2, + _get_encoder_chunk_size_seconds=lambda: None, ) result = processor._call_hf_processor( From 0c3e9dd73c5527932e40955beb368ff786e7b746 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 21 May 2026 08:53:22 -0700 Subject: [PATCH 2/4] Fixes for multi-audio turns, audiobench MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/speechlm2/vllm/salm/audio.py | 40 ++++++++++++++++++- .../test_vllm_audio_token_estimator.py | 6 +-- 2 files changed, 41 insertions(+), 5 deletions(-) diff --git a/nemo/collections/speechlm2/vllm/salm/audio.py b/nemo/collections/speechlm2/vllm/salm/audio.py index 2da564e56fc1..bf31eba73491 100644 --- a/nemo/collections/speechlm2/vllm/salm/audio.py +++ b/nemo/collections/speechlm2/vllm/salm/audio.py @@ -132,7 +132,7 @@ def get_data_parser(self) -> MultiModalDataParser: ) def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"audio": 1} + return {"audio": None} def _get_encoder_chunk_size_seconds(self) -> float | None: """Return the per-encoder-call chunk size baked into the checkpoint. @@ -202,6 +202,31 @@ def _estimate_audio_tokens( return sum(cls._estimate_audio_tokens_single_pass(end - begin) for begin, end in spans) + @classmethod + def _samples_for_audio_tokens(cls, target_tokens: int, chunk_size_seconds: float | None = None) -> int: + """Return the smallest sample count estimated to produce ``target_tokens``. + + vLLM sizes the multimodal encoder cache from dummy inputs. The SALM + plugin supports arbitrarily long audio by chunking the encoder forward, + but the decoder still receives the concatenated full-audio embedding + sequence. This inverse estimator lets ``--limit-mm-per-prompt`` audio + length hints reserve cache for that full sequence without hard-coding a + single maximum call duration. + """ + target_tokens = max(1, int(target_tokens)) + max_samples = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE) + lo, hi = 1, min(_SAMPLING_RATE, max_samples) + while hi < max_samples and cls._estimate_audio_tokens(hi, chunk_size_seconds) < target_tokens: + hi = min(hi * 2, max_samples) + + while lo < hi: + mid = (lo + hi) // 2 + if cls._estimate_audio_tokens(mid, chunk_size_seconds) >= target_tokens: + hi = mid + else: + lo = mid + 1 + return lo + class NeMoSpeechLMMultiModalProcessor( BaseMultiModalProcessor[NeMoSpeechLMProcessingInfo], @@ -310,6 +335,19 @@ def get_dummy_mm_data( ) -> MultiModalDataDict: num_audios = mm_counts.get("audio", 0) dummy_audio_len = int(_DUMMY_AUDIO_DURATION_S * _SAMPLING_RATE) + audio_options = mm_options.get("audio") if mm_options else None + requested_audio_len = getattr(audio_options, "length", None) + if requested_audio_len: + chunk_size_seconds = self.info._get_encoder_chunk_size_seconds() + if seq_len > _DUMMY_AUDIO_TEXT_TOKEN_RESERVE: + max_audio_tokens = seq_len - _DUMMY_AUDIO_TEXT_TOKEN_RESERVE + max_audio_len = NeMoSpeechLMProcessingInfo._samples_for_audio_tokens( + max_audio_tokens, + chunk_size_seconds, + ) + else: + max_audio_len = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE) + dummy_audio_len = min(int(requested_audio_len), max_audio_len) return { "audio": self._get_dummy_audios( length=dummy_audio_len, diff --git a/tests/collections/speechlm2/test_vllm_audio_token_estimator.py b/tests/collections/speechlm2/test_vllm_audio_token_estimator.py index 86588204b2ca..af6e9f938358 100644 --- a/tests/collections/speechlm2/test_vllm_audio_token_estimator.py +++ b/tests/collections/speechlm2/test_vllm_audio_token_estimator.py @@ -109,8 +109,7 @@ def test_estimator_chunked_sums_per_chunk_frames() -> None: for i in range(0, samples, chunk_samples) ) assert ( - NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) - == expected + NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) == expected ) @@ -124,8 +123,7 @@ def test_estimator_chunked_tail_folded_into_previous_chunk() -> None: # Folded: one chunk of `samples` samples (no split). expected = NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(samples) assert ( - NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) - == expected + NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) == expected ) From 1ff85ad9179e0b2c99c9364a621a1be2fbc3b49d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 21 May 2026 13:42:59 -0700 Subject: [PATCH 3/4] Add missing constants MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Piotr Żelasko --- nemo/collections/speechlm2/vllm/salm/audio.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nemo/collections/speechlm2/vllm/salm/audio.py b/nemo/collections/speechlm2/vllm/salm/audio.py index bf31eba73491..1f5c3a5afc65 100644 --- a/nemo/collections/speechlm2/vllm/salm/audio.py +++ b/nemo/collections/speechlm2/vllm/salm/audio.py @@ -64,6 +64,8 @@ _SAMPLING_RATE = 16000 _AUDIO_CHANNELS = 1 _DUMMY_AUDIO_DURATION_S = 40.0 +_DUMMY_AUDIO_MAX_DURATION_S = 3600.0 +_DUMMY_AUDIO_TEXT_TOKEN_RESERVE = 64 # FastConformer preprocessor hop length, used to derive the smallest # chunk that produces ≥ 2 feature frames (per-feature normalization # breaks on a single frame). Mirrors From 25bb66b99dc10b0d581ea7d86016e5eedbeb6d13 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Thu, 28 May 2026 09:23:35 -0700 Subject: [PATCH 4/4] Address SALM vLLM review feedback --- nemo/collections/speechlm2/vllm/salm/audio.py | 20 +++++- nemo/collections/speechlm2/vllm/salm/model.py | 2 +- .../test_vllm_audio_token_estimator.py | 39 ++++++++++- .../collections/speechlm2/test_vllm_plugin.py | 68 +++++++++++++++++++ 4 files changed, 124 insertions(+), 5 deletions(-) diff --git a/nemo/collections/speechlm2/vllm/salm/audio.py b/nemo/collections/speechlm2/vllm/salm/audio.py index 1f5c3a5afc65..c892364c0a56 100644 --- a/nemo/collections/speechlm2/vllm/salm/audio.py +++ b/nemo/collections/speechlm2/vllm/salm/audio.py @@ -198,7 +198,7 @@ def _estimate_audio_tokens( for begin in range(0, audio_length_samples, chunk_size_samples): end = min(begin + chunk_size_samples, audio_length_samples) spans.append((begin, end)) - if len(spans) > 1 and spans[-1][1] - spans[-1][0] < _MIN_CHUNK_SIZE_SAMPLES: + if spans[-1][1] - spans[-1][0] < _MIN_CHUNK_SIZE_SAMPLES: spans[-2] = (spans[-2][0], spans[-1][1]) spans.pop() @@ -221,6 +221,14 @@ def _samples_for_audio_tokens(cls, target_tokens: int, chunk_size_seconds: float while hi < max_samples and cls._estimate_audio_tokens(hi, chunk_size_seconds) < target_tokens: hi = min(hi * 2, max_samples) + hi_tokens = cls._estimate_audio_tokens(hi, chunk_size_seconds) + if hi_tokens < target_tokens: + raise ValueError( + f"Cannot produce {target_tokens} audio tokens within the " + f"{_DUMMY_AUDIO_MAX_DURATION_S:g} s dummy-audio cap; " + f"maximum is {hi_tokens}." + ) + while lo < hi: mid = (lo + hi) // 2 if cls._estimate_audio_tokens(mid, chunk_size_seconds) >= target_tokens: @@ -343,10 +351,16 @@ def get_dummy_mm_data( chunk_size_seconds = self.info._get_encoder_chunk_size_seconds() if seq_len > _DUMMY_AUDIO_TEXT_TOKEN_RESERVE: max_audio_tokens = seq_len - _DUMMY_AUDIO_TEXT_TOKEN_RESERVE - max_audio_len = NeMoSpeechLMProcessingInfo._samples_for_audio_tokens( - max_audio_tokens, + max_audio_len = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE) + max_supported_audio_tokens = NeMoSpeechLMProcessingInfo._estimate_audio_tokens( + max_audio_len, chunk_size_seconds, ) + if max_audio_tokens < max_supported_audio_tokens: + max_audio_len = NeMoSpeechLMProcessingInfo._samples_for_audio_tokens( + max_audio_tokens, + chunk_size_seconds, + ) else: max_audio_len = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE) dummy_audio_len = min(int(requested_audio_len), max_audio_len) diff --git a/nemo/collections/speechlm2/vllm/salm/model.py b/nemo/collections/speechlm2/vllm/salm/model.py index db81bd6c3692..cffc19a4977d 100644 --- a/nemo/collections/speechlm2/vllm/salm/model.py +++ b/nemo/collections/speechlm2/vllm/salm/model.py @@ -159,7 +159,7 @@ def _process_audio(self, audio_input: NeMoSpeechLMAudioInputs) -> tuple[torch.Te sampling_rate=_SAMPLING_RATE, ) - return tuple(emb.to(torch.bfloat16) for emb in audio_embeds) + return tuple(emb.to(_PERCEPTION_DTYPE) for emb in audio_embeds) def embed_multimodal(self, **kwargs) -> MultiModalEmbeddings: audio_input = self._parse_audio_input(**kwargs) diff --git a/tests/collections/speechlm2/test_vllm_audio_token_estimator.py b/tests/collections/speechlm2/test_vllm_audio_token_estimator.py index af6e9f938358..cc37292d1cbd 100644 --- a/tests/collections/speechlm2/test_vllm_audio_token_estimator.py +++ b/tests/collections/speechlm2/test_vllm_audio_token_estimator.py @@ -33,7 +33,12 @@ pytest.importorskip("vllm") from nemo.collections.asr.parts.submodules.subsampling import calc_length -from nemo.collections.speechlm2.vllm.salm.audio import NeMoSpeechLMProcessingInfo +from nemo.collections.speechlm2.vllm.salm.audio import ( + _DUMMY_AUDIO_MAX_DURATION_S, + _MIN_CHUNK_SIZE_SAMPLES, + _SAMPLING_RATE, + NeMoSpeechLMProcessingInfo, +) def _reference(audio_length_samples: int) -> int: @@ -127,6 +132,38 @@ def test_estimator_chunked_tail_folded_into_previous_chunk() -> None: ) +def test_estimator_clamps_tiny_chunk_size_to_min_samples() -> None: + assert _MIN_CHUNK_SIZE_SAMPLES == 320 + + chunk_size_seconds = 1 / _SAMPLING_RATE + samples = 2 * _MIN_CHUNK_SIZE_SAMPLES + 100 + expected = NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass( + _MIN_CHUNK_SIZE_SAMPLES + ) + NeMoSpeechLMProcessingInfo._estimate_audio_tokens_single_pass(_MIN_CHUNK_SIZE_SAMPLES + 100) + + assert ( + NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds=chunk_size_seconds) == expected + ) + + def test_estimator_negative_chunk_size_raises() -> None: with pytest.raises(ValueError, match="encoder_chunk_size_seconds"): NeMoSpeechLMProcessingInfo._estimate_audio_tokens(16_000, chunk_size_seconds=-1.0) + + +@pytest.mark.parametrize("chunk_size_seconds", [None, 30.0]) +def test_samples_for_audio_tokens_returns_minimum_sample_count(chunk_size_seconds: float | None) -> None: + target_tokens = 17 + + samples = NeMoSpeechLMProcessingInfo._samples_for_audio_tokens(target_tokens, chunk_size_seconds) + + assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples, chunk_size_seconds) >= target_tokens + assert NeMoSpeechLMProcessingInfo._estimate_audio_tokens(samples - 1, chunk_size_seconds) < target_tokens + + +def test_samples_for_audio_tokens_rejects_unreachable_target() -> None: + max_samples = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE) + max_tokens = NeMoSpeechLMProcessingInfo._estimate_audio_tokens(max_samples) + + with pytest.raises(ValueError, match="Cannot produce"): + NeMoSpeechLMProcessingInfo._samples_for_audio_tokens(max_tokens + 1) diff --git a/tests/collections/speechlm2/test_vllm_plugin.py b/tests/collections/speechlm2/test_vllm_plugin.py index 8f6d5c817fe8..565ee5f08935 100644 --- a/tests/collections/speechlm2/test_vllm_plugin.py +++ b/tests/collections/speechlm2/test_vllm_plugin.py @@ -346,6 +346,74 @@ def test_dummy_inputs_use_profiling_audio_length(self): assert result["audio"][0].shape[-1] == 40 * 16000 + def test_dummy_inputs_use_requested_audio_length(self, monkeypatch): + from nemo.collections.speechlm2.vllm.salm.audio import NeMoSpeechLMDummyInputsBuilder + + builder = object.__new__(NeMoSpeechLMDummyInputsBuilder) + builder.info = SimpleNamespace(_get_encoder_chunk_size_seconds=lambda: None) + monkeypatch.setattr( + builder, + "_get_dummy_audios", + lambda length, num_audios: [SimpleNamespace(length=length) for _ in range(num_audios)], + ) + + result = builder.get_dummy_mm_data( + seq_len=0, + mm_counts={"audio": 1}, + mm_options={"audio": SimpleNamespace(length=12345)}, + ) + + assert result["audio"][0].length == 12345 + + def test_dummy_inputs_cap_requested_audio_length_to_text_budget(self, monkeypatch): + from nemo.collections.speechlm2.vllm.salm.audio import ( + _DUMMY_AUDIO_TEXT_TOKEN_RESERVE, + NeMoSpeechLMDummyInputsBuilder, + NeMoSpeechLMProcessingInfo, + ) + + target_audio_tokens = 4 + max_audio_len = NeMoSpeechLMProcessingInfo._samples_for_audio_tokens(target_audio_tokens) + builder = object.__new__(NeMoSpeechLMDummyInputsBuilder) + builder.info = SimpleNamespace(_get_encoder_chunk_size_seconds=lambda: None) + monkeypatch.setattr( + builder, + "_get_dummy_audios", + lambda length, num_audios: [SimpleNamespace(length=length) for _ in range(num_audios)], + ) + + result = builder.get_dummy_mm_data( + seq_len=_DUMMY_AUDIO_TEXT_TOKEN_RESERVE + target_audio_tokens, + mm_counts={"audio": 1}, + mm_options={"audio": SimpleNamespace(length=max_audio_len + 16000)}, + ) + + assert result["audio"][0].length == max_audio_len + + def test_dummy_inputs_large_seq_len_uses_max_audio_cap(self, monkeypatch): + from nemo.collections.speechlm2.vllm.salm.audio import ( + _DUMMY_AUDIO_MAX_DURATION_S, + _SAMPLING_RATE, + NeMoSpeechLMDummyInputsBuilder, + ) + + max_audio_len = int(_DUMMY_AUDIO_MAX_DURATION_S * _SAMPLING_RATE) + builder = object.__new__(NeMoSpeechLMDummyInputsBuilder) + builder.info = SimpleNamespace(_get_encoder_chunk_size_seconds=lambda: None) + monkeypatch.setattr( + builder, + "_get_dummy_audios", + lambda length, num_audios: [SimpleNamespace(length=length) for _ in range(num_audios)], + ) + + result = builder.get_dummy_mm_data( + seq_len=10_000_000, + mm_counts={"audio": 1}, + mm_options={"audio": SimpleNamespace(length=max_audio_len + 16000)}, + ) + + assert result["audio"][0].length == max_audio_len + def test_call_hf_processor_requires_matching_placeholder_count(self): from nemo.collections.speechlm2.vllm.salm.audio import NeMoSpeechLMMultiModalProcessor