From 328944e714eccb6c1fd07508c8eca81552fa1d5b Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 2 Jun 2026 18:20:44 +0200 Subject: [PATCH 01/15] nemo/collections/tts/models/easy_magpietts_inference.py: remove duplicate speaker encoder application Signed-off-by: Viacheslav Klimkov --- .../tts/models/easy_magpietts_inference.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/nemo/collections/tts/models/easy_magpietts_inference.py b/nemo/collections/tts/models/easy_magpietts_inference.py index 67eb7212d6ea..d0951dc56d9a 100644 --- a/nemo/collections/tts/models/easy_magpietts_inference.py +++ b/nemo/collections/tts/models/easy_magpietts_inference.py @@ -958,22 +958,6 @@ def prepare_context_tensors( context_audio_embedded = self.embed_audio_tokens(context_audio_codes) # (B, T', E) batch_size = context_audio_embedded.size(0) - if self.use_speaker_encoder: - if ( - self.training - and batch_size > 1 - and self.train_shuffle_context_embedding_prob > 0 - and random.random() < self.train_shuffle_context_embedding_prob - ): - # Feed shuffled raw context embeddings (without speaker encoder) so - # the decoder cannot rely on direct unencoded speaker identity cues. - shift = random.randint(1, batch_size - 1) - context_audio_embedded = context_audio_embedded.roll(shift, dims=0) - else: - context_audio_embedded = self.encode_context_audio_embeddings( - context_audio_embedded=context_audio_embedded, context_audio_lens=context_audio_codes_lens - ) - if self.use_speaker_encoder: if ( self.training From 78404cbcf4dc114639b89ad43bce79390c25583e Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 2 Jun 2026 18:20:44 +0200 Subject: [PATCH 02/15] examples/tts/easymagpie_vllm_omni: initial commit for vllm_omni definition of Easy Magpie Signed-off-by: Viacheslav Klimkov --- ...easy_magpietts_extract_speaker_encoding.py | 164 +++++ .../easy_magpietts_single_infer.py | 141 +++++ .../easymagpie_inference_demo.ipynb | 419 +++++++++++++ .../easymagpie_vllm_omni/__init__.py | 27 + .../easymagpie_vllm_omni/config.py | 158 +++++ .../easymagpie_vllm_omni/easymagpie.py | 586 ++++++++++++++++++ .../easymagpie_vllm_omni/local_transformer.py | 398 ++++++++++++ .../tts/easymagpie_vllm_omni/pyproject.toml | 20 + .../vllm_plugin_easymagpie_omni/__init__.py | 51 ++ 9 files changed, 1964 insertions(+) create mode 100644 examples/tts/easymagpie_vllm_omni/easy_magpietts_extract_speaker_encoding.py create mode 100644 examples/tts/easymagpie_vllm_omni/easy_magpietts_single_infer.py create mode 100644 examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb create mode 100644 examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/__init__.py create mode 100644 examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py create mode 100644 examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py create mode 100644 examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py create mode 100644 examples/tts/easymagpie_vllm_omni/pyproject.toml create mode 100644 examples/tts/easymagpie_vllm_omni/vllm_plugin_easymagpie_omni/__init__.py diff --git a/examples/tts/easymagpie_vllm_omni/easy_magpietts_extract_speaker_encoding.py b/examples/tts/easymagpie_vllm_omni/easy_magpietts_extract_speaker_encoding.py new file mode 100644 index 000000000000..90f47dbd6a49 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easy_magpietts_extract_speaker_encoding.py @@ -0,0 +1,164 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Standalone speaker-encoder output extractor for EasyMagpieTTS. + +Pre-computes ONLY the speaker-encoded context-audio embedding so it can be fed +to a separate (e.g. vLLM) backbone implementation. Context-text / task +embeddings are intentionally NOT included here -- the caller is expected to +prepend/append those (e.g. inside the vLLM model's ``preprocess``). + +This reproduces the audio branch of +``EasyMagpieTTSInferenceModel.prepare_context_tensors``:: + + audio -> codec codes -> (codec convert) -> add BOS/EOS -> frame stacking + -> per-codebook embedding -> speaker encoder + +and saves the resulting ``(T_audio, embedding_dim)`` tensor to disk. + +Example: + python examples/tts/easy_magpietts_extract_speaker_encoding.py \\ + --nemo_file /path/to/EMTTS_Pretraining_Qwen_WithCrossLingual_3_5_Delay.nemo \\ + --codec_model_path /path/to/25fps_spectral_codec_with_bandwidth_extension.nemo \\ + --phoneme_tokenizer_path /path/to/bpe_ipa_tokenizer_2048_en_de_es_fr_hi_it_vi_zh.json \\ + --context_audio /path/to/reference_voice.wav \\ + --out_file ./speaker_encoding.pt +""" +from __future__ import annotations + +import argparse + +import torch + +from nemo.collections.tts.modules.magpietts_inference.utils import ModelLoadConfig, load_easy_magpie_model +from nemo.collections.tts.modules.magpietts_modules import add_special_tokens +from nemo.utils import logging + + +def main(): + parser = argparse.ArgumentParser(description="Extract EasyMagpieTTS speaker-encoder output") + parser.add_argument("--nemo_file", required=True, help="Path to the EasyMagpieTTS .nemo checkpoint") + parser.add_argument("--codec_model_path", required=True, help="Path to the audio codec .nemo checkpoint") + parser.add_argument( + "--phoneme_tokenizer_path", + default=None, + help="Override the phoneme (IPA BPE) tokenizer path baked into the checkpoint. " + "Required if the path stored in the .nemo does not exist locally.", + ) + parser.add_argument("--context_audio", required=True, help="Reference/context wav for voice cloning") + parser.add_argument( + "--disable_cas_for_context_text", + action="store_true", + help="Set for legacy checkpoints trained without CAS embeddings on context text", + ) + parser.add_argument("--context_audio_duration", type=float, default=5.0) + parser.add_argument("--device", default="cuda") + parser.add_argument( + "--out_file", + default="./speaker_encoding.pt", + help="Output path. A torch .pt file (dict) is written; if it ends with .npy the " + "speaker-encoding tensor is saved as a NumPy array instead.", + ) + + args = parser.parse_args() + + model, ckpt_name = load_easy_magpie_model( + ModelLoadConfig( + nemo_file=args.nemo_file, + codecmodel_path=args.codec_model_path, + phoneme_tokenizer_path=args.phoneme_tokenizer_path, + disable_cas_for_context_text=args.disable_cas_for_context_text, + ), + device=args.device, + ) + logging.info(f"Loaded EasyMagpieTTS checkpoint: {ckpt_name}") + logging.info(f"use_speaker_encoder={getattr(model, 'use_speaker_encoder', False)}") + + device = next(model.parameters()).device + + with torch.inference_mode(): + # Load + trim context audio exactly like EasyMagpieTTSInferenceModel.do_tts. + context_audio = model._load_audio_for_inference(args.context_audio, model.sample_rate) + context_audio = model._adjust_audio_to_duration_for_inference( + context_audio, + model.sample_rate, + args.context_audio_duration, + model.codec_model_samples_per_frame, + ) + context_audio = context_audio.to(device) + context_audio_lens = torch.tensor([context_audio.size(1)], dtype=torch.long, device=device) + context_audio_codes, context_audio_codes_lens = model._codec_helper.audio_to_codes( + context_audio, context_audio_lens + ) + + # --- Audio branch of prepare_context_tensors (no context text / task embedding) --- + if model._codec_converter is not None: + context_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens + ).long() + + context_audio_codes, context_audio_codes_lens = add_special_tokens( + codes=context_audio_codes, + codes_len=context_audio_codes_lens, + bos_id=model.context_audio_bos_id, + eos_id=model.context_audio_eos_id, + ) + + context_audio_codes, context_audio_codes_lens = model.stack_codes( + context_audio_codes, + context_audio_codes_lens, + model.context_audio_bos_id, + model.context_audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + + context_audio_embedded = model.embed_audio_tokens(context_audio_codes) # (B, T_audio, E) + + if getattr(model, "use_speaker_encoder", False): + context_audio_embedded = model.encode_context_audio_embeddings( + context_audio_embedded=context_audio_embedded, + context_audio_lens=context_audio_codes_lens, + ) + else: + logging.warning( + "Checkpoint has use_speaker_encoder=False; saving raw per-codebook audio embeddings " + "(no speaker encoder applied)." + ) + + # Strip batch dim (B == 1) -> (T_audio, embedding_dim). + audio_len = int(context_audio_codes_lens[0].item()) + speaker_encoding = context_audio_embedded[0, :audio_len].contiguous().float().detach().cpu() + logging.info(f"Extracted speaker-encoder output: {tuple(speaker_encoding.shape)}") + + if args.out_file.endswith(".npy"): + import numpy as np + + np.save(args.out_file, speaker_encoding.numpy()) + else: + torch.save( + { + "speaker_encoding": speaker_encoding, + "context_audio": args.context_audio, + "embedding_dim": int(speaker_encoding.size(-1)), + "num_frames": int(speaker_encoding.size(0)), + "checkpoint": ckpt_name, + }, + args.out_file, + ) + logging.info(f"Wrote speaker encoding of shape {tuple(speaker_encoding.shape)} to {args.out_file}") + + +if __name__ == "__main__": + main() diff --git a/examples/tts/easymagpie_vllm_omni/easy_magpietts_single_infer.py b/examples/tts/easymagpie_vllm_omni/easy_magpietts_single_infer.py new file mode 100644 index 000000000000..313f4caa7f61 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easy_magpietts_single_infer.py @@ -0,0 +1,141 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Minimal pure-PyTorch single-utterance inference for EasyMagpieTTS. + +No vLLM, no manifest, no evalset config. Just: one context wav + one text -> one wav. + +Example: + python examples/tts/easy_magpietts_single_infer.py \\ + --nemo_file /path/to/EMTTS_Pretraining_Qwen_WithCrossLingual_3_5_Delay.nemo \\ + --codec_model_path /path/to/25fps_spectral_codec_with_bandwidth_extension.nemo \\ + --phoneme_tokenizer_path /path/to/bpe_ipa_tokenizer_2048_en_de_es_fr_hi_it_vi_zh.json \\ + --context_audio /path/to/reference_voice.wav \\ + --text "Hello, this is a test of the EasyMagpie text to speech model." \\ + --out_wav ./out.wav +""" +from __future__ import annotations + +import argparse + +import soundfile as sf +import torch + +from nemo.collections.tts.modules.magpietts_inference.utils import ModelLoadConfig, load_easy_magpie_model +from nemo.utils import logging + + +def main(): + parser = argparse.ArgumentParser(description="EasyMagpieTTS single-utterance pure-torch inference") + parser.add_argument("--nemo_file", required=True, help="Path to the EasyMagpieTTS .nemo checkpoint") + parser.add_argument("--codec_model_path", required=True, help="Path to the audio codec .nemo checkpoint") + parser.add_argument( + "--phoneme_tokenizer_path", + default=None, + help="Override the phoneme (IPA BPE) tokenizer path baked into the checkpoint. " + "Required if the path stored in the .nemo does not exist locally.", + ) + parser.add_argument("--context_audio", default=None, help="Reference/context wav for voice cloning") + parser.add_argument( + "--context_text", + default=None, + help="Optional style/context text tag. The voice is cloned from --context_audio; this is a " + "separate style/language conditioning string. If omitted, the correct in-distribution " + '"no text context" placeholder is auto-selected to match how the checkpoint was trained ' + "(language tag like [EN] if add_language_to_context_text=True, else [NO TEXT CONTEXT]). " + "Do NOT pass a free-form sentence unless you want it spoken/styled.", + ) + parser.add_argument( + "--language", + default="en", + help="Language of --text; used to build the [LANG] context-text placeholder for checkpoints " + "trained with add_language_to_context_text=True (e.g. en, de, es, fr, it, hi, zh, vi, ko-KR, pt-BR, ar)", + ) + parser.add_argument("--text", required=True, help="Text to synthesize") + parser.add_argument("--out_wav", default="./out.wav", help="Output wav path") + + # Tokenizer selection: defaults to the first text tokenizer in the checkpoint config + # (e.g. nemotron_nano_30b). Override only if your checkpoint has multiple. + parser.add_argument("--main_tokenizer_name", default=None) + + # The legacy Qwen EasyMagpie checkpoint was trained without CAS embeddings on context text. + parser.add_argument( + "--disable_cas_for_context_text", + action="store_true", + help="Set for legacy checkpoints trained without CAS embeddings on context text", + ) + + # Sampling / decoding parameters (defaults mirror the InferEvaluate functional test). + parser.add_argument("--temperature", type=float, default=0.6) + parser.add_argument("--topk", type=int, default=80) + parser.add_argument("--use_cfg", action="store_true", default=True) + parser.add_argument("--no_cfg", dest="use_cfg", action="store_false") + parser.add_argument("--cfg_scale", type=float, default=2.5) + parser.add_argument("--no_local_transformer", dest="use_local_transformer", action="store_false", default=True) + parser.add_argument("--max_steps", type=int, default=500) + parser.add_argument("--context_audio_duration", type=float, default=5.0) + parser.add_argument("--device", default="cuda") + + args = parser.parse_args() + + model, ckpt_name = load_easy_magpie_model( + ModelLoadConfig( + nemo_file=args.nemo_file, + codecmodel_path=args.codec_model_path, + phoneme_tokenizer_path=args.phoneme_tokenizer_path, + disable_cas_for_context_text=args.disable_cas_for_context_text, + ), + device=args.device, + ) + logging.info(f"Loaded EasyMagpieTTS checkpoint: {ckpt_name}") + logging.info(f"Available text tokenizers: {list(model.tokenizer.tokenizers.keys())}") + + # Resolve the context-text placeholder to match the training-time convention. + # The dataset uses "[]" when add_language_to_context_text=True, else "[NO TEXT CONTEXT]". + # Passing the wrong placeholder is out-of-distribution and the model may literally speak it + # (e.g. starting the audio with the word "context"). + context_text = args.context_text + if context_text is None: + if getattr(model, "add_language_to_context_text", False): + context_text = f"[{args.language.upper()}]" + else: + context_text = "[NO TEXT CONTEXT]" + logging.info(f"Using context_text={context_text!r}") + + with torch.inference_mode(): + audio, audio_lens = model.do_tts( + transcript=args.text, + context_audio_file_path=args.context_audio, + context_text=context_text, + main_tokenizer_name=args.main_tokenizer_name, + context_audio_duration=args.context_audio_duration, + use_cfg=args.use_cfg, + cfg_scale=args.cfg_scale, + use_local_transformer=args.use_local_transformer, + temperature=args.temperature, + topk=args.topk, + max_steps=args.max_steps, + ) + + audio_len = int(audio_lens[0].item()) + audio_np = audio[0, :audio_len].float().detach().cpu().numpy() + sf.write(args.out_wav, audio_np, model.output_sample_rate) + logging.info( + f"Wrote {audio_len / model.output_sample_rate:.2f}s of audio " + f"({model.output_sample_rate} Hz) to {args.out_wav}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb new file mode 100644 index 000000000000..0e2693776adf --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb @@ -0,0 +1,419 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d5a1129d", + "metadata": {}, + "source": [ + "# EasyMagpieTTS — vLLM-Omni inference demo (dummy weights)\n", + "\n", + "This notebook runs an end-to-end inference pass through the\n", + "[`easymagpie_vllm_omni`](./easymagpie_vllm_omni) model definition using\n", + "**dummy / random weights**, so you can exercise the full engine path\n", + "(prefill -> autoregressive decode -> audio-code extraction) without a converted\n", + "checkpoint.\n", + "\n", + "It follows the same `AsyncOmni` single-stage pattern as the reference\n", + "`qwen3-tts` and `eartts` demos:\n", + "\n", + "* **prefill** — the caller supplies a precomputed context embedding via\n", + " `additional_information.prompt_embeds` of shape `(T_ctx, embedding_dim)`, with\n", + " `prompt_token_ids = [0] * T_ctx` (exactly like qwen3-tts `talker_prompt_embeds`\n", + " / eartts `speaker_latent`).\n", + "* **decode** — each step consumes one subword id from the streaming\n", + " `additional_information.text_tokens` list; the local transformer samples all\n", + " `C * S` stacked audio codebooks for the frame.\n", + "* **output** — per-step audio codes are surfaced on\n", + " `OmniOutput.multimodal_outputs[\\\"audio_codes\\\"]` (`BT x num_codebooks`), and the\n", + " engine accumulates them across steps just like eartts, so we trim to the last\n", + " `len(token_ids)` decoded rows.\n", + "\n", + "> **Dummy weights.** We build a tiny `config.json` (small backbone + small\n", + "> codebooks) and start the engine with `load_format=\\\"dummy\\\"`, so vLLM fills all\n", + "> parameters with random values. The emitted codes are therefore meaningless —\n", + "> this is a *smoke test* of the engine wiring, not a real synthesis. Point the\n", + "> engine at a real converted checkpoint (and drop `load_format`) to get audio.\n", + "\n", + "> **Environment.** Run this inside the bootstrapped `vllm_omni_env` (vLLM +\n", + "> vLLM-Omni + compatible torch) with the plugin installed:\n", + "> ```bash\n", + "> source /path/to/vllm_omni_env/bin/activate\n", + "> pip install -e examples/tts/easymagpie_vllm_omni\n", + "> ```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c9a71b74", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "# Single-process executor below, but keep spawn semantics consistent with the\n", + "# qwen3-tts / eartts demos in case you switch to a multiproc backend.\n", + "os.environ.setdefault(\"VLLM_WORKER_MULTIPROC_METHOD\", \"spawn\")\n", + "\n", + "import json\n", + "import tempfile\n", + "import uuid\n", + "from pathlib import Path\n", + "\n", + "import torch\n", + "import yaml\n", + "\n", + "from vllm import SamplingParams\n", + "from vllm_omni import AsyncOmni\n", + "\n", + "# Importing the model package is optional (the engine resolves the arch via the\n", + "# `vllm.general_plugins` entry point installed with the package), but doing it\n", + "# here surfaces the arch dataclass we use to size the dummy prompt embedding.\n", + "from easymagpie_vllm_omni.config import EasyMagpieOmniArch\n", + "\n", + "print(\"torch:\", torch.__version__, \"| cuda:\", torch.cuda.is_available())" + ] + }, + { + "cell_type": "markdown", + "id": "f7ff55fe", + "metadata": {}, + "source": [ + "## 1. Build a tiny dummy model directory\n", + "\n", + "The engine only needs a `config.json` that (a) names the registered arch and\n", + "(b) carries the EasyMagpie + Qwen2 scalars. We deliberately pick **small** dims\n", + "so the dummy backbone and local transformer are fast to instantiate.\n", + "\n", + "The EasyMagpie-specific scalars (`embedding_dim`, `num_audio_codebooks`,\n", + "`codebook_size`, `frame_stacking_factor`, `local_transformer_*`, …) are read by\n", + "`EasyMagpieOmniArch.from_hf_config`; the standard Qwen2 fields (`hidden_size`,\n", + "`num_hidden_layers`, …) configure the reused `Qwen2Model` backbone. Setting\n", + "`phoneme_vocab_size = 0` disables the optional phoneme branch for simplicity.\n", + "\n", + "With `load_format=\\\"dummy\\\"` (set in the stage config) vLLM never reads weight\n", + "files, so a lone `config.json` is enough — no safetensors, no tokenizer." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e0df89e", + "metadata": {}, + "outputs": [], + "source": [ + "# Small, internally-consistent dummy profile.\n", + "# embedding_dim == hidden_size == audio_embedding_dim == local_transformer_hidden_dim\n", + "# keeps every in/out projection an Identity (fewer dummy params, same code path).\n", + "HIDDEN = 256\n", + "NUM_AUDIO_CODEBOOKS = 4\n", + "CODEBOOK_SIZE = 64\n", + "FRAME_STACKING = 2 # -> num_stacked_codebooks = NUM_AUDIO_CODEBOOKS * FRAME_STACKING = 8\n", + "TEXT_VOCAB = 256\n", + "\n", + "config = {\n", + " # Resolved through the `vllm.general_plugins` entry point registered by the\n", + " # `easymagpie_vllm_omni` package -> EasyMagpieTTSForConditionalGeneration.\n", + " \"architectures\": [\"EasyMagpieTTSForConditionalGeneration\"],\n", + " # Standard Qwen2 backbone fields (consumed by vllm Qwen2Model).\n", + " \"model_type\": \"qwen2\",\n", + " \"hidden_size\": HIDDEN,\n", + " \"intermediate_size\": 4 * HIDDEN,\n", + " \"num_hidden_layers\": 2,\n", + " \"num_attention_heads\": 4,\n", + " \"num_key_value_heads\": 4,\n", + " \"max_position_embeddings\": 4096,\n", + " \"rms_norm_eps\": 1e-6,\n", + " \"rope_theta\": 1000000.0,\n", + " \"vocab_size\": TEXT_VOCAB,\n", + " \"tie_word_embeddings\": False,\n", + " \"torch_dtype\": \"float32\",\n", + " # EasyMagpie-specific scalars (read by EasyMagpieOmniArch.from_hf_config).\n", + " \"text_vocab_size\": TEXT_VOCAB,\n", + " \"embedding_dim\": HIDDEN,\n", + " \"audio_embedding_dim\": HIDDEN,\n", + " \"num_audio_codebooks\": NUM_AUDIO_CODEBOOKS,\n", + " \"codebook_size\": CODEBOOK_SIZE,\n", + " \"frame_stacking_factor\": FRAME_STACKING,\n", + " \"phoneme_stacking_factor\": 0, # disable phoneme branch\n", + " \"phoneme_vocab_size\": 0,\n", + " \"local_transformer_n_layers\": 2,\n", + " \"local_transformer_n_heads\": 4,\n", + " \"local_transformer_hidden_dim\": HIDDEN,\n", + "}\n", + "\n", + "MODEL_DIR = Path(tempfile.mkdtemp(prefix=\"easymagpie_dummy_\"))\n", + "(MODEL_DIR / \"config.json\").write_text(json.dumps(config, indent=2))\n", + "print(f\"Dummy model dir: {MODEL_DIR}\")\n", + "\n", + "# Sanity-check the arch the model will derive from this config.\n", + "arch = EasyMagpieOmniArch.from_hf_config(type(\"Cfg\", (), config))\n", + "print(f\"embedding_dim : {arch.embedding_dim}\")\n", + "print(f\"num_stacked_codebooks : {arch.num_stacked_codebooks} (C*S)\")\n", + "print(f\"tokens / codebook : {arch.num_all_tokens_per_codebook} (codebook_size + specials)\")\n", + "print(f\"audio_bos / audio_eos id : {arch.audio_bos_id} / {arch.audio_eos_id}\")" + ] + }, + { + "cell_type": "markdown", + "id": "012df58d", + "metadata": {}, + "source": [ + "## 2. Single-stage `AsyncOmni` engine\n", + "\n", + "A single `llm` stage that runs the EasyMagpie talker, mirroring the eartts demo\n", + "(`worker_type=\\\"ar\\\"`, `OmniARScheduler`). The stage declares\n", + "`engine_output_type=\\\"audio\\\"` / `final_output_type=\\\"audio\\\"`: for a single-stage\n", + "AR TTS model these make the runner attach the per-step `audio_codes` multimodal\n", + "payload to the output (with `\\\"latent\\\"` the payload is dropped because nothing\n", + "downstream consumes it, and `multimodal_output[\\\"audio_codes\\\"]` comes back\n", + "`None`). Two extra knobs make this a dummy-weights run with no external assets:\n", + "\n", + "* `load_format: \\\"dummy\\\"` — vLLM initializes random weights instead of reading a\n", + " checkpoint (so `load_weights` / `init_forbidden_mask` are skipped; the\n", + " forbidden-token mask stays all-zeros, i.e. no sampling mask — fine for a smoke\n", + " test).\n", + "* `skip_tokenizer_init: true` — we feed `prompt_token_ids` + `text_tokens`\n", + " directly, so no tokenizer files are needed.\n", + "\n", + "`max_model_len` must cover `T_ctx` (prefill) + the number of decode steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5085e9a4", + "metadata": {}, + "outputs": [], + "source": [ + "T_CTX = 16 # prefill context-embedding length (prompt_token_ids = [0] * T_CTX)\n", + "DECODE_STEPS = 32 # number of audio frames to decode\n", + "MAX_MODEL_LEN = 512\n", + "MAX_NUM_BATCHED_TOKENS = 512\n", + "\n", + "stage_cfg = {\n", + " \"stage_args\": [\n", + " {\n", + " \"stage_id\": 0,\n", + " \"stage_type\": \"llm\",\n", + " \"is_comprehension\": True,\n", + " \"final_output\": True,\n", + " # \"audio\" (not \"latent\") is required for a single-stage AR TTS model:\n", + " # it makes the AR model runner attach the per-step multimodal payload\n", + " # (\"audio_codes\") to the EngineCoreOutput even though no downstream\n", + " # stage consumes it, so the codes reach the client. With \"latent\" the\n", + " # payload is dropped and multimodal_output[\"audio_codes\"] is None.\n", + " \"final_output_type\": \"audio\",\n", + " \"runtime\": {\"devices\": \"0\"},\n", + " \"engine_args\": {\n", + " \"model_stage\": \"easymagpie\",\n", + " \"max_num_seqs\": 1,\n", + " \"model_arch\": \"EasyMagpieTTSForConditionalGeneration\",\n", + " \"worker_type\": \"ar\",\n", + " \"scheduler_cls\": \"vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler\",\n", + " \"enforce_eager\": True, # dummy run: skip CUDA-graph capture for a faster start\n", + " \"trust_remote_code\": True,\n", + " \"async_scheduling\": True,\n", + " \"enable_prefix_caching\": False,\n", + " \"engine_output_type\": \"audio\",\n", + " \"gpu_memory_utilization\": 0.6,\n", + " \"distributed_executor_backend\": \"uni\",\n", + " \"max_num_batched_tokens\": MAX_NUM_BATCHED_TOKENS,\n", + " \"max_model_len\": MAX_MODEL_LEN,\n", + " \"dtype\": \"float32\",\n", + " \"attention_backend\": \"TRITON_ATTN\",\n", + " # --- dummy-weights smoke-test knobs ---\n", + " \"load_format\": \"dummy\",\n", + " \"skip_tokenizer_init\": True,\n", + " },\n", + " \"default_sampling_params\": {\n", + " \"temperature\": 0.0,\n", + " \"max_tokens\": DECODE_STEPS,\n", + " \"detokenize\": False,\n", + " \"ignore_eos\": True,\n", + " },\n", + " }\n", + " ],\n", + "}\n", + "\n", + "_tmp = tempfile.NamedTemporaryFile(\n", + " mode=\"w\", suffix=\".yaml\", prefix=\"easymagpie_omni_demo_\", delete=False,\n", + ")\n", + "yaml.dump(stage_cfg, _tmp, sort_keys=False)\n", + "_tmp.close()\n", + "STAGE_CFG_PATH = _tmp.name\n", + "print(f\"Stage config: {STAGE_CFG_PATH}\")\n", + "\n", + "omni = AsyncOmni(\n", + " model=str(MODEL_DIR),\n", + " stage_configs_path=STAGE_CFG_PATH,\n", + " log_stats=False,\n", + " stage_init_timeout=300,\n", + ")\n", + "print(\"Engine ready (single stage: EasyMagpie talker, dummy weights)\")" + ] + }, + { + "cell_type": "markdown", + "id": "2736b86d", + "metadata": {}, + "source": [ + "## 3. Build the prompt\n", + "\n", + "Two pieces of per-request input, passed through `additional_information`:\n", + "\n", + "* **`prompt_embeds`** `(T_ctx, embedding_dim)` — the precomputed context\n", + " embedding consumed during prefill. In a real run this is the speaker-encoded\n", + " context audio + context text produced by the caller; here we use random noise.\n", + " `prompt_token_ids = [0] * T_ctx` are placeholders (the model feeds the backbone\n", + " via `inputs_embeds`, never via these ids).\n", + "* **`text_tokens`** `list[int]` — the streaming subword stream; decode step `k`\n", + " consumes `text_tokens[k]`. We provide one id per decode step.\n", + "\n", + "(If the checkpoint had a phoneme branch you'd also stream `phoneme_tokens`; it's\n", + "disabled here via `phoneme_vocab_size = 0`.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "697c74b3", + "metadata": {}, + "outputs": [], + "source": [ + "torch.manual_seed(0)\n", + "\n", + "# Precomputed context embedding (random stand-in for the speaker/text encoder).\n", + "prompt_embeds = torch.randn(T_CTX, arch.embedding_dim, dtype=torch.float32)\n", + "\n", + "# Streaming subword ids: one per decode step (step k consumes text_tokens[k]).\n", + "text_tokens = torch.randint(0, TEXT_VOCAB, (DECODE_STEPS,)).tolist()\n", + "\n", + "additional_information = {\n", + " \"prompt_embeds\": prompt_embeds, # (T_ctx, embedding_dim) tensor\n", + " \"text_tokens\": text_tokens, # list[int], grows by one per step\n", + "}\n", + "\n", + "prompt = {\n", + " \"prompt_token_ids\": [0] * T_CTX, # prefill placeholders\n", + " \"additional_information\": additional_information,\n", + "}\n", + "\n", + "sampling_params = SamplingParams(\n", + " temperature=0.0,\n", + " max_tokens=DECODE_STEPS,\n", + " detokenize=False,\n", + " ignore_eos=True, # dummy logits never emit a meaningful EOS -> run the full budget\n", + ")\n", + "\n", + "print(f\"T_ctx (prefill placeholders) : {T_CTX}\")\n", + "print(f\"prompt_embeds : {tuple(prompt_embeds.shape)}\")\n", + "print(f\"decode steps (max_tokens) : {DECODE_STEPS}\")\n", + "print(f\"text_tokens[:8] : {text_tokens[:8]}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3ef8934d", + "metadata": {}, + "source": [ + "## 4. Run inference and extract audio codes\n", + "\n", + "`omni.generate(...)` is an async generator yielding one `RequestOutput` per\n", + "engine step; we keep the last one. As in the eartts demo, the accumulated\n", + "`multimodal_output[\\\"audio_codes\\\"]` holds one row per flat-batch token over the\n", + "whole run (the `T_ctx` prefill frames — codes left zero — plus one frame per\n", + "decode step), so we trim to the last `len(token_ids)` rows to recover just the\n", + "decoded frames." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d0ccbd4", + "metadata": {}, + "outputs": [], + "source": [ + "async def run_request(prompt: dict, sampling_params):\n", + " request_id = f\"easymagpie-{uuid.uuid4().hex[:8]}\"\n", + " final_ro = None\n", + " num_steps = 0\n", + " async for stage_output in omni.generate(\n", + " prompt,\n", + " sampling_params_list=[sampling_params],\n", + " request_id=request_id,\n", + " ):\n", + " final_ro = stage_output\n", + " num_steps += 1\n", + " return final_ro, num_steps\n", + "\n", + "\n", + "final_ro, num_steps = await run_request(prompt, sampling_params)\n", + "assert final_ro is not None, \"no output from engine\"\n", + "\n", + "mm = final_ro.multimodal_output or {}\n", + "audio_codes = mm.get(\"audio_codes\")\n", + "token_ids = final_ro.outputs[0].token_ids if final_ro.outputs else []\n", + "\n", + "print(f\"Engine steps yielded : {num_steps}\")\n", + "print(f\"Layer-0 tokens (token_ids) : {len(token_ids)}\")\n", + "if isinstance(audio_codes, torch.Tensor):\n", + " audio_codes = audio_codes.detach().cpu().to(torch.long)\n", + " print(f\"audio_codes shape (raw) : {tuple(audio_codes.shape)}\")\n", + " # Trim the Tref prefill frames echoed during prefill: keep only the decoded\n", + " # frames (the last len(token_ids) rows), exactly like the eartts demo.\n", + " if len(token_ids) > 0:\n", + " audio_codes = audio_codes[-len(token_ids):].contiguous()\n", + " print(f\"audio_codes shape (decode) : {tuple(audio_codes.shape)}\")\n", + " print(f\"audio_codes dtype : {audio_codes.dtype}\")\n", + " print(f\"codes min / max : {int(audio_codes.min())} / {int(audio_codes.max())}\")\n", + " print(f\"first frame codes : {audio_codes[0].tolist()}\")\n", + "else:\n", + " print(f\"audio_codes : {audio_codes!r}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "04196662", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pylab as plt\n", + "\n", + "plt.imshow(audio_codes.T, aspect=\"auto\")\n", + "plt.colorbar()\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a6603b9", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "emp", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/__init__.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/__init__.py new file mode 100644 index 000000000000..8a37af8454ea --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""EasyMagpieTTS model definition for vLLM-Omni. + +This package provides an inference-only re-implementation of EasyMagpieTTS +(decoder-only, Qwen2 backbone + autoregressive local transformer over the +stacked audio codebooks) that plugs into the vLLM-Omni serving stack via the +standard ``preprocess`` / ``postprocess`` / ``make_omni_output`` hooks. + +The companion ``vllm_plugin_easymagpie_omni`` package registers the model with +vLLM's ``ModelRegistry`` through the ``vllm.general_plugins`` entry point. +""" + +from easymagpie_vllm_omni.config import EASYMAGPIE_QWEN, EasyMagpieOmniArch + +__all__ = ["EASYMAGPIE_QWEN", "EasyMagpieOmniArch"] diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py new file mode 100644 index 000000000000..c569089e32f7 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py @@ -0,0 +1,158 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Architecture constants for the EasyMagpieTTS vLLM-Omni model. + +These mirror the values baked into the reference EasyMagpieTTS checkpoint +(``examples/tts/conf/magpietts/easy_magpietts.yaml`` — Qwen2.5-1.5B backbone, +8 codebooks, frame-stacking ×2, 3-layer autoregressive local transformer). + +The vLLM-Omni model reads the bulk of its configuration from the +``hf_config`` provided by vLLM at construction time; this dataclass captures +the TTS-specific scalars that are *not* part of a standard HF text-LM config +and provides a single, well-documented default profile. +""" +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any + +# Number of trailing special tokens appended to every audio codebook. +# Matches ``len(SpecialAudioToken)`` in +# ``nemo.collections.tts.modules.magpietts_modules`` (BOS, EOS, CONTEXT_BOS, +# CONTEXT_EOS, MASK, RESERVED_1..3). +NUM_SPECIAL_AUDIO_TOKENS: int = 8 + +# Offsets of the special audio tokens *within* the trailing special-token block +# (i.e. ``codebook_size + `` is the real embedding-table id). +SPECIAL_AUDIO_BOS: int = 0 +SPECIAL_AUDIO_EOS: int = 1 +SPECIAL_AUDIO_CONTEXT_BOS: int = 2 +SPECIAL_AUDIO_CONTEXT_EOS: int = 3 +SPECIAL_AUDIO_MASK: int = 4 + + +@dataclass +class EasyMagpieOmniArch: + """Static architecture description for an EasyMagpieTTS checkpoint. + + Attributes: + hidden_dim: Backbone hidden size (``cfg.hidden_dim``). + embedding_dim: Embedding size feeding the backbone (``cfg.embedding_dim``). + audio_embedding_dim: Per-codebook audio embedding size + (``cfg.audio_embedding_dim``); may differ from ``embedding_dim``. + num_audio_codebooks: Number of codec codebooks (``C``). + codebook_size: Base codec codebook size (excluding special tokens). + frame_stacking_factor: Frame stacking factor (``S``). The model treats + the audio stream as ``C * S`` independent "stacked" codebooks. + phoneme_stacking_factor: Phoneme stacking factor. + phoneme_vocab_size: Phoneme tokenizer vocabulary size. + local_transformer_n_layers / _n_heads / _hidden_dim: local-transformer + (intra-frame codebook predictor) sizing. + """ + + hidden_dim: int = 1536 + embedding_dim: int = 1536 + audio_embedding_dim: int = 1536 + + num_audio_codebooks: int = 8 + codebook_size: int = 1024 + frame_stacking_factor: int = 2 + + phoneme_stacking_factor: int = 1 + phoneme_vocab_size: int = 2051 + + local_transformer_n_layers: int = 3 + local_transformer_n_heads: int = 12 + local_transformer_hidden_dim: int = 1536 + + # Optional per-checkpoint overrides for backward compatibility (legacy + # checkpoints sometimes forced special-token ids). + forced_audio_bos_id: int | None = None + forced_audio_eos_id: int | None = None + forced_mask_token_id: int | None = None + + extra: dict[str, Any] = field(default_factory=dict) + + # ── Derived quantities ─────────────────────────────────────────── + @property + def num_stacked_codebooks(self) -> int: + """Number of independent codebooks the model autoregresses over (``C * S``).""" + return self.num_audio_codebooks * self.frame_stacking_factor + + @property + def num_all_tokens_per_codebook(self) -> int: + """Per-codebook vocabulary size including the trailing special tokens.""" + return self.codebook_size + NUM_SPECIAL_AUDIO_TOKENS + + @property + def audio_bos_id(self) -> int: + """Embedding-table id of the audio BOS token.""" + if self.forced_audio_bos_id is not None: + return self.forced_audio_bos_id + return self.codebook_size + SPECIAL_AUDIO_BOS + + @property + def audio_eos_id(self) -> int: + """Embedding-table id of the audio EOS token.""" + if self.forced_audio_eos_id is not None: + return self.forced_audio_eos_id + return self.codebook_size + SPECIAL_AUDIO_EOS + + @property + def mask_token_id(self) -> int: + """Embedding-table id of the MaskGit MASK token.""" + if self.forced_mask_token_id is not None: + return self.forced_mask_token_id + return self.codebook_size + SPECIAL_AUDIO_MASK + + @classmethod + def from_hf_config(cls, hf_config: Any) -> "EasyMagpieOmniArch": + """Build an arch description from a vLLM ``hf_config``. + + Any attribute present on ``hf_config`` overrides the default profile; + unknown attributes are ignored. This lets a converted checkpoint carry + its own ``easymagpie`` block in ``config.json`` while still working + out-of-the-box on the reference Qwen2.5-1.5B profile. + """ + defaults = cls() + kwargs: dict[str, Any] = {} + for f in ( + "hidden_dim", + "embedding_dim", + "audio_embedding_dim", + "num_audio_codebooks", + "codebook_size", + "frame_stacking_factor", + "phoneme_stacking_factor", + "phoneme_vocab_size", + "local_transformer_n_layers", + "local_transformer_n_heads", + "local_transformer_hidden_dim", + "forced_audio_bos_id", + "forced_audio_eos_id", + "forced_mask_token_id", + ): + if hasattr(hf_config, f): + kwargs[f] = getattr(hf_config, f) + # ``hidden_size`` is the canonical HF name for the backbone width. + if "hidden_dim" not in kwargs and hasattr(hf_config, "hidden_size"): + kwargs["hidden_dim"] = hf_config.hidden_size + kwargs.setdefault("embedding_dim", hf_config.hidden_size) + merged = {**defaults.__dict__, **kwargs} + merged.pop("extra", None) + return cls(**merged) + + +# Reference profile: Qwen2.5-1.5B backbone EasyMagpieTTS checkpoint. +EASYMAGPIE_QWEN = EasyMagpieOmniArch() diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py new file mode 100644 index 000000000000..bfb76ccb9303 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -0,0 +1,586 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only EasyMagpieTTS model for vLLM-Omni. + +EasyMagpieTTS is a decoder-only streaming TTS model: a text-LM backbone (the +reference checkpoint uses Qwen2.5-1.5B) consumes a per-frame additive input +embedding (text + phoneme + audio) and emits a per-frame hidden state, from +which a small autoregressive *local transformer* samples all ``C * S`` stacked +audio codebooks for that frame (see :mod:`easymagpie_vllm_omni.local_transformer`). + +This module wires that architecture into vLLM-Omni's +``preprocess`` / ``forward`` / ``compute_logits`` / ``make_omni_output`` / +``postprocess`` contract, following the same conventions as the upstream +qwen3-tts and eartts vLLM-Omni model definitions: + +* **Backbone** — vLLM's :class:`~vllm.model_executor.models.qwen2.Qwen2Model`, + reused wholesale (KV cache + paged attention) the same way the EasyMagpie + vLLM *sidecar* reuses ``NemotronHModel``. Every step feeds the backbone via + ``inputs_embeds``; its own ``embed_tokens`` table is never consumed. +* **Local transformer** — :class:`EasyMagpieCodePredictor`, a from-scratch, + CUDA-graph-capturable re-implementation that runs as a single compiled graph. +* **compute_logits** — returns trivial logits (à la eartts) so vLLM's sampler + always picks index 0; the real audio output is the codes tensor surfaced + through :meth:`make_omni_output` under the ``"audio_codes"`` key. + +Text is embedded via a precomputed per-subword lookup table baked at +checkpoint-conversion time (the reference char-aware subword encoder is +deterministic per subword id, so it is never run inside the engine). + +Per-request I/O (via ``additional_information``): + +* ``prompt_embeds`` (prefill only) — ``(T_ctx, embedding_dim)`` precomputed + context/prompt embedding (speaker-encoded context audio + context text) + produced by the caller, exactly like qwen3-tts ``talker_prompt_embeds`` / + eartts ``speaker_latent``. The user passes ``prompt_token_ids = [0] * T_ctx``. +* ``text_tokens`` — Python ``list[int]`` of subword ids that grows by one per + decode step; step ``k`` consumes ``text_tokens[k]`` (embedded through the + precomputed per-subword table). +* ``phoneme_tokens`` (optional) — same streaming-list contract for the phoneme + channel; if omitted the phoneme branch is skipped. +""" +from __future__ import annotations + +import bisect +from collections.abc import Iterable +from typing import Any, Optional + +import torch +from torch import nn +from vllm.compilation.backends import set_model_tag +from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile +from vllm.config import CUDAGraphMode, VllmConfig +from vllm.forward_context import BatchDescriptor, get_forward_context +from vllm.logger import init_logger +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.models.utils import maybe_prefix +from vllm.sequence import IntermediateTensors + +from vllm_omni.model_executor.models.output_templates import OmniOutput + +from easymagpie_vllm_omni.config import EasyMagpieOmniArch +from easymagpie_vllm_omni.local_transformer import EasyMagpieCodePredictor + +logger = init_logger(__name__) + +# Placeholder token id stuffed into the per-step ``input_ids`` returned by +# ``preprocess`` — the model never consumes ``input_ids`` (decode behaviour is +# driven by the per-token buffers), and ``compute_logits`` returns +# argmax-at-0 dummy logits, so this only needs to be a valid id. +_DUMMY_TOKEN_ID = 0 + + +# ``dynamic_arg_dims`` is passed explicitly: this file uses +# ``from __future__ import annotations`` (PEP 563), so ``forward``'s annotations +# are strings and vLLM's annotation-based inference would fail with +# "No dynamic dimensions found...". These mirror vLLM's default inference +# (dim 0 for every tensor / IntermediateTensors argument). +@ignore_torch_compile +@support_torch_compile( + dynamic_arg_dims={ + "input_ids": 0, + "positions": 0, + "intermediate_tensors": 0, + "inputs_embeds": 0, + } +) +class EasyMagpieTTSForConditionalGeneration(nn.Module): + """EasyMagpieTTS talker for vLLM-Omni. + + See the module docstring for the per-step flow and the per-request I/O + contract. The class exposes the omni hooks (``has_preprocess`` / + ``has_postprocess`` / ``have_multimodal_outputs``) consumed by the + ``OmniGPUModelRunner``. + """ + + # Omni runner hooks. + has_preprocess: bool = True + has_postprocess: bool = True + have_multimodal_outputs: bool = True + + # Keep small per-step tensors GPU-resident across steps (no D2H/H2D). + gpu_resident_buffer_keys: set[str] = { + "last_audio_codes", + "last_phoneme_token", + "last_hidden", + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + hf_config = vllm_config.model_config.hf_config + self.hf_config = hf_config + self.vllm_config = vllm_config + self.arch = EasyMagpieOmniArch.from_hf_config(hf_config) + self.model_path = vllm_config.model_config.model + + arch = self.arch + self.hidden_dim = arch.hidden_dim + self.embedding_dim = arch.embedding_dim + self.num_codebooks = arch.num_stacked_codebooks + + # ── Backbone (reused vLLM text LM; fed via inputs_embeds) ─────── + self.backbone = Qwen2Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "backbone"), + ) + + # ── Local transformer (its own compile group / CUDA graph) ────── + with set_model_tag("local_transformer"): + self.code_predictor = EasyMagpieCodePredictor( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "code_predictor"), + ) + + # ── Text + phoneme embedding heads ────────────────────────────── + # Precomputed per-subword text embedding. The reference model embeds + # text with a char-aware subword (CAS) encoder + the decoder's subword + # table; both are deterministic per subword id, so the checkpoint + # converter bakes their combined result into this single lookup table + # (one row per subword id). It is fed additively on every decode step; + # the CAS encoder is never run inside the engine. + text_vocab_size = int(getattr(hf_config, "text_vocab_size", getattr(hf_config, "vocab_size", 0))) + self.text_embedding = nn.Embedding(text_vocab_size, self.embedding_dim) + + # Phoneme channel (optional — only built when the checkpoint has one). + self.has_phoneme = arch.phoneme_vocab_size > 0 and arch.phoneme_stacking_factor > 0 + if self.has_phoneme: + self.phoneme_embeddings = nn.ModuleList( + [nn.Embedding(arch.phoneme_vocab_size, self.embedding_dim) for _ in range(arch.phoneme_stacking_factor)] + ) + self.phoneme_final_proj = nn.Linear( + self.hidden_dim, arch.phoneme_vocab_size * arch.phoneme_stacking_factor + ) + + # ── Persistent, address-stable scratch buffers ───────────────── + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + dtype = vllm_config.model_config.dtype + # Combined per-token input embedding fed into the backbone. + self._combined_embeddings = torch.zeros(max_num_tokens, self.embedding_dim, dtype=dtype) + # Per-token decode inputs assembled by ``preprocess``. + self._dec_text_tokens = torch.zeros(max_num_tokens, dtype=torch.long) + self._dec_text_mask = torch.zeros(max_num_tokens, dtype=torch.long) + self._dec_audio_codes = torch.zeros(max_num_tokens, self.num_codebooks, dtype=torch.long) + self._dec_audio_valid = torch.zeros(max_num_tokens, dtype=torch.long) + if self.has_phoneme: + self._dec_phoneme_tokens = torch.zeros( + max_num_tokens, arch.phoneme_stacking_factor, dtype=torch.long + ) + self._dec_phoneme_valid = torch.zeros(max_num_tokens, dtype=torch.long) + + self._out_codes = torch.zeros(max_num_tokens, self.num_codebooks, dtype=torch.long) + + # ------------------------------------------------------------------ + # Embedding helpers + # ------------------------------------------------------------------ + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Compatibility shim — unused at runtime (everything goes via inputs_embeds).""" + return self.text_embedding(input_ids) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.get_input_embeddings(input_ids) + + def _embed_phoneme(self, phoneme_tokens: torch.Tensor) -> torch.Tensor: + """Average the per-stack phoneme embeddings (``[num_tokens, S] -> [num_tokens, dim]``).""" + acc = self.phoneme_embeddings[0](phoneme_tokens[:, 0]) + for s in range(1, len(self.phoneme_embeddings)): + acc = acc + self.phoneme_embeddings[s](phoneme_tokens[:, s]) + return acc / len(self.phoneme_embeddings) + + # ------------------------------------------------------------------ + # Decode-token dispatch (which positions need the local transformer) + # ------------------------------------------------------------------ + + def _get_decode_idxs(self): + """Return ``(decode_token_indices, num_requests)`` for code-predictor dispatch. + + Mirrors the qwen3-tts / eartts pattern: + + * ``(None, 0)`` → run the local transformer on every token (profile / + dummy run with no ``attn_metadata``, or a decode-only batch where + ``max_query_len == 1``), so the captured CUDA graph covers every + ``cudagraph_capture_sizes`` value. + * ``(indices, num_requests)`` → run only on the listed decode positions + (mixed prefill+decode batch). ``indices`` is padded to the next + captured graph size; ``num_requests`` is the unpadded count. + """ + ctx = get_forward_context() + attn_metadata = ctx.attn_metadata + if attn_metadata is None: + return None, 0 + + if isinstance(attn_metadata, dict): + any_layer_meta = next(iter(attn_metadata.values())) + else: + any_layer_meta = attn_metadata + + if any_layer_meta.max_query_len == 1: + return None, 0 + + start_loc = any_layer_meta.query_start_loc + tokens_per_req = start_loc[1:] - start_loc[:-1] + is_decode = tokens_per_req == 1 + decode_token_indices = start_loc[:-1][is_decode] + + num_requests = decode_token_indices.shape[0] + padded_num_requests = num_requests + if self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE: + sizes = self.vllm_config.compilation_config.cudagraph_capture_sizes + idx = bisect.bisect_left(sizes, num_requests) + if idx < len(sizes): + padded_num_requests = sizes[idx] + if padded_num_requests != num_requests: + decode_token_indices = torch.nn.functional.pad( + decode_token_indices, (0, padded_num_requests - num_requests) + ) + return decode_token_indices, num_requests + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **_: Any, + ) -> torch.Tensor: + """Assemble the per-token embedding, run the backbone, then the codes. + + ``inputs_embeds`` carries the prefill embedding span produced by + :meth:`preprocess` (zeros at decode positions). For decode positions we + assemble ``text_emb + phoneme_emb + audio_emb`` in-place from the + per-token buffers, run the backbone, then sample the codebooks with the + local transformer (skipping prefill positions). + """ + num_tokens = input_ids.shape[0] + combined = self._combined_embeddings[:num_tokens] + if inputs_embeds is not None: + combined.copy_(inputs_embeds) + else: + combined.zero_() + + decode_idx, num_req = self._get_decode_idxs() + + if decode_idx is None: + # Profile / dummy run or decode-only batch: assemble decode + # embeddings everywhere so the captured graph sees the full path. + self._assemble_decode_embeddings(combined, slice(0, num_tokens)) + elif num_req > 0: + valid = decode_idx[:num_req] + self._assemble_decode_embeddings(combined, valid) + + hidden_states = self.backbone( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=combined, + ) + + # Sample codes (local transformer) only where needed. + if decode_idx is None: + codes = self.code_predictor.generate_codes(hidden_states) + self._out_codes[:num_tokens].copy_(codes) + if self.has_phoneme: + self._predict_phonemes(hidden_states, slice(0, num_tokens)) + elif num_req > 0: + ctx = get_forward_context() + orig_bd = ctx.batch_descriptor + ctx.batch_descriptor = BatchDescriptor(num_tokens=decode_idx.shape[0]) + codes = self.code_predictor.generate_codes(hidden_states[decode_idx]) + ctx.batch_descriptor = orig_bd + valid = decode_idx[:num_req] + self._out_codes[valid] = codes[:num_req] + if self.has_phoneme: + self._predict_phonemes(hidden_states, valid) + + return hidden_states + + def _assemble_decode_embeddings(self, combined: torch.Tensor, idx) -> None: + """Add ``text + phoneme + audio`` embeddings into ``combined`` at ``idx``.""" + # Audio: previous-frame codes (gated by validity). + audio_codes = self._dec_audio_codes[idx] + audio_emb = self.code_predictor.embed_audio_frame(audio_codes) + audio_emb = audio_emb * self._dec_audio_valid[idx].unsqueeze(-1).to(audio_emb.dtype) + combined[idx] += audio_emb + + # Text: current subword token (gated by validity). + text_emb = self.text_embedding(self._dec_text_tokens[idx]) + text_emb = text_emb * self._dec_text_mask[idx].unsqueeze(-1).to(text_emb.dtype) + combined[idx] += text_emb + + # Phoneme: previous predicted phoneme (gated by validity). + if self.has_phoneme: + phon_emb = self._embed_phoneme(self._dec_phoneme_tokens[idx]) + phon_emb = phon_emb * self._dec_phoneme_valid[idx].unsqueeze(-1).to(phon_emb.dtype) + combined[idx] += phon_emb + + @torch.no_grad() + def _predict_phonemes(self, hidden_states: torch.Tensor, idx) -> None: + """Argmax the phoneme head and stash the prediction for the next step.""" + logits = self.phoneme_final_proj(hidden_states[idx].float()) + s = self.arch.phoneme_stacking_factor + logits = logits.view(-1, s, self.arch.phoneme_vocab_size) + self._dec_phoneme_tokens[idx] = logits.argmax(dim=-1).long() + self._dec_phoneme_valid[idx] = 1 + + # ------------------------------------------------------------------ + # compute_logits — dummy (real output is the codes tensor) + # ------------------------------------------------------------------ + + def compute_logits(self, hidden_states, sampling_metadata: Any = None) -> Optional[torch.Tensor]: + """Return zero logits so vLLM's sampler always picks index 0. + + The width is taken from ``hf_config.vocab_size`` so the sampler's + working buffers match. The sampled id is irrelevant — audio is surfaced + via :meth:`make_omni_output`. + """ + if isinstance(hidden_states, OmniOutput): + hidden_states = hidden_states.text_hidden_states + if hidden_states is None: + return None + batch_size = hidden_states.shape[0] + return hidden_states.new_zeros(batch_size, int(self.hf_config.vocab_size)) + + # ------------------------------------------------------------------ + # multimodal output plumbing + # ------------------------------------------------------------------ + + def make_omni_output(self, model_outputs, **_: Any) -> OmniOutput: + """Surface the sampled codes (``BT x num_codebooks``) under ``audio_codes``.""" + if isinstance(model_outputs, OmniOutput): + return model_outputs + hidden = model_outputs + num_tokens = int(hidden.shape[0]) + audio_codes = self._out_codes[:num_tokens].clone() + return OmniOutput( + text_hidden_states=hidden, + multimodal_outputs={"audio_codes": audio_codes}, + ) + + # ------------------------------------------------------------------ + # preprocess / postprocess + # ------------------------------------------------------------------ + + @staticmethod + def _unwrap(value: Any) -> Any: + if isinstance(value, list): + return value[0] if value else None + return value + + def preprocess( + self, + input_ids: torch.Tensor, + input_embeds: Optional[torch.Tensor], + *, + start: int = 0, + end: int = 0, + **info_dict: Any, + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + """Build per-request ``(input_ids, inputs_embeds)`` for this step. + + Prefill (``span_len > 1``): slice the precomputed ``prompt_embeds`` + context embedding into this chunk and return it; ``input_ids`` are + placeholders. Decode (``span_len == 1``): write the per-token decode + inputs (previous codes, current text token, previous phoneme) into the + model buffers at ``start`` and return a zero embedding that + :meth:`forward` accumulates into. + """ + nested = info_dict.get("additional_information") + if isinstance(nested, dict): + merged = {k: v for k, v in info_dict.items() if k != "additional_information"} + for k, v in nested.items(): + merged.setdefault(k, v) + info_dict = merged + + device = input_ids.device + span_len = int(input_ids.shape[0]) + if span_len <= 0: + base = input_embeds if input_embeds is not None else self.embed_input_ids(input_ids) + return input_ids, base, {} + + if span_len > 1: + return self._preprocess_prefill(input_ids, span_len, device, info_dict) + return self._preprocess_decode(input_ids, start, device, info_dict) + + def _preprocess_prefill( + self, + input_ids: torch.Tensor, + span_len: int, + device: torch.device, + info_dict: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + prompt_embeds = self._unwrap(info_dict.get("prompt_embeds")) + if not isinstance(prompt_embeds, torch.Tensor) or prompt_embeds.ndim != 2: + raise ValueError( + "EasyMagpieTTS preprocess requires additional_information.prompt_embeds " + "of shape (T_ctx, embedding_dim) for prefill." + ) + prompt_embeds = prompt_embeds.to(device=device, dtype=self._combined_embeddings.dtype) + + offset = int(info_dict.get("ear_prefill_offset", 0) or 0) + total = int(prompt_embeds.shape[0]) + s = max(0, min(offset, total)) + e = max(0, min(offset + span_len, total)) + take = prompt_embeds[s:e] + if int(take.shape[0]) < span_len: + pad_n = span_len - int(take.shape[0]) + pad_rows = ( + take[-1:].expand(pad_n, -1) + if take.shape[0] > 0 + else prompt_embeds.new_zeros(pad_n, prompt_embeds.shape[-1]) + ) + take = torch.cat([take, pad_rows], dim=0) + + info_update = { + "ear_prefill_offset": offset + span_len, + "ear_decode_offset": 0, + } + input_ids_out = torch.full_like(input_ids, _DUMMY_TOKEN_ID) + return input_ids_out, take, info_update + + def _preprocess_decode( + self, + input_ids: torch.Tensor, + start: int, + device: torch.device, + info_dict: dict[str, Any], + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + decode_offset = int(info_dict.get("ear_decode_offset", 0) or 0) + + # Text channel (streaming list that grows by one per step). + text_tokens = info_dict.get("text_tokens") + if isinstance(text_tokens, list) and text_tokens: + idx = min(decode_offset, len(text_tokens) - 1) + self._dec_text_tokens[start] = int(text_tokens[idx]) + self._dec_text_mask[start] = 1 + else: + self._dec_text_mask[start] = 0 + + # Phoneme channel: previous-step prediction stashed by postprocess. + if self.has_phoneme: + last_phon = info_dict.get("last_phoneme_token") + if isinstance(last_phon, torch.Tensor) and last_phon.numel() > 0: + p = last_phon.to(device=device, dtype=torch.long).reshape(-1) + self._dec_phoneme_tokens[start, : p.shape[0]].copy_(p[: self.arch.phoneme_stacking_factor]) + self._dec_phoneme_valid[start] = 1 + else: + self._dec_phoneme_valid[start] = 0 + + # Audio channel: previous-frame codes (BOS seed on the first step). + last_codes = info_dict.get("last_audio_codes") + if isinstance(last_codes, torch.Tensor) and last_codes.numel() > 0: + c = last_codes.to(device=device, dtype=torch.long).reshape(-1)[: self.num_codebooks] + self._dec_audio_codes[start, : c.shape[0]].copy_(c) + self._dec_audio_valid[start] = 1 + else: + # First decode step after prefill: seed with audio BOS. + self._dec_audio_codes[start].fill_(self.arch.audio_bos_id) + self._dec_audio_valid[start] = 1 + + inputs_embeds_out = torch.zeros((1, self.embedding_dim), device=device, dtype=self._combined_embeddings.dtype) + info_update = {"ear_decode_offset": decode_offset + 1} + return input_ids, inputs_embeds_out, info_update + + def postprocess(self, hidden_states: torch.Tensor, multimodal_outputs: Optional[dict[str, Any]] = None, **_: Any): + """Stash the last frame's codes (and phoneme) for the next decode step.""" + if hidden_states.numel() == 0: + return {} + stride0 = hidden_states.stride(0) or 1 + req_start = hidden_states.storage_offset() // stride0 + last = req_start + hidden_states.shape[0] - 1 + + out: dict[str, Any] = {} + audio_codes = (multimodal_outputs or {}).get("audio_codes") + if isinstance(audio_codes, torch.Tensor) and audio_codes.numel() > 0: + out["last_audio_codes"] = audio_codes[last : last + 1].detach() + if self.has_phoneme: + out["last_phoneme_token"] = self._dec_phoneme_tokens[last : last + 1].detach().clone() + return out + + # ------------------------------------------------------------------ + # weight loading + # ------------------------------------------------------------------ + + # Checkpoint prefixes (reference EasyMagpieTTS state dict) → in-model paths. + # ``decoder.*`` is fed to the vLLM backbone loader separately (it understands + # HF Qwen2 naming + qkv packing). The TTS submodules are copied manually. + _TTS_PREFIX_MAP = { + "local_transformer.": "code_predictor.local_transformer.", + "local_transformer_in_projection.": "code_predictor.local_transformer_in_projection.", + "local_transformer_audio_out_projection.": "code_predictor.local_transformer_audio_out_projection.", + "local_transformer_out_projections.": "code_predictor.local_transformer_out_projections.", + "audio_embeddings.": "code_predictor.audio_embeddings.", + "audio_in_projection.": "code_predictor.audio_in_projection.", + "phoneme_embeddings.": "phoneme_embeddings.", + "phoneme_final_proj.": "phoneme_final_proj.", + "text_embedding.": "text_embedding.", + } + + def _remap_tts_key(self, name: str) -> Optional[str]: + """Map a raw checkpoint key to its in-model parameter path (or ``None``).""" + for src, dst in self._TTS_PREFIX_MAP.items(): + if name.startswith(src): + return dst + name[len(src) :] + return None + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load backbone (Qwen2) + TTS submodule weights from a converted checkpoint. + + The converted checkpoint is expected to use the reference EasyMagpieTTS + key layout: the backbone under ``decoder.*`` (HF Qwen2 names) and the + TTS submodules at top level (``audio_embeddings.*``, ``local_transformer.*``, + ``phoneme_*``, ``text_embedding.*``, projection heads). Backbone weights + are routed to :meth:`Qwen2Model.load_weights` (which packs qkv / gate-up + and handles HF naming); TTS weights are copied directly by name. + """ + own_params = dict(self.named_parameters()) + loaded: set[str] = set() + backbone_weights: list[tuple[str, torch.Tensor]] = [] + + for name, tensor in weights: + if name.startswith("decoder."): + backbone_weights.append((name[len("decoder.") :], tensor)) + continue + mapped = self._remap_tts_key(name) + if mapped is None: + # Unrelated checkpoint section (codec, speaker encoder, CAS, etc.). + continue + target = own_params.get(mapped) + if target is None: + logger.warning("EasyMagpieTTS: no parameter for checkpoint key %s -> %s", name, mapped) + continue + if target.shape != tensor.shape: + raise RuntimeError( + f"EasyMagpieTTS weight shape mismatch at {mapped!r}: " + f"ckpt {tuple(tensor.shape)} vs model {tuple(target.shape)}" + ) + with torch.no_grad(): + target.data.copy_(tensor.to(target.dtype)) + loaded.add(mapped) + + backbone_loaded = self.backbone.load_weights(backbone_weights) + loaded |= {f"backbone.{n}" for n in backbone_loaded} + + # Derived runtime state. + self.code_predictor.init_forbidden_mask() + + # The backbone's vestigial embed_tokens table is never consumed + # (everything goes through inputs_embeds); don't flag it as missing. + loaded.add("backbone.embed_tokens.weight") + + logger.info("Loaded %d weights for EasyMagpieTTSForConditionalGeneration", len(loaded)) + return loaded diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py new file mode 100644 index 000000000000..a72ee6ecd52d --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py @@ -0,0 +1,398 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""From-scratch autoregressive local transformer for EasyMagpieTTS on vLLM-Omni. + +The reference EasyMagpieTTS model predicts the ``C * S`` stacked audio +codebooks of one frame *autoregressively* with a small causal transformer +(``nemo.collections.tts.modules.transformer_2501.Transformer``) conditioned on +the backbone's per-frame hidden state. The reference implementation re-creates +fresh tensors and (optionally) a KV cache on every codebook step, which is +incompatible with CUDA-graph replay. + +This module re-implements that local transformer from scratch so it can run as +a single compiled CUDA graph: + +* :class:`EasyMagpieLocalTransformer` mirrors the ``transformer_2501`` + layer/weight layout **exactly** (so a stock checkpoint loads 1:1) but uses + ``scaled_dot_product_attention`` and drops the KV cache / padding-mask + plumbing. It is decorated with ``@support_torch_compile`` so vLLM captures + one CUDA graph for the fixed ``(num_tokens, num_stacked_codebooks, hidden)`` + input shape. +* :class:`EasyMagpieCodePredictor` owns the persistent, address-stable scratch + buffers and runs the per-frame autoregressive loop, re-invoking the compiled + transformer once per codebook over the **same** buffer (matching the + qwen3-tts code-predictor trick: replaying one fixed-shape graph N times is + faster and simpler than capturing N separate graphs). + +All sampling is CUDA-graph safe (Gumbel-max + ``topk`` + ``masked_fill`` only; +no host syncs, no ``multinomial`` on possibly-degenerate warmup data). +""" +from __future__ import annotations + +import torch +from torch import nn +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig + +from easymagpie_vllm_omni.config import EasyMagpieOmniArch + + +def _gumbel_argmax(logits: torch.Tensor) -> torch.Tensor: + """Gumbel-max categorical draw — CUDA-graph safe. + + Equivalent to sampling from ``softmax(logits)`` but uses only + ``uniform_`` + ``log`` + ``argmax`` (all legal inside a captured graph) + and degrades gracefully on degenerate warmup logits instead of triggering + a device-side assert the way ``multinomial`` does. + """ + u = torch.empty_like(logits).uniform_(1e-20, 1.0 - 1e-20) + return (logits - torch.log(-torch.log(u))).argmax(dim=-1) + + +def sample_codebook( + logits: torch.Tensor, + *, + temperature: float, + top_k: int, + forbidden_mask: torch.Tensor | None, +) -> torch.Tensor: + """Sample one codebook's tokens from logits (CUDA-graph safe). + + Args: + logits: ``[num_tokens, vocab]`` raw codebook logits. + temperature: Sampling temperature; ``<= 0`` falls back to argmax. + top_k: Top-k truncation width (``<= 0`` disables truncation). + forbidden_mask: Optional ``[vocab]`` bool mask; ``True`` entries are + set to ``-inf`` before sampling (reserved/special tokens). + + Returns: + ``[num_tokens]`` int64 sampled token ids. + """ + if forbidden_mask is not None: + logits = logits.masked_fill(forbidden_mask, float("-inf")) + + if temperature <= 0.0: + return logits.argmax(dim=-1) + + logits = logits / temperature + + if top_k is not None and top_k > 0: + vals, idxs = torch.topk(logits, k=min(top_k, logits.size(-1)), dim=-1) + sampled_in_k = _gumbel_argmax(vals) + return idxs.gather(-1, sampled_in_k.unsqueeze(-1)).squeeze(-1) + + return _gumbel_argmax(logits) + + +class EasyMagpieLTSelfAttention(nn.Module): + """Causal self-attention matching ``transformer_2501.SelfAttention`` weights. + + Same projections (``qkv_net`` fused QKV without bias, ``o_net`` without + bias) and the same ``d_head ** -0.5`` scaling, but computed with + ``scaled_dot_product_attention`` and an ``is_causal=True`` flag instead of + the materialised causal-mask buffer + naive softmax. No KV cache: the + autoregressive loop re-runs the full (short, fixed-length) sequence each + step, which is what makes the whole thing CUDA-graph capturable. + """ + + def __init__(self, d_model: int, n_heads: int) -> None: + super().__init__() + assert d_model % n_heads == 0, "d_model must be divisible by n_heads" + self.n_heads = n_heads + self.d_head = d_model // n_heads + self.scale = self.d_head**-0.5 + self.qkv_net = nn.Linear(d_model, 3 * n_heads * self.d_head, bias=False) + self.o_net = nn.Linear(n_heads * self.d_head, d_model, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, t, _ = x.shape + qkv = self.qkv_net(x).reshape(b, t, 3, self.n_heads, self.d_head) + q, k, v = qkv.unbind(dim=2) # each [b, t, nh, dh] + # [b, nh, t, dh] + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + attn = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=True, scale=self.scale + ) + attn = attn.transpose(1, 2).contiguous().view(b, t, -1) + return self.o_net(attn) + + +class EasyMagpieLTFeedForward(nn.Module): + """Positionwise FFN matching ``transformer_2501.PositionwiseConvFF`` weights. + + The reference uses ``Conv1d(kernel_size=1)`` layers named ``proj.conv`` and + ``o_net.conv`` (no bias). A kernel-1 conv is a plain linear over the channel + dim, so we keep the exact ``Conv1d`` submodule names — the checkpoint loads + 1:1 — and apply them with a single transpose, GELU(tanh) in between. + """ + + def __init__(self, d_model: int, d_ffn: int) -> None: + super().__init__() + # Wrap the Conv1d in a tiny container so the parameter path is + # ``proj.conv.weight`` / ``o_net.conv.weight`` exactly as in the + # reference ``ConvolutionLayer``. + self.proj = _Conv1dWrapper(d_model, d_ffn) + self.o_net = _Conv1dWrapper(d_ffn, d_model) + self.act = nn.GELU(approximate="tanh") + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: [b, t, c] -> conv expects [b, c, t] + h = x.transpose(1, 2) + h = self.act(self.proj(h)) + h = self.o_net(h) + return h.transpose(1, 2) + + +class _Conv1dWrapper(nn.Module): + """Holds a kernel-1 ``Conv1d`` under attribute name ``conv`` (no bias).""" + + def __init__(self, in_ch: int, out_ch: int) -> None: + super().__init__() + self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=1, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.conv(x) + + +class EasyMagpieLTLayer(nn.Module): + """One pre-norm transformer layer (self-attn + FFN), bias-free LayerNorms. + + Residual structure matches ``transformer_2501.TransformerLayer`` with an + all-ones ``x_mask`` (inference): ``x = x + attn(norm_self(x))`` then + ``x = x + ff(norm_pos_ff(x))``. The ``x * x_mask`` multiplications are + identities when nothing is padded, so they are dropped. + """ + + def __init__(self, d_model: int, d_ffn: int, n_heads: int) -> None: + super().__init__() + self.norm_self = nn.LayerNorm(d_model, bias=False) + self.self_attention = EasyMagpieLTSelfAttention(d_model, n_heads) + self.norm_pos_ff = nn.LayerNorm(d_model, bias=False) + self.pos_ff = EasyMagpieLTFeedForward(d_model, d_ffn) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.self_attention(self.norm_self(x)) + x = x + self.pos_ff(self.norm_pos_ff(x)) + return x + + +# NOTE: ``dynamic_arg_dims`` is passed explicitly rather than relying on +# vLLM's annotation-based inference. This file uses +# ``from __future__ import annotations`` (PEP 563), so ``forward``'s +# annotations are stored as strings (``"torch.Tensor"``) and vLLM's +# ``v.annotation in [torch.Tensor, ...]`` check would never match, raising +# "No dynamic dimensions found...". ``inputs_embeds`` is +# ``[num_tokens, num_codebooks, hidden]`` -> dim 0 (num_tokens) is dynamic. +@support_torch_compile(dynamic_arg_dims={"inputs_embeds": 0}) +class EasyMagpieLocalTransformer(nn.Module): + """Compiled causal transformer stack with learnable positional embeddings. + + Decorated with ``@support_torch_compile`` so vLLM captures a single CUDA + graph for the fixed ``(num_tokens, num_stacked_codebooks, d_model)`` input + shape. Weight layout mirrors ``transformer_2501.Transformer``: + ``position_embeddings`` (learnable), ``layers.{i}.*`` and a no-op + ``norm_out`` (``apply_norm_out=False`` in the reference, hence ``Identity``). + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + arch = EasyMagpieOmniArch.from_hf_config(vllm_config.model_config.hf_config) + d_model = arch.local_transformer_hidden_dim + n_heads = arch.local_transformer_n_heads + n_layers = arch.local_transformer_n_layers + d_ffn = d_model * 4 + # +2 matches the reference ``max_length_causal_mask`` head-room + # (``num_stacked_codebooks + 2``). + max_len = arch.num_stacked_codebooks + 2 + + self.position_embeddings = nn.Embedding(max_len, d_model) + self.layers = nn.ModuleList( + [EasyMagpieLTLayer(d_model, d_ffn, n_heads) for _ in range(n_layers)] + ) + # apply_norm_out=False in the reference config -> no parameters. + self.norm_out = nn.Identity() + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + seq_len = inputs_embeds.shape[1] + positions = torch.arange(seq_len, device=inputs_embeds.device) + x = inputs_embeds + self.position_embeddings(positions).unsqueeze(0) + for layer in self.layers: + x = layer(x) + return self.norm_out(x) + + +class EasyMagpieCodePredictor(nn.Module): + """Autoregressive intra-frame codebook predictor (the "local transformer"). + + Given the backbone's per-frame hidden state, predicts all ``C * S`` stacked + audio codebooks one at a time. Owns the codebook input embeddings (shared + with the outer model for building decode-step input embeddings) and all the + projection heads, plus the persistent scratch buffers required for + CUDA-graph replay. + + Per frame (``generate_codes``): + + 1. Position 0 of the input buffer holds ``in_proj(dec_hidden)``. + 2. For codebook ``k`` in ``0 .. N-1``: run the compiled transformer over the + whole buffer, read row ``k`` of the output, project to codebook-``k`` + logits, sample, and (if ``k < N-1``) write ``in_proj(audio_emb_k(code))`` + into buffer row ``k + 1``. + + The buffer is zeroed once per frame and filled incrementally; because the + transformer is causal, rows ``> k`` never influence row ``k``, so replaying + the same fixed-shape graph N times yields the correct autoregressive result. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + arch = EasyMagpieOmniArch.from_hf_config(vllm_config.model_config.hf_config) + self.arch = arch + self.num_codebooks = arch.num_stacked_codebooks + self.num_tokens_per_codebook = arch.num_all_tokens_per_codebook + self.audio_embedding_dim = arch.audio_embedding_dim + self.embedding_dim = arch.embedding_dim + lt_hidden = arch.local_transformer_hidden_dim + + # Per-codebook audio token embeddings (shared with the outer model's + # decode-step input-embedding assembly). Names match the reference + # checkpoint's ``audio_embeddings.{i}``. + self.audio_embeddings = nn.ModuleList( + [nn.Embedding(self.num_tokens_per_codebook, self.audio_embedding_dim) for _ in range(self.num_codebooks)] + ) + # audio_embedding_dim -> embedding_dim (Identity when equal). + if self.audio_embedding_dim != self.embedding_dim: + self.audio_in_projection = nn.Linear(self.audio_embedding_dim, self.embedding_dim) + else: + self.audio_in_projection = nn.Identity() + + # embedding_dim (== backbone hidden) -> local-transformer hidden. + if lt_hidden != self.embedding_dim: + self.local_transformer_in_projection = nn.Linear(self.embedding_dim, lt_hidden) + else: + self.local_transformer_in_projection = nn.Identity() + + self.local_transformer = EasyMagpieLocalTransformer( + vllm_config=vllm_config, prefix=f"{prefix}.local_transformer" + ) + + # local-transformer hidden -> audio_embedding_dim (Identity when equal). + if self.audio_embedding_dim != lt_hidden: + self.local_transformer_audio_out_projection = nn.Linear(lt_hidden, self.audio_embedding_dim) + else: + self.local_transformer_audio_out_projection = nn.Identity() + + # Per-codebook output heads. + self.local_transformer_out_projections = nn.ModuleList( + [nn.Linear(self.audio_embedding_dim, self.num_tokens_per_codebook) for _ in range(self.num_codebooks)] + ) + + # Forbidden-token mask (reserved/special tokens, EOS kept reachable). + # Populated by :meth:`init_forbidden_mask` once arch ids are known. + self.register_buffer( + "forbidden_mask", + torch.zeros(self.num_tokens_per_codebook, dtype=torch.bool), + persistent=False, + ) + + # Sampling knobs (overridable from the outer model / request). + self.temperature: float = 0.7 + self.top_k: int = 80 + + # ── Persistent address-stable scratch buffers ────────────────── + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + dtype = vllm_config.model_config.dtype + self._buf_inputs = torch.zeros(max_num_tokens, self.num_codebooks, lt_hidden, dtype=dtype) + self._out_codes = torch.zeros(max_num_tokens, self.num_codebooks, dtype=torch.long) + + @torch.no_grad() + def init_forbidden_mask(self) -> None: + """Forbid all trailing special tokens except audio EOS. + + Mirrors ``SpecialAudioToken.get_forbidden_tokens`` — everything in the + special-token block above ``codebook_size`` is blocked at sampling + time, except ``audio_eos`` which must remain reachable to terminate. + """ + mask = torch.zeros(self.num_tokens_per_codebook, dtype=torch.bool, device=self.forbidden_mask.device) + mask[self.arch.codebook_size :] = True + eos = self.arch.audio_eos_id + if 0 <= eos < self.num_tokens_per_codebook: + mask[eos] = False + self.forbidden_mask.copy_(mask) + + def embed_codebook(self, codebook_idx: int, codes: torch.Tensor) -> torch.Tensor: + """Embed a single codebook's tokens (``[num_tokens] -> [num_tokens, audio_dim]``).""" + return self.audio_embeddings[codebook_idx](codes) + + def embed_audio_frame(self, codes: torch.Tensor) -> torch.Tensor: + """Embed a full frame of stacked codes into the backbone embedding space. + + Averages per-codebook embeddings then applies ``audio_in_projection``, + matching the reference ``embed_audio_tokens`` (which sums and divides by + the number of codebooks). Used by the outer model to build the decode + input embedding from the previous frame's codes. + + Args: + codes: ``[num_tokens, num_codebooks]`` int64 codes. + + Returns: + ``[num_tokens, embedding_dim]`` float embedding. + """ + acc = self.audio_embeddings[0](codes[:, 0]) + for c in range(1, self.num_codebooks): + acc = acc + self.audio_embeddings[c](codes[:, c]) + acc = acc / self.num_codebooks + return self.audio_in_projection(acc) + + def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: + """Run the compiled local transformer over the input buffer.""" + return self.local_transformer(inputs_embeds) + + @torch.no_grad() + def generate_codes(self, dec_hidden: torch.Tensor) -> torch.Tensor: + """Autoregressively sample all ``C * S`` codebooks for each frame. + + Args: + dec_hidden: ``[num_tokens, hidden]`` backbone hidden state (one row + per frame being decoded). + + Returns: + ``[num_tokens, num_codebooks]`` int64 sampled codes. + """ + num_tokens = dec_hidden.shape[0] + buf = self._buf_inputs[:num_tokens] + out = self._out_codes[:num_tokens] + buf.zero_() + + # Row 0: projected backbone hidden state (the AR "prompt"). + buf[:, 0, :] = self.local_transformer_in_projection(dec_hidden) + + forbidden = self.forbidden_mask if self.forbidden_mask.any() else None + for k in range(self.num_codebooks): + hidden = self(buf) # compiled transformer over the fixed buffer + row = self.local_transformer_audio_out_projection(hidden[:, k, :]) + logits = self.local_transformer_out_projections[k](row) + code_k = sample_codebook( + logits, + temperature=self.temperature, + top_k=self.top_k, + forbidden_mask=forbidden, + ) + out[:, k] = code_k + if k + 1 < self.num_codebooks: + emb = self.audio_in_projection(self.audio_embeddings[k](code_k)) + buf[:, k + 1, :] = self.local_transformer_in_projection(emb) + + return out[:num_tokens] diff --git a/examples/tts/easymagpie_vllm_omni/pyproject.toml b/examples/tts/easymagpie_vllm_omni/pyproject.toml new file mode 100644 index 000000000000..5cd41cead748 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/pyproject.toml @@ -0,0 +1,20 @@ +[build-system] +requires = ["setuptools>=64"] +build-backend = "setuptools.build_meta" + +[project] +name = "easymagpie-vllm-omni" +version = "0.1.0" +description = "vLLM-Omni model definition for EasyMagpieTTS (Qwen2 backbone + AR local transformer)" +requires-python = ">=3.10" +dependencies = [ + # The heavy runtime deps (vllm, vllm-omni, torch) are provided by the + # target vllm_omni_env. Treat this as "install into an already-bootstrapped + # vllm_omni_env"; do not install into NeMo's nemo_virtual_environment. +] + +[project.entry-points."vllm.general_plugins"] +easymagpie_omni = "vllm_plugin_easymagpie_omni:register" + +[tool.setuptools.packages.find] +include = ["easymagpie_vllm_omni*", "vllm_plugin_easymagpie_omni*"] diff --git a/examples/tts/easymagpie_vllm_omni/vllm_plugin_easymagpie_omni/__init__.py b/examples/tts/easymagpie_vllm_omni/vllm_plugin_easymagpie_omni/__init__.py new file mode 100644 index 000000000000..a050ed562788 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/vllm_plugin_easymagpie_omni/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""vLLM plugin: register ``EasyMagpieTTS`` as a model architecture for vLLM-Omni. + +Loaded by vLLM in the parent process and each EngineCore subprocess via the +``vllm.general_plugins`` entry point. The lazy ``:`` target means +the (NeMo-free) model module is only imported when vLLM resolves the +architecture, keeping heavy imports out of the parent process. +""" + +_TARGET = "easymagpie_vllm_omni.easymagpie:EasyMagpieTTSForConditionalGeneration" +_ARCHS = ("EasyMagpieTTS", "EasyMagpieTTSForConditionalGeneration") + + +def register() -> None: + """Register the model class under all supported arch names. + + The architecture must be registered in **both** registries: + + * ``vllm.ModelRegistry`` — the stock vLLM global registry. + * ``vllm_omni``'s ``OmniModelRegistry`` — a *separate* ``_ModelRegistry`` + instance that the vLLM-Omni engine actually consults when resolving a + model architecture. Registering only in the stock registry leaves the + omni engine reporting ``Model architectures [...] are not supported``. + """ + from vllm import ModelRegistry + + registries = [ModelRegistry] + try: + from vllm_omni.model_executor.models import OmniModelRegistry + + registries.append(OmniModelRegistry) + except Exception: + # vllm_omni not installed — stock vLLM registration is enough. + pass + + for registry in registries: + for arch in _ARCHS: + if arch not in registry.get_supported_archs(): + registry.register_model(arch, _TARGET) From 87d742c859a20b2ac66bc53504c9bf9e119a5e2d Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 2 Jun 2026 18:20:44 +0200 Subject: [PATCH 03/15] examples/tts/easymagpie_vllm_omni: switch to actual configuration Signed-off-by: Viacheslav Klimkov --- .../easymagpie_inference_demo.ipynb | 115 ++++++++++---- .../easymagpie_vllm_omni/__init__.py | 11 +- .../easymagpie_vllm_omni/backbone_patches.py | 64 ++++++++ .../easymagpie_vllm_omni/config.py | 12 +- .../easymagpie_vllm_omni/easymagpie.py | 150 ++++++++++++------ .../tts/easymagpie_vllm_omni/pyproject.toml | 2 +- 6 files changed, 262 insertions(+), 92 deletions(-) create mode 100644 examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/backbone_patches.py diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb index 0e2693776adf..88af6037e056 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb @@ -28,11 +28,12 @@ " engine accumulates them across steps just like eartts, so we trim to the last\n", " `len(token_ids)` decoded rows.\n", "\n", - "> **Dummy weights.** We build a tiny `config.json` (small backbone + small\n", - "> codebooks) and start the engine with `load_format=\\\"dummy\\\"`, so vLLM fills all\n", - "> parameters with random values. The emitted codes are therefore meaningless —\n", - "> this is a *smoke test* of the engine wiring, not a real synthesis. Point the\n", - "> engine at a real converted checkpoint (and drop `load_format`) to get audio.\n", + "> **Dummy weights.** We build a `config.json` sized to the real checkpoint\n", + "> (`2605_EMTTS_SmallMamba_Step150k_posttrained_epoch12.nemo`) and start the\n", + "> engine with `load_format=\\\"dummy\\\"`, so vLLM fills all parameters with random\n", + "> values. The emitted codes are therefore meaningless — this is a *smoke test*\n", + "> of the engine wiring, not a real synthesis. Point the engine at a real\n", + "> converted checkpoint (and drop `load_format`) to get audio.\n", "\n", "> **Environment.** Run this inside the bootstrapped `vllm_omni_env` (vLLM +\n", "> vLLM-Omni + compatible torch) with the plugin installed:\n", @@ -82,14 +83,24 @@ "## 1. Build a tiny dummy model directory\n", "\n", "The engine only needs a `config.json` that (a) names the registered arch and\n", - "(b) carries the EasyMagpie + Qwen2 scalars. We deliberately pick **small** dims\n", - "so the dummy backbone and local transformer are fast to instantiate.\n", + "(b) carries the EasyMagpie + Nemotron-H scalars. Here we size everything to match\n", + "the real checkpoint\n", + "`2605_NemotronTTS_V0.2/v2/2605_EMTTS_SmallMamba_Step150k_posttrained_epoch12.nemo`\n", + "(hidden 1536, 8 codebooks × 1024, frame-stacking ×2, 3-layer local transformer).\n", + "\n", + "The backbone is a **Nemotron-H** hybrid (Mamba2 + attention + MoE) decoder:\n", + "`EasyMagpieTTSForConditionalGeneration` constructs vLLM's `NemotronHModel` and\n", + "implements the hybrid-Mamba interfaces (`HasInnerState` / `IsHybrid` /\n", + "`SupportsMambaPrefixCaching`), exactly like the EasyMagpie vLLM *sidecar*. The\n", + "`nemotron_h_config` fields (`hybrid_override_pattern`, `mamba_*`, `n_routed_experts`,\n", + "…) are copied verbatim from the checkpoint.\n", "\n", "The EasyMagpie-specific scalars (`embedding_dim`, `num_audio_codebooks`,\n", "`codebook_size`, `frame_stacking_factor`, `local_transformer_*`, …) are read by\n", - "`EasyMagpieOmniArch.from_hf_config`; the standard Qwen2 fields (`hidden_size`,\n", - "`num_hidden_layers`, …) configure the reused `Qwen2Model` backbone. Setting\n", - "`phoneme_vocab_size = 0` disables the optional phoneme branch for simplicity.\n", + "`EasyMagpieOmniArch.from_hf_config`. The phoneme branch is **enabled**\n", + "(`phoneme_stacking_factor = 1`, `phoneme_vocab_size = 2051`) to match the\n", + "checkpoint; the model self-predicts phonemes, so no phoneme stream needs to be\n", + "supplied in the prompt.\n", "\n", "With `load_format=\\\"dummy\\\"` (set in the stage config) vLLM never reads weight\n", "files, so a lone `config.json` is enough — no safetensors, no tokenizer." @@ -102,32 +113,70 @@ "metadata": {}, "outputs": [], "source": [ - "# Small, internally-consistent dummy profile.\n", + "# Config matching the real checkpoint:\n", + "# 2605_NemotronTTS_V0.2/v2/2605_EMTTS_SmallMamba_Step150k_posttrained_epoch12.nemo\n", + "#\n", + "# The backbone is a Nemotron-H hybrid (Mamba2 + attention + MoE) decoder, wired\n", + "# through vLLM's `NemotronHModel` by `EasyMagpieTTSForConditionalGeneration`. The\n", + "# fields below are ported verbatim from the checkpoint's `model_config.yaml`\n", + "# (the `nemotron_h_config` block + the EasyMagpie scalars). With\n", + "# `load_format=\"dummy\"` the weights are random — a realistically-sized smoke test.\n", + "#\n", "# embedding_dim == hidden_size == audio_embedding_dim == local_transformer_hidden_dim\n", - "# keeps every in/out projection an Identity (fewer dummy params, same code path).\n", - "HIDDEN = 256\n", - "NUM_AUDIO_CODEBOOKS = 4\n", - "CODEBOOK_SIZE = 64\n", - "FRAME_STACKING = 2 # -> num_stacked_codebooks = NUM_AUDIO_CODEBOOKS * FRAME_STACKING = 8\n", - "TEXT_VOCAB = 256\n", + "# (all 1536 in the checkpoint) keeps every in/out projection an Identity.\n", + "HIDDEN = 1536 # nemotron_h_config.hidden_size / embedding_dim / audio_embedding_dim\n", + "NUM_AUDIO_CODEBOOKS = 8 # vector_quantizer.num_groups\n", + "CODEBOOK_SIZE = 1024 # prod(vector_quantizer.num_levels_per_group) = 4**5\n", + "FRAME_STACKING = 2 # -> num_stacked_codebooks = NUM_AUDIO_CODEBOOKS * FRAME_STACKING = 16\n", + "PHONEME_STACKING = 1 # phoneme_stacking_factor\n", + "PHONEME_VOCAB = 2051 # IPA-BPE 2048 tokenizer + 3 special tokens\n", + "TEXT_VOCAB = 131072 # nemotron_h_config.vocab_size\n", "\n", "config = {\n", " # Resolved through the `vllm.general_plugins` entry point registered by the\n", " # `easymagpie_vllm_omni` package -> EasyMagpieTTSForConditionalGeneration.\n", " \"architectures\": [\"EasyMagpieTTSForConditionalGeneration\"],\n", - " # Standard Qwen2 backbone fields (consumed by vllm Qwen2Model).\n", - " \"model_type\": \"qwen2\",\n", + " # Nemotron-H backbone fields (consumed by vllm NemotronHModel) — copied\n", + " # verbatim from the checkpoint's `nemotron_h_config` block.\n", + " \"model_type\": \"nemotron_h\",\n", " \"hidden_size\": HIDDEN,\n", - " \"intermediate_size\": 4 * HIDDEN,\n", - " \"num_hidden_layers\": 2,\n", - " \"num_attention_heads\": 4,\n", - " \"num_key_value_heads\": 4,\n", - " \"max_position_embeddings\": 4096,\n", - " \"rms_norm_eps\": 1e-6,\n", - " \"rope_theta\": 1000000.0,\n", + " \"num_hidden_layers\": 31,\n", " \"vocab_size\": TEXT_VOCAB,\n", + " \"num_attention_heads\": 12,\n", + " \"num_key_value_heads\": 4,\n", + " \"attention_dropout\": 0.0,\n", + " \"attention_bias\": False,\n", + " \"max_position_embeddings\": 8192,\n", + " \"mamba_num_heads\": 64,\n", + " \"mamba_head_dim\": 24,\n", + " \"ssm_state_size\": 128,\n", + " \"conv_kernel\": 4,\n", + " \"n_groups\": 8,\n", + " \"chunk_size\": 256,\n", + " \"mamba_hidden_act\": \"silu\",\n", + " \"use_conv_bias\": True,\n", + " \"use_bias\": False,\n", + " \"intermediate_size\": 4096,\n", + " \"mlp_hidden_act\": \"silu\",\n", + " \"mlp_bias\": False,\n", + " \"n_routed_experts\": 24,\n", + " \"num_experts_per_tok\": 4,\n", + " \"moe_intermediate_size\": 768,\n", + " \"moe_shared_expert_intermediate_size\": 2048,\n", + " \"n_group\": 1,\n", + " \"topk_group\": 1,\n", + " \"routed_scaling_factor\": 2.5,\n", + " \"norm_topk_prob\": True,\n", + " # 31-char layer pattern: M=Mamba2, *=attention, E=MLP/MoE (len == num_hidden_layers).\n", + " \"hybrid_override_pattern\": \"MEMEM*EMEMEM*EMEMEMEM*EMEMEMEME\",\n", + " \"layer_norm_epsilon\": 1e-5,\n", + " \"residual_in_fp32\": False,\n", " \"tie_word_embeddings\": False,\n", - " \"torch_dtype\": \"float32\",\n", + " # bfloat16, not float32: the Nemotron-H MoE layers run vLLM's fused-MoE\n", + " # Triton kernel, whose block sizes are tuned for 16-bit. In float32 the\n", + " # kernel needs ~2x shared memory and overflows the GPU limit\n", + " # (OutOfResources: shared memory). bf16 also matches the real checkpoint.\n", + " \"torch_dtype\": \"bfloat16\",\n", " # EasyMagpie-specific scalars (read by EasyMagpieOmniArch.from_hf_config).\n", " \"text_vocab_size\": TEXT_VOCAB,\n", " \"embedding_dim\": HIDDEN,\n", @@ -135,10 +184,10 @@ " \"num_audio_codebooks\": NUM_AUDIO_CODEBOOKS,\n", " \"codebook_size\": CODEBOOK_SIZE,\n", " \"frame_stacking_factor\": FRAME_STACKING,\n", - " \"phoneme_stacking_factor\": 0, # disable phoneme branch\n", - " \"phoneme_vocab_size\": 0,\n", - " \"local_transformer_n_layers\": 2,\n", - " \"local_transformer_n_heads\": 4,\n", + " \"phoneme_stacking_factor\": PHONEME_STACKING,\n", + " \"phoneme_vocab_size\": PHONEME_VOCAB,\n", + " \"local_transformer_n_layers\": 3,\n", + " \"local_transformer_n_heads\": 12,\n", " \"local_transformer_hidden_dim\": HIDDEN,\n", "}\n", "\n", @@ -220,7 +269,9 @@ " \"distributed_executor_backend\": \"uni\",\n", " \"max_num_batched_tokens\": MAX_NUM_BATCHED_TOKENS,\n", " \"max_model_len\": MAX_MODEL_LEN,\n", - " \"dtype\": \"float32\",\n", + " # bf16 (not fp32): the Nemotron-H fused-MoE Triton kernel's block\n", + " # sizes are tuned for 16-bit and overflow shared memory in fp32.\n", + " \"dtype\": \"bfloat16\",\n", " \"attention_backend\": \"TRITON_ATTN\",\n", " # --- dummy-weights smoke-test knobs ---\n", " \"load_format\": \"dummy\",\n", diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/__init__.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/__init__.py index 8a37af8454ea..074c48463276 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/__init__.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/__init__.py @@ -14,14 +14,15 @@ """EasyMagpieTTS model definition for vLLM-Omni. This package provides an inference-only re-implementation of EasyMagpieTTS -(decoder-only, Qwen2 backbone + autoregressive local transformer over the -stacked audio codebooks) that plugs into the vLLM-Omni serving stack via the -standard ``preprocess`` / ``postprocess`` / ``make_omni_output`` hooks. +(decoder-only, Nemotron-H hybrid-Mamba backbone + autoregressive local +transformer over the stacked audio codebooks) that plugs into the vLLM-Omni +serving stack via the standard ``preprocess`` / ``postprocess`` / +``make_omni_output`` hooks. The companion ``vllm_plugin_easymagpie_omni`` package registers the model with vLLM's ``ModelRegistry`` through the ``vllm.general_plugins`` entry point. """ -from easymagpie_vllm_omni.config import EASYMAGPIE_QWEN, EasyMagpieOmniArch +from easymagpie_vllm_omni.config import EASYMAGPIE_SMALLMAMBA, EasyMagpieOmniArch -__all__ = ["EASYMAGPIE_QWEN", "EasyMagpieOmniArch"] +__all__ = ["EASYMAGPIE_SMALLMAMBA", "EasyMagpieOmniArch"] diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/backbone_patches.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/backbone_patches.py new file mode 100644 index 000000000000..efe8421f7af2 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/backbone_patches.py @@ -0,0 +1,64 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Backbone-side patches applied at model ``__init__``. + +Runtime fixes for the constructed ``NemotronHModel`` backbone. They live with +the model because they're inherent to running EasyMagpie SmallMamba +(``mlp_hidden_act=silu``) on vLLM's NemotronH implementation. Mirrors the +EasyMagpie vLLM *sidecar* (``easymagpie_vllm/backbone_patches.py``). +""" +from __future__ import annotations + +import torch.nn as nn +import torch.nn.functional as F +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class _SiluActivation(nn.Module): + """``nn.Module`` wrapper around ``F.silu`` (so vLLM's NemotronHMLP can hold it).""" + + def forward(self, x): + return F.silu(x) + + +def patch_silu_shared_experts(backbone) -> int: + """Replace ``shared_experts.act_fn`` with SiLU on every NemotronHMoE layer. + + vLLM's ``NemotronHMLP`` hard-codes ReLU² for ``shared_experts`` (ignoring + ``config.mlp_hidden_act``). SmallMamba trained with SiLU, so the mismatch + blows up shared-expert norms ~5× and the per-layer cosine drops to ≈-0.7 by + layer 30. Patching only ``act_fn`` (not the whole forward) keeps + ``NemotronHMLP.forward`` in charge so torch.compile / CUDA-graph capture + continue to wrap it unchanged. + + Args: + backbone: the ``NemotronHModel`` instance. + + Returns: + Number of layers patched. + """ + patched = 0 + for layer in backbone.layers: + mixer = getattr(layer, "mixer", None) + if mixer is None or mixer.__class__.__name__ != "NemotronHMoE": + continue + se = getattr(mixer, "shared_experts", None) + if se is None: + continue + se.act_fn = _SiluActivation() + patched += 1 + logger.info("SiLU shared_experts fix installed on %d layers", patched) + return patched diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py index c569089e32f7..cdba51613a8c 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py @@ -13,9 +13,9 @@ # limitations under the License. """Architecture constants for the EasyMagpieTTS vLLM-Omni model. -These mirror the values baked into the reference EasyMagpieTTS checkpoint -(``examples/tts/conf/magpietts/easy_magpietts.yaml`` — Qwen2.5-1.5B backbone, -8 codebooks, frame-stacking ×2, 3-layer autoregressive local transformer). +These mirror the values baked into the reference EasyMagpieTTS SmallMamba +checkpoint (Nemotron-H hybrid Mamba2 + attention + MoE backbone, 8 codebooks, +frame-stacking ×2, 3-layer autoregressive local transformer). The vLLM-Omni model reads the bulk of its configuration from the ``hf_config`` provided by vLLM at construction time; this dataclass captures @@ -123,7 +123,7 @@ def from_hf_config(cls, hf_config: Any) -> "EasyMagpieOmniArch": Any attribute present on ``hf_config`` overrides the default profile; unknown attributes are ignored. This lets a converted checkpoint carry its own ``easymagpie`` block in ``config.json`` while still working - out-of-the-box on the reference Qwen2.5-1.5B profile. + out-of-the-box on the reference SmallMamba profile. """ defaults = cls() kwargs: dict[str, Any] = {} @@ -154,5 +154,5 @@ def from_hf_config(cls, hf_config: Any) -> "EasyMagpieOmniArch": return cls(**merged) -# Reference profile: Qwen2.5-1.5B backbone EasyMagpieTTS checkpoint. -EASYMAGPIE_QWEN = EasyMagpieOmniArch() +# Reference profile: Nemotron-H SmallMamba EasyMagpieTTS checkpoint. +EASYMAGPIE_SMALLMAMBA = EasyMagpieOmniArch() diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py index bfb76ccb9303..e188b9387a7b 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -14,20 +14,27 @@ """Inference-only EasyMagpieTTS model for vLLM-Omni. EasyMagpieTTS is a decoder-only streaming TTS model: a text-LM backbone (the -reference checkpoint uses Qwen2.5-1.5B) consumes a per-frame additive input -embedding (text + phoneme + audio) and emits a per-frame hidden state, from -which a small autoregressive *local transformer* samples all ``C * S`` stacked -audio codebooks for that frame (see :mod:`easymagpie_vllm_omni.local_transformer`). +SmallMamba checkpoint uses a Nemotron-H hybrid Mamba2 + attention + MoE decoder) +consumes a per-frame additive input embedding (text + phoneme + audio) and +emits a per-frame hidden state, from which a small autoregressive *local +transformer* samples all ``C * S`` stacked audio codebooks for that frame +(see :mod:`easymagpie_vllm_omni.local_transformer`). This module wires that architecture into vLLM-Omni's ``preprocess`` / ``forward`` / ``compute_logits`` / ``make_omni_output`` / ``postprocess`` contract, following the same conventions as the upstream qwen3-tts and eartts vLLM-Omni model definitions: -* **Backbone** — vLLM's :class:`~vllm.model_executor.models.qwen2.Qwen2Model`, - reused wholesale (KV cache + paged attention) the same way the EasyMagpie - vLLM *sidecar* reuses ``NemotronHModel``. Every step feeds the backbone via - ``inputs_embeds``; its own ``embed_tokens`` table is never consumed. +* **Backbone** — vLLM's + :class:`~vllm.model_executor.models.nemotron_h.NemotronHModel`, reused + wholesale (hybrid Mamba2 state + KV cache + paged attention) exactly like the + EasyMagpie vLLM *sidecar*. Every step feeds the backbone via ``inputs_embeds``; + its own ``embed_tokens`` table is never consumed. Because the backbone is a + hybrid-Mamba model, the class implements vLLM's + :class:`HasInnerState` / :class:`IsHybrid` / :class:`SupportsMambaPrefixCaching` + contracts (mamba-state shape/dtype/copy helpers are delegated to + :class:`NemotronHForCausalLM`), and the SmallMamba SiLU shared-experts fix is + applied at construction (see :mod:`easymagpie_vllm_omni.backbone_patches`). * **Local transformer** — :class:`EasyMagpieCodePredictor`, a from-scratch, CUDA-graph-capturable re-implementation that runs as a single compiled graph. * **compute_logits** — returns trivial logits (à la eartts) so vLLM's sampler @@ -59,16 +66,21 @@ import torch from torch import nn from vllm.compilation.backends import set_model_tag -from vllm.compilation.decorators import ignore_torch_compile, support_torch_compile from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger -from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.models.interfaces import ( + HasInnerState, + IsHybrid, + SupportsMambaPrefixCaching, +) +from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM, NemotronHModel from vllm.model_executor.models.utils import maybe_prefix from vllm.sequence import IntermediateTensors from vllm_omni.model_executor.models.output_templates import OmniOutput +from easymagpie_vllm_omni.backbone_patches import patch_silu_shared_experts from easymagpie_vllm_omni.config import EasyMagpieOmniArch from easymagpie_vllm_omni.local_transformer import EasyMagpieCodePredictor @@ -81,21 +93,18 @@ _DUMMY_TOKEN_ID = 0 -# ``dynamic_arg_dims`` is passed explicitly: this file uses -# ``from __future__ import annotations`` (PEP 563), so ``forward``'s annotations -# are strings and vLLM's annotation-based inference would fail with -# "No dynamic dimensions found...". These mirror vLLM's default inference -# (dim 0 for every tensor / IntermediateTensors argument). -@ignore_torch_compile -@support_torch_compile( - dynamic_arg_dims={ - "input_ids": 0, - "positions": 0, - "intermediate_tensors": 0, - "inputs_embeds": 0, - } -) -class EasyMagpieTTSForConditionalGeneration(nn.Module): +# NOTE: unlike the Qwen2 backbone variant, this class is *not* wrapped in +# ``@support_torch_compile``. The Nemotron-H backbone is a hybrid-Mamba model +# that manages its own ``torch.compile`` / CUDA-graph capture internally (as +# does :class:`EasyMagpieCodePredictor`), so the outer ``forward`` runs eagerly +# and dispatches into the two self-compiled subgraphs — matching the EasyMagpie +# vLLM sidecar (``EasyMagpieSmallMamba``). +class EasyMagpieTTSForConditionalGeneration( + nn.Module, + HasInnerState, + IsHybrid, + SupportsMambaPrefixCaching, +): """EasyMagpieTTS talker for vLLM-Omni. See the module docstring for the per-step flow and the per-request I/O @@ -104,6 +113,12 @@ class EasyMagpieTTSForConditionalGeneration(nn.Module): ``OmniGPUModelRunner``. """ + # Hybrid-Mamba bookkeeping (delegated to vLLM's NemotronH causal-LM, exactly + # like the EasyMagpie sidecar). vLLM expects these as class attributes. + get_mamba_state_dtype_from_config = NemotronHForCausalLM.get_mamba_state_dtype_from_config + get_mamba_state_shape_from_config = NemotronHForCausalLM.get_mamba_state_shape_from_config + get_mamba_state_copy_func = NemotronHForCausalLM.get_mamba_state_copy_func + # Omni runner hooks. has_preprocess: bool = True has_postprocess: bool = True @@ -129,11 +144,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.embedding_dim = arch.embedding_dim self.num_codebooks = arch.num_stacked_codebooks - # ── Backbone (reused vLLM text LM; fed via inputs_embeds) ─────── - self.backbone = Qwen2Model( + # ── Backbone (reused vLLM Nemotron-H LM; fed via inputs_embeds) ── + self.backbone = NemotronHModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone"), ) + # SmallMamba was trained with mlp_hidden_act=silu but vLLM's NemotronHMLP + # hard-codes ReLU² in shared_experts. Restore SiLU (no-op when the + # backbone has no MoE layers). + patch_silu_shared_experts(self.backbone) # ── Local transformer (its own compile group / CUDA graph) ────── with set_model_tag("local_transformer"): @@ -202,6 +221,43 @@ def _embed_phoneme(self, phoneme_tokens: torch.Tensor) -> torch.Tensor: # Decode-token dispatch (which positions need the local transformer) # ------------------------------------------------------------------ + @staticmethod + def _select_query_layout(attn_metadata): + """Return ``(max_query_len, query_start_loc)`` from heterogeneous metadata. + + The Nemotron-H backbone is hybrid, so ``attn_metadata`` is a per-layer + dict mixing two metadata types: + + * **attention** layers carry standard metadata that exposes the + batch-level ``max_query_len`` + ``query_start_loc`` (e.g. + ``TritonAttentionMetadata``); + * **Mamba2** layers carry ``Mamba2AttentionMetadata``, which has *no* + ``max_query_len`` and splits the query layout into ``query_start_loc_p`` + / ``query_start_loc_d`` instead. + + Both are built from the same batch query layout, so we prefer any + attention-layer metadata. As a fallback for a (hypothetical) attention-free + backbone, we infer a decode-only batch from the Mamba2 ``num_prefills`` + counter. Returns ``(None, None)`` when the layout can't be determined. + """ + metas = list(attn_metadata.values()) if isinstance(attn_metadata, dict) else [attn_metadata] + + # Preferred: an attention layer exposes the unified query layout. + for m in metas: + mql = getattr(m, "max_query_len", None) + qsl = getattr(m, "query_start_loc", None) + if mql is not None and qsl is not None: + return int(mql), qsl + + # Fallback: Mamba2-only backbone. We can at least detect a decode-only + # batch (every request contributes a single token) from the counters. + for m in metas: + if hasattr(m, "num_prefills") and hasattr(m, "num_decodes"): + if int(getattr(m, "num_prefills", 0)) == 0: + return 1, None # decode-only -> caller runs the LT everywhere + break + return None, None + def _get_decode_idxs(self): """Return ``(decode_token_indices, num_requests)`` for code-predictor dispatch. @@ -220,15 +276,12 @@ def _get_decode_idxs(self): if attn_metadata is None: return None, 0 - if isinstance(attn_metadata, dict): - any_layer_meta = next(iter(attn_metadata.values())) - else: - any_layer_meta = attn_metadata + max_query_len, start_loc = self._select_query_layout(attn_metadata) - if any_layer_meta.max_query_len == 1: + # Decode-only batch (or layout unavailable) -> run the LT on every token. + if max_query_len is None or max_query_len == 1 or start_loc is None: return None, 0 - start_loc = any_layer_meta.query_start_loc tokens_per_req = start_loc[1:] - start_loc[:-1] is_decode = tokens_per_req == 1 decode_token_indices = start_loc[:-1][is_decode] @@ -284,9 +337,9 @@ def forward( self._assemble_decode_embeddings(combined, valid) hidden_states = self.backbone( - input_ids, - positions, - intermediate_tensors, + input_ids=input_ids, + positions=positions, + intermediate_tensors=intermediate_tensors, inputs_embeds=combined, ) @@ -331,7 +384,10 @@ def _assemble_decode_embeddings(self, combined: torch.Tensor, idx) -> None: @torch.no_grad() def _predict_phonemes(self, hidden_states: torch.Tensor, idx) -> None: """Argmax the phoneme head and stash the prediction for the next step.""" - logits = self.phoneme_final_proj(hidden_states[idx].float()) + # Run in the model dtype (don't force fp32): ``phoneme_final_proj`` weights + # follow ``model_config.dtype`` (e.g. bf16), and argmax is dtype-insensitive, + # so an fp32 upcast here would mismatch the weight dtype in ``F.linear``. + logits = self.phoneme_final_proj(hidden_states[idx]) s = self.arch.phoneme_stacking_factor logits = logits.view(-1, s, self.arch.phoneme_vocab_size) self._dec_phoneme_tokens[idx] = logits.argmax(dim=-1).long() @@ -517,7 +573,8 @@ def postprocess(self, hidden_states: torch.Tensor, multimodal_outputs: Optional[ # Checkpoint prefixes (reference EasyMagpieTTS state dict) → in-model paths. # ``decoder.*`` is fed to the vLLM backbone loader separately (it understands - # HF Qwen2 naming + qkv packing). The TTS submodules are copied manually. + # HF Nemotron-H naming + Mamba/MoE packing). The TTS submodules are copied + # manually. _TTS_PREFIX_MAP = { "local_transformer.": "code_predictor.local_transformer.", "local_transformer_in_projection.": "code_predictor.local_transformer_in_projection.", @@ -538,14 +595,15 @@ def _remap_tts_key(self, name: str) -> Optional[str]: return None def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: - """Load backbone (Qwen2) + TTS submodule weights from a converted checkpoint. + """Load backbone (Nemotron-H) + TTS submodule weights from a converted checkpoint. The converted checkpoint is expected to use the reference EasyMagpieTTS - key layout: the backbone under ``decoder.*`` (HF Qwen2 names) and the - TTS submodules at top level (``audio_embeddings.*``, ``local_transformer.*``, - ``phoneme_*``, ``text_embedding.*``, projection heads). Backbone weights - are routed to :meth:`Qwen2Model.load_weights` (which packs qkv / gate-up - and handles HF naming); TTS weights are copied directly by name. + key layout: the backbone under ``decoder.*`` (HF Nemotron-H names) and + the TTS submodules at top level (``audio_embeddings.*``, + ``local_transformer.*``, ``phoneme_*``, ``text_embedding.*``, projection + heads). Backbone weights are routed to :meth:`NemotronHModel.load_weights` + (which handles HF naming + Mamba/MoE packing); TTS weights are copied + directly by name. """ own_params = dict(self.named_parameters()) loaded: set[str] = set() @@ -578,9 +636,5 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: # Derived runtime state. self.code_predictor.init_forbidden_mask() - # The backbone's vestigial embed_tokens table is never consumed - # (everything goes through inputs_embeds); don't flag it as missing. - loaded.add("backbone.embed_tokens.weight") - logger.info("Loaded %d weights for EasyMagpieTTSForConditionalGeneration", len(loaded)) return loaded diff --git a/examples/tts/easymagpie_vllm_omni/pyproject.toml b/examples/tts/easymagpie_vllm_omni/pyproject.toml index 5cd41cead748..c6d4d8942c93 100644 --- a/examples/tts/easymagpie_vllm_omni/pyproject.toml +++ b/examples/tts/easymagpie_vllm_omni/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "easymagpie-vllm-omni" version = "0.1.0" -description = "vLLM-Omni model definition for EasyMagpieTTS (Qwen2 backbone + AR local transformer)" +description = "vLLM-Omni model definition for EasyMagpieTTS (Nemotron-H hybrid-Mamba backbone + AR local transformer)" requires-python = ">=3.10" dependencies = [ # The heavy runtime deps (vllm, vllm-omni, torch) are provided by the From bb8b4276f347b0776ef1815d7d3bc4ef80d9c85f Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 2 Jun 2026 18:20:44 +0200 Subject: [PATCH 04/15] examples/tts/easymagpie_vllm_omni: make sure model runs with cuda graphs Signed-off-by: Viacheslav Klimkov --- .../easymagpie_vllm_omni/easymagpie_inference_demo.ipynb | 2 +- .../easymagpie_vllm_omni/local_transformer.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb index 88af6037e056..50e53a1c1e03 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb @@ -260,7 +260,7 @@ " \"model_arch\": \"EasyMagpieTTSForConditionalGeneration\",\n", " \"worker_type\": \"ar\",\n", " \"scheduler_cls\": \"vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler\",\n", - " \"enforce_eager\": True, # dummy run: skip CUDA-graph capture for a faster start\n", + " #\"enforce_eager\": True, # dummy run: skip CUDA-graph capture for a faster start\n", " \"trust_remote_code\": True,\n", " \"async_scheduling\": True,\n", " \"enable_prefix_caching\": False,\n", diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py index a72ee6ecd52d..b48715604530 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py @@ -379,7 +379,11 @@ def generate_codes(self, dec_hidden: torch.Tensor) -> torch.Tensor: # Row 0: projected backbone hidden state (the AR "prompt"). buf[:, 0, :] = self.local_transformer_in_projection(dec_hidden) - forbidden = self.forbidden_mask if self.forbidden_mask.any() else None + # Always pass the mask unconditionally. An all-False mask makes + # ``masked_fill`` a no-op, so there's no need to guard with + # ``forbidden_mask.any()`` — and that guard is a data-dependent + # host sync that is illegal during CUDA-graph capture. + forbidden = self.forbidden_mask for k in range(self.num_codebooks): hidden = self(buf) # compiled transformer over the fixed buffer row = self.local_transformer_audio_out_projection(hidden[:, k, :]) From 999256918b1040e054014c5c591e8cfffab2182b Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 2 Jun 2026 18:20:44 +0200 Subject: [PATCH 05/15] examples/tts/easymagpie_vllm_omni: extend preprocess to take speaker embeddings and prepare prefill embeddings Signed-off-by: Viacheslav Klimkov --- .../easymagpie_inference_demo.ipynb | 106 +++++++--- .../easymagpie_vllm_omni/config.py | 7 + .../easymagpie_vllm_omni/easymagpie.py | 194 ++++++++++++++++-- 3 files changed, 261 insertions(+), 46 deletions(-) diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb index 50e53a1c1e03..dd7322cef37c 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb @@ -16,10 +16,12 @@ "It follows the same `AsyncOmni` single-stage pattern as the reference\n", "`qwen3-tts` and `eartts` demos:\n", "\n", - "* **prefill** — the caller supplies a precomputed context embedding via\n", - " `additional_information.prompt_embeds` of shape `(T_ctx, embedding_dim)`, with\n", - " `prompt_token_ids = [0] * T_ctx` (exactly like qwen3-tts `talker_prompt_embeds`\n", - " / eartts `speaker_latent`).\n", + "* **prefill** — the caller supplies the speaker-encoded context-audio embedding\n", + " via `additional_information.speaker_embedding` `(T_audio, embedding_dim)` plus a\n", + " plain `context_text` string; the model assembles the full prefill context\n", + " (`[task_embedding? | speaker_embedding | context_text_embedded]`) and tokenizes\n", + " `context_text` itself. `prompt_token_ids = [0] * prompt_len`, sized with\n", + " `EasyMagpieTTSForConditionalGeneration.estimate_prompt_len(...)`.\n", "* **decode** — each step consumes one subword id from the streaming\n", " `additional_information.text_tokens` list; the local transformer samples all\n", " `C * S` stacked audio codebooks for the frame.\n", @@ -103,7 +105,11 @@ "supplied in the prompt.\n", "\n", "With `load_format=\\\"dummy\\\"` (set in the stage config) vLLM never reads weight\n", - "files, so a lone `config.json` is enough — no safetensors, no tokenizer." + "files, so no safetensors are needed. We do save the checkpoint's\n", + "text-conditioning tokenizer (`TEXT_TOKENIZER`, the Nemotron-H tokenizer that\n", + "matches `TEXT_VOCAB`) into the model dir, since the model tokenizes the\n", + "per-request `context_text` in-engine via\n", + "`AutoTokenizer.from_pretrained(model_path)`." ] }, { @@ -131,6 +137,10 @@ "PHONEME_STACKING = 1 # phoneme_stacking_factor\n", "PHONEME_VOCAB = 2051 # IPA-BPE 2048 tokenizer + 3 special tokens\n", "TEXT_VOCAB = 131072 # nemotron_h_config.vocab_size\n", + "# Text-conditioning tokenizer that matches the checkpoint (SmallMamba uses the\n", + "# Nemotron-H tokenizer, vocab 131072 == TEXT_VOCAB). Point this at the converted\n", + "# checkpoint dir / the checkpoint's tokenizer when running a real model.\n", + "TEXT_TOKENIZER = \"nvidia/Nemotron-H-8B-Base-8K\"\n", "\n", "config = {\n", " # Resolved through the `vllm.general_plugins` entry point registered by the\n", @@ -193,6 +203,14 @@ "\n", "MODEL_DIR = Path(tempfile.mkdtemp(prefix=\"easymagpie_dummy_\"))\n", "(MODEL_DIR / \"config.json\").write_text(json.dumps(config, indent=2))\n", + "\n", + "# The model tokenizes the per-request `context_text` string in-engine via\n", + "# `AutoTokenizer.from_pretrained(model_path)` (qwen3-tts style), so the model dir\n", + "# must ship the checkpoint's text-conditioning tokenizer. We save the matching\n", + "# Nemotron-H tokenizer (TEXT_TOKENIZER) into MODEL_DIR.\n", + "from transformers import AutoTokenizer\n", + "\n", + "AutoTokenizer.from_pretrained(TEXT_TOKENIZER, trust_remote_code=True).save_pretrained(MODEL_DIR)\n", "print(f\"Dummy model dir: {MODEL_DIR}\")\n", "\n", "# Sanity-check the arch the model will derive from this config.\n", @@ -235,8 +253,10 @@ "metadata": {}, "outputs": [], "source": [ - "T_CTX = 16 # prefill context-embedding length (prompt_token_ids = [0] * T_CTX)\n", "DECODE_STEPS = 32 # number of audio frames to decode\n", + "# Prefill length is derived at prompt-build time from the speaker embedding +\n", + "# tokenized context_text (see the prompt cell); these just need to be large\n", + "# enough to cover prefill + decode.\n", "MAX_MODEL_LEN = 512\n", "MAX_NUM_BATCHED_TOKENS = 512\n", "\n", @@ -311,18 +331,26 @@ "source": [ "## 3. Build the prompt\n", "\n", - "Two pieces of per-request input, passed through `additional_information`:\n", + "Per-request input, passed through `additional_information`:\n", "\n", - "* **`prompt_embeds`** `(T_ctx, embedding_dim)` — the precomputed context\n", - " embedding consumed during prefill. In a real run this is the speaker-encoded\n", - " context audio + context text produced by the caller; here we use random noise.\n", - " `prompt_token_ids = [0] * T_ctx` are placeholders (the model feeds the backbone\n", - " via `inputs_embeds`, never via these ids).\n", + "* **`speaker_embedding`** `(T_audio, embedding_dim)` — the speaker-encoded\n", + " context-audio embedding (the audio branch of `prepare_context_tensors`),\n", + " loaded here from `eng_speaker_emb.pt` (as written by\n", + " `easy_magpietts_extract_speaker_encoding.py`). The model assembles the full\n", + " prefill context itself as `[task_embedding? | speaker_embedding |\n", + " context_text_embedded]`.\n", + "* **`context_text`** — a plain conditioning string, here `\"[EN]\"`. The model\n", + " tokenizes it in-engine and embeds it through the baked `text_embedding` table.\n", "* **`text_tokens`** `list[int]` — the streaming subword stream; decode step `k`\n", " consumes `text_tokens[k]`. We provide one id per decode step.\n", "\n", - "(If the checkpoint had a phoneme branch you'd also stream `phoneme_tokens`; it's\n", - "disabled here via `phoneme_vocab_size = 0`.)" + "`prompt_token_ids = [0] * prompt_len` are placeholders (the model feeds the\n", + "backbone via `inputs_embeds`, never these ids). `prompt_len` must equal the\n", + "assembled context length, so we size it with the model's\n", + "`estimate_prompt_len(...)` — the length-only mirror of the in-engine prefill\n", + "assembly (à la qwen3-tts's `estimate_prompt_len_from_additional_information`).\n", + "\n", + "(If the checkpoint had a phoneme branch you'd also stream `phoneme_tokens`.)" ] }, { @@ -334,33 +362,61 @@ "source": [ "torch.manual_seed(0)\n", "\n", - "# Precomputed context embedding (random stand-in for the speaker/text encoder).\n", - "prompt_embeds = torch.randn(T_CTX, arch.embedding_dim, dtype=torch.float32)\n", + "from transformers import AutoTokenizer\n", + "\n", + "from easymagpie_vllm_omni.easymagpie import EasyMagpieTTSForConditionalGeneration\n", + "\n", + "# Speaker-encoded context audio (audio branch of prepare_context_tensors),\n", + "# produced by easy_magpietts_extract_speaker_encoding.py.\n", + "SPEAKER_EMB_FILE = \"eng_speaker_emb.pt\"\n", + "_loaded = torch.load(SPEAKER_EMB_FILE, map_location=\"cpu\")\n", + "speaker_embedding = _loaded[\"speaker_encoding\"] if isinstance(_loaded, dict) else _loaded\n", + "speaker_embedding = speaker_embedding.to(torch.float32)\n", + "\n", + "# Plain conditioning string; the model tokenizes + embeds it in-engine.\n", + "CONTEXT_TEXT = \"[EN]\"\n", + "\n", + "# Same tokenizer the engine loads from MODEL_DIR — used to size the prefill\n", + "# placeholders so prompt_token_ids length matches the assembled context.\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)\n", + "prompt_len = EasyMagpieTTSForConditionalGeneration.estimate_prompt_len(\n", + " speaker_embedding,\n", + " tokenize=lambda t: tokenizer.encode(t),\n", + " context_text=CONTEXT_TEXT,\n", + " has_task_embedding=arch.num_task_embeddings > 0,\n", + ")\n", "\n", "# Streaming subword ids: one per decode step (step k consumes text_tokens[k]).\n", "text_tokens = torch.randint(0, TEXT_VOCAB, (DECODE_STEPS,)).tolist()\n", "\n", "additional_information = {\n", - " \"prompt_embeds\": prompt_embeds, # (T_ctx, embedding_dim) tensor\n", - " \"text_tokens\": text_tokens, # list[int], grows by one per step\n", + " \"speaker_embedding\": speaker_embedding, # (T_audio, embedding_dim) tensor\n", + " \"context_text\": CONTEXT_TEXT, # plain string, tokenized in-model\n", + " \"text_tokens\": text_tokens, # list[int], grows by one per step\n", "}\n", "\n", "prompt = {\n", - " \"prompt_token_ids\": [0] * T_CTX, # prefill placeholders\n", + " \"prompt_token_ids\": [0] * prompt_len, # prefill placeholders\n", " \"additional_information\": additional_information,\n", "}\n", "\n", + "assert prompt_len + DECODE_STEPS <= MAX_MODEL_LEN, (\n", + " f\"prompt_len ({prompt_len}) + decode steps ({DECODE_STEPS}) exceeds \"\n", + " f\"MAX_MODEL_LEN ({MAX_MODEL_LEN}); raise MAX_MODEL_LEN / MAX_NUM_BATCHED_TOKENS.\"\n", + ")\n", + "\n", + "print(f\"speaker_embedding : {tuple(speaker_embedding.shape)}\")\n", + "print(f\"context_text : {CONTEXT_TEXT!r} -> {tokenizer.encode(CONTEXT_TEXT)}\")\n", + "print(f\"prompt_len (placeholders) : {prompt_len}\")\n", + "print(f\"decode steps (max_tokens) : {DECODE_STEPS}\")\n", + "print(f\"text_tokens[:8] : {text_tokens[:8]}\")\n", + "\n", "sampling_params = SamplingParams(\n", " temperature=0.0,\n", " max_tokens=DECODE_STEPS,\n", " detokenize=False,\n", " ignore_eos=True, # dummy logits never emit a meaningful EOS -> run the full budget\n", - ")\n", - "\n", - "print(f\"T_ctx (prefill placeholders) : {T_CTX}\")\n", - "print(f\"prompt_embeds : {tuple(prompt_embeds.shape)}\")\n", - "print(f\"decode steps (max_tokens) : {DECODE_STEPS}\")\n", - "print(f\"text_tokens[:8] : {text_tokens[:8]}\")" + ")" ] }, { diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py index cdba51613a8c..1b086ec3e562 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py @@ -72,6 +72,12 @@ class EasyMagpieOmniArch: phoneme_stacking_factor: int = 1 phoneme_vocab_size: int = 2051 + # Number of multi-mode task ("service token") embeddings. The reference model + # prepends a single learned per-mode embedding to the prefill context when + # trained with >1 mode (``cfg.training_modes``); 0 disables it (single-mode + # checkpoints have no ``task_embedding`` table). + num_task_embeddings: int = 0 + local_transformer_n_layers: int = 3 local_transformer_n_heads: int = 12 local_transformer_hidden_dim: int = 1536 @@ -136,6 +142,7 @@ def from_hf_config(cls, hf_config: Any) -> "EasyMagpieOmniArch": "frame_stacking_factor", "phoneme_stacking_factor", "phoneme_vocab_size", + "num_task_embeddings", "local_transformer_n_layers", "local_transformer_n_heads", "local_transformer_hidden_dim", diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py index e188b9387a7b..cb5e8bd346f4 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -47,10 +47,25 @@ Per-request I/O (via ``additional_information``): -* ``prompt_embeds`` (prefill only) — ``(T_ctx, embedding_dim)`` precomputed - context/prompt embedding (speaker-encoded context audio + context text) - produced by the caller, exactly like qwen3-tts ``talker_prompt_embeds`` / - eartts ``speaker_latent``. The user passes ``prompt_token_ids = [0] * T_ctx``. +* ``speaker_embedding`` (prefill only) — ``(T_audio, embedding_dim)`` speaker- + encoded context-audio embedding (the audio branch of the reference + ``prepare_context_tensors``), e.g. the tensor saved by + ``easy_magpietts_extract_speaker_encoding.py``. ``preprocess`` assembles the + full prefill context embedding itself as + ``[task_embedding | speaker_embedding | context_text_embedded]`` — the same + layout the reference model builds — so the caller only does the speaker-encoder + math and passes plain context text (the model tokenizes + embeds it and + prepends the per-mode service token). +* ``context_text`` (prefill only, optional) — plain conditioning string (e.g. + ``"[EN]"``); tokenized in-model with the checkpoint's text tokenizer and + embedded through the baked per-subword ``text_embedding`` table. Defaults to + ``"[NO TEXT CONTEXT]"`` when omitted. +* ``task_mode_id`` (prefill only, optional) — int selecting the per-mode task + ("service token") embedding row; defaults to ``0``. Ignored for single-mode + checkpoints (no ``task_embedding`` table). + + The caller passes ``prompt_token_ids = [0] * T_ctx``, where ``T_ctx`` is the + assembled context length (``[task?] + T_audio + len(tokenize(context_text))``). * ``text_tokens`` — Python ``list[int]`` of subword ids that grows by one per decode step; step ``k`` consumes ``text_tokens[k]`` (embedded through the precomputed per-subword table). @@ -60,7 +75,7 @@ from __future__ import annotations import bisect -from collections.abc import Iterable +from collections.abc import Callable, Iterable from typing import Any, Optional import torch @@ -92,6 +107,9 @@ # argmax-at-0 dummy logits, so this only needs to be a valid id. _DUMMY_TOKEN_ID = 0 +# Context text used when the request omits ``context_text`` +_DEFAULT_CONTEXT_TEXT = "[EN]" + # NOTE: unlike the Qwen2 backbone variant, this class is *not* wrapped in # ``@support_torch_compile``. The Nemotron-H backbone is a hybrid-Mamba model @@ -171,6 +189,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: text_vocab_size = int(getattr(hf_config, "text_vocab_size", getattr(hf_config, "vocab_size", 0))) self.text_embedding = nn.Embedding(text_vocab_size, self.embedding_dim) + # Task ("service token") embedding — a single learned per-mode row the + # reference model prepends to the prefill context when trained with >1 + # mode. Built only when the checkpoint carries one; otherwise ``None``. + self.num_task_embeddings = int(arch.num_task_embeddings) + if self.num_task_embeddings > 0: + self.task_embedding = nn.Embedding(self.num_task_embeddings, self.embedding_dim) + else: + self.task_embedding = None + + # Context-text tokenizer, loaded lazily from the model directory (same + # ``AutoTokenizer.from_pretrained(model_path)`` pattern as qwen3-tts). It + # turns the per-request ``context_text`` string (e.g. ``"[EN]"``) into the + # subword ids that the baked ``text_embedding`` table consumes — so the + # caller passes plain text, never pre-tokenized ids. + self._text_tokenizer: Any = None + # Phoneme channel (optional — only built when the checkpoint has one). self.has_phoneme = arch.phoneme_vocab_size > 0 and arch.phoneme_stacking_factor > 0 if self.has_phoneme: @@ -432,10 +466,13 @@ def make_omni_output(self, model_outputs, **_: Any) -> OmniOutput: # ------------------------------------------------------------------ @staticmethod - def _unwrap(value: Any) -> Any: + def _first_str(value: Any) -> str: + """Return the first element of a list-wrapped scalar, or the scalar itself, as a string.""" if isinstance(value, list): - return value[0] if value else None - return value + return str(value[0]) if value else "" + if value is None: + return "" + return str(value) def preprocess( self, @@ -448,9 +485,11 @@ def preprocess( ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: """Build per-request ``(input_ids, inputs_embeds)`` for this step. - Prefill (``span_len > 1``): slice the precomputed ``prompt_embeds`` - context embedding into this chunk and return it; ``input_ids`` are - placeholders. Decode (``span_len == 1``): write the per-token decode + Prefill (``span_len > 1``): assemble the full context embedding + (``[task_embedding | speaker_embedding | context_text_embedded]`` from + the per-request inputs; see :meth:`_build_prefill_embeds`), slice this + chunk out of it, and return it; + ``input_ids`` are placeholders. Decode (``span_len == 1``): write the per-token decode inputs (previous codes, current text token, previous phoneme) into the model buffers at ``start`` and return a zero embedding that :meth:`forward` accumulates into. @@ -479,25 +518,19 @@ def _preprocess_prefill( device: torch.device, info_dict: dict[str, Any], ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: - prompt_embeds = self._unwrap(info_dict.get("prompt_embeds")) - if not isinstance(prompt_embeds, torch.Tensor) or prompt_embeds.ndim != 2: - raise ValueError( - "EasyMagpieTTS preprocess requires additional_information.prompt_embeds " - "of shape (T_ctx, embedding_dim) for prefill." - ) - prompt_embeds = prompt_embeds.to(device=device, dtype=self._combined_embeddings.dtype) + prefill_embeds = self._build_prefill_embeds(device, info_dict) offset = int(info_dict.get("ear_prefill_offset", 0) or 0) - total = int(prompt_embeds.shape[0]) + total = int(prefill_embeds.shape[0]) s = max(0, min(offset, total)) e = max(0, min(offset + span_len, total)) - take = prompt_embeds[s:e] + take = prefill_embeds[s:e] if int(take.shape[0]) < span_len: pad_n = span_len - int(take.shape[0]) pad_rows = ( take[-1:].expand(pad_n, -1) if take.shape[0] > 0 - else prompt_embeds.new_zeros(pad_n, prompt_embeds.shape[-1]) + else prefill_embeds.new_zeros(pad_n, prefill_embeds.shape[-1]) ) take = torch.cat([take, pad_rows], dim=0) @@ -508,6 +541,121 @@ def _preprocess_prefill( input_ids_out = torch.full_like(input_ids, _DUMMY_TOKEN_ID) return input_ids_out, take, info_update + def _build_prefill_embeds( + self, + device: torch.device, + info_dict: dict[str, Any], + ) -> torch.Tensor: + """Assemble the full ``(T_ctx, embedding_dim)`` prefill context embedding. + + Reproduces the prefill assembly from the reference + ``prepare_context_tensors``:: + + [task_embedding | speaker_embedding | context_text_embedded] + + from the per-request inputs: + + * ``speaker_embedding`` — the speaker-encoded context-audio embedding + (e.g. produced by ``easy_magpietts_extract_speaker_encoding.py``), + required as a 2-D ``(T_audio, embedding_dim)`` tensor. + * ``context_text`` — a plain string (e.g. ``"[EN]"``); tokenized in-model + (see :meth:`_encode_context_text`) and embedded through the baked + per-subword ``text_embedding`` table (which already folds in the CAS + encoder, matching the default ``disable_cas_for_context_text=False`` + training). Defaults to ``"[NO TEXT CONTEXT]"`` when omitted. + * ``task_mode_id`` — selects the per-mode task ("service token") + embedding row; prepended only when the checkpoint has a task table. + + Returns the full context embedding; the per-chunk slicing/padding is done + by :meth:`_preprocess_prefill`. + """ + dtype = self._combined_embeddings.dtype + + speaker_embedding = info_dict.get("speaker_embedding") + assert isinstance(speaker_embedding, torch.Tensor) and speaker_embedding.ndim == 2, ( + "EasyMagpieTTS preprocess expects additional_information.speaker_embedding to be a 2-D " + "(T_audio, embedding_dim) tensor (the speaker-encoded context audio); " + f"got {type(speaker_embedding).__name__}" + + (f" with ndim={speaker_embedding.ndim}" if isinstance(speaker_embedding, torch.Tensor) else "") + ) + + parts: list[torch.Tensor] = [] + + # Task / "service token" embedding (prepended), when present. + if self.task_embedding is not None: + task_mode_id = int(info_dict.get("task_mode_id", 0) or 0) + task_mode_id = max(0, min(task_mode_id, self.num_task_embeddings - 1)) + task_row = self.task_embedding(torch.tensor([task_mode_id], device=device, dtype=torch.long)) + parts.append(task_row.to(dtype)) + + # Speaker-encoded context audio. + parts.append(speaker_embedding.to(device=device, dtype=dtype)) + + # Context text: tokenized in-model and embedded through the baked table. + context_text = self._first_str(info_dict.get("context_text")) or _DEFAULT_CONTEXT_TEXT + ctx_ids = self._encode_context_text(context_text, device) + if ctx_ids.numel() > 0: + parts.append(self.text_embedding(ctx_ids).to(dtype)) + + return torch.cat(parts, dim=0) + + def _get_text_tokenizer(self): + """Lazily load the context-text tokenizer from the model directory. + + Mirrors qwen3-tts: the converted checkpoint ships a HuggingFace + ``AutoTokenizer`` (the model's text-conditioning tokenizer) alongside its + weights, so we load it on first use from ``model_path``. + """ + if self._text_tokenizer is None: + from transformers import AutoTokenizer + + self._text_tokenizer = AutoTokenizer.from_pretrained(self.model_path, trust_remote_code=True) + return self._text_tokenizer + + def _encode_context_text(self, context_text: str, device: torch.device) -> torch.Tensor: + """Tokenize ``context_text`` to subword ids (matching the reference encode path). + + The reference ``AggregatedTTSTokenizer.encode`` calls the underlying + HF tokenizer's ``encode`` (default ``add_special_tokens``) for the + text-conditioning tokenizer, which sits at offset 0 in the aggregate, so + its raw ids index the baked ``text_embedding`` table directly. + """ + tok = self._get_text_tokenizer() + ids = tok.encode(context_text) + return torch.tensor(ids, device=device, dtype=torch.long) + + @staticmethod + def estimate_prompt_len( + speaker_embedding: torch.Tensor, + *, + tokenize: Callable[[str], Iterable[int]], + context_text: str = _DEFAULT_CONTEXT_TEXT, + has_task_embedding: bool = False, + ) -> int: + """Length-only mirror of :meth:`_build_prefill_embeds` (à la qwen3-tts's + ``estimate_prompt_len_from_additional_information``). + + The engine assembles the prefill context as + ``[task_embedding? | speaker_embedding | context_text_embedded]``, so the + caller must pass ``prompt_token_ids = [0] * estimate_prompt_len(...)`` for + the placeholder length to match the assembled embedding length (otherwise + vLLM pads / truncates and quality drops). + + Args: + speaker_embedding: ``(T_audio, embedding_dim)`` speaker-encoded + context-audio embedding (only its length is used). + tokenize: callable turning ``context_text`` into its subword ids + (e.g. ``lambda t: tokenizer.encode(t)``) — must match the + tokenizer the engine loads from ``model_path``. + context_text: conditioning string (default ``"[NO TEXT CONTEXT]"``). + has_task_embedding: whether the checkpoint prepends a task / + "service token" embedding (``num_task_embeddings > 0``). + """ + t_audio = int(speaker_embedding.shape[0]) + ctx_len = len(list(tokenize(context_text or _DEFAULT_CONTEXT_TEXT))) + task_len = 1 if has_task_embedding else 0 + return task_len + t_audio + ctx_len + def _preprocess_decode( self, input_ids: torch.Tensor, @@ -585,6 +733,7 @@ def postprocess(self, hidden_states: torch.Tensor, multimodal_outputs: Optional[ "phoneme_embeddings.": "phoneme_embeddings.", "phoneme_final_proj.": "phoneme_final_proj.", "text_embedding.": "text_embedding.", + "task_embedding.": "task_embedding.", } def _remap_tts_key(self, name: str) -> Optional[str]: @@ -617,6 +766,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: if mapped is None: # Unrelated checkpoint section (codec, speaker encoder, CAS, etc.). continue + if mapped.startswith("task_embedding.") and self.task_embedding is None: + # Single-mode model: checkpoint may still ship an (unused) table. + continue target = own_params.get(mapped) if target is None: logger.warning("EasyMagpieTTS: no parameter for checkpoint key %s -> %s", name, mapped) From 3a8d50b36398cf7d96dcce95f81d9772edd44f95 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 2 Jun 2026 18:20:44 +0200 Subject: [PATCH 06/15] examples/tts/easymagpie_vllm_omni: introduce script to convert the checkpoint to vllm omni one Signed-off-by: Viacheslav Klimkov --- .../easy_magpietts_convert_to_vllm.py | 438 ++++++++++++++++++ .../easymagpie_inference_demo.ipynb | 367 ++++++++------- .../easymagpie_vllm_omni/easymagpie.py | 19 +- 3 files changed, 658 insertions(+), 166 deletions(-) create mode 100644 examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py diff --git a/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py b/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py new file mode 100644 index 000000000000..664a243d7415 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py @@ -0,0 +1,438 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert an EasyMagpieTTS ``.nemo`` checkpoint to a vLLM-Omni model directory. + +The output directory is self-contained and ready to be passed as ``model=`` +to the ``easymagpie_vllm_omni`` vLLM-Omni model +(:class:`easymagpie_vllm_omni.easymagpie.EasyMagpieTTSForConditionalGeneration`). +It contains: + +* ``config.json`` — the flat HF-style config the vLLM model reads at + construction (the Nemotron-H backbone fields + the EasyMagpie scalars consumed + by :class:`easymagpie_vllm_omni.config.EasyMagpieOmniArch`). +* ``model.safetensors`` (+ ``model.safetensors.index.json``) — the converted + weights using the reference EasyMagpieTTS key layout expected by the vLLM + model's ``load_weights`` (``decoder.*`` backbone + top-level TTS submodules). +* the checkpoint's **text-conditioning tokenizer** saved via + ``AutoTokenizer.save_pretrained`` so the model can tokenize per-request + ``context_text`` in-engine. +* ``speaker_embeddings/.pt`` (optional) — pre-computed speaker-encoder + outputs for one or more reference audio files, used as the ``speaker_embedding`` + input at inference time. + +Compared to running the reference model, the character-aware subword (CAS) +encoder is collapsed into a single pre-computed lookup table mapping +``subword_id -> embedding`` (the CAS encoder is fully deterministic per subword +id, so it is baked once at conversion time and never run inside the engine). The +``decoder``'s unused token-embedding table is replaced by a tiny dummy (the +backbone is always fed via ``inputs_embeds``). + +Example:: + + python examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py \\ + --nemo_file /path/to/EMTTS_SmallMamba.nemo \\ + --codec_model_path /path/to/25fps_spectral_codec.nemo \\ + --outdir ./easymagpie_vllm_model \\ + --context_audio /path/to/reference_voice.wav --speaker_name eng +""" +from __future__ import annotations + +import argparse +import json +import os + +import torch +import tqdm +from omegaconf import OmegaConf +from safetensors.torch import save_file + +from nemo.collections.tts.modules.magpietts_inference.utils import ModelLoadConfig, load_easy_magpie_model +from nemo.collections.tts.modules.magpietts_modules import add_special_tokens +from nemo.utils import logging + +# Top-level checkpoint key prefixes the vLLM model's ``load_weights`` consumes +# for the TTS submodules (everything else under these names maps 1:1 into the +# vLLM model). ``text_embedding.*`` is intentionally excluded here: it is +# replaced by the pre-computed per-subword lookup table. +_TTS_PREFIXES = ( + "audio_embeddings.", + "audio_in_projection.", + "local_transformer.", + "local_transformer_in_projection.", + "local_transformer_audio_out_projection.", + "local_transformer_out_projections.", + "phoneme_embeddings.", + "phoneme_final_proj.", + "task_embedding.", +) + +# The backbone token-embedding table is never consumed at runtime (the model +# runs off ``inputs_embeds``), so we ship a dummy table. It must still be >= 2: +# vLLM's profiling ``_dummy_sampler_run`` sets ``top_k = vocab_size - 1`` and then +# gathers at index ``vocab_size - top_k``, which is out of bounds for a width-1 +# logits tensor (device-side "scatter gather index out of bounds" assert). +_BACKBONE_VOCAB_SIZE = 2 + +# Nemotron-H backbone config fields forwarded into the flat vLLM ``config.json``. +# Names match the HF/vLLM Nemotron-H config (and the NeMo ``NemotronHConfig``). +_NEMOTRON_CONFIG_FIELDS = ( + "hidden_size", + "num_hidden_layers", + "num_attention_heads", + "num_key_value_heads", + "head_dim", + "attention_dropout", + "attention_bias", + "max_position_embeddings", + "mamba_num_heads", + "mamba_head_dim", + "ssm_state_size", + "conv_kernel", + "n_groups", + "chunk_size", + "mamba_hidden_act", + "use_conv_bias", + "use_bias", + "intermediate_size", + "mlp_hidden_act", + "mlp_bias", + "n_routed_experts", + "num_experts_per_tok", + "moe_intermediate_size", + "moe_shared_expert_intermediate_size", + "n_group", + "topk_group", + "routed_scaling_factor", + "norm_topk_prob", + "hybrid_override_pattern", + "layer_norm_epsilon", + "residual_in_fp32", +) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Convert an EasyMagpieTTS .nemo checkpoint to a vLLM-Omni model directory." + ) + parser.add_argument("--nemo_file", required=True, help="Path to the EasyMagpieTTS .nemo checkpoint.") + parser.add_argument("--codec_model_path", required=True, help="Path to the audio codec .nemo checkpoint.") + parser.add_argument("--outdir", required=True, help="Output directory for the vLLM model.") + parser.add_argument( + "--phoneme_tokenizer_path", + default=None, + help="Override the phoneme (IPA BPE) tokenizer path baked into the checkpoint.", + ) + parser.add_argument( + "--disable_cas_for_context_text", + action="store_true", + help="Set for legacy checkpoints trained without CAS embeddings on context text.", + ) + parser.add_argument( + "--text_tokenizer", + default=None, + help="HuggingFace tokenizer name/path to export. Defaults to the checkpoint's " + "text-conditioning AutoTokenizer (`pretrained_model`).", + ) + parser.add_argument( + "--context_audio", + default=None, + help="Optional reference wav for which to pre-compute a speaker embedding.", + ) + parser.add_argument( + "--speaker_name", + default="default", + help="Name for the saved speaker embedding (speaker_embeddings/.pt).", + ) + parser.add_argument("--context_audio_duration", type=float, default=5.0) + parser.add_argument( + "--dtype", + default="bfloat16", + choices=["bfloat16", "float16", "float32"], + help="Saved weight dtype / config torch_dtype. bf16 matches the reference inference setup.", + ) + parser.add_argument( + "--precompute_batch_size", + type=int, + default=1024, + help="Batch size for pre-computing per-subword text embeddings.", + ) + parser.add_argument("--device", default="cuda") + return parser.parse_args() + + +@torch.no_grad() +def precompute_text_embeddings(model, batch_size: int) -> torch.Tensor: + """Bake the per-subword text embedding into a single lookup table. + + Runs ``embed_text_tokens`` (decoder subword embedding + the deterministic + char-aware subword encoder) once per subword id so the vLLM model can replace + the whole text-embedding path with a single ``nn.Embedding`` lookup. + + Returns: + Tensor of shape ``[vocab_size, embedding_dim]`` (float32). + """ + device = next(model.parameters()).device + + # Vocabulary size of the subword id space (decoder text-embedding table when + # present; otherwise the CAS-only id range, which ends at cfg_unk_token_id). + if getattr(model, "text_embedding", None) is not None: + vocab_size = model.text_embedding.num_embeddings + else: + vocab_size = int(model.cfg_unk_token_id) + 1 + embedding_dim = int(model.cfg.embedding_dim) + + table = torch.zeros((vocab_size, embedding_dim), dtype=torch.float32, device=device) + logging.info(f"Pre-computing text embeddings for {vocab_size} subword ids on {device}") + for start in tqdm.tqdm(range(0, vocab_size, batch_size), desc="Pre-computing text embeddings"): + end = min(start + batch_size, vocab_size) + ids = torch.arange(start, end, dtype=torch.long, device=device).unsqueeze(0) # (1, n) + lens = torch.tensor([end - start], dtype=torch.long, device=device) + embeds = model.embed_text_tokens(ids, text_lens=lens, disable_cas_embedding=False) # (1, n, E) + table[start:end] = embeds.squeeze(0).to(torch.float32) + return table.cpu() + + +@torch.no_grad() +def extract_speaker_embedding(model, context_audio_path: str, context_audio_duration: float) -> torch.Tensor: + """Reproduce the audio branch of ``prepare_context_tensors`` for one wav. + + Mirrors ``easy_magpietts_extract_speaker_encoding.py``: encode the (trimmed) + reference audio to codec codes, add special tokens, frame-stack, embed the + per-codebook tokens, and (when enabled) run the speaker encoder. Returns the + ``(T_audio, embedding_dim)`` tensor consumed as the model's ``speaker_embedding``. + """ + device = next(model.parameters()).device + + context_audio = model._load_audio_for_inference(context_audio_path, model.sample_rate) + context_audio = model._adjust_audio_to_duration_for_inference( + context_audio, + model.sample_rate, + context_audio_duration, + model.codec_model_samples_per_frame, + ) + context_audio = context_audio.to(device) + context_audio_lens = torch.tensor([context_audio.size(1)], dtype=torch.long, device=device) + context_audio_codes, context_audio_codes_lens = model._codec_helper.audio_to_codes( + context_audio, context_audio_lens + ) + + if model._codec_converter is not None: + context_audio_codes = model._codec_converter.convert_original_to_new( + audio_tokens=context_audio_codes, audio_lens=context_audio_codes_lens + ).long() + + context_audio_codes, context_audio_codes_lens = add_special_tokens( + codes=context_audio_codes, + codes_len=context_audio_codes_lens, + bos_id=model.context_audio_bos_id, + eos_id=model.context_audio_eos_id, + ) + context_audio_codes, context_audio_codes_lens = model.stack_codes( + context_audio_codes, + context_audio_codes_lens, + model.context_audio_bos_id, + model.context_audio_eos_id, + model.frame_stacking_factor, + model.num_audio_codebooks, + ) + + context_audio_embedded = model.embed_audio_tokens(context_audio_codes) # (B, T_audio, E) + if getattr(model, "use_speaker_encoder", False): + context_audio_embedded = model.encode_context_audio_embeddings( + context_audio_embedded=context_audio_embedded, + context_audio_lens=context_audio_codes_lens, + ) + else: + logging.warning( + "Checkpoint has use_speaker_encoder=False; saving raw per-codebook audio embeddings " + "(no speaker encoder applied)." + ) + + audio_len = int(context_audio_codes_lens[0].item()) + return context_audio_embedded[0, :audio_len].contiguous().float().detach().cpu() + + +def build_config(model, vocab_size: int, torch_dtype: str) -> dict: + """Build the flat vLLM ``config.json`` dict from the loaded NeMo model.""" + from nemo.collections.tts.modules.nemotron_h_decoder import NemotronHConfig + + cfg = model.cfg + if cfg.get("decoder_type", "huggingface") != "nemotron_h": + raise ValueError( + "The easymagpie_vllm_omni model only supports a Nemotron-H backbone " + f"(decoder_type='nemotron_h'); got '{cfg.get('decoder_type')}'." + ) + + hidden_dim = int(cfg.hidden_dim) + embedding_dim = int(cfg.embedding_dim) + + # Resolve the backbone config exactly as NeMo does (fills head_dim, expands + # the hybrid pattern to num_hidden_layers, etc.). + nemotron_dict = dict(OmegaConf.to_container(cfg.nemotron_h_config, resolve=True)) + nemotron_dict.setdefault("hidden_size", embedding_dim) + nemotron_cfg = NemotronHConfig(**nemotron_dict) + + config: dict = {"architectures": ["EasyMagpieTTSForConditionalGeneration"], "model_type": "nemotron_h"} + for field in _NEMOTRON_CONFIG_FIELDS: + if hasattr(nemotron_cfg, field): + config[field] = getattr(nemotron_cfg, field) + config["tie_word_embeddings"] = False + config["torch_dtype"] = torch_dtype + # The backbone token-embedding table is never consumed (inputs_embeds path); + # the dummy logits width follows it. Must be >= 2 (see ``_BACKBONE_VOCAB_SIZE``). + # The text path is driven by ``text_vocab_size`` / the baked ``text_embedding`` + # table instead. + config["vocab_size"] = _BACKBONE_VOCAB_SIZE + + # ── EasyMagpie scalars (read by EasyMagpieOmniArch.from_hf_config) ── + config["text_vocab_size"] = vocab_size + config["embedding_dim"] = embedding_dim + config["audio_embedding_dim"] = int(cfg.get("audio_embedding_dim", hidden_dim)) + config["num_audio_codebooks"] = int(model.num_audio_codebooks) + config["codebook_size"] = int(model.codebook_size) + config["frame_stacking_factor"] = int(model.frame_stacking_factor) + + has_phoneme = getattr(model, "phoneme_tokenizer", None) is not None + config["phoneme_stacking_factor"] = int(getattr(model, "phoneme_stacking_factor", 0)) if has_phoneme else 0 + config["phoneme_vocab_size"] = int(getattr(model, "phoneme_vocab_size", 0)) if has_phoneme else 0 + + config["num_task_embeddings"] = len(model.training_modes) if model.task_embedding is not None else 0 + + config["local_transformer_n_layers"] = int(cfg.get("local_transformer_n_layers", 2)) + config["local_transformer_n_heads"] = int(cfg.get("local_transformer_n_heads", 1)) + config["local_transformer_hidden_dim"] = int(cfg.get("local_transformer_hidden_dim", hidden_dim)) + + # Pin the exact special-token ids (covers legacy ``forced_*`` checkpoints). + config["forced_audio_bos_id"] = int(model.audio_bos_id) + config["forced_audio_eos_id"] = int(model.audio_eos_id) + config["forced_mask_token_id"] = int(model.mask_token_id) + + return config + + +def select_weights(state_dict: dict, hidden_dim: int, dtype: torch.dtype) -> dict: + """Select + rename checkpoint weights into the vLLM ``load_weights`` layout.""" + weights: dict = {} + + # Backbone: keep all ``decoder.*`` except the unused token-embedding table. + for key, value in state_dict.items(): + if not key.startswith("decoder."): + continue + if key == "decoder.embeddings.weight": + continue + if key.endswith(".causal_mask"): + continue + weights[key] = value.to(dtype) if value.is_floating_point() else value + + # Dummy backbone embeddings (size ``_BACKBONE_VOCAB_SIZE``) — never consumed + # at runtime; sized to match ``config.vocab_size``. + weights["decoder.embeddings.weight"] = torch.zeros(_BACKBONE_VOCAB_SIZE, hidden_dim, dtype=dtype) + + # TTS submodules copied 1:1. + for key, value in state_dict.items(): + if key.endswith(".causal_mask"): + continue + if any(key.startswith(prefix) for prefix in _TTS_PREFIXES): + weights[key] = value.to(dtype) if value.is_floating_point() else value + + return weights + + +def save_text_tokenizer(model, outdir: str, override: str | None) -> None: + """Export the checkpoint's text-conditioning tokenizer into ``outdir``.""" + from transformers import AutoTokenizer + + pretrained = override + if pretrained is None: + tok_name = model.text_conditioning_tokenizer_name + tok_cfg = model.cfg.text_tokenizers[tok_name] + if tok_cfg.get("_target_", None) != "AutoTokenizer" or tok_cfg.get("pretrained_model", None) is None: + raise ValueError( + "Could not infer the text-conditioning AutoTokenizer from the checkpoint config. " + "Pass --text_tokenizer explicitly." + ) + pretrained = tok_cfg.pretrained_model + + logging.info(f"Saving text tokenizer '{pretrained}' to {outdir}") + AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True).save_pretrained(outdir) + + +def convert(args) -> None: + os.makedirs(args.outdir, exist_ok=True) + dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[args.dtype] + + model, ckpt_name = load_easy_magpie_model( + ModelLoadConfig( + nemo_file=args.nemo_file, + codecmodel_path=args.codec_model_path, + phoneme_tokenizer_path=args.phoneme_tokenizer_path, + disable_cas_for_context_text=args.disable_cas_for_context_text, + ), + device=args.device, + ) + logging.info(f"Loaded EasyMagpieTTS checkpoint: {ckpt_name}") + + hidden_dim = int(model.cfg.hidden_dim) + + # ── 1. Pre-compute the per-subword text embedding table ────────────── + text_table = precompute_text_embeddings(model, args.precompute_batch_size) + vocab_size = int(text_table.shape[0]) + + # ── 2. config.json ─────────────────────────────────────────────────── + config = build_config(model, vocab_size, args.dtype) + with open(os.path.join(args.outdir, "config.json"), "w") as f: + json.dump(config, f, indent=2) + logging.info("Saved config.json") + + # ── 3. weights ─────────────────────────────────────────────────────── + state_dict = model.state_dict() + weights = select_weights(state_dict, hidden_dim, dtype) + weights["text_embedding.weight"] = text_table.to(dtype) + + safetensors_path = os.path.join(args.outdir, "model.safetensors") + save_file(weights, safetensors_path, metadata={"format": "pt"}) + index = { + "metadata": {"total_size": sum(w.numel() * w.element_size() for w in weights.values())}, + "weight_map": {name: "model.safetensors" for name in weights}, + } + with open(os.path.join(args.outdir, "model.safetensors.index.json"), "w") as f: + json.dump(index, f, indent=2) + logging.info(f"Saved {len(weights)} weights to {safetensors_path}") + + # ── 4. text tokenizer ──────────────────────────────────────────────── + save_text_tokenizer(model, args.outdir, args.text_tokenizer) + + # ── 5. optional speaker embedding ──────────────────────────────────── + if args.context_audio is not None: + speaker_dir = os.path.join(args.outdir, "speaker_embeddings") + os.makedirs(speaker_dir, exist_ok=True) + speaker_encoding = extract_speaker_embedding(model, args.context_audio, args.context_audio_duration) + out_path = os.path.join(speaker_dir, f"{args.speaker_name}.pt") + torch.save( + { + "speaker_encoding": speaker_encoding, + "context_audio": args.context_audio, + "embedding_dim": int(speaker_encoding.size(-1)), + "num_frames": int(speaker_encoding.size(0)), + "checkpoint": ckpt_name, + }, + out_path, + ) + logging.info(f"Saved speaker embedding '{args.speaker_name}' {tuple(speaker_encoding.shape)} to {out_path}") + + logging.info(f"Done. vLLM model directory: {args.outdir}") + + +if __name__ == "__main__": + convert(parse_args()) diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb index dd7322cef37c..8cbcc17f4665 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb @@ -5,13 +5,14 @@ "id": "d5a1129d", "metadata": {}, "source": [ - "# EasyMagpieTTS — vLLM-Omni inference demo (dummy weights)\n", + "# EasyMagpieTTS — vLLM-Omni inference demo\n", "\n", "This notebook runs an end-to-end inference pass through the\n", - "[`easymagpie_vllm_omni`](./easymagpie_vllm_omni) model definition using\n", - "**dummy / random weights**, so you can exercise the full engine path\n", - "(prefill -> autoregressive decode -> audio-code extraction) without a converted\n", - "checkpoint.\n", + "[`easymagpie_vllm_omni`](./easymagpie_vllm_omni) model definition using a\n", + "**converted checkpoint directory** produced by\n", + "[`easy_magpietts_convert_to_vllm.py`](./easy_magpietts_convert_to_vllm.py)\n", + "(weights + `config.json` + text tokenizer + speaker embeddings). It exercises the\n", + "full engine path: prefill -> autoregressive decode -> audio-code extraction.\n", "\n", "It follows the same `AsyncOmni` single-stage pattern as the reference\n", "`qwen3-tts` and `eartts` demos:\n", @@ -23,19 +24,16 @@ " `context_text` itself. `prompt_token_ids = [0] * prompt_len`, sized with\n", " `EasyMagpieTTSForConditionalGeneration.estimate_prompt_len(...)`.\n", "* **decode** — each step consumes one subword id from the streaming\n", - " `additional_information.text_tokens` list; the local transformer samples all\n", - " `C * S` stacked audio codebooks for the frame.\n", + " `additional_information.text_tokens` list (the tokenized target sentence); the\n", + " local transformer samples all `C * S` stacked audio codebooks for the frame.\n", "* **output** — per-step audio codes are surfaced on\n", " `OmniOutput.multimodal_outputs[\\\"audio_codes\\\"]` (`BT x num_codebooks`), and the\n", " engine accumulates them across steps just like eartts, so we trim to the last\n", " `len(token_ids)` decoded rows.\n", "\n", - "> **Dummy weights.** We build a `config.json` sized to the real checkpoint\n", - "> (`2605_EMTTS_SmallMamba_Step150k_posttrained_epoch12.nemo`) and start the\n", - "> engine with `load_format=\\\"dummy\\\"`, so vLLM fills all parameters with random\n", - "> values. The emitted codes are therefore meaningless — this is a *smoke test*\n", - "> of the engine wiring, not a real synthesis. Point the engine at a real\n", - "> converted checkpoint (and drop `load_format`) to get audio.\n", + "> **Converted checkpoint.** Set `MODEL_DIR` below to the directory written by the\n", + "> converter. The engine reads the `config.json`, weights, and tokenizer straight\n", + "> from it — no hardcoded config, no dummy weights.\n", "\n", "> **Environment.** Run this inside the bootstrapped `vllm_omni_env` (vLLM +\n", "> vLLM-Omni + compatible torch) with the plugin installed:\n", @@ -71,7 +69,7 @@ "\n", "# Importing the model package is optional (the engine resolves the arch via the\n", "# `vllm.general_plugins` entry point installed with the package), but doing it\n", - "# here surfaces the arch dataclass we use to size the dummy prompt embedding.\n", + "# here surfaces the arch dataclass we use to read scalars from the config.\n", "from easymagpie_vllm_omni.config import EasyMagpieOmniArch\n", "\n", "print(\"torch:\", torch.__version__, \"| cuda:\", torch.cuda.is_available())" @@ -82,34 +80,30 @@ "id": "f7ff55fe", "metadata": {}, "source": [ - "## 1. Build a tiny dummy model directory\n", + "## 1. Point at the converted model directory\n", "\n", - "The engine only needs a `config.json` that (a) names the registered arch and\n", - "(b) carries the EasyMagpie + Nemotron-H scalars. Here we size everything to match\n", - "the real checkpoint\n", - "`2605_NemotronTTS_V0.2/v2/2605_EMTTS_SmallMamba_Step150k_posttrained_epoch12.nemo`\n", - "(hidden 1536, 8 codebooks × 1024, frame-stacking ×2, 3-layer local transformer).\n", + "Set `MODEL_DIR` to the directory written by\n", + "[`easy_magpietts_convert_to_vllm.py`](./easy_magpietts_convert_to_vllm.py). It\n", + "already contains everything the engine needs:\n", + "\n", + "* `config.json` — the registered arch + Nemotron-H backbone scalars + EasyMagpie\n", + " scalars (read by `EasyMagpieOmniArch.from_hf_config`),\n", + "* `model.safetensors` — the converted weights (backbone + TTS submodules + the\n", + " baked per-subword `text_embedding` table),\n", + "* the text-conditioning tokenizer (`tokenizer.json` / `tokenizer_config.json`),\n", + " loaded in-engine to tokenize the per-request `context_text`,\n", + "* `speaker_embeddings/.pt` — pre-computed speaker embeddings for reference\n", + " voices.\n", "\n", "The backbone is a **Nemotron-H** hybrid (Mamba2 + attention + MoE) decoder:\n", "`EasyMagpieTTSForConditionalGeneration` constructs vLLM's `NemotronHModel` and\n", "implements the hybrid-Mamba interfaces (`HasInnerState` / `IsHybrid` /\n", "`SupportsMambaPrefixCaching`), exactly like the EasyMagpie vLLM *sidecar*. The\n", - "`nemotron_h_config` fields (`hybrid_override_pattern`, `mamba_*`, `n_routed_experts`,\n", - "…) are copied verbatim from the checkpoint.\n", - "\n", - "The EasyMagpie-specific scalars (`embedding_dim`, `num_audio_codebooks`,\n", - "`codebook_size`, `frame_stacking_factor`, `local_transformer_*`, …) are read by\n", - "`EasyMagpieOmniArch.from_hf_config`. The phoneme branch is **enabled**\n", - "(`phoneme_stacking_factor = 1`, `phoneme_vocab_size = 2051`) to match the\n", - "checkpoint; the model self-predicts phonemes, so no phoneme stream needs to be\n", - "supplied in the prompt.\n", - "\n", - "With `load_format=\\\"dummy\\\"` (set in the stage config) vLLM never reads weight\n", - "files, so no safetensors are needed. We do save the checkpoint's\n", - "text-conditioning tokenizer (`TEXT_TOKENIZER`, the Nemotron-H tokenizer that\n", - "matches `TEXT_VOCAB`) into the model dir, since the model tokenizes the\n", - "per-request `context_text` in-engine via\n", - "`AutoTokenizer.from_pretrained(model_path)`." + "phoneme branch is enabled in the converted config; the model self-predicts\n", + "phonemes, so no phoneme stream needs to be supplied in the prompt.\n", + "\n", + "We just read the `config.json` here to surface a few scalars used for building\n", + "the prompt (`text_vocab_size`, the audio EOS id, whether a task embedding exists)." ] }, { @@ -119,106 +113,26 @@ "metadata": {}, "outputs": [], "source": [ - "# Config matching the real checkpoint:\n", - "# 2605_NemotronTTS_V0.2/v2/2605_EMTTS_SmallMamba_Step150k_posttrained_epoch12.nemo\n", - "#\n", - "# The backbone is a Nemotron-H hybrid (Mamba2 + attention + MoE) decoder, wired\n", - "# through vLLM's `NemotronHModel` by `EasyMagpieTTSForConditionalGeneration`. The\n", - "# fields below are ported verbatim from the checkpoint's `model_config.yaml`\n", - "# (the `nemotron_h_config` block + the EasyMagpie scalars). With\n", - "# `load_format=\"dummy\"` the weights are random — a realistically-sized smoke test.\n", - "#\n", - "# embedding_dim == hidden_size == audio_embedding_dim == local_transformer_hidden_dim\n", - "# (all 1536 in the checkpoint) keeps every in/out projection an Identity.\n", - "HIDDEN = 1536 # nemotron_h_config.hidden_size / embedding_dim / audio_embedding_dim\n", - "NUM_AUDIO_CODEBOOKS = 8 # vector_quantizer.num_groups\n", - "CODEBOOK_SIZE = 1024 # prod(vector_quantizer.num_levels_per_group) = 4**5\n", - "FRAME_STACKING = 2 # -> num_stacked_codebooks = NUM_AUDIO_CODEBOOKS * FRAME_STACKING = 16\n", - "PHONEME_STACKING = 1 # phoneme_stacking_factor\n", - "PHONEME_VOCAB = 2051 # IPA-BPE 2048 tokenizer + 3 special tokens\n", - "TEXT_VOCAB = 131072 # nemotron_h_config.vocab_size\n", - "# Text-conditioning tokenizer that matches the checkpoint (SmallMamba uses the\n", - "# Nemotron-H tokenizer, vocab 131072 == TEXT_VOCAB). Point this at the converted\n", - "# checkpoint dir / the checkpoint's tokenizer when running a real model.\n", - "TEXT_TOKENIZER = \"nvidia/Nemotron-H-8B-Base-8K\"\n", - "\n", - "config = {\n", - " # Resolved through the `vllm.general_plugins` entry point registered by the\n", - " # `easymagpie_vllm_omni` package -> EasyMagpieTTSForConditionalGeneration.\n", - " \"architectures\": [\"EasyMagpieTTSForConditionalGeneration\"],\n", - " # Nemotron-H backbone fields (consumed by vllm NemotronHModel) — copied\n", - " # verbatim from the checkpoint's `nemotron_h_config` block.\n", - " \"model_type\": \"nemotron_h\",\n", - " \"hidden_size\": HIDDEN,\n", - " \"num_hidden_layers\": 31,\n", - " \"vocab_size\": TEXT_VOCAB,\n", - " \"num_attention_heads\": 12,\n", - " \"num_key_value_heads\": 4,\n", - " \"attention_dropout\": 0.0,\n", - " \"attention_bias\": False,\n", - " \"max_position_embeddings\": 8192,\n", - " \"mamba_num_heads\": 64,\n", - " \"mamba_head_dim\": 24,\n", - " \"ssm_state_size\": 128,\n", - " \"conv_kernel\": 4,\n", - " \"n_groups\": 8,\n", - " \"chunk_size\": 256,\n", - " \"mamba_hidden_act\": \"silu\",\n", - " \"use_conv_bias\": True,\n", - " \"use_bias\": False,\n", - " \"intermediate_size\": 4096,\n", - " \"mlp_hidden_act\": \"silu\",\n", - " \"mlp_bias\": False,\n", - " \"n_routed_experts\": 24,\n", - " \"num_experts_per_tok\": 4,\n", - " \"moe_intermediate_size\": 768,\n", - " \"moe_shared_expert_intermediate_size\": 2048,\n", - " \"n_group\": 1,\n", - " \"topk_group\": 1,\n", - " \"routed_scaling_factor\": 2.5,\n", - " \"norm_topk_prob\": True,\n", - " # 31-char layer pattern: M=Mamba2, *=attention, E=MLP/MoE (len == num_hidden_layers).\n", - " \"hybrid_override_pattern\": \"MEMEM*EMEMEM*EMEMEMEM*EMEMEMEME\",\n", - " \"layer_norm_epsilon\": 1e-5,\n", - " \"residual_in_fp32\": False,\n", - " \"tie_word_embeddings\": False,\n", - " # bfloat16, not float32: the Nemotron-H MoE layers run vLLM's fused-MoE\n", - " # Triton kernel, whose block sizes are tuned for 16-bit. In float32 the\n", - " # kernel needs ~2x shared memory and overflows the GPU limit\n", - " # (OutOfResources: shared memory). bf16 also matches the real checkpoint.\n", - " \"torch_dtype\": \"bfloat16\",\n", - " # EasyMagpie-specific scalars (read by EasyMagpieOmniArch.from_hf_config).\n", - " \"text_vocab_size\": TEXT_VOCAB,\n", - " \"embedding_dim\": HIDDEN,\n", - " \"audio_embedding_dim\": HIDDEN,\n", - " \"num_audio_codebooks\": NUM_AUDIO_CODEBOOKS,\n", - " \"codebook_size\": CODEBOOK_SIZE,\n", - " \"frame_stacking_factor\": FRAME_STACKING,\n", - " \"phoneme_stacking_factor\": PHONEME_STACKING,\n", - " \"phoneme_vocab_size\": PHONEME_VOCAB,\n", - " \"local_transformer_n_layers\": 3,\n", - " \"local_transformer_n_heads\": 12,\n", - " \"local_transformer_hidden_dim\": HIDDEN,\n", - "}\n", - "\n", - "MODEL_DIR = Path(tempfile.mkdtemp(prefix=\"easymagpie_dummy_\"))\n", - "(MODEL_DIR / \"config.json\").write_text(json.dumps(config, indent=2))\n", + "# Directory produced by easy_magpietts_convert_to_vllm.py.\n", + "MODEL_DIR = Path(\"/home/vklimkov/workspace/emp/NeMo/examples/tts/easymagpie_vllm_omni/easymp_vllm_model\")\n", + "assert (MODEL_DIR / \"config.json\").exists(), f\"No config.json under {MODEL_DIR}; run the converter first.\"\n", "\n", - "# The model tokenizes the per-request `context_text` string in-engine via\n", - "# `AutoTokenizer.from_pretrained(model_path)` (qwen3-tts style), so the model dir\n", - "# must ship the checkpoint's text-conditioning tokenizer. We save the matching\n", - "# Nemotron-H tokenizer (TEXT_TOKENIZER) into MODEL_DIR.\n", - "from transformers import AutoTokenizer\n", + "# Read the converted config to surface a few scalars used when building the\n", + "# prompt. The engine itself loads everything from MODEL_DIR; we only peek here.\n", + "config = json.loads((MODEL_DIR / \"config.json\").read_text())\n", + "arch = EasyMagpieOmniArch.from_hf_config(type(\"Cfg\", (), config))\n", "\n", - "AutoTokenizer.from_pretrained(TEXT_TOKENIZER, trust_remote_code=True).save_pretrained(MODEL_DIR)\n", - "print(f\"Dummy model dir: {MODEL_DIR}\")\n", + "# Subword id space of the baked text-embedding table (the streaming text stream\n", + "# indexes into it). The model's text BOS/EOS/CFG-UNK ids are the last 3 rows.\n", + "TEXT_VOCAB = int(config[\"text_vocab_size\"])\n", + "TEXT_EOS_ID = TEXT_VOCAB - 2 # matches EasyMagpieTTSInferenceModel.eos_id\n", "\n", - "# Sanity-check the arch the model will derive from this config.\n", - "arch = EasyMagpieOmniArch.from_hf_config(type(\"Cfg\", (), config))\n", + "print(f\"Model dir : {MODEL_DIR}\")\n", "print(f\"embedding_dim : {arch.embedding_dim}\")\n", "print(f\"num_stacked_codebooks : {arch.num_stacked_codebooks} (C*S)\")\n", "print(f\"tokens / codebook : {arch.num_all_tokens_per_codebook} (codebook_size + specials)\")\n", - "print(f\"audio_bos / audio_eos id : {arch.audio_bos_id} / {arch.audio_eos_id}\")" + "print(f\"audio_bos / audio_eos id : {arch.audio_bos_id} / {arch.audio_eos_id}\")\n", + "print(f\"text_vocab / text_eos : {TEXT_VOCAB} / {TEXT_EOS_ID}\")" ] }, { @@ -234,14 +148,12 @@ "AR TTS model these make the runner attach the per-step `audio_codes` multimodal\n", "payload to the output (with `\\\"latent\\\"` the payload is dropped because nothing\n", "downstream consumes it, and `multimodal_output[\\\"audio_codes\\\"]` comes back\n", - "`None`). Two extra knobs make this a dummy-weights run with no external assets:\n", + "`None`).\n", "\n", - "* `load_format: \\\"dummy\\\"` — vLLM initializes random weights instead of reading a\n", - " checkpoint (so `load_weights` / `init_forbidden_mask` are skipped; the\n", - " forbidden-token mask stays all-zeros, i.e. no sampling mask — fine for a smoke\n", - " test).\n", - "* `skip_tokenizer_init: true` — we feed `prompt_token_ids` + `text_tokens`\n", - " directly, so no tokenizer files are needed.\n", + "`skip_tokenizer_init: true` — we feed `prompt_token_ids` + `text_tokens`\n", + "directly, so vLLM doesn't need its own tokenizer for the prompt (the model still\n", + "loads the bundled `AutoTokenizer` from `MODEL_DIR` in-engine to tokenize\n", + "`context_text`).\n", "\n", "`max_model_len` must cover `T_ctx` (prefill) + the number of decode steps." ] @@ -253,12 +165,12 @@ "metadata": {}, "outputs": [], "source": [ - "DECODE_STEPS = 32 # number of audio frames to decode\n", + "DECODE_STEPS = 256 # max number of audio frames to decode (trimmed at audio EOS)\n", "# Prefill length is derived at prompt-build time from the speaker embedding +\n", "# tokenized context_text (see the prompt cell); these just need to be large\n", "# enough to cover prefill + decode.\n", - "MAX_MODEL_LEN = 512\n", - "MAX_NUM_BATCHED_TOKENS = 512\n", + "MAX_MODEL_LEN = 1024\n", + "MAX_NUM_BATCHED_TOKENS = 1024\n", "\n", "stage_cfg = {\n", " \"stage_args\": [\n", @@ -293,8 +205,8 @@ " # sizes are tuned for 16-bit and overflow shared memory in fp32.\n", " \"dtype\": \"bfloat16\",\n", " \"attention_backend\": \"TRITON_ATTN\",\n", - " # --- dummy-weights smoke-test knobs ---\n", - " \"load_format\": \"dummy\",\n", + " # We feed prompt_token_ids + text_tokens directly; the model still\n", + " # loads the bundled AutoTokenizer from MODEL_DIR for context_text.\n", " \"skip_tokenizer_init\": True,\n", " },\n", " \"default_sampling_params\": {\n", @@ -321,7 +233,7 @@ " log_stats=False,\n", " stage_init_timeout=300,\n", ")\n", - "print(\"Engine ready (single stage: EasyMagpie talker, dummy weights)\")" + "print(\"Engine ready (single stage: EasyMagpie talker)\")" ] }, { @@ -335,14 +247,15 @@ "\n", "* **`speaker_embedding`** `(T_audio, embedding_dim)` — the speaker-encoded\n", " context-audio embedding (the audio branch of `prepare_context_tensors`),\n", - " loaded here from `eng_speaker_emb.pt` (as written by\n", - " `easy_magpietts_extract_speaker_encoding.py`). The model assembles the full\n", - " prefill context itself as `[task_embedding? | speaker_embedding |\n", - " context_text_embedded]`.\n", + " loaded here from `MODEL_DIR/speaker_embeddings/.pt` (written by\n", + " the converter). The model assembles the full prefill context itself as\n", + " `[task_embedding? | speaker_embedding | context_text_embedded]`.\n", "* **`context_text`** — a plain conditioning string, here `\"[EN]\"`. The model\n", " tokenizes it in-engine and embeds it through the baked `text_embedding` table.\n", - "* **`text_tokens`** `list[int]` — the streaming subword stream; decode step `k`\n", - " consumes `text_tokens[k]`. We provide one id per decode step.\n", + "* **`text_tokens`** `list[int]` — the streaming subword stream: the target\n", + " sentence tokenized with the bundled tokenizer, ending with the model's text\n", + " EOS id. Decode step `k` consumes `text_tokens[k]`; once exhausted the channel\n", + " is masked off (matching the reference `... encode(transcript) + [eos_id]`).\n", "\n", "`prompt_token_ids = [0] * prompt_len` are placeholders (the model feeds the\n", "backbone via `inputs_embeds`, never these ids). `prompt_len` must equal the\n", @@ -367,17 +280,21 @@ "from easymagpie_vllm_omni.easymagpie import EasyMagpieTTSForConditionalGeneration\n", "\n", "# Speaker-encoded context audio (audio branch of prepare_context_tensors),\n", - "# produced by easy_magpietts_extract_speaker_encoding.py.\n", - "SPEAKER_EMB_FILE = \"eng_speaker_emb.pt\"\n", - "_loaded = torch.load(SPEAKER_EMB_FILE, map_location=\"cpu\")\n", + "# pre-computed by the converter into MODEL_DIR/speaker_embeddings/.pt.\n", + "SPEAKER_NAME = \"eng\"\n", + "_loaded = torch.load(MODEL_DIR / \"speaker_embeddings\" / f\"{SPEAKER_NAME}.pt\", map_location=\"cpu\")\n", "speaker_embedding = _loaded[\"speaker_encoding\"] if isinstance(_loaded, dict) else _loaded\n", "speaker_embedding = speaker_embedding.to(torch.float32)\n", "\n", "# Plain conditioning string; the model tokenizes + embeds it in-engine.\n", "CONTEXT_TEXT = \"[EN]\"\n", "\n", - "# Same tokenizer the engine loads from MODEL_DIR — used to size the prefill\n", - "# placeholders so prompt_token_ids length matches the assembled context.\n", + "# Target sentence to synthesize.\n", + "TEXT = \"Hello, this is a test of the EasyMagpie text to speech model.\"\n", + "\n", + "# Same tokenizer the engine loads from MODEL_DIR. Used to (a) size the prefill\n", + "# placeholders so prompt_token_ids length matches the assembled context, and\n", + "# (b) tokenize the target sentence into the streaming text stream.\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)\n", "prompt_len = EasyMagpieTTSForConditionalGeneration.estimate_prompt_len(\n", " speaker_embedding,\n", @@ -386,8 +303,10 @@ " has_task_embedding=arch.num_task_embeddings > 0,\n", ")\n", "\n", - "# Streaming subword ids: one per decode step (step k consumes text_tokens[k]).\n", - "text_tokens = torch.randint(0, TEXT_VOCAB, (DECODE_STEPS,)).tolist()\n", + "# Streaming subword ids consumed one per decode step. Mirrors the reference\n", + "# `encode(transcript) + [eos_id]` (no BOS; HF special tokens disabled so the ids\n", + "# index the baked text_embedding table directly).\n", + "text_tokens = tokenizer.encode(TEXT, add_special_tokens=False) + [TEXT_EOS_ID]\n", "\n", "additional_information = {\n", " \"speaker_embedding\": speaker_embedding, # (T_audio, embedding_dim) tensor\n", @@ -407,15 +326,16 @@ "\n", "print(f\"speaker_embedding : {tuple(speaker_embedding.shape)}\")\n", "print(f\"context_text : {CONTEXT_TEXT!r} -> {tokenizer.encode(CONTEXT_TEXT)}\")\n", + "print(f\"text : {TEXT!r}\")\n", + "print(f\"text_tokens (len {len(text_tokens):3d}) : {text_tokens[:8]}{' ...' if len(text_tokens) > 8 else ''}\")\n", "print(f\"prompt_len (placeholders) : {prompt_len}\")\n", "print(f\"decode steps (max_tokens) : {DECODE_STEPS}\")\n", - "print(f\"text_tokens[:8] : {text_tokens[:8]}\")\n", "\n", "sampling_params = SamplingParams(\n", - " temperature=0.0,\n", + " temperature=0.0, # backbone token sampler is a no-op (audio is sampled in the local transformer)\n", " max_tokens=DECODE_STEPS,\n", " detokenize=False,\n", - " ignore_eos=True, # dummy logits never emit a meaningful EOS -> run the full budget\n", + " ignore_eos=True, # audio EOS lives in the codes, not the vLLM token stream -> run the budget + trim\n", ")" ] }, @@ -431,7 +351,8 @@ "`multimodal_output[\\\"audio_codes\\\"]` holds one row per flat-batch token over the\n", "whole run (the `T_ctx` prefill frames — codes left zero — plus one frame per\n", "decode step), so we trim to the last `len(token_ids)` rows to recover just the\n", - "decoded frames." + "decoded frames, then trim again at the audio EOS frame (the model signals\n", + "end-of-speech in the codes, not in the vLLM token stream)." ] }, { @@ -471,6 +392,17 @@ " # frames (the last len(token_ids) rows), exactly like the eartts demo.\n", " if len(token_ids) > 0:\n", " audio_codes = audio_codes[-len(token_ids):].contiguous()\n", + "\n", + " # Trim at the audio EOS: the model signals end-of-speech inside the codes\n", + " # (codebook 0 == audio_eos_id), not via the vLLM token stream.\n", + " eos_frames = (audio_codes[:, 0] == arch.audio_eos_id).nonzero(as_tuple=True)[0]\n", + " if eos_frames.numel() > 0:\n", + " eos_idx = int(eos_frames[0])\n", + " print(f\"audio EOS at frame : {eos_idx} / {audio_codes.shape[0]}\")\n", + " audio_codes = audio_codes[:eos_idx].contiguous()\n", + " else:\n", + " print(f\"no audio EOS within budget ({DECODE_STEPS} frames); using full decode\")\n", + "\n", " print(f\"audio_codes shape (decode) : {tuple(audio_codes.shape)}\")\n", " print(f\"audio_codes dtype : {audio_codes.dtype}\")\n", " print(f\"codes min / max : {int(audio_codes.min())} / {int(audio_codes.max())}\")\n", @@ -493,13 +425,124 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "a32b07d5", + "metadata": {}, + "source": [ + "## 5. Decode audio codes to a waveform\n", + "\n", + "The engine emits **stacked** codebooks: `audio_codes` is `(T, C*S)` where\n", + "`C = num_audio_codebooks` and `S = frame_stacking_factor` (here `C*S = 16`).\n", + "To turn them back into a waveform we mirror what\n", + "[`EasyMagpieTTSInferenceModel.streaming_finalize`](../../../nemo/collections/tts/models/easy_magpietts_inference.py)\n", + "does at the end of inference:\n", + "\n", + "1. **load the codec** — the `.nemo` audio codec used to train the model\n", + " (`AudioCodecModel.restore_from(...)`, discriminator dropped to save memory),\n", + "2. **unstack** `(T, C*S)` -> `(1, C, T*S)` — the inverse of `stack_codes`,\n", + "3. **convert codec tokens** — this checkpoint was trained on a *regrouped* FSQ\n", + " index space (8 codebooks of size 1024) that differs from the codec's native\n", + " `GroupFiniteScalarQuantizer` (4 codebooks), so we map the model's tokens back\n", + " to the codec's space (`convert_new_to_original`) before decoding. The\n", + " `vector_quantizer` config is read straight from the source EasyMagpie `.nemo`\n", + " (config only, no weights). If the two spaces match, this step is skipped.\n", + "4. **decode** `codec_model.decode(tokens=..., tokens_len=...)` -> waveform at\n", + " `codec_model.output_sample_rate`.\n", + "\n", + "> Set `CODEC_MODEL_PATH` / `EASYMAGPIE_NEMO` to the **same** `.nemo` files passed\n", + "> to `easy_magpietts_convert_to_vllm.py` (`--codec_model_path` / `--nemo_file`).\n", + "> This step needs NeMo importable in the current environment." + ] + }, { "cell_type": "code", "execution_count": null, - "id": "3a6603b9", + "id": "aa57a573", "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from hydra.utils import instantiate\n", + "from IPython.display import Audio, display\n", + "\n", + "from nemo.collections.tts.models import AudioCodecModel\n", + "from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel\n", + "from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter\n", + "\n", + "# Same .nemo codec passed to easy_magpietts_convert_to_vllm.py --codec_model_path.\n", + "CODEC_MODEL_PATH = \"/home/vklimkov/workspace/emp/ckpt/easymagpietts_NEXT/25fps_spectral_codec_with_bandwidth_extension.nemo\"\n", + "\n", + "# --- load the codec once (drop the discriminator to save memory) ---\n", + "_codec_cfg = AudioCodecModel.restore_from(CODEC_MODEL_PATH, return_config=True)\n", + "if \"use_scl_loss\" in _codec_cfg:\n", + " _codec_cfg.use_scl_loss = False\n", + "codec_model = AudioCodecModel.restore_from(\n", + " CODEC_MODEL_PATH, strict=False, override_config_path=_codec_cfg\n", + ")\n", + "if hasattr(codec_model, \"discriminator\"):\n", + " del codec_model.discriminator\n", + "codec_model = codec_model.eval().to(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "codec_device = next(codec_model.parameters()).device\n", + "\n", + "# Source EasyMagpie .nemo (--nemo_file). Only its config is read (no weights),\n", + "# to recover the `vector_quantizer` override the model was trained with.\n", + "EASYMAGPIE_NEMO = \"/home/vklimkov/workspace/emp/ckpt/easymagpietts_NEXT/2605_NemotronTTS_V0.2/v2/2605_EMTTS_SmallMamba_Step150k_posttrained_epoch12.nemo\"\n", + "\n", + "# --- optional codec-token converter ---------------------------------------\n", + "# EasyMagpie can be trained on a *regrouped* FSQ index space (here C=8 codebooks\n", + "# of size 1024) that differs from the codec's native quantizer (this codec's\n", + "# GroupFiniteScalarQuantizer has 4 codebooks). When they differ the model's\n", + "# tokens must be mapped back to the codec's space before decoding, exactly as\n", + "# EasyMagpieTTSInferenceModel does via `_codec_converter` /\n", + "# `CodecHelper.codes_to_audio`.\n", + "_em_cfg = EasyMagpieTTSInferenceModel.restore_from(EASYMAGPIE_NEMO, return_config=True)\n", + "_vq_cfg = _em_cfg.get(\"vector_quantizer\")\n", + "if _vq_cfg is not None and instantiate(_vq_cfg).num_codebooks != codec_model.vector_quantizer.num_codebooks:\n", + " codec_converter = VectorQuantizerIndexConverter(\n", + " vector_quantizer_original=codec_model.vector_quantizer,\n", + " vector_quantizer_new=instantiate(_vq_cfg),\n", + " ).to(codec_device)\n", + "else:\n", + " codec_converter = None\n", + "print(f\"codec native codebooks : {codec_model.vector_quantizer.num_codebooks}\")\n", + "print(f\"codec token converter : {'enabled' if codec_converter is not None else 'not needed'}\")\n", + "\n", + "S = arch.frame_stacking_factor # stacking factor (sub-frames per stacked frame)\n", + "C = arch.num_stacked_codebooks // S # real codec codebooks\n", + "assert audio_codes.dim() == 2 and audio_codes.size(1) == arch.num_stacked_codebooks, (\n", + " f\"expected audio_codes (T, {arch.num_stacked_codebooks}); got {tuple(audio_codes.shape)}\"\n", + ")\n", + "\n", + "# --- unstack (T, C*S) -> (1, C, T*S): inverse of EasyMagpie stack_codes ---\n", + "stacked = audio_codes.to(codec_device, torch.long).T.unsqueeze(0) # (1, C*S, T)\n", + "T_out = stacked.size(-1)\n", + "codes = stacked.view(1, C, S, T_out).permute(0, 1, 3, 2).reshape(1, C, T_out * S) # (1, C, T*S)\n", + "codes_len = torch.tensor([codes.size(-1)], device=codec_device, dtype=torch.long)\n", + "\n", + "# Pad very short sequences (codec needs a few frames), matching _prepare_codes_for_decode.\n", + "MIN_LEN = 4\n", + "if int(codes_len.min()) < MIN_LEN:\n", + " codes = torch.nn.functional.pad(codes, (0, MIN_LEN - int(codes_len.min())), value=0)\n", + " codes_len = codes_len.clamp(min=MIN_LEN)\n", + "\n", + "# Drop any stray special tokens (BOS/EOS/MASK live at codebook_size..) so every\n", + "# index is a valid codec entry before decoding.\n", + "codes = codes.clamp_(0, arch.codebook_size - 1)\n", + "\n", + "# --- decode codes -> waveform (mirrors CodecHelper.codes_to_audio) ---\n", + "with torch.no_grad(), torch.autocast(device_type=codec_device.type, dtype=torch.float32):\n", + " if codec_converter is not None:\n", + " codes = codec_converter.convert_new_to_original(audio_tokens=codes, audio_lens=codes_len)\n", + " audio, audio_len = codec_model.decode(tokens=codes, tokens_len=codes_len)\n", + "\n", + "waveform = audio[0, : int(audio_len[0])].detach().cpu().float().numpy()\n", + "sample_rate = int(codec_model.output_sample_rate)\n", + "\n", + "print(f\"codes (unstacked) shape : {tuple(codes.shape)} (1, C={C}, T*S={codes.size(-1)})\")\n", + "print(f\"waveform samples : {waveform.shape[0]} ({waveform.shape[0] / sample_rate:.2f}s @ {sample_rate} Hz)\")\n", + "\n", + "display(Audio(waveform, rate=sample_rate))" + ] } ], "metadata": { diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py index cb5e8bd346f4..fd0ff1b79f13 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -665,11 +665,14 @@ def _preprocess_decode( ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: decode_offset = int(info_dict.get("ear_decode_offset", 0) or 0) - # Text channel (streaming list that grows by one per step). + # Text channel (streaming list, one subword consumed per step). Step k + # consumes text_tokens[k] (the list ends with the text eos id). Once the + # stream is exhausted the channel is masked off (adds nothing) — matching + # the reference ``text_finished`` behaviour, which stops adding text after + # EOS rather than repeating the last token. text_tokens = info_dict.get("text_tokens") - if isinstance(text_tokens, list) and text_tokens: - idx = min(decode_offset, len(text_tokens) - 1) - self._dec_text_tokens[start] = int(text_tokens[idx]) + if isinstance(text_tokens, list) and decode_offset < len(text_tokens): + self._dec_text_tokens[start] = int(text_tokens[decode_offset]) self._dec_text_mask[start] = 1 else: self._dec_text_mask[start] = 0 @@ -782,6 +785,14 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: target.data.copy_(tensor.to(target.dtype)) loaded.add(mapped) + # ``NemotronHModel.load_weights`` (the inner model) does *not* apply the + # HF->vLLM renaming that lives on the ``NemotronHForCausalLM`` wrapper, so + # raw HF names such as ``embeddings.weight`` / ``...mixer.A_log`` would not + # match the inner param names (``embed_tokens.weight`` / ``...mixer.A``). + # Apply that mapper here so the converted checkpoint can keep stock HF + # Nemotron-H names. The wrapper's ``backbone -> model`` prefix rule is a + # no-op here because we already stripped the ``decoder.`` prefix. + backbone_weights = list(NemotronHForCausalLM.hf_to_vllm_mapper.apply(backbone_weights)) backbone_loaded = self.backbone.load_weights(backbone_weights) loaded |= {f"backbone.{n}" for n in backbone_loaded} From 85be1282da9d5d4fe3010831f1e1e83163d2a344 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 2 Jun 2026 18:20:44 +0200 Subject: [PATCH 07/15] examples/tts/easymagpie_vllm_omni: clean up, add readme Signed-off-by: Viacheslav Klimkov --- examples/tts/easymagpie_vllm_omni/README.md | 27 +++ .../easymagpie_vllm_omni/easymagpie.py | 166 ++++++++---------- .../easymagpie_vllm_omni/local_transformer.py | 90 ++++------ 3 files changed, 134 insertions(+), 149 deletions(-) create mode 100644 examples/tts/easymagpie_vllm_omni/README.md diff --git a/examples/tts/easymagpie_vllm_omni/README.md b/examples/tts/easymagpie_vllm_omni/README.md new file mode 100644 index 000000000000..27cc34a08d34 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/README.md @@ -0,0 +1,27 @@ +WIP model definition of EasyMP for vllm-omni. Follows footsteps of qwen3tts: +backbone and LT are compiled into a single cuda graph during uniform batch decoding, +piecewise during mixed/prefill. + +Install: +``` +pip install -e ".[all]" +pip install ninja mamba_ssm causal_conv1d --no-build-isolation +# install vllm +pip install vllm==0.21.0 vllm_omni==0.21.0rc1 +# register vllm models +pip install -e examples/tts/easymagpie_vllm_omni/ +``` + +Conver the checkpoint from +https://huggingface.co/nvidia/easymagpietts_NEXT/tree/main/2605_NemotronTTS_V0.2/v2 +``` +python examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py \ + --nemo_file /2605_EMTTS_SmallMamba_Step150k_posttrained_epoch12.nemo \ + --codec_model_path /25fps_spectral_codec_with_bandwidth_extension.nemo \ + --outdir examples/tts/easymagpie_vllm_omni/easymp_vllm_model \ + --context_audio english_sample.wav --speaker_name eng \ + --phoneme_tokenizer_path /bpe_ipa_tokenizer_2048_en_de_es_fr_hi_it_vi_zh_ko-KR_pt-BR_ar.json +``` + +Finally run notebook `examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb` +to predict acoustic tokens \ No newline at end of file diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py index fd0ff1b79f13..c1d173a92402 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -13,53 +13,46 @@ # limitations under the License. """Inference-only EasyMagpieTTS model for vLLM-Omni. -EasyMagpieTTS is a decoder-only streaming TTS model: a text-LM backbone (the -SmallMamba checkpoint uses a Nemotron-H hybrid Mamba2 + attention + MoE decoder) -consumes a per-frame additive input embedding (text + phoneme + audio) and -emits a per-frame hidden state, from which a small autoregressive *local -transformer* samples all ``C * S`` stacked audio codebooks for that frame -(see :mod:`easymagpie_vllm_omni.local_transformer`). +EasyMagpieTTS is a decoder-only streaming TTS model. A Nemotron-H hybrid +(Mamba2 + attention + MoE) text-LM backbone consumes a per-frame additive input +embedding (text + phoneme + audio) and emits a per-frame hidden state. A small +autoregressive *local transformer* then samples all ``C * S`` stacked audio +codebooks for that frame (see :mod:`easymagpie_vllm_omni.local_transformer`). This module wires that architecture into vLLM-Omni's ``preprocess`` / ``forward`` / ``compute_logits`` / ``make_omni_output`` / -``postprocess`` contract, following the same conventions as the upstream -qwen3-tts and eartts vLLM-Omni model definitions: +``postprocess`` contract: * **Backbone** — vLLM's - :class:`~vllm.model_executor.models.nemotron_h.NemotronHModel`, reused - wholesale (hybrid Mamba2 state + KV cache + paged attention) exactly like the - EasyMagpie vLLM *sidecar*. Every step feeds the backbone via ``inputs_embeds``; - its own ``embed_tokens`` table is never consumed. Because the backbone is a - hybrid-Mamba model, the class implements vLLM's - :class:`HasInnerState` / :class:`IsHybrid` / :class:`SupportsMambaPrefixCaching` - contracts (mamba-state shape/dtype/copy helpers are delegated to - :class:`NemotronHForCausalLM`), and the SmallMamba SiLU shared-experts fix is + :class:`~vllm.model_executor.models.nemotron_h.NemotronHModel` is reused + wholesale (hybrid Mamba2 state + KV cache + paged attention). Every step feeds + the backbone via ``inputs_embeds``; its own ``embed_tokens`` table is never + consumed. Because the backbone is a hybrid-Mamba model, the class implements + vLLM's :class:`HasInnerState` / :class:`IsHybrid` / + :class:`SupportsMambaPrefixCaching` contracts (mamba-state helpers are + delegated to :class:`NemotronHForCausalLM`), and a SiLU shared-experts fix is applied at construction (see :mod:`easymagpie_vllm_omni.backbone_patches`). -* **Local transformer** — :class:`EasyMagpieCodePredictor`, a from-scratch, - CUDA-graph-capturable re-implementation that runs as a single compiled graph. -* **compute_logits** — returns trivial logits (à la eartts) so vLLM's sampler - always picks index 0; the real audio output is the codes tensor surfaced - through :meth:`make_omni_output` under the ``"audio_codes"`` key. +* **Local transformer** — :class:`EasyMagpieCodePredictor`, a + CUDA-graph-capturable implementation that runs as a single compiled graph. +* **compute_logits** — returns trivial logits so vLLM's sampler always picks + index 0; the real audio output is the codes tensor surfaced through + :meth:`make_omni_output` under the ``"audio_codes"`` key. Text is embedded via a precomputed per-subword lookup table baked at -checkpoint-conversion time (the reference char-aware subword encoder is -deterministic per subword id, so it is never run inside the engine). +checkpoint-conversion time, so the char-aware subword encoder is never run +inside the engine. Per-request I/O (via ``additional_information``): -* ``speaker_embedding`` (prefill only) — ``(T_audio, embedding_dim)`` speaker- - encoded context-audio embedding (the audio branch of the reference - ``prepare_context_tensors``), e.g. the tensor saved by - ``easy_magpietts_extract_speaker_encoding.py``. ``preprocess`` assembles the - full prefill context embedding itself as - ``[task_embedding | speaker_embedding | context_text_embedded]`` — the same - layout the reference model builds — so the caller only does the speaker-encoder - math and passes plain context text (the model tokenizes + embeds it and - prepends the per-mode service token). +* ``speaker_embedding`` (prefill only) — ``(T_audio, embedding_dim)`` + speaker-encoded context-audio embedding. ``preprocess`` assembles the full + prefill context embedding itself as + ``[task_embedding | speaker_embedding | context_text_embedded]``, so the + caller only does the speaker-encoder math and passes plain context text (the + model tokenizes + embeds it and prepends the per-mode service token). * ``context_text`` (prefill only, optional) — plain conditioning string (e.g. ``"[EN]"``); tokenized in-model with the checkpoint's text tokenizer and - embedded through the baked per-subword ``text_embedding`` table. Defaults to - ``"[NO TEXT CONTEXT]"`` when omitted. + embedded through the baked per-subword ``text_embedding`` table. * ``task_mode_id`` (prefill only, optional) — int selecting the per-mode task ("service token") embedding row; defaults to ``0``. Ignored for single-mode checkpoints (no ``task_embedding`` table). @@ -111,12 +104,10 @@ _DEFAULT_CONTEXT_TEXT = "[EN]" -# NOTE: unlike the Qwen2 backbone variant, this class is *not* wrapped in -# ``@support_torch_compile``. The Nemotron-H backbone is a hybrid-Mamba model -# that manages its own ``torch.compile`` / CUDA-graph capture internally (as -# does :class:`EasyMagpieCodePredictor`), so the outer ``forward`` runs eagerly -# and dispatches into the two self-compiled subgraphs — matching the EasyMagpie -# vLLM sidecar (``EasyMagpieSmallMamba``). +# This class is not wrapped in ``@support_torch_compile``: the Nemotron-H +# backbone and :class:`EasyMagpieCodePredictor` each manage their own +# ``torch.compile`` / CUDA-graph capture internally, so the outer ``forward`` +# runs eagerly and dispatches into the two self-compiled subgraphs. class EasyMagpieTTSForConditionalGeneration( nn.Module, HasInnerState, @@ -131,8 +122,8 @@ class EasyMagpieTTSForConditionalGeneration( ``OmniGPUModelRunner``. """ - # Hybrid-Mamba bookkeeping (delegated to vLLM's NemotronH causal-LM, exactly - # like the EasyMagpie sidecar). vLLM expects these as class attributes. + # Hybrid-Mamba bookkeeping (delegated to vLLM's NemotronH causal-LM). vLLM + # expects these as class attributes. get_mamba_state_dtype_from_config = NemotronHForCausalLM.get_mamba_state_dtype_from_config get_mamba_state_shape_from_config = NemotronHForCausalLM.get_mamba_state_shape_from_config get_mamba_state_copy_func = NemotronHForCausalLM.get_mamba_state_copy_func @@ -167,9 +158,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: vllm_config=vllm_config, prefix=maybe_prefix(prefix, "backbone"), ) - # SmallMamba was trained with mlp_hidden_act=silu but vLLM's NemotronHMLP - # hard-codes ReLU² in shared_experts. Restore SiLU (no-op when the - # backbone has no MoE layers). + # The checkpoint was trained with mlp_hidden_act=silu but vLLM's + # NemotronHMLP hard-codes ReLU² in shared_experts. Restore SiLU (no-op + # when the backbone has no MoE layers). patch_silu_shared_experts(self.backbone) # ── Local transformer (its own compile group / CUDA graph) ────── @@ -180,26 +171,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: ) # ── Text + phoneme embedding heads ────────────────────────────── - # Precomputed per-subword text embedding. The reference model embeds - # text with a char-aware subword (CAS) encoder + the decoder's subword - # table; both are deterministic per subword id, so the checkpoint - # converter bakes their combined result into this single lookup table - # (one row per subword id). It is fed additively on every decode step; - # the CAS encoder is never run inside the engine. + # Precomputed per-subword text embedding (one row per subword id), baked + # at conversion time and fed additively on every decode step. text_vocab_size = int(getattr(hf_config, "text_vocab_size", getattr(hf_config, "vocab_size", 0))) self.text_embedding = nn.Embedding(text_vocab_size, self.embedding_dim) - # Task ("service token") embedding — a single learned per-mode row the - # reference model prepends to the prefill context when trained with >1 - # mode. Built only when the checkpoint carries one; otherwise ``None``. + # Task ("service token") embedding — a single learned per-mode row + # prepended to the prefill context for multi-mode checkpoints. Built only + # when the checkpoint carries one; otherwise ``None``. self.num_task_embeddings = int(arch.num_task_embeddings) if self.num_task_embeddings > 0: self.task_embedding = nn.Embedding(self.num_task_embeddings, self.embedding_dim) else: self.task_embedding = None - # Context-text tokenizer, loaded lazily from the model directory (same - # ``AutoTokenizer.from_pretrained(model_path)`` pattern as qwen3-tts). It + # Context-text tokenizer, loaded lazily from the model directory. It # turns the per-request ``context_text`` string (e.g. ``"[EN]"``) into the # subword ids that the baked ``text_embedding`` table consumes — so the # caller passes plain text, never pre-tokenized ids. @@ -295,8 +281,6 @@ def _select_query_layout(attn_metadata): def _get_decode_idxs(self): """Return ``(decode_token_indices, num_requests)`` for code-predictor dispatch. - Mirrors the qwen3-tts / eartts pattern: - * ``(None, 0)`` → run the local transformer on every token (profile / dummy run with no ``attn_metadata``, or a decode-only batch where ``max_query_len == 1``), so the captured CUDA graph covers every @@ -434,9 +418,9 @@ def _predict_phonemes(self, hidden_states: torch.Tensor, idx) -> None: def compute_logits(self, hidden_states, sampling_metadata: Any = None) -> Optional[torch.Tensor]: """Return zero logits so vLLM's sampler always picks index 0. - The width is taken from ``hf_config.vocab_size`` so the sampler's - working buffers match. The sampled id is irrelevant — audio is surfaced - via :meth:`make_omni_output`. + The width is taken from ``hf_config.vocab_size`` so the sampler's working + buffers match. The sampled id is irrelevant — audio is surfaced via + :meth:`make_omni_output`. """ if isinstance(hidden_states, OmniOutput): hidden_states = hidden_states.text_hidden_states @@ -520,7 +504,7 @@ def _preprocess_prefill( ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: prefill_embeds = self._build_prefill_embeds(device, info_dict) - offset = int(info_dict.get("ear_prefill_offset", 0) or 0) + offset = int(info_dict.get("prefill_offset", 0) or 0) total = int(prefill_embeds.shape[0]) s = max(0, min(offset, total)) e = max(0, min(offset + span_len, total)) @@ -535,8 +519,8 @@ def _preprocess_prefill( take = torch.cat([take, pad_rows], dim=0) info_update = { - "ear_prefill_offset": offset + span_len, - "ear_decode_offset": 0, + "prefill_offset": offset + span_len, + "decode_offset": 0, } input_ids_out = torch.full_like(input_ids, _DUMMY_TOKEN_ID) return input_ids_out, take, info_update @@ -546,23 +530,17 @@ def _build_prefill_embeds( device: torch.device, info_dict: dict[str, Any], ) -> torch.Tensor: - """Assemble the full ``(T_ctx, embedding_dim)`` prefill context embedding. - - Reproduces the prefill assembly from the reference - ``prepare_context_tensors``:: + """Assemble the full ``(T_ctx, embedding_dim)`` prefill context embedding:: [task_embedding | speaker_embedding | context_text_embedded] from the per-request inputs: - * ``speaker_embedding`` — the speaker-encoded context-audio embedding - (e.g. produced by ``easy_magpietts_extract_speaker_encoding.py``), + * ``speaker_embedding`` — the speaker-encoded context-audio embedding, required as a 2-D ``(T_audio, embedding_dim)`` tensor. * ``context_text`` — a plain string (e.g. ``"[EN]"``); tokenized in-model (see :meth:`_encode_context_text`) and embedded through the baked - per-subword ``text_embedding`` table (which already folds in the CAS - encoder, matching the default ``disable_cas_for_context_text=False`` - training). Defaults to ``"[NO TEXT CONTEXT]"`` when omitted. + per-subword ``text_embedding`` table. * ``task_mode_id`` — selects the per-mode task ("service token") embedding row; prepended only when the checkpoint has a task table. @@ -602,9 +580,9 @@ def _build_prefill_embeds( def _get_text_tokenizer(self): """Lazily load the context-text tokenizer from the model directory. - Mirrors qwen3-tts: the converted checkpoint ships a HuggingFace - ``AutoTokenizer`` (the model's text-conditioning tokenizer) alongside its - weights, so we load it on first use from ``model_path``. + The converted checkpoint ships a HuggingFace ``AutoTokenizer`` (the + model's text-conditioning tokenizer) alongside its weights, so we load it + on first use from ``model_path``. """ if self._text_tokenizer is None: from transformers import AutoTokenizer @@ -613,12 +591,11 @@ def _get_text_tokenizer(self): return self._text_tokenizer def _encode_context_text(self, context_text: str, device: torch.device) -> torch.Tensor: - """Tokenize ``context_text`` to subword ids (matching the reference encode path). + """Tokenize ``context_text`` to subword ids. - The reference ``AggregatedTTSTokenizer.encode`` calls the underlying - HF tokenizer's ``encode`` (default ``add_special_tokens``) for the - text-conditioning tokenizer, which sits at offset 0 in the aggregate, so - its raw ids index the baked ``text_embedding`` table directly. + The text-conditioning tokenizer sits at offset 0 in the model's + tokenizer aggregate, so its raw ids index the baked ``text_embedding`` + table directly. """ tok = self._get_text_tokenizer() ids = tok.encode(context_text) @@ -632,8 +609,7 @@ def estimate_prompt_len( context_text: str = _DEFAULT_CONTEXT_TEXT, has_task_embedding: bool = False, ) -> int: - """Length-only mirror of :meth:`_build_prefill_embeds` (à la qwen3-tts's - ``estimate_prompt_len_from_additional_information``). + """Length-only mirror of :meth:`_build_prefill_embeds`. The engine assembles the prefill context as ``[task_embedding? | speaker_embedding | context_text_embedded]``, so the @@ -663,13 +639,12 @@ def _preprocess_decode( device: torch.device, info_dict: dict[str, Any], ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: - decode_offset = int(info_dict.get("ear_decode_offset", 0) or 0) + decode_offset = int(info_dict.get("decode_offset", 0) or 0) # Text channel (streaming list, one subword consumed per step). Step k # consumes text_tokens[k] (the list ends with the text eos id). Once the - # stream is exhausted the channel is masked off (adds nothing) — matching - # the reference ``text_finished`` behaviour, which stops adding text after - # EOS rather than repeating the last token. + # stream is exhausted the channel is masked off (adds nothing) rather than + # repeating the last token. text_tokens = info_dict.get("text_tokens") if isinstance(text_tokens, list) and decode_offset < len(text_tokens): self._dec_text_tokens[start] = int(text_tokens[decode_offset]) @@ -699,7 +674,7 @@ def _preprocess_decode( self._dec_audio_valid[start] = 1 inputs_embeds_out = torch.zeros((1, self.embedding_dim), device=device, dtype=self._combined_embeddings.dtype) - info_update = {"ear_decode_offset": decode_offset + 1} + info_update = {"decode_offset": decode_offset + 1} return input_ids, inputs_embeds_out, info_update def postprocess(self, hidden_states: torch.Tensor, multimodal_outputs: Optional[dict[str, Any]] = None, **_: Any): @@ -722,7 +697,7 @@ def postprocess(self, hidden_states: torch.Tensor, multimodal_outputs: Optional[ # weight loading # ------------------------------------------------------------------ - # Checkpoint prefixes (reference EasyMagpieTTS state dict) → in-model paths. + # Checkpoint prefixes (EasyMagpieTTS state dict) → in-model paths. # ``decoder.*`` is fed to the vLLM backbone loader separately (it understands # HF Nemotron-H naming + Mamba/MoE packing). The TTS submodules are copied # manually. @@ -749,13 +724,12 @@ def _remap_tts_key(self, name: str) -> Optional[str]: def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: """Load backbone (Nemotron-H) + TTS submodule weights from a converted checkpoint. - The converted checkpoint is expected to use the reference EasyMagpieTTS - key layout: the backbone under ``decoder.*`` (HF Nemotron-H names) and - the TTS submodules at top level (``audio_embeddings.*``, - ``local_transformer.*``, ``phoneme_*``, ``text_embedding.*``, projection - heads). Backbone weights are routed to :meth:`NemotronHModel.load_weights` - (which handles HF naming + Mamba/MoE packing); TTS weights are copied - directly by name. + The converted checkpoint carries the backbone under ``decoder.*`` (HF + Nemotron-H names) and the TTS submodules at top level + (``audio_embeddings.*``, ``local_transformer.*``, ``phoneme_*``, + ``text_embedding.*``, projection heads). Backbone weights are routed to + :meth:`NemotronHModel.load_weights` (which handles HF naming + Mamba/MoE + packing); TTS weights are copied directly by name. """ own_params = dict(self.named_parameters()) loaded: set[str] = set() diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py index b48715604530..aab5fe2f224b 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/local_transformer.py @@ -11,29 +11,24 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""From-scratch autoregressive local transformer for EasyMagpieTTS on vLLM-Omni. - -The reference EasyMagpieTTS model predicts the ``C * S`` stacked audio -codebooks of one frame *autoregressively* with a small causal transformer -(``nemo.collections.tts.modules.transformer_2501.Transformer``) conditioned on -the backbone's per-frame hidden state. The reference implementation re-creates -fresh tensors and (optionally) a KV cache on every codebook step, which is -incompatible with CUDA-graph replay. - -This module re-implements that local transformer from scratch so it can run as -a single compiled CUDA graph: - -* :class:`EasyMagpieLocalTransformer` mirrors the ``transformer_2501`` - layer/weight layout **exactly** (so a stock checkpoint loads 1:1) but uses - ``scaled_dot_product_attention`` and drops the KV cache / padding-mask - plumbing. It is decorated with ``@support_torch_compile`` so vLLM captures - one CUDA graph for the fixed ``(num_tokens, num_stacked_codebooks, hidden)`` - input shape. +"""Autoregressive local transformer for EasyMagpieTTS on vLLM-Omni. + +EasyMagpieTTS predicts the ``C * S`` stacked audio codebooks of one frame +*autoregressively* with a small causal transformer conditioned on the backbone's +per-frame hidden state. This module implements that local transformer so it can +run as a single compiled CUDA graph: + +* :class:`EasyMagpieLocalTransformer` is a causal transformer stack with + learnable positional embeddings, using ``scaled_dot_product_attention`` and no + KV cache. It is decorated with ``@support_torch_compile`` so vLLM captures one + CUDA graph for the fixed ``(num_tokens, num_stacked_codebooks, hidden)`` input + shape. Its layer/weight layout matches the training checkpoint so weights load + 1:1. * :class:`EasyMagpieCodePredictor` owns the persistent, address-stable scratch buffers and runs the per-frame autoregressive loop, re-invoking the compiled - transformer once per codebook over the **same** buffer (matching the - qwen3-tts code-predictor trick: replaying one fixed-shape graph N times is - faster and simpler than capturing N separate graphs). + transformer once per codebook over the **same** buffer (replaying one + fixed-shape graph N times is faster and simpler than capturing N separate + graphs). All sampling is CUDA-graph safe (Gumbel-max + ``topk`` + ``masked_fill`` only; no host syncs, no ``multinomial`` on possibly-degenerate warmup data). @@ -96,12 +91,11 @@ def sample_codebook( class EasyMagpieLTSelfAttention(nn.Module): - """Causal self-attention matching ``transformer_2501.SelfAttention`` weights. + """Causal self-attention. - Same projections (``qkv_net`` fused QKV without bias, ``o_net`` without - bias) and the same ``d_head ** -0.5`` scaling, but computed with - ``scaled_dot_product_attention`` and an ``is_causal=True`` flag instead of - the materialised causal-mask buffer + naive softmax. No KV cache: the + Fused QKV projection (``qkv_net``) and output projection (``o_net``), both + bias-free, with ``d_head ** -0.5`` scaling computed via + ``scaled_dot_product_attention`` with ``is_causal=True``. No KV cache: the autoregressive loop re-runs the full (short, fixed-length) sequence each step, which is what makes the whole thing CUDA-graph capturable. """ @@ -131,19 +125,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class EasyMagpieLTFeedForward(nn.Module): - """Positionwise FFN matching ``transformer_2501.PositionwiseConvFF`` weights. + """Positionwise feed-forward network. - The reference uses ``Conv1d(kernel_size=1)`` layers named ``proj.conv`` and - ``o_net.conv`` (no bias). A kernel-1 conv is a plain linear over the channel - dim, so we keep the exact ``Conv1d`` submodule names — the checkpoint loads - 1:1 — and apply them with a single transpose, GELU(tanh) in between. + Uses ``Conv1d(kernel_size=1)`` layers named ``proj.conv`` and ``o_net.conv`` + (no bias). A kernel-1 conv is a plain linear over the channel dim, applied + with a single transpose and GELU(tanh) in between. The ``Conv1d`` submodule + names match the training checkpoint so weights load 1:1. """ def __init__(self, d_model: int, d_ffn: int) -> None: super().__init__() - # Wrap the Conv1d in a tiny container so the parameter path is - # ``proj.conv.weight`` / ``o_net.conv.weight`` exactly as in the - # reference ``ConvolutionLayer``. self.proj = _Conv1dWrapper(d_model, d_ffn) self.o_net = _Conv1dWrapper(d_ffn, d_model) self.act = nn.GELU(approximate="tanh") @@ -170,10 +161,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class EasyMagpieLTLayer(nn.Module): """One pre-norm transformer layer (self-attn + FFN), bias-free LayerNorms. - Residual structure matches ``transformer_2501.TransformerLayer`` with an - all-ones ``x_mask`` (inference): ``x = x + attn(norm_self(x))`` then - ``x = x + ff(norm_pos_ff(x))``. The ``x * x_mask`` multiplications are - identities when nothing is padded, so they are dropped. + Residual structure: ``x = x + attn(norm_self(x))`` then + ``x = x + ff(norm_pos_ff(x))``. """ def __init__(self, d_model: int, d_ffn: int, n_heads: int) -> None: @@ -202,9 +191,8 @@ class EasyMagpieLocalTransformer(nn.Module): Decorated with ``@support_torch_compile`` so vLLM captures a single CUDA graph for the fixed ``(num_tokens, num_stacked_codebooks, d_model)`` input - shape. Weight layout mirrors ``transformer_2501.Transformer``: - ``position_embeddings`` (learnable), ``layers.{i}.*`` and a no-op - ``norm_out`` (``apply_norm_out=False`` in the reference, hence ``Identity``). + shape. Holds learnable ``position_embeddings``, the stacked ``layers.{i}.*`` + and a no-op ``norm_out``. """ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: @@ -214,15 +202,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: n_heads = arch.local_transformer_n_heads n_layers = arch.local_transformer_n_layers d_ffn = d_model * 4 - # +2 matches the reference ``max_length_causal_mask`` head-room - # (``num_stacked_codebooks + 2``). + # +2 of head-room over ``num_stacked_codebooks`` for the positional table. max_len = arch.num_stacked_codebooks + 2 self.position_embeddings = nn.Embedding(max_len, d_model) self.layers = nn.ModuleList( [EasyMagpieLTLayer(d_model, d_ffn, n_heads) for _ in range(n_layers)] ) - # apply_norm_out=False in the reference config -> no parameters. self.norm_out = nn.Identity() def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor: @@ -267,8 +253,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: lt_hidden = arch.local_transformer_hidden_dim # Per-codebook audio token embeddings (shared with the outer model's - # decode-step input-embedding assembly). Names match the reference - # checkpoint's ``audio_embeddings.{i}``. + # decode-step input-embedding assembly). self.audio_embeddings = nn.ModuleList( [nn.Embedding(self.num_tokens_per_codebook, self.audio_embedding_dim) for _ in range(self.num_codebooks)] ) @@ -321,9 +306,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: def init_forbidden_mask(self) -> None: """Forbid all trailing special tokens except audio EOS. - Mirrors ``SpecialAudioToken.get_forbidden_tokens`` — everything in the - special-token block above ``codebook_size`` is blocked at sampling - time, except ``audio_eos`` which must remain reachable to terminate. + Everything in the special-token block above ``codebook_size`` is blocked + at sampling time, except ``audio_eos`` which must remain reachable to + terminate. """ mask = torch.zeros(self.num_tokens_per_codebook, dtype=torch.bool, device=self.forbidden_mask.device) mask[self.arch.codebook_size :] = True @@ -339,10 +324,9 @@ def embed_codebook(self, codebook_idx: int, codes: torch.Tensor) -> torch.Tensor def embed_audio_frame(self, codes: torch.Tensor) -> torch.Tensor: """Embed a full frame of stacked codes into the backbone embedding space. - Averages per-codebook embeddings then applies ``audio_in_projection``, - matching the reference ``embed_audio_tokens`` (which sums and divides by - the number of codebooks). Used by the outer model to build the decode - input embedding from the previous frame's codes. + Averages the per-codebook embeddings then applies ``audio_in_projection``. + Used by the outer model to build the decode input embedding from the + previous frame's codes. Args: codes: ``[num_tokens, num_codebooks]`` int64 codes. From f984ee1850ade1152227dc51376726338d6f251c Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Tue, 2 Jun 2026 18:20:44 +0200 Subject: [PATCH 08/15] examples/tts/easymagpie_vllm_omni: implement delay and proper phoneme prediction processing Signed-off-by: Viacheslav Klimkov --- .../easy_magpietts_convert_to_vllm.py | 28 ++++ .../easymagpie_inference_demo.ipynb | 26 +++- .../easymagpie_vllm_omni/config.py | 38 +++++ .../easymagpie_vllm_omni/easymagpie.py | 131 +++++++++++++++--- 4 files changed, 198 insertions(+), 25 deletions(-) diff --git a/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py b/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py index 664a243d7415..4cb99a08baee 100644 --- a/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py +++ b/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py @@ -306,6 +306,34 @@ def build_config(model, vocab_size: int, torch_dtype: str) -> dict: has_phoneme = getattr(model, "phoneme_tokenizer", None) is not None config["phoneme_stacking_factor"] = int(getattr(model, "phoneme_stacking_factor", 0)) if has_phoneme else 0 config["phoneme_vocab_size"] = int(getattr(model, "phoneme_vocab_size", 0)) if has_phoneme else 0 + if has_phoneme: + # Phoneme special-token ids + the confidence→UNK replacement threshold, + # consumed by the in-engine phoneme stream (BOS seeding, EOS-stop, UNK). + config["phoneme_bos_id"] = int(model.phoneme_tokenizer.bos_token_id) + config["phoneme_eos_id"] = int(model.phoneme_tokenizer.eos_token_id) + unk_id = getattr(model.phoneme_tokenizer, "unk_token_id", None) + if unk_id is not None: + config["phoneme_unk_id"] = int(unk_id) + config["phoneme_confidence_unk_threshold"] = float(getattr(model, "phoneme_confidence_unk_threshold", 0.0)) + + # ── Streaming delays from the default inference mode ── + # The reference offsets the text/phoneme/audio streams by these per-mode + # delays; the vLLM model reproduces them in its decode step. A 0/0 (or "full") + # mode runs the three streams in lock-step. + default_mode = model.mode_name_to_mode.get(model.default_inference_mode) + if default_mode is not None: + if default_mode.text_input_mode != "streaming": + logging.warning( + "Converting a checkpoint whose default inference mode is " + f"'{default_mode.text_input_mode}' (not 'streaming'); the vLLM model only " + "implements the streaming-mode delay semantics (audio starts after " + "`streaming_speech_delay` text tokens)." + ) + config["streaming_phonemes_delay"] = int(default_mode.streaming_phonemes_delay) + config["streaming_speech_delay"] = int(default_mode.streaming_speech_delay) + else: + config["streaming_phonemes_delay"] = 0 + config["streaming_speech_delay"] = 0 config["num_task_embeddings"] = len(model.training_modes) if model.task_embedding is not None else 0 diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb index 8cbcc17f4665..ebab7bc05532 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb @@ -192,18 +192,20 @@ " \"model_arch\": \"EasyMagpieTTSForConditionalGeneration\",\n", " \"worker_type\": \"ar\",\n", " \"scheduler_cls\": \"vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler\",\n", - " #\"enforce_eager\": True, # dummy run: skip CUDA-graph capture for a faster start\n", + " \"enforce_eager\": True, # dummy run: skip CUDA-graph capture for a faster start\n", " \"trust_remote_code\": True,\n", " \"async_scheduling\": True,\n", " \"enable_prefix_caching\": False,\n", " \"engine_output_type\": \"audio\",\n", - " \"gpu_memory_utilization\": 0.6,\n", + " \"gpu_memory_utilization\": 0.8,\n", " \"distributed_executor_backend\": \"uni\",\n", " \"max_num_batched_tokens\": MAX_NUM_BATCHED_TOKENS,\n", " \"max_model_len\": MAX_MODEL_LEN,\n", " # bf16 (not fp32): the Nemotron-H fused-MoE Triton kernel's block\n", " # sizes are tuned for 16-bit and overflow shared memory in fp32.\n", - " \"dtype\": \"bfloat16\",\n", + " #\"dtype\": \"bfloat16\",\n", + " \"dtype\": \"float16\",\n", + " \"mamba_ssm_cache_dtype\": \"float32\",\n", " \"attention_backend\": \"TRITON_ATTN\",\n", " # We feed prompt_token_ids + text_tokens directly; the model still\n", " # loads the bundled AutoTokenizer from MODEL_DIR for context_text.\n", @@ -308,10 +310,19 @@ "# index the baked text_embedding table directly).\n", "text_tokens = tokenizer.encode(TEXT, add_special_tokens=False) + [TEXT_EOS_ID]\n", "\n", + "# Audio (local-transformer) sampling params. vLLM's SamplingParams.temperature\n", + "# drives only the dummy backbone token sampler, so the *audio* temperature/top-k\n", + "# are forwarded via additional_information. temperature=0.0 == argmax\n", + "# (deterministic; matches the torch reference run with --temperature 0.0 --no_cfg).\n", + "LT_TEMPERATURE = 0.0\n", + "LT_TOPK = 80\n", + "\n", "additional_information = {\n", " \"speaker_embedding\": speaker_embedding, # (T_audio, embedding_dim) tensor\n", " \"context_text\": CONTEXT_TEXT, # plain string, tokenized in-model\n", " \"text_tokens\": text_tokens, # list[int], grows by one per step\n", + " \"temperature\": LT_TEMPERATURE, # audio sampling temperature (local transformer)\n", + " \"top_k\": LT_TOPK, # audio sampling top-k (local transformer)\n", "}\n", "\n", "prompt = {\n", @@ -393,6 +404,15 @@ " if len(token_ids) > 0:\n", " audio_codes = audio_codes[-len(token_ids):].contiguous()\n", "\n", + " # Drop the leading streaming_speech_delay warm-up frames. With the streaming\n", + " # delay the audio stream only opens at decode step == speech_delay, so the\n", + " # first speech_delay decoded frames carry no real audio (the audio channel was\n", + " # masked off while the model consumed lookahead text/phonemes).\n", + " speech_delay = int(getattr(arch, \"streaming_speech_delay\", 0) or 0)\n", + " if speech_delay > 0:\n", + " print(f\"dropping {speech_delay} leading speech-delay warm-up frames\")\n", + " audio_codes = audio_codes[speech_delay:].contiguous()\n", + "\n", " # Trim at the audio EOS: the model signals end-of-speech inside the codes\n", " # (codebook 0 == audio_eos_id), not via the vLLM token stream.\n", " eos_frames = (audio_codes[:, 0] == arch.audio_eos_id).nonzero(as_tuple=True)[0]\n", diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py index 1b086ec3e562..b6b76e97be0f 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/config.py @@ -72,6 +72,23 @@ class EasyMagpieOmniArch: phoneme_stacking_factor: int = 1 phoneme_vocab_size: int = 2051 + # ── Streaming delays (per the checkpoint's default inference mode) ── + # The text/phoneme/audio streams are temporally offset: at decode step ``k`` + # the text channel consumes ``text_tokens[k]``, the phoneme channel starts at + # ``k == streaming_phonemes_delay`` (seeded with phoneme BOS), and the audio + # channel starts at ``k == streaming_speech_delay`` (seeded with audio BOS). + # Both default to 0 (lock-step), which reproduces a non-delayed / "full" mode. + streaming_phonemes_delay: int = 0 + streaming_speech_delay: int = 0 + + # Phoneme special-token ids (into the per-stack ``phoneme_embeddings`` table) + # and the confidence→UNK replacement threshold. ``None`` falls back to the + # IPABPETokenizer convention (bos/eos/unk = vocab-3/-2/-1). + phoneme_bos_id: int | None = None + phoneme_eos_id: int | None = None + phoneme_unk_id: int | None = None + phoneme_confidence_unk_threshold: float = 0.0 + # Number of multi-mode task ("service token") embeddings. The reference model # prepends a single learned per-mode embedding to the prefill context when # trained with >1 mode (``cfg.training_modes``); 0 disables it (single-mode @@ -122,6 +139,21 @@ def mask_token_id(self) -> int: return self.forced_mask_token_id return self.codebook_size + SPECIAL_AUDIO_MASK + @property + def resolved_phoneme_bos_id(self) -> int: + """Phoneme BOS id, falling back to the IPABPETokenizer convention (vocab-3).""" + return self.phoneme_bos_id if self.phoneme_bos_id is not None else self.phoneme_vocab_size - 3 + + @property + def resolved_phoneme_eos_id(self) -> int: + """Phoneme EOS id, falling back to the IPABPETokenizer convention (vocab-2).""" + return self.phoneme_eos_id if self.phoneme_eos_id is not None else self.phoneme_vocab_size - 2 + + @property + def resolved_phoneme_unk_id(self) -> int: + """Phoneme UNK id, falling back to the IPABPETokenizer convention (vocab-1).""" + return self.phoneme_unk_id if self.phoneme_unk_id is not None else self.phoneme_vocab_size - 1 + @classmethod def from_hf_config(cls, hf_config: Any) -> "EasyMagpieOmniArch": """Build an arch description from a vLLM ``hf_config``. @@ -142,6 +174,12 @@ def from_hf_config(cls, hf_config: Any) -> "EasyMagpieOmniArch": "frame_stacking_factor", "phoneme_stacking_factor", "phoneme_vocab_size", + "streaming_phonemes_delay", + "streaming_speech_delay", + "phoneme_bos_id", + "phoneme_eos_id", + "phoneme_unk_id", + "phoneme_confidence_unk_threshold", "num_task_embeddings", "local_transformer_n_layers", "local_transformer_n_heads", diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py index c1d173a92402..8f01eb886bb8 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -62,8 +62,19 @@ * ``text_tokens`` — Python ``list[int]`` of subword ids that grows by one per decode step; step ``k`` consumes ``text_tokens[k]`` (embedded through the precomputed per-subword table). -* ``phoneme_tokens`` (optional) — same streaming-list contract for the phoneme - channel; if omitted the phoneme branch is skipped. +* ``temperature`` / ``top_k`` (prefill only, optional) — audio sampling params + for the local transformer. vLLM's ``SamplingParams.temperature`` drives only + the dummy backbone token sampler, so the *audio* temperature/top-k are passed + here and applied to the code predictor (defaults: ``0.7`` / ``80``). + +Streaming delays: the text, phoneme and audio streams are temporally offset by +the checkpoint's ``streaming_phonemes_delay`` / ``streaming_speech_delay`` (baked +into ``config.json`` by the converter from the default inference mode). The text +stream runs from decode step 0; the phoneme stream opens at step +``phonemes_delay`` (seeded with phoneme BOS) and the audio stream at step +``speech_delay`` (seeded with audio BOS). The leading ``speech_delay`` decoded +frames are warm-up only and must be dropped by the caller. Delays of 0/0 +reproduce a lock-step / non-delayed model. """ from __future__ import annotations @@ -191,6 +202,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # caller passes plain text, never pre-tokenized ids. self._text_tokenizer: Any = None + # ── Streaming delays (text leads phoneme by ``phonemes_delay`` and audio + # by ``speech_delay`` decode steps; 0/0 == lock-step). ── + self.phonemes_delay = int(getattr(arch, "streaming_phonemes_delay", 0) or 0) + self.speech_delay = int(getattr(arch, "streaming_speech_delay", 0) or 0) + # Phoneme channel (optional — only built when the checkpoint has one). self.has_phoneme = arch.phoneme_vocab_size > 0 and arch.phoneme_stacking_factor > 0 if self.has_phoneme: @@ -200,6 +216,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.phoneme_final_proj = nn.Linear( self.hidden_dim, arch.phoneme_vocab_size * arch.phoneme_stacking_factor ) + # Phoneme special-token ids + confidence→UNK replacement threshold. + self.phoneme_bos_id = int(arch.resolved_phoneme_bos_id) + self.phoneme_eos_id = int(arch.resolved_phoneme_eos_id) + self.phoneme_unk_id = int(arch.resolved_phoneme_unk_id) + self.phoneme_confidence_unk_threshold = float(arch.phoneme_confidence_unk_threshold) # ── Persistent, address-stable scratch buffers ───────────────── max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens @@ -401,14 +422,34 @@ def _assemble_decode_embeddings(self, combined: torch.Tensor, idx) -> None: @torch.no_grad() def _predict_phonemes(self, hidden_states: torch.Tensor, idx) -> None: - """Argmax the phoneme head and stash the prediction for the next step.""" + """Argmax the phoneme head (with confidence→UNK replacement) and stash it. + + The UNK replacement mirrors the reference: when the max phoneme + probability of any stacked channel falls below + ``phoneme_confidence_unk_threshold`` (and the step is not an EOS step), + the whole step is replaced with the UNK id to curb error propagation. + + This is done here — not in ``preprocess``/``postprocess`` — because this + is the only place the phoneme logits exist (preprocess has no logits, and + postprocess only sees the argmax id). It uses only elementwise ops + + ``torch.where`` (no ``.item()`` / host sync), so it stays CUDA-graph safe. + """ # Run in the model dtype (don't force fp32): ``phoneme_final_proj`` weights # follow ``model_config.dtype`` (e.g. bf16), and argmax is dtype-insensitive, # so an fp32 upcast here would mismatch the weight dtype in ``F.linear``. logits = self.phoneme_final_proj(hidden_states[idx]) s = self.arch.phoneme_stacking_factor logits = logits.view(-1, s, self.arch.phoneme_vocab_size) - self._dec_phoneme_tokens[idx] = logits.argmax(dim=-1).long() + preds = logits.argmax(dim=-1).long() # (n, S) + + if self.phoneme_confidence_unk_threshold > 0.0: + max_probs = torch.softmax(logits.float(), dim=-1).amax(dim=-1) # (n, S) + underconfident = (max_probs < self.phoneme_confidence_unk_threshold).any(dim=1, keepdim=True) + eos_step = (preds == self.phoneme_eos_id).any(dim=1, keepdim=True) + replace = underconfident & (~eos_step) + preds = torch.where(replace, torch.full_like(preds, self.phoneme_unk_id), preds) + + self._dec_phoneme_tokens[idx] = preds self._dec_phoneme_valid[idx] = 1 # ------------------------------------------------------------------ @@ -502,6 +543,13 @@ def _preprocess_prefill( device: torch.device, info_dict: dict[str, Any], ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + # Forward the audio (local-transformer) sampling params from the request. + # vLLM's ``SamplingParams.temperature`` drives only the dummy backbone + # token sampler, so the real audio temperature/top-k are passed via + # ``additional_information`` and applied to the code predictor here (once, + # at prefill — they are scalars that persist across decode steps). + self._maybe_set_lt_sampling_params(info_dict) + prefill_embeds = self._build_prefill_embeds(device, info_dict) offset = int(info_dict.get("prefill_offset", 0) or 0) @@ -577,6 +625,20 @@ def _build_prefill_embeds( return torch.cat(parts, dim=0) + def _maybe_set_lt_sampling_params(self, info_dict: dict[str, Any]) -> None: + """Apply per-request audio sampling params to the local transformer. + + Reads ``temperature`` / ``top_k`` (alias ``topk``) from the request's + ``additional_information`` and stores them on the code predictor. Absent + keys leave the existing defaults untouched. + """ + temperature = info_dict.get("temperature") + if temperature is not None: + self.code_predictor.temperature = float(self._first_str(temperature) or 0.0) + top_k = info_dict.get("top_k", info_dict.get("topk")) + if top_k is not None: + self.code_predictor.top_k = int(float(self._first_str(top_k) or 0)) + def _get_text_tokenizer(self): """Lazily load the context-text tokenizer from the model directory. @@ -640,11 +702,13 @@ def _preprocess_decode( info_dict: dict[str, Any], ) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: decode_offset = int(info_dict.get("decode_offset", 0) or 0) + info_update: dict[str, Any] = {"decode_offset": decode_offset + 1} - # Text channel (streaming list, one subword consumed per step). Step k + # ── Text channel ── (delay 0: one subword per step from step 0). Step k # consumes text_tokens[k] (the list ends with the text eos id). Once the # stream is exhausted the channel is masked off (adds nothing) rather than - # repeating the last token. + # repeating the last token. The text stream leads the phoneme/audio + # streams by their respective delays. text_tokens = info_dict.get("text_tokens") if isinstance(text_tokens, list) and decode_offset < len(text_tokens): self._dec_text_tokens[start] = int(text_tokens[decode_offset]) @@ -652,29 +716,52 @@ def _preprocess_decode( else: self._dec_text_mask[start] = 0 - # Phoneme channel: previous-step prediction stashed by postprocess. + # ── Phoneme channel ── opens at decode step == ``phonemes_delay`` (seeded + # with phoneme BOS), then feeds back the previous step's prediction, and + # closes one step after the model emits the phoneme EOS (sticky flag). if self.has_phoneme: - last_phon = info_dict.get("last_phoneme_token") - if isinstance(last_phon, torch.Tensor) and last_phon.numel() > 0: - p = last_phon.to(device=device, dtype=torch.long).reshape(-1) - self._dec_phoneme_tokens[start, : p.shape[0]].copy_(p[: self.arch.phoneme_stacking_factor]) + phoneme_ended = bool(info_dict.get("phoneme_ended", False)) + feed_eos = False + if phoneme_ended or decode_offset < self.phonemes_delay: + self._dec_phoneme_valid[start] = 0 + elif decode_offset == self.phonemes_delay: + self._dec_phoneme_tokens[start].fill_(self.phoneme_bos_id) self._dec_phoneme_valid[start] = 1 else: - self._dec_phoneme_valid[start] = 0 - - # Audio channel: previous-frame codes (BOS seed on the first step). - last_codes = info_dict.get("last_audio_codes") - if isinstance(last_codes, torch.Tensor) and last_codes.numel() > 0: - c = last_codes.to(device=device, dtype=torch.long).reshape(-1)[: self.num_codebooks] - self._dec_audio_codes[start, : c.shape[0]].copy_(c) - self._dec_audio_valid[start] = 1 - else: - # First decode step after prefill: seed with audio BOS. + last_phon = info_dict.get("last_phoneme_token") + if isinstance(last_phon, torch.Tensor) and last_phon.numel() > 0: + p = last_phon.to(device=device, dtype=torch.long).reshape(-1)[: self.arch.phoneme_stacking_factor] + self._dec_phoneme_tokens[start, : p.shape[0]].copy_(p) + self._dec_phoneme_valid[start] = 1 + feed_eos = bool((p == self.phoneme_eos_id).any()) + else: + self._dec_phoneme_valid[start] = 0 + if phoneme_ended or feed_eos: + info_update["phoneme_ended"] = True + + # ── Audio channel ── opens at decode step == ``speech_delay`` (seeded with + # audio BOS), then feeds back the previous frame's codes. For the leading + # ``speech_delay`` steps the channel is masked off (only text/phoneme + # condition the backbone); the local transformer still runs for CUDA-graph + # stability but its codes for those frames are discarded by the caller and + # never fed back here. + if decode_offset < self.speech_delay: + self._dec_audio_valid[start] = 0 + elif decode_offset == self.speech_delay: self._dec_audio_codes[start].fill_(self.arch.audio_bos_id) self._dec_audio_valid[start] = 1 + else: + last_codes = info_dict.get("last_audio_codes") + if isinstance(last_codes, torch.Tensor) and last_codes.numel() > 0: + c = last_codes.to(device=device, dtype=torch.long).reshape(-1)[: self.num_codebooks] + self._dec_audio_codes[start, : c.shape[0]].copy_(c) + self._dec_audio_valid[start] = 1 + else: + # Fallback (should not happen once audio has started): seed BOS. + self._dec_audio_codes[start].fill_(self.arch.audio_bos_id) + self._dec_audio_valid[start] = 1 inputs_embeds_out = torch.zeros((1, self.embedding_dim), device=device, dtype=self._combined_embeddings.dtype) - info_update = {"decode_offset": decode_offset + 1} return input_ids, inputs_embeds_out, info_update def postprocess(self, hidden_states: torch.Tensor, multimodal_outputs: Optional[dict[str, Any]] = None, **_: Any): From 9ab003808797c73aa6913bb9327fefbf35331704 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Wed, 3 Jun 2026 12:27:31 +0200 Subject: [PATCH 09/15] examples/tts/easymagpie_vllm_omni: take text as input instead of tokens Signed-off-by: Viacheslav Klimkov --- .../easymagpie_inference_demo.ipynb | 48 +++++++++---------- .../easymagpie_vllm_omni/easymagpie.py | 46 ++++++++++++++++-- 2 files changed, 65 insertions(+), 29 deletions(-) diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb index ebab7bc05532..192e08e1ab03 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb @@ -23,9 +23,10 @@ " (`[task_embedding? | speaker_embedding | context_text_embedded]`) and tokenizes\n", " `context_text` itself. `prompt_token_ids = [0] * prompt_len`, sized with\n", " `EasyMagpieTTSForConditionalGeneration.estimate_prompt_len(...)`.\n", - "* **decode** — each step consumes one subword id from the streaming\n", - " `additional_information.text_tokens` list (the tokenized target sentence); the\n", - " local transformer samples all `C * S` stacked audio codebooks for the frame.\n", + "* **decode** — the caller passes the plain target sentence as\n", + " `additional_information.text`; the model tokenizes it in-engine (no caller-side\n", + " tokenization) and consumes one subword id per step. The local transformer\n", + " samples all `C * S` stacked audio codebooks for the frame.\n", "* **output** — per-step audio codes are surfaced on\n", " `OmniOutput.multimodal_outputs[\\\"audio_codes\\\"]` (`BT x num_codebooks`), and the\n", " engine accumulates them across steps just like eartts, so we trim to the last\n", @@ -150,10 +151,10 @@ "downstream consumes it, and `multimodal_output[\\\"audio_codes\\\"]` comes back\n", "`None`).\n", "\n", - "`skip_tokenizer_init: true` — we feed `prompt_token_ids` + `text_tokens`\n", - "directly, so vLLM doesn't need its own tokenizer for the prompt (the model still\n", - "loads the bundled `AutoTokenizer` from `MODEL_DIR` in-engine to tokenize\n", - "`context_text`).\n", + "`skip_tokenizer_init: true` — we feed `prompt_token_ids` directly, so vLLM\n", + "doesn't need its own tokenizer for the prompt. The model loads the bundled\n", + "`AutoTokenizer` from `MODEL_DIR` in-engine and uses it to tokenize both\n", + "`context_text` and the target `text`.\n", "\n", "`max_model_len` must cover `T_ctx` (prefill) + the number of decode steps." ] @@ -191,8 +192,8 @@ " \"max_num_seqs\": 1,\n", " \"model_arch\": \"EasyMagpieTTSForConditionalGeneration\",\n", " \"worker_type\": \"ar\",\n", - " \"scheduler_cls\": \"vllm_omni.core.sched.omni_ar_scheduler.OmniARScheduler\",\n", - " \"enforce_eager\": True, # dummy run: skip CUDA-graph capture for a faster start\n", + " \"scheduler_cls\": \"vllm_omni.core.sched.omni_ar_scheduler.OmniARAsyncScheduler\",\n", + " #\"enforce_eager\": True, # dummy run: skip CUDA-graph capture for a faster start\n", " \"trust_remote_code\": True,\n", " \"async_scheduling\": True,\n", " \"enable_prefix_caching\": False,\n", @@ -207,8 +208,8 @@ " \"dtype\": \"float16\",\n", " \"mamba_ssm_cache_dtype\": \"float32\",\n", " \"attention_backend\": \"TRITON_ATTN\",\n", - " # We feed prompt_token_ids + text_tokens directly; the model still\n", - " # loads the bundled AutoTokenizer from MODEL_DIR for context_text.\n", + " # We feed prompt_token_ids directly; the model loads the bundled\n", + " # AutoTokenizer from MODEL_DIR to tokenize context_text + text.\n", " \"skip_tokenizer_init\": True,\n", " },\n", " \"default_sampling_params\": {\n", @@ -254,10 +255,11 @@ " `[task_embedding? | speaker_embedding | context_text_embedded]`.\n", "* **`context_text`** — a plain conditioning string, here `\"[EN]\"`. The model\n", " tokenizes it in-engine and embeds it through the baked `text_embedding` table.\n", - "* **`text_tokens`** `list[int]` — the streaming subword stream: the target\n", - " sentence tokenized with the bundled tokenizer, ending with the model's text\n", - " EOS id. Decode step `k` consumes `text_tokens[k]`; once exhausted the channel\n", - " is masked off (matching the reference `... encode(transcript) + [eos_id]`).\n", + "* **`text`** — the plain target sentence to synthesize. The model tokenizes it\n", + " in-engine at prefill (HF special tokens disabled, trailing text-EOS id\n", + " appended — matching the reference `encode(transcript) + [eos_id]`) and streams\n", + " one subword id per decode step; once exhausted the channel is masked off. No\n", + " caller-side tokenization needed.\n", "\n", "`prompt_token_ids = [0] * prompt_len` are placeholders (the model feeds the\n", "backbone via `inputs_embeds`, never these ids). `prompt_len` must equal the\n", @@ -294,9 +296,9 @@ "# Target sentence to synthesize.\n", "TEXT = \"Hello, this is a test of the EasyMagpie text to speech model.\"\n", "\n", - "# Same tokenizer the engine loads from MODEL_DIR. Used to (a) size the prefill\n", - "# placeholders so prompt_token_ids length matches the assembled context, and\n", - "# (b) tokenize the target sentence into the streaming text stream.\n", + "# Same tokenizer the engine loads from MODEL_DIR. Used only to size the prefill\n", + "# placeholders so prompt_token_ids length matches the assembled context (the\n", + "# target text is tokenized in-engine — we just pass the plain string below).\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, trust_remote_code=True)\n", "prompt_len = EasyMagpieTTSForConditionalGeneration.estimate_prompt_len(\n", " speaker_embedding,\n", @@ -305,11 +307,6 @@ " has_task_embedding=arch.num_task_embeddings > 0,\n", ")\n", "\n", - "# Streaming subword ids consumed one per decode step. Mirrors the reference\n", - "# `encode(transcript) + [eos_id]` (no BOS; HF special tokens disabled so the ids\n", - "# index the baked text_embedding table directly).\n", - "text_tokens = tokenizer.encode(TEXT, add_special_tokens=False) + [TEXT_EOS_ID]\n", - "\n", "# Audio (local-transformer) sampling params. vLLM's SamplingParams.temperature\n", "# drives only the dummy backbone token sampler, so the *audio* temperature/top-k\n", "# are forwarded via additional_information. temperature=0.0 == argmax\n", @@ -320,7 +317,7 @@ "additional_information = {\n", " \"speaker_embedding\": speaker_embedding, # (T_audio, embedding_dim) tensor\n", " \"context_text\": CONTEXT_TEXT, # plain string, tokenized in-model\n", - " \"text_tokens\": text_tokens, # list[int], grows by one per step\n", + " \"text\": TEXT, # plain target sentence, tokenized in-model\n", " \"temperature\": LT_TEMPERATURE, # audio sampling temperature (local transformer)\n", " \"top_k\": LT_TOPK, # audio sampling top-k (local transformer)\n", "}\n", @@ -337,8 +334,7 @@ "\n", "print(f\"speaker_embedding : {tuple(speaker_embedding.shape)}\")\n", "print(f\"context_text : {CONTEXT_TEXT!r} -> {tokenizer.encode(CONTEXT_TEXT)}\")\n", - "print(f\"text : {TEXT!r}\")\n", - "print(f\"text_tokens (len {len(text_tokens):3d}) : {text_tokens[:8]}{' ...' if len(text_tokens) > 8 else ''}\")\n", + "print(f\"text : {TEXT!r} (tokenized in-engine)\")\n", "print(f\"prompt_len (placeholders) : {prompt_len}\")\n", "print(f\"decode steps (max_tokens) : {DECODE_STEPS}\")\n", "\n", diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py index 8f01eb886bb8..d21f357c36aa 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -59,9 +59,18 @@ The caller passes ``prompt_token_ids = [0] * T_ctx``, where ``T_ctx`` is the assembled context length (``[task?] + T_audio + len(tokenize(context_text))``). -* ``text_tokens`` — Python ``list[int]`` of subword ids that grows by one per - decode step; step ``k`` consumes ``text_tokens[k]`` (embedded through the - precomputed per-subword table). +* ``text`` (prefill only) — the plain target sentence to synthesize. This is the + caller's text input: the model tokenizes it in-model at prefill with the + checkpoint's text tokenizer (HF special tokens disabled, trailing text-EOS id + appended), so callers never tokenize themselves. The resulting subword ids are + consumed one per decode step (step ``k`` consumes id ``k``, embedded through + the precomputed per-subword ``text_embedding`` table); once exhausted the text + channel is masked off. + + (Internal: the tokenized ids are stashed as ``text_tokens`` in the per-request + info dict between prefill and decode. A future streaming mode will let the + caller push subword ids gradually instead of one upfront ``text`` string; for + now assume ``text`` is always provided whole at prefill.) * ``temperature`` / ``top_k`` (prefill only, optional) — audio sampling params for the local transformer. vLLM's ``SamplingParams.temperature`` drives only the dummy backbone token sampler, so the *audio* temperature/top-k are passed @@ -187,6 +196,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: text_vocab_size = int(getattr(hf_config, "text_vocab_size", getattr(hf_config, "vocab_size", 0))) self.text_embedding = nn.Embedding(text_vocab_size, self.embedding_dim) + # Text-stream EOS id — the last-but-one row of the text vocab, matching + # the reference ``EasyMagpieTTSInferenceModel.eos_id = num_tokens - 2``. + # Appended to the in-model-tokenized target text stream (see + # :meth:`_encode_text_stream`). + self.text_eos_id = text_vocab_size - 2 + # Task ("service token") embedding — a single learned per-mode row # prepended to the prefill context for multi-mode checkpoints. Built only # when the checkpoint carries one; otherwise ``None``. @@ -570,6 +585,17 @@ def _preprocess_prefill( "prefill_offset": offset + span_len, "decode_offset": 0, } + # Tokenize the caller's ``text`` in-model and stash the subword ids in the + # per-request info dict (alongside the offsets) so each decode step + # consumes one id from it without the caller ever running the tokenizer + # (see :meth:`_preprocess_decode`). The caller always passes ``text`` + # whole at prefill; a future streaming mode will instead let the caller + # push ``text_tokens`` ids gradually, which is why an already-present + # ``text_tokens`` list is left untouched here. + if not info_dict.get("text_tokens"): + text = self._first_str(info_dict.get("text")) + if text: + info_update["text_tokens"] = self._encode_text_stream(text) input_ids_out = torch.full_like(input_ids, _DUMMY_TOKEN_ID) return input_ids_out, take, info_update @@ -663,6 +689,20 @@ def _encode_context_text(self, context_text: str, device: torch.device) -> torch ids = tok.encode(context_text) return torch.tensor(ids, device=device, dtype=torch.long) + def _encode_text_stream(self, text: str) -> list[int]: + """Tokenize the target ``text`` into the streaming subword-id list. + + Mirrors the reference ``tokenizer.encode(transcript) + [eos_id]``: HF + special tokens are disabled so the raw ids index the baked + ``text_embedding`` table directly, and the trailing text-EOS id closes + the stream. One id is consumed per decode step (see + :meth:`_preprocess_decode`); once exhausted the text channel is masked + off. + """ + tok = self._get_text_tokenizer() + ids = tok.encode(text, add_special_tokens=False) + return list(ids) + [self.text_eos_id] + @staticmethod def estimate_prompt_len( speaker_embedding: torch.Tensor, From 36ce9a5fca0be2366effc8fc68dde424f5ba7fbf Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Wed, 3 Jun 2026 12:28:16 +0200 Subject: [PATCH 10/15] examples/tts/easymagpie_vllm_omni: add script to benchmark the acoustic token prediction Signed-off-by: Viacheslav Klimkov --- .../benchmark_easymagpie_tts.py | 1082 +++++++++++++++++ 1 file changed, 1082 insertions(+) create mode 100644 examples/tts/easymagpie_vllm_omni/benchmark_easymagpie_tts.py diff --git a/examples/tts/easymagpie_vllm_omni/benchmark_easymagpie_tts.py b/examples/tts/easymagpie_vllm_omni/benchmark_easymagpie_tts.py new file mode 100644 index 000000000000..c8ac80f4e563 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/benchmark_easymagpie_tts.py @@ -0,0 +1,1082 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Benchmark the EasyMagpieTTS talker via a single-stage AsyncOmni engine. + +Runs the EasyMagpie talker (``EasyMagpieTTSForConditionalGeneration``) only — +no codec / code2wav — producing stacked audio codes as output. It mirrors the +reference ``qwen3-tts`` talker benchmark and the +``easymagpie_inference_demo.ipynb`` engine setup. + +Metrics measured under configurable concurrency: + +* **TTFT** — time to first decoded frame (first engine token). +* **ITL** — per-token inter-token latency (excluding the first token). +* **E2E** — end-to-end latency per request (up to the audio-EOS frame). +* **RTX** — real-time factor (generated audio seconds / wall time). Both the + per-request RTX and an overall (concurrency-aware) RTX are reported. +* **Throughput** — frames/s and requests/s. + +The decode loop stops at the audio-EOS frame (the EasyMagpie model signals +end-of-speech inside codebook 0 of the codes, not via the vLLM token stream), +so E2E / RTX reflect the real synthesized length rather than the full token +budget. Audio duration is derived from the number of decoded frames: +``audio_seconds = (frames - speech_delay) * frame_stacking_factor / codec_fps``. + +Reads texts from a file (one utterance per line, optionally tab-separated with +the text in the second column) or uses a small built-in default set. + +Usage: + # Basic benchmark with default prompts + python benchmark_easymagpie_tts.py \\ + --model ./easymp_vllm_model \\ + --num-requests 50 + + # From a text file with a concurrency sweep + python benchmark_easymagpie_tts.py \\ + --model ./easymp_vllm_model \\ + --text-file texts.txt \\ + --num-requests 100 \\ + --concurrency 1 4 8 + + # With torch profiler on the run + python benchmark_easymagpie_tts.py \\ + --model ./easymp_vllm_model \\ + --num-requests 20 --concurrency 1 --profile + + # Save JSON results + python benchmark_easymagpie_tts.py \\ + --model ./easymp_vllm_model \\ + --text-file texts.txt \\ + --num-requests 100 --concurrency 1 4 \\ + --result-dir results/ +""" + +import os + +# Keep spawn semantics consistent with the qwen3-tts / eartts demos in case the +# executor backend is switched to a multiproc one. +os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + +import argparse +import asyncio +import json +import logging +import tempfile +import time +import uuid +from dataclasses import asdict, dataclass, field +from datetime import datetime +from pathlib import Path +from typing import Any, Optional + +import numpy as np +import yaml + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", +) +logger = logging.getLogger(__name__) + +DEFAULT_PROMPTS = [ + "Hello, welcome to the voice synthesis benchmark test.", + "She said she would be here by noon, but nobody showed up.", + "The quick brown fox jumps over the lazy dog near the riverbank.", + "I can't believe how beautiful the sunset looks from up here on the mountain.", + "Please remember to bring your identification documents to the appointment tomorrow morning.", + "Have you ever wondered what it would be like to travel through time and visit ancient civilizations?", + "The restaurant on the corner serves the best pasta I have ever tasted in my entire life.", + "After the meeting, we should discuss the quarterly results and plan for the next phase.", + "Learning a new language takes patience, practice, and a genuine curiosity about other cultures.", + "The train leaves at half past seven, so we need to arrive at the station before then.", + "Could you please turn down the music a little bit, I'm trying to concentrate on my work.", + "It was a dark and stormy night when the old lighthouse keeper heard a knock at the door.", +] + + +# --------------------------------------------------------------------------- +# Stage config generation +# --------------------------------------------------------------------------- + + +def _build_easymagpie_stage_config( + max_num_seqs: int = 1, + profile: bool = False, + torch_profiler_dir: str = "./profiler_traces", + with_stack: bool = False, + record_shapes: bool = False, + gpu_memory_utilization: float = 0.8, + max_model_len: int = 1024, + max_num_batched_tokens: int = 1024, + enforce_eager: bool = False, + max_new_tokens: int = 256, + dtype: str = "float16", + distributed_executor_backend: str = "uni", + cudagraph_mode: Optional[str] = None, +) -> dict: + """Build a single-stage YAML dict containing only the EasyMagpie talker. + + Mirrors the engine_args used in ``easymagpie_inference_demo.ipynb``. + + ``cudagraph_mode`` (when set and ``enforce_eager`` is False) selects the + vLLM CUDA-graph capture strategy via ``compilation_config.cudagraph_mode``: + + * ``FULL_AND_PIECEWISE`` (vLLM default) — a single full graph over the whole + forward for uniform/decode-only batches, piecewise (per compile group: + backbone vs local transformer) for mixed/prefill batches. + * ``PIECEWISE`` — always piecewise, so the backbone and local transformer are + captured as *separate* graphs even during decode. This re-introduces a + launch boundary between them (so decode is a touch slower than FULL), but + makes the backbone-vs-LT split visible as two distinct ``cudaGraphLaunch`` + events in a profiler. + * ``FULL`` / ``FULL_DECODE_ONLY`` — full graph (decode only) capture. + * ``NONE`` — no CUDA graphs (equivalent to ``--enforce-eager``). + """ + engine_args: dict[str, Any] = { + "model_stage": "easymagpie", + "max_num_seqs": max_num_seqs, + "model_arch": "EasyMagpieTTSForConditionalGeneration", + "worker_type": "ar", + "scheduler_cls": "vllm_omni.core.sched.omni_ar_scheduler.OmniARAsyncScheduler", + "enforce_eager": enforce_eager, + "trust_remote_code": True, + "async_scheduling": True, + "enable_prefix_caching": False, + "engine_output_type": "audio", + "gpu_memory_utilization": gpu_memory_utilization, + # "uni" runs the worker in-process (no shm_broadcast IPC); use "mp" + # only when TP/PP > 1 or you actually need a separate worker process. + "distributed_executor_backend": distributed_executor_backend, + "max_num_batched_tokens": max_num_batched_tokens, + "max_model_len": max_model_len, + # bf16/fp16 (not fp32): the Nemotron-H fused-MoE Triton kernel's block + # sizes are tuned for 16-bit and overflow shared memory in fp32. + "dtype": dtype, + "mamba_ssm_cache_dtype": "float32", + "attention_backend": "TRITON_ATTN", + # We feed prompt_token_ids directly; the model loads the bundled + # AutoTokenizer from the model dir to tokenize context_text + text. + "skip_tokenizer_init": True, + } + + # CUDA-graph capture strategy. ``enforce_eager`` already disables graphs, so + # only set compilation_config when graphs are enabled (mirrors the sidecar + # server). Passed as a plain dict so it survives YAML serialization; vLLM + # parses it into a CompilationConfig. + if cudagraph_mode is not None and not enforce_eager: + engine_args["compilation_config"] = {"cudagraph_mode": cudagraph_mode} + + if profile: + engine_args["profiler_config"] = { + "profiler": "torch", + "torch_profiler_dir": os.path.abspath(torch_profiler_dir), + "torch_profiler_with_stack": with_stack, + "torch_profiler_record_shapes": record_shapes, + } + + cfg = { + "stage_args": [ + { + "stage_id": 0, + "stage_type": "llm", + "is_comprehension": True, + "final_output": True, + # "audio" (not "latent") is required for a single-stage AR TTS + # model: it makes the AR model runner attach the per-step + # multimodal payload ("audio_codes") to the output so the codes + # reach the client. + "final_output_type": "audio", + "runtime": {"devices": "0"}, + "engine_args": engine_args, + "default_sampling_params": { + # The backbone token sampler is a no-op (audio is sampled in + # the local transformer); the audio temperature/top-k are + # forwarded per-request via additional_information. + "temperature": 0.0, + "max_tokens": max_new_tokens, + "detokenize": False, + # Audio EOS lives in the codes, not the vLLM token stream, so + # let the budget run and stop client-side at the EOS frame. + "ignore_eos": True, + }, + } + ], + } + return cfg + + +def _write_temp_stage_config(cfg: dict) -> str: + """Write stage config dict to a temp YAML file, return its path.""" + tmp = tempfile.NamedTemporaryFile( + mode="w", + suffix=".yaml", + prefix="easymagpie_bench_", + delete=False, + ) + yaml.dump(cfg, tmp, default_flow_style=False, sort_keys=False) + tmp.close() + logger.info("Wrote single-stage config to %s", tmp.name) + return tmp.name + + +# --------------------------------------------------------------------------- +# Model metadata (arch scalars + tokenizer + speaker embedding) +# --------------------------------------------------------------------------- + + +@dataclass +class ModelMeta: + """Scalars + assets needed to build prompts and interpret outputs.""" + + arch: Any + tokenizer: Any + speaker_embedding: Any # torch.Tensor (T_audio, embedding_dim) + prompt_len: int + audio_eos_id: int + speech_delay: int + frame_stacking_factor: int + + +def _load_model_meta( + model_dir: str, + speaker: str, + speaker_embedding_path: Optional[str], + context_text: str, +) -> ModelMeta: + """Read config.json, tokenizer, and the speaker embedding from the model dir. + + Mirrors the prompt-prep cells of ``easymagpie_inference_demo.ipynb``: the + arch scalars come from ``config.json``, the speaker embedding from + ``speaker_embeddings/.pt``, and the prefill placeholder length from + ``EasyMagpieTTSForConditionalGeneration.estimate_prompt_len(...)``. + """ + import torch + from transformers import AutoTokenizer + + from easymagpie_vllm_omni.config import EasyMagpieOmniArch + from easymagpie_vllm_omni.easymagpie import EasyMagpieTTSForConditionalGeneration + + model_path = Path(model_dir) + config = json.loads((model_path / "config.json").read_text()) + arch = EasyMagpieOmniArch.from_hf_config(type("Cfg", (), config)) + + # Speaker-encoded context audio (audio branch of prepare_context_tensors). + if speaker_embedding_path is not None: + emb_path = Path(speaker_embedding_path) + else: + emb_path = model_path / "speaker_embeddings" / f"{speaker}.pt" + if not emb_path.exists(): + raise FileNotFoundError(f"Speaker embedding not found: {emb_path}") + loaded = torch.load(emb_path, map_location="cpu") + speaker_embedding = loaded["speaker_encoding"] if isinstance(loaded, dict) else loaded + speaker_embedding = speaker_embedding.to(torch.float32) + + tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) + + # prompt_len depends only on the speaker embedding + context_text (+ task + # embedding) — NOT on the target text (which is streamed in-engine), so we + # size it once. + prompt_len = EasyMagpieTTSForConditionalGeneration.estimate_prompt_len( + speaker_embedding, + tokenize=lambda t: tokenizer.encode(t), + context_text=context_text, + has_task_embedding=arch.num_task_embeddings > 0, + ) + + return ModelMeta( + arch=arch, + tokenizer=tokenizer, + speaker_embedding=speaker_embedding, + prompt_len=int(prompt_len), + audio_eos_id=int(arch.audio_eos_id), + speech_delay=int(getattr(arch, "streaming_speech_delay", 0) or 0), + frame_stacking_factor=int(arch.frame_stacking_factor), + ) + + +def build_prompt( + text: str, + meta: ModelMeta, + context_text: str, + lt_temperature: float, + lt_topk: int, +) -> dict: + """Build an engine input dict from a target sentence + the shared assets.""" + additional_information = { + "speaker_embedding": meta.speaker_embedding, # (T_audio, embedding_dim) + "context_text": context_text, # plain string, tokenized in-model + "text": text, # plain target sentence, tokenized in-model + "temperature": lt_temperature, # audio sampling temperature (local transformer) + "top_k": lt_topk, # audio sampling top-k (local transformer) + } + return { + "prompt_token_ids": [0] * meta.prompt_len, + "additional_information": additional_information, + } + + +# --------------------------------------------------------------------------- +# Result dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class RequestResult: + success: bool = False + text: str = "" + prompt_len: int = 0 + num_generated: int = 0 # decoded frames (engine tokens) up to EOS + audio_frames: int = 0 # codec frames of real audio (post speech-delay, pre-EOS) + audio_s: float = 0.0 # synthesized audio duration in seconds + steps: int = 0 + eos_reached: bool = False + ttft_s: float = 0.0 + e2e_s: float = 0.0 + rtx: float = 0.0 # audio_s / e2e_s + inter_token_latencies: list = field(default_factory=list) + error: str = "" + + +@dataclass +class BenchmarkResult: + config_name: str = "" + concurrency: int = 0 + num_requests: int = 0 + completed: int = 0 + failed: int = 0 + duration_s: float = 0.0 + # TTFT + mean_ttft_ms: float = 0.0 + median_ttft_ms: float = 0.0 + p95_ttft_ms: float = 0.0 + p99_ttft_ms: float = 0.0 + # E2E + mean_e2e_ms: float = 0.0 + median_e2e_ms: float = 0.0 + p95_e2e_ms: float = 0.0 + p99_e2e_ms: float = 0.0 + # ITL (inter-token latency, excluding first token) + mean_itl_ms: float = 0.0 + median_itl_ms: float = 0.0 + p95_itl_ms: float = 0.0 + p99_itl_ms: float = 0.0 + # RTX (real-time factor: synthesized audio seconds / generation seconds) + mean_rtx: float = 0.0 + median_rtx: float = 0.0 + overall_rtx: float = 0.0 # total_audio_s / wall_clock_duration (concurrency-aware) + # Throughput + total_tokens: int = 0 + total_audio_s: float = 0.0 + mean_tokens_per_request: float = 0.0 + token_throughput: float = 0.0 + request_throughput: float = 0.0 + per_request: list = field(default_factory=list) + + +# --------------------------------------------------------------------------- +# Inference +# --------------------------------------------------------------------------- + + +def _extract_request_output(stage_output): + """Return the RequestOutput-like object from a yielded stage output. + + AsyncOmni stages may yield either a wrapper carrying ``.request_output`` + (qwen3-tts style) or the RequestOutput directly (easymagpie demo style). + """ + return getattr(stage_output, "request_output", stage_output) + + +async def run_one_request( + omni, + prompt: dict, + sampling_params, + request_id: str, + meta: ModelMeta, + codec_fps: float, + stop_on_eos: bool, +) -> RequestResult: + """Submit one TTS request, collect per-token timing and audio length. + + Each engine step yields one decoded frame (one layer-0 token). We time the + first token (TTFT) and the gaps between subsequent tokens (ITL). The audio + EOS lives in codebook 0 of the accumulated ``audio_codes`` (not in the vLLM + token stream), so we watch the newest decoded frame and stop at the EOS + frame to recover the real synthesized length. + """ + import torch + + result = RequestResult() + t_start = time.perf_counter() + t_last_token = None + prev_num_tokens = 0 + eos_decode_idx = None # 0-based decode-frame index where audio EOS appears + + try: + gen = omni.generate( + prompt, + sampling_params_list=[sampling_params], + request_id=request_id, + ) + async for stage_output in gen: + now = time.perf_counter() + ro = _extract_request_output(stage_output) + result.steps += 1 + + cur_num_tokens = prev_num_tokens + if hasattr(ro, "outputs") and ro.outputs: + out0 = ro.outputs[0] + cum_ids = getattr(out0, "cumulative_token_ids", None) + if cum_ids is not None: + cur_num_tokens = len(cum_ids) + else: + cur_num_tokens = len(getattr(out0, "token_ids", []) or []) + + if cur_num_tokens > prev_num_tokens: + if t_last_token is None: + result.ttft_s = now - t_start + else: + result.inter_token_latencies.append(now - t_last_token) + t_last_token = now + + # Audio-EOS detection on the newest decoded frame. The accumulated + # audio_codes hold (T_ctx prefill + decode) rows; the last row is + # the newest decoded frame. Only meaningful past the speech delay. + mm = getattr(stage_output, "multimodal_output", None) or {} + audio_codes = mm.get("audio_codes") + newest_frame_idx = cur_num_tokens - 1 # 0-based decode-frame index + if ( + eos_decode_idx is None + and newest_frame_idx >= meta.speech_delay + and isinstance(audio_codes, torch.Tensor) + and audio_codes.numel() > 0 + ): + if int(audio_codes[-1, 0]) == meta.audio_eos_id: + eos_decode_idx = newest_frame_idx + result.eos_reached = True + + prev_num_tokens = cur_num_tokens + + if eos_decode_idx is not None and stop_on_eos: + break + + t_end = time.perf_counter() + result.e2e_s = t_end - t_start + result.num_generated = prev_num_tokens + result.success = True + + if result.ttft_s == 0.0 and result.steps > 0: + result.ttft_s = t_end - t_start + + # Real audio length: frames between the start of speech (speech_delay) + # and the EOS frame (or the full decode if no EOS was emitted). + last_audio_frame = eos_decode_idx if eos_decode_idx is not None else prev_num_tokens + result.audio_frames = max(0, last_audio_frame - meta.speech_delay) + if codec_fps > 0: + result.audio_s = result.audio_frames * meta.frame_stacking_factor / codec_fps + result.rtx = result.audio_s / result.e2e_s if result.e2e_s > 0 else 0.0 + + except Exception as exc: + result.e2e_s = time.perf_counter() - t_start + result.error = str(exc) + logger.error("Request %s failed: %s", request_id, exc) + finally: + # Make sure the async generator is closed (aborts the request in the + # engine when we broke out early on EOS). + try: + await gen.aclose() + except Exception: + pass + + return result + + +# --------------------------------------------------------------------------- +# Worker / concurrency +# --------------------------------------------------------------------------- + + +async def worker( + worker_id: int, + omni, + texts: list, + meta: ModelMeta, + context_text: str, + lt_temperature: float, + lt_topk: int, + sampling_params, + codec_fps: float, + stop_on_eos: bool, + results: list, + counter: dict, + lock: asyncio.Lock, +): + """Persistent async worker that picks texts until the quota is exhausted.""" + while True: + async with lock: + if counter["remaining"] <= 0: + break + counter["remaining"] -= 1 + idx = counter["issued"] + counter["issued"] += 1 + + text = texts[idx % len(texts)] + request_id = f"bench-easymp-w{worker_id}-{uuid.uuid4().hex[:8]}" + + prompt = build_prompt( + text=text, + meta=meta, + context_text=context_text, + lt_temperature=lt_temperature, + lt_topk=lt_topk, + ) + + result = await run_one_request( + omni, + prompt, + sampling_params, + request_id, + meta, + codec_fps, + stop_on_eos, + ) + result.text = text + result.prompt_len = len(prompt["prompt_token_ids"]) + + async with lock: + results.append(result) + done = len(results) + + if done % 10 == 0 or done == counter["total"]: + logger.info(" progress: %d / %d", done, counter["total"]) + + +# --------------------------------------------------------------------------- +# Metrics +# --------------------------------------------------------------------------- + + +def _pct(arr, p): + return float(np.percentile(arr, p)) if len(arr) > 0 else 0.0 + + +def compute_and_print_metrics( + results: list, + duration: float, + concurrency: int, + num_requests: int, +) -> BenchmarkResult: + successful = [r for r in results if r.success] + failed = [r for r in results if not r.success] + + bench = BenchmarkResult( + concurrency=concurrency, + num_requests=num_requests, + completed=len(successful), + failed=len(failed), + duration_s=duration, + ) + + if not successful: + print("ERROR: No requests completed successfully.") + return bench + + ttfts = [r.ttft_s * 1000 for r in successful] + e2es = [r.e2e_s * 1000 for r in successful] + rtxs = [r.rtx for r in successful] + all_itls = [] + for r in successful: + all_itls.extend([t * 1000 for t in r.inter_token_latencies]) + gen_tokens = [r.num_generated for r in successful] + + bench.mean_ttft_ms = float(np.mean(ttfts)) + bench.median_ttft_ms = float(np.median(ttfts)) + bench.p95_ttft_ms = _pct(ttfts, 95) + bench.p99_ttft_ms = _pct(ttfts, 99) + + bench.mean_e2e_ms = float(np.mean(e2es)) + bench.median_e2e_ms = float(np.median(e2es)) + bench.p95_e2e_ms = _pct(e2es, 95) + bench.p99_e2e_ms = _pct(e2es, 99) + + if all_itls: + bench.mean_itl_ms = float(np.mean(all_itls)) + bench.median_itl_ms = float(np.median(all_itls)) + bench.p95_itl_ms = _pct(all_itls, 95) + bench.p99_itl_ms = _pct(all_itls, 99) + + bench.mean_rtx = float(np.mean(rtxs)) + bench.median_rtx = float(np.median(rtxs)) + + bench.total_tokens = int(sum(gen_tokens)) + bench.total_audio_s = float(sum(r.audio_s for r in successful)) + bench.mean_tokens_per_request = float(np.mean(gen_tokens)) + bench.token_throughput = bench.total_tokens / duration if duration > 0 else 0.0 + bench.request_throughput = len(successful) / duration if duration > 0 else 0.0 + bench.overall_rtx = bench.total_audio_s / duration if duration > 0 else 0.0 + + bench.per_request = [ + { + "ttft_ms": r.ttft_s * 1000, + "e2e_ms": r.e2e_s * 1000, + "rtx": r.rtx, + "num_generated": r.num_generated, + "audio_frames": r.audio_frames, + "audio_s": r.audio_s, + "eos_reached": r.eos_reached, + "steps": r.steps, + "prompt_len": r.prompt_len, + "mean_itl_ms": float(np.mean([t * 1000 for t in r.inter_token_latencies])) + if r.inter_token_latencies + else 0.0, + "text": r.text, + } + for r in successful + ] + + eos_hits = sum(1 for r in successful if r.eos_reached) + + W = 56 + print(f"\n{'=' * W}") + print(f"{'Benchmark Result':^{W}}") + print(f"{'=' * W}") + print(f"{'Successful requests:':<42}{bench.completed}") + print(f"{'Failed requests:':<42}{bench.failed}") + print(f"{'Reached audio EOS:':<42}{eos_hits} / {bench.completed}") + print(f"{'Concurrency:':<42}{concurrency}") + print(f"{'Wall-clock duration (s):':<42}{duration:.2f}") + print(f"{'Request throughput (req/s):':<42}{bench.request_throughput:.2f}") + + print(f"\n{'-' * W}") + print(f"{'Time to First Token (TTFT)':^{W}}") + print(f"{'-' * W}") + print(f"{'Mean (ms):':<42}{bench.mean_ttft_ms:.2f}") + print(f"{'Median (ms):':<42}{bench.median_ttft_ms:.2f}") + print(f"{'P95 (ms):':<42}{bench.p95_ttft_ms:.2f}") + print(f"{'P99 (ms):':<42}{bench.p99_ttft_ms:.2f}") + + print(f"\n{'-' * W}") + print(f"{'End-to-End Latency (E2E)':^{W}}") + print(f"{'-' * W}") + print(f"{'Mean (ms):':<42}{bench.mean_e2e_ms:.2f}") + print(f"{'Median (ms):':<42}{bench.median_e2e_ms:.2f}") + print(f"{'P95 (ms):':<42}{bench.p95_e2e_ms:.2f}") + print(f"{'P99 (ms):':<42}{bench.p99_e2e_ms:.2f}") + + print(f"\n{'-' * W}") + print(f"{'Inter-Token Latency (ITL)':^{W}}") + print(f"{'-' * W}") + if all_itls: + print(f"{'Mean (ms):':<42}{bench.mean_itl_ms:.2f}") + print(f"{'Median (ms):':<42}{bench.median_itl_ms:.2f}") + print(f"{'P95 (ms):':<42}{bench.p95_itl_ms:.2f}") + print(f"{'P99 (ms):':<42}{bench.p99_itl_ms:.2f}") + else: + print(f"{'(no inter-token data)':^{W}}") + + print(f"\n{'-' * W}") + print(f"{'Real-Time Factor (RTX = audio_s / gen_s)':^{W}}") + print(f"{'-' * W}") + print(f"{'Mean RTX (per request):':<42}{bench.mean_rtx:.2f}x") + print(f"{'Median RTX (per request):':<42}{bench.median_rtx:.2f}x") + print(f"{'Overall RTX (total audio / wall):':<42}{bench.overall_rtx:.2f}x") + + print(f"\n{'-' * W}") + print(f"{'Throughput':^{W}}") + print(f"{'-' * W}") + print(f"{'Total frames generated:':<42}{bench.total_tokens}") + print(f"{'Total audio generated (s):':<42}{bench.total_audio_s:.2f}") + print(f"{'Mean frames / request:':<42}{bench.mean_tokens_per_request:.1f}") + print(f"{'Frame throughput (frames/s):':<42}{bench.token_throughput:.2f}") + print(f"{'=' * W}\n") + + if failed: + print(f" First {min(3, len(failed))} errors:") + for r in failed[:3]: + print(f" {r.error[:200]}") + + return bench + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +async def main(args): + from vllm import SamplingParams + from vllm_omni import AsyncOmni + + model_name = args.model + + # ── Load texts ──────────────────────────────────────────────────────── + if args.text_file: + path = Path(args.text_file) + if not path.exists(): + print(f"ERROR: text file not found: {path}") + return + raw_lines = [line.strip() for line in path.read_text().splitlines() if line.strip()] + texts = [] + for line in raw_lines: + if "\t" in line: + texts.append(line.split("\t", 1)[1].strip()) + else: + texts.append(line) + texts = [t for t in texts if t] + logger.info("Loaded %d texts from %s", len(texts), path) + else: + texts = DEFAULT_PROMPTS + logger.info("Using %d default prompts", len(texts)) + + if not texts: + print("ERROR: no texts available.") + return + + # ── Read arch scalars + tokenizer + speaker embedding ───────────────── + logger.info("Reading model metadata from %s ...", model_name) + meta = _load_model_meta( + model_dir=model_name, + speaker=args.speaker, + speaker_embedding_path=args.speaker_embedding, + context_text=args.context_text, + ) + logger.info( + "prompt_len=%d audio_eos_id=%d speech_delay=%d frame_stacking=%d", + meta.prompt_len, + meta.audio_eos_id, + meta.speech_delay, + meta.frame_stacking_factor, + ) + if meta.prompt_len + args.max_new_tokens > args.max_model_len: + logger.warning( + "prompt_len (%d) + max_new_tokens (%d) exceeds max_model_len (%d); raise --max-model-len.", + meta.prompt_len, + args.max_new_tokens, + args.max_model_len, + ) + + max_concurrency = max(args.concurrency) + + # ── Build stage config ──────────────────────────────────────────────── + stage_cfg = _build_easymagpie_stage_config( + max_num_seqs=max_concurrency, + profile=args.profile, + torch_profiler_dir=args.torch_profiler_dir, + with_stack=args.with_stack, + record_shapes=args.record_shapes, + gpu_memory_utilization=args.gpu_memory_utilization, + max_model_len=args.max_model_len, + max_num_batched_tokens=args.max_num_batched_tokens, + enforce_eager=args.enforce_eager, + max_new_tokens=args.max_new_tokens, + dtype=args.dtype, + distributed_executor_backend=args.distributed_executor_backend, + cudagraph_mode=args.cudagraph_mode, + ) + if args.cudagraph_mode is not None and args.enforce_eager: + logger.warning( + "--cudagraph-mode %s is ignored because --enforce-eager disables CUDA graphs.", + args.cudagraph_mode, + ) + elif args.cudagraph_mode is not None: + logger.info("CUDA-graph mode: %s", args.cudagraph_mode) + tmp_config_path = _write_temp_stage_config(stage_cfg) + + sampling_params = SamplingParams( + temperature=0.0, + max_tokens=args.max_new_tokens, + detokenize=False, + ignore_eos=True, + ) + + try: + logger.info("Creating AsyncOmni engine (EasyMagpie talker only) for %s ...", model_name) + omni = AsyncOmni( + model=model_name, + stage_configs_path=tmp_config_path, + log_stats=args.log_stats, + stage_init_timeout=args.stage_init_timeout, + ) + logger.info("Engine ready (single stage: EasyMagpie talker).") + + all_bench_results = [] + + for concurrency in args.concurrency: + logger.info( + "=== concurrency=%d requests=%d ===", + concurrency, + args.num_requests, + ) + + # ── Warmup ──────────────────────────────────────────────────── + warmup_count = 0 if args.no_warmup else args.num_warmups * concurrency + if warmup_count > 0: + logger.info("Warming up with %d requests (concurrency=%d)...", warmup_count, concurrency) + warmup_results: list = [] + warmup_counter = { + "remaining": warmup_count, + "issued": 0, + "total": warmup_count, + } + warmup_lock = asyncio.Lock() + warmup_tasks = [ + asyncio.create_task( + worker( + worker_id=i, + omni=omni, + texts=texts, + meta=meta, + context_text=args.context_text, + lt_temperature=args.lt_temperature, + lt_topk=args.lt_topk, + sampling_params=sampling_params, + codec_fps=args.codec_frame_rate, + stop_on_eos=not args.no_stop_on_eos, + results=warmup_results, + counter=warmup_counter, + lock=warmup_lock, + ) + ) + for i in range(concurrency) + ] + await asyncio.gather(*warmup_tasks) + warmup_ok = sum(1 for r in warmup_results if r.success) + logger.info("Warmup done: %d / %d succeeded.", warmup_ok, warmup_count) + + # ── Benchmark run ───────────────────────────────────────────── + logger.info("Starting benchmark run (%d requests, concurrency=%d)...", args.num_requests, concurrency) + + bench_results: list = [] + counter = { + "remaining": args.num_requests, + "issued": 0, + "total": args.num_requests, + } + lock = asyncio.Lock() + + if args.profile: + logger.info("Starting profiler ...") + await omni.start_profile( + profile_prefix=args.profile_prefix, + stages=[0], + ) + + start_time = time.perf_counter() + try: + tasks = [ + asyncio.create_task( + worker( + worker_id=i, + omni=omni, + texts=texts, + meta=meta, + context_text=args.context_text, + lt_temperature=args.lt_temperature, + lt_topk=args.lt_topk, + sampling_params=sampling_params, + codec_fps=args.codec_frame_rate, + stop_on_eos=not args.no_stop_on_eos, + results=bench_results, + counter=counter, + lock=lock, + ) + ) + for i in range(concurrency) + ] + await asyncio.gather(*tasks) + finally: + if args.profile: + logger.info("Stopping profiler ...") + await omni.stop_profile(stages=[0]) + + duration = time.perf_counter() - start_time + + bench = compute_and_print_metrics( + bench_results, + duration, + concurrency, + args.num_requests, + ) + bench.config_name = args.config_name + all_bench_results.append(asdict(bench)) + + # ── Save results ────────────────────────────────────────────────── + if args.result_dir: + result_dir = Path(args.result_dir) + result_dir.mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + result_file = result_dir / f"bench_easymagpie_{args.config_name}_{timestamp}.json" + with open(result_file, "w") as f: + json.dump(all_bench_results, f, indent=2) + logger.info("Results saved to %s", result_file) + + omni.shutdown() + finally: + os.unlink(tmp_config_path) + + logger.info("Done.") + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Benchmark the EasyMagpieTTS talker (AR stage only) via AsyncOmni", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + model = parser.add_argument_group("model / input") + model.add_argument( + "--model", + type=str, + default="./easymp_vllm_model", + help="Converted EasyMagpie model directory (output of easy_magpietts_convert_to_vllm.py)", + ) + model.add_argument( + "--text-file", + type=str, + default=None, + help="Path to text file (one utterance per line, optionally tab-separated with text in 2nd column)", + ) + model.add_argument( + "--speaker", + type=str, + default="eng", + help="Speaker embedding name under /speaker_embeddings/.pt", + ) + model.add_argument( + "--speaker-embedding", + type=str, + default=None, + help="Explicit path to a speaker embedding .pt (overrides --speaker)", + ) + model.add_argument( + "--context-text", + type=str, + default="[EN]", + help="Conditioning string tokenized + embedded in-engine (e.g. '[EN]')", + ) + model.add_argument( + "--lt-temperature", + type=float, + default=0.0, + help="Audio (local-transformer) sampling temperature (0.0 == argmax)", + ) + model.add_argument( + "--lt-topk", + type=int, + default=80, + help="Audio (local-transformer) sampling top-k", + ) + model.add_argument( + "--max-new-tokens", + type=int, + default=256, + help="Max decode frames per request (decode budget; trimmed at audio EOS)", + ) + model.add_argument( + "--codec-frame-rate", + type=float, + default=25.0, + help="Codec frame rate (Hz) used to convert decoded frames to audio seconds " + "(default 25 for the 25fps spectral codec)", + ) + + bench = parser.add_argument_group("benchmark") + bench.add_argument( + "-c", + "--concurrency", + type=int, + nargs="+", + default=[1], + help="Concurrency levels to test (space-separated, default: 1)", + ) + bench.add_argument( + "-n", + "--num-requests", + type=int, + default=50, + help="Total number of requests per concurrency level (default: 50)", + ) + bench.add_argument( + "--num-warmups", + type=int, + default=3, + help="Warmup rounds per concurrency level (total warmup = concurrency * this, default: 3)", + ) + bench.add_argument("--no-warmup", action="store_true", help="Skip warmup") + bench.add_argument( + "--no-stop-on-eos", + action="store_true", + help="Do not stop at the audio-EOS frame; run the full decode budget every request", + ) + bench.add_argument( + "--config-name", + type=str, + default="easymagpie", + help="Label for this run (used in result filenames)", + ) + bench.add_argument( + "--result-dir", + type=str, + default=None, + help="Directory to save JSON results", + ) + + engine = parser.add_argument_group("engine") + engine.add_argument("--gpu-memory-utilization", type=float, default=0.8) + engine.add_argument("--max-model-len", type=int, default=1024) + engine.add_argument("--max-num-batched-tokens", type=int, default=1024) + engine.add_argument("--dtype", type=str, default="float16", help="Model dtype (float16 / bfloat16)") + engine.add_argument("--enforce-eager", action="store_true") + engine.add_argument( + "--cudagraph-mode", + type=str, + default=None, + choices=["NONE", "PIECEWISE", "FULL", "FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"], + help="vLLM CUDA-graph capture strategy (compilation_config.cudagraph_mode). " + "Default: unset (vLLM default, FULL_AND_PIECEWISE). Use PIECEWISE to capture the " + "backbone and local transformer as separate graphs during decode so their split is " + "visible in a profiler (slightly slower than the default full decode graph). " + "Ignored when --enforce-eager is set.", + ) + engine.add_argument("--stage-init-timeout", type=int, default=300) + engine.add_argument("--log-stats", action="store_true", default=False) + engine.add_argument( + "--distributed-executor-backend", + type=str, + default="uni", + choices=["uni", "mp", "ray"], + help="vLLM executor backend. 'uni' runs the worker in-process and " + "avoids shm_broadcast IPC round-trips (recommended for TP=1, single " + "GPU). Default: uni.", + ) + + prof = parser.add_argument_group("profiling") + prof.add_argument( + "--profile", + action="store_true", + help="Enable torch profiler during the benchmark run", + ) + prof.add_argument("--profile-prefix", type=str, default=None, help="Prefix for profiler trace filenames") + prof.add_argument( + "--torch-profiler-dir", type=str, default="./profiler_traces", help="Directory for torch profiler traces" + ) + prof.add_argument("--with-stack", action="store_true", help="Record Python call stacks in profiler") + prof.add_argument("--record-shapes", action="store_true", help="Record tensor shapes in profiler") + + return parser.parse_args() + + +if __name__ == "__main__": + asyncio.run(main(parse_args())) From f5c06a519c2c942cdc095d21e05df98329f7c582 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Wed, 3 Jun 2026 14:52:04 +0200 Subject: [PATCH 11/15] examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py: do ckpt conversion without precision loss Signed-off-by: Viacheslav Klimkov --- .../tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py b/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py index 4cb99a08baee..af5cbe720f04 100644 --- a/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py +++ b/examples/tts/easymagpie_vllm_omni/easy_magpietts_convert_to_vllm.py @@ -157,7 +157,7 @@ def parse_args(): parser.add_argument("--context_audio_duration", type=float, default=5.0) parser.add_argument( "--dtype", - default="bfloat16", + default="float32", choices=["bfloat16", "float16", "float32"], help="Saved weight dtype / config torch_dtype. bf16 matches the reference inference setup.", ) From 8721e54e888ee86dce93ccb6f3e47f3e87128647 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Wed, 3 Jun 2026 16:16:59 +0200 Subject: [PATCH 12/15] examples/tts/easymagpie_vllm_omni/tests: add tests to check equivalence of cudagraph-friendly LT re-implemantation Signed-off-by: Viacheslav Klimkov --- .../easymagpie_vllm_omni/tests/conftest.py | 78 ++++++ .../easymagpie_vllm_omni/tests/test_config.py | 88 +++++++ .../tests/test_local_transformer.py | 242 ++++++++++++++++++ 3 files changed, 408 insertions(+) create mode 100644 examples/tts/easymagpie_vllm_omni/tests/conftest.py create mode 100644 examples/tts/easymagpie_vllm_omni/tests/test_config.py create mode 100644 examples/tts/easymagpie_vllm_omni/tests/test_local_transformer.py diff --git a/examples/tts/easymagpie_vllm_omni/tests/conftest.py b/examples/tts/easymagpie_vllm_omni/tests/conftest.py new file mode 100644 index 000000000000..8170bd62b34c --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/tests/conftest.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Shared pytest fixtures for the EasyMagpieTTS vLLM-Omni tests. + +The model definition (``easymagpie_vllm_omni.local_transformer`` etc.) is plain +PyTorch: the ``@support_torch_compile`` decorator short-circuits to eager when +``compilation_config.mode == CompilationMode.NONE``, and the modules only read a +handful of scalars off the ``VllmConfig``. So the whole stack can be exercised as +ordinary PyTorch with a tiny stand-in config — **no model directory, no engine, +no GPU required** — which is what these fixtures provide. + +All heavy imports (torch / vllm) are done lazily inside the fixture so test +collection never fails on machines where those packages are absent; the +dependent tests ``importorskip`` them and are skipped instead. +""" +from __future__ import annotations + +import types + +import pytest + +# A deliberately tiny architecture so the tests run fast on CPU. Dimensions are +# kept equal by default (so the in/out projections collapse to ``nn.Identity``, +# matching the reference SmallMamba checkpoint where everything is 1536-wide). +_DEFAULT_ARCH: dict = dict( + hidden_dim=64, + embedding_dim=64, + audio_embedding_dim=64, + num_audio_codebooks=2, + codebook_size=32, + frame_stacking_factor=2, + local_transformer_n_layers=2, + local_transformer_n_heads=4, + local_transformer_hidden_dim=64, +) + + +def build_vllm_config(**arch_overrides): + """Build a minimal stand-in ``VllmConfig`` for the code predictor. + + Returns a ``types.SimpleNamespace`` exposing exactly the attributes the + EasyMagpie modules touch at construction time: + + * ``model_config.hf_config`` — arch scalars (read via ``from_hf_config``); + * ``model_config.dtype`` — buffer dtype; + * ``scheduler_config.max_num_batched_tokens`` — scratch-buffer length; + * ``compilation_config.mode`` — ``CompilationMode.NONE`` so the + ``@support_torch_compile`` wrapper stays in eager mode. + + Any keyword overrides are merged into the default tiny arch profile. + """ + import torch + from vllm.config import CompilationMode + + arch = {**_DEFAULT_ARCH, **arch_overrides} + hf_config = types.SimpleNamespace(**arch) + return types.SimpleNamespace( + model_config=types.SimpleNamespace(hf_config=hf_config, dtype=torch.float32), + scheduler_config=types.SimpleNamespace(max_num_batched_tokens=128), + compilation_config=types.SimpleNamespace(mode=CompilationMode.NONE), + ) + + +@pytest.fixture +def vllm_config_factory(): + """Fixture returning the :func:`build_vllm_config` factory.""" + return build_vllm_config diff --git a/examples/tts/easymagpie_vllm_omni/tests/test_config.py b/examples/tts/easymagpie_vllm_omni/tests/test_config.py new file mode 100644 index 000000000000..e8955d348328 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/tests/test_config.py @@ -0,0 +1,88 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pure-Python tests for :class:`EasyMagpieOmniArch`. + +These have no heavy dependencies (no torch / vllm) and validate the derived +quantities and the ``from_hf_config`` merge logic that the rest of the model +relies on for correct vocab sizes and special-token ids. +""" +from __future__ import annotations + +import types + +from easymagpie_vllm_omni.config import ( + EASYMAGPIE_SMALLMAMBA, + NUM_SPECIAL_AUDIO_TOKENS, + SPECIAL_AUDIO_EOS, + SPECIAL_AUDIO_MASK, + EasyMagpieOmniArch, +) + + +def test_derived_codebook_counts(): + arch = EasyMagpieOmniArch(num_audio_codebooks=8, frame_stacking_factor=2, codebook_size=1024) + assert arch.num_stacked_codebooks == 16 + assert arch.num_all_tokens_per_codebook == 1024 + NUM_SPECIAL_AUDIO_TOKENS + + +def test_special_token_ids_default_to_codebook_offsets(): + arch = EasyMagpieOmniArch(codebook_size=1024) + assert arch.audio_eos_id == 1024 + SPECIAL_AUDIO_EOS + assert arch.mask_token_id == 1024 + SPECIAL_AUDIO_MASK + # EOS must remain inside the per-codebook vocab so it stays sampleable. + assert arch.audio_eos_id < arch.num_all_tokens_per_codebook + + +def test_forced_special_token_ids_override_defaults(): + arch = EasyMagpieOmniArch( + codebook_size=1024, + forced_audio_bos_id=1024, + forced_audio_eos_id=1025, + forced_mask_token_id=1028, + ) + assert arch.audio_bos_id == 1024 + assert arch.audio_eos_id == 1025 + assert arch.mask_token_id == 1028 + + +def test_phoneme_ids_fall_back_to_tokenizer_convention(): + arch = EasyMagpieOmniArch(phoneme_vocab_size=2051) + assert arch.resolved_phoneme_bos_id == 2048 + assert arch.resolved_phoneme_eos_id == 2049 + assert arch.resolved_phoneme_unk_id == 2050 + + +def test_from_hf_config_overrides_and_ignores_unknown(): + hf_config = types.SimpleNamespace( + num_audio_codebooks=4, + codebook_size=2048, + frame_stacking_factor=1, + local_transformer_n_layers=5, + some_unrelated_field="ignored", + ) + arch = EasyMagpieOmniArch.from_hf_config(hf_config) + assert arch.num_audio_codebooks == 4 + assert arch.codebook_size == 2048 + assert arch.frame_stacking_factor == 1 + assert arch.local_transformer_n_layers == 5 + # Untouched fields keep the default profile. + assert arch.audio_embedding_dim == EASYMAGPIE_SMALLMAMBA.audio_embedding_dim + + +def test_from_hf_config_hidden_size_fallback(): + hf_config = types.SimpleNamespace(hidden_size=999) + arch = EasyMagpieOmniArch.from_hf_config(hf_config) + assert arch.hidden_dim == 999 + # embedding_dim defaults to the same backbone width when not given explicitly. + assert arch.embedding_dim == 999 diff --git a/examples/tts/easymagpie_vllm_omni/tests/test_local_transformer.py b/examples/tts/easymagpie_vllm_omni/tests/test_local_transformer.py new file mode 100644 index 000000000000..21f9b3f92d98 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/tests/test_local_transformer.py @@ -0,0 +1,242 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Validity tests for the vLLM-Omni EasyMagpieTTS local transformer. + +The headline test is a **numerical parity check against the reference NeMo +implementation** (``transformer_2501.Transformer`` + the projection / embedding +heads, exactly as wired in ``EasyMagpieTTSInferenceModel``): random NeMo weights +are copied 1:1 into the vLLM ``EasyMagpieCodePredictor`` and both stacks are run +teacher-forced on identical inputs; the per-codebook logits must match to fp32 +tolerance with identical argmax. This is the pytest port of +``debug_local_transformer.py`` and guards against the re-implementation silently +drifting from the training-time math. + +The remaining tests assert the autoregressive sampler's contract (output shape / +dtype / value range, forbidden-token masking, and seeded determinism). + +Everything runs as plain PyTorch on CPU via the tiny stand-in config from +``conftest.py`` — no model directory, no vLLM engine, no GPU. +""" +from __future__ import annotations + +import pytest + +torch = pytest.importorskip("torch") +pytest.importorskip("vllm") +transformer_2501 = pytest.importorskip("nemo.collections.tts.modules.transformer_2501") + +from conftest import build_vllm_config # noqa: E402 +from easymagpie_vllm_omni.config import EasyMagpieOmniArch # noqa: E402 +from easymagpie_vllm_omni.local_transformer import EasyMagpieCodePredictor # noqa: E402 +from torch import nn # noqa: E402 + +# Two arch profiles: one where all widths are equal (in/out projections are +# Identity, matching the real checkpoint) and one where they differ (projections +# are real Linears) — so the weight-copy + parity covers both code paths. +ARCH_PROFILES = { + "equal_dims": dict( + hidden_dim=64, + embedding_dim=64, + audio_embedding_dim=64, + local_transformer_hidden_dim=64, + local_transformer_n_heads=4, + ), + "mixed_dims": dict( + hidden_dim=64, + embedding_dim=64, + audio_embedding_dim=48, + local_transformer_hidden_dim=80, + local_transformer_n_heads=4, + ), +} + + +class NeMoLocalTransformerStack(nn.Module): + """Reference NeMo local-transformer submodules, named to match the vLLM code predictor. + + Mirrors the wiring in ``EasyMagpieTTSInferenceModel.__init__`` (the + ``local_transformer*`` / ``audio_*`` heads). Attribute names match + :class:`EasyMagpieCodePredictor` so a state-dict copy is 1:1. + """ + + def __init__(self, arch: EasyMagpieOmniArch) -> None: + super().__init__() + self.n_codebooks = arch.num_stacked_codebooks + self.num_all_tokens = arch.num_all_tokens_per_codebook + embedding_dim = arch.embedding_dim + audio_dim = arch.audio_embedding_dim + lt_hidden = arch.local_transformer_hidden_dim + + self.audio_embeddings = nn.ModuleList( + [nn.Embedding(self.num_all_tokens, audio_dim) for _ in range(self.n_codebooks)] + ) + self.audio_in_projection = nn.Linear(audio_dim, embedding_dim) if audio_dim != embedding_dim else nn.Identity() + self.local_transformer_in_projection = ( + nn.Linear(embedding_dim, lt_hidden) if lt_hidden != embedding_dim else nn.Identity() + ) + self.local_transformer = transformer_2501.Transformer( + n_layers=arch.local_transformer_n_layers, + d_model=lt_hidden, + d_ffn=lt_hidden * 4, + sa_n_heads=arch.local_transformer_n_heads, + kernel_size=1, + is_causal=True, + max_length_causal_mask=self.n_codebooks + 2, + use_learnable_pos_emb=True, + ) + self.local_transformer_audio_out_projection = ( + nn.Linear(lt_hidden, audio_dim) if audio_dim != lt_hidden else nn.Identity() + ) + self.local_transformer_out_projections = nn.ModuleList( + [nn.Linear(audio_dim, self.num_all_tokens) for _ in range(self.n_codebooks)] + ) + + @torch.no_grad() + def teacher_forced_logits(self, dec_hidden: torch.Tensor, codes: torch.Tensor) -> torch.Tensor: + """Per-codebook logits given a hidden state and teacher-forced previous codes. + + Replicates ``LocalTransformerHelper.compute_logits`` (AR layout): the input + sequence is ``[dec_hidden, emb(code_0), ..., emb(code_{N-1})]``; row ``k`` of + the causal output predicts codebook ``k``, and the trailing row is dropped. + """ + seq = [dec_hidden] + for k in range(self.n_codebooks): + seq.append(self.audio_in_projection(self.audio_embeddings[k](codes[:, k]))) + x = torch.stack(seq, dim=1) # (T, N+1, embedding_dim) + x = self.local_transformer_in_projection(x) # (T, N+1, lt_hidden) + mask = torch.ones(x.size(0), x.size(1), device=x.device, dtype=x.dtype) + out = self.local_transformer(x, mask)["output"][:, :-1, :] # (T, N, lt_hidden) + out = self.local_transformer_audio_out_projection(out) # (T, N, audio_dim) + logits = [self.local_transformer_out_projections[k](out[:, k, :]) for k in range(self.n_codebooks)] + return torch.stack(logits, dim=1) # (T, N, vocab) + + +@torch.no_grad() +def _vllm_teacher_forced_logits( + cp: EasyMagpieCodePredictor, dec_hidden: torch.Tensor, codes: torch.Tensor +) -> torch.Tensor: + """Per-codebook logits from the vLLM code predictor, teacher-forced. + + Mirrors :meth:`EasyMagpieCodePredictor.generate_codes` buffer layout (``N`` + rows; row 0 = ``in_proj(dec_hidden)``, row ``k+1`` = projected embedding of + ``codes[:, k]``), but reads the logits for every row instead of sampling. + """ + num_tokens = dec_hidden.shape[0] + n = cp.num_codebooks + lt_hidden = cp._buf_inputs.shape[-1] + buf = torch.zeros(num_tokens, n, lt_hidden, dtype=dec_hidden.dtype, device=dec_hidden.device) + buf[:, 0, :] = cp.local_transformer_in_projection(dec_hidden) + for k in range(n - 1): + emb = cp.audio_in_projection(cp.audio_embeddings[k](codes[:, k])) + buf[:, k + 1, :] = cp.local_transformer_in_projection(emb) + hidden = cp.local_transformer(buf) # (T, N, lt_hidden) + logits = [] + for k in range(n): + row = cp.local_transformer_audio_out_projection(hidden[:, k, :]) + logits.append(cp.local_transformer_out_projections[k](row)) + return torch.stack(logits, dim=1) # (T, N, vocab) + + +def _copy_nemo_into_vllm(nemo: NeMoLocalTransformerStack, cp: EasyMagpieCodePredictor) -> None: + """Copy every vLLM code-predictor parameter from the matching NeMo parameter (names align 1:1).""" + nemo_sd = nemo.state_dict() + missing = [] + for name, param in cp.named_parameters(): + if name in nemo_sd: + assert param.shape == nemo_sd[name].shape, f"shape mismatch {name}" + param.data.copy_(nemo_sd[name].to(param.dtype)) + else: + missing.append(name) + assert not missing, f"vLLM params with no NeMo counterpart: {missing}" + + +def _build_pair(profile_kwargs: dict, seed: int = 0): + """Build a (code_predictor, nemo_stack, arch) triple with NeMo weights copied in.""" + cfg = build_vllm_config(**profile_kwargs) + arch = EasyMagpieOmniArch.from_hf_config(cfg.model_config.hf_config) + + cp = EasyMagpieCodePredictor(vllm_config=cfg, prefix="code_predictor").eval() + cp.init_forbidden_mask() + + gen = torch.Generator().manual_seed(seed) + nemo = NeMoLocalTransformerStack(arch).float().eval() + with torch.no_grad(): + for prm in nemo.parameters(): + prm.copy_(torch.empty(prm.shape).normal_(0.0, 0.02, generator=gen)) + _copy_nemo_into_vllm(nemo, cp) + return cp, nemo, arch + + +@pytest.mark.unit +@pytest.mark.parametrize("profile", list(ARCH_PROFILES), ids=list(ARCH_PROFILES)) +def test_local_transformer_matches_nemo(profile): + """vLLM re-implementation must equal the NeMo reference in fp32 (teacher-forced).""" + cp, nemo, arch = _build_pair(ARCH_PROFILES[profile]) + + torch.manual_seed(1234) + num_tokens = 6 + dec_hidden = torch.randn(num_tokens, arch.hidden_dim) + codes = torch.randint(0, arch.codebook_size, (num_tokens, arch.num_stacked_codebooks)) + + nemo_logits = nemo.teacher_forced_logits(dec_hidden, codes) + vllm_logits = _vllm_teacher_forced_logits(cp, dec_hidden, codes) + + assert vllm_logits.shape == nemo_logits.shape + max_abs_diff = (vllm_logits - nemo_logits).abs().max().item() + argmax_mismatch = (vllm_logits.argmax(-1) != nemo_logits.argmax(-1)).sum().item() + assert max_abs_diff < 1e-3, f"max abs diff too large: {max_abs_diff:.3e}" + assert argmax_mismatch == 0, f"{argmax_mismatch} argmax mismatches" + + +@pytest.mark.unit +def test_generate_codes_shape_dtype_and_range(): + """``generate_codes`` returns valid (num_tokens, num_codebooks) int64 codes within vocab.""" + cp, _, arch = _build_pair(ARCH_PROFILES["equal_dims"]) + num_tokens = 5 + + torch.manual_seed(0) + codes = cp.generate_codes(torch.randn(num_tokens, arch.hidden_dim)) + + assert codes.shape == (num_tokens, arch.num_stacked_codebooks) + assert codes.dtype == torch.long + assert codes.min().item() >= 0 + assert codes.max().item() < arch.num_all_tokens_per_codebook + + +@pytest.mark.unit +def test_generate_codes_respects_forbidden_mask(): + """With argmax sampling, forbidden special tokens are never emitted (only EOS stays reachable).""" + cp, _, arch = _build_pair(ARCH_PROFILES["equal_dims"]) + cp.temperature = 0.0 # argmax over masked logits + + torch.manual_seed(0) + codes = cp.generate_codes(torch.randn(7, arch.hidden_dim)) + + # Allowed = real codebook tokens [0, codebook_size) plus the audio EOS id. + allowed = (codes < arch.codebook_size) | (codes == arch.audio_eos_id) + assert allowed.all(), f"sampled forbidden tokens: {sorted(set(codes[~allowed].tolist()))}" + + +@pytest.mark.unit +def test_generate_codes_deterministic_with_seed(): + """Same seed + same input ⇒ identical sampled codes (sampler is RNG-driven, no host state).""" + cp, _, arch = _build_pair(ARCH_PROFILES["equal_dims"]) + dec_hidden = torch.randn(4, arch.hidden_dim) + + torch.manual_seed(7) + first = cp.generate_codes(dec_hidden) + torch.manual_seed(7) + second = cp.generate_codes(dec_hidden) + + assert torch.equal(first, second) From 4eda162dccdf74ba90fc8727813dfd41a3cc22ec Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Wed, 3 Jun 2026 16:18:42 +0200 Subject: [PATCH 13/15] examples/tts/easymagpie_vllm_omni: hotfix for nemotron_h in fp16, need scaling Signed-off-by: Viacheslav Klimkov --- .../easymagpie_vllm_omni/backbone_patches.py | 48 +++++++++++++++++++ .../easymagpie_vllm_omni/easymagpie.py | 7 ++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/backbone_patches.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/backbone_patches.py index efe8421f7af2..da57ce6742cc 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/backbone_patches.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/backbone_patches.py @@ -20,6 +20,7 @@ """ from __future__ import annotations +import torch import torch.nn as nn import torch.nn.functional as F from vllm.logger import init_logger @@ -62,3 +63,50 @@ def patch_silu_shared_experts(backbone) -> int: patched += 1 logger.info("SiLU shared_experts fix installed on %d layers", patched) return patched + + +def patch_moe_routed_scale(backbone) -> int: + """Restore ``routed_scaling_factor`` on the NemotronHMoE output in FP16. + + vLLM's ``FusedMoE`` uses an FP16 overflow trick: with + ``apply_routed_scale_to_output=True`` it does **not** multiply the routed + output by ``s`` (=routed_scaling_factor); in FP16 it instead divides the + *shared* output by ``s`` and relies on the decoder layer to keep the whole + residual stream scaled by ``1/s`` (see ``DeepseekV2DecoderLayer.forward``). + NemotronH's decoder layer never applies that compensation, so in FP16 the + MoE block emits ``routed_raw + shared/s == (s*routed + shared)/s`` — the + correct value divided by ``s``. The MoE contribution to the residual ends up + ``s``× too small and the error accumulates across the MoE layers. + + We re-multiply each MoE mixer's output by ``s`` in FP16:: + + s * (routed_raw + shared/s) = s*routed_raw + shared + + which matches the NeMo reference. FP32/BF16 already take the correct + ``fused_output *= s`` branch, so the hook is a no-op there. + + Args: + backbone: the ``NemotronHModel`` instance. + + Returns: + Number of layers patched. + """ + patched = 0 + for layer in backbone.layers: + mixer = getattr(layer, "mixer", None) + if mixer is None or mixer.__class__.__name__ != "NemotronHMoE": + continue + scale = float(getattr(mixer, "routed_scaling_factor", 1.0)) + if scale == 1.0: + continue + + def _scale_output(_mod, _inp, out, _scale=scale): + # FusedMoE only defers the scale in FP16; leave other dtypes alone. + if isinstance(out, torch.Tensor) and out.dtype == torch.float16: + return out * _scale + return out + + mixer.register_forward_hook(_scale_output) + patched += 1 + logger.info("FP16 MoE routed-scale fix installed on %d layers", patched) + return patched diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py index d21f357c36aa..184ec50f9e02 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -108,7 +108,7 @@ from vllm_omni.model_executor.models.output_templates import OmniOutput -from easymagpie_vllm_omni.backbone_patches import patch_silu_shared_experts +from easymagpie_vllm_omni.backbone_patches import patch_moe_routed_scale, patch_silu_shared_experts from easymagpie_vllm_omni.config import EasyMagpieOmniArch from easymagpie_vllm_omni.local_transformer import EasyMagpieCodePredictor @@ -182,6 +182,11 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: # NemotronHMLP hard-codes ReLU² in shared_experts. Restore SiLU (no-op # when the backbone has no MoE layers). patch_silu_shared_experts(self.backbone) + # vLLM's FusedMoE defers routed_scaling_factor to the decoder layer in + # FP16, but NemotronH's decoder layer never compensates, so the MoE + # output is under-scaled by routed_scaling_factor. Restore it (no-op in + # fp32/bf16 and when there are no MoE layers). + patch_moe_routed_scale(self.backbone) # ── Local transformer (its own compile group / CUDA graph) ────── with set_model_tag("local_transformer"): From 4c7388ff76b1656f234ca61bb8b6e64611bcbbc6 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Thu, 4 Jun 2026 14:56:35 +0200 Subject: [PATCH 14/15] examples/tts/easymagpie_vllm_omni: introduce EOS forwarding from LT sampled tokens Signed-off-by: Viacheslav Klimkov --- .../benchmark_easymagpie_tts.py | 10 ++- .../easymagpie_inference_demo.ipynb | 35 ++++++--- .../easymagpie_vllm_omni/easymagpie.py | 72 +++++++++++++++++-- 3 files changed, 99 insertions(+), 18 deletions(-) diff --git a/examples/tts/easymagpie_vllm_omni/benchmark_easymagpie_tts.py b/examples/tts/easymagpie_vllm_omni/benchmark_easymagpie_tts.py index c8ac80f4e563..7abe8b296143 100644 --- a/examples/tts/easymagpie_vllm_omni/benchmark_easymagpie_tts.py +++ b/examples/tts/easymagpie_vllm_omni/benchmark_easymagpie_tts.py @@ -246,6 +246,7 @@ class ModelMeta: audio_eos_id: int speech_delay: int frame_stacking_factor: int + stop_token_id: int # backbone stop token the model emits at the audio-EOS frame def _load_model_meta( @@ -302,6 +303,7 @@ def _load_model_meta( audio_eos_id=int(arch.audio_eos_id), speech_delay=int(getattr(arch, "streaming_speech_delay", 0) or 0), frame_stacking_factor=int(arch.frame_stacking_factor), + stop_token_id=EasyMagpieTTSForConditionalGeneration.audio_eos_stop_token_id(type("Cfg", (), config)), ) @@ -462,7 +464,9 @@ async def run_one_request( and isinstance(audio_codes, torch.Tensor) and audio_codes.numel() > 0 ): - if int(audio_codes[-1, 0]) == meta.audio_eos_id: + # audio EOS in ANY codebook (not just codebook 0) — mirrors the + # reference EOS check and the model's own stop signal. + if bool((audio_codes[-1] == meta.audio_eos_id).any()): eos_decode_idx = newest_frame_idx result.eos_reached = True @@ -798,6 +802,10 @@ async def main(args): max_tokens=args.max_new_tokens, detokenize=False, ignore_eos=True, + # The model emits this backbone token at the audio-EOS frame (audio EOS in + # any codebook), so vLLM stops the request there instead of decoding the + # full budget. stop_token_ids is honored even with ignore_eos. + stop_token_ids=[meta.stop_token_id], ) try: diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb index 192e08e1ab03..7aeadfaf7129 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_inference_demo.ipynb @@ -128,12 +128,15 @@ "TEXT_VOCAB = int(config[\"text_vocab_size\"])\n", "TEXT_EOS_ID = TEXT_VOCAB - 2 # matches EasyMagpieTTSInferenceModel.eos_id\n", "\n", + "AUDIO_STOP_TOKEN_ID = max(1, int(config.get(\"vocab_size\", 2)) - 1)\n", + "\n", "print(f\"Model dir : {MODEL_DIR}\")\n", "print(f\"embedding_dim : {arch.embedding_dim}\")\n", "print(f\"num_stacked_codebooks : {arch.num_stacked_codebooks} (C*S)\")\n", "print(f\"tokens / codebook : {arch.num_all_tokens_per_codebook} (codebook_size + specials)\")\n", "print(f\"audio_bos / audio_eos id : {arch.audio_bos_id} / {arch.audio_eos_id}\")\n", - "print(f\"text_vocab / text_eos : {TEXT_VOCAB} / {TEXT_EOS_ID}\")" + "print(f\"text_vocab / text_eos : {TEXT_VOCAB} / {TEXT_EOS_ID}\")\n", + "print(f\"audio-EOS stop token id : {AUDIO_STOP_TOKEN_ID}\")" ] }, { @@ -216,7 +219,9 @@ " \"temperature\": 0.0,\n", " \"max_tokens\": DECODE_STEPS,\n", " \"detokenize\": False,\n", + " # model forwards EOS to dummy output tokens\n", " \"ignore_eos\": True,\n", + " \"stop_token_ids\": [AUDIO_STOP_TOKEN_ID],\n", " },\n", " }\n", " ],\n", @@ -342,7 +347,12 @@ " temperature=0.0, # backbone token sampler is a no-op (audio is sampled in the local transformer)\n", " max_tokens=DECODE_STEPS,\n", " detokenize=False,\n", - " ignore_eos=True, # audio EOS lives in the codes, not the vLLM token stream -> run the budget + trim\n", + " ignore_eos=True, # audio EOS lives in the codes, not the vLLM token stream\n", + " # The model emits AUDIO_STOP_TOKEN_ID on the backbone stream at the EOS frame\n", + " # (audio EOS in any codebook), so vLLM ends the request there instead of\n", + " # decoding the full DECODE_STEPS budget. stop_token_ids is honored regardless\n", + " # of ignore_eos.\n", + " stop_token_ids=[AUDIO_STOP_TOKEN_ID],\n", ")" ] }, @@ -409,15 +419,20 @@ " print(f\"dropping {speech_delay} leading speech-delay warm-up frames\")\n", " audio_codes = audio_codes[speech_delay:].contiguous()\n", "\n", - " # Trim at the audio EOS: the model signals end-of-speech inside the codes\n", - " # (codebook 0 == audio_eos_id), not via the vLLM token stream.\n", - " eos_frames = (audio_codes[:, 0] == arch.audio_eos_id).nonzero(as_tuple=True)[0]\n", - " if eos_frames.numel() > 0:\n", - " eos_idx = int(eos_frames[0])\n", - " print(f\"audio EOS at frame : {eos_idx} / {audio_codes.shape[0]}\")\n", - " audio_codes = audio_codes[:eos_idx].contiguous()\n", + " # Trim the trailing audio-EOS frame. The engine stops the request the moment\n", + " # the backbone emits AUDIO_STOP_TOKEN_ID (driven high at the audio-EOS frame),\n", + " # so when it finished for that reason the *last* decoded frame is the EOS frame\n", + " # itself — its codes carry audio_eos_id and must not be vocoded.\n", + " # NOTE: we actually expect EOS to be emited\n", + " co = final_ro.outputs[0] if final_ro.outputs else None\n", + " finish_reason = getattr(co, \"finish_reason\", None)\n", + " stop_reason = getattr(co, \"stop_reason\", None)\n", + " print(f\"finish_reason / stop_reason: {finish_reason} / {stop_reason}\")\n", + " if finish_reason == \"stop\" and stop_reason == AUDIO_STOP_TOKEN_ID and audio_codes.shape[0] > 0:\n", + " print(f\"dropping trailing audio-EOS frame at {audio_codes.shape[0] - 1}\")\n", + " audio_codes = audio_codes[:-1].contiguous()\n", " else:\n", - " print(f\"no audio EOS within budget ({DECODE_STEPS} frames); using full decode\")\n", + " print(f\"no engine EOS stop (finish_reason={finish_reason}); using full decode\")\n", "\n", " print(f\"audio_codes shape (decode) : {tuple(audio_codes.shape)}\")\n", " print(f\"audio_codes dtype : {audio_codes.dtype}\")\n", diff --git a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py index 184ec50f9e02..5c6c29a17d0f 100644 --- a/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py +++ b/examples/tts/easymagpie_vllm_omni/easymagpie_vllm_omni/easymagpie.py @@ -260,10 +260,36 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self._out_codes = torch.zeros(max_num_tokens, self.num_codebooks, dtype=torch.long) + # ── Audio-EOS → engine stop ───────────────────────────────────── + # The model signals end-of-speech inside the audio codebooks. + # To make vLLM terminate the request at the EOS frame, + # we flags decode positions with ``audio_eos_id`` emit designated ``stop_token_id`` + # in ``compute_logits``. + # Callers must pass ``SamplingParams(stop_token_ids=[stop_id])`` with + # ``stop_id = audio_eos_stop_token_id(hf_config)``. + self.audio_eos_id = int(arch.audio_eos_id) + self._stop_token_id = self.audio_eos_stop_token_id(hf_config) + # flags frames in which ``_out_codes`` contain ``audio_eos_id`` + self._token_stop = torch.zeros(max_num_tokens, dtype=torch.bool) + # slice of ``token_stop`` based on ``logit_idx`` that can be used in + # ``compute_logits`` + self._sample_stop = torch.zeros(max_num_tokens, dtype=torch.bool) + # ------------------------------------------------------------------ # Embedding helpers # ------------------------------------------------------------------ + @staticmethod + def audio_eos_stop_token_id(hf_config: Any) -> int: + """Backbone token id this model emits when audio EOS is reached. + + Audio end-of-speech lives in the codebooks, not the backbone token + stream, so the dummy backbone vocab is repurposed as a 2-way stop + signal: index ``0`` == "continue", the last index == "stop". Callers + must pass ``SamplingParams(stop_token_ids=[this])`` + """ + return max(1, int(getattr(hf_config, "vocab_size", 2)) - 1) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: """Compatibility shim — unused at runtime (everything goes via inputs_embeds).""" return self.text_embedding(input_ids) @@ -368,7 +394,7 @@ def forward( positions: torch.Tensor, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - **_: Any, + **kwargs: Any, ) -> torch.Tensor: """Assemble the per-token embedding, run the backbone, then the codes. @@ -385,6 +411,11 @@ def forward( else: combined.zero_() + # Reset per-token stop flags for this step (so prefill / warm-up rows stay + # "continue"); decode positions get set below by :meth:`_flag_audio_eos`. + self._token_stop[:num_tokens].zero_() + logits_index = kwargs.get("logits_index") + decode_idx, num_req = self._get_decode_idxs() if decode_idx is None: @@ -406,6 +437,7 @@ def forward( if decode_idx is None: codes = self.code_predictor.generate_codes(hidden_states) self._out_codes[:num_tokens].copy_(codes) + self._flag_audio_eos(codes, slice(0, num_tokens)) if self.has_phoneme: self._predict_phonemes(hidden_states, slice(0, num_tokens)) elif num_req > 0: @@ -416,11 +448,30 @@ def forward( ctx.batch_descriptor = orig_bd valid = decode_idx[:num_req] self._out_codes[valid] = codes[:num_req] + self._flag_audio_eos(codes[:num_req], valid) if self.has_phoneme: self._predict_phonemes(hidden_states, valid) + # Re-index _token_stop into _sample_stop. + # this only happens for mixed/prefill, since for capture logits_index is None, + # so during decode-only the branch for logits_index is None will be executed. + if logits_index is not None: + self._sample_stop[:logits_index.shape[0]] = self._token_stop[logits_index] + else: + self._sample_stop[:num_tokens].copy_(self._token_stop[:num_tokens]) + return hidden_states + def _flag_audio_eos(self, codes: torch.Tensor, idx) -> None: + """Flag decode positions whose newly sampled frame ends speech. + Checks codes for eos and assigns token_stop[idx] + + Note: this uses the *sampled* codes. NeMo also checks armax(logits) == eos_idx, + i.e. checks if EOS is emited without sampling. Skip for now. + """ + eos = (codes == self.audio_eos_id).any(dim=1) & (self._dec_audio_valid[idx] == 1) + self._token_stop[idx] = eos + def _assemble_decode_embeddings(self, combined: torch.Tensor, idx) -> None: """Add ``text + phoneme + audio`` embeddings into ``combined`` at ``idx``.""" # Audio: previous-frame codes (gated by validity). @@ -477,18 +528,25 @@ def _predict_phonemes(self, hidden_states: torch.Tensor, idx) -> None: # ------------------------------------------------------------------ def compute_logits(self, hidden_states, sampling_metadata: Any = None) -> Optional[torch.Tensor]: - """Return zero logits so vLLM's sampler always picks index 0. - - The width is taken from ``hf_config.vocab_size`` so the sampler's working - buffers match. The sampled id is irrelevant — audio is surfaced via - :meth:`make_omni_output`. + f"""Dummy backbone logits, repurposed as a 2-way continue/stop signal. + ``_sample_stop`` indicates which frames contain EOS. We set logits, + based on that: logits[sample_stop == True, stop_token_id] = 30 or -30 otherwise. + SamplingParams should set stop_token_id as EOS token though. """ if isinstance(hidden_states, OmniOutput): hidden_states = hidden_states.text_hidden_states if hidden_states is None: return None batch_size = hidden_states.shape[0] - return hidden_states.new_zeros(batch_size, int(self.hf_config.vocab_size)) + logits = hidden_states.new_zeros(batch_size, int(self.hf_config.vocab_size)) + if self._stop_token_id < logits.shape[1]: + stop_rows = self._sample_stop[:batch_size] + logits[:, self._stop_token_id] = torch.where( + stop_rows, + logits.new_full((), 30.0), + logits.new_full((), -30.0), + ) + return logits # ------------------------------------------------------------------ # multimodal output plumbing From 42105350988dab8068b63c35047e1330f5496df2 Mon Sep 17 00:00:00 2001 From: Viacheslav Klimkov Date: Thu, 4 Jun 2026 18:47:03 +0200 Subject: [PATCH 15/15] examples/tts/easymagpie_vllm_omni: initial version of TTS service Signed-off-by: Viacheslav Klimkov --- examples/tts/easymagpie_vllm_omni/Dockerfile | 23 + .../export_codec_decoder_onnx.py | 283 ++++++++++++ .../export_codec_decoder_trt.py | 119 +++++ .../model_repository/codec/config.pbtxt | 33 ++ .../model_repository/easymp/1/model.py | 420 ++++++++++++++++++ .../model_repository/easymp/config.pbtxt | 103 +++++ 6 files changed, 981 insertions(+) create mode 100644 examples/tts/easymagpie_vllm_omni/Dockerfile create mode 100644 examples/tts/easymagpie_vllm_omni/export_codec_decoder_onnx.py create mode 100644 examples/tts/easymagpie_vllm_omni/export_codec_decoder_trt.py create mode 100644 examples/tts/easymagpie_vllm_omni/model_repository/codec/config.pbtxt create mode 100644 examples/tts/easymagpie_vllm_omni/model_repository/easymp/1/model.py create mode 100644 examples/tts/easymagpie_vllm_omni/model_repository/easymp/config.pbtxt diff --git a/examples/tts/easymagpie_vllm_omni/Dockerfile b/examples/tts/easymagpie_vllm_omni/Dockerfile new file mode 100644 index 000000000000..f9b29c757618 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/Dockerfile @@ -0,0 +1,23 @@ +FROM nvcr.io/nvidia/tritonserver:26.02-py3 + +# 1. System dependency for git-based installs +RUN apt-get update && \ + apt-get install -y git sox libsox-fmt-all + +# 2. upstream vllm +RUN pip install --no-cache-dir \ + "vllm==0.21.0" \ + "vllm_omni==0.21.0rc1" + +# 3. TODO install NeMo/examples/tts/easymagpie_vllm_omni + +# 4. Extra python requirements needed to compile the model +RUN pip install --no-cache-dir \ + onnxscript \ + librosa \ + sox \ + onnx-graphsurgeon \ + "tritonclient[grpc]" +RUN pip install --no-cache-dir --force-reinstall --no-deps "numpy==2.3.5" + +WORKDIR /workspace diff --git a/examples/tts/easymagpie_vllm_omni/export_codec_decoder_onnx.py b/examples/tts/easymagpie_vllm_omni/export_codec_decoder_onnx.py new file mode 100644 index 000000000000..85098993daad --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/export_codec_decoder_onnx.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stage 1/2: export the EasyMagpieTTS (25 fps spectral) codec decoder to ONNX. + +The exported graph takes the model's **raw stacked codes** and produces a +waveform, baking the two stateless model->codec glue steps into the graph so the +serving side (e.g. a Triton python backend) needs no NeMo: + + input audio_codes : int64 (batch, frames, num_stacked_codebooks) + output audio_values : float (batch, frames * output_samples_per_frame) + + audio_codes -> clamp(specials) -> unstack -> index-convert -> codec.decode + +``num_stacked_codebooks = num_audio_codebooks * frame_stacking_factor`` (e.g. +``8 * 2 = 16``). With ``--nemo_file`` the wrapper: + +* **clamps** out-of-range special tokens (audio bos/eos/mask) to valid indices, +* **unstacks** ``(B, T, C*S) -> (B, C, T*S)`` (inverse of ``stack_codes``), and +* **index-converts** the model's regrouped FSQ space (e.g. 8 codebooks of 1024) + to the codec's native ``GroupFiniteScalarQuantizer`` space (e.g. 5 codebooks of + 4^8) via ``VectorQuantizerIndexConverter.convert_new_to_original`` -- a lossless + per-frame index remap, read straight from the EasyMagpie ``.nemo``. + +Without ``--nemo_file`` it falls back to the codec's *native* decode (input +``(batch, frames, num_codebooks)``, no unstack / convert). + +For ``25fps_spectral_codec_with_bandwidth_extension.nemo`` the codec emits 882 +output samples / frame (decode emits 22050 Hz; encoder runs at 16000 Hz / 640 +samples per frame); one model frame unstacks to ``frame_stacking_factor`` codec +frames. + +The frame axis is exported as **static** (``--frames``) and only ``batch`` is +dynamic -- this matches the streaming decode usage (a fixed chunk size) and lets +TensorRT pick efficient tactics. Build several engines if you need several chunk +sizes, or pass a frames profile to the TRT builder for a dynamic frame axis. + +Stage 2 (TRT engine build) lives in ``export_codec_decoder_trt.py``. + +Example: + python examples/tts/easymagpie_vllm_omni/export_codec_decoder_onnx.py \\ + --codec_model_path /path/to/25fps_spectral_codec_with_bandwidth_extension.nemo \\ + --nemo_file /path/to/easymagpie.nemo \\ + --onnx-path codec/codec_decoder.onnx \\ + --frames 15 --device cuda +""" +from __future__ import annotations + +import argparse +from pathlib import Path + +import numpy as np +import torch + +# Match ORT's full-FP32 matmul; PyTorch on Ampere+ uses TF32 by default and would +# otherwise diverge from the ONNX/ORT reference during the parity check. +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +try: + import onnx +except ImportError as exc: # pragma: no cover + raise ImportError("`onnx` is required. Install with: pip install onnx onnxruntime") from exc + +from nemo.collections.tts.models import AudioCodecModel +from nemo.utils import logging + + +class CodecDecoderWrapper(torch.nn.Module): + """Wrap ``AudioCodecModel`` so a single ``(B, T, C)`` int tensor decodes to ``(B, T_audio)``. + + With ``converter``/``stacking`` set, the input is the model's *stacked* codes + ``(B, T, C*S)`` and the wrapper clamps special tokens, unstacks to ``(B, C, T*S)`` + and index-converts to the codec's native space before decoding. Otherwise the + input is the codec's *native* codes ``(B, T, num_codebooks)``. + + The codec's conv layers mask out-of-range positions using a per-batch length. + We bake a *full-length* length tensor (all frames valid) so the mask folds to a + constant at export time and disappears from the graph. + """ + + def __init__( + self, + codec_model: AudioCodecModel, + converter: torch.nn.Module = None, + stacking: int = 1, + clamp_max: int = None, + ): + super().__init__() + self.codec_model = codec_model + self.converter = converter + self.stacking = int(stacking) + self.clamp_max = clamp_max + + def forward(self, audio_codes: torch.Tensor) -> torch.Tensor: + # audio_codes: (B, T, C) -> codec expects (B, C, T) + tokens = audio_codes.transpose(1, 2).contiguous() + bsz = tokens.shape[0] + + if self.stacking > 1: + # Unstack (B, C*S, T) -> (B, C, T*S): inverse of EasyMagpie stack_codes. + cs, t = tokens.shape[1], tokens.shape[2] + c = cs // self.stacking + tokens = tokens.view(bsz, c, self.stacking, t).permute(0, 1, 3, 2).reshape(bsz, c, t * self.stacking) + + if self.clamp_max is not None: + # Drop special tokens (audio bos/eos/mask live above the codebook). + tokens = tokens.clamp(0, self.clamp_max) + + tokens = tokens.contiguous() + frames = tokens.shape[2] + tokens_len = torch.full((bsz,), frames, dtype=torch.long, device=tokens.device) + + if self.converter is not None: + tokens = self.converter.convert_new_to_original(audio_tokens=tokens, audio_lens=tokens_len) + + audio, _ = self.codec_model.decode(tokens=tokens, tokens_len=tokens_len) + return audio + + +def check_onnx_parity(wrapper, onnx_path, audio_codes, device, atol=1e-3): + try: + import onnxruntime as ort + except ImportError: + logging.warning("onnxruntime not installed -- skipping parity check") + return True + + providers = ( + ["CUDAExecutionProvider", "CPUExecutionProvider"] if device.type == "cuda" else ["CPUExecutionProvider"] + ) + sess = ort.InferenceSession(str(onnx_path), providers=providers) + + with torch.inference_mode(): + ref = wrapper(audio_codes).detach().cpu().float().numpy() + ort_out = sess.run(None, {"audio_codes": audio_codes.cpu().numpy()})[0] + max_diff = float(np.abs(ref - ort_out).max()) + ok = max_diff <= atol + logging.info( + f"ONNX parity ({sess.get_providers()[0]}): max_abs_diff={max_diff:.6f} " + f"atol={atol} {'PASSED' if ok else 'FAILED'}" + ) + return ok + + +def load_codec_decoder(codec_model_path: str, device: torch.device) -> AudioCodecModel: + """Restore the codec in FP32/eval and strip the (unused at inference) discriminator.""" + codec_cfg = AudioCodecModel.restore_from(codec_model_path, return_config=True) + if "use_scl_loss" in codec_cfg: + codec_cfg.use_scl_loss = False + codec = AudioCodecModel.restore_from(codec_model_path, strict=False, override_config_path=codec_cfg) + if hasattr(codec, "discriminator"): + del codec.discriminator + codec = codec.to(device).eval().float() + codec.freeze() + # Fuse weight-norm reparameterizations into plain conv weights for a clean graph. + if hasattr(codec, "audio_decoder") and hasattr(codec.audio_decoder, "remove_weight_norm"): + codec.audio_decoder.remove_weight_norm() + return codec + + +def load_index_converter(codec: AudioCodecModel, nemo_file: str, device: torch.device): + """Build the model->codec index converter + stacking factor from an EasyMagpie .nemo. + + Reads only the EasyMagpie config (no weights): the ``vector_quantizer`` override + the model was trained with and its ``frame_stacking_factor``. Returns + ``(converter_or_None, stacking, new_codebook_size)``. ``converter`` is None when + the model and codec already share the same FSQ grouping. + """ + from hydra.utils import instantiate + + from nemo.collections.tts.models.easy_magpietts_inference import EasyMagpieTTSInferenceModel + from nemo.collections.tts.modules.audio_codec_modules import VectorQuantizerIndexConverter + + em_cfg = EasyMagpieTTSInferenceModel.restore_from(nemo_file, return_config=True) + stacking = int(em_cfg.get("frame_stacking_factor", 1)) + vq_cfg = em_cfg.get("vector_quantizer") + if vq_cfg is None: + return None, stacking, None + + vq_new = instantiate(vq_cfg).to(device).eval() + new_codebook_size = int(vq_new.codebook_size) + if vq_new.num_codebooks == codec.vector_quantizer.num_codebooks: + return None, stacking, new_codebook_size + + converter = VectorQuantizerIndexConverter( + vector_quantizer_original=codec.vector_quantizer, + vector_quantizer_new=vq_new, + ).to(device).eval() + return converter, stacking, new_codebook_size + + +def parse_args(): + p = argparse.ArgumentParser(description="Export the EasyMagpieTTS codec decoder to ONNX") + p.add_argument("--codec_model_path", required=True, help="Path to the audio codec .nemo checkpoint") + p.add_argument( + "--nemo_file", + default=None, + help="EasyMagpie .nemo: bakes unstack + index conversion in (input becomes stacked model codes). " + "Omit to export the codec's native decode.", + ) + p.add_argument("--onnx-path", default="codec_decoder.onnx") + p.add_argument("--frames", type=int, default=30, help="Static frame count baked into the graph (chunk size)") + p.add_argument("--batch-size", type=int, default=1, help="Dummy batch size used for export/parity") + p.add_argument("--opset", type=int, default=18) + p.add_argument("--device", default="cpu", choices=["cpu", "cuda"]) + p.add_argument("--atol", type=float, default=1e-3) + return p.parse_args() + + +def main(): + args = parse_args() + device = torch.device(args.device) + + codec = load_codec_decoder(args.codec_model_path, device) + + converter, stacking, new_codebook_size = (None, 1, None) + if args.nemo_file is not None: + converter, stacking, new_codebook_size = load_index_converter(codec, args.nemo_file, device) + + if args.nemo_file is not None: + # Input is the model's stacked codes; clamp specials, unstack, convert. + model_codebooks = ( + converter.vector_quantizer_new.num_codebooks if converter is not None else int(codec.num_codebooks) + ) + codebook_size = new_codebook_size if new_codebook_size is not None else int(codec.codebook_size) + nq = model_codebooks * stacking # num_stacked_codebooks (e.g. 16) + clamp_max = codebook_size - 1 + else: + # Input is the codec's native codes. + nq = int(codec.num_codebooks) + codebook_size = int(codec.codebook_size) + clamp_max = None + + wrapper = CodecDecoderWrapper(codec, converter=converter, stacking=stacking, clamp_max=clamp_max).to(device).eval() + + logging.info( + f"codec: sample_rate={codec.sample_rate} output_sample_rate={codec.output_sample_rate} " + f"samples_per_frame={codec.samples_per_frame} native_codebooks={int(codec.num_codebooks)} " + f"| input num_codebooks={nq} stacking={stacking} convert={converter is not None}" + ) + + dummy = torch.randint(0, codebook_size, (args.batch_size, args.frames, nq), dtype=torch.long, device=device) + + onnx_path = Path(args.onnx_path) + onnx_path.parent.mkdir(parents=True, exist_ok=True) + + with torch.inference_mode(): + torch.onnx.export( + wrapper, + (dummy,), + str(onnx_path), + dynamo=False, + export_params=True, + opset_version=args.opset, + do_constant_folding=True, + input_names=["audio_codes"], + output_names=["audio_values"], + dynamic_axes={ + "audio_codes": {0: "batch"}, + "audio_values": {0: "batch"}, + }, + ) + logging.info(f"ONNX exported to {onnx_path}") + + onnx.checker.check_model(str(onnx_path)) + + if not check_onnx_parity(wrapper, onnx_path, dummy, device, atol=args.atol): + raise RuntimeError("ONNX vs PyTorch parity failed -- export is broken.") + + +if __name__ == "__main__": + main() diff --git a/examples/tts/easymagpie_vllm_omni/export_codec_decoder_trt.py b/examples/tts/easymagpie_vllm_omni/export_codec_decoder_trt.py new file mode 100644 index 000000000000..67d2b0174107 --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/export_codec_decoder_trt.py @@ -0,0 +1,119 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Stage 2/2: build a TensorRT engine from the codec-decoder ONNX. + +Consumes the ONNX produced by ``export_codec_decoder_onnx.py`` and runs +``trtexec`` to build an engine with a dynamic ``batch`` (and optionally dynamic +``frames``) shape profile. + +The input tensor is ``audio_codes`` with shape ``(batch, frames, num_codebooks)``. +``num_codebooks`` is read from the ONNX graph; ``batch``/``frames`` come from the +profile flags. + +Example: + python examples/tts/easymagpie_vllm_omni/export_codec_decoder_trt.py \\ + --onnx-path codec/codec_decoder.onnx \\ + --trt-path codec/codec_decoder.plan \\ + --batch-profile 1 8 32 \\ + --frames-profile 30 30 30 --fp16 + +Notes +----- +* The frame axis is usually static (export with a fixed ``--frames`` and use the + same value for min/opt/max). A dynamic frame axis works too if the ONNX was + exported with ``frames`` dynamic. +""" +from __future__ import annotations + +import argparse +import shutil +import subprocess +from pathlib import Path + +import onnx + + +def _infer_num_quantizers(onnx_path): + model = onnx.load(str(onnx_path)) + for inp in model.graph.input: + if inp.name != "audio_codes": + continue + dims = inp.type.tensor_type.shape.dim + if len(dims) >= 3 and dims[2].dim_value > 0: + return int(dims[2].dim_value) + raise RuntimeError( + f"could not infer num_quantizers from {onnx_path} (audio_codes dim 2 is not a static positive integer)" + ) + + +def convert_to_trt(onnx_path, trt_path, trtexec_bin, nq, batch_prof, frames_prof, fp32): + exe = shutil.which(trtexec_bin) if "/" not in trtexec_bin else trtexec_bin + if exe is None: + raise FileNotFoundError(f"trtexec not found: {trtexec_bin}") + trt_path.parent.mkdir(parents=True, exist_ok=True) + + def s(b, f): + return f"{b}x{f}x{nq}" + + cmd = [ + exe, + f"--onnx={onnx_path}", + f"--saveEngine={trt_path}", + f"--minShapes=audio_codes:{s(batch_prof[0], frames_prof[0])}", + f"--optShapes=audio_codes:{s(batch_prof[1], frames_prof[1])}", + f"--maxShapes=audio_codes:{s(batch_prof[2], frames_prof[2])}", + ] + if not fp32: + cmd.append("--fp16") + print("Running:", " ".join(cmd)) + subprocess.run(cmd, check=True) + print(f"TensorRT engine saved to {trt_path}") + + +def parse_args(): + p = argparse.ArgumentParser(description="Build a TensorRT engine from the codec-decoder ONNX") + p.add_argument("--onnx-path", required=True) + p.add_argument("--trt-path", required=True) + p.add_argument("--trtexec-bin", default="/usr/src/tensorrt/bin/trtexec") + p.add_argument("--batch-profile", nargs=3, type=int, default=[1, 8, 32], metavar=("MIN", "OPT", "MAX")) + p.add_argument("--frames-profile", nargs=3, type=int, default=[15, 15, 15], metavar=("MIN", "OPT", "MAX")) + p.add_argument("--fp32", action="store_true", help="Build pure FP32 engine (default: FP16).") + return p.parse_args() + + +def main(): + args = parse_args() + onnx_path = Path(args.onnx_path) + trt_path = Path(args.trt_path) + + if not onnx_path.is_file(): + raise FileNotFoundError(f"ONNX not found: {onnx_path}") + + nq = _infer_num_quantizers(onnx_path) + print(f"num_quantizers={nq} (from {onnx_path})") + + convert_to_trt( + onnx_path, + trt_path, + args.trtexec_bin, + nq=nq, + batch_prof=tuple(args.batch_profile), + frames_prof=tuple(args.frames_profile), + fp32=args.fp32, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/tts/easymagpie_vllm_omni/model_repository/codec/config.pbtxt b/examples/tts/easymagpie_vllm_omni/model_repository/codec/config.pbtxt new file mode 100644 index 000000000000..a575964c4dba --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/model_repository/codec/config.pbtxt @@ -0,0 +1,33 @@ +name: "codec" +platform: "tensorrt_plan" +max_batch_size: 32 + +# Stacked model codes (clamp + unstack + index-convert are baked into the engine). +# Frame axis is static at the exported chunk size (15 model frames). +input [ + { + name: "audio_codes" + data_type: TYPE_INT64 + dims: [ 15, 16 ] + } +] + +output [ + { + name: "audio_values" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +dynamic_batching { + max_queue_delay_microseconds: 1000 + preferred_batch_size: [ 32 ] +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] diff --git a/examples/tts/easymagpie_vllm_omni/model_repository/easymp/1/model.py b/examples/tts/easymagpie_vllm_omni/model_repository/easymp/1/model.py new file mode 100644 index 000000000000..237a7125e5ee --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/model_repository/easymp/1/model.py @@ -0,0 +1,420 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Triton Python backend for EasyMagpieTTS driven by vllm-omni's AsyncOmni engine. + +Wraps ``EasyMagpieTTSForConditionalGeneration`` (the vLLM-Omni talker, same model +used by the inference demo / benchmark): it streams stacked codec frames, which we +chunk-decode (overlap-save) through the ``codec`` TensorRT model. + +Pipeline: + 1. Build ``additional_information`` from ``{speaker_embedding, context_text, text, + temperature, top_k}`` and a placeholder ``prompt_token_ids`` of length + ``estimate_prompt_len(...)``. + 2. Submit one request to ``AsyncOmni.generate()``. Each step yields the + *cumulative* ``audio_codes`` tensor ``(T_total, C*S)`` (prefill rows + one row + per decode step) and cumulative backbone ``token_ids``; we slice the decoded + rows, drop the leading ``speech_delay`` warm-up frames, and stop at the audio + EOS frame. + 3. New frames are streamed out in fixed ``codec_chunk_size``-frame windows (with a + trimmed ``codec_left_context``) through the ``codec`` BLS, which unstacks + + index-converts + decodes them to 22.05 kHz audio chunks. +""" + +from __future__ import annotations + +import asyncio +import concurrent.futures +import json +import logging +import os +import queue +import tempfile +import threading +import time +import uuid +from pathlib import Path + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +import yaml + +logging.basicConfig( + format="%(asctime)s [%(levelname)s]: %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) +logger = logging.getLogger("easymp_triton") + + +def _require_param(parameters: dict, key: str) -> str: + val = parameters.get(key) + if isinstance(val, dict): + val = val.get("string_value") + if val is None: + raise KeyError(f"Missing required model parameter: {key!r}") + return str(val) + + +class TritonPythonModel: + def initialize(self, args): + os.environ.setdefault("VLLM_WORKER_MULTIPROC_METHOD", "spawn") + + self.model_config = json.loads(args["model_config"]) + params = self.model_config.get("parameters", {}) + + self.vllm_model_path = _require_param(params, "vllm_model_path") + self.default_speaker = _require_param(params, "default_speaker") + self.default_context_text = _require_param(params, "default_context_text") + + self.max_model_len = int(_require_param(params, "max_model_len")) + self.max_num_seqs = int(_require_param(params, "max_num_seqs")) + self.max_num_batched_tokens = int(_require_param(params, "max_num_batched_tokens")) + self.max_new_tokens = int(_require_param(params, "max_new_tokens")) + self.gpu_memory_utilization = float(_require_param(params, "gpu_memory_utilization")) + + self.codec_chunk_size = int(_require_param(params, "codec_chunk_size")) + self.codec_left_context = int(_require_param(params, "codec_left_context")) + self.first_chunk_frames = int(_require_param(params, "first_chunk_frames")) + + self.lt_temperature = float(_require_param(params, "lt_temperature")) + self.lt_top_k = int(_require_param(params, "lt_top_k")) + + self._load_arch_and_tokenizer() + self._speaker_cache: dict = {} + # Inferred from the first codec decode (audio_len / codec_chunk_size). + self._spf: int | None = None + + self._loop = asyncio.new_event_loop() + self._loop_thread = threading.Thread(target=self._loop.run_forever, daemon=True) + self._loop_thread.start() + + # One thread per in-flight request serializes its codec decode + + # response_sender.send calls, off the asyncio loop and overlapping with + # vLLM generation; Triton dynamic batching then groups the codec calls. + self._codec_pool = concurrent.futures.ThreadPoolExecutor( + max_workers=max(1, self.max_num_seqs), + thread_name_prefix="easymp_codec", + ) + + self._start_omni_engine() + logger.info("EasyMagpie initialized (default_speaker=%s)", self.default_speaker) + + def _load_arch_and_tokenizer(self): + from transformers import AutoTokenizer + + from easymagpie_vllm_omni.config import EasyMagpieOmniArch + from easymagpie_vllm_omni.easymagpie import EasyMagpieTTSForConditionalGeneration + + config = json.loads((Path(self.vllm_model_path) / "config.json").read_text()) + cfg_obj = type("Cfg", (), config) + arch = EasyMagpieOmniArch.from_hf_config(cfg_obj) + + self.audio_eos_id = int(arch.audio_eos_id) + self.speech_delay = int(getattr(arch, "streaming_speech_delay", 0) or 0) + self.num_stacked_codebooks = int(arch.num_stacked_codebooks) + self.has_task_embedding = arch.num_task_embeddings > 0 + self.stop_token_id = EasyMagpieTTSForConditionalGeneration.audio_eos_stop_token_id(cfg_obj) + + self.tokenizer = AutoTokenizer.from_pretrained(self.vllm_model_path, trust_remote_code=True) + self._estimate_prompt_len = EasyMagpieTTSForConditionalGeneration.estimate_prompt_len + + def _build_stage_config_file(self) -> str: + stage_cfg = { + "stage_args": [ + { + "stage_id": 0, + "stage_type": "llm", + "is_comprehension": True, + "final_output": True, + "final_output_type": "audio", + "runtime": {"devices": "0"}, + "engine_args": { + "model_stage": "easymagpie", + "max_num_seqs": self.max_num_seqs, + "model_arch": "EasyMagpieTTSForConditionalGeneration", + "worker_type": "ar", + "scheduler_cls": "vllm_omni.core.sched.omni_ar_scheduler.OmniARAsyncScheduler", + "enforce_eager": False, + "trust_remote_code": True, + "async_scheduling": True, + "enable_prefix_caching": False, + "engine_output_type": "audio", + "gpu_memory_utilization": self.gpu_memory_utilization, + "distributed_executor_backend": "uni", + "max_num_batched_tokens": self.max_num_batched_tokens, + "max_model_len": self.max_model_len, + # bf16 overflows the Nemotron-H fused-MoE Triton kernel's + # fp32 shared memory; fp16 backbone + fp32 mamba cache. + "dtype": "float16", + "mamba_ssm_cache_dtype": "float32", + "attention_backend": "TRITON_ATTN", + # We feed prompt_token_ids directly; the model loads the + # bundled tokenizer to tokenize context_text + text. + "skip_tokenizer_init": True, + }, + "default_sampling_params": { + # Backbone token sampler is a no-op (audio is sampled in the + # local transformer via additional_information temperature/top_k). + "temperature": 0.0, + "max_tokens": self.max_new_tokens, + "detokenize": False, + # Audio EOS lives in the codes; the model emits stop_token_id + # on the backbone stream at the EOS frame. + "ignore_eos": True, + "stop_token_ids": [self.stop_token_id], + }, + } + ], + } + tmp = tempfile.NamedTemporaryFile(mode="w", suffix=".yaml", prefix="easymp_triton_", delete=False) + yaml.dump(stage_cfg, tmp, sort_keys=False) + tmp.close() + return tmp.name + + def _start_omni_engine(self): + from vllm_omni import AsyncOmni + + self._stage_cfg_path = self._build_stage_config_file() + self.omni = AsyncOmni( + model=self.vllm_model_path, + stage_configs_path=self._stage_cfg_path, + log_stats=False, + stage_init_timeout=300, + ) + + def _get_speaker_embedding(self, speaker: str) -> torch.Tensor: + if speaker not in self._speaker_cache: + emb_path = Path(self.vllm_model_path) / "speaker_embeddings" / f"{speaker}.pt" + if not emb_path.exists(): + raise FileNotFoundError(f"Speaker embedding not found: {emb_path}") + loaded = torch.load(emb_path, map_location="cpu") + emb = loaded["speaker_encoding"] if isinstance(loaded, dict) else loaded + self._speaker_cache[speaker] = emb.to(torch.float32) + return self._speaker_cache[speaker] + + def _build_prompt(self, text: str, context_text: str, speaker: str) -> dict: + speaker_embedding = self._get_speaker_embedding(speaker) + prompt_len = self._estimate_prompt_len( + speaker_embedding, + tokenize=lambda t: self.tokenizer.encode(t), + context_text=context_text, + has_task_embedding=self.has_task_embedding, + ) + return { + "prompt_token_ids": [0] * prompt_len, + "additional_information": { + "speaker_embedding": speaker_embedding, + "context_text": context_text, + "text": text, + "temperature": self.lt_temperature, + "top_k": self.lt_top_k, + }, + } + + def _decode_codec(self, codes: torch.Tensor, left_context_frames: int) -> np.ndarray: + """Decode one ``(<=codec_chunk_size, C*S)`` window, trim left context + pad.""" + codes_np = codes.detach().cpu().to(torch.int64).numpy() + pad = self.codec_chunk_size - codes_np.shape[0] + if pad > 0: + codes_np = np.pad(codes_np, ((0, pad), (0, 0))) + + response = pb_utils.InferenceRequest( + model_name="codec", + requested_output_names=["audio_values"], + inputs=[pb_utils.Tensor("audio_codes", codes_np[np.newaxis])], + ).exec() + if response.has_error(): + raise RuntimeError(f"Codec decode failed: {response.error().message()}") + + audio_tensor = pb_utils.get_output_tensor_by_name(response, "audio_values") + audio = ( + audio_tensor.as_numpy() + if audio_tensor.is_cpu() + else torch.from_dlpack(audio_tensor.to_dlpack()).cpu().numpy() + ) + if audio.ndim > 1: + audio = audio[0] + + if self._spf is None: + self._spf = audio.shape[-1] // self.codec_chunk_size + left = left_context_frames * self._spf + right = pad * self._spf + return audio[left:-right] if right > 0 else audio[left:] + + def _send_audio(self, response_sender, audio: np.ndarray, final: bool): + response_sender.send( + pb_utils.InferenceResponse(output_tensors=[pb_utils.Tensor("audio", audio.astype(np.float32))]), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL if final else 0, + ) + + def _send_error(self, response_sender, err: Exception): + try: + response_sender.send( + pb_utils.InferenceResponse(output_tensors=[], error=pb_utils.TritonError(str(err))), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + except Exception: + pass + + def _codec_worker(self, codec_q: queue.Queue, response_sender, state: dict) -> None: + """Pop ``(chunk, ctx, is_final)`` tuples; ``None`` == send empty final + exit.""" + finalized = False + try: + while True: + item = codec_q.get() + if item is None: + self._send_audio(response_sender, np.array([], dtype=np.float32), final=True) + finalized = True + return + chunk, ctx, is_final = item + audio = self._decode_codec(chunk, ctx) + self._send_audio(response_sender, audio, final=is_final) + if state["t_first_audio"] is None: + state["t_first_audio"] = time.perf_counter() + if is_final: + finalized = True + return + except Exception as e: + state["error"] = e + if not finalized: + self._send_error(response_sender, e) + + async def _synthesize(self, text: str, context_text: str, speaker: str, response_sender): + t_start = time.perf_counter() + request_id = f"easymp-{uuid.uuid4().hex[:8]}" + prompt = self._build_prompt(text, context_text, speaker) + prompt_len = len(prompt["prompt_token_ids"]) + + codec_q: queue.Queue = queue.Queue() + state: dict = {"t_first_audio": None, "error": None} + codec_future = self._codec_pool.submit(self._codec_worker, codec_q, response_sender, state) + + L = self.codec_left_context + sent = 0 # real frames (post speech-delay, pre EOS) already queued + threshold = self.first_chunk_frames + real: torch.Tensor | None = None + real_count = 0 + eos_found = False + + try: + async for out in self.omni.generate(prompt, request_id=request_id): + if state["error"] is not None: + break + mm = getattr(out, "multimodal_output", None) or {} + audio_codes = mm.get("audio_codes") + if not isinstance(audio_codes, torch.Tensor): + continue + # audio_codes accumulates one row per flat-batch token: prompt_len + # prefill rows + one per decode step. Count decoded frames from the + # tensor (token_ids on a streaming step is a delta, not cumulative). + num_decoded = audio_codes.shape[0] - prompt_len + if num_decoded <= self.speech_delay: + continue + + # Decoded rows are everything after prefill; drop the leading + # speech-delay warm-up frames. + real = audio_codes[prompt_len + self.speech_delay :] + eos_rows = (real == self.audio_eos_id).any(dim=1).nonzero() + if eos_rows.numel() > 0: + real_count = int(eos_rows[0].item()) # exclude the EOS frame + eos_found = True + else: + real_count = real.shape[0] + + while real_count - sent >= threshold: + ctx = min(sent, L) + chunk = real[sent - ctx : sent + threshold] + codec_q.put((chunk, ctx, False)) + sent += threshold + threshold = self.codec_chunk_size - L + if eos_found: + break + + if state["error"] is None: + if real is not None and real_count > sent: + ctx = min(sent, L) + codec_q.put((real[sent - ctx : real_count], ctx, True)) + else: + codec_q.put(None) + + await asyncio.wrap_future(codec_future) + if state["error"] is not None: + raise state["error"] + + t_end = time.perf_counter() + ttfa_ms = ((state["t_first_audio"] or t_end) - t_start) * 1000 + logger.info( + "rid=%s ttfa=%.1fms total=%.1fms frames=%d speaker=%s text=%r", + request_id, + ttfa_ms, + (t_end - t_start) * 1000, + sent, + speaker, + text[:120], + ) + except Exception as e: + logger.error("rid=%s failed: %s", request_id, e, exc_info=True) + try: + await self.omni.abort(request_id) + except Exception: + pass + if not codec_future.done(): + codec_q.put(None) + try: + await asyncio.wrap_future(codec_future) + except Exception: + pass + self._send_error(response_sender, e) + + @staticmethod + def _read_str(request, name: str, default: str) -> str: + tensor = pb_utils.get_input_tensor_by_name(request, name) + if tensor is None: + return default + return tensor.as_numpy().flatten()[0].decode("utf-8") + + def execute(self, requests): + for request in requests: + response_sender = request.get_response_sender() + try: + text = self._read_str(request, "text", "") + context_text = self._read_str(request, "context_text", self.default_context_text) + speaker = self._read_str(request, "speaker", self.default_speaker) + asyncio.run_coroutine_threadsafe( + self._synthesize(text, context_text, speaker, response_sender), + self._loop, + ) + except Exception as e: + logger.error("Request parse failed: %s", e, exc_info=True) + self._send_error(response_sender, e) + return None + + def finalize(self): + if hasattr(self, "omni"): + try: + self.omni.shutdown() + except Exception: + pass + if hasattr(self, "_loop") and self._loop.is_running(): + self._loop.call_soon_threadsafe(self._loop.stop) + if hasattr(self, "_loop_thread"): + self._loop_thread.join(timeout=10) + if hasattr(self, "_codec_pool"): + self._codec_pool.shutdown(wait=False) + if getattr(self, "_stage_cfg_path", None): + try: + os.unlink(self._stage_cfg_path) + except OSError: + pass diff --git a/examples/tts/easymagpie_vllm_omni/model_repository/easymp/config.pbtxt b/examples/tts/easymagpie_vllm_omni/model_repository/easymp/config.pbtxt new file mode 100644 index 000000000000..f4266539dc7b --- /dev/null +++ b/examples/tts/easymagpie_vllm_omni/model_repository/easymp/config.pbtxt @@ -0,0 +1,103 @@ +name: "easymp" +backend: "python" +max_batch_size: 32 + +input [ + { + name: "text" + data_type: TYPE_STRING + dims: [ 1 ] + }, + { + name: "context_text" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + }, + { + name: "speaker" + data_type: TYPE_STRING + dims: [ 1 ] + optional: true + } +] + +output [ + { + name: "audio" + data_type: TYPE_FP32 + dims: [ -1 ] + } +] + +model_transaction_policy { + decoupled: true +} + +instance_group [ + { + count: 1 + kind: KIND_GPU + } +] + +# Converted EasyMagpie checkpoint dir (config.json + weights + tokenizer + +# speaker_embeddings/.pt), produced by easy_magpietts_convert_to_vllm.py. +parameters { + key: "vllm_model_path" + value: { string_value: "/workspace/examples/tts/easymagpie_vllm_omni/easymp_vllm_model" } +} +parameters { + key: "default_speaker" + value: { string_value: "eng" } +} +parameters { + key: "default_context_text" + value: { string_value: "[EN]" } +} + +parameters { + key: "max_model_len" + value: { string_value: "1024" } +} +parameters { + key: "max_num_seqs" + value: { string_value: "32" } +} +parameters { + key: "max_num_batched_tokens" + value: { string_value: "1024" } +} +parameters { + key: "max_new_tokens" + value: { string_value: "512" } +} +parameters { + key: "gpu_memory_utilization" + value: { string_value: "0.5" } +} + +# Codec streaming (overlap-save) in MODEL frames. codec_chunk_size must equal the +# codec engine's static frame axis (15). new frames / chunk = chunk - left_context. +parameters { + key: "codec_chunk_size" + value: { string_value: "15" } +} +parameters { + key: "codec_left_context" + value: { string_value: "12" } +} +parameters { + key: "first_chunk_frames" + value: { string_value: "2" } +} + +# Audio (local-transformer) sampling, forwarded via additional_information. +parameters { + key: "lt_temperature" + value: { string_value: "0.7" } +} +parameters { + key: "lt_top_k" + value: { string_value: "80" } +}