diff --git a/nemo/collections/common/prompts/__init__.py b/nemo/collections/common/prompts/__init__.py index 55309aa24519..4390fd2e07e6 100644 --- a/nemo/collections/common/prompts/__init__.py +++ b/nemo/collections/common/prompts/__init__.py @@ -15,7 +15,7 @@ from nemo.collections.common.prompts.canary import CanaryPromptFormatter from nemo.collections.common.prompts.canary2 import Canary2PromptFormatter from nemo.collections.common.prompts.formatter import PromptFormatter -from nemo.collections.common.prompts.gemma import GemmaPromptFormatter +from nemo.collections.common.prompts.gemma import GemmaPromptFormatter, Gemma4PromptFormatter from nemo.collections.common.prompts.llama import Llama2PromptFormatter, Llama3PromptFormatter from nemo.collections.common.prompts.mistral import MistralPromptFormatter from nemo.collections.common.prompts.nemotron_h import NemotronHPromptFormatter diff --git a/nemo/collections/common/prompts/gemma.py b/nemo/collections/common/prompts/gemma.py index ee31ba43f3c6..67dff97253ad 100644 --- a/nemo/collections/common/prompts/gemma.py +++ b/nemo/collections/common/prompts/gemma.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -""" -Implemented following the guide at https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format """ +Gemma1 prompt format reference: + https://www.promptingguide.ai/models/gemma#gemma-7b-prompt-format +Gemma4 prompt format reference (multimodal: text + image + audio): + <|turn>user + Describe this image: <|image|> + And translate this audio: <|audio|> + <|turn>model +""" from lhotse.cut import Cut, MixedCut from nemo.collections.common.data.prompt_fn import registered_prompt_format_fn @@ -58,9 +63,61 @@ def gemma1(cut: Cut, prompt: GemmaPromptFormatter): context = cut.question else: context = cut.default_context - turns = [{"role": "user", "slots": {"message": context}}] if (answer := cut.supervisions[0].text) is not None: turns.append({"role": "assistant", "slots": {"message": answer}}) - return prompt.encode_dialog(turns) + + +GEMMA4_BOT = "<|turn>" # beginning-of-turn +GEMMA4_EOT = "" # end-of-turn +GEMMA4_IMAGE = "<|image|>" # image placeholder token +GEMMA4_AUDIO = "<|audio|>" # audio placeholder token + + +class Gemma4PromptFormatter(PromptFormatter): + NAME = "gemma4" + OUTPUT_ROLE = "assistant" + INSERT_BOS = True + INSERT_EOS = True + TEMPLATE = { + "user": { + "template": f"{GEMMA4_BOT}user\n|message|{GEMMA4_EOT}\n{GEMMA4_BOT}model\n", + "slots": { + "message": Modality.Text, + }, + }, + OUTPUT_ROLE: { + "template": f"|message|{GEMMA4_EOT}\n", + "slots": { + "message": Modality.Text, + }, + }, + } + + +@registered_prompt_format_fn(Cut, Gemma4PromptFormatter) +def gemma4(cut: Cut, prompt: Gemma4PromptFormatter): + if isinstance(cut, MixedCut): + cut = cut.first_non_padding_cut + if cut.has_custom("context"): + context = cut.context + elif cut.has_custom("question"): + context = cut.question + else: + context = cut.default_context + parts = [] + if context: + parts.append(context) + if cut.has_custom("image") and cut.image is not None: + parts.append(GEMMA4_IMAGE) + if getattr(cut, "has_recording", False) or cut.has_custom("audio_filepath"): + parts.append(GEMMA4_AUDIO) + if cut.has_custom("extra_audios") and cut.extra_audios: + for _ in cut.extra_audios: + parts.append(GEMMA4_AUDIO) + user_message = "\n".join(parts) + turns = [{"role": "user", "slots": {"message": user_message}}] + if (answer := cut.supervisions[0].text) is not None: + turns.append({"role": "assistant", "slots": {"message": answer}}) + return prompt.encode_dialog(turns) \ No newline at end of file diff --git a/nemo/collections/speechlm2/data/salm_dataset.py b/nemo/collections/speechlm2/data/salm_dataset.py index 765ed00ee181..313dd17c0029 100644 --- a/nemo/collections/speechlm2/data/salm_dataset.py +++ b/nemo/collections/speechlm2/data/salm_dataset.py @@ -11,9 +11,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +""" +SALMDataset with support for two pretraining patterns (see Figure 2). +The relationship between audio and text is encoded purely by position: + + Repetition pattern: + … … + loss_mask=True on transcript tokens only. + The model learns ASR: given Aₙ predict Tₙ. + + Continuation pattern: + … + loss_mask=True on assistant text tokens only. + The model learns cross-modal continuation: given Aₙ predict Tₙ₊₁. + +Manifest entry schema (continuation / mixed): +───────────────────────────────────────────── +{ + "id": "", + "conversations": [ + {"from": "user", "value": "", "type": "text"}, + {"from": "user", "value": "/utt1.wav", "duration": 12.34, "type": "audio"}, + {"from": "assistant", "value": "", "type": "text"}, + {"from": "user", "value": "/utt2.wav", "duration": 13.01, "type": "audio"}, + {"from": "assistant", "value": "", "type": "text"} + ] +} +""" + import logging from itertools import groupby -from typing import Iterable, Union +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -34,63 +63,228 @@ from nemo.collections.speechlm2.data.utils import get_pad_id +# ────────────────────────────────────────────────────────────────────────────── +# Helpers +# ────────────────────────────────────────────────────────────────────────────── + +def _encode_text(tokenizer: AutoTokenizer, text: str) -> torch.Tensor: + """Encode *text* and return a 1-D LongTensor of token ids (no BOS/EOS).""" + ids = tokenizer.text_to_ids(text) + return torch.tensor(ids, dtype=torch.long) + + +def _token_id(tokenizer: AutoTokenizer, token: str) -> int: + """Return the single integer id for a special token string.""" + ids = tokenizer.text_to_ids(token) + if len(ids) != 1: + raise ValueError( + f"Special token '{token}' should map to exactly one id, got {ids}. " + "Make sure it is added to the tokenizer vocabulary." + ) + return ids[0] + + +def _cat(parts: List[torch.Tensor]) -> torch.Tensor: + """Concatenate a list of 1-D tensors; return empty LongTensor when empty.""" + if not parts: + return torch.empty(0, dtype=torch.long) + return torch.cat(parts, dim=0) + + +# ────────────────────────────────────────────────────────────────────────────── +# Core dataset +# ────────────────────────────────────────────────────────────────────────────── + class SALMDataset(torch.utils.data.Dataset): """ - A dataset for Speech-Augmented Language Models (SALM) that processes multimodal conversations - containing both text and audio turns. + Dataset for Speech-Augmented Language Models (SALM). + + Supports two pretraining patterns + The audio–text relationship is encoded purely by sequence position. + + Pattern "repetition": + (loss on transcript only) - This dataset handles NeMoMultimodalConversation objects which combine text messages - and audio segments in a conversational format. It uses audio_locator_tag in the text, - where each such placeholder corresponds to an entire audio segment. + Pattern "continuation": + … + (loss on assistant text turns only) + + Pattern "mixed": + Per-sample "pattern" key in the manifest selects the builder; + falls back to "continuation" when the key is absent. Args: - tokenizer (AutoTokenizer): - Tokenizer for converting text to token IDs and vice versa. Must have a special - audio_locator_tag token that will be replaced with audio embeddings during model's - training step. - - Returns: - A dictionary with the following keys: - - audios: Tensor of audio waveform samples [B_audio, T_samples] - - audio_lens: Tensor of audio lengths [B_audio] - - input_ids: Tensor of text token IDs [B, T_tokens], including audio_locator_tag tokens - - loss_mask: Boolean tensor [B, T_tokens] indicating which tokens are part of the - assistant's responses (True) and should be used for computing loss - - Notes: - - Each audio_locator_tag token in input_ids corresponds to an audio segment in audios - - The SALM model later replaces these audio_locator_tag tokens with encoded audio embeddings - - The loss_mask identifies which tokens are part of the target sequences (assistant responses) - and which are part of the source sequences (user prompts) - - The input_ids and loss_mask will be expanded during model forward pass to account for - the variable-length audio segments that replace each audio_locator_tag token + tokenizer: NeMo tokenizer. Must contain ``audio_locator_tag`` + as a single registered special token. + audio_locator_tag: Placeholder string for audio turns. + pattern: Default pattern: "continuation" | "repetition" | "mixed". """ - def __init__(self, tokenizer: AutoTokenizer) -> None: + def __init__( + self, + tokenizer: AutoTokenizer, + audio_locator_tag: str = "<|audioplaceholder|>", + pattern: str = "continuation", + ) -> None: self.tokenizer = tokenizer self.pad_id = get_pad_id(tokenizer) + self.audio_locator_tag = audio_locator_tag + self.pattern = pattern + + # Only the audio-locator placeholder must be a single registered token. + self._audio_loc_id = _token_id(tokenizer, audio_locator_tag) + + # EOS must be present and supervised at the end of each target sequence; + # otherwise autoregressive generation has no learned stop signal. + self._eos_id = getattr(tokenizer, "eos_id", None) + if self._eos_id is None: + self._eos_id = getattr(tokenizer, "eos_token_id", None) + if self._eos_id is None: + raise ValueError("SALMDataset: tokenizer has neither eos_id nor eos_token_id") + # ------------------------------------------------------------------ # + # Public interface # + # ------------------------------------------------------------------ # - def __getitem__(self, conversations: CutSet) -> dict | None: - # Note: the function call below may filter out some or all conversations due to audio loading issues. - # If all conversations are filtered out, we'll return None, and expect users to wrap this dataset - # in ``nemo.collections.common.data.fallback.FallbackDataset`` to use the previous mini-batch instead. + def __getitem__(self, conversations: CutSet) -> Optional[dict]: + """ + Process a mini-batch of NeMoMultimodalConversation cuts. + + Returns a dict with keys: + audios – FloatTensor [B_audio, T_samples] + audio_lens – LongTensor [B_audio] + input_ids – LongTensor [B, T_tokens] (left-padded) + loss_mask – BoolTensor [B, T_tokens] (True = compute loss) + conversations – CutSet (in-memory data dropped) + """ try: - audios, audio_lens, conversations = collate_conversation_audio_fault_tolerant(conversations) - except Exception as e: - logging.warning(f"Error collating conversations: {e}") + audios, audio_lens, conversations = collate_conversation_audio_fault_tolerant( + conversations + ) + except Exception as exc: + logging.warning(f"Error collating conversations: {exc}") return None + if not conversations: return None + + all_input_ids: List[torch.Tensor] = [] + all_loss_masks: List[torch.Tensor] = [] + + for conv in conversations: + sample_pattern = getattr(conv, "pattern", self.pattern) + if sample_pattern == "repetition": + ids, mask = self._build_repetition(conv) + else: + ids, mask = self._build_continuation(conv) + + all_input_ids.append(ids) + all_loss_masks.append(mask) + return { "audios": audios, "audio_lens": audio_lens, - "input_ids": left_collate_vectors([c.input_ids for c in conversations], padding_value=self.pad_id), - "loss_mask": left_collate_vectors( - [getattr(c, "mask", torch.empty(0)) for c in conversations], padding_value=0 - ).to(torch.bool), + "input_ids": left_collate_vectors(all_input_ids, padding_value=self.pad_id), + "loss_mask": left_collate_vectors(all_loss_masks, padding_value=0).to(torch.bool), "conversations": drop_in_memory_data(conversations), } + # ------------------------------------------------------------------ # + # Pattern builders # + # ------------------------------------------------------------------ # + + def _build_repetition( + self, conv: NeMoMultimodalConversation + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Repetition pattern – no signal token, purely positional: + + [context …] + + loss_mask = True only over tokens. + """ + ids_parts: List[torch.Tensor] = [] + mask_parts: List[torch.Tensor] = [] + + turns = list(conv.turns) + i = 0 + while i < len(turns): + turn = turns[i] + + if isinstance(turn, AudioTurn): + # audio placeholder – no loss + ids_parts.append(torch.tensor([self._audio_loc_id], dtype=torch.long)) + mask_parts.append(torch.zeros(1, dtype=torch.long)) + # immediately following TextTurn is the transcript target + if i + 1 < len(turns) and isinstance(turns[i + 1], TextTurn): + i += 1 + transcript_ids = _encode_text(self.tokenizer, turns[i].value) + ids_parts.append(transcript_ids) + mask_parts.append(torch.ones(len(transcript_ids), dtype=torch.long)) + + elif isinstance(turn, TextTurn): + # context / prompt – no loss + text_ids = _encode_text(self.tokenizer, turn.value) + ids_parts.append(text_ids) + mask_parts.append(torch.zeros(len(text_ids), dtype=torch.long)) + i += 1 + + ids_parts.append(torch.tensor([self._eos_id], dtype=torch.long)) + mask_parts.append(torch.ones(1, dtype=torch.long)) + + return _cat(ids_parts), _cat(mask_parts) + + def _build_continuation( + self, conv: NeMoMultimodalConversation + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Continuation pattern – no signal token, purely positional: + + + … + + Turn roles: + user TextTurn → context, loss_mask = False + AudioTurn → , loss_mask = False + assistant TextTurn → target text, loss_mask = True + """ + ids_parts: List[torch.Tensor] = [] + mask_parts: List[torch.Tensor] = [] + + turns = list(conv.turns) + i = 0 + while i < len(turns): + turn = turns[i] + + if isinstance(turn, AudioTurn): + # audio placeholder – no loss + ids_parts.append(torch.tensor([self._audio_loc_id], dtype=torch.long)) + mask_parts.append(torch.zeros(1, dtype=torch.long)) + + # immediately following assistant TextTurn is the prediction target + if ( + i + 1 < len(turns) + and isinstance(turns[i + 1], TextTurn) + and getattr(turns[i + 1], "role", "assistant") == "assistant" + ): + i += 1 + target_ids = _encode_text(self.tokenizer, turns[i].value) + ids_parts.append(target_ids) + mask_parts.append(torch.ones(len(target_ids), dtype=torch.long)) + + elif isinstance(turn, TextTurn): + # user prompt / context – no loss + text_ids = _encode_text(self.tokenizer, turn.value) + ids_parts.append(text_ids) + mask_parts.append(torch.zeros(len(text_ids), dtype=torch.long)) + i += 1 + ids_parts.append(torch.tensor([self._eos_id], dtype=torch.long)) + mask_parts.append(torch.ones(1, dtype=torch.long)) + return _cat(ids_parts), _cat(mask_parts) + + +# ────────────────────────────────────────────────────────────────────────────── +# Collation / utility functions (public API unchanged) +# ────────────────────────────────────────────────────────────────────────────── def left_collate_vectors( tensors: Iterable[Union[torch.Tensor, np.ndarray]], @@ -113,26 +307,38 @@ def _drop(conversation: NeMoMultimodalConversation) -> NeMoMultimodalConversatio return conversations.map(_drop, apply_fn=None) +# ────────────────────────────────────────────────────────────────────────────── +# Prompt format fn +# ────────────────────────────────────────────────────────────────────────────── + @registered_prompt_format_fn(NeMoMultimodalConversation, Llama2PromptFormatter) def default_multimodal_conversation_prompt_format_fn( - example: NeMoMultimodalConversation, prompt: Llama2PromptFormatter, **prompt_kwargs + example: NeMoMultimodalConversation, + prompt: Llama2PromptFormatter, ): - # Collapse consecutive same-role turns into single turn for proper prompt formatting. - turns = groupby( - [ - { - "role": turn.role, - "slots": {"message": turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag}, - } - for turn in example.turns - ], - key=lambda turn: turn["role"], - ) - turns = [ - {"role": role, "slots": {"message": " ".join(t["slots"]["message"] for t in turn_grp)}} - for role, turn_grp in turns + """Build dialog turns for the prompt formatter (unchanged semantics).""" + raw_turns = [ + { + "role": turn.role, + "slots": { + "message": ( + turn.value if isinstance(turn, TextTurn) else turn.audio_locator_tag + ) + }, + } + for turn in example.turns ] + + collapsed = [ + { + "role": role, + "slots": {"message": " ".join(t["slots"]["message"] for t in grp)}, + } + for role, grp in groupby(raw_turns, key=lambda t: t["role"]) + ] + if hasattr(example, "system_prompt"): - turns[0]["role"] = "system_and_user" - turns[0]["slots"]["system"] = example.system_prompt - return prompt.encode_dialog(turns, **prompt_kwargs) + collapsed[0]["role"] = "system_and_user" + collapsed[0]["slots"]["system"] = example.system_prompt + + return prompt.encode_dialog(collapsed) diff --git a/nemo/collections/speechlm2/models/salm.py b/nemo/collections/speechlm2/models/salm.py index e59819be8efc..41d47d3e1a82 100644 --- a/nemo/collections/speechlm2/models/salm.py +++ b/nemo/collections/speechlm2/models/salm.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import types import warnings from collections import defaultdict from itertools import repeat @@ -50,8 +51,116 @@ from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, MaskType, NeuralType from nemo.utils import logging +def _is_gemma4(llm) -> bool: + """Detect whether the loaded LLM is a Gemma 4 multimodal model. + + Gemma 4 wraps its language model one level deeper: + llm.model.language_model (Gemma4) + vs. the standard layout: + llm.model (everyone else) + """ + return hasattr(llm, "model") and hasattr(llm.model, "language_model") + +def _setup_normal(salm: "SALM") -> None: + """Move embed_tokens out of llm.model to avoid FSDP/TP hook issues.""" + salm.embed_tokens = salm.llm.model.embed_tokens + del salm.llm.model.embed_tokens + +def _setup_gemma4(salm: "SALM") -> None: + """Apply all patches required for Gemma 4's nested multimodal architecture. + + Gemma 4 buries embed_tokens one level deeper than standard HF models, adds + multimodal placeholder routing, and requires per_layer_inputs in several + internal call sites. These patches make SALM's audio-injection path work + without touching Gemma's own forward logic. + """ + lm = salm.llm.model.language_model # shorthand + + # ------------------------------------------------------------------ + # Patch 1 — pull embed_tokens out to avoid duplicate parameters while + # keeping Gemma's own forward pass intact via pre/post hooks. + # ------------------------------------------------------------------ + salm.embed_tokens = lm.embed_tokens + del lm.embed_tokens + + def _pre_forward_hook(module, args, kwargs): + module.model.language_model.embed_tokens = salm.embed_tokens + return args, kwargs + + def _post_forward_hook(module, args, output): + # Remove so PyTorch doesn't see the parameter twice. + del module.model.language_model.embed_tokens + + salm.llm.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True) + salm.llm.register_forward_hook(_post_forward_hook) + + # Redirect get/set_input_embeddings on all three relevant modules so that + # HF internals (e.g. resize_token_embeddings) always resolve to our copy. + def _get_embeddings(inner_self): + return salm.embed_tokens + + def _set_embeddings(inner_self, value): + salm.embed_tokens = value + + for module in (salm.llm, salm.llm.model, lm): + module.get_input_embeddings = types.MethodType(_get_embeddings, module) + module.set_input_embeddings = types.MethodType(_set_embeddings, module) + + # ------------------------------------------------------------------ + # Patch 2 — disable Gemma 4's multimodal placeholder routing. + # SALM injects audio embeddings directly, so the vision/audio slot + # detection logic must be a no-op. + # ------------------------------------------------------------------ + def _noop_placeholder_mask(inner_self, input_ids, inputs_embeds): + batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0] + seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] + device = inputs_embeds.device if inputs_embeds is not None else input_ids.device + empty = torch.zeros(batch_size, seq_len, dtype=torch.bool, device=device) + return empty, empty, empty + + salm.llm.model.get_placeholder_mask = types.MethodType( + _noop_placeholder_mask, salm.llm.model + ) + + # ------------------------------------------------------------------ + # Patch 3 — provide token-identity per_layer_inputs even when inputs_embeds + # is the only thing passed at the top of SALM's forward. We read the spliced + # input_ids that prepare_inputs stashed on the SALM instance and feed them + # to the original get_per_layer_inputs. Fallback (None) preserves the + # original short-circuit if no stash is available (e.g. generate()). + # ------------------------------------------------------------------ + _orig_get_per_layer_inputs = lm.get_per_layer_inputs + + def _patched_get_per_layer_inputs(inner_self, input_ids, inputs_embeds): + if input_ids is None: + cached = getattr(salm, "_current_input_ids", None) + if cached is not None and cached.shape == inputs_embeds.shape[:2]: + input_ids = cached + else: + return None + return _orig_get_per_layer_inputs(input_ids, inputs_embeds) + + lm.get_per_layer_inputs = types.MethodType(_patched_get_per_layer_inputs, lm) + + # ------------------------------------------------------------------ + # Patch 4 — ensure language_model.forward always receives + # per_layer_inputs (defaults to None when not supplied). + # ------------------------------------------------------------------ + _orig_lm_forward = lm.forward + + def _patched_lm_forward(inner_self, *args, **kwargs): + kwargs.setdefault("per_layer_inputs", None) + return _orig_lm_forward(*args, **kwargs) + + lm.forward = types.MethodType(_patched_lm_forward, lm) class SALM(LightningModule, HFHubMixin): + """Speech-Augmented Language Model. + + Supports both standard HF language models and Gemma 4's nested multimodal + architecture. Model-specific setup is handled transparently at init time. + """ + def __init__(self, cfg) -> None: assert isinstance(cfg, dict), ( "You must pass the config to SALM as a Python dict to support hyperparameter serialization " @@ -62,30 +171,56 @@ def __init__(self, cfg) -> None: self.cfg = DictConfig(cfg) self.audio_locator_tag = self.cfg.audio_locator_tag + # ── Tokenizer ────────────────────────────────────────────────────── self.tokenizer = AutoTokenizer( - self.cfg.pretrained_llm, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False) + self.cfg.pretrained_llm, + use_fast=True, + trust_remote_code=self.cfg.get("trust_remote_code", False), + ) + self.tokenizer.add_special_tokens( + {"additional_special_tokens": [self.audio_locator_tag]} ) - self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]}) + + # ── LLM ──────────────────────────────────────────────────────────── self.llm = load_pretrained_hf( self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights, trust_remote_code=self.cfg.get("trust_remote_code", False), ) - # Note: we have to "move out" the token embedding outside of LLM to avoid - # messing up FSDP/TP hooks. - self.embed_tokens = self.llm.model.embed_tokens - del self.llm.model.embed_tokens + # ── Model-specific setup ─────────────────────────────────────────── + # We detect Gemma 4 by the presence of llm.model.language_model. + # All other models take the original code path. + if _is_gemma4(self.llm): + _setup_gemma4(self) + else: + _setup_normal(self) + + # ── Shared post-setup ────────────────────────────────────────────── maybe_install_lora(self) - # Load the pretrained streaming ASR model and copy its parameters into the audio perception module. setup_speech_encoder(self, pretrained_weights=self.cfg.pretrained_weights) - # Optionally initialize weights from a previous checkpoint (fresh optimizer/scheduler). - # Set model.pretrained_s2s_model or model.pretrained_perception_from_s2s in the config. maybe_load_pretrained_models(self) self._use_fsdp = False self._use_tp = False + # ----------------------------------------------------------------------- + # embed_tokens property + # ----------------------------------------------------------------------- + # Both code paths store embed_tokens as a plain attribute on `self`. + # The property below is intentionally NOT defined so that attribute access + # resolves directly to `self.__dict__["embed_tokens"]`, matching the + # original behaviour for normal models and keeping Gemma4's hooks simple. + # + # If you ever need to expose embed_tokens as a @property (e.g. to always + # read from lm.embed_tokens after Gemma4 hooks restore it), uncomment: + # + # @property + # def embed_tokens(self): + # if _is_gemma4(self.llm): + # return self.llm.model.language_model.embed_tokens + # return self.llm.model.embed_tokens + @property def text_vocab_size(self): """Return the size of the text tokenizer.""" @@ -172,6 +307,11 @@ def prepare_inputs(self, batch: dict): audio_embs = [emb[:emblen] for emb, emblen in zip(audio_embs, audio_emb_lens)] input_ids_to_embed = torch.where(batch["input_ids"] == self.audio_locator_tag_id, 0, batch["input_ids"]) text_embs = self.embed_tokens(input_ids_to_embed) + # If embed_tokens is fp32 (LLM-fp32 workaround) but encoder output is bf16, + # cast audio_embs to match so the splice and downstream LLM forward see one dtype. + if len(audio_embs) > 0 and audio_embs[0].dtype != text_embs.dtype: + audio_embs = [emb.to(text_embs.dtype) for emb in audio_embs] + input_embs, target_ids, attention_mask = replace_placeholders_and_build_targets( input_ids=batch["input_ids"], embeds=text_embs, @@ -180,9 +320,22 @@ def prepare_inputs(self, batch: dict): replacements=audio_embs, target_ids=batch["input_ids"].where(batch["loss_mask"], -100), # CrossEntropyLoss().ignore_index ) + # Build spliced input_ids (audio_locator_tag_id at audio positions, real + # token ids elsewhere) so Gemma 4's per-layer-embeddings lookup has token + # identity to use. Stash on self for the patched get_per_layer_inputs to read. + _, _llm_input_ids, _ = replace_placeholders_and_build_targets( + input_ids=batch["input_ids"], + embeds=text_embs, + padding_id=self.text_pad_id, + placeholder_id=self.audio_locator_tag_id, + replacements=audio_embs, + target_ids=batch["input_ids"], + ) + _llm_input_ids = torch.where(_llm_input_ids == -100, 0, _llm_input_ids) input_embs = input_embs[:, :-1] attention_mask = attention_mask[:, :-1] target_ids = target_ids[:, 1:] + self._current_input_ids = _llm_input_ids[:, :-1] # Combine target audio and text into a single tensor to slice them together. # It will also help us truncate the sequence lengths to be divisible by TP world size, @@ -234,6 +387,7 @@ def training_step(self, batch: dict, batch_idx: int): "target_to_input_ratio": num_frames / (B * T), "padding_ratio": (batch["input_ids"] != self.text_pad_id).long().sum() / batch["input_ids"].numel(), } + # self.log_dict(ans, on_step=True) self.log("loss", loss, on_step=True, prog_bar=True) self.log_dict({k: v for k, v in ans.items() if k != "loss"}, on_step=True) return ans @@ -391,9 +545,11 @@ def generate( if enable_thinking is not None: formatter_kwargs["enable_thinking"] = enable_thinking tokens = left_collate_vectors( + # [formatter.encode_dialog(turns=prompt)["input_ids"] for prompt in prompts], [formatter.encode_dialog(turns=prompt, **formatter_kwargs)["input_ids"] for prompt in prompts], padding_value=self.text_pad_id, ).to(self.device) + if audios is not None: # Audio + text input for generation. # Prepare token embeddings and audio embeddings. @@ -404,6 +560,10 @@ def generate( audio_embeds, audio_embed_lens = self.perception(audios, audio_lens) audio_embeds = [audio_embeds[i, :elen] for i, elen in enumerate(audio_embed_lens)] # Insert audio embeddings into relevant positions in text embeddings. + # Use left-padding for generation (Bug D fix): HF generate() appends new tokens at the + # end of the sequence; with right-pad, shorter rows in a batch end at position + # seq_len < max_seq_length, so the next token lands at position max_seq_length and + # RoPE positions are misaligned for shorter audios -> gibberish at bs>1. input_embeds, _, attention_mask = replace_placeholders_and_build_targets( input_ids=tokens, embeds=token_embeds, @@ -411,6 +571,7 @@ def generate( placeholder_id=self.audio_locator_tag_id, replacements=audio_embeds, target_ids=None, + padding_side="left", ) generation_inputs = {"inputs_embeds": input_embeds, "attention_mask": attention_mask} else: @@ -561,6 +722,7 @@ def replace_placeholders_and_build_targets( placeholder_id: int, replacements: list[torch.Tensor], target_ids: Optional[torch.Tensor] = None, + padding_side: str = "right", ) -> tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: """Replaces each occurrence of the placeholder_id in input_ids with the corresponding tensor from the replacements list in the embeds tensor, and creates corresponding adjusted target_ids. @@ -575,6 +737,11 @@ def replace_placeholders_and_build_targets( placeholder_id (int): an id to be replaced. replacements (list of Tensor): each Tensor has shape (L_i, hidden_dim), with L_i arbitrary. target_ids (Tensor): shape (batch, sequence_length); target token ids. + padding_side (str): "right" (default, for training — bf16 backward stability through + Gemma 4's 35 layers requires real content at low RoPE positions) or "left" + (for autoregressive generation — all batch rows must end at the same position + so the "next token" RoPE position is correct; shorter audios otherwise produce + gibberish at bs>1 because RoPE positions are misaligned, see Bug D notes). Returns: Tuple[Tensor, Tensor, Tensor]: @@ -587,6 +754,7 @@ def replace_placeholders_and_build_targets( - Tensor of shape (batch, max_new_sequence_length) with attention padding masks updated to account for shape changes due to replacements. """ + assert padding_side in ("right", "left"), f"padding_side must be 'right' or 'left', got {padding_side!r}" batch_size, seq_len = input_ids.size() if target_ids is not None: assert target_ids.size() == input_ids.size(), "target_ids must have the same shape as input_ids" @@ -682,10 +850,22 @@ def replace_placeholders_and_build_targets( output_target_ids = repeat(None) for i, (seq, tgt, att) in enumerate(zip(output_sequences, output_target_ids, output_att_masks)): seq_len = seq.size(0) - output[i, -seq_len:] = seq - if tgt is not None: - new_target_ids[i, -seq_len:] = tgt - attention_masks[i, -seq_len:] = att + # padding_side="right" (default, training): real tokens at [0:seq_len], pad at [seq_len:]. + # bf16 backward through Gemma 4's 35 transformer layers requires this layout to avoid + # NaN gradients. See docs/DECISIONS.md (2026-04-30 entry) for the full bisection. + # padding_side="left" (generation): real tokens at [-seq_len:], pad at [:max_seq_length-seq_len]. + # All rows end at the same position, so HF generate()'s "next token" RoPE position is + # consistent across the batch. Required for bs>1 inference (Bug D fix, 2026-05-23). + if padding_side == "left": + output[i, -seq_len:] = seq + if tgt is not None: + new_target_ids[i, -seq_len:] = tgt + attention_masks[i, -seq_len:] = att + else: # "right" + output[i, :seq_len] = seq + if tgt is not None: + new_target_ids[i, :seq_len] = tgt + attention_masks[i, :seq_len] = att return output, new_target_ids, attention_masks diff --git a/nemo/collections/speechlm2/parts/pretrained.py b/nemo/collections/speechlm2/parts/pretrained.py index 871392f77e44..adcaca8ffbe6 100644 --- a/nemo/collections/speechlm2/parts/pretrained.py +++ b/nemo/collections/speechlm2/parts/pretrained.py @@ -160,7 +160,12 @@ def setup_speech_encoder(model: torch.nn.Module, pretrained_weights: bool = True model.cfg.perception.preprocessor = asr.cfg.preprocessor model.cfg.perception.encoder = asr.cfg.encoder if model.llm is not None: - model.cfg.perception.output_dim = model.llm.config.hidden_size + llm = model.llm + if hasattr(llm, 'config') and hasattr(llm.config, 'hidden_size'): + model.cfg.perception.output_dim = llm.config.hidden_size + elif hasattr(llm, 'language_model') and hasattr(llm.language_model, 'config') \ + and hasattr(llm.language_model.config, 'hidden_size'): + model.cfg.perception.output_dim = llm.language_model.config.hidden_size # Override with user-specified encoder parameters, e.g. initializiing a non-causal encoder for causal setup. if user_encoder_config: for key, value in user_encoder_config.items():