From 56e901a22feb4b62b1aea5d8db3691a403119d7c Mon Sep 17 00:00:00 2001 From: bgiddwani-ai Date: Wed, 22 Apr 2026 17:33:50 +0530 Subject: [PATCH 1/5] salm.py for gemma4 only Signed-off-by: bgiddwani-ai --- nemo/collections/speechlm2/models/salm.py | 268 ++++++++-------------- 1 file changed, 93 insertions(+), 175 deletions(-) diff --git a/nemo/collections/speechlm2/models/salm.py b/nemo/collections/speechlm2/models/salm.py index e59819be8efc..7d489b5ca53a 100644 --- a/nemo/collections/speechlm2/models/salm.py +++ b/nemo/collections/speechlm2/models/salm.py @@ -41,12 +41,7 @@ from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin from nemo.collections.speechlm2.parts.lora import maybe_install_lora from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen -from nemo.collections.speechlm2.parts.pretrained import ( - load_pretrained_hf, - maybe_load_pretrained_models, - move_embedding, - setup_speech_encoder, -) +from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf, move_embedding, setup_speech_encoder from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import logging @@ -62,30 +57,90 @@ def __init__(self, cfg) -> None: self.cfg = DictConfig(cfg) self.audio_locator_tag = self.cfg.audio_locator_tag - self.tokenizer = AutoTokenizer( - self.cfg.pretrained_llm, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False) - ) + self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True) self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]}) - self.llm = load_pretrained_hf( - self.cfg.pretrained_llm, - pretrained_weights=self.cfg.pretrained_weights, - trust_remote_code=self.cfg.get("trust_remote_code", False), + self.llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights) + + # CHANGED - Start + # ── Patch 1: fix get_input_embeddings so HF internals can find embed_tokens ── + self.embed_tokens = self.llm.model.language_model.embed_tokens + del self.llm.model.language_model.embed_tokens + + # ── Hook: restore embed_tokens into language_model before every llm forward ── + def _pre_forward_hook(module, args, kwargs): + module.model.language_model.embed_tokens = self.embed_tokens + return args, kwargs + + def _post_forward_hook(module, args, output): + # Remove again so PyTorch doesn't see duplicate parameters + del module.model.language_model.embed_tokens + + self.llm.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True) + self.llm.register_forward_hook(_post_forward_hook) + + # ── Patch 1: fix get_input_embeddings ── + import types + salm_ref = self + + def _get_embeddings(inner_self): + return salm_ref.embed_tokens + + def _set_embeddings(inner_self, value): + salm_ref.embed_tokens = value + + for module in (self.llm, self.llm.model, self.llm.model.language_model): + module.get_input_embeddings = types.MethodType(_get_embeddings, module) + module.set_input_embeddings = types.MethodType(_set_embeddings, module) + + # ── Patch 2: disable Gemma4's multimodal placeholder routing ── + def _noop_placeholder_mask(inner_self, input_ids, inputs_embeds): + batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0] + seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] + device = inputs_embeds.device if inputs_embeds is not None else input_ids.device + empty = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) + return empty, empty, empty + + self.llm.model.get_placeholder_mask = types.MethodType( + _noop_placeholder_mask, self.llm.model ) - # Note: we have to "move out" the token embedding outside of LLM to avoid - # messing up FSDP/TP hooks. - self.embed_tokens = self.llm.model.embed_tokens - del self.llm.model.embed_tokens - + + # ── Patch 3: skip per_layer_inputs when inputs_embeds provided ── + _orig_get_per_layer_inputs = self.llm.model.language_model.get_per_layer_inputs + + def _patched_get_per_layer_inputs(inner_self, input_ids, inputs_embeds): + if input_ids is None: + return None + return _orig_get_per_layer_inputs(input_ids, inputs_embeds) + + self.llm.model.language_model.get_per_layer_inputs = types.MethodType( + _patched_get_per_layer_inputs, + self.llm.model.language_model + ) + + # ── Patch 4: handle None per_layer_inputs in language_model.forward ── + _orig_lm_forward = self.llm.model.language_model.forward + + def _patched_lm_forward(inner_self, *args, **kwargs): + kwargs.setdefault('per_layer_inputs', None) + return _orig_lm_forward(*args, **kwargs) + + self.llm.model.language_model.forward = types.MethodType( + _patched_lm_forward, + self.llm.model.language_model + ) + maybe_install_lora(self) - # Load the pretrained streaming ASR model and copy its parameters into the audio perception module. setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights) - # Optionally initialize weights from a previous checkpoint (fresh optimizer/scheduler). - # Set model.pretrained_s2s_model or model.pretrained_perception_from_s2s in the config. - maybe_load_pretrained_models(self) - + self._use_fsdp = False self._use_tp = False + @property + def embed_tokens(self): + """Always read embed_tokens from llm to avoid double parameters and + keep Gemma4's own forward working normally.""" + return self.llm.model.language_model.embed_tokens + @property def text_vocab_size(self): """Return the size of the text tokenizer.""" @@ -132,14 +187,6 @@ def forward( attention_mask: Tensor = None, cache=None, ) -> dict[str, Tensor]: - """ - Implements a fully offline forward pass through the entire model. - The flow is the following: - - |speech and text embeddings| -> |llm| -> |lm_head| -> |token ids| - - """ - # input_embeds and out: (B, T, H) out = self.llm( inputs_embeds=input_embeds, attention_mask=attention_mask, @@ -147,25 +194,12 @@ def forward( use_cache=cache is not None, return_dict=True, ) - ans = {"logits": out['logits']} # (B, T, text_vocab_size) + ans = {"logits": out['logits']} if cache is not None: ans["cache"] = out["past_key_values"] return ans def prepare_inputs(self, batch: dict): - """ - Performs additional processing on the mini-batch collected from dataloader. - Notably: - * Convert source audio to speech representations. - * Convert target audio to target audio tokens. - * Convert target text to embeddings. - * Combine the input audio and target text embeddings. - * Take care of any necessary slicing to align the shapes of source audio, - target audio, and target token ids. - """ - # Source audio encoding. - # Input audio: (B, T_samples) - # Audio embeddings: (B, T, H) audio_embs, audio_emb_lens = self.perception( input_signal=batch["audios"], input_signal_length=batch["audio_lens"] ) @@ -178,21 +212,15 @@ def prepare_inputs(self, batch: dict): padding_id=self.text_pad_id, placeholder_id=self.audio_locator_tag_id, replacements=audio_embs, - target_ids=batch["input_ids"].where(batch["loss_mask"], -100), # CrossEntropyLoss().ignore_index + target_ids=batch["input_ids"].where(batch["loss_mask"], -100), ) input_embs = input_embs[:, :-1] attention_mask = attention_mask[:, :-1] target_ids = target_ids[:, 1:] - # Combine target audio and text into a single tensor to slice them together. - # It will also help us truncate the sequence lengths to be divisible by TP world size, - # when TP is enabled. - # Input ids: (B, T, K+1) if self._use_tp: tp_world_size = self.device_mesh["tensor_parallel"].size() if (remainder := (input_embs.shape[1] - 1) % tp_world_size) != 0: - # Truncate some tokens from the end to make the sequence lenght shape divisible by tensor parallelism - # world size. Otherwise, sequence parallelism will change the input shape making leading to mismatches. input_embs = input_embs[:, :-remainder] attention_mask = attention_mask[:, :-remainder] target_ids = target_ids[:, :-remainder] @@ -214,7 +242,7 @@ def training_step(self, batch: dict, batch_idx: int): with loss_parallel(): loss = ( torch.nn.functional.cross_entropy( - forward_outputs["logits"].flatten(0, 1), # (B, T, Vt) -> (*, Vt) + forward_outputs["logits"].flatten(0, 1), inputs["target_ids"].flatten(0, 1), reduction="sum", ignore_index=-100, @@ -230,12 +258,11 @@ def training_step(self, batch: dict, batch_idx: int): ), "batch_size": B, "sequence_length": T, - "num_frames": num_frames.to(torch.float32), # avoid warning + "num_frames": num_frames.to(torch.float32), "target_to_input_ratio": num_frames / (B * T), "padding_ratio": (batch["input_ids"] != self.text_pad_id).long().sum() / batch["input_ids"].numel(), } - self.log("loss", loss, on_step=True, prog_bar=True) - self.log_dict({k: v for k, v in ans.items() if k != "loss"}, on_step=True) + self.log_dict(ans, on_step=True) return ans def on_validation_epoch_start(self) -> None: @@ -263,7 +290,7 @@ def on_validation_epoch_end(self) -> None: def validation_step(self, batch: dict, batch_idx: int): for name, dataset_batch in batch.items(): if dataset_batch is None: - continue # some dataset is exhausted + continue inputs = self.prepare_inputs(dataset_batch) forward_outputs = self(inputs["input_embeds"], attention_mask=inputs["attention_mask"]) num_frames = (inputs["target_ids"] != -100).long().sum() @@ -307,75 +334,8 @@ def generate( audios: torch.Tensor = None, audio_lens: torch.Tensor = None, generation_config: GenerationConfig = None, - enable_thinking: bool | None = None, **generation_kwargs, ) -> torch.Tensor: - """ - Generate LLM answers given text or mixed text+audio prompts. - - Example 1. High-level API using ``prompts`` to provide both text and audio:: - - >>> answer_ids = model.generate( - ... prompts=[ - ... [ - ... { - ... "role": "user", - ... "content": f"Transcribe the following: {model.audio_locator_tag}", - ... "audio": ["path/to/audio.wav"], - ... } - ... ] - ... ], - ... max_new_tokens=128, - ... ) - - You may also include a ``transformers.GenerationConfig`` object to customize decoding strategy:: - - >>> answer_ids = model.generate(..., generation_config=GenerationConfig(do_sample=True, num_beams=5)) - - Example 2. Lower-level API, using ``prompts`` for the text part, - and pre-loaded ``audio`` and ``audio_lens`` tensors:: - - >>> answer_ids = model.generate( - ... prompts=[ - ... [{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}], - ... [{"role": "user", "content": f"Transcribe the following in Polish: {model.audio_locator_tag}"}], - ... ], - ... audios=audios, # torch.Tensor, float32, of shape (batch, time) - ... audio_lens=audio_lens, # torch.Tensor, int64, of shape (batch,) - ... max_new_tokens=128, - ... ) - - Example 3. Lower-level API, using pre-tokenized and pre-formatted ``prompts`` for the text part, - and pre-loaded ``audio`` and ``audio_lens`` tensors:: - - >>> answer_ids = model.generate( - ... prompts=prompts, # torch.Tensor, int64, of shape (batch, num_tokens) - ... audios=audios, # torch.Tensor, float32, of shape (batch, time) - ... audio_lens=audio_lens, # torch.Tensor, int64, of shape (batch,) - ... max_new_tokens=128, - ... ) - - Inputs: - prompts: batch of prompts Tensor or as list[dict] each in the following format - [ - # batch example id 0 - [{"role": "user"}, "slots": {"message": f"Transcribe the following: {model.audio_locator_tag}"}] - # batch example id 1 - [{"role": "user"}, "slots": {"message": f"Transcribe the following in Polish: {model.audio_locator_tag}"}] - ] - "role" is LLM-specific, you can pass multiple turns as well. - If ``prompts`` is a Tensor, we assume it was already formatted in the relevant chat template - and tokenized with the model's tokenizer. - audios: Optional. Time-domain audio signal zero-padded batch of shape (B, T). - The number of audios must correspond to the number of occurrences of in prompts. - Each prompt can have multiple audios. - audio_lens: Optional. Length of each audio example. - generation_config: Optional HuggingFace GenerationConfig object. - enable_thinking: Optional prompt-formatter hint forwarded to ``encode_dialog``. - Relevant for prompt formats that support thinking/reasoning mode. - generation_kwargs: Keyword arguments passed directly to the underlying LLM's ``generate`` method. - """ - # Encode prompt dicts into int token ids. if isinstance(prompts, torch.Tensor): tokens = prompts else: @@ -387,23 +347,15 @@ def generate( ), "Audios cannot be provided via ``prompts`` and ``audios``/``audio_lens`` arguments simultaneously." audios, audio_lens = maybe_audio formatter = PromptFormatter.resolve(self.cfg.prompt_format)(self.tokenizer) - formatter_kwargs = {} - if enable_thinking is not None: - formatter_kwargs["enable_thinking"] = enable_thinking tokens = left_collate_vectors( - [formatter.encode_dialog(turns=prompt, **formatter_kwargs)["input_ids"] for prompt in prompts], + [formatter.encode_dialog(turns=prompt)["input_ids"] for prompt in prompts], padding_value=self.text_pad_id, ).to(self.device) if audios is not None: - # Audio + text input for generation. - # Prepare token embeddings and audio embeddings. tokens_to_embed = tokens.where(tokens != self.audio_locator_tag_id, 0) token_embeds = self.embed_tokens(tokens_to_embed) - # TODO: temporary workaround to perform batch_size=1 inference for audio encoder - # due to accuracy issues at bs>1 audio_embeds, audio_embed_lens = self.perception(audios, audio_lens) audio_embeds = [audio_embeds[i, :elen] for i, elen in enumerate(audio_embed_lens)] - # Insert audio embeddings into relevant positions in text embeddings. input_embeds, _, attention_mask = replace_placeholders_and_build_targets( input_ids=tokens, embeds=token_embeds, @@ -414,7 +366,6 @@ def generate( ) generation_inputs = {"inputs_embeds": input_embeds, "attention_mask": attention_mask} else: - # Text-only generation. attention_mask = tokens != self.text_pad_id generation_inputs = {"input_ids": tokens, "attention_mask": attention_mask} if generation_config is None: @@ -423,8 +374,6 @@ def generate( eos_token_id=self.text_eos_id, pad_token_id=self.text_pad_id, ) - # Generate the answers using HF Generate API. - # Note: we need to put the text embedding layer back to the LLM for processing. with move_embedding(self): answer_tokens = self.llm.generate( **generation_inputs, @@ -437,7 +386,6 @@ def configure_optimizers(self): return configure_optimizers(self) def configure_model(self) -> None: - # TODO(pzelasko): refactor into separate module re-usable across models device_mesh = self.device_mesh if device_mesh is None: return @@ -449,35 +397,17 @@ def configure_model(self) -> None: if (tp_mesh := device_mesh["tensor_parallel"]).size() > 1: self._use_tp = True - # TODO: Distributing embeddings with TP in this setup is tricky - # because we're adding with the output of a non-parallelized - # speech encoder. - # for m in (self.embed_tokens, self.embed_audio_tokens): - # parallelize_module( - # m, - # tp_mesh, - # ColwiseParallel( - # # input_layouts=Shard(1), - # # # Optional: Shard the output along the class dimension to compute the loss in parallel. - # # # See `loss_parallel` in `train.py` - # # output_layouts=Shard(1), - # # use_local_output=False, - # ), - # ) - - # # Parallelize the first embedding and the last linear out projection plan = { - "layers.0": PrepareModuleInput( - input_layouts=(Replicate(),), # , None) - desired_input_layouts=(Shard(1),), # , None) + "model.language_model.layers.0": PrepareModuleInput( + input_layouts=(Replicate(),), + desired_input_layouts=(Shard(1),), use_local_output=True, ), - "norm": SequenceParallel(), + "model.language_model.norm": SequenceParallel(), } parallelize_module(llm, tp_mesh, plan) - # Parallelize each transformer block - for transformer_block in llm.model.layers: + for transformer_block in llm.model.language_model.layers: plan = { "input_layernorm": SequenceParallel(), "self_attn.q_proj": ColwiseParallel(), @@ -492,11 +422,8 @@ def configure_model(self) -> None: "mlp.gate_proj": ColwiseParallel(), "mlp.up_proj": ColwiseParallel(), "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), - # "pre_feedforward_layernorm": SequenceParallel(), - # "post_feedforward_layernorm": SequenceParallel(), } - # Adjust attention module to use the local number of heads attn_layer = transformer_block.self_attn for attr in ("num_heads", "num_key_value_heads", "hidden_size"): val = getattr(attn_layer, attr) @@ -506,7 +433,6 @@ def configure_model(self) -> None: ) setattr(attn_layer, attr, val // tp_mesh.size()) - # Apply the plan for the current transformer block parallelize_module(transformer_block, tp_mesh, plan) parallelize_module( @@ -514,30 +440,23 @@ def configure_model(self) -> None: tp_mesh, ColwiseParallel( input_layouts=Shard(1), - # Optional: Shard the output along the class dimension to compute the loss in parallel. - # See `loss_parallel` in `train.py` output_layouts=Shard(-1), use_local_output=False, ), ) if (dp_mesh := device_mesh["data_parallel"]).size() > 1: - assert dp_mesh.ndim == 1 # Hybrid-sharding not supported + assert dp_mesh.ndim == 1 self._use_fsdp = True fsdp_config = {"mesh": dp_mesh} - for idx, layer in enumerate(llm.model.layers): - llm.model.layers[idx] = fully_shard(layer, **fsdp_config) - self.embed_tokens = fully_shard(self.embed_tokens, **fsdp_config) + for idx, layer in enumerate(llm.model.language_model.layers): + llm.model.language_model.layers[idx] = fully_shard(layer, **fsdp_config) llm.lm_head = fully_shard(llm.lm_head, **fsdp_config) self.llm = fully_shard(self.llm, **fsdp_config) self.perception = fully_shard(self.perception, **fsdp_config) @property def oomptimizer_schema(self) -> dict: - """ - Return a typing schema for optimal batch size calibration for various - sequence lengths using OOMptimizer. - """ return { "cls": dict, "inputs": [ @@ -553,7 +472,6 @@ def oomptimizer_schema(self) -> dict: ], } - def replace_placeholders_and_build_targets( input_ids: torch.Tensor, embeds: torch.Tensor, From 98b6a0d3eae78abc2257bca6583c3772a172b7a6 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Apr 2026 20:31:49 +0000 Subject: [PATCH 2/5] Gemma4 added Signed-off-by: root --- nemo/collections/common/prompts/__init__.py | 2 +- nemo/collections/common/prompts/gemma.py | 67 ++- nemo/collections/speechlm2/models/salm.py | 395 ++++++++++++++---- .../collections/speechlm2/parts/pretrained.py | 7 +- 4 files changed, 373 insertions(+), 98 deletions(-) diff --git a/nemo/collections/common/prompts/__init__.py b/nemo/collections/common/prompts/__init__.py index 55309aa24519..4390fd2e07e6 100644 --- a/nemo/collections/common/prompts/__init__.py +++ b/nemo/collections/common/prompts/__init__.py @@ -15,7 +15,7 @@ from nemo.collections.common.prompts.canary import CanaryPromptFormatter from nemo.collections.common.prompts.canary2 import Canary2PromptFormatter from nemo.collections.common.prompts.formatter import PromptFormatter -from nemo.collections.common.prompts.gemma import GemmaPromptFormatter +from nemo.collections.common.prompts.gemma import GemmaPromptFormatter, Gemma4PromptFormatter from nemo.collections.common.prompts.llama import Llama2PromptFormatter, Llama3PromptFormatter from nemo.collections.common.prompts.mistral import MistralPromptFormatter from nemo.collections.common.prompts.nemotron_h import NemotronHPromptFormatter diff --git a/nemo/collections/common/prompts/gemma.py b/nemo/collections/common/prompts/gemma.py index ee31ba43f3c6..67dff97253ad 100644 --- a/nemo/collections/common/prompts/gemma.py +++ b/nemo/collections/common/prompts/gemma.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -""" -Implemented following the guide at https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format """ +Gemma1 prompt format reference: + https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format +Gemma4 prompt format reference (multimodal: text + image + audio): + <|turn>user + Describe this image: <|image|> + And translate this audio: <|audio|> + <|turn>model +""" from lhotse.cut import Cut, MixedCut from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn @@ -58,9 +63,61 @@ def gemma1(cut: Cut, prompt: GemmaPromptFormatter): context = cut.question else: context = cut.default_context - turns = [{"role": "user", "slots": {"message": context}}] if (answer := cut.supervisions[0].text) is not None: turns.append({"role": "assistant", "slots": {"message": answer}}) - return prompt.encode_dialog(turns) + + +GEMMA4_BOT = "<|turn>" # beginning-of-turn +GEMMA4_EOT = "" # end-of-turn +GEMMA4_IMAGE = "<|image|>" # image placeholder token +GEMMA4_AUDIO = "<|audio|>" # audio placeholder token + + +class Gemma4PromptFormatter(PromptFormatter): + NAME = "gemma4" + OUTPUT_ROLE = "assistant" + INSERT_BOS = True + INSERT_EOS = True + TEMPLATE = { + "user": { + "template": f"{GEMMA4_BOT}user\n|message|{GEMMA4_EOT}\n{GEMMA4_BOT}model\n", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|message|{GEMMA4_EOT}\n", + "slots": { + "message": Modality.Text, + }, + }, + } + + +@registered_prompt_format_fn(Cut, Gemma4PromptFormatter) +def gemma4(cut: Cut, prompt: Gemma4PromptFormatter): + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + if cut.has_custom("context"): + context = cut.context + elif cut.has_custom("question"): + context = cut.question + else: + context = cut.default_context + parts = [] + if context: + parts.append(context) + if cut.has_custom("image") and cut.image is not None: + parts.append(GEMMA4_IMAGE) + if getattr(cut, "has_recording", False) or cut.has_custom("audio_filepath"): + parts.append(GEMMA4_AUDIO) + if cut.has_custom("extra_audios") and cut.extra_audios: + for _ in cut.extra_audios: + parts.append(GEMMA4_AUDIO) + user_message = "\n".join(parts) + turns = [{"role": "user", "slots": {"message": user_message}}] + if (answer := cut.supervisions[0].text) is not None: + turns.append({"role": "assistant", "slots": {"message": answer}}) + return prompt.encode_dialog(turns) \ No newline at end of file diff --git a/nemo/collections/speechlm2/models/salm.py b/nemo/collections/speechlm2/models/salm.py index 7d489b5ca53a..a5c755be575d 100644 --- a/nemo/collections/speechlm2/models/salm.py +++ b/nemo/collections/speechlm2/models/salm.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import types import warnings from collections import defaultdict from itertools import repeat @@ -41,12 +42,118 @@ from nemo.collections.speechlm2.parts.hf_hub import HFHubMixin from nemo.collections.speechlm2.parts.lora import maybe_install_lora from nemo.collections.speechlm2.parts.optim_setup import configure_optimizers, is_frozen -from nemo.collections.speechlm2.parts.pretrained import load_pretrained_hf, move_embedding, setup_speech_encoder +from nemo.collections.speechlm2.parts.pretrained import ( + load_pretrained_hf, + maybe_load_pretrained_models, + move_embedding, + setup_speech_encoder, +) from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import logging +def _is_gemma4(llm) -> bool: + """Detect whether the loaded LLM is a Gemma 4 multimodal model. + + Gemma 4 wraps its language model one level deeper: + llm.model.language_model (Gemma4) + vs. the standard layout: + llm.model (everyone else) + """ + return hasattr(llm, "model") and hasattr(llm.model, "language_model") + +def _setup_normal(salm: "SALM") -> None: + """Move embed_tokens out of llm.model to avoid FSDP/TP hook issues.""" + salm.embed_tokens = salm.llm.model.embed_tokens + del salm.llm.model.embed_tokens + +def _setup_gemma4(salm: "SALM") -> None: + """Apply all patches required for Gemma 4's nested multimodal architecture. + + Gemma 4 buries embed_tokens one level deeper than standard HF models, adds + multimodal placeholder routing, and requires per_layer_inputs in several + internal call sites. These patches make SALM's audio-injection path work + without touching Gemma's own forward logic. + """ + lm = salm.llm.model.language_model # shorthand + + # ------------------------------------------------------------------ + # Patch 1 — pull embed_tokens out to avoid duplicate parameters while + # keeping Gemma's own forward pass intact via pre/post hooks. + # ------------------------------------------------------------------ + salm.embed_tokens = lm.embed_tokens + del lm.embed_tokens + + def _pre_forward_hook(module, args, kwargs): + module.model.language_model.embed_tokens = salm.embed_tokens + return args, kwargs + + def _post_forward_hook(module, args, output): + # Remove so PyTorch doesn't see the parameter twice. + del module.model.language_model.embed_tokens + + salm.llm.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True) + salm.llm.register_forward_hook(_post_forward_hook) + + # Redirect get/set_input_embeddings on all three relevant modules so that + # HF internals (e.g. resize_token_embeddings) always resolve to our copy. + def _get_embeddings(inner_self): + return salm.embed_tokens + + def _set_embeddings(inner_self, value): + salm.embed_tokens = value + + for module in (salm.llm, salm.llm.model, lm): + module.get_input_embeddings = types.MethodType(_get_embeddings, module) + module.set_input_embeddings = types.MethodType(_set_embeddings, module) + + # ------------------------------------------------------------------ + # Patch 2 — disable Gemma 4's multimodal placeholder routing. + # SALM injects audio embeddings directly, so the vision/audio slot + # detection logic must be a no-op. + # ------------------------------------------------------------------ + def _noop_placeholder_mask(inner_self, input_ids, inputs_embeds): + batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0] + seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] + device = inputs_embeds.device if inputs_embeds is not None else input_ids.device + empty = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) + return empty, empty, empty + + salm.llm.model.get_placeholder_mask = types.MethodType( + _noop_placeholder_mask, salm.llm.model + ) + + # ------------------------------------------------------------------ + # Patch 3 — skip per_layer_inputs computation when inputs_embeds are + # already provided (i.e. during SALM's audio-fused forward pass). + # ------------------------------------------------------------------ + _orig_get_per_layer_inputs = lm.get_per_layer_inputs + + def _patched_get_per_layer_inputs(inner_self, input_ids, inputs_embeds): + if input_ids is None: + return None + return _orig_get_per_layer_inputs(input_ids, inputs_embeds) + + lm.get_per_layer_inputs = types.MethodType(_patched_get_per_layer_inputs, lm) + + # ------------------------------------------------------------------ + # Patch 4 — ensure language_model.forward always receives + # per_layer_inputs (defaults to None when not supplied). + # ------------------------------------------------------------------ + _orig_lm_forward = lm.forward + + def _patched_lm_forward(inner_self, *args, **kwargs): + kwargs.setdefault("per_layer_inputs", None) + return _orig_lm_forward(*args, **kwargs) + + lm.forward = types.MethodType(_patched_lm_forward, lm) class SALM(LightningModule, HFHubMixin): + """Speech-Augmented Language Model. + + Supports both standard HF language models and Gemma 4's nested multimodal + architecture. Model-specific setup is handled transparently at init time. + """ + def __init__(self, cfg) -> None: assert isinstance(cfg, dict), ( "You must pass the config to SALM as a Python dict to support hyperparameter serialization " @@ -57,89 +164,55 @@ def __init__(self, cfg) -> None: self.cfg = DictConfig(cfg) self.audio_locator_tag = self.cfg.audio_locator_tag - self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True) - self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]}) - self.llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights) - - # CHANGED - Start - # ── Patch 1: fix get_input_embeddings so HF internals can find embed_tokens ── - self.embed_tokens = self.llm.model.language_model.embed_tokens - del self.llm.model.language_model.embed_tokens - - # ── Hook: restore embed_tokens into language_model before every llm forward ── - def _pre_forward_hook(module, args, kwargs): - module.model.language_model.embed_tokens = self.embed_tokens - return args, kwargs - - def _post_forward_hook(module, args, output): - # Remove again so PyTorch doesn't see duplicate parameters - del module.model.language_model.embed_tokens - - self.llm.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True) - self.llm.register_forward_hook(_post_forward_hook) - - # ── Patch 1: fix get_input_embeddings ── - import types - salm_ref = self - - def _get_embeddings(inner_self): - return salm_ref.embed_tokens - - def _set_embeddings(inner_self, value): - salm_ref.embed_tokens = value - - for module in (self.llm, self.llm.model, self.llm.model.language_model): - module.get_input_embeddings = types.MethodType(_get_embeddings, module) - module.set_input_embeddings = types.MethodType(_set_embeddings, module) - - # ── Patch 2: disable Gemma4's multimodal placeholder routing ── - def _noop_placeholder_mask(inner_self, input_ids, inputs_embeds): - batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0] - seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] - device = inputs_embeds.device if inputs_embeds is not None else input_ids.device - empty = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) - return empty, empty, empty - - self.llm.model.get_placeholder_mask = types.MethodType( - _noop_placeholder_mask, self.llm.model + # ── Tokenizer ────────────────────────────────────────────────────── + self.tokenizer = AutoTokenizer( + self.cfg.pretrained_llm, + use_fast=True, + trust_remote_code=self.cfg.get("trust_remote_code", False), ) - - # ── Patch 3: skip per_layer_inputs when inputs_embeds provided ── - _orig_get_per_layer_inputs = self.llm.model.language_model.get_per_layer_inputs - - def _patched_get_per_layer_inputs(inner_self, input_ids, inputs_embeds): - if input_ids is None: - return None - return _orig_get_per_layer_inputs(input_ids, inputs_embeds) - - self.llm.model.language_model.get_per_layer_inputs = types.MethodType( - _patched_get_per_layer_inputs, - self.llm.model.language_model + self.tokenizer.add_special_tokens( + {"additional_special_tokens": [self.audio_locator_tag]} ) - - # ── Patch 4: handle None per_layer_inputs in language_model.forward ── - _orig_lm_forward = self.llm.model.language_model.forward - - def _patched_lm_forward(inner_self, *args, **kwargs): - kwargs.setdefault('per_layer_inputs', None) - return _orig_lm_forward(*args, **kwargs) - - self.llm.model.language_model.forward = types.MethodType( - _patched_lm_forward, - self.llm.model.language_model + + # ── LLM ──────────────────────────────────────────────────────────── + self.llm = load_pretrained_hf( + self.cfg.pretrained_llm, + pretrained_weights=self.cfg.pretrained_weights, + trust_remote_code=self.cfg.get("trust_remote_code", False), ) - + + # ── Model-specific setup ─────────────────────────────────────────── + # We detect Gemma 4 by the presence of llm.model.language_model. + # All other models take the original code path. + if _is_gemma4(self.llm): + _setup_gemma4(self) + else: + _setup_normal(self) + + # ── Shared post-setup ────────────────────────────────────────────── maybe_install_lora(self) setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights) - + maybe_load_pretrained_models(self) + self._use_fsdp = False self._use_tp = False - @property - def embed_tokens(self): - """Always read embed_tokens from llm to avoid double parameters and - keep Gemma4's own forward working normally.""" - return self.llm.model.language_model.embed_tokens + # ----------------------------------------------------------------------- + # embed_tokens property + # ----------------------------------------------------------------------- + # Both code paths store embed_tokens as a plain attribute on `self`. + # The property below is intentionally NOT defined so that attribute access + # resolves directly to `self.__dict__["embed_tokens"]`, matching the + # original behaviour for normal models and keeping Gemma4's hooks simple. + # + # If you ever need to expose embed_tokens as a @property (e.g. to always + # read from lm.embed_tokens after Gemma4 hooks restore it), uncomment: + # + # @property + # def embed_tokens(self): + # if _is_gemma4(self.llm): + # return self.llm.model.language_model.embed_tokens + # return self.llm.model.embed_tokens @property def text_vocab_size(self): @@ -187,6 +260,14 @@ def forward( attention_mask: Tensor = None, cache=None, ) -> dict[str, Tensor]: + """ + Implements a fully offline forward pass through the entire model. + The flow is the following: + + |speech and text embeddings| -> |llm| -> |lm_head| -> |token ids| + + """ + # input_embeds and out: (B, T, H) out = self.llm( inputs_embeds=input_embeds, attention_mask=attention_mask, @@ -194,12 +275,25 @@ def forward( use_cache=cache is not None, return_dict=True, ) - ans = {"logits": out['logits']} + ans = {"logits": out['logits']} # (B, T, text_vocab_size) if cache is not None: ans["cache"] = out["past_key_values"] return ans def prepare_inputs(self, batch: dict): + """ + Performs additional processing on the mini-batch collected from dataloader. + Notably: + * Convert source audio to speech representations. + * Convert target audio to target audio tokens. + * Convert target text to embeddings. + * Combine the input audio and target text embeddings. + * Take care of any necessary slicing to align the shapes of source audio, + target audio, and target token ids. + """ + # Source audio encoding. + # Input audio: (B, T_samples) + # Audio embeddings: (B, T, H) audio_embs, audio_emb_lens = self.perception( input_signal=batch["audios"], input_signal_length=batch["audio_lens"] ) @@ -212,15 +306,21 @@ def prepare_inputs(self, batch: dict): padding_id=self.text_pad_id, placeholder_id=self.audio_locator_tag_id, replacements=audio_embs, - target_ids=batch["input_ids"].where(batch["loss_mask"], -100), + target_ids=batch["input_ids"].where(batch["loss_mask"], -100), # CrossEntropyLoss().ignore_index ) input_embs = input_embs[:, :-1] attention_mask = attention_mask[:, :-1] target_ids = target_ids[:, 1:] + # Combine target audio and text into a single tensor to slice them together. + # It will also help us truncate the sequence lengths to be divisible by TP world size, + # when TP is enabled. + # Input ids: (B, T, K+1) if self._use_tp: tp_world_size = self.device_mesh["tensor_parallel"].size() if (remainder := (input_embs.shape[1] - 1) % tp_world_size) != 0: + # Truncate some tokens from the end to make the sequence lenght shape divisible by tensor parallelism + # world size. Otherwise, sequence parallelism will change the input shape making leading to mismatches. input_embs = input_embs[:, :-remainder] attention_mask = attention_mask[:, :-remainder] target_ids = target_ids[:, :-remainder] @@ -242,7 +342,7 @@ def training_step(self, batch: dict, batch_idx: int): with loss_parallel(): loss = ( torch.nn.functional.cross_entropy( - forward_outputs["logits"].flatten(0, 1), + forward_outputs["logits"].flatten(0, 1), # (B, T, Vt) -> (*, Vt) inputs["target_ids"].flatten(0, 1), reduction="sum", ignore_index=-100, @@ -258,11 +358,13 @@ def training_step(self, batch: dict, batch_idx: int): ), "batch_size": B, "sequence_length": T, - "num_frames": num_frames.to(torch.float32), + "num_frames": num_frames.to(torch.float32), # avoid warning "target_to_input_ratio": num_frames / (B * T), "padding_ratio": (batch["input_ids"] != self.text_pad_id).long().sum() / batch["input_ids"].numel(), } - self.log_dict(ans, on_step=True) + # self.log_dict(ans, on_step=True) + self.log("loss", loss, on_step=True, prog_bar=True) + self.log_dict({k: v for k, v in ans.items() if k != "loss"}, on_step=True) return ans def on_validation_epoch_start(self) -> None: @@ -290,7 +392,7 @@ def on_validation_epoch_end(self) -> None: def validation_step(self, batch: dict, batch_idx: int): for name, dataset_batch in batch.items(): if dataset_batch is None: - continue + continue # some dataset is exhausted inputs = self.prepare_inputs(dataset_batch) forward_outputs = self(inputs["input_embeds"], attention_mask=inputs["attention_mask"]) num_frames = (inputs["target_ids"] != -100).long().sum() @@ -334,8 +436,75 @@ def generate( audios: torch.Tensor = None, audio_lens: torch.Tensor = None, generation_config: GenerationConfig = None, + enable_thinking: bool | None = None, **generation_kwargs, ) -> torch.Tensor: + """ + Generate LLM answers given text or mixed text+audio prompts. + + Example 1. High-level API using ``prompts`` to provide both text and audio:: + + >>> answer_ids = model.generate( + ... prompts=[ + ... [ + ... { + ... "role": "user", + ... "content": f"Transcribe the following: {model.audio_locator_tag}", + ... "audio": ["path/to/audio.wav"], + ... } + ... ] + ... ], + ... max_new_tokens=128, + ... ) + + You may also include a ``transformers.GenerationConfig`` object to customize decoding strategy:: + + >>> answer_ids = model.generate(..., generation_config=GenerationConfig(do_sample=True, num_beams=5)) + + Example 2. Lower-level API, using ``prompts`` for the text part, + and pre-loaded ``audio`` and ``audio_lens`` tensors:: + + >>> answer_ids = model.generate( + ... prompts=[ + ... [{"role": "user", "content": f"Transcribe the following: {model.audio_locator_tag}"}], + ... [{"role": "user", "content": f"Transcribe the following in Polish: {model.audio_locator_tag}"}], + ... ], + ... audios=audios, # torch.Tensor, float32, of shape (batch, time) + ... audio_lens=audio_lens, # torch.Tensor, int64, of shape (batch,) + ... max_new_tokens=128, + ... ) + + Example 3. Lower-level API, using pre-tokenized and pre-formatted ``prompts`` for the text part, + and pre-loaded ``audio`` and ``audio_lens`` tensors:: + + >>> answer_ids = model.generate( + ... prompts=prompts, # torch.Tensor, int64, of shape (batch, num_tokens) + ... audios=audios, # torch.Tensor, float32, of shape (batch, time) + ... audio_lens=audio_lens, # torch.Tensor, int64, of shape (batch,) + ... max_new_tokens=128, + ... ) + + Inputs: + prompts: batch of prompts Tensor or as list[dict] each in the following format + [ + # batch example id 0 + [{"role": "user"}, "slots": {"message": f"Transcribe the following: {model.audio_locator_tag}"}] + # batch example id 1 + [{"role": "user"}, "slots": {"message": f"Transcribe the following in Polish: {model.audio_locator_tag}"}] + ] + "role" is LLM-specific, you can pass multiple turns as well. + If ``prompts`` is a Tensor, we assume it was already formatted in the relevant chat template + and tokenized with the model's tokenizer. + audios: Optional. Time-domain audio signal zero-padded batch of shape (B, T). + The number of audios must correspond to the number of occurrences of in prompts. + Each prompt can have multiple audios. + audio_lens: Optional. Length of each audio example. + generation_config: Optional HuggingFace GenerationConfig object. + enable_thinking: Optional prompt-formatter hint forwarded to ``encode_dialog``. + Relevant for prompt formats that support thinking/reasoning mode. + generation_kwargs: Keyword arguments passed directly to the underlying LLM's ``generate`` method. + """ + # Encode prompt dicts into int token ids. if isinstance(prompts, torch.Tensor): tokens = prompts else: @@ -347,15 +516,25 @@ def generate( ), "Audios cannot be provided via ``prompts`` and ``audios``/``audio_lens`` arguments simultaneously." audios, audio_lens = maybe_audio formatter = PromptFormatter.resolve(self.cfg.prompt_format)(self.tokenizer) + formatter_kwargs = {} + if enable_thinking is not None: + formatter_kwargs["enable_thinking"] = enable_thinking tokens = left_collate_vectors( - [formatter.encode_dialog(turns=prompt)["input_ids"] for prompt in prompts], + # [formatter.encode_dialog(turns=prompt)["input_ids"] for prompt in prompts], + [formatter.encode_dialog(turns=prompt, **formatter_kwargs)["input_ids"] for prompt in prompts], padding_value=self.text_pad_id, ).to(self.device) + if audios is not None: + # Audio + text input for generation. + # Prepare token embeddings and audio embeddings. tokens_to_embed = tokens.where(tokens != self.audio_locator_tag_id, 0) token_embeds = self.embed_tokens(tokens_to_embed) + # TODO: temporary workaround to perform batch_size=1 inference for audio encoder + # due to accuracy issues at bs>1 audio_embeds, audio_embed_lens = self.perception(audios, audio_lens) audio_embeds = [audio_embeds[i, :elen] for i, elen in enumerate(audio_embed_lens)] + # Insert audio embeddings into relevant positions in text embeddings. input_embeds, _, attention_mask = replace_placeholders_and_build_targets( input_ids=tokens, embeds=token_embeds, @@ -366,6 +545,7 @@ def generate( ) generation_inputs = {"inputs_embeds": input_embeds, "attention_mask": attention_mask} else: + # Text-only generation. attention_mask = tokens != self.text_pad_id generation_inputs = {"input_ids": tokens, "attention_mask": attention_mask} if generation_config is None: @@ -374,6 +554,8 @@ def generate( eos_token_id=self.text_eos_id, pad_token_id=self.text_pad_id, ) + # Generate the answers using HF Generate API. + # Note: we need to put the text embedding layer back to the LLM for processing. with move_embedding(self): answer_tokens = self.llm.generate( **generation_inputs, @@ -386,6 +568,7 @@ def configure_optimizers(self): return configure_optimizers(self) def configure_model(self) -> None: + # TODO(pzelasko): refactor into separate module re-usable across models device_mesh = self.device_mesh if device_mesh is None: return @@ -397,17 +580,35 @@ def configure_model(self) -> None: if (tp_mesh := device_mesh["tensor_parallel"]).size() > 1: self._use_tp = True + # TODO: Distributing embeddings with TP in this setup is tricky + # because we're adding with the output of a non-parallelized + # speech encoder. + # for m in (self.embed_tokens, self.embed_audio_tokens): + # parallelize_module( + # m, + # tp_mesh, + # ColwiseParallel( + # # input_layouts=Shard(1), + # # # Optional: Shard the output along the class dimension to compute the loss in parallel. + # # # See `loss_parallel` in `train.py` + # # output_layouts=Shard(1), + # # use_local_output=False, + # ), + # ) + + # # Parallelize the first embedding and the last linear out projection plan = { - "model.language_model.layers.0": PrepareModuleInput( - input_layouts=(Replicate(),), - desired_input_layouts=(Shard(1),), + "layers.0": PrepareModuleInput( + input_layouts=(Replicate(),), # , None) + desired_input_layouts=(Shard(1),), # , None) use_local_output=True, ), - "model.language_model.norm": SequenceParallel(), + "norm": SequenceParallel(), } parallelize_module(llm, tp_mesh, plan) - for transformer_block in llm.model.language_model.layers: + # Parallelize each transformer block + for transformer_block in llm.model.layers: plan = { "input_layernorm": SequenceParallel(), "self_attn.q_proj": ColwiseParallel(), @@ -422,8 +623,11 @@ def configure_model(self) -> None: "mlp.gate_proj": ColwiseParallel(), "mlp.up_proj": ColwiseParallel(), "mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), + # "pre_feedforward_layernorm": SequenceParallel(), + # "post_feedforward_layernorm": SequenceParallel(), } + # Adjust attention module to use the local number of heads attn_layer = transformer_block.self_attn for attr in ("num_heads", "num_key_value_heads", "hidden_size"): val = getattr(attn_layer, attr) @@ -433,6 +637,7 @@ def configure_model(self) -> None: ) setattr(attn_layer, attr, val // tp_mesh.size()) + # Apply the plan for the current transformer block parallelize_module(transformer_block, tp_mesh, plan) parallelize_module( @@ -440,23 +645,30 @@ def configure_model(self) -> None: tp_mesh, ColwiseParallel( input_layouts=Shard(1), + # Optional: Shard the output along the class dimension to compute the loss in parallel. + # See `loss_parallel` in `train.py` output_layouts=Shard(-1), use_local_output=False, ), ) if (dp_mesh := device_mesh["data_parallel"]).size() > 1: - assert dp_mesh.ndim == 1 + assert dp_mesh.ndim == 1 # Hybrid-sharding not supported self._use_fsdp = True fsdp_config = {"mesh": dp_mesh} - for idx, layer in enumerate(llm.model.language_model.layers): - llm.model.language_model.layers[idx] = fully_shard(layer, **fsdp_config) + for idx, layer in enumerate(llm.model.layers): + llm.model.layers[idx] = fully_shard(layer, **fsdp_config) + self.embed_tokens = fully_shard(self.embed_tokens, **fsdp_config) llm.lm_head = fully_shard(llm.lm_head, **fsdp_config) self.llm = fully_shard(self.llm, **fsdp_config) self.perception = fully_shard(self.perception, **fsdp_config) @property def oomptimizer_schema(self) -> dict: + """ + Return a typing schema for optimal batch size calibration for various + sequence lengths using OOMptimizer. + """ return { "cls": dict, "inputs": [ @@ -472,6 +684,7 @@ def oomptimizer_schema(self) -> dict: ], } + def replace_placeholders_and_build_targets( input_ids: torch.Tensor, embeds: torch.Tensor, diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 871392f77e44..adcaca8ffbe6 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -160,7 +160,12 @@ def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True model.cfg.perception.preprocessor = asr.cfg.preprocessor model.cfg.perception.encoder = asr.cfg.encoder if model.llm is not None: - model.cfg.perception.output_dim = model.llm.config.hidden_size + llm = model.llm + if hasattr(llm, 'config') and hasattr(llm.config, 'hidden_size'): + model.cfg.perception.output_dim = llm.config.hidden_size + elif hasattr(llm, 'language_model') and hasattr(llm.language_model, 'config') \ + and hasattr(llm.language_model.config, 'hidden_size'): + model.cfg.perception.output_dim = llm.language_model.config.hidden_size # Override with user-specified encoder parameters, e.g. initializiing a non-causal encoder for causal setup. if user_encoder_config: for key, value in user_encoder_config.items(): From 9b34b7a043d0fef3290c267b68f8ce146620a08b Mon Sep 17 00:00:00 2001 From: root Date: Mon, 27 Apr 2026 20:39:51 +0000 Subject: [PATCH 3/5] Continuation / Interleaved Data added Signed-off-by: root --- .../speechlm2/data/salm_dataset.py | 308 ++++++++++++++---- 1 file changed, 252 insertions(+), 56 deletions(-) diff --git a/nemo/collections/speechlm2/data/salm_dataset.py b/nemo/collections/speechlm2/data/salm_dataset.py index 765ed00ee181..3a28c17bd6a9 100644 --- a/nemo/collections/speechlm2/data/salm_dataset.py +++ b/nemo/collections/speechlm2/data/salm_dataset.py @@ -11,9 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +SALMDataset with support for two pretraining patterns (see Figure 2). +The relationship between audio and text is encoded purely by position: + + Repetition pattern: + … … + loss_mask=True on transcript tokens only. + The model learns ASR: given Aₙ predict Tₙ. + + Continuation pattern: + … + loss_mask=True on assistant text tokens only. + The model learns cross-modal continuation: given Aₙ predict Tₙ₊₁. + +Manifest entry schema (continuation / mixed): +───────────────────────────────────────────── +{ + "id": "", + "conversations": [ + {"from": "user", "value": "", "type": "text"}, + {"from": "user", "value": "/utt1.wav", "duration": 12.34, "type": "audio"}, + {"from": "assistant", "value": "", "type": "text"}, + {"from": "user", "value": "/utt2.wav", "duration": 13.01, "type": "audio"}, + {"from": "assistant", "value": "", "type": "text"} + ] +} +""" + import logging from itertools import groupby -from typing import Iterable, Union +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -34,63 +63,218 @@ from nemo.collections.speechlm2.data.utils import get_pad_id +# ────────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────────── + +def _encode_text(tokenizer: AutoTokenizer, text: str) -> torch.Tensor: + """Encode *text* and return a 1-D LongTensor of token ids (no BOS/EOS).""" + ids = tokenizer.text_to_ids(text) + return torch.tensor(ids, dtype=torch.long) + + +def _token_id(tokenizer: AutoTokenizer, token: str) -> int: + """Return the single integer id for a special token string.""" + ids = tokenizer.text_to_ids(token) + if len(ids) != 1: + raise ValueError( + f"Special token '{token}' should map to exactly one id, got {ids}. " + "Make sure it is added to the tokenizer vocabulary." + ) + return ids[0] + + +def _cat(parts: List[torch.Tensor]) -> torch.Tensor: + """Concatenate a list of 1-D tensors; return empty LongTensor when empty.""" + if not parts: + return torch.empty(0, dtype=torch.long) + return torch.cat(parts, dim=0) + + +# ────────────────────────────────────────────────────────────────────────────── +# Core dataset +# ────────────────────────────────────────────────────────────────────────────── + class SALMDataset(torch.utils.data.Dataset): """ - A dataset for Speech-Augmented Language Models (SALM) that processes multimodal conversations - containing both text and audio turns. + Dataset for Speech-Augmented Language Models (SALM). - This dataset handles NeMoMultimodalConversation objects which combine text messages - and audio segments in a conversational format. It uses audio_locator_tag in the text, - where each such placeholder corresponds to an entire audio segment. + Supports two pretraining patterns + The audio–text relationship is encoded purely by sequence position. + + Pattern "repetition": + (loss on transcript only) + + Pattern "continuation": + … + (loss on assistant text turns only) + + Pattern "mixed": + Per-sample "pattern" key in the manifest selects the builder; + falls back to "continuation" when the key is absent. Args: - tokenizer (AutoTokenizer): - Tokenizer for converting text to token IDs and vice versa. Must have a special - audio_locator_tag token that will be replaced with audio embeddings during model's - training step. - - Returns: - A dictionary with the following keys: - - audios: Tensor of audio waveform samples [B_audio, T_samples] - - audio_lens: Tensor of audio lengths [B_audio] - - input_ids: Tensor of text token IDs [B, T_tokens], including audio_locator_tag tokens - - loss_mask: Boolean tensor [B, T_tokens] indicating which tokens are part of the - assistant's responses (True) and should be used for computing loss - - Notes: - - Each audio_locator_tag token in input_ids corresponds to an audio segment in audios - - The SALM model later replaces these audio_locator_tag tokens with encoded audio embeddings - - The loss_mask identifies which tokens are part of the target sequences (assistant responses) - and which are part of the source sequences (user prompts) - - The input_ids and loss_mask will be expanded during model forward pass to account for - the variable-length audio segments that replace each audio_locator_tag token + tokenizer: NeMo tokenizer. Must contain ``audio_locator_tag`` + as a single registered special token. + audio_locator_tag: Placeholder string for audio turns. + pattern: Default pattern: "continuation" | "repetition" | "mixed". """ - def __init__(self, tokenizer: AutoTokenizer) -> None: + def __init__( + self, + tokenizer: AutoTokenizer, + audio_locator_tag: str = "<|audioplaceholder|>", + pattern: str = "continuation", + ) -> None: self.tokenizer = tokenizer self.pad_id = get_pad_id(tokenizer) + self.audio_locator_tag = audio_locator_tag + self.pattern = pattern + + # Only the audio-locator placeholder must be a single registered token. + self._audio_loc_id = _token_id(tokenizer, audio_locator_tag) + + # ------------------------------------------------------------------ # + # Public interface # + # ------------------------------------------------------------------ # + + def __getitem__(self, conversations: CutSet) -> Optional[dict]: + """ + Process a mini-batch of NeMoMultimodalConversation cuts. - def __getitem__(self, conversations: CutSet) -> dict | None: - # Note: the function call below may filter out some or all conversations due to audio loading issues. - # If all conversations are filtered out, we'll return None, and expect users to wrap this dataset - # in ``nemo.collections.common.data.fallback.FallbackDataset`` to use the previous mini-batch instead. + Returns a dict with keys: + audios – FloatTensor [B_audio, T_samples] + audio_lens – LongTensor [B_audio] + input_ids – LongTensor [B, T_tokens] (left-padded) + loss_mask – BoolTensor [B, T_tokens] (True = compute loss) + conversations – CutSet (in-memory data dropped) + """ try: - audios, audio_lens, conversations = collate_conversation_audio_fault_tolerant(conversations) - except Exception as e: - logging.warning(f"Error collating conversations: {e}") + audios, audio_lens, conversations = collate_conversation_audio_fault_tolerant( + conversations + ) + except Exception as exc: + logging.warning(f"Error collating conversations: {exc}") return None + if not conversations: return None + + all_input_ids: List[torch.Tensor] = [] + all_loss_masks: List[torch.Tensor] = [] + + for conv in conversations: + sample_pattern = getattr(conv, "pattern", self.pattern) + if sample_pattern == "repetition": + ids, mask = self._build_repetition(conv) + else: + ids, mask = self._build_continuation(conv) + + all_input_ids.append(ids) + all_loss_masks.append(mask) + return { "audios": audios, "audio_lens": audio_lens, - "input_ids": left_collate_vectors([c.input_ids for c in conversations], padding_value=self.pad_id), - "loss_mask": left_collate_vectors( - [getattr(c, "mask", torch.empty(0)) for c in conversations], padding_value=0 - ).to(torch.bool), + "input_ids": left_collate_vectors(all_input_ids, padding_value=self.pad_id), + "loss_mask": left_collate_vectors(all_loss_masks, padding_value=0).to(torch.bool), "conversations": drop_in_memory_data(conversations), } + # ------------------------------------------------------------------ # + # Pattern builders # + # ------------------------------------------------------------------ # + + def _build_repetition( + self, conv: NeMoMultimodalConversation + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Repetition pattern – no signal token, purely positional: + + [context …] + + loss_mask = True only over tokens. + """ + ids_parts: List[torch.Tensor] = [] + mask_parts: List[torch.Tensor] = [] + + turns = list(conv.turns) + i = 0 + while i < len(turns): + turn = turns[i] + + if isinstance(turn, AudioTurn): + # audio placeholder – no loss + ids_parts.append(torch.tensor([self._audio_loc_id], dtype=torch.long)) + mask_parts.append(torch.zeros(1, dtype=torch.long)) + # immediately following TextTurn is the transcript target + if i + 1 < len(turns) and isinstance(turns[i + 1], TextTurn): + i += 1 + transcript_ids = _encode_text(self.tokenizer, turns[i].value) + ids_parts.append(transcript_ids) + mask_parts.append(torch.ones(len(transcript_ids), dtype=torch.long)) + + elif isinstance(turn, TextTurn): + # context / prompt – no loss + text_ids = _encode_text(self.tokenizer, turn.value) + ids_parts.append(text_ids) + mask_parts.append(torch.zeros(len(text_ids), dtype=torch.long)) + i += 1 + + return _cat(ids_parts), _cat(mask_parts) + + def _build_continuation( + self, conv: NeMoMultimodalConversation + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Continuation pattern – no signal token, purely positional: + + + … + + Turn roles: + user TextTurn → context, loss_mask = False + AudioTurn → , loss_mask = False + assistant TextTurn → target text, loss_mask = True + """ + ids_parts: List[torch.Tensor] = [] + mask_parts: List[torch.Tensor] = [] + + turns = list(conv.turns) + i = 0 + while i < len(turns): + turn = turns[i] + + if isinstance(turn, AudioTurn): + # audio placeholder – no loss + ids_parts.append(torch.tensor([self._audio_loc_id], dtype=torch.long)) + mask_parts.append(torch.zeros(1, dtype=torch.long)) + + # immediately following assistant TextTurn is the prediction target + if ( + i + 1 < len(turns) + and isinstance(turns[i + 1], TextTurn) + and getattr(turns[i + 1], "role", "assistant") == "assistant" + ): + i += 1 + target_ids = _encode_text(self.tokenizer, turns[i].value) + ids_parts.append(target_ids) + mask_parts.append(torch.ones(len(target_ids), dtype=torch.long)) + + elif isinstance(turn, TextTurn): + # user prompt / context – no loss + text_ids = _encode_text(self.tokenizer, turn.value) + ids_parts.append(text_ids) + mask_parts.append(torch.zeros(len(text_ids), dtype=torch.long)) + + i += 1 + + return _cat(ids_parts), _cat(mask_parts) + + +# ────────────────────────────────────────────────────────────────────────────── +# Collation / utility functions (public API unchanged) +# ────────────────────────────────────────────────────────────────────────────── def left_collate_vectors( tensors: Iterable[Union[torch.Tensor, np.ndarray]], @@ -113,26 +297,38 @@ def _drop(conversation: NeMoMultimodalConversation) -> NeMoMultimodalConversatio return conversations.map(_drop, apply_fn=None) +# ────────────────────────────────────────────────────────────────────────────── +# Prompt format fn +# ────────────────────────────────────────────────────────────────────────────── + @registered_prompt_format_fn(NeMoMultimodalConversation, Llama2PromptFormatter) def default_multimodal_conversation_prompt_format_fn( - example: NeMoMultimodalConversation, prompt: Llama2PromptFormatter, **prompt_kwargs + example: NeMoMultimodalConversation, + prompt: Llama2PromptFormatter, ): - # Collapse consecutive same-role turns into single turn for proper prompt formatting. - turns = groupby( - [ - { - "role": turn.role, - "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, - } - for turn in example.turns - ], - key=lambda turn: turn["role"], - ) - turns = [ - {"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}} - for role, turn_grp in turns + """Build dialog turns for the prompt formatter (unchanged semantics).""" + raw_turns = [ + { + "role": turn.role, + "slots": { + "message": ( + turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag + ) + }, + } + for turn in example.turns + ] + + collapsed = [ + { + "role": role, + "slots": {"message": " ".join(t["slots"]["message"] for t in grp)}, + } + for role, grp in groupby(raw_turns, key=lambda t: t["role"]) ] + if hasattr(example, "system_prompt"): - turns[0]["role"] = "system_and_user" - turns[0]["slots"]["system"] = example.system_prompt - return prompt.encode_dialog(turns, **prompt_kwargs) + collapsed[0]["role"] = "system_and_user" + collapsed[0]["slots"]["system"] = example.system_prompt + + return prompt.encode_dialog(collapsed) \ No newline at end of file From 943e9168fd7973fe95bc13ee93a25d640efce443 Mon Sep 17 00:00:00 2001 From: bgiddwani-ai Date: Sun, 24 May 2026 21:27:19 +0530 Subject: [PATCH 4/5] Update salm_dataset.py Signed-off-by: bgiddwani-ai --- nemo/collections/speechlm2/data/salm_dataset.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/nemo/collections/speechlm2/data/salm_dataset.py b/nemo/collections/speechlm2/data/salm_dataset.py index 3a28c17bd6a9..313dd17c0029 100644 --- a/nemo/collections/speechlm2/data/salm_dataset.py +++ b/nemo/collections/speechlm2/data/salm_dataset.py @@ -134,6 +134,13 @@ def __init__( # Only the audio-locator placeholder must be a single registered token. self._audio_loc_id = _token_id(tokenizer, audio_locator_tag) + # EOS must be present and supervised at the end of each target sequence; + # otherwise autoregressive generation has no learned stop signal. + self._eos_id = getattr(tokenizer, "eos_id", None) + if self._eos_id is None: + self._eos_id = getattr(tokenizer, "eos_token_id", None) + if self._eos_id is None: + raise ValueError("SALMDataset: tokenizer has neither eos_id nor eos_token_id") # ------------------------------------------------------------------ # # Public interface # # ------------------------------------------------------------------ # @@ -220,6 +227,9 @@ def _build_repetition( ids_parts.append(text_ids) mask_parts.append(torch.zeros(len(text_ids), dtype=torch.long)) i += 1 + + ids_parts.append(torch.tensor([self._eos_id], dtype=torch.long)) + mask_parts.append(torch.ones(1, dtype=torch.long)) return _cat(ids_parts), _cat(mask_parts) @@ -266,9 +276,9 @@ def _build_continuation( text_ids = _encode_text(self.tokenizer, turn.value) ids_parts.append(text_ids) mask_parts.append(torch.zeros(len(text_ids), dtype=torch.long)) - i += 1 - + ids_parts.append(torch.tensor([self._eos_id], dtype=torch.long)) + mask_parts.append(torch.ones(1, dtype=torch.long)) return _cat(ids_parts), _cat(mask_parts) @@ -331,4 +341,4 @@ def default_multimodal_conversation_prompt_format_fn( collapsed[0]["role"] = "system_and_user" collapsed[0]["slots"]["system"] = example.system_prompt - return prompt.encode_dialog(collapsed) \ No newline at end of file + return prompt.encode_dialog(collapsed) From f2679c777b37d2ca84ae547961e0a2484aa150aa Mon Sep 17 00:00:00 2001 From: bgiddwani-ai Date: Mon, 25 May 2026 13:29:54 +0530 Subject: [PATCH 5/5] Update salm.py Added left and right pad Signed-off-by: bgiddwani-ai --- nemo/collections/speechlm2/models/salm.py | 63 ++++++++++++++++++++--- 1 file changed, 56 insertions(+), 7 deletions(-) diff --git a/nemo/collections/speechlm2/models/salm.py b/nemo/collections/speechlm2/models/salm.py index a5c755be575d..41d47d3e1a82 100644 --- a/nemo/collections/speechlm2/models/salm.py +++ b/nemo/collections/speechlm2/models/salm.py @@ -123,14 +123,21 @@ def _noop_placeholder_mask(inner_self, input_ids, inputs_embeds): ) # ------------------------------------------------------------------ - # Patch 3 — skip per_layer_inputs computation when inputs_embeds are - # already provided (i.e. during SALM's audio-fused forward pass). + # Patch 3 — provide token-identity per_layer_inputs even when inputs_embeds + # is the only thing passed at the top of SALM's forward. We read the spliced + # input_ids that prepare_inputs stashed on the SALM instance and feed them + # to the original get_per_layer_inputs. Fallback (None) preserves the + # original short-circuit if no stash is available (e.g. generate()). # ------------------------------------------------------------------ _orig_get_per_layer_inputs = lm.get_per_layer_inputs def _patched_get_per_layer_inputs(inner_self, input_ids, inputs_embeds): if input_ids is None: - return None + cached = getattr(salm, "_current_input_ids", None) + if cached is not None and cached.shape == inputs_embeds.shape[:2]: + input_ids = cached + else: + return None return _orig_get_per_layer_inputs(input_ids, inputs_embeds) lm.get_per_layer_inputs = types.MethodType(_patched_get_per_layer_inputs, lm) @@ -300,6 +307,11 @@ def prepare_inputs(self, batch: dict): audio_embs = [emb[:emblen] for emb, emblen in zip(audio_embs, audio_emb_lens)] input_ids_to_embed = torch.where(batch["input_ids"] == self.audio_locator_tag_id, 0, batch["input_ids"]) text_embs = self.embed_tokens(input_ids_to_embed) + # If embed_tokens is fp32 (LLM-fp32 workaround) but encoder output is bf16, + # cast audio_embs to match so the splice and downstream LLM forward see one dtype. + if len(audio_embs) > 0 and audio_embs[0].dtype != text_embs.dtype: + audio_embs = [emb.to(text_embs.dtype) for emb in audio_embs] + input_embs, target_ids, attention_mask = replace_placeholders_and_build_targets( input_ids=batch["input_ids"], embeds=text_embs, @@ -308,9 +320,22 @@ def prepare_inputs(self, batch: dict): replacements=audio_embs, target_ids=batch["input_ids"].where(batch["loss_mask"], -100), # CrossEntropyLoss().ignore_index ) + # Build spliced input_ids (audio_locator_tag_id at audio positions, real + # token ids elsewhere) so Gemma 4's per-layer-embeddings lookup has token + # identity to use. Stash on self for the patched get_per_layer_inputs to read. + _, _llm_input_ids, _ = replace_placeholders_and_build_targets( + input_ids=batch["input_ids"], + embeds=text_embs, + padding_id=self.text_pad_id, + placeholder_id=self.audio_locator_tag_id, + replacements=audio_embs, + target_ids=batch["input_ids"], + ) + _llm_input_ids = torch.where(_llm_input_ids == -100, 0, _llm_input_ids) input_embs = input_embs[:, :-1] attention_mask = attention_mask[:, :-1] target_ids = target_ids[:, 1:] + self._current_input_ids = _llm_input_ids[:, :-1] # Combine target audio and text into a single tensor to slice them together. # It will also help us truncate the sequence lengths to be divisible by TP world size, @@ -535,6 +560,10 @@ def generate( audio_embeds, audio_embed_lens = self.perception(audios, audio_lens) audio_embeds = [audio_embeds[i, :elen] for i, elen in enumerate(audio_embed_lens)] # Insert audio embeddings into relevant positions in text embeddings. + # Use left-padding for generation (Bug D fix): HF generate() appends new tokens at the + # end of the sequence; with right-pad, shorter rows in a batch end at position + # seq_len < max_seq_length, so the next token lands at position max_seq_length and + # RoPE positions are misaligned for shorter audios -> gibberish at bs>1. input_embeds, _, attention_mask = replace_placeholders_and_build_targets( input_ids=tokens, embeds=token_embeds, @@ -542,6 +571,7 @@ def generate( placeholder_id=self.audio_locator_tag_id, replacements=audio_embeds, target_ids=None, + padding_side="left", ) generation_inputs = {"inputs_embeds": input_embeds, "attention_mask": attention_mask} else: @@ -692,6 +722,7 @@ def replace_placeholders_and_build_targets( placeholder_id: int, replacements: list[torch.Tensor], target_ids: Optional[torch.Tensor] = None, + padding_side: str = "right", ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: """Replaces each occurrence of the placeholder_id in input_ids with the corresponding tensor from the replacements list in the embeds tensor, and creates corresponding adjusted target_ids. @@ -706,6 +737,11 @@ def replace_placeholders_and_build_targets( placeholder_id (int): an id to be replaced. replacements (list of Tensor): each Tensor has shape (L_i, hidden_dim), with L_i arbitrary. target_ids (Tensor): shape (batch, sequence_length); target token ids. + padding_side (str): "right" (default, for training — bf16 backward stability through + Gemma 4's 35 layers requires real content at low RoPE positions) or "left" + (for autoregressive generation — all batch rows must end at the same position + so the "next token" RoPE position is correct; shorter audios otherwise produce + gibberish at bs>1 because RoPE positions are misaligned, see Bug D notes). Returns: Tuple[Tensor, Tensor, Tensor]: @@ -718,6 +754,7 @@ def replace_placeholders_and_build_targets( - Tensor of shape (batch, max_new_sequence_length) with attention padding masks updated to account for shape changes due to replacements. """ + assert padding_side in ("right", "left"), f"padding_side must be 'right' or 'left', got {padding_side!r}" batch_size, seq_len = input_ids.size() if target_ids is not None: assert target_ids.size() == input_ids.size(), "target_ids must have the same shape as input_ids" @@ -813,10 +850,22 @@ def replace_placeholders_and_build_targets( output_target_ids = repeat(None) for i, (seq, tgt, att) in enumerate(zip(output_sequences, output_target_ids, output_att_masks)): seq_len = seq.size(0) - output[i, -seq_len:] = seq - if tgt is not None: - new_target_ids[i, -seq_len:] = tgt - attention_masks[i, -seq_len:] = att + # padding_side="right" (default, training): real tokens at [0:seq_len], pad at [seq_len:]. + # bf16 backward through Gemma 4's 35 transformer layers requires this layout to avoid + # NaN gradients. See docs/DECISIONS.md (2026-04-30 entry) for the full bisection. + # padding_side="left" (generation): real tokens at [-seq_len:], pad at [:max_seq_length-seq_len]. + # All rows end at the same position, so HF generate()'s "next token" RoPE position is + # consistent across the batch. Required for bs>1 inference (Bug D fix, 2026-05-23). + if padding_side == "left": + output[i, -seq_len:] = seq + if tgt is not None: + new_target_ids[i, -seq_len:] = tgt + attention_masks[i, -seq_len:] = att + else: # "right" + output[i, :seq_len] = seq + if tgt is not None: + new_target_ids[i, :seq_len] = tgt + attention_masks[i, :seq_len] = att return output, new_target_ids, attention_masks