diff --git a/Makefile b/Makefile index c4535adb7f7..40053385fc7 100644 --- a/Makefile +++ b/Makefile @@ -91,7 +91,7 @@ # # ============================================================================== -.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help +.PHONY: voxtral-cuda voxtral-cpu voxtral-metal voxtral_realtime-cuda voxtral_realtime-cpu voxtral_realtime-metal qwen3-tts-cpu whisper-cuda whisper-cuda-debug whisper-cpu whisper-metal parakeet-cuda parakeet-cuda-debug parakeet-cpu parakeet-metal sortformer-cuda sortformer-cpu silero-vad-cpu llama-cuda llama-cuda-debug llama-cpu llava-cpu gemma3-cuda gemma3-cpu clean help help: @echo "This Makefile adds targets to build runners for various models on various backends. Run using \`make \`. Available targets:" @@ -101,6 +101,7 @@ help: @echo " voxtral_realtime-cuda - Build Voxtral Realtime runner with CUDA backend" @echo " voxtral_realtime-cpu - Build Voxtral Realtime runner with CPU backend" @echo " voxtral_realtime-metal - Build Voxtral Realtime runner with Metal backend (macOS only)" + @echo " qwen3-tts-cpu - Build Qwen3-TTS runner with CPU backend" @echo " whisper-cuda - Build Whisper runner with CUDA backend" @echo " whisper-cuda-debug - Build Whisper runner with CUDA backend (debug mode)" @echo " whisper-cpu - Build Whisper runner with CPU backend" @@ -264,6 +265,15 @@ voxtral_realtime-cuda: @echo "✓ Build complete!" @echo " Binary: cmake-out/examples/models/voxtral_realtime/voxtral_realtime_runner" +qwen3-tts-cpu: + @echo "==> Building and installing ExecuTorch..." + cmake --workflow --preset llm-release + @echo "==> Building Qwen3-TTS runner (CPU)..." + cd examples/models/qwen3-tts && cmake --workflow --preset qwen3-tts-cpu + @echo "" + @echo "✓ Build complete!" + @echo " Binary: cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner" + silero-vad-cpu: @echo "==> Building and installing ExecuTorch..." cmake --workflow --preset llm-release diff --git a/examples/models/qwen3-tts/.gitignore b/examples/models/qwen3-tts/.gitignore new file mode 100644 index 00000000000..8d352e6ef83 --- /dev/null +++ b/examples/models/qwen3-tts/.gitignore @@ -0,0 +1,20 @@ +# Local model downloads and caches +qwen3-tts-12Hz-0.6B-Base/ +tokenizer_cache/ + +# Local converted checkpoints and export artifacts +qwen3_tts_artifacts/ +qwen3_tts_exports_talker_8da4w/ +qwen3_tts_exports_talker_8da4w_s256/ +qwen3_tts_exports_unified/*.pte +qwen3_tts_exports_unified_q4emb/*.pte +qwen3_tts_exports_unified_q8emb/*.pte + +# Local experiment outputs +repro_runs/ +output*.wav +metal_test_codes.json +qwen3_runner_exports.txt +qwen3_runner_symbols.txt +decode_codes_so_blob71390_strings.txt +decode_codes_so_blob71390_symbols.txt diff --git a/examples/models/qwen3-tts/CMakeLists.txt b/examples/models/qwen3-tts/CMakeLists.txt new file mode 100644 index 00000000000..8d4a0cbebb7 --- /dev/null +++ b/examples/models/qwen3-tts/CMakeLists.txt @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +cmake_minimum_required(VERSION 3.24) +project(qwen3_tts_runner) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../..) +include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake) + +# Let files say "include " +set(_common_include_directories ${EXECUTORCH_ROOT}/..) + +set(gflags_DIR ${CMAKE_CURRENT_BINARY_DIR}/../../../third-party/gflags) +find_package(gflags REQUIRED) + +list(APPEND CMAKE_FIND_ROOT_PATH ${CMAKE_CURRENT_BINARY_DIR}/../../..) +find_package(executorch CONFIG REQUIRED FIND_ROOT_PATH_BOTH) +executorch_target_link_options_shared_lib(executorch) + +set(_link_libraries executorch gflags) + +# Common ops for all builds. +list(APPEND _link_libraries optimized_native_cpu_ops_lib cpublas eigen_blas) +executorch_target_link_options_shared_lib(optimized_native_cpu_ops_lib) + +# CPU path can require quantized/custom ops when XNNPACK delegates are present. +if(NOT EXECUTORCH_BUILD_CUDA) + list(APPEND _link_libraries quantized_ops_lib custom_ops) + executorch_target_link_options_shared_lib(quantized_ops_lib) + executorch_target_link_options_shared_lib(custom_ops) +endif() + +# XNNPACK +if(TARGET xnnpack_backend) + set(xnnpack_backend_libs xnnpack_backend XNNPACK xnnpack-microkernels-prod) + if(TARGET kleidiai) + list(APPEND xnnpack_backend_libs kleidiai) + endif() + list(APPEND _link_libraries ${xnnpack_backend_libs}) + executorch_target_link_options_shared_lib(xnnpack_backend) +endif() + +# Base extensions needed for module loading + tensors. +list( + APPEND + _link_libraries + extension_module + extension_data_loader + extension_tensor + extension_flat_tensor +) + +if(ANDROID) + list(APPEND _link_libraries log) +endif() + +if(EXECUTORCH_BUILD_CUDA) + find_package(CUDAToolkit REQUIRED) + list(APPEND _link_libraries aoti_cuda_backend) + if(NOT MSVC) + executorch_target_link_options_shared_lib(aoti_cuda_backend) + endif() +endif() + +add_executable(qwen3_tts_runner main.cpp qwen3_tts_runner.cpp) +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(qwen3_tts_runner) + if(NOT APPLE AND NOT MSVC) + target_link_options(qwen3_tts_runner PRIVATE "LINKER:-s") + endif() +endif() + +target_include_directories(qwen3_tts_runner PUBLIC ${_common_include_directories}) +target_link_libraries(qwen3_tts_runner PUBLIC ${_link_libraries}) +target_compile_options(qwen3_tts_runner PUBLIC ${_common_compile_options}) + +# Unified runner: single .pte with all methods (text -> audio). +add_executable( + qwen3_tts_unified_runner main_unified.cpp qwen3_tts_unified_runner.cpp +) +if(NOT CMAKE_BUILD_TYPE STREQUAL "Debug") + target_link_options_gc_sections(qwen3_tts_unified_runner) + if(NOT APPLE AND NOT MSVC) + target_link_options(qwen3_tts_unified_runner PRIVATE "LINKER:-s") + endif() +endif() + +target_include_directories( + qwen3_tts_unified_runner PUBLIC ${_common_include_directories} +) +target_link_libraries( + qwen3_tts_unified_runner PUBLIC ${_link_libraries} extension_llm_runner +) + +# Metal/AOTI backend for GPU acceleration. +if(EXECUTORCH_BUILD_METAL) + target_link_libraries(qwen3_tts_unified_runner PUBLIC metal_backend) + executorch_target_link_options_shared_lib(metal_backend) +endif() +target_compile_options( + qwen3_tts_unified_runner PUBLIC ${_common_compile_options} +) + +if(MSVC AND EXECUTORCH_BUILD_CUDA) + add_custom_command( + TARGET qwen3_tts_runner + POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different $ + $ + COMMENT "Copying aoti_cuda_shims.dll to qwen3_tts_runner directory" + ) +endif() diff --git a/examples/models/qwen3-tts/CMakePresets.json b/examples/models/qwen3-tts/CMakePresets.json new file mode 100644 index 00000000000..fe399b0a1bd --- /dev/null +++ b/examples/models/qwen3-tts/CMakePresets.json @@ -0,0 +1,48 @@ +{ + "version": 6, + "configurePresets": [ + { + "name": "qwen3-tts-base", + "hidden": true, + "binaryDir": "${sourceDir}/../../../cmake-out/examples/models/qwen3-tts", + "cacheVariables": { + "CMAKE_BUILD_TYPE": "Release", + "CMAKE_FIND_ROOT_PATH": "${sourceDir}/../../../cmake-out", + "CMAKE_PREFIX_PATH": "${sourceDir}/../../../cmake-out" + } + }, + { + "name": "qwen3-tts-cpu", + "displayName": "Qwen3-TTS runner (CPU)", + "inherits": [ + "qwen3-tts-base" + ] + } + ], + "buildPresets": [ + { + "name": "qwen3-tts-cpu", + "displayName": "Build Qwen3-TTS runner (CPU)", + "configurePreset": "qwen3-tts-cpu", + "targets": [ + "qwen3_tts_unified_runner" + ] + } + ], + "workflowPresets": [ + { + "name": "qwen3-tts-cpu", + "displayName": "Configure and build Qwen3-TTS runner (CPU)", + "steps": [ + { + "type": "configure", + "name": "qwen3-tts-cpu" + }, + { + "type": "build", + "name": "qwen3-tts-cpu" + } + ] + } + ] +} diff --git a/examples/models/qwen3-tts/CONTEXT.md b/examples/models/qwen3-tts/CONTEXT.md new file mode 100644 index 00000000000..574b75b81cd --- /dev/null +++ b/examples/models/qwen3-tts/CONTEXT.md @@ -0,0 +1,129 @@ +# Qwen3-TTS Bring-up Context + +## Scope + +- Target model: `Qwen/Qwen3-TTS-12Hz-0.6B-Base` +- Target path: `examples/models/qwen3-tts` +- Backend: XNNPACK (CPU) + +## Reference patterns used + +### 1) Qwen conversion/export patterns + +- `examples/models/qwen3/convert_weights.py` + - HF checkpoint conversion style with shard handling. +- `examples/models/qwen3_5/convert_weights.py` + - strict key mapping behavior and defensive conversion logic. +- `examples/models/qwen3_5/tests/test_convert_weights.py` + - focused conversion unit tests for mapping and unknown keys. + +### 2) Speech model export/runtime patterns + +- `examples/models/voxtral_realtime/export_voxtral_rt.py` + - multi-method export wrappers. + - backend split and metadata in `constant_methods`. +- `examples/models/voxtral_realtime/voxtral_realtime_runner.cpp` + - custom C++ runner using `executorch::extension::Module`. +- `examples/models/whisper/main.cpp` + - ASR runtime ergonomics and preprocessor handoff. + +### 3) Build integration patterns + +- `examples/models/whisper/CMakeLists.txt` +- `examples/models/whisper/CMakePresets.json` +- top-level `Makefile` + +### 4) Backend support references + +- `examples/models/MODEL_BACKEND_SUPPORT.md` + - confirms XNNPACK as the practical first backend target for CPU bring-up. + - speech model examples currently emphasize CUDA/Metal; this bring-up closes a + gap for CPU-oriented TTS decode execution. + +## Repository observations (examples/models survey) + +- Existing audio examples are STT-focused (`whisper`, `parakeet`, `voxtral_realtime`). +- No first-class generic TTS runner existed before this bring-up. +- Existing reusable primitive for speech output generation is closest in + tokenizer/codec decoder stacks (not yet standardized as a shared TTS runtime). + +## Qwen3-TTS package observations + +- `Qwen3TTSModel.generate_voice_clone(...)` performs: + - text/ref prompt packing, + - talker generation of codec tokens, + - speech tokenizer decode into waveform. +- Speech tokenizer decode path for 12Hz variant is represented by + `Qwen3TTSTokenizerV2Decoder` and can run from codebook tokens. +- Full talker generation export to ExecuTorch is significantly larger in scope + (autoregressive + sub-talker generation path and cache/state flow). + +## Bring-up design choice + +To get XNNPACK validation first: + +- Export the **speech-tokenizer decoder** into ExecuTorch. +- Keep **codec generation** in Python helper using upstream `qwen_tts`. +- Add a C++ runner that: + - optionally invokes helper (`text -> codec ids`) + - then decodes codec ids through exported `model.pte` (`codec ids -> wav`). + +This keeps the path runnable and measurable while preserving room to move +talker generation into ExecuTorch in a follow-up phase. + +## Implemented architecture map + +### Conversion layer + +- `convert_weights.py` + - pulls local or remote HF snapshots. + - reads safetensor shards and extracts: + - speech decoder weights (`decoder.*` from `speech_tokenizer/`) + - optional talker weights (`talker.*` from root model) + - writes `decoder_metadata.json` for export/runtime contracts. + +### Export layer + +- `model.py` + - defines `Qwen3TTSSpeechDecoderExport` wrapper. + - computes output lengths from codec tokens and runs decoder forward. +- `export_qwen3_tts.py` + - lowers wrapper to ExecuTorch. + - attaches `constant_methods` metadata: + - `output_sample_rate` + - `decode_upsample_rate` + - `num_quantizers` + - `codebook_size` + - `fixed_codes_len` + - supports fp32/bf16 and optional 8da4w quant for linear layers. + +### Runtime layer + +- `generate_codes.py` + - uses upstream `Qwen3TTSModel` for text->codec generation. + - supports: + - text-only mode (fallback x-vector prompt from generated silence) + - voice clone mode (`ref_audio` + optional `ref_text`) + - emits compact binary codec file consumed by C++ runner. +- `qwen3_tts_runner.cpp` + - loads exported decoder `.pte`. + - optionally invokes helper script for codec generation. + - pads codec sequence to `fixed_codes_len` and decodes waveform. + - writes PCM16 WAV output. + +## Why fixed-length export is used + +- Initial dynamic-shape export failed with `torch.export` constraint violations + on `codes_len` for decoder internals. +- Static export (`fixed_codes_len=1200`) was adopted to unblock XNNPACK + execution. +- Runner-side padding with sentinel `-1` preserves true output trimming through + decoder length metadata. + +## Follow-up work suggested by this bring-up + +1. Move talker autoregressive generation into ExecuTorch methods + (prefill/decode-step style). +2. Investigate BF16 decode runtime stall observed in current experiments. +3. Add Metal backend support for the speech decoder. +4. Replace helper-script dependency with fully in-runner ExecuTorch graph path. diff --git a/examples/models/qwen3-tts/PROGRESS.md b/examples/models/qwen3-tts/PROGRESS.md new file mode 100644 index 00000000000..7eff5143d9a --- /dev/null +++ b/examples/models/qwen3-tts/PROGRESS.md @@ -0,0 +1,1110 @@ +# Qwen3-TTS Bring-up Progress + +This file records commands, outcomes, and observations for the bring-up. + +## 2026-03-14 + +### Environment notes + +- Conda env: `executorch` +- Installed package: + - `qwen-tts==0.1.1` +- Noted dependency conflict warning: + - `optimum-executorch 0.2.0.dev0 requires transformers==5.0.0rc1` + - `qwen-tts` installation pulled `transformers==4.57.3` + +### Status log + +- [x] Scaffolded `examples/models/qwen3-tts` Python + C++ + CMake files. +- [x] Added conversion script for decoder/talker extraction from HF snapshots. +- [x] Added decoder export script for XNNPACK/portable. +- [x] Added helper for codec generation from text and optional clone prompt. +- [x] Added C++ runner to decode codec ids via exported `model.pte`. +- [x] Run conversion/export/build/runtime experiments. +- [x] Add non-quantized -> quantized experiment outcomes. + +--- + +## Experiment log + +### 1) Converter unit tests + +Command: + +```bash +conda run -n executorch python -m pytest examples/models/qwen3-tts/tests/test_convert_weights.py +``` + +Result: **PASS** (`3 passed`, ~1.8s) + +### 2) Convert HF -> local artifacts + +Command: + +```bash +conda run -n executorch python examples/models/qwen3-tts/convert_weights.py \ + Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + examples/models/qwen3-tts/qwen3_tts_artifacts \ + --model-id-or-path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --save-talker +``` + +Result: **PASS** (`elapsed ~83.5s`) + +Artifacts: + +- `qwen3_tts_decoder.pth`: `436M` +- `qwen3_tts_talker.pth`: `1.7G` +- `decoder_metadata.json`: `1.1K` + +### 3) Export attempts (XNNPACK) + +#### 3.1 Dynamic-shape export attempt + +Command: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_qwen3_tts.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --backend xnnpack \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_fp32 +``` + +Result: **FAIL** + +- Failure reason: `ConstraintViolationError` for dynamic `codes_len` guards in `torch.export`. +- Mitigation: switched to static `--fixed-codes-len` export. + +#### 3.2 FP32 export (static length) + +Command: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_qwen3_tts.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --backend xnnpack \ + --fixed-codes-len 1200 \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_fp32 +``` + +Result: **PASS** (`elapsed ~64.4s`) + +Artifact: + +- `qwen3_tts_exports_fp32/model.pte`: `440M` + +#### 3.3 BF16 export (static length) + +Command: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_qwen3_tts.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --backend xnnpack \ + --fixed-codes-len 1200 \ + --dtype bf16 \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_bf16 +``` + +Result: **PASS** (`elapsed ~46.1s`) + +Artifact: + +- `qwen3_tts_exports_bf16/model.pte`: `222M` + +#### 3.4 8da4w quant export (static length) + +Command: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_qwen3_tts.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --backend xnnpack \ + --fixed-codes-len 1200 \ + --qlinear 8da4w \ + --qlinear-group-size 32 \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_8da4w +``` + +Result: **PASS** (`elapsed ~79.8s`) + +Artifact: + +- `qwen3_tts_exports_8da4w/model.pte`: `285M` + +### 4) Build runner + +Command: + +```bash +make qwen3-tts-cpu +``` + +Result: **PASS** (`elapsed ~206.9s`) + +Binary: + +- `cmake-out/examples/models/qwen3-tts/qwen3_tts_runner` + +### 5) Runtime checks + +#### 5.1 FP32 text-only + +Command: + +```bash +conda run -n executorch cmake-out/examples/models/qwen3-tts/qwen3_tts_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_fp32/model.pte \ + --text "Hello from ExecuTorch Qwen3 TTS." \ + --language English \ + --model_id_or_path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --helper_script examples/models/qwen3-tts/generate_codes.py \ + --output_wav examples/models/qwen3-tts/output_text.wav +``` + +Result: **PASS** (`elapsed ~104.5s`) + +- output: `output_text.wav` +- sample rate: `24000` +- frames: `76800` +- duration: `3.20s` +- file size: `150K` + +#### 5.2 FP32 voice clone (`ref_audio` + `ref_text`) + +Command: + +```bash +conda run -n executorch cmake-out/examples/models/qwen3-tts/qwen3_tts_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_fp32/model.pte \ + --text "This is a voice clone validation run." \ + --language English \ + --model_id_or_path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --ref_audio poem.wav \ + --ref_text "This poem recording is the voice reference transcript." \ + --helper_script examples/models/qwen3-tts/generate_codes.py \ + --output_wav examples/models/qwen3-tts/output_clone.wav +``` + +Result: **PASS** (`elapsed ~100.6s`) + +- output: `output_clone.wav` +- sample rate: `24000` +- frames: `88320` +- duration: `3.68s` +- file size: `173K` + +#### 5.3 8da4w text-only + +Command: + +```bash +conda run -n executorch cmake-out/examples/models/qwen3-tts/qwen3_tts_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_8da4w/model.pte \ + --text "Quantized decoder run." \ + --language English \ + --model_id_or_path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --helper_script examples/models/qwen3-tts/generate_codes.py \ + --output_wav examples/models/qwen3-tts/output_text_8da4w.wav +``` + +Result: **PASS** (`elapsed ~71.1s`) + +- output: `output_text_8da4w.wav` +- sample rate: `24000` +- frames: `53760` +- duration: `2.24s` +- file size: `105K` + +#### 5.4 BF16 runtime + +Command: + +```bash +conda run -n executorch cmake-out/examples/models/qwen3-tts/qwen3_tts_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_bf16/model.pte \ + --text "BF16 decoder run." \ + --language English \ + --model_id_or_path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --helper_script examples/models/qwen3-tts/generate_codes.py \ + --output_wav examples/models/qwen3-tts/output_text_bf16.wav +``` + +Result: **FAIL / TIMEOUT** + +- Process consumed CPU for ~8 minutes with no completed output artifact. +- Command was manually terminated. +- Follow-up needed to profile BF16 runtime behavior on this decoder graph. + +#### 5.5 Performance profiling: padding waste analysis + +Profiling results for 91 real codes (metal_test_codes.bin): + +| Stage | Actual (91 codes) | Padded (1200 codes) | Ratio | +|---|---|---|---| +| quantizer.decode | 0.003s | 0.005s | 1.7x | +| pre_transformer (8-layer attn) | 0.034s | 0.207s | 6x | +| upsample[0-1] (2x each) | 0.058s | 0.181s | 3x | +| **decoder[1-4] (vocoder convs)** | **0.94s** | **16.1s** | **17x** | +| **TOTAL** | **1.1s** | **17.2s** | **15.8x** | + +Root cause: The decoder upsamples codes by 1920x through ConvTranspose1d layers. +Padding 91 codes to 1200 means processing 2.3M samples instead of 175K — a 13x +blowup in vocoder compute that dominates runtime. + +Dynamic shape export fails due to CausalConvNet padding creating `math.ceil` +guard chains incompatible with `torch.export` symbolic shape constraints. + +Solution: Multi-bucket export (`--bucket-sizes 75,150,300,600,1200`) with +nearest-bucket selection at runtime. For 91 codes, the 150 bucket is selected +instead of 1200, giving a proportional ~6-8x decode speedup. + +#### 5.6 Decoder-only sanity runs from precomputed codec ids + +FP32 command: + +```bash +conda run -n executorch cmake-out/examples/models/qwen3-tts/qwen3_tts_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_fp32/model.pte \ + --codes_path /var/folders/.../qwen3_tts_codegen_codes.bin \ + --output_wav examples/models/qwen3-tts/output_from_codes.wav +``` + +Result: **PASS** (`elapsed ~46.8s`, output `71K`) + +8da4w command: + +```bash +conda run -n executorch cmake-out/examples/models/qwen3-tts/qwen3_tts_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_8da4w/model.pte \ + --codes_path /var/folders/.../qwen3_tts_codegen_codes.bin \ + --output_wav examples/models/qwen3-tts/output_from_codes_8da4w.wav +``` + +Result: **PASS** (`elapsed ~49.5s`, output `71K`) + +BF16 command: + +```bash +conda run -n executorch cmake-out/examples/models/qwen3-tts/qwen3_tts_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_bf16/model.pte \ + --codes_path /var/folders/.../qwen3_tts_codegen_codes.bin \ + --output_wav examples/models/qwen3-tts/output_from_codes_bf16.wav +``` + +Result: **FAIL / TIMEOUT** + +- Decoder-only BF16 path also stalled (>5 minutes at ~100% CPU) and was terminated. +- This indicates the issue is likely in BF16 decode execution itself, not helper code generation. + +## 2026-03-18 + +### 6) Decoder multi-bucket export (10.5x speedup) + +Root cause of slow decoder: padding 91 codes to 1200 wastes 13x compute in +the vocoder's 1920x ConvTranspose1d upsample chain. Dynamic shapes fail due +to `math.ceil` guard chains in CausalConvNet. + +Solution: export at multiple fixed `codes_len` values, pick the smallest +bucket >= actual length at runtime. + +Command: + +```bash +python examples/models/qwen3-tts/export_qwen3_tts.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --backend xnnpack --qlinear 8da4w \ + --bucket-sizes 75,150,300,600,1200 \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_8da4w_bucketed +``` + +Result: **PASS** — 5 `.pte` files produced (`model_75.pte` through `model_1200.pte`) + +Benchmark (91 codes → 7.28s audio, 8da4w XNNPACK CPU): + +| Bucket | Decode Time | Speedup vs 1200 | +|--------|-------------|-----------------| +| 75 | N/A | Too small (91 > 75) | +| **150 (selected)** | **3.1s** | **10.5x** | +| 300 | 6.4s | 5.1x | +| 600 | 15.2s | 2.1x | +| 1200 (old default) | 32.4s | 1.0x | + +Scaling is near-linear with bucket size, confirming vocoder cost is +proportional to sequence length. Output quality is identical — 174720 samples +at 24000 Hz (7.28s) in both cases. + +### 7) Talker export to ExecuTorch + +The talker is architecturally identical to Qwen3 0.6B: 28-layer decoder-only +transformer with GQA, SiLU MLP, QK-norm, RoPE. Reused the existing +Llama/Qwen3 export infrastructure directly. + +Actual architecture from weights (differs from HF config defaults): +- dim=1024, n_heads=16, n_kv_heads=8, head_dim=128, hidden_dim=3072 +- Main talker: 28 layers, vocab_size=3072 (codec vocabulary) +- Code predictor: 5 layers, vocab_size=2048, 15 per-group embeddings/heads +- num_code_groups=16 (1 main + 15 sub, matching decoder's num_quantizers=16) + +#### 7.1 Weight conversion + +```bash +python examples/models/qwen3-tts/convert_talker_weights.py \ + --talker-checkpoint examples/models/qwen3-tts/qwen3_tts_artifacts/qwen3_tts_talker.pth \ + --output-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted +``` + +Result: **PASS** + +### 15) XNNPACK confidence gate: metric/accounting fixes + +Goal: + +- Fix the measurement bugs identified during `/autoresearch` before starting MLX + work, then document the best verified warmed XNNPACK path. + +What changed: + +- `qwen3_tts_unified_runner.cpp` + - `codegen_ms` now excludes in-loop streaming decode checkpoints instead of + timing across both code generation and chunk decode. + - Non-streaming `first_audio_ms` is now reported relative to request start, + matching the streaming code path. +- `main_unified.cpp` + - `audio` and `rtf` now use the raw waveform before silence trimming. + - Added `trimmed_audio` and `rtf_trimmed` so post-processing effects stay + visible without polluting the main throughput metric. +- `tests/test_unified_quality_contract.py` + - Added contract coverage for the separated codegen timing and the new raw vs. + trimmed RTF reporting. +- `XNNPACK_CONFIDENCE_STATUS.md` + - Added a dedicated note capturing the current trustworthy XNNPACK status, + the exact warmed benchmark command, and the remaining blockers before we can + claim we beat `mlx-audio`. + +Verification: + +Focused tests: + +```bash +conda run -n executorch python -m unittest \ + examples.models.qwen3-tts.tests.test_unified_runner_contract \ + examples.models.qwen3-tts.tests.test_unified_quality_contract \ + examples.models.qwen3-tts.tests.test_unified_metadata +``` + +Result: **PASS** (`26 tests`) + +Runner rebuild: + +```bash +cmake --build cmake-out/examples/models/qwen3-tts --target qwen3_tts_unified_runner +``` + +Result: **PASS** + +Warmed prompt-set benchmark (checked-in export): + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --prompts_path examples/models/qwen3-tts/benchmark_prompts.txt \ + --repeat 1 \ + --max_new_tokens 128 \ + --temperature 1.0 \ + --top_k 50 \ + --disable_streaming_decoder_surface +``` + +Result: **PASS** + +- Avg raw RTF: `0.51x` +- Avg first audio: `3.57s` +- Avg codegen: `9.16s` +- Avg decode: `1.57s` + +Warmed prompt-set benchmark (temporary `max_seq_len=160` export): + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path /tmp/qwen3_tts_exports_unified_s160/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --prompts_path examples/models/qwen3-tts/benchmark_prompts.txt \ + --repeat 1 \ + --max_new_tokens 128 \ + --temperature 1.0 \ + --top_k 50 \ + --disable_streaming_decoder_surface +``` + +Result: **PASS** + +- Avg raw RTF: `0.52x` +- Avg first audio: `3.43s` +- Avg codegen: `8.74s` +- Avg decode: `1.61s` + +Non-streaming sanity check: + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --text "Hello from the non streaming timing check." \ + --max_new_tokens 64 \ + --temperature 1.0 \ + --top_k 50 \ + --non_streaming_mode \ + --disable_streaming_decoder_surface +``` + +Result: **PASS** + +- `first_audio_ms = 7685.4` +- `generation_ms = 7685.4` +- `final_decode_ms = 902.5` + +Interpretation: + +- The timing/accounting layer is now trustworthy enough to start MLX work. +- The best current XNNPACK path remains the overlap-window fallback + (`--disable_streaming_decoder_surface`). +- XNNPACK is still below realtime after load, so we cannot honestly claim it is + faster than `mlx-audio` yet. + +### 14) Streaming parity upgrade + XNNPACK path comparison + +Goal: align the ExecuTorch streaming path more closely with upstream Qwen3-TTS +chunking semantics, add a dedicated streaming decoder export surface, and +measure whether that new surface actually improves XNNPACK first-audio latency. + +Changes landed in source: + +- Added `capture_reference_streaming_contract.py` to record fixed-seed upstream + codec traces, chunk boundaries, and decode pacing semantics. +- Reworked `qwen3_tts_unified_runner` streaming decode from cumulative + prefix re-decode to bounded overlap-window decode with delta chunk emission. +- Added a dedicated `decode_audio_stream` export surface plus manifest metadata: + `streaming_decoder_contract_version`, `streaming_decoder_chunk_size`, + `streaming_decoder_left_context_size`, and `streaming_decoder_max_codes`. +- Capability-gated the new surface in the C++ runner, warmed it up explicitly, + and split timing into `first_audio_ms`, `chunk_decode_ms`, and + `final_decode_ms`. +- Added contract/metadata/reference tests for the new export surface and runner + switches. + +Verification: + +Reference contract capture: + +```bash +conda run -n executorch python -u \ + examples/models/qwen3-tts/capture_reference_streaming_contract.py \ + --upstream-repo /Users/younghan/project/executorch-exp/Qwen3-TTS \ + --output-dir /tmp/qwen3_streaming_reference \ + --text "Hello from the streaming benchmark path." \ + --language English \ + --max-new-tokens 256 +``` + +Result: **PASS** + +- Wrote fixed-seed upstream reference codes/audio/contract metadata. + +Focused tests: + +```bash +conda run -n executorch python -m pytest \ + examples/models/qwen3-tts/tests/test_unified_runner_contract.py \ + examples/models/qwen3-tts/tests/test_unified_quality_contract.py \ + examples/models/qwen3-tts/tests/test_unified_metadata.py \ + examples/models/qwen3-tts/tests/test_streaming_reference_contract.py +``` + +Result: **PASS** + +Runner rebuild: + +```bash +cmake --build cmake-out/examples/models/qwen3-tts --target qwen3_tts_unified_runner +``` + +Result: **PASS** + +Fresh unified export with streaming decoder surface: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_unified \ + --backend xnnpack \ + --qlinear 8da4w +``` + +Result: **PASS** (`elapsed ~25 min`) + +- Export includes `decode_audio_stream` and the new `streaming_decoder_*` + metadata fields. + +Streaming benchmark command family (same prompt, warmed process, WAV write on): + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --text "Hello from the streaming benchmark path." \ + --language English \ + --max_new_tokens 128 \ + --streaming_interval 2.0 \ + --streaming_chunk_size 300 \ + --streaming_left_context_size 25 \ + --output_wav /tmp/qwen3_streaming.wav +``` + +Results: + +| Mode | Extra flags | First audio | Chunk decode | Generation-only | Audio | RTF | +|------|-------------|-------------|--------------|-----------------|-------|-----| +| Auto streaming surface | none | `8.73s` | `11.41s` | `15.20s` | `2.48s` | `0.16x` | +| Windowed fallback | `--disable_streaming_decoder_surface` | `5.41s` | `1.56s` | `7.39s` | `2.48s` | `0.34x` | +| Legacy cumulative | `--use_legacy_cumulative_streaming_decode` | `5.40s` | `1.60s` | `7.41s` | `2.48s` | `0.33x` | + +Interpretation: + +- The bounded overlap-window decode path is working and materially better than + the old cumulative prefix strategy for first-audio-oriented streaming. +- The new fixed-shape `decode_audio_stream` surface is **not** yet a win on the + current XNNPACK build. It is functionally correct but significantly slower + than the dynamic `decode_audio` fallback on this benchmark. +- The dominant latency is still in talker/code-predictor generation. Streaming + decode remains primarily a first-audio lever, not the main throughput bottleneck. + +Follow-up: + +- Investigate why `decode_audio_stream` regresses on XNNPACK despite the tighter + fixed-shape contract. +- Keep the overlap-window fallback as the preferred current XNNPACK path while + preserving the new export surface for future backend tuning. + +Artifacts: +- `talker_main.pth` (311 keys) — main backbone in Meta/Llama format +- `talker_code_predictor.pth` (56 keys) — code predictor backbone +- `talker_aux.pth` (37 keys) — text_projection, codec_head, per-group embeddings/heads + +#### 7.2 Main talker export (8da4w, max_seq_len=256) + +```bash +python examples/models/qwen3-tts/export_talker.py \ + --checkpoint examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted/talker_main.pth \ + --params examples/models/qwen3-tts/config/talker_config.json \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_talker_8da4w_s256 \ + --backend xnnpack --qlinear 8da4w --max-seq-len 256 +``` + +Result: **PASS** — `talker.pte` (259 MB) + +#### 7.3 Code predictor export (8da4w, max_seq_len=32) + +```bash +python examples/models/qwen3-tts/export_talker.py \ + --checkpoint examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted/talker_code_predictor.pth \ + --params examples/models/qwen3-tts/config/code_predictor_config.json \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_talker_8da4w_s256 \ + --output-name code_predictor.pte \ + --backend xnnpack --qlinear 8da4w --max-seq-len 32 +``` + +Result: **PASS** — `code_predictor.pte` (52 MB) + +Note: `tok_embeddings.weight` and `output.weight` are missing (expected) — +the code predictor has 15 per-group embeddings/heads stored in `talker_aux.pth`. + +#### 7.4 Talker benchmarks + +max_seq_len has a large impact on KV cache attention cost: + +| max_seq_len | Per-step latency | 91 steps total | +|-------------|------------------|----------------| +| 2048 | 269 ms/step | 24.5s | +| **256** | **64 ms/step** | **5.8s** | + +Code predictor (max_seq_len=32): **7.2 ms/step** + +#### 7.5 Projected end-to-end performance + +All stages 8da4w XNNPACK on CPU, 91 codes (7.28s audio): + +| Stage | Steps | Per-step | Total | % of time | +|-------|-------|----------|-------|-----------| +| Main talker | 91 | 64 ms | 5.8s | 31% | +| Code predictor | 1365 (91×15) | 7.2 ms | 9.8s | 53% | +| Decoder (bucket 150) | 1 | — | 3.1s | 16% | +| **Total** | | | **18.7s** | | + +Comparison: + +| Configuration | Total time | Speedup | +|---|---|---| +| Python baseline (all stages) | 58s | 1.0x | +| ExecuTorch 8da4w bucketed (all stages) | **18.7s** | **3.1x** | +| ExecuTorch decoder only (bucket 150 vs 1200) | 3.1s vs 32.4s | 10.5x | + +### 8) Streaming decode (inspired by mlx-audio) + +mlx-audio achieves realtime streaming by decoding audio incrementally every +~25 tokens instead of waiting for all codes. Applied the same approach: + +```bash +python examples/models/qwen3-tts/streaming_generate.py \ + --decoder-dir examples/models/qwen3-tts/qwen3_tts_exports_8da4w_bucketed \ + --codes-path examples/models/qwen3-tts/metal_test_codes.bin \ + --chunk-size 25 +``` + +| Mode | First audio | Total time | RTF | +|------|-------------|------------|-----| +| Streaming (25-code chunks, bucket 75) | **2.15s** | 6.68s | 1.09x RT | +| Non-streaming (all 91, bucket 150) | 3.97s | **3.97s** | 1.84x RT | +| Old baseline (all 91, bucket 1200) | 32.4s | 32.4s | 0.22x RT | + +Streaming gives **2.15s first-audio latency** — user hears audio 1.8s sooner. +Non-streaming is faster total (less padding overhead from fewer decoder calls) +but has higher first-audio latency. + +Key insight from mlx-audio: their streaming decoder maintains conv buffers +across chunks, avoiding redundant computation. Our chunked approach processes +each chunk independently (simpler but less efficient). + +### Remaining work for 3s target + +- [ ] C++ runner integration for talker prefill + decode orchestration +- [ ] Metal/GPU backend export (expected 3-5x speedup over CPU → ~4-6s) +- [ ] Code predictor optimization — currently 53% of total time (1365 steps). + Options: batched/parallel inner loop, model distillation, or fewer code groups +- [ ] Text embedding + text_projection in C++ (currently requires Python) +- [ ] Prefill export (dynamic shape or bucketed) for prompt processing + +## 2026-03-23 + +### 9) Unified text-only prompt-contract rewrite + +Goal: internalize the text-only `generate_codes.py` prompt semantics into +`qwen3_tts_unified_runner` so the unified C++ binary can accept direct text +input with dynamic prompt length instead of the previous fixed 8-slot +approximation. + +Changes landed in source: + +- Added a shared prompt-contract helper (`text_prompt_contract.py`) with tests + for: + - assistant-wrapped prompt formatting + - prompt embedding split (`role`, `first_text`, `trailing + tts_eos`) + - prompt-budget validation (`prefill`, `max_new_tokens`, `max_seq_len`) +- Rewrote `qwen3_tts_unified_runner.cpp` to: + - tokenize the assistant-wrapped prompt instead of raw text + - run `encode_text` over the whole prompt once + - fold the first text token into prefill + - feed trailing text hidden states during autoregressive decode + - enforce prompt-budget guardrails before generation +- Updated the unified runner CLI to: + - reject ambiguous `--codes_path` + `--text` usage + - require `--tokenizer_path` for text mode + - wire `top_p` consistently through the public interface +- Synced unified export manifests and export metadata with the current 7-method + surface, including `cp_generate` and the text-prompt contract fields. +- Added `TODO.md` as the explicit no-compromise backlog. + +Verification: + +Prompt/metadata/runner-contract tests: + +```bash +python -m unittest \ + examples.models.qwen3-tts.tests.test_unified_prompt_flow \ + examples.models.qwen3-tts.tests.test_unified_metadata \ + examples.models.qwen3-tts.tests.test_unified_runner_contract +``` + +Result: **PASS** (`11 tests`) + +Unified runner rebuild: + +```bash +cmake --build cmake-out/examples/models/qwen3-tts --target qwen3_tts_unified_runner +``` + +Result: **PASS** + +Tokenizer materialization for text-mode verification: + +```bash +python - <<'PY' +from transformers import AutoTokenizer +from pathlib import Path +out_dir = Path('examples/models/qwen3-tts/tokenizer_cache') +out_dir.mkdir(parents=True, exist_ok=True) +tok = AutoTokenizer.from_pretrained('Qwen/Qwen3-TTS-12Hz-0.6B-Base', trust_remote_code=True) +tok.save_pretrained(out_dir) +print(out_dir / 'tokenizer.json') +PY +``` + +Result: **PASS** (`examples/models/qwen3-tts/tokenizer_cache/tokenizer.json`) + +Text-mode smoke run against the existing checked-in unified artifact: + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/tokenizer_cache/tokenizer.json \ + --text "Hello from ExecuTorch." \ + --max_new_tokens 128 \ + --output_wav /tmp/qwen3_tts_short.wav +``` + +Result: **FAIL (stale artifact)** + +- The updated runner successfully reached the new assistant-wrapped prompt path + and completed talker prefill. +- The existing `model.pte` is stale: it does **not** contain `cp_generate` and + does **not** expose the new prompt-contract constant methods. +- A fresh unified re-export is required before the text-only end-to-end path can + be verified against a current artifact. + +Follow-up: + +- Re-export `qwen3_tts_exports_unified/model.pte` from the updated + `export_unified.py`. +- Re-run short and longer text prompts through `qwen3_tts_unified_runner`. +- Confirm the fresh artifact exposes `cp_generate` and the prompt-contract + metadata methods. + +### 10) Fresh unified export + real CLI verification + +Tokenizer source requested for runtime verification: + +- `examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base` + +Tokenizer materialization: + +```bash +python - <<'PY' +from transformers import AutoTokenizer +from pathlib import Path +model_dir = Path('examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base') +tok = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) +tok.save_pretrained(model_dir) +print(model_dir / 'tokenizer.json') +PY +``` + +Result: **PASS** + +- Wrote `examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json` +- Note: Transformers warns about the Mistral regex path here, but the + ExecuTorch tokenizer loader successfully falls back to PCRE2 at runtime. + +Makefile runner build requested by user: + +```bash +make qwen3-tts-cpu +``` + +Result: **PASS** + +- `CMakePresets.json` was updated so the `qwen3-tts-cpu` workflow builds + `qwen3_tts_unified_runner` instead of the legacy `qwen3_tts_runner`. +- `Makefile` was updated so the success message points at the unified binary. + +Fresh unified export: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_unified \ + --backend xnnpack \ + --qlinear 8da4w +``` + +Result: **PASS** (`elapsed ~24.0 min`) + +- Saved fresh `qwen3_tts_exports_unified/model.pte` (`2378.5 MB`) +- Saved fresh `qwen3_tts_exports_unified/export_manifest.json` +- Verified source export includes `cp_generate` and the prompt-contract fields. + +#### 10.1 Short real CLI run + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --text "Hello from ExecuTorch." \ + --max_new_tokens 128 \ + --output_wav /tmp/qwen3_tts_short.wav +``` + +Result: **PASS** (`elapsed ~28.2s`) + +- Prompt token count: `15` +- Generated codes: `128` +- Output wav: `/tmp/qwen3_tts_short.wav` +- Samples written: `245760` at `24000 Hz` + +#### 10.2 Longer real CLI run + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --text "ExecuTorch now runs the unified Qwen3 TTS path directly in C plus plus, with the assistant prompt built inside the runner, dynamic prompt length handling, and a fused code predictor path for end to end synthesis on XNNPACK." \ + --max_new_tokens 192 \ + --output_wav /tmp/qwen3_tts_long.wav +``` + +Result: **PASS** (`elapsed ~33.5s`) + +- Prompt token count: `58` +- Generated codes: `192` +- Output wav: `/tmp/qwen3_tts_long.wav` +- Samples written: `368640` at `24000 Hz` + +#### 10.3 Prompt-budget guardrail check + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --text "ExecuTorch now runs the unified Qwen3 TTS path directly in C plus plus, with the assistant prompt built inside the runner, dynamic prompt length handling, and a fused code predictor path for end to end synthesis on XNNPACK." \ + --max_new_tokens 16 \ + --output_wav /tmp/qwen3_tts_guardrail.wav +``` + +Result: **EXPECTED FAIL** + +- The runner rejected the request with: + `max_new_tokens=16 is too small to consume the trailing prompt budget=50.` +- This confirms the dynamic prompt-budget guardrail is active in the real CLI. + +### 11) Quality remediation: codec IDs, sampler parity, and fixed voice artifact + +Root-cause fixes landed after comparing the unified runner against the MLX +Qwen3-TTS reference: + +- Corrected unified export/runtime metadata to use the real codec control token + IDs (`2148..2157`) instead of the stale `4196..4205` band. +- Updated the C++ runner to extract the last-token talker/code-predictor state + after prefill instead of reusing the first token. +- Suppressed the talker special-token band (`[vocab_size - 1024, vocab_size)`) + during `code_0` sampling while still allowing `codec_eos_id`. +- Removed the silent decoder clamp-to-zero fallback for invalid codec IDs and + now fail loudly if the talker/code-predictor produces an out-of-range code. +- Restored closer MLX sampling parity for the text path: + `temperature=0.9`, `top_k=50`, `top_p=1.0`, `repetition_penalty=1.05`. +- Switched the runtime path away from greedy fused `cp_generate` rollout and + back to the stochastic `code_predictor` + `cp_head` loop for the 15 sub-code + groups. + +Focused regression tests: + +```bash +python -m unittest \ + examples.models.qwen3-tts.tests.test_unified_prompt_flow \ + examples.models.qwen3-tts.tests.test_unified_metadata \ + examples.models.qwen3-tts.tests.test_unified_runner_contract \ + examples.models.qwen3-tts.tests.test_unified_quality_contract +``` + +Result: **PASS** + +- Ran `17` tests, all passing. + +Fresh rebuild: + +```bash +make qwen3-tts-cpu +``` + +Result: **PASS** (`elapsed ~52.9s`) + +- Built `cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner` + +Fresh export after quality fixes: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_unified \ + --backend xnnpack \ + --qlinear 8da4w +``` + +Result: **PASS** (`elapsed ~24.0 min`) + +- Saved fresh `examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte` +- Saved fresh `examples/models/qwen3-tts/qwen3_tts_exports_unified/export_manifest.json` +- Verified manifest now records: + `codec_pad_id=2148`, `codec_bos_id=2149`, `codec_eos_id=2150`, + `codec_nothink_id=2155`, `codec_think_bos_id=2156`, + `codec_think_eos_id=2157` + +Fixed-voice validation artifact: + +```bash +./cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --text "ExecuTorch now runs the unified Qwen3 TTS path directly in C plus plus with corrected codec control tokens, special token suppression, and sampling that is aligned more closely with the reference implementation." \ + --max_new_tokens 192 \ + --temperature 0.9 \ + --top_k 50 \ + --top_p 1.0 \ + --repetition_penalty 1.05 \ + --output_wav examples/models/qwen3-tts/output_text_fixed_quality.wav +``` + +Result: **PASS** (`elapsed ~37.9s`) + +- Prompt token count: `49` +- Generated codes: `192` +- Output wav: + `examples/models/qwen3-tts/output_text_fixed_quality.wav` +- Samples written: `368640` at `24000 Hz` + +### 13) Warm XNNPACK benchmark + fused `cp_generate` v2 + +Goal: measure steady-state text-to-voice latency honestly in one warmed process, +then reduce the XNNPACK hot-loop cost without switching backends. + +Changes landed in source: + +- Added a per-prompt `SynthesisSession` API plus `SynthesisTiming` so the + runner can stay loaded/warmed while each request gets fresh RNG and state. +- Updated `main_unified.cpp` with warm benchmark controls: + - `--prompts_path` + - `--repeat` + - `--seed` + - `--output_dir` + - `--disable_fused_cp_generate` +- Split timing into prompt prep, talker prefill, codegen, decode-audio, and + total generation. +- Expanded `warmup_all()` so it actually executes the text path, including + `encode_text`, `talker`, `codec_embed`, `code_predictor`, `cp_head`, + `cp_generate`, and `decode_audio`. +- Replaced the old greedy-only fused `cp_generate` export with a v2 contract + that: + - keeps host-side `code_0` sampling + - samples groups `1..15` inside the fused graph for the XNNPACK fast path + - returns sampled sub-codes plus the fused embedding sum for the next talker step +- Added ABI/version metadata: + - `cp_generate_contract_version = 2` + - `cp_generate_fast_top_k = 50` + - `cp_generate_sampler = cdf_topk50_no_top_p_v2` +- Gated the fast path on exported metadata so older `.pte` artifacts cleanly + fall back to the legacy host-side sub-code loop instead of crashing. +- Aligned host and fused sub-code sampling to the same inverse-CDF categorical + sampler shape for the current fast-path mode (`top_k=50`, top-p disabled). + +Warm benchmark prompt set: + +- `examples/models/qwen3-tts/benchmark_prompts.txt` + +Warm benchmark results (`top_k=50`, `temperature=1.0`, `max_new_tokens=128`, +same warmed process, no WAV writes): + +| Prompt | Legacy generation-only | Fused generation-only | Legacy codegen | Fused codegen | +|--------|-------------------------|-----------------------|----------------|---------------| +| 0 | 3.61s | 5.35s | 3.18s / 20 steps | 4.58s / 37 steps | +| 1 | 12.46s | 12.92s | 10.97s / 81 steps | 11.19s / 88 steps | +| 2 | 21.56s | 14.69s | 18.75s / 128 steps | 12.90s / 95 steps | + +Interpretation: + +- The fused path consistently lowers codegen cost per generated codec step: + - prompt 0: ~159 ms/step -> ~124 ms/step + - prompt 1: ~135 ms/step -> ~127 ms/step + - prompt 2: ~146 ms/step -> ~136 ms/step +- End-to-end warm wall time still depends on sampling trajectory and EOS timing, + so raw prompt latency can move in either direction even when the hot path is + cheaper per step. +- The first-order XNNPACK bottleneck is still the talker/codegen loop, not the + decoder and not startup once warmup is separated out. + +Follow-up evaluation: + +- Talker decode-step specialization remains secondary for now: + warm benchmarks still show `codegen_ms` dominating `decode_audio_ms`. +- Streaming decode is mainly a first-audio latency lever, not the biggest + throughput win for the current single-run warm benchmark. +- The next XNNPACK speed work should focus on: + - reducing generated codec step count without hurting quality + - shrinking per-step talker/code-predictor cost further + - only then revisiting talker decode specialization and streaming decode + +Verification: + +Source/contract tests: + +```bash +python -m unittest \ + examples.models.qwen3-tts.tests.test_unified_prompt_flow \ + examples.models.qwen3-tts.tests.test_unified_metadata \ + examples.models.qwen3-tts.tests.test_unified_runner_contract \ + examples.models.qwen3-tts.tests.test_unified_quality_contract +``` + +Result: **PASS** (`28 tests`) + +Runner rebuild: + +```bash +cmake --build cmake-out/examples/models/qwen3-tts --target qwen3_tts_unified_runner +``` + +Result: **PASS** + +Unified export: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_unified \ + --backend xnnpack --qlinear 8da4w +``` + +Result: **PASS** (`model.pte` + manifest updated with `cp_generate` v2 metadata) + +Warm legacy comparison: + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --prompts_path examples/models/qwen3-tts/benchmark_prompts.txt \ + --repeat 1 \ + --max_new_tokens 128 \ + --temperature 1.0 \ + --top_k 50 \ + --disable_fused_cp_generate +``` + +Result: **PASS** + +Warm fused benchmark: + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --prompts_path examples/models/qwen3-tts/benchmark_prompts.txt \ + --repeat 1 \ + --max_new_tokens 128 \ + --temperature 1.0 \ + --top_k 50 +``` + +Result: **PASS** diff --git a/examples/models/qwen3-tts/README.md b/examples/models/qwen3-tts/README.md new file mode 100644 index 00000000000..47101adfe73 --- /dev/null +++ b/examples/models/qwen3-tts/README.md @@ -0,0 +1,250 @@ +## Qwen3-TTS + +ExecuTorch implementation of `Qwen/Qwen3-TTS-12Hz-0.6B-Base`. + +Supports three backends: **XNNPACK** (CPU), **Metal/AOTI** (Apple GPU), and **portable** (fallback). + +### Performance + +| Backend | 26 codes decode | Realtime | Export time | +|---------|----------------|----------|-------------| +| XNNPACK (8da4w quantized) | **728 ms** | 2.9x RT | ~20 min | +| Metal/AOTI (fp32) | **728 ms** | 2.9x RT | ~8 min | +| Portable (no backend) | 72,761 ms | 0.03x RT | ~2 min | + +Model load + warmup: ~5-7s (one-time at startup). + +Warm XNNPACK multi-prompt benchmark in one process (`top_k=50`, generation-only, +no WAV writes): + +- Legacy host loop (`--disable_fused_cp_generate`): `3.61s`, `12.46s`, `21.56s` +- Fused `cp_generate` v2: `5.35s`, `12.92s`, `14.69s` +- The stable speed win is in `codegen_ms` per generated codec step: + roughly `159/135/146 ms` down to `124/127/136 ms` on the benchmark prompts. + End-to-end wall time still depends on how many codec steps the sampler emits. + +Warm XNNPACK streaming benchmark on the refreshed unified export (`31` codec +steps, `2.48s` audio, same warmed process): + +| Streaming path | Key flags | First audio | Decode time | Generation-only | RTF | +|----------------|-----------|-------------|-------------|-----------------|-----| +| Auto streaming surface | default (`decode_audio_stream`) | `8.73s` | `11.41s` | `15.20s` | `0.16x` | +| Windowed fallback | `--disable_streaming_decoder_surface` | `5.41s` | `1.56s` | `7.39s` | `0.34x` | +| Legacy cumulative | `--use_legacy_cumulative_streaming_decode` | `5.40s` | `1.60s` | `7.41s` | `0.33x` | + +Today the dedicated fixed-shape `decode_audio_stream` export is functionally +correct, but it is slower than the dynamic `decode_audio` overlap-window path on +XNNPACK. The best current XNNPACK streaming path remains the windowed fallback, +and the main steady-state bottleneck is still code generation rather than audio +decode orchestration. + +The CLI now reports `audio` and `rtf` from the raw waveform before silence +trimming, with `trimmed_audio` and `rtf_trimmed` logged separately. On the +corrected warmed prompt-set benchmark, the checked-in `max_seq_len=256` export +reaches about `0.51x` raw realtime and the best temporary `max_seq_len=160` +export reaches about `0.52x`. See `XNNPACK_CONFIDENCE_STATUS.md` for the exact +commands and benchmark breakdown. + +### Model Sizes + +| Config | Size | +|--------|------| +| XNNPACK 8da4w + 4w embedding | **1,027 MB** | +| XNNPACK 8da4w (no emb quant) | 2,065 MB | +| Metal fp32 (mixed w/ XNNPACK decoder) | 4,636 MB | + +## Prerequisites + +```bash +conda activate executorch +pip install qwen-tts + +# For Metal backend only: +sudo mkdir -p /opt/llvm-openmp/lib +sudo ln -sf /opt/homebrew/Cellar/libomp/*/lib/libomp.dylib /opt/llvm-openmp/lib/libomp.dylib +``` + +## 1) Convert Weights + +```bash +python examples/models/qwen3-tts/convert_weights.py \ + Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + examples/models/qwen3-tts/qwen3_tts_artifacts \ + --model-id-or-path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --save-talker + +python examples/models/qwen3-tts/convert_talker_weights.py \ + --talker-checkpoint examples/models/qwen3-tts/qwen3_tts_artifacts/qwen3_tts_talker.pth \ + --output-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted +``` + +## 2) Export + +### XNNPACK (CPU, quantized — recommended for mobile) + +```bash +python examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_xnnpack \ + --backend xnnpack --qlinear 8da4w +``` + +### Metal/AOTI (Apple GPU — recommended for Mac) + +```bash +python examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_metal \ + --backend metal --dtype fp32 +``` + +Metal exports talker/code predictor to GPU, decoder stays on XNNPACK CPU +(Metal lacks `cumsum` fallback needed by the decoder). + +### Portable (no acceleration — for debugging) + +```bash +python examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_portable \ + --backend portable --dtype fp32 +``` + +## 3) Generate Test Codes + +```bash +python examples/models/qwen3-tts/generate_codes.py \ + --model-id-or-path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --text "Hello from ExecuTorch." \ + --output-codes /tmp/hello_codes.bin +``` + +## 4) Build Runner + +```bash +cmake --build cmake-out/examples/models/qwen3-tts --target qwen3_tts_unified_runner +``` + +The runner automatically links XNNPACK and Metal backends if built. + +## 5) Run + +### XNNPACK decode + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_xnnpack/model.pte \ + --codes_path /tmp/hello_codes.bin \ + --output_wav /tmp/hello_xnnpack.wav +``` + +### XNNPACK text-only end to end + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --text "Hello from ExecuTorch." \ + --max_new_tokens 200 \ + --output_wav /tmp/hello_text.wav +``` + +### Warm XNNPACK multi-prompt benchmark + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --prompts_path examples/models/qwen3-tts/benchmark_prompts.txt \ + --repeat 1 \ + --max_new_tokens 128 \ + --temperature 1.0 \ + --top_k 50 +``` + +To compare against the legacy host-side sub-code loop in the same binary, add: + +```bash + --disable_fused_cp_generate +``` + +### Metal decode + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_metal/model.pte \ + --codes_path /tmp/hello_codes.bin \ + --output_wav /tmp/hello_metal.wav +``` + +### Play output + +```bash +afplay /tmp/hello_xnnpack.wav +afplay /tmp/hello_metal.wav +``` + +## Architecture + +Single `model.pte` with 8 named methods: + +| Method | Backend | Purpose | +|--------|---------|---------| +| `encode_text` | Metal/XNNPACK | Text tokens → projected embeddings | +| `talker` | Metal/XNNPACK | 28-layer transformer with KV cache | +| `code_predictor` | Metal/XNNPACK | 5-layer sub-talker with KV cache | +| `codec_embed` | Portable | Codec token embedding lookup | +| `cp_head` | Metal/XNNPACK | Per-group LM head selection | +| `cp_generate` | Metal/XNNPACK | Fused sampled 15-step code predictor fast path | +| `decode_audio` | XNNPACK | Vocoder: codes → waveform (dynamic shapes) | +| `decode_audio_stream` | XNNPACK | Fixed-shape streaming vocoder surface for chunked decode | + +The runner calls `decode_audio` for codes→audio (decode-only mode) or orchestrates +all methods for text-only full text→audio synthesis through the assistant-wrapped +prompt contract used by the Python helper. + +## Files + +| File | Purpose | +|------|---------| +| `export_unified.py` | Multi-method export (XNNPACK/Metal/portable) | +| `main_unified.cpp` | CLI runner with decode-only and text modes | +| `qwen3_tts_unified_runner.*` | C++ runner with lazy loading and warmup | +| `generate_codes.py` | Python talker: text → codec tokens | +| `convert_weights.py` | HF → ExecuTorch weight conversion | +| `convert_talker_weights.py` | Talker weights to Llama format | +| `model.py` | Export wrappers and binary codec I/O | +| `metal_benchmark.md` | Metal backend benchmark results | +| `single_export.md` | Development progress log | + +## Notes + +- The decoder uses dynamic shapes with patched `CausalConvNet` padding + (`math.ceil` → integer ceiling division for `torch.export` compatibility). +- XNNPACK has a one-time warmup cost on first call. The runner now exercises the + full text path in `warmup_all()` so sequential prompt benchmarking reflects + steady-state generation instead of cold delegate setup. +- Leading silence is automatically trimmed (`--trim_silence`, default on), but + the main `audio` / `rtf` metrics are reported from the raw pre-trim waveform. +- Text-only `--text` mode now runs directly in `qwen3_tts_unified_runner` with + dynamic prompt length, explicit prompt-budget checks, and assistant-wrapped + prompt formatting aligned to `generate_codes.py`. +- Warm benchmark mode supports `--prompts_path`, `--repeat`, `--seed`, optional + batch output writing via `--output_dir`, and `--disable_fused_cp_generate` for + apples-to-apples comparisons. +- The recommended tokenizer path for local bring-up is + `examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json`. +- Unified export manifests now record the text prompt contract, the current + 8-method surface, and streaming decoder metadata (`streaming_decoder_*`). +- The new `decode_audio_stream` surface is capability-gated by manifest + metadata. The runner can benchmark it directly, or fall back to the dynamic + `decode_audio` overlap-window path if that is faster on the current backend. +- Text mode still requires an external tokenizer path; tokenizer packaging is + tracked in `TODO.md`. +- Metal/AOTI uses AOTInductor to compile graphs into `.so` with Metal kernels. + Export takes ~8 min but runtime is GPU-accelerated. +- Voice clone / `ref_audio` / `ref_text`, full ICL prompting, and full sampling + parity remain deferred. See `TODO.md` for the no-compromise backlog. diff --git a/examples/models/qwen3-tts/REMEDIATION_HANDOFF.md b/examples/models/qwen3-tts/REMEDIATION_HANDOFF.md new file mode 100644 index 00000000000..d6f4cd75c77 --- /dev/null +++ b/examples/models/qwen3-tts/REMEDIATION_HANDOFF.md @@ -0,0 +1,451 @@ +# Qwen3-TTS Remediation Handoff + +This file is a handoff note for continuing the qwen3-tts remediation work from a different agent. + +Important: + +- This file lives in the main workspace for discoverability. +- The actual in-progress code changes do **not** live in this checkout. +- The active remediation changes are in the worktree: + +```text +/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation +``` + +- The worktree branch is: + +```text +qwen3-tts-red-team-remediation +``` + +- The plan file already exists and should **not** be edited: + +```text +/Users/younghan/.cursor/plans/qwen3-tts_remediation_edc4c5f6.plan.md +``` + +## 1. Resume Here + +If another agent continues this work, start in the remediation worktree: + +```bash +cd "/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation" +git status --short --branch +``` + +Do **not** continue the code work in the main checkout at: + +```text +/Users/younghan/project/executorch +``` + +That main checkout is only where this handoff file was written. + +## 2. Current Todo Status + +Plan todo state at the time this file was written: + +- `repro-import-boundaries`: completed +- `decoder-clone-parity`: completed +- `runtime-hardening`: in progress +- `talker-export-validity`: pending +- `talker-end-to-end`: pending +- `streaming-cleanup`: pending +- `tests-and-docs`: pending + +Important nuance: + +- `decoder-clone-parity` was marked completed in the todo tracker because the code changes, unit tests, and runner build were verified. +- However, the original plan still called for a stronger upstream parity harness against real `Qwen3TTSModel.generate_voice_clone()` behavior. That full end-to-end harness has **not** been added yet. +- If strict plan fidelity matters more than the current todo state, consider reopening that verification gap later under `decoder-clone-parity` or folding it into `tests-and-docs`. + +## 3. What Is Already Implemented + +### 3.1 Reproducibility And Import Boundaries + +Implemented in the remediation worktree: + +- Created `examples/models/qwen3-tts/codec_io.py`. +- Created `examples/models/qwen3-tts/runtime_env.py`. +- Moved codec binary read/write helpers out of `model.py` into `codec_io.py`. +- Made `generate_codes.py` lazy-import `qwen_tts` only inside `main()`. +- Made `model.py` lazy-import `qwen_tts` decoder internals only inside `load_decoder_from_metadata()`. +- Made `export_qwen3_tts.py` lazy-import ExecuTorch lowering pieces only inside `lower_to_executorch()`. +- Added explicit environment preflight logic with: + - validated `qwen-tts` version + - validated `transformers` version + - optional SoX check +- Added an explicit BF16 gate so `export_qwen3_tts.py --dtype bf16` fails early with an actionable error instead of emitting a known-bad path. +- Updated `README.md` to document: + - the validated Python environment matrix + - that BF16 is intentionally blocked for now + +New tests added: + +- `examples/models/qwen3-tts/tests/test_startup_and_env.py` + +### 3.2 Decoder Semantics And Clone Parity + +Implemented in the remediation worktree: + +- `Qwen3TTSSpeechDecoderExport.forward()` now calls `decoder.chunked_decode(...)` instead of direct `decoder(...)`. +- `generate_codes.py` now has: + - `--allow-silence-bootstrap` + - `_validate_prompt_mode()` + - `_prepare_codes_for_decode()` + - `_build_codes_metadata()` + - `_metadata_output_paths()` +- Clone-mode helper output now preserves prefix/reference codec context by: + - prepending `prompt_dict["ref_code"]` to generated codes when present + - writing `prefix_codes_len` into metadata + - always writing the runner-visible sibling metadata sidecar at `codes_path.with_suffix(".json")` +- `main.cpp` now: + - rejects text-only helper invocation unless either `--ref_audio` is provided or `--allow_silence_bootstrap` is explicitly passed + - forwards `allow_silence_bootstrap` to the helper +- `qwen3_tts_runner.h/.cpp` now: + - reads sibling metadata sidecar JSON for `prefix_codes_len` + - trims decoded waveform proportionally after vocoder execution + - clears waveform instead of erroring when `prefix_codes_len == codes_len` + +New tests added: + +- `examples/models/qwen3-tts/tests/test_decoder_clone_parity.py` + +## 4. Files Currently Modified In The Remediation Worktree + +At the time this note was written, `git status --short --branch` in the remediation worktree showed: + +```text +## qwen3-tts-red-team-remediation + M examples/models/qwen3-tts/README.md + M examples/models/qwen3-tts/export_qwen3_tts.py + M examples/models/qwen3-tts/generate_codes.py + M examples/models/qwen3-tts/main.cpp + M examples/models/qwen3-tts/model.py + M examples/models/qwen3-tts/qwen3_tts_runner.cpp + M examples/models/qwen3-tts/qwen3_tts_runner.h +?? examples/models/qwen3-tts/codec_io.py +?? examples/models/qwen3-tts/runtime_env.py +?? examples/models/qwen3-tts/tests/test_decoder_clone_parity.py +?? examples/models/qwen3-tts/tests/test_startup_and_env.py +``` + +These are still uncommitted. + +No runtime-hardening, talker-export-validity, talker-end-to-end, streaming-cleanup, or docs/progress-file changes have been made yet beyond the files listed above. + +## 5. Fresh Verification Evidence + +### 5.1 Python Tests + +The last full targeted Python regression run from the remediation worktree was: + +```bash +conda run -n executorch python -m pytest \ + examples/models/qwen3-tts/tests/test_convert_weights.py \ + examples/models/qwen3-tts/tests/test_startup_and_env.py \ + examples/models/qwen3-tts/tests/test_decoder_clone_parity.py \ + -q +``` + +Result: + +```text +13 passed in 5.62s +``` + +### 5.2 Python Entry Point Startup + +These were verified successfully: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_qwen3_tts.py --help +conda run -n executorch python examples/models/qwen3-tts/generate_codes.py --help +``` + +The current machine environment intentionally still fails the real helper preflight in a friendly way because it is **not** yet in the documented supported state: + +```bash +conda run -n executorch python examples/models/qwen3-tts/generate_codes.py \ + --model-id-or-path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --text hello \ + --output-codes /tmp/qwen3_tts_test_codes.bin +``` + +Observed expected failure reason: + +- `transformers==5.3.0` is installed instead of `4.57.3` +- `sox` is missing from `PATH` + +This is expected with the new preflight code. + +### 5.3 Runner Build + +The qwen3-tts runner build was freshly re-verified in this session, but **not** from the raw remediation worktree path. + +Important build constraint: + +- ExecuTorch top-level CMake currently hard-fails unless the source directory name is exactly `executorch`. +- Therefore, direct `make qwen3-tts-cpu` from: + +```text +/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation +``` + +will fail with the repo-name restriction. + +The working build workaround is to create a symlink alias whose final path component is `executorch`. + +Correct command: + +```bash +cd "/Users/younghan/project/executorch" +ln -s "/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation" ".worktrees/executorch" +``` + +Important: + +- Use the absolute target path exactly as above. +- A previous relative symlink attempt was wrong and caused `cd` failures. + +Then run: + +```bash +git -C "/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation" submodule update --init --recursive + +cmake -S "/Users/younghan/project/executorch/.worktrees/executorch" --preset llm-release + +cmake --build "/Users/younghan/project/executorch/.worktrees/executorch/cmake-out" --target install --parallel 8 + +cmake --preset qwen3-tts-cpu -S "/Users/younghan/project/executorch/.worktrees/executorch/examples/models/qwen3-tts" + +cmake --build "/Users/younghan/project/executorch/.worktrees/executorch/cmake-out/examples/models/qwen3-tts" --target qwen3_tts_runner --parallel 8 +``` + +Fresh result from this session: + +- ExecuTorch configure via symlink alias succeeded +- ExecuTorch install build succeeded +- qwen3-tts runner configure succeeded +- qwen3-tts runner build succeeded + +Last meaningful lines: + +```text +[ 33%] Linking CXX executable qwen3_tts_runner +[100%] Built target qwen3_tts_runner +``` + +Warnings observed during the qwen3-tts-specific configure stage that did **not** block the build: + +- duplicate-library warning during link +- optional library warnings such as: + - `aoti_cuda_backend library is not found` + - `flatccrt library is not found` + - `etdump library is not found` + - `bundled_program library is not found` + - `metal_backend library is not found` + - others in the same pattern + +These warnings did not prevent `qwen3_tts_runner` from building in the verified path above. + +## 6. Environment And Local State Gotchas + +### 6.1 Helper Environment + +The current `executorch` conda env is **not yet** suitable for real helper/export execution because: + +- `qwen-tts==0.1.1` is expected +- `transformers==4.57.3` is required by the current preflight +- the machine currently has `transformers==5.3.0` +- `sox` is missing from `PATH` + +This means: + +- startup and help commands are now fixed and work +- real `qwen_tts` runs are intentionally blocked by preflight until the environment is corrected + +If the next agent needs real generation/export instead of just unit tests: + +1. install the supported `transformers` version +2. ensure `sox` is installed and on `PATH` +3. re-run the relevant helper/export commands + +### 6.2 Worktree Move Limitation + +Do **not** try to rename the remediation worktree using `git worktree move` after submodules are checked out. + +Observed error: + +```text +fatal: working trees containing submodules cannot be moved or removed +``` + +That is why the symlink alias workaround is used instead of renaming the worktree. + +### 6.3 Symlink Alias Is Local Only + +The `.worktrees/executorch` symlink is only a local build convenience artifact. + +- It is not a committed repo change. +- If it goes missing, recreate it with the exact absolute symlink command shown above. + +## 7. What Is Not Done Yet + +### 7.1 Runtime Hardening + +Status: in progress in the todo tracker, but no substantive code changes have been made yet for this workstream. + +Still outstanding: + +- Replace ad-hoc string scanning of `export_manifest.json` in `qwen3_tts_runner.cpp` with `nlohmann/json`. +- Replace ad-hoc metadata parsing in `read_codes_metadata()` with `nlohmann/json`. +- Harden codec file parsing: + - header size validation + - multiplication overflow checks + - payload size checks + - `codes_len * num_quantizers` consistency + - `num_quantizers == exported metadata` + - `0 <= code < codebook_size` +- Replace fixed temp file path in `main.cpp`: + +```text +qwen3_tts_codegen_codes.bin +``` + +with a unique temp file and cleanup flow. +- Stop eager-loading all bucket models in `from_model_dir()`. +- Add bucket-load-aware latency accounting. +- Add negative tests for malformed codec input and malformed manifest input. + +### 7.2 Talker Export Validity + +Status: untouched so far. + +Files not yet updated: + +- `examples/models/qwen3-tts/export_talker.py` +- `examples/models/qwen3-tts/convert_talker_weights.py` +- `examples/models/qwen3-tts/config/talker_config.json` +- `examples/models/qwen3-tts/config/code_predictor_config.json` + +Known outstanding issues from the plan: + +- `export_talker.py` still uses `strict=False` for loading state dicts. +- warning-only invalid exports are still possible +- no strict config-vs-checkpoint validation exists yet +- no reusable talker manifest separating backbone vs aux weights exists yet + +### 7.3 End-To-End ExecuTorch Talker Orchestration + +Status: not started. + +Missing: + +- `examples/models/qwen3-tts/talker_exec.py` +- prefill/decode-step orchestration +- aux weight integration +- greedy parity harness vs upstream Qwen/HF + +### 7.4 Streaming Cleanup + +Status: not started. + +`examples/models/qwen3-tts/streaming_generate.py` is still in the red-team state: + +- requires `--talker-dir` even for decode-only path +- eagerly loads all decoder buckets +- accepts oversize `chunk-size` without proper rejection path +- reports chunk-concatenation as if it were true streaming +- latency accounting is not yet separated into honest metrics + +### 7.5 Docs / Progress / Landing + +Status: not started beyond the initial README env note. + +Still needed: + +- update `README.md` for streaming caveats and current supported flows +- update `PROGRESS.md` to reflect completed remediation steps +- add broader regression coverage +- possibly reopen or supplement `decoder-clone-parity` verification with true upstream harnessing + +## 8. Recommended Exact Next Steps + +If another agent resumes immediately, the safest order is: + +1. Read this file completely. +2. Read the plan file completely. +3. Switch to the remediation worktree. +4. Recreate the symlink alias if it is missing. +5. Re-run the targeted Python regression suite to make sure the starting point still matches this handoff: + +```bash +conda run -n executorch python -m pytest \ + examples/models/qwen3-tts/tests/test_convert_weights.py \ + examples/models/qwen3-tts/tests/test_startup_and_env.py \ + examples/models/qwen3-tts/tests/test_decoder_clone_parity.py \ + -q +``` + +6. If a C++ rebuild is needed, use the symlink-alias path and the exact commands from Section 5.3. +7. Continue with `runtime-hardening` first. + +Suggested first code slice for `runtime-hardening`: + +- create/extend tests first for malformed codec and manifest handling +- replace `export_manifest.json` parsing with `nlohmann/json` +- replace metadata sidecar parsing with `nlohmann/json` +- then harden codec input parsing +- then replace the fixed temp file +- then move to lazy bucket loading + +## 9. Suggested Commands For The Next Agent + +### Resume Worktree + +```bash +cd "/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation" +git status --short --branch +``` + +### Recreate Build Alias If Missing + +```bash +cd "/Users/younghan/project/executorch" +ln -s "/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation" ".worktrees/executorch" +``` + +### Sync Submodules In The Remediation Worktree + +```bash +git -C "/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation" submodule update --init --recursive +``` + +### Re-run Verified Python Tests + +```bash +cd "/Users/younghan/project/executorch/.worktrees/qwen3-tts-red-team-remediation" +conda run -n executorch python -m pytest \ + examples/models/qwen3-tts/tests/test_convert_weights.py \ + examples/models/qwen3-tts/tests/test_startup_and_env.py \ + examples/models/qwen3-tts/tests/test_decoder_clone_parity.py \ + -q +``` + +### Rebuild Runner Using Working Path + +```bash +cmake -S "/Users/younghan/project/executorch/.worktrees/executorch" --preset llm-release +cmake --build "/Users/younghan/project/executorch/.worktrees/executorch/cmake-out" --target install --parallel 8 +cmake --preset qwen3-tts-cpu -S "/Users/younghan/project/executorch/.worktrees/executorch/examples/models/qwen3-tts" +cmake --build "/Users/younghan/project/executorch/.worktrees/executorch/cmake-out/examples/models/qwen3-tts" --target qwen3_tts_runner --parallel 8 +``` + +## 10. Final Honesty Notes + +- The code state is real and verified for the completed workstreams listed above. +- The fresh runner build verification is real and was done in this session using the symlink alias workaround. +- The next workstreams are **not** started yet, except for the todo tracker moving `runtime-hardening` to in progress. +- The current environment is still intentionally hostile to real `qwen_tts` execution until `transformers` and `sox` are corrected. +- The last code-review subagent for task 2 aborted because the session moved on, not because a code issue was found. The latest concrete state was validated by tests, lints, and runner build instead. diff --git a/examples/models/qwen3-tts/TODO.md b/examples/models/qwen3-tts/TODO.md new file mode 100644 index 00000000000..5d79fbf98f2 --- /dev/null +++ b/examples/models/qwen3-tts/TODO.md @@ -0,0 +1,54 @@ +# Qwen3-TTS No-Compromise Backlog + +This file tracks work we are explicitly deferring while the current milestone +focuses on text-only end-to-end C++ synthesis through +`qwen3_tts_unified_runner`. + +## Must-Fix Semantic Gaps + +- [ ] Fix the talker/codebook token-space mismatch so `codec_eos_id` is actually reachable from the sampled talker vocabulary instead of relying on `max_new_tokens`. +- [ ] Remove the decoder clamp fallback that silently maps out-of-range codec ids to `0`. +- [ ] Close the remaining text-generation parity gap with `generate_codes.py` for `language`, `top_p`, `repetition_penalty`, and `non_streaming_mode`. +- [ ] Verify that the assistant-wrapped prompt string used by the C++ runner matches the upstream `qwen_tts` helper exactly, not just the `mlx-audio` approximation. +- [ ] Add a deterministic parity harness that compares text-only codec traces between the Python helper and the unified C++ path. + +## Deferred Feature Parity + +- [ ] Internalize `ref_audio` + `ref_text` voice-clone prompting into the unified C++ runner. +- [ ] Support x-vector-only speaker prompting in the unified C++ path. +- [ ] Support full ICL prompting with reference speech-token context instead of text-only prompting. +- [ ] Decide whether to export extra primitives for speaker-conditioning flows or keep them in host-side orchestration. +- [ ] Add explicit support for non-English language conditioning instead of logging and falling back to the text-only default. + +## Performance And Realtime Work + +- [ ] Measure the new text-only unified C++ path end to end and compare it against the two-stage `generate_codes.py -> qwen3_tts_unified_runner` baseline. +- [ ] Build a proper realtime scorecard: cold start, warm start, first-audio latency, full decode latency, and realtime factor. +- [ ] Optimize the code-predictor path, which still dominates projected CPU latency. +- [ ] Explore whether the `cp_generate` fused path can reduce per-step latency further without changing semantics. +- [ ] Revisit streaming decode with MLX-style persistent conv-buffer reuse instead of chunked re-decode. +- [ ] Revisit Metal/GPU acceleration once the text-only C++ semantics are stable. + +## Packaging And Reproducibility + +- [ ] Re-export the unified `.pte` artifacts after the prompt-contract metadata changes and verify they load correctly. +- [ ] Package `tokenizer.json` alongside unified export artifacts or define a stable artifact-discovery contract for text mode. +- [ ] Audit all checked-in unified export manifests and generated artifacts for drift against current source. +- [ ] Decide whether checked-in `.pte` artifacts should remain in-tree or move to reproducible export scripts plus manifests only. +- [ ] Capture the exact export command and expected artifact set for the unified text-only path. + +## Validation And Regression Coverage + +- [ ] Add a regression test that fails if a checked-in manifest drops `cp_generate` again. +- [ ] Add a stronger runner-contract test that validates text-mode CLI behavior against the built binary, not just source text. +- [ ] Add export-metadata tests for the new prompt-budget constant methods. +- [ ] Add end-to-end smoke coverage for at least one short prompt and one longer prompt through the unified C++ path. +- [ ] Add explicit coverage for prompt-budget failure cases: prompt too short, `max_new_tokens` too small, and `max_seq_len` overflow. +- [ ] Add a reproducible fixture corpus for text-only, multilingual, x-vector-only, and full clone cases. + +## MLX Reference Follow-Ups + +- [ ] Decide which `mlx-audio` prompt-preparation pieces should be mirrored directly and which should remain reference-only. +- [ ] Investigate whether MLX's “first text token folded into prefill” path remains correct under the exact upstream `qwen_tts` tokenizer output. +- [ ] Investigate whether MLX's non-streaming ICL overlay should become the long-term reference for ExecuTorch clone mode. +- [ ] Avoid porting MLX runtime-only details such as cache/eval mechanics and heuristic streaming constants without an ExecuTorch-specific rationale. diff --git a/examples/models/qwen3-tts/XNNPACK_CONFIDENCE_STATUS.md b/examples/models/qwen3-tts/XNNPACK_CONFIDENCE_STATUS.md new file mode 100644 index 00000000000..217e249358e --- /dev/null +++ b/examples/models/qwen3-tts/XNNPACK_CONFIDENCE_STATUS.md @@ -0,0 +1,91 @@ +# XNNPACK Confidence Status + +This note records the measurement fixes we made before starting MLX work and the +best warmed XNNPACK path we can currently defend. + +## Measurement fixes completed + +- `codegen_ms` now excludes in-loop streaming decode checkpoints. Previously the + metric double-counted chunk decode time and overstated the hot loop cost. +- Non-streaming `first_audio_ms` is now anchored to request start instead of the + start of the final decode phase, so it is comparable to streaming runs. +- The CLI now reports `audio` and `rtf` from the raw waveform before silence + trimming. `trimmed_audio` and `rtf_trimmed` are logged separately so + post-processing no longer inflates the main throughput metric. + +## Best verified warmed XNNPACK path + +Use the bounded overlap-window decoder, not the fixed-shape +`decode_audio_stream` surface: + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --prompts_path examples/models/qwen3-tts/benchmark_prompts.txt \ + --repeat 1 \ + --max_new_tokens 128 \ + --temperature 1.0 \ + --top_k 50 \ + --disable_streaming_decoder_surface +``` + +Warmed prompt-set results after the accounting fixes: + +| Export | Max seq len | Avg raw RTF | Avg first audio | Avg codegen | Avg decode | +|--------|-------------|-------------|-----------------|-------------|------------| +| Checked-in unified export | 256 | `0.51x` | `3.57s` | `9.16s` | `1.57s` | +| Tuned experimental export | 160 | `0.52x` | `3.43s` | `8.74s` | `1.61s` | + +Notes: + +- The `max_seq_len=160` export remains the best measured XNNPACK artifact so far, + but only by a small margin after the metric fix. +- The checked-in `decode_audio_stream` surface is still slower than the dynamic + overlap-window fallback on this backend. +- The hot loop is still dominated by talker/code-predictor generation, not audio + decode. + +## Additional non-streaming sanity check + +Spot check command: + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path examples/models/qwen3-tts/qwen3_tts_exports_unified/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --text "Hello from the non streaming timing check." \ + --max_new_tokens 64 \ + --temperature 1.0 \ + --top_k 50 \ + --non_streaming_mode \ + --disable_streaming_decoder_surface +``` + +Observed result: + +- `first_audio_ms = 7685.4` +- `generation_ms = 7685.4` +- `final_decode_ms = 902.5` + +That matches the expected non-streaming behavior: first audio is only available +after the full generation + final decode path completes. + +## What we can now say with confidence + +- The metric layer is no longer hiding chunk decode time inside `codegen_ms`. +- The raw XNNPACK throughput number is about `0.51x` to `0.52x` realtime on the + current warmed short-prompt benchmark. +- The best current XNNPACK path is `--disable_streaming_decoder_surface`. +- XNNPACK streaming is still below realtime after load. + +## What still blocks a "faster than mlx-audio" claim + +- We still do bounded window re-decode on the decoder side; we do not yet have a + true stateful incremental vocoder path like the MLX reference. +- The tuned `max_seq_len=160` export is reproducible, but it is not yet the + default checked-in artifact or a documented export preset. +- We do not yet have an apples-to-apples benchmark harness that runs our future + MLX path and the upstream `mlx-audio` path on the exact same prompt set. +- `decode_audio_stream` remains a regression on current XNNPACK, so the export + surface intended for streaming still needs backend-specific tuning. diff --git a/examples/models/qwen3-tts/__init__.py b/examples/models/qwen3-tts/__init__.py new file mode 100644 index 00000000000..c918bea7bc8 --- /dev/null +++ b/examples/models/qwen3-tts/__init__.py @@ -0,0 +1,2 @@ +# This directory intentionally uses a dash in its name (`qwen3-tts`), +# so it is not imported as a standard Python package. diff --git a/examples/models/qwen3-tts/benchmark_mlx.py b/examples/models/qwen3-tts/benchmark_mlx.py new file mode 100644 index 00000000000..80e92010014 --- /dev/null +++ b/examples/models/qwen3-tts/benchmark_mlx.py @@ -0,0 +1,281 @@ +#!/usr/bin/env python3 +"""Benchmark local mlx-audio against the cached MLX session backend.""" + +from __future__ import annotations + +import argparse +import time +from pathlib import Path +from statistics import mean + +from mlx_backend import Qwen3TTSMlxBackend + + +DEFAULT_MODEL = "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16" +DEFAULT_REF_TEXT = "This is what my voice sounds like." + + +def _repo_root() -> Path: + return Path(__file__).resolve().parents[3] + + +def _default_ref_audio() -> Path: + return _repo_root() / "poem.wav" + + +def _default_prompts_path() -> Path: + return Path(__file__).with_name("benchmark_prompts.txt") + + +def _load_prompts(path: Path) -> list[str]: + prompts = [line.strip() for line in path.read_text(encoding="utf-8").splitlines()] + return [prompt for prompt in prompts if prompt] + + +def _print_summary(name: str, metrics) -> None: + avg_throughput = mean(metric.throughput_x for metric in metrics) + avg_first_audio = mean(metric.first_audio_s for metric in metrics) + total_audio = sum(metric.audio_s for metric in metrics) + total_elapsed = sum(metric.elapsed_s for metric in metrics) + total_throughput = total_audio / total_elapsed if total_elapsed > 0.0 else 0.0 + print() + print(f"{name} summary") + print(f" Average throughput : {avg_throughput:.3f}x (> 1 = faster than real-time)") + print(f" Total throughput : {total_throughput:.3f}x") + print(f" Average first audio: {avg_first_audio:.2f}s") + print(f" Total audio : {total_audio:.2f}s") + print(f" Total elapsed : {total_elapsed:.2f}s") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--mlx_audio_repo", + type=Path, + default=None, + help="Optional local mlx-audio checkout to prepend to PYTHONPATH.", + ) + parser.add_argument( + "--model", + type=str, + default=DEFAULT_MODEL, + help="MLX Qwen3-TTS model path or repo id.", + ) + parser.add_argument( + "--prompts_path", + type=Path, + default=_default_prompts_path(), + help="Prompt set for warmed sequential benchmarking.", + ) + parser.add_argument( + "--ref_audio", + type=Path, + default=_default_ref_audio(), + help="Reference audio used for base-model ICL prompting.", + ) + parser.add_argument( + "--ref_text", + type=str, + default=DEFAULT_REF_TEXT, + help="Transcript for the reference audio.", + ) + parser.add_argument( + "--mode", + choices=("baseline", "cached_session", "both"), + default="both", + help="Which MLX path to benchmark.", + ) + parser.add_argument( + "--stream", + action="store_true", + help="Enable the streaming decoder path.", + ) + parser.add_argument( + "--streaming_interval", + type=float, + default=4.0, + help="Streaming interval in seconds when --stream is enabled.", + ) + parser.add_argument( + "--streaming_context_size", + type=int, + default=25, + help="Streaming left context size for the cached session path.", + ) + parser.add_argument( + "--seed", + type=int, + default=123, + help="Base seed for mx.random; each prompt offsets this by its index.", + ) + parser.add_argument( + "--max_tokens", + type=int, + default=4096, + help="Maximum codec steps to generate.", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.9, + help="Sampling temperature.", + ) + parser.add_argument( + "--top_k", + type=int, + default=50, + help="Top-k for sampling.", + ) + parser.add_argument( + "--top_p", + type=float, + default=1.0, + help="Top-p for sampling.", + ) + parser.add_argument( + "--repetition_penalty", + type=float, + default=1.5, + help="Repetition penalty for ICL generation.", + ) + parser.add_argument( + "--warmup_text", + type=str, + default="Hi.", + help="Warmup prompt run once after model load.", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + prompts = _load_prompts(args.prompts_path) + + print(f"Prompts : {args.prompts_path}") + print(f"Reference audio: {args.ref_audio}") + print(f"Stream : {args.stream}") + print(f"Mode : {args.mode}") + print() + + load_t0 = time.perf_counter() + backend = Qwen3TTSMlxBackend( + model_path=args.model, + mlx_audio_repo=args.mlx_audio_repo, + ) + load_s = time.perf_counter() - load_t0 + print(f"Device : {backend.mx.default_device()}") + print(f"Model load : {load_s:.2f}s") + if backend.repo_path is not None: + print(f"mlx-audio repo : {backend.repo_path}") + print() + + print("Warmup baseline generate...") + warmup_baseline = backend.warmup( + text=args.warmup_text, + ref_audio=args.ref_audio, + ref_text=args.ref_text, + stream=args.stream, + streaming_interval=args.streaming_interval, + seed=args.seed, + ) + print( + f"Warmup baseline: elapsed={warmup_baseline.elapsed_s:.2f}s " + f"audio={warmup_baseline.audio_s:.2f}s " + f"throughput={warmup_baseline.throughput_x:.3f}x" + ) + + session = backend.create_icl_session( + ref_audio=args.ref_audio, + ref_text=args.ref_text, + ) + warmup_cached = session.benchmark( + text=args.warmup_text, + stream=args.stream, + streaming_interval=args.streaming_interval, + streaming_context_size=args.streaming_context_size, + seed=args.seed, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + repetition_penalty=args.repetition_penalty, + max_tokens=args.max_tokens, + ) + print( + f"Warmup cached : elapsed={warmup_cached.elapsed_s:.2f}s " + f"audio={warmup_cached.audio_s:.2f}s " + f"throughput={warmup_cached.throughput_x:.3f}x" + ) + + baseline_metrics = [] + cached_metrics = [] + print() + + for prompt_idx, prompt in enumerate(prompts): + prompt_seed = args.seed + prompt_idx + print(f"Prompt {prompt_idx}: {prompt}") + if args.mode in ("baseline", "both"): + baseline = backend.benchmark_baseline( + text=prompt, + ref_audio=args.ref_audio, + ref_text=args.ref_text, + stream=args.stream, + streaming_interval=args.streaming_interval, + seed=prompt_seed, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + repetition_penalty=args.repetition_penalty, + max_tokens=args.max_tokens, + ) + baseline_metrics.append(baseline) + print( + " baseline " + f"elapsed={baseline.elapsed_s:.2f}s " + f"audio={baseline.audio_s:.2f}s " + f"throughput={baseline.throughput_x:.3f}x " + f"first_audio={baseline.first_audio_s:.2f}s " + f"chunks={baseline.chunk_count}" + ) + if args.mode in ("cached_session", "both"): + cached = session.benchmark( + text=prompt, + stream=args.stream, + streaming_interval=args.streaming_interval, + streaming_context_size=args.streaming_context_size, + seed=prompt_seed, + temperature=args.temperature, + top_k=args.top_k, + top_p=args.top_p, + repetition_penalty=args.repetition_penalty, + max_tokens=args.max_tokens, + ) + cached_metrics.append(cached) + print( + " cached_session " + f"elapsed={cached.elapsed_s:.2f}s " + f"audio={cached.audio_s:.2f}s " + f"throughput={cached.throughput_x:.3f}x " + f"first_audio={cached.first_audio_s:.2f}s " + f"chunks={cached.chunk_count}" + ) + + if baseline_metrics: + _print_summary("Baseline mlx-audio", baseline_metrics) + if cached_metrics: + _print_summary("Cached session backend", cached_metrics) + if baseline_metrics and cached_metrics: + baseline_avg = mean(metric.throughput_x for metric in baseline_metrics) + cached_avg = mean(metric.throughput_x for metric in cached_metrics) + speedup = cached_avg / baseline_avg if baseline_avg > 0.0 else 0.0 + print() + print( + "Cached session speedup: " + f"{speedup:.3f}x over baseline mlx-audio" + ) + + print(f"Peak memory : {backend.mx.get_peak_memory() / 1e9:.2f} GB") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/examples/models/qwen3-tts/benchmark_prompts.txt b/examples/models/qwen3-tts/benchmark_prompts.txt new file mode 100644 index 00000000000..0654cd435c1 --- /dev/null +++ b/examples/models/qwen3-tts/benchmark_prompts.txt @@ -0,0 +1,3 @@ +Hello from ExecuTorch. +Please speak clearly and keep the pacing natural from the first word to the final sentence. +This is a warm benchmark run for sequential text to speech generation on XNNPACK in a single process. diff --git a/examples/models/qwen3-tts/capture_reference_streaming_contract.py b/examples/models/qwen3-tts/capture_reference_streaming_contract.py new file mode 100644 index 00000000000..a8882264994 --- /dev/null +++ b/examples/models/qwen3-tts/capture_reference_streaming_contract.py @@ -0,0 +1,276 @@ +import argparse +import json +import random +import struct +import sys +from pathlib import Path + +import numpy as np +import torch +from transformers import modeling_rope_utils as hf_rope_utils +from transformers.utils import generic as hf_generic + +SCRIPT_DIR = Path(__file__).resolve().parent +if str(SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(SCRIPT_DIR)) + + +STREAMING_CHUNK_SIZE = 300 +STREAMING_LEFT_CONTEXT_SIZE = 25 + + +if not hasattr(hf_generic, "check_model_inputs"): + def _identity_check_model_inputs(*args, **kwargs): + def decorator(fn): + return fn + + return decorator + + hf_generic.check_model_inputs = _identity_check_model_inputs + + +if "default" not in hf_rope_utils.ROPE_INIT_FUNCTIONS: + def _compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None): + if hasattr(config, "standardize_rope_params"): + config.standardize_rope_params() + rope_parameters = getattr(config, "rope_parameters", None) + if rope_parameters is None: + base = getattr(config, "rope_theta", 10000.0) + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + else: + rope_parameters = ( + rope_parameters[layer_type] if layer_type is not None else rope_parameters + ) + base = rope_parameters.get("rope_theta", getattr(config, "rope_theta", 10000.0)) + partial_rotary_factor = rope_parameters.get( + "partial_rotary_factor", + getattr(config, "partial_rotary_factor", 1.0), + ) + head_dim = getattr(config, "head_dim", None) or ( + config.hidden_size // config.num_attention_heads + ) + dim = int(head_dim * partial_rotary_factor) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, dim, 2, dtype=torch.int64).to( + device=device, dtype=torch.float + ) + / dim + ) + ) + return inv_freq, 1.0 + + hf_rope_utils.ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Capture the upstream Qwen3-TTS streaming contract for parity checks." + ) + parser.add_argument( + "--upstream-repo", + type=Path, + default=Path("/Users/younghan/project/executorch-exp/Qwen3-TTS"), + ) + parser.add_argument("--model-id-or-path", required=True, type=str) + parser.add_argument("--output-dir", required=True, type=Path) + parser.add_argument("--text", required=True, type=str) + parser.add_argument("--language", default="English", type=str) + parser.add_argument("--instruct", default="", type=str) + parser.add_argument("--seed", default=42, type=int) + parser.add_argument("--max-new-tokens", default=2048, type=int) + parser.add_argument("--top-k", default=50, type=int) + parser.add_argument("--top-p", default=1.0, type=float) + parser.add_argument("--temperature", default=0.9, type=float) + parser.add_argument("--repetition-penalty", default=1.05, type=float) + parser.add_argument("--streaming-interval", default=2.0, type=float) + parser.add_argument("--non-streaming-mode", action="store_true") + return parser.parse_args() + + +def _default_reference_audio(duration_sec: float = 1.0, sample_rate: int = 24000): + wav = np.zeros(int(duration_sec * sample_rate), dtype=np.float32) + return wav, sample_rate + + +def _set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + + +def _load_model(upstream_repo: Path, model_id_or_path: str): + for module_name in list(sys.modules): + if module_name == "qwen_tts" or module_name.startswith("qwen_tts."): + sys.modules.pop(module_name) + if str(upstream_repo) not in sys.path: + sys.path.insert(0, str(upstream_repo)) + from qwen_tts.core.models.configuration_qwen3_tts import ( # noqa: WPS433 + Qwen3TTSConfig, + Qwen3TTSTalkerCodePredictorConfig, + Qwen3TTSTalkerConfig, + ) + from qwen_tts.core.models.modeling_qwen3_tts import ( # noqa: WPS433 + Qwen3TTSForConditionalGeneration, + ) + from qwen_tts.core.models.processing_qwen3_tts import Qwen3TTSProcessor # noqa: WPS433 + + def _patch_pad_token_id(config_cls, fallback_attr: str) -> None: + original_init = config_cls.__init__ + + def _wrapped_init(self, *args, **kwargs): + original_init(self, *args, **kwargs) + if not hasattr(self, "pad_token_id"): + self.pad_token_id = getattr(self, fallback_attr, 0) + + config_cls.__init__ = _wrapped_init + + _patch_pad_token_id(Qwen3TTSTalkerConfig, "tts_pad_token_id") + _patch_pad_token_id(Qwen3TTSTalkerCodePredictorConfig, "pad_token_id") + from qwen_tts import Qwen3TTSModel # noqa: WPS433 + from transformers import AutoConfig, AutoModel, AutoProcessor # noqa: WPS433 + + AutoConfig.register("qwen3_tts", Qwen3TTSConfig) + AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration) + AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor) + + model = AutoModel.from_pretrained( + model_id_or_path, + device_map="cpu", + dtype=torch.float32, + ) + processor = AutoProcessor.from_pretrained(model_id_or_path) + return Qwen3TTSModel( + model=model, + processor=processor, + generate_defaults=model.generate_config, + ) + + +def write_codes_binary(path: Path, codes: torch.Tensor) -> None: + codes_i32 = codes.to(dtype=torch.int32).contiguous().cpu() + t_len, num_q = int(codes_i32.shape[0]), int(codes_i32.shape[1]) + flat_values = [int(v) for v in codes_i32.view(-1).tolist()] + with path.open("wb") as f: + f.write(struct.pack(" None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + _set_seed(args.seed) + + model = _load_model(args.upstream_repo, args.model_id_or_path) + codes = ( + _capture_voice_design_codes(model, args) + if args.instruct + else _capture_base_codes(model, args) + ) + wavs, sample_rate = model.model.speech_tokenizer.decode([{"audio_codes": codes}]) + wav = np.asarray(wavs[0], dtype=np.float32) + + decode_upsample_rate = int( + getattr(model.model.speech_tokenizer, "decode_upsample_rate", 1920) + ) + codec_steps_per_second = sample_rate / decode_upsample_rate + interval_steps = max(1, round(args.streaming_interval * codec_steps_per_second)) + chunk_boundaries = list(range(interval_steps, int(codes.shape[0]) + 1, interval_steps)) + if not chunk_boundaries or chunk_boundaries[-1] != int(codes.shape[0]): + chunk_boundaries.append(int(codes.shape[0])) + + codes_path = args.output_dir / "reference_codes.bin" + write_codes_binary(codes_path, codes) + np.save(args.output_dir / "reference_audio.npy", wav) + np.save(args.output_dir / "reference_codes.npy", codes.numpy()) + + contract = { + "model_id_or_path": args.model_id_or_path, + "text": args.text, + "language": args.language, + "instruct": args.instruct, + "seed": args.seed, + "non_streaming_mode": args.non_streaming_mode, + "temperature": args.temperature, + "top_k": args.top_k, + "top_p": args.top_p, + "repetition_penalty": args.repetition_penalty, + "streaming_interval_sec": args.streaming_interval, + "streaming_chunk_size": 300, + "streaming_left_context_size": 25, + "codec_steps_per_second": codec_steps_per_second, + "decode_upsample_rate": decode_upsample_rate, + "num_codec_steps": int(codes.shape[0]), + "num_quantizers": int(codes.shape[1]), + "audio_duration_sec": float(len(wav) / sample_rate), + "eos_position": int(codes.shape[0]), + "chunk_boundaries": chunk_boundaries, + "codec_trace": codes[:, 0].tolist(), + } + with (args.output_dir / "reference_contract.json").open("w", encoding="utf-8") as f: + json.dump(contract, f, indent=2, sort_keys=True) + + print(f"Saved: {codes_path}") + print(f"Saved: {args.output_dir / 'reference_contract.json'}") + print(f"Saved: {args.output_dir / 'reference_audio.npy'}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3-tts/config/code_predictor_config.json b/examples/models/qwen3-tts/config/code_predictor_config.json new file mode 100644 index 00000000000..9530eaa7e61 --- /dev/null +++ b/examples/models/qwen3-tts/config/code_predictor_config.json @@ -0,0 +1,17 @@ +{ + "dim": 1024, + "ffn_dim_multiplier": 1, + "hidden_dim": 3072, + "n_heads": 16, + "head_dim": 128, + "n_kv_heads": 8, + "n_layers": 5, + "norm_eps": 1e-06, + "rope_theta": 10000.0, + "use_scaled_rope": false, + "vocab_size": 2048, + "use_hf_rope": true, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true +} diff --git a/examples/models/qwen3-tts/config/model_config.json b/examples/models/qwen3-tts/config/model_config.json new file mode 100644 index 00000000000..4a035276130 --- /dev/null +++ b/examples/models/qwen3-tts/config/model_config.json @@ -0,0 +1,10 @@ +{ + "model_id": "Qwen/Qwen3-TTS-12Hz-0.6B-Base", + "tokenizer_type": "qwen3_tts_tokenizer_v2", + "tts_model_type": "base", + "output_sample_rate": 24000, + "notes": [ + "Text generation uses Qwen3-TTS talker in Python helper.", + "ExecuTorch export in this bring-up targets speech tokenizer decode path first." + ] +} diff --git a/examples/models/qwen3-tts/config/qwen3_tts_xnnpack_8da4w.yaml b/examples/models/qwen3-tts/config/qwen3_tts_xnnpack_8da4w.yaml new file mode 100644 index 00000000000..be5a57f6c9a --- /dev/null +++ b/examples/models/qwen3-tts/config/qwen3_tts_xnnpack_8da4w.yaml @@ -0,0 +1,10 @@ +model: + converted_dir: "qwen3_tts_artifacts" + metadata_file: "decoder_metadata.json" + +export: + backend: "xnnpack" + dtype: "fp32" + fixed_codes_len: 1200 + qlinear: "8da4w" + qlinear_group_size: 32 diff --git a/examples/models/qwen3-tts/config/qwen3_tts_xnnpack_fp32.yaml b/examples/models/qwen3-tts/config/qwen3_tts_xnnpack_fp32.yaml new file mode 100644 index 00000000000..acc7db24c7b --- /dev/null +++ b/examples/models/qwen3-tts/config/qwen3_tts_xnnpack_fp32.yaml @@ -0,0 +1,9 @@ +model: + converted_dir: "qwen3_tts_artifacts" + metadata_file: "decoder_metadata.json" + +export: + backend: "xnnpack" + dtype: "fp32" + fixed_codes_len: 1200 + qlinear: null diff --git a/examples/models/qwen3-tts/config/talker_config.json b/examples/models/qwen3-tts/config/talker_config.json new file mode 100644 index 00000000000..9b373e73b6f --- /dev/null +++ b/examples/models/qwen3-tts/config/talker_config.json @@ -0,0 +1,17 @@ +{ + "dim": 1024, + "ffn_dim_multiplier": 1, + "hidden_dim": 3072, + "n_heads": 16, + "head_dim": 128, + "n_kv_heads": 8, + "n_layers": 28, + "norm_eps": 1e-06, + "rope_theta": 10000.0, + "use_scaled_rope": false, + "vocab_size": 3072, + "use_hf_rope": true, + "attention_qkv_bias": false, + "use_qk_norm": true, + "qk_norm_before_rope": true +} diff --git a/examples/models/qwen3-tts/convert_talker_weights.py b/examples/models/qwen3-tts/convert_talker_weights.py new file mode 100644 index 00000000000..d6f4bb75fc3 --- /dev/null +++ b/examples/models/qwen3-tts/convert_talker_weights.py @@ -0,0 +1,173 @@ +"""Convert Qwen3-TTS talker weights from HF format to ExecuTorch/Meta Llama format. + +Produces two checkpoint files: + - talker_main.pth: main talker backbone (Qwen3 format for Llama infra) + - talker_code_predictor.pth: code predictor backbone (same Qwen3 format) + +Also extracts auxiliary weights (text_projection, codec_head, embeddings) into + - talker_aux.pth: non-transformer weights needed by the C++ runner +""" + +import argparse +import json +from pathlib import Path +from typing import Dict + +import torch + +# Weight key mapping: Meta/Llama format <- HF format +# Same as executorch/examples/models/qwen3/convert_weights.py but adapted +# for the talker checkpoint structure. +_QWEN3_FROM_META = { + "tok_embeddings.weight": "codec_embedding.weight", + "norm.weight": "norm.weight", + "output.weight": "__CODEC_HEAD__", # handled specially + "layers.{}.attention.wk.weight": "layers.{}.self_attn.k_proj.weight", + "layers.{}.attention.k_norm_fn.weight": "layers.{}.self_attn.k_norm.weight", + "layers.{}.attention.wq.weight": "layers.{}.self_attn.q_proj.weight", + "layers.{}.attention.q_norm_fn.weight": "layers.{}.self_attn.q_norm.weight", + "layers.{}.attention.wv.weight": "layers.{}.self_attn.v_proj.weight", + "layers.{}.attention.wo.weight": "layers.{}.self_attn.o_proj.weight", + "layers.{}.attention_norm.weight": "layers.{}.input_layernorm.weight", + "layers.{}.ffn_norm.weight": "layers.{}.post_attention_layernorm.weight", + # Note: gate_proj and up_proj are swapped (same as Qwen3 text models). + "layers.{}.feed_forward.w1.weight": "layers.{}.mlp.gate_proj.weight", + "layers.{}.feed_forward.w2.weight": "layers.{}.mlp.down_proj.weight", + "layers.{}.feed_forward.w3.weight": "layers.{}.mlp.up_proj.weight", +} + + +def _convert_backbone( + hf_state: Dict[str, torch.Tensor], + prefix: str, + codec_head_key: str, +) -> Dict[str, torch.Tensor]: + """Convert a transformer backbone from HF to Meta format.""" + inverted = {v: k for k, v in _QWEN3_FROM_META.items()} + converted = {} + + for hf_key, tensor in hf_state.items(): + if not hf_key.startswith(prefix): + continue + stripped = hf_key[len(prefix):] + + # Try direct match first. + if stripped in inverted: + meta_key = inverted[stripped] + if meta_key == "__CODEC_HEAD__": + continue # Handled separately. + converted[meta_key] = tensor + continue + + # Try layer-pattern match. + matched = False + for meta_pattern, hf_pattern in _QWEN3_FROM_META.items(): + if "{}" not in hf_pattern: + continue + hf_parts = hf_pattern.split("{}") + if stripped.startswith(hf_parts[0]) and stripped.endswith(hf_parts[1]): + layer_str = stripped[len(hf_parts[0]):-len(hf_parts[1]) if hf_parts[1] else len(stripped)] + meta_key = meta_pattern.replace("{}", layer_str) + converted[meta_key] = tensor + matched = True + break + if not matched and stripped not in ("codec_embedding.weight", "text_embedding.weight"): + print(f" Skipping unmapped key: {hf_key}") + + # Map codec_head -> output (lm_head equivalent). + if codec_head_key in hf_state: + converted["output.weight"] = hf_state[codec_head_key] + + return converted + + +def _convert_code_predictor( + hf_state: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + """Convert code predictor backbone to Meta format.""" + return _convert_backbone(hf_state, "code_predictor.model.", "") + + +def _extract_aux_weights( + hf_state: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + """Extract non-transformer weights for the C++ runner.""" + aux = {} + for key in sorted(hf_state.keys()): + if key.startswith("text_projection."): + aux[key] = hf_state[key] + elif key == "codec_head.weight": + aux[key] = hf_state[key] + elif key == "model.text_embedding.weight": + aux[key] = hf_state[key] + elif key == "model.codec_embedding.weight": + aux["main_codec_embedding.weight"] = hf_state[key] + elif key.startswith("code_predictor.model.codec_embedding."): + # e.g., code_predictor.model.codec_embedding.0.weight -> cp_codec_embedding.0.weight + suffix = key[len("code_predictor.model."):] + aux[f"cp_{suffix}"] = hf_state[key] + elif key.startswith("code_predictor.lm_head."): + aux[key] = hf_state[key] + + return aux + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert Qwen3-TTS talker weights to ExecuTorch Llama format." + ) + parser.add_argument( + "--talker-checkpoint", + type=Path, + required=True, + help="Path to qwen3_tts_talker.pth (from convert_weights.py --save-talker).", + ) + parser.add_argument( + "--output-dir", + type=Path, + required=True, + help="Output directory for converted checkpoints.", + ) + args = parser.parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + print(f"Loading talker checkpoint: {args.talker_checkpoint}") + hf_state = torch.load( + args.talker_checkpoint, map_location="cpu", weights_only=True + ) + + # Convert main talker backbone. + print("Converting main talker backbone...") + main_state = _convert_backbone(hf_state, "model.", "codec_head.weight") + main_path = args.output_dir / "talker_main.pth" + torch.save(main_state, main_path) + print(f" Saved {len(main_state)} keys -> {main_path}") + + # Convert code predictor backbone. + print("Converting code predictor backbone...") + cp_state = _convert_backbone(hf_state, "code_predictor.model.", "") + cp_path = args.output_dir / "talker_code_predictor.pth" + torch.save(cp_state, cp_path) + print(f" Saved {len(cp_state)} keys -> {cp_path}") + + # Extract auxiliary weights. + print("Extracting auxiliary weights...") + aux_state = _extract_aux_weights(hf_state) + aux_path = args.output_dir / "talker_aux.pth" + torch.save(aux_state, aux_path) + print(f" Saved {len(aux_state)} keys -> {aux_path}") + + # Write config files alongside. + config_dir = Path(__file__).resolve().parent / "config" + for name in ("talker_config.json", "code_predictor_config.json"): + src = config_dir / name + if src.exists(): + import shutil + shutil.copy2(src, args.output_dir / name) + print(f" Copied {name}") + + print("Done.") + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3-tts/convert_weights.py b/examples/models/qwen3-tts/convert_weights.py new file mode 100644 index 00000000000..4e97c06fdea --- /dev/null +++ b/examples/models/qwen3-tts/convert_weights.py @@ -0,0 +1,210 @@ +import argparse +import json +import re +from pathlib import Path +from typing import Dict, Optional, Tuple + +import torch + + +def _load_sharded_safetensors(input_dir: Path) -> Dict[str, torch.Tensor]: + from safetensors.torch import load_file + + index_path = input_dir / "model.safetensors.index.json" + if index_path.exists(): + with index_path.open("r", encoding="utf-8") as f: + index = json.load(f) + weight_map = index["weight_map"] + shard_to_names = {} + for name, shard in weight_map.items(): + shard_to_names.setdefault(shard, []).append(name) + merged = {} + for shard, names in shard_to_names.items(): + shard_state = load_file(str(input_dir / shard)) + for name in names: + merged[name] = shard_state[name] + return merged + + model_path = input_dir / "model.safetensors" + if model_path.exists(): + return load_file(str(model_path)) + + raise FileNotFoundError(f"Could not find safetensors checkpoint under {input_dir}") + + +def _extract_prefixed_state_dict( + state_dict: Dict[str, torch.Tensor], prefix: str +) -> Dict[str, torch.Tensor]: + out = {} + for key, value in state_dict.items(): + if key.startswith(prefix): + out[key[len(prefix) :]] = value + return out + + +def _sanitize_model_id(model_id: str) -> str: + cleaned = re.sub(r"[^a-zA-Z0-9._-]+", "_", model_id.strip()) + return cleaned.strip("_") or "qwen3_tts_model" + + +def _build_decoder_metadata( + model_id_or_path: str, + root_cfg: Dict, + speech_tokenizer_cfg: Dict, + decoder_checkpoint_name: str, +) -> Dict: + decoder_cfg = speech_tokenizer_cfg.get("decoder_config", {}) + return { + "model_id_or_path": model_id_or_path, + "tokenizer_type": root_cfg.get("tokenizer_type", "qwen3_tts_tokenizer_v2"), + "tts_model_type": root_cfg.get("tts_model_type", "base"), + "decoder_checkpoint": decoder_checkpoint_name, + "output_sample_rate": int(speech_tokenizer_cfg.get("output_sample_rate", 24000)), + "decode_upsample_rate": int(speech_tokenizer_cfg.get("decode_upsample_rate", 1920)), + "num_quantizers": int(decoder_cfg.get("num_quantizers", 16)), + "codebook_size": int(decoder_cfg.get("codebook_size", 2048)), + "decoder_config": decoder_cfg, + } + + +def _resolve_snapshot_dir( + input_ref: str, cache_dir: Optional[str] +) -> Tuple[Path, Optional[str]]: + input_path = Path(input_ref) + if input_path.exists(): + return input_path.resolve(), None + + from huggingface_hub import snapshot_download + + snapshot_path = snapshot_download( + repo_id=input_ref, + cache_dir=cache_dir, + allow_patterns=[ + "config.json", + "model.safetensors*", + "model-*.safetensors*", + "speech_tokenizer/*", + ], + ) + return Path(snapshot_path).resolve(), input_ref + + +def convert_weights( + input_ref: str, + output_dir: Path, + model_id_or_path: Optional[str], + save_talker: bool, + cache_dir: Optional[str], +) -> None: + input_dir, downloaded_model_id = _resolve_snapshot_dir( + input_ref=input_ref, cache_dir=cache_dir + ) + output_dir.mkdir(parents=True, exist_ok=True) + + root_cfg_path = input_dir / "config.json" + if not root_cfg_path.exists(): + raise FileNotFoundError(f"Missing root config: {root_cfg_path}") + with root_cfg_path.open("r", encoding="utf-8") as f: + root_cfg = json.load(f) + + speech_tokenizer_dir = input_dir / "speech_tokenizer" + speech_tokenizer_cfg_path = speech_tokenizer_dir / "config.json" + if not speech_tokenizer_cfg_path.exists(): + raise FileNotFoundError(f"Missing speech tokenizer config: {speech_tokenizer_cfg_path}") + with speech_tokenizer_cfg_path.open("r", encoding="utf-8") as f: + speech_tokenizer_cfg = json.load(f) + + print("Loading speech tokenizer checkpoint...") + speech_state = _load_sharded_safetensors(speech_tokenizer_dir) + decoder_state = _extract_prefixed_state_dict(speech_state, "decoder.") + if not decoder_state: + raise RuntimeError( + "Decoder weights were not found in speech tokenizer checkpoint " + "(expected keys prefixed by 'decoder.')." + ) + decoder_ckpt = output_dir / "qwen3_tts_decoder.pth" + torch.save(decoder_state, decoder_ckpt) + print(f"Saved decoder checkpoint: {decoder_ckpt}") + + talker_ckpt_name = None + if save_talker: + print("Loading root model checkpoint for talker extraction...") + root_state = _load_sharded_safetensors(input_dir) + talker_state = _extract_prefixed_state_dict(root_state, "talker.") + if not talker_state: + raise RuntimeError( + "Talker weights were not found in root checkpoint " + "(expected keys prefixed by 'talker.')." + ) + talker_ckpt = output_dir / "qwen3_tts_talker.pth" + torch.save(talker_state, talker_ckpt) + talker_ckpt_name = talker_ckpt.name + print(f"Saved talker checkpoint: {talker_ckpt}") + + if model_id_or_path is None: + if downloaded_model_id is not None: + model_id_or_path = downloaded_model_id + else: + model_id_or_path = _sanitize_model_id(input_dir.name) + + metadata = _build_decoder_metadata( + model_id_or_path=model_id_or_path, + root_cfg=root_cfg, + speech_tokenizer_cfg=speech_tokenizer_cfg, + decoder_checkpoint_name=decoder_ckpt.name, + ) + if talker_ckpt_name is not None: + metadata["talker_checkpoint"] = talker_ckpt_name + + metadata_path = output_dir / "decoder_metadata.json" + with metadata_path.open("w", encoding="utf-8") as f: + json.dump(metadata, f, indent=2, sort_keys=True) + print(f"Saved decoder metadata: {metadata_path}") + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert Qwen3-TTS HF checkpoints into export-ready decoder artifacts." + ) + parser.add_argument( + "input_ref", + type=str, + help=( + "Either a local HF snapshot path or a Hugging Face model id " + "(e.g., Qwen/Qwen3-TTS-12Hz-0.6B-Base)." + ), + ) + parser.add_argument( + "output_dir", + type=Path, + help="Directory where converted artifacts will be written.", + ) + parser.add_argument( + "--model-id-or-path", + type=str, + default=None, + help="Original model id or path to record in metadata.", + ) + parser.add_argument( + "--save-talker", + action="store_true", + help="Also extract and save talker.* weights from root checkpoint.", + ) + parser.add_argument( + "--cache-dir", + type=str, + default=None, + help="Optional huggingface_hub cache directory when input_ref is a model id.", + ) + args = parser.parse_args() + convert_weights( + input_ref=args.input_ref, + output_dir=args.output_dir, + model_id_or_path=args.model_id_or_path, + save_talker=args.save_talker, + cache_dir=args.cache_dir, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3-tts/export_qwen3_tts.py b/examples/models/qwen3-tts/export_qwen3_tts.py new file mode 100644 index 00000000000..6edc1e26249 --- /dev/null +++ b/examples/models/qwen3-tts/export_qwen3_tts.py @@ -0,0 +1,236 @@ +import argparse +import json +import sys +from pathlib import Path + +import torch +from torch.export import export + +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass + +SCRIPT_DIR = Path(__file__).resolve().parent +if str(SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(SCRIPT_DIR)) + +from model import DecoderExportMetadata, make_decode_export_module, make_sample_codes # noqa: E402 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Export Qwen3-TTS speech tokenizer decoder to ExecuTorch." + ) + parser.add_argument( + "--converted-dir", + type=Path, + required=True, + help="Directory produced by convert_weights.py (contains decoder_metadata.json).", + ) + parser.add_argument( + "--backend", + choices=["portable", "xnnpack"], + default="xnnpack", + help="Backend to target for decoder export.", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("./qwen3_tts_exports"), + help="Output directory for model.pte.", + ) + parser.add_argument( + "--fixed-codes-len", + type=int, + default=1200, + help="Static codec sequence length used for export.", + ) + parser.add_argument( + "--bucket-sizes", + type=str, + default=None, + help="Comma-separated list of bucket sizes to export (e.g. '75,150,300,600,1200'). " + "Each bucket produces a separate model_{size}.pte file. Overrides --fixed-codes-len.", + ) + parser.add_argument( + "--dtype", + choices=["fp32", "bf16"], + default="fp32", + help="Decoder weight dtype for export.", + ) + parser.add_argument( + "--qlinear", + choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"], + default=None, + help="Optional quantization mode for linear layers.", + ) + parser.add_argument( + "--qlinear-group-size", + type=int, + default=32, + help="Group size for linear quantization.", + ) + parser.add_argument( + "--qlinear-packing-format", + choices=["tile_packed_to_4d"], + default=None, + help="Optional packing format for 4w quantization.", + ) + return parser.parse_args() + + +def lower_to_executorch(programs, constant_methods: dict, backend: str): + if backend == "xnnpack": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, + XnnpackPartitioner, + ) + + partitioner = { + "decode_codes": [ + XnnpackDynamicallyQuantizedPartitioner(), + XnnpackPartitioner(), + ] + } + else: + partitioner = [] + + edge_prog = to_edge_transform_and_lower( + programs, + partitioner=partitioner, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=constant_methods, + ) + return edge_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ) + ) + + +def export_single_bucket( + module, + metadata: DecoderExportMetadata, + codes_len: int, + backend: str, + output_dir: Path, + model_filename: str = "model.pte", +) -> Path: + sample_codes = make_sample_codes( + codebook_size=metadata.codebook_size, + num_quantizers=metadata.num_quantizers, + code_len=codes_len, + ) + programs = { + "decode_codes": export( + module, + (sample_codes,), + strict=True, + ) + } + + constant_methods = metadata.to_constant_methods() + constant_methods["fixed_codes_len"] = int(codes_len) + + et_prog = lower_to_executorch( + programs, constant_methods=constant_methods, backend=backend + ) + model_path = output_dir / model_filename + with model_path.open("wb") as f: + et_prog.write_to_file(f) + return model_path + + +def main() -> None: + args = parse_args() + output_dir = args.output_dir.resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + converted_dir = args.converted_dir.resolve() + metadata_path = converted_dir / "decoder_metadata.json" + metadata = DecoderExportMetadata.from_json(metadata_path) + checkpoint_path = converted_dir / metadata.decoder_checkpoint + if not checkpoint_path.exists(): + raise FileNotFoundError(f"Decoder checkpoint not found: {checkpoint_path}") + + dtype = {"fp32": torch.float32, "bf16": torch.bfloat16}[args.dtype] + module = make_decode_export_module( + metadata=metadata, checkpoint_path=checkpoint_path, dtype=dtype + ) + + if args.qlinear is not None: + from executorch.extension.llm.export.quantize import quantize_model_ + + quantize_model_( + module, + qlinear_config=args.qlinear, + qlinear_group_size=args.qlinear_group_size, + qlinear_packing_format=args.qlinear_packing_format, + ) + + bucket_sizes = None + if args.bucket_sizes is not None: + bucket_sizes = sorted(int(s.strip()) for s in args.bucket_sizes.split(",")) + + if bucket_sizes is not None: + bucket_manifest = { + "backend": args.backend, + "dtype": args.dtype, + "qlinear": args.qlinear, + "qlinear_group_size": args.qlinear_group_size, + "qlinear_packing_format": args.qlinear_packing_format, + "source_converted_dir": str(converted_dir), + "buckets": [], + } + for size in bucket_sizes: + print(f"\n--- Exporting bucket: codes_len={size} ---") + filename = f"model_{size}.pte" + model_path = export_single_bucket( + module, metadata, size, args.backend, output_dir, filename + ) + bucket_manifest["buckets"].append({ + "codes_len": size, + "model_path": str(model_path), + "model_filename": filename, + }) + print(f"Saved model: {model_path}") + + manifest_path = output_dir / "export_manifest.json" + with manifest_path.open("w", encoding="utf-8") as f: + json.dump(bucket_manifest, f, indent=2, sort_keys=True) + print(f"\nSaved manifest: {manifest_path}") + print(f"Exported {len(bucket_sizes)} buckets: {bucket_sizes}") + else: + model_path = export_single_bucket( + module, metadata, args.fixed_codes_len, args.backend, output_dir + ) + + export_manifest = { + "backend": args.backend, + "dtype": args.dtype, + "qlinear": args.qlinear, + "qlinear_group_size": args.qlinear_group_size, + "qlinear_packing_format": args.qlinear_packing_format, + "fixed_codes_len": args.fixed_codes_len, + "source_converted_dir": str(converted_dir), + "model_path": str(model_path), + "constant_methods": metadata.to_constant_methods(), + } + manifest_path = output_dir / "export_manifest.json" + with manifest_path.open("w", encoding="utf-8") as f: + json.dump(export_manifest, f, indent=2, sort_keys=True) + + print(f"Saved model: {model_path}") + print(f"Saved manifest: {manifest_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3-tts/export_talker.py b/examples/models/qwen3-tts/export_talker.py new file mode 100644 index 00000000000..5d23e101e45 --- /dev/null +++ b/examples/models/qwen3-tts/export_talker.py @@ -0,0 +1,219 @@ +"""Export Qwen3-TTS talker backbone to ExecuTorch. + +The talker is architecturally identical to Qwen3 0.6B (same attention, MLP, +RMSNorm, QK-norm, RoPE) so we reuse the existing Llama/Qwen3 export +infrastructure directly. + +This exports the main talker as a standard autoregressive LM with KV cache, +producing a .pte that supports prefill + per-token decode. + +Usage: + python export_talker.py \ + --checkpoint examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted/talker_main.pth \ + --params examples/models/qwen3-tts/config/talker_config.json \ + --output-dir examples/models/qwen3-tts/qwen3_tts_exports_talker \ + --backend xnnpack \ + --qlinear 8da4w +""" + +import argparse +import json +import sys +from pathlib import Path + +import torch +from torch.export import export + +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass + +SCRIPT_DIR = Path(__file__).resolve().parent +if str(SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(SCRIPT_DIR)) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Export Qwen3-TTS talker to ExecuTorch." + ) + parser.add_argument( + "--checkpoint", type=Path, required=True, + help="Converted talker checkpoint (talker_main.pth).", + ) + parser.add_argument( + "--params", type=Path, required=True, + help="Model params JSON (talker_config.json).", + ) + parser.add_argument( + "--backend", choices=["portable", "xnnpack"], default="xnnpack", + ) + parser.add_argument( + "--output-dir", type=Path, default=Path("./qwen3_tts_exports_talker"), + ) + parser.add_argument( + "--max-seq-len", type=int, default=2048, + help="Max sequence length for KV cache allocation.", + ) + parser.add_argument( + "--qlinear", choices=["4w", "8w", "8da4w", "8da8w"], default=None, + ) + parser.add_argument("--qlinear-group-size", type=int, default=32) + parser.add_argument( + "--output-name", type=str, default="talker.pte", + help="Output .pte filename.", + ) + parser.add_argument( + "--no-embedding", action="store_true", + help="Don't apply tok_embeddings (model takes hidden states). " + "Used for code_predictor which has per-group embeddings.", + ) + parser.add_argument( + "--no-output", action="store_true", + help="Don't apply output projection (model returns hidden states). " + "Used for code_predictor which has per-group LM heads.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + from executorch.examples.models.llama.model_args import ModelArgs + from executorch.examples.models.llama.llama_transformer import construct_transformer + + # Load config. + with args.params.open("r") as f: + params_dict = json.load(f) + + params_dict["use_kv_cache"] = True + params_dict["max_seq_len"] = args.max_seq_len + params_dict["max_context_len"] = args.max_seq_len + params_dict["max_batch_size"] = 1 + params_dict["generate_full_logits"] = False + if args.no_embedding: + params_dict["apply_embedding"] = False + if args.no_output: + params_dict["apply_output"] = False + + model_args = ModelArgs(**params_dict) + print(f"ModelArgs: dim={model_args.dim}, n_layers={model_args.n_layers}, " + f"n_heads={model_args.n_heads}, n_kv_heads={model_args.n_kv_heads}, " + f"vocab_size={model_args.vocab_size}, max_seq_len={model_args.max_seq_len}") + + # Build model. + model = construct_transformer(model_args) + model.eval() + + # Load weights. + print(f"Loading checkpoint: {args.checkpoint}") + state_dict = torch.load(args.checkpoint, map_location="cpu", weights_only=True) + missing, unexpected = model.load_state_dict(state_dict, strict=False) + if missing: + # Filter out KV cache buffers (expected to be missing). + real_missing = [k for k in missing if "k_cache" not in k and "v_cache" not in k and "mask" not in k] + if real_missing: + print(f"WARNING: Missing keys: {real_missing}") + if unexpected: + print(f"WARNING: Unexpected keys: {unexpected}") + + # Apply quantization. + if args.qlinear is not None: + from executorch.extension.llm.export.quantize import quantize_model_ + print(f"Applying {args.qlinear} quantization (group_size={args.qlinear_group_size})...") + quantize_model_( + model, + qlinear_config=args.qlinear, + qlinear_group_size=args.qlinear_group_size, + ) + + # Disable gradients on all parameters (required for in-place KV cache ops). + for param in model.parameters(): + param.requires_grad_(False) + for buf in model.buffers(): + buf.requires_grad_(False) + + # Export with KV cache: single-token decode mode. + example_attn_options = {"input_pos": torch.tensor([0], dtype=torch.long)} + + if args.no_embedding: + # Code predictor: takes hidden states [1, 1, dim] instead of token ids. + example_h = torch.randn(1, 1, model_args.dim) + example_args = (None, example_attn_options, example_h) + else: + # Main talker: takes token ids [1, 1]. + example_tokens = torch.tensor([[0]], dtype=torch.long) + example_args = (example_tokens, example_attn_options) + + print("Exporting with torch.export...") + with torch.no_grad(): + exported = export( + model, + example_args, + strict=False, + ) + + # Lower to ExecuTorch. + if args.backend == "xnnpack": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, + XnnpackPartitioner, + ) + partitioner = [XnnpackDynamicallyQuantizedPartitioner(), XnnpackPartitioner()] + else: + partitioner = [] + + constant_methods = { + "max_seq_len": args.max_seq_len, + "vocab_size": model_args.vocab_size, + "dim": model_args.dim, + "n_heads": model_args.n_heads, + "n_kv_heads": model_args.n_kv_heads, + "head_dim": model_args.head_dim, + "n_layers": model_args.n_layers, + } + + print("Lowering to ExecuTorch...") + edge_prog = to_edge_transform_and_lower( + {"forward": exported}, + partitioner={"forward": partitioner}, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=constant_methods, + ) + et_prog = edge_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=not args.no_output, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ) + ) + + model_path = args.output_dir / args.output_name + with model_path.open("wb") as f: + et_prog.write_to_file(f) + print(f"Saved: {model_path}") + + manifest = { + "model_type": "qwen3_tts_talker", + "backend": args.backend, + "qlinear": args.qlinear, + "max_seq_len": args.max_seq_len, + "model_args": params_dict, + "constant_methods": constant_methods, + } + manifest_name = args.output_name.replace(".pte", "_manifest.json") + manifest_path = args.output_dir / manifest_name + with manifest_path.open("w") as f: + json.dump(manifest, f, indent=2, sort_keys=True) + print(f"Saved: {manifest_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3-tts/export_unified.py b/examples/models/qwen3-tts/export_unified.py new file mode 100644 index 00000000000..e309ff2e68b --- /dev/null +++ b/examples/models/qwen3-tts/export_unified.py @@ -0,0 +1,947 @@ +"""Export Qwen3-TTS as a single multi-method .pte for mobile deployment. + +Produces one model.pte containing all pipeline stages: + encode_text — text token_ids → projected embeddings [1, S, 1024] + talker — composite embeddings → (logits, hidden) with KV cache + code_predictor — sub-code embeddings → hidden with KV cache + codec_embed — (token_id, group_idx) → embedding [1, 1, 1024] + cp_head — (hidden, head_idx) → logits [1, 2048] + cp_generate — fused sampled 15-step code predictor loop + decode_audio — audio codes [1, T, 16] → (waveform, lengths) + +Follows the Parakeet multi-method export pattern. + +Usage: + python export_unified.py \ + --converted-dir qwen3_tts_artifacts \ + --talker-dir qwen3_tts_artifacts/talker_converted \ + --output-dir qwen3_tts_exports_unified \ + --backend xnnpack \ + --qlinear 8da4w +""" + +import argparse +import json +import math +import sys +from pathlib import Path +from typing import Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.export import Dim, export + +from executorch.exir import ( + EdgeCompileConfig, + ExecutorchBackendConfig, + to_edge_transform_and_lower, +) +from executorch.exir.passes import MemoryPlanningPass + +SCRIPT_DIR = Path(__file__).resolve().parent +DEFAULT_MODEL_CONFIG_PATH = SCRIPT_DIR / "qwen3-tts-12Hz-0.6B-Base" / "config.json" +if str(SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(SCRIPT_DIR)) + +from model import DecoderExportMetadata, load_decoder_from_metadata +from text_prompt_contract import ( + MIN_PROMPT_TOKEN_COUNT, + TEXT_ONLY_PREFILL_TOKEN_COUNT, + TEXT_ONLY_PREFILL_TOKEN_COUNT_WITH_LANGUAGE, + TRAILING_TEMPLATE_TOKEN_COUNT, +) + + +STREAMING_DECODER_CHUNK_SIZE = 300 +STREAMING_DECODER_LEFT_CONTEXT_SIZE = 25 + +BACKEND_CODE_PORTABLE = 0 +BACKEND_CODE_XNNPACK = 1 +BACKEND_CODE_METAL = 2 + + +def resolve_backend_runtime_metadata(backend: str) -> tuple[int, int, int]: + backend_code = { + "portable": BACKEND_CODE_PORTABLE, + "xnnpack": BACKEND_CODE_XNNPACK, + "metal": BACKEND_CODE_METAL, + }[backend] + generation_backend_code = backend_code + decoder_backend_code = backend_code + prefer_streaming_decoder_surface = 0 + if backend == "metal": + decoder_backend_code = BACKEND_CODE_XNNPACK + return ( + generation_backend_code, + decoder_backend_code, + prefer_streaming_decoder_surface, + ) + + +def load_runtime_token_ids(model_config_path: Path) -> Dict[str, int]: + with model_config_path.open("r", encoding="utf-8") as f: + config = json.load(f) + + talker_config = config["talker_config"] + return { + "tts_pad_token_id": int(config["tts_pad_token_id"]), + "tts_bos_token_id": int(config["tts_bos_token_id"]), + "tts_eod_token_id": int(config["tts_eos_token_id"]), + "codec_pad_id": int(talker_config["codec_pad_id"]), + "codec_bos_id": int(talker_config["codec_bos_id"]), + "codec_eos_id": int(talker_config["codec_eos_token_id"]), + "codec_think_id": int(talker_config["codec_think_id"]), + "codec_language_english_id": int(talker_config["codec_language_id"]["english"]), + "codec_nothink_id": int(talker_config["codec_nothink_id"]), + "codec_think_bos_id": int(talker_config["codec_think_bos_id"]), + "codec_think_eos_id": int(talker_config["codec_think_eos_id"]), + "im_start_token_id": int(config["im_start_token_id"]), + "assistant_token_id": int(config["assistant_token_id"]), + "newline_token_id": 198, + } + + +# --------------------------------------------------------------------------- +# Wrapper modules +# --------------------------------------------------------------------------- + +class EncodeTextExport(nn.Module): + """Text token_ids → projected embeddings [1, S, 1024]. + + Wraps text_embedding (nn.Embedding) + text_projection (2-layer MLP). + """ + + def __init__(self, text_embedding: nn.Embedding, text_projection: nn.Module): + super().__init__() + self.text_embedding = text_embedding + self.text_projection = text_projection + + def forward(self, token_ids: torch.Tensor) -> torch.Tensor: + embeds = self.text_embedding(token_ids) + return self.text_projection(embeds) + + +class TextProjectionMLP(nn.Module): + """2-layer MLP: text_hidden (2048) → intermediate (2048) → talker_dim (1024).""" + + def __init__(self, fc1_weight, fc1_bias, fc2_weight, fc2_bias): + super().__init__() + self.fc1 = nn.Linear(fc1_weight.shape[1], fc1_weight.shape[0]) + self.fc1.weight = nn.Parameter(fc1_weight) + self.fc1.bias = nn.Parameter(fc1_bias) + self.fc2 = nn.Linear(fc2_weight.shape[1], fc2_weight.shape[0]) + self.fc2.weight = nn.Parameter(fc2_weight) + self.fc2.bias = nn.Parameter(fc2_bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc2(F.silu(self.fc1(x))) + + +class TalkerExport(nn.Module): + """Talker transformer wrapper returning both logits and hidden state. + + The transformer runs with apply_output=False (returns normalized hidden). + We apply codec_head manually to produce logits. + """ + + def __init__(self, transformer: nn.Module, codec_head_weight: torch.Tensor): + super().__init__() + self.transformer = transformer + self.codec_head = nn.Linear( + codec_head_weight.shape[1], codec_head_weight.shape[0], bias=False + ) + self.codec_head.weight = nn.Parameter(codec_head_weight) + + def forward( + self, embeds: torch.Tensor, input_pos: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + hidden = self.transformer( + tokens=None, attn_options={"input_pos": input_pos}, h=embeds + ) + if isinstance(hidden, tuple): + hidden = hidden[0] + logits = self.codec_head(hidden) + return logits, hidden + + +class CodePredictorExport(nn.Module): + """Code predictor transformer wrapper (returns hidden state only).""" + + def __init__(self, transformer: nn.Module): + super().__init__() + self.transformer = transformer + + def forward( + self, embeds: torch.Tensor, input_pos: torch.Tensor + ) -> torch.Tensor: + hidden = self.transformer( + tokens=None, attn_options={"input_pos": input_pos}, h=embeds + ) + if isinstance(hidden, tuple): + hidden = hidden[0] + return hidden + + +class CodecEmbedExport(nn.Module): + """All codec embeddings (main + 15 cp) stacked for index_select lookup. + + Main codec: vocab 3072, dim 1024 (group_idx=0) + CP codec 0-14: vocab 2048, dim 1024 (group_idx=1..15) + """ + + def __init__( + self, + main_codec_weight: torch.Tensor, + cp_codec_weights: list, + ): + super().__init__() + vocab_max = main_codec_weight.shape[0] + dim = main_codec_weight.shape[1] + num_groups = 1 + len(cp_codec_weights) + + stacked = torch.zeros(num_groups, vocab_max, dim, dtype=main_codec_weight.dtype) + stacked[0, : main_codec_weight.shape[0]] = main_codec_weight + for i, w in enumerate(cp_codec_weights): + stacked[i + 1, : w.shape[0]] = w + + self.register_buffer("stacked_embeds", stacked) + + def forward( + self, token_id: torch.Tensor, group_idx: torch.Tensor + ) -> torch.Tensor: + table = torch.index_select(self.stacked_embeds, 0, group_idx).squeeze(0) + return F.embedding(token_id, table).unsqueeze(0) + + +class CpHeadExport(nn.Module): + """Code predictor per-group LM heads stacked for index_select. + + 15 heads, each [2048, 1024]. Stacked to [15, 2048, 1024]. + """ + + def __init__(self, head_weights: list): + super().__init__() + stacked = torch.stack(head_weights, dim=0) + self.register_buffer("stacked_heads", stacked) + + def forward( + self, hidden: torch.Tensor, head_idx: torch.Tensor + ) -> torch.Tensor: + head_weight = torch.index_select( + self.stacked_heads, 0, head_idx + ).squeeze(0) + return F.linear(hidden, head_weight) + + +class CpGenerateExport(nn.Module): + """Fused code predictor v2 for the warm XNNPACK fast path. + + The runner still samples `code_0` on the host so it can keep repetition + penalty, suppression, and EOS handling aligned with the talker path. + This fused export handles groups 1..15 using the current common sampler + shape used in this project: + - temperature > 0 + - top_k == 50 + - top_p disabled + + Sampling is made exportable by passing one pre-generated uniform random + value per sub-code group. The graph performs: + 1. per-group LM head + 2. fixed top-k(50) + 3. inverse-CDF sampling over softmax(topk logits / temperature) + 4. embedding lookup for the chosen code + 5. code predictor step for the next group + + Returns: + - sampled sub-codes [15] + - fused embedding sum for the next talker step [1024] + """ + + def __init__( + self, + cp_transformer: nn.Module, + cp_head_weights: list, + cp_embed_weights: list, + ): + super().__init__() + self.cp_transformer = cp_transformer + self.num_groups = len(cp_head_weights) + self.register_buffer("stacked_heads", torch.stack(cp_head_weights, dim=0)) + self.register_buffer("stacked_embeds", torch.stack(cp_embed_weights, dim=0)) + + def _cp_forward(self, embeds: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: + hidden = self.cp_transformer( + tokens=None, attn_options={"input_pos": pos}, h=embeds + ) + if isinstance(hidden, tuple): + hidden = hidden[0] + return hidden + + def forward( + self, + talker_hidden: torch.Tensor, + code_0_embed: torch.Tensor, + temperature: torch.Tensor, + sample_uniforms: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Prefill: [talker_hidden, code_0_embed] at positions [0, 1] + cp_input = torch.cat([talker_hidden, code_0_embed], dim=1) + cp_pos = torch.arange(2, dtype=torch.long) + cp_hidden = self._cp_forward(cp_input, cp_pos) + + # Start with code_0 embedding in the sum + embed_sum = code_0_embed.reshape(-1) # [1024] + safe_temperature = torch.clamp(temperature.reshape(()), min=1e-5) + sampled_codes = [] + + for group_idx in range(self.num_groups): + head_weight = self.stacked_heads[group_idx] + logits = F.linear(cp_hidden, head_weight).reshape(-1) + topk_vals, topk_idx = torch.topk(logits, k=50, dim=0) + probs = torch.softmax(topk_vals / safe_temperature, dim=0) + cdf = torch.cumsum(probs, dim=0) + sample = torch.clamp(sample_uniforms[group_idx], min=1e-6, max=1.0 - 1e-6) + choice = torch.argmax((cdf >= sample).to(torch.int64), dim=0) + code = torch.gather(topk_idx, 0, choice.reshape(1)).reshape(()) + sampled_codes.append(code) + + embed_weight = self.stacked_embeds[group_idx] + code_embed = F.embedding(code.reshape(1), embed_weight).unsqueeze(0) + embed_sum = embed_sum + code_embed.reshape(-1) + + if group_idx + 1 < self.num_groups: + cp_hidden = self._cp_forward( + code_embed, + torch.tensor([group_idx + 2], dtype=torch.long), + ) + + return torch.stack(sampled_codes, dim=0), embed_sum + + +class DynamicDecoderExport(nn.Module): + """Decoder wrapper with exportable padding (no math.ceil on SymInt).""" + + def __init__(self, decoder, decode_upsample_rate: int): + super().__init__() + self.decoder = decoder + self.decode_upsample_rate = int(decode_upsample_rate) + self._patch_causal_conv_padding() + + def _patch_causal_conv_padding(self): + """Replace math.ceil-based padding with integer arithmetic.""" + for module in self.decoder.modules(): + cls_name = type(module).__name__ + if "CausalConvNet" in cls_name and hasattr(module, "stride"): + _patch_conv_padding(module) + + def forward(self, audio_codes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + audio_lengths = (audio_codes[..., 0] > -1).sum(1) * self.decode_upsample_rate + clamped_codes = torch.clamp(audio_codes, min=0) + wav = self.decoder(clamped_codes.transpose(1, 2)).squeeze(1) + return wav, audio_lengths + + +class StreamingDecoderExport(DynamicDecoderExport): + """Fixed-window decoder surface for overlap-context streaming on XNNPACK.""" + + def __init__( + self, + decoder, + decode_upsample_rate: int, + chunk_size: int, + left_context_size: int, + ): + super().__init__(decoder, decode_upsample_rate) + self.chunk_size = int(chunk_size) + self.left_context_size = int(left_context_size) + self.max_codes = self.chunk_size + self.left_context_size + + def forward(self, audio_codes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if audio_codes.shape[1] != self.max_codes: + raise ValueError( + f"audio_codes must have shape [B, {self.max_codes}, Q], got {tuple(audio_codes.shape)}" + ) + return super().forward(audio_codes) + + +def _patch_conv_padding(module): + """Monkey-patch _get_extra_padding_for_conv1d to avoid math.ceil on SymInt.""" + kernel_size = module.kernel_size + stride = module.stride + padding = module.padding + + def _exportable_extra_padding(self, hidden_state): + length = hidden_state.shape[-1] + n_frames_num = length - kernel_size + padding + stride + n_frames_ceil = (n_frames_num + stride - 1) // stride + ideal_length = (n_frames_ceil - 1) * stride + (kernel_size - padding) + return ideal_length - length + + import types + module._get_extra_padding_for_conv1d = types.MethodType( + _exportable_extra_padding, module + ) + + +# --------------------------------------------------------------------------- +# Model loading +# --------------------------------------------------------------------------- + +def load_talker_model(talker_dir: Path, max_seq_len: int): + """Load the talker backbone using Llama infrastructure.""" + from executorch.examples.models.llama.model_args import ModelArgs + from executorch.examples.models.llama.llama_transformer import construct_transformer + + config_path = talker_dir / "talker_config.json" + with config_path.open("r") as f: + params = json.load(f) + + params["use_kv_cache"] = True + params["max_seq_len"] = max_seq_len + params["max_context_len"] = max_seq_len + params["max_batch_size"] = 1 + params["generate_full_logits"] = False + params["apply_embedding"] = False + params["apply_output"] = False + + model_args = ModelArgs(**params) + model = construct_transformer(model_args) + model.eval() + + ckpt = torch.load(talker_dir / "talker_main.pth", map_location="cpu", weights_only=True) + missing, unexpected = model.load_state_dict(ckpt, strict=False) + real_missing = [k for k in missing if "k_cache" not in k and "v_cache" not in k and "mask" not in k] + if real_missing: + print(f"WARNING: Talker missing keys: {real_missing}") + + return model, model_args + + +def load_code_predictor_model(talker_dir: Path, max_seq_len: int = 32): + """Load the code predictor backbone.""" + from executorch.examples.models.llama.model_args import ModelArgs + from executorch.examples.models.llama.llama_transformer import construct_transformer + + config_path = talker_dir / "code_predictor_config.json" + with config_path.open("r") as f: + params = json.load(f) + + params["use_kv_cache"] = True + params["max_seq_len"] = max_seq_len + params["max_context_len"] = max_seq_len + params["max_batch_size"] = 1 + params["generate_full_logits"] = False + params["apply_embedding"] = False + params["apply_output"] = False + + model_args = ModelArgs(**params) + model = construct_transformer(model_args) + model.eval() + + ckpt = torch.load( + talker_dir / "talker_code_predictor.pth", map_location="cpu", weights_only=True + ) + missing, unexpected = model.load_state_dict(ckpt, strict=False) + real_missing = [k for k in missing if "k_cache" not in k and "v_cache" not in k and "mask" not in k] + if real_missing: + print(f"WARNING: Code predictor missing keys: {real_missing}") + + return model, model_args + + +def load_aux_weights(talker_dir: Path): + """Load auxiliary weights (embeddings, heads, projections).""" + aux = torch.load(talker_dir / "talker_aux.pth", map_location="cpu", weights_only=True) + return aux + + +# --------------------------------------------------------------------------- +# Export +# --------------------------------------------------------------------------- + +def build_wrapper_modules( + talker_dir: Path, + converted_dir: Path, + metadata: DecoderExportMetadata, + max_seq_len: int, + dtype: torch.dtype, + backend: str, +): + """Build all wrapper modules for multi-method export.""" + aux = load_aux_weights(talker_dir) + + # 1. encode_text + text_emb_weight = aux["model.text_embedding.weight"].to(dtype) + text_embedding = nn.Embedding( + text_emb_weight.shape[0], text_emb_weight.shape[1] + ) + text_embedding.weight = nn.Parameter(text_emb_weight) + + text_projection = TextProjectionMLP( + fc1_weight=aux["text_projection.linear_fc1.weight"].to(dtype), + fc1_bias=aux["text_projection.linear_fc1.bias"].to(dtype), + fc2_weight=aux["text_projection.linear_fc2.weight"].to(dtype), + fc2_bias=aux["text_projection.linear_fc2.bias"].to(dtype), + ) + encode_text = EncodeTextExport(text_embedding, text_projection) + encode_text.eval() + + # 2. talker + talker_model, talker_args = load_talker_model(talker_dir, max_seq_len) + if backend == "metal": + from executorch.examples.models.llama.source_transformation.sdpa import ( + replace_causal_mask, + ) + + talker_model = replace_causal_mask(talker_model) + codec_head_weight = aux["codec_head.weight"].to(dtype) + talker = TalkerExport(talker_model, codec_head_weight) + talker.eval() + + # 3. code_predictor (standalone, kept for backward compat) + cp_model, cp_args = load_code_predictor_model(talker_dir, max_seq_len=32) + if backend == "metal": + cp_model = replace_causal_mask(cp_model) + code_predictor = CodePredictorExport(cp_model) + code_predictor.eval() + + # 4. codec_embed + main_codec_weight = aux["main_codec_embedding.weight"].to(dtype) + cp_codec_weights = [] + for i in range(15): + key = f"cp_codec_embedding.{i}.weight" + cp_codec_weights.append(aux[key].to(dtype)) + codec_embed = CodecEmbedExport(main_codec_weight, cp_codec_weights) + codec_embed.eval() + + # 5. cp_head (standalone, kept for backward compat) + cp_head_weights = [] + for i in range(15): + key = f"code_predictor.lm_head.{i}.weight" + cp_head_weights.append(aux[key].to(dtype)) + cp_head = CpHeadExport(cp_head_weights) + cp_head.eval() + + # 6. cp_generate (FUSED: 15-step code predictor in one graph) + cp_model_fused, _ = load_code_predictor_model(talker_dir, max_seq_len=32) + cp_generate = CpGenerateExport( + cp_transformer=cp_model_fused, + cp_head_weights=[w.to(dtype) for w in cp_head_weights], + cp_embed_weights=[w.to(dtype) for w in cp_codec_weights], + ) + cp_generate.eval() + + # 7. decode_audio + checkpoint_path = converted_dir / metadata.decoder_checkpoint + decoder = load_decoder_from_metadata(metadata, checkpoint_path, dtype=dtype) + streaming_decoder = load_decoder_from_metadata( + metadata, checkpoint_path, dtype=dtype + ) + decode_audio = DynamicDecoderExport(decoder, metadata.decode_upsample_rate) + decode_audio_stream = StreamingDecoderExport( + streaming_decoder, + metadata.decode_upsample_rate, + chunk_size=STREAMING_DECODER_CHUNK_SIZE, + left_context_size=STREAMING_DECODER_LEFT_CONTEXT_SIZE, + ) + decode_audio.eval() + decode_audio.to(dtype=dtype) + decode_audio_stream.eval() + decode_audio_stream.to(dtype=dtype) + + for mod in [encode_text, talker, code_predictor, codec_embed, cp_head, + cp_generate, decode_audio, decode_audio_stream]: + for p in mod.parameters(): + p.requires_grad_(False) + for b in mod.buffers(): + b.requires_grad_(False) + + return { + "encode_text": encode_text, + "talker": talker, + "code_predictor": code_predictor, + "codec_embed": codec_embed, + "cp_head": cp_head, + "cp_generate": cp_generate, + "decode_audio": decode_audio, + "decode_audio_stream": decode_audio_stream, + }, talker_args, cp_args + + +def export_all( + modules: dict, + talker_args, + cp_args, + metadata: DecoderExportMetadata, + runtime_token_ids: Dict[str, int], + max_seq_len: int, + backend: str, + qlinear: str = None, + qlinear_group_size: int = 32, + qembedding: str = None, +): + """Export all methods into a single .pte.""" + + # Apply quantization before export. + if qlinear is not None or qembedding is not None: + from executorch.extension.llm.export.quantize import quantize_model_ + for name, mod in modules.items(): + if name in ("codec_embed", "cp_head"): + continue + q_linear = qlinear if name not in ("codec_embed",) else None + q_embed = qembedding if name in ("encode_text",) else None + if q_linear or q_embed: + print(f" Quantizing {name} (linear={q_linear}, embedding={q_embed})...") + quantize_model_( + mod, + qlinear_config=q_linear, + qlinear_group_size=qlinear_group_size, + qembedding_config=q_embed, + ) + + programs = {} + + # 1. encode_text — dynamic sequence length + print("Exporting encode_text...") + seq_dim = Dim("seq_len", min=1, max=4096) + sample_ids = torch.zeros(1, 10, dtype=torch.long) + programs["encode_text"] = export( + modules["encode_text"], + (sample_ids,), + dynamic_shapes={"token_ids": {1: seq_dim}}, + strict=False, + ) + + # 2. talker — dynamic sequence length for prefill+decode + print("Exporting talker...") + talker_seq = Dim("talker_seq", min=1, max=max_seq_len) + sample_embeds = torch.randn(1, 4, talker_args.dim) + sample_pos = torch.arange(4, dtype=torch.long) + programs["talker"] = export( + modules["talker"], + (sample_embeds, sample_pos), + dynamic_shapes={ + "embeds": {1: talker_seq}, + "input_pos": {0: talker_seq}, + }, + strict=False, + ) + + # 3. code_predictor — dynamic sequence length + print("Exporting code_predictor...") + cp_seq = Dim("cp_seq", min=1, max=32) + sample_cp_embeds = torch.randn(1, 2, cp_args.dim) + sample_cp_pos = torch.arange(2, dtype=torch.long) + programs["code_predictor"] = export( + modules["code_predictor"], + (sample_cp_embeds, sample_cp_pos), + dynamic_shapes={ + "embeds": {1: cp_seq}, + "input_pos": {0: cp_seq}, + }, + strict=False, + ) + + # 4. codec_embed — static shapes + print("Exporting codec_embed...") + sample_tid = torch.tensor([0], dtype=torch.long) + sample_gidx = torch.tensor([0], dtype=torch.long) + programs["codec_embed"] = export( + modules["codec_embed"], + (sample_tid, sample_gidx), + strict=False, + ) + + # 5. cp_head — static shapes + print("Exporting cp_head...") + sample_hidden = torch.randn(1, cp_args.dim) + sample_hidx = torch.tensor([0], dtype=torch.long) + programs["cp_head"] = export( + modules["cp_head"], + (sample_hidden, sample_hidx), + strict=False, + ) + + # 6. cp_generate — fused sampled 15-step code predictor (static shapes) + print("Exporting cp_generate (fused sampled 15-step loop)...") + sample_talker_hidden = torch.randn(1, 1, cp_args.dim) + sample_code0_embed = torch.randn(1, 1, cp_args.dim) + sample_temperature = torch.tensor([1.0], dtype=torch.float32) + sample_uniforms = torch.full((15,), 0.5, dtype=torch.float32) + programs["cp_generate"] = export( + modules["cp_generate"], + ( + sample_talker_hidden, + sample_code0_embed, + sample_temperature, + sample_uniforms, + ), + strict=False, + ) + + # 7. decode_audio — dynamic codes length + print("Exporting decode_audio...") + codes_dim = Dim("codes_len", min=1, max=2000) + sample_codes = torch.randint(0, metadata.codebook_size, (1, 10, metadata.num_quantizers), dtype=torch.long) + programs["decode_audio"] = export( + modules["decode_audio"], + (sample_codes,), + dynamic_shapes={"audio_codes": {1: codes_dim}}, + strict=False, + ) + + print("Exporting decode_audio_stream...") + sample_stream_codes = torch.full( + ( + 1, + STREAMING_DECODER_CHUNK_SIZE + STREAMING_DECODER_LEFT_CONTEXT_SIZE, + metadata.num_quantizers, + ), + -1, + dtype=torch.long, + ) + sample_stream_codes[:, :10, :] = torch.randint( + 0, + metadata.codebook_size, + (1, 10, metadata.num_quantizers), + dtype=torch.long, + ) + programs["decode_audio_stream"] = export( + modules["decode_audio_stream"], + (sample_stream_codes,), + strict=False, + ) + + # Build per-method partitioners. + if backend == "xnnpack": + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, + XnnpackPartitioner, + ) + partitioner = {} + for key in programs: + if key in ("codec_embed",): + partitioner[key] = [] + else: + partitioner[key] = [ + XnnpackDynamicallyQuantizedPartitioner(), + XnnpackPartitioner(), + ] + elif backend == "metal": + from executorch.backends.apple.metal.metal_backend import MetalBackend + from executorch.backends.apple.metal.metal_partitioner import MetalPartitioner + + # Linear bias decomposition (following Voxtral pattern). + def _linear_bias_decomposition(input_tensor, weight, bias=None): + out = torch.matmul(input_tensor, weight.t()) + if bias is not None: + out = out + bias + return out + + updated_programs = {} + for key, ep in programs.items(): + if key in ("codec_embed",): + updated_programs[key] = ep + else: + updated_programs[key] = ep.run_decompositions( + {torch.ops.aten.linear.default: _linear_bias_decomposition} + ) + programs = updated_programs + + partitioner = {} + for key in programs: + if key in ("codec_embed",): + partitioner[key] = [] + elif key in ("cp_generate", "decode_audio", "decode_audio_stream"): + # Keep GPU-incompatible methods on XNNPACK for the hybrid Metal path. + # `cp_generate` still lowers through topk/cumsum fallback kernels that + # the current AOTI Metal backend does not provide, and the decoder + # remains on XNNPACK until we have a true Metal vocoder path. + from executorch.backends.xnnpack.partition.xnnpack_partitioner import ( + XnnpackDynamicallyQuantizedPartitioner, + XnnpackPartitioner, + ) + partitioner[key] = [ + XnnpackDynamicallyQuantizedPartitioner(), + XnnpackPartitioner(), + ] + else: + compile_specs = [MetalBackend.generate_method_name_compile_spec(key)] + partitioner[key] = [MetalPartitioner(compile_specs)] + else: + partitioner = {key: [] for key in programs} + + ( + generation_backend_code, + decoder_backend_code, + prefer_streaming_decoder_surface, + ) = resolve_backend_runtime_metadata(backend) + + # Constant methods (metadata). + constant_methods = metadata.to_constant_methods() + constant_methods.update({ + "max_seq_len": max_seq_len, + "talker_vocab_size": talker_args.vocab_size, + "talker_dim": talker_args.dim, + "talker_n_layers": talker_args.n_layers, + "cp_n_layers": cp_args.n_layers, + "num_code_groups": 16, + "text_prompt_min_token_count": MIN_PROMPT_TOKEN_COUNT, + "text_prompt_prefill_token_count": TEXT_ONLY_PREFILL_TOKEN_COUNT, + "text_prompt_prefill_token_count_with_language": TEXT_ONLY_PREFILL_TOKEN_COUNT_WITH_LANGUAGE, + "text_prompt_trailing_template_token_count": TRAILING_TEMPLATE_TOKEN_COUNT, + "cp_generate_contract_version": 2, + "cp_generate_fast_top_k": 50, + "generation_backend_code": generation_backend_code, + "decoder_backend_code": decoder_backend_code, + "prefer_streaming_decoder_surface": prefer_streaming_decoder_surface, + "streaming_decoder_contract_version": 1, + "streaming_decoder_chunk_size": STREAMING_DECODER_CHUNK_SIZE, + "streaming_decoder_left_context_size": STREAMING_DECODER_LEFT_CONTEXT_SIZE, + "streaming_decoder_max_codes": STREAMING_DECODER_CHUNK_SIZE + STREAMING_DECODER_LEFT_CONTEXT_SIZE, + }) + constant_methods.update(runtime_token_ids) + + print("Lowering to ExecuTorch...") + edge_prog = to_edge_transform_and_lower( + programs, + partitioner=partitioner, + compile_config=EdgeCompileConfig( + _check_ir_validity=False, + _skip_dim_order=True, + ), + constant_methods=constant_methods, + ) + et_prog = edge_prog.to_executorch( + config=ExecutorchBackendConfig( + extract_delegate_segments=True, + do_quant_fusion_and_const_prop=True, + memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False), + ) + ) + return et_prog + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Export Qwen3-TTS as single multi-method .pte" + ) + parser.add_argument( + "--converted-dir", type=Path, required=True, + help="Directory with decoder_metadata.json and decoder checkpoint.", + ) + parser.add_argument( + "--talker-dir", type=Path, required=True, + help="Directory with talker_main.pth, talker_code_predictor.pth, talker_aux.pth.", + ) + parser.add_argument( + "--output-dir", type=Path, default=Path("./qwen3_tts_exports_unified"), + ) + parser.add_argument("--backend", choices=["portable", "xnnpack", "metal"], default="xnnpack") + parser.add_argument("--max-seq-len", type=int, default=256) + parser.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32") + parser.add_argument("--qlinear", choices=["4w", "8w", "8da4w", "8da8w"], default=None) + parser.add_argument("--qlinear-group-size", type=int, default=32) + parser.add_argument( + "--qembedding", choices=["4w", "8w"], default=None, + help="Embedding quantization. Reduces text_embedding from ~1.2GB to ~300-600MB.", + ) + parser.add_argument( + "--model-config-path", + type=Path, + default=DEFAULT_MODEL_CONFIG_PATH, + help="Path to the checked-in Qwen3-TTS config.json for runtime token IDs.", + ) + parser.add_argument("--output-name", type=str, default="model.pte") + return parser.parse_args() + + +def main(): + args = parse_args() + args.output_dir.mkdir(parents=True, exist_ok=True) + + converted_dir = args.converted_dir.resolve() + talker_dir = args.talker_dir.resolve() + model_config_path = args.model_config_path.resolve() + metadata = DecoderExportMetadata.from_json(converted_dir / "decoder_metadata.json") + runtime_token_ids = load_runtime_token_ids(model_config_path) + dtype = {"fp32": torch.float32, "bf16": torch.bfloat16}[args.dtype] + + print("Building wrapper modules...") + modules, talker_args, cp_args = build_wrapper_modules( + talker_dir=talker_dir, + converted_dir=converted_dir, + metadata=metadata, + max_seq_len=args.max_seq_len, + dtype=dtype, + backend=args.backend, + ) + + print(f"\nModule summary:") + for name, mod in modules.items(): + n_params = sum(p.numel() for p in mod.parameters()) + n_bufs = sum(b.numel() for b in mod.buffers()) + print(f" {name}: {n_params:,} params, {n_bufs:,} buffer elements") + + et_prog = export_all( + modules=modules, + talker_args=talker_args, + cp_args=cp_args, + metadata=metadata, + runtime_token_ids=runtime_token_ids, + max_seq_len=args.max_seq_len, + backend=args.backend, + qlinear=args.qlinear, + qlinear_group_size=args.qlinear_group_size, + qembedding=args.qembedding, + ) + ( + generation_backend_code, + decoder_backend_code, + prefer_streaming_decoder_surface, + ) = resolve_backend_runtime_metadata(args.backend) + + model_path = args.output_dir / args.output_name + with model_path.open("wb") as f: + et_prog.write_to_file(f) + file_size_mb = model_path.stat().st_size / (1024 * 1024) + print(f"\nSaved: {model_path} ({file_size_mb:.1f} MB)") + + manifest = { + "model_type": "qwen3_tts_unified", + "backend": args.backend, + "dtype": args.dtype, + "qlinear": args.qlinear, + "qembedding": args.qembedding, + "max_seq_len": args.max_seq_len, + "methods": list(modules.keys()), + "num_code_groups": 16, + "prompt_contract": "assistant_chat_text_v1", + "requires_tokenizer": True, + "supports_text_only_synthesis": True, + "supports_voice_clone_synthesis": False, + "text_prompt_min_token_count": MIN_PROMPT_TOKEN_COUNT, + "text_prompt_prefill_token_count": TEXT_ONLY_PREFILL_TOKEN_COUNT, + "text_prompt_prefill_token_count_with_language": TEXT_ONLY_PREFILL_TOKEN_COUNT_WITH_LANGUAGE, + "text_prompt_trailing_template_token_count": TRAILING_TEMPLATE_TOKEN_COUNT, + "cp_generate_contract_version": 2, + "cp_generate_fast_top_k": 50, + "cp_generate_sampler": "cdf_topk50_no_top_p_v2", + "generation_backend_code": generation_backend_code, + "decoder_backend_code": decoder_backend_code, + "prefer_streaming_decoder_surface": prefer_streaming_decoder_surface, + "streaming_decoder_contract_version": 1, + "streaming_decoder_chunk_size": STREAMING_DECODER_CHUNK_SIZE, + "streaming_decoder_left_context_size": STREAMING_DECODER_LEFT_CONTEXT_SIZE, + "streaming_decoder_max_codes": STREAMING_DECODER_CHUNK_SIZE + STREAMING_DECODER_LEFT_CONTEXT_SIZE, + **runtime_token_ids, + } + manifest_path = args.output_dir / "export_manifest.json" + with manifest_path.open("w") as f: + json.dump(manifest, f, indent=2, sort_keys=True) + print(f"Saved: {manifest_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3-tts/generate_codes.py b/examples/models/qwen3-tts/generate_codes.py new file mode 100644 index 00000000000..4e743eb8c24 --- /dev/null +++ b/examples/models/qwen3-tts/generate_codes.py @@ -0,0 +1,219 @@ +import argparse +import json +import random +import sys +from pathlib import Path +from typing import List, Optional + +import numpy as np +import torch + +SCRIPT_DIR = Path(__file__).resolve().parent +if str(SCRIPT_DIR) not in sys.path: + sys.path.insert(0, str(SCRIPT_DIR)) + +from model import write_codes_binary # noqa: E402 +from qwen_tts import Qwen3TTSModel # noqa: E402 + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description="Generate Qwen3-TTS codec ids for downstream ExecuTorch decode." + ) + parser.add_argument("--model-id-or-path", required=True, type=str) + parser.add_argument("--text", required=True, type=str) + parser.add_argument("--language", default="English", type=str) + parser.add_argument("--output-codes", required=True, type=Path) + parser.add_argument("--output-meta", default=None, type=Path) + parser.add_argument("--cache-dir", default=None, type=str) + parser.add_argument("--ref-audio", default=None, type=str) + parser.add_argument("--ref-text", default=None, type=str) + parser.add_argument( + "--x-vector-only-mode", + action="store_true", + help="Use x-vector only mode for voice clone prompt.", + ) + parser.add_argument("--non-streaming-mode", action="store_true") + parser.add_argument("--dtype", choices=["fp32", "bf16"], default="fp32") + parser.add_argument("--seed", type=int, default=42) + parser.add_argument("--instruct", type=str, default="") + parser.add_argument("--max-new-tokens", type=int, default=None) + parser.add_argument("--top-k", type=int, default=None) + parser.add_argument("--top-p", type=float, default=None) + parser.add_argument("--temperature", type=float, default=None) + parser.add_argument("--repetition-penalty", type=float, default=None) + parser.add_argument( + "--trim-silence", + action="store_true", + help="Trim leading silence codes. In streaming mode, the model generates " + "~N silent codes (where N ≈ text token count) before speech begins. " + "This decodes a small chunk to detect where speech starts and strips " + "the silent prefix, saving compute in the downstream decoder.", + ) + parser.add_argument( + "--trim-threshold", + type=float, + default=0.005, + help="RMS threshold for silence detection (default: 0.005).", + ) + return parser.parse_args() + + +def _default_reference_audio(duration_sec: float = 1.0, sample_rate: int = 24000): + wav = np.zeros(int(duration_sec * sample_rate), dtype=np.float32) + return wav, sample_rate + + +def _trim_silent_prefix( + codes: torch.Tensor, + model, + metadata_decoder_config=None, + threshold: float = 0.005, + upsample_rate: int = 1920, + sample_rate: int = 24000, + chunk_size: int = 5, +) -> torch.Tensor: + """Trim leading silent codes by decoding small chunks and checking RMS energy. + + In streaming mode, the talker generates ~N silent codes (N ≈ text token count) + while absorbing text before producing speech. This function finds where speech + starts by decoding codes in small chunks and checking audio energy. + + Returns the trimmed codes tensor [T', Q] with silent prefix removed. + """ + t_len, n_q = codes.shape + if t_len <= chunk_size: + return codes + + speech_start = 0 + for start in range(0, t_len - chunk_size + 1, chunk_size): + end = min(start + chunk_size, t_len) + chunk = codes[start:end] + chunk_clamped = torch.clamp(chunk, min=0) + with torch.no_grad(): + wav = model.model.speech_tokenizer.decode( + chunk_clamped.unsqueeze(0).transpose(1, 2) + ) + if isinstance(wav, (list, tuple)): + wav = wav[0] + wav = wav.squeeze() + rms = torch.sqrt(torch.mean(wav**2)).item() + if rms > threshold: + speech_start = max(0, start - 1) + break + else: + return codes + + if speech_start > 0: + print( + f"Trimmed {speech_start} silent codes " + f"({speech_start * upsample_rate / sample_rate:.1f}s silence)" + ) + return codes[speech_start:] + + +def _build_ref_ids( + model: Qwen3TTSModel, prompt_items +) -> List[Optional[torch.Tensor]]: + ref_ids = [] + for item in prompt_items: + if item.ref_text is None or item.ref_text == "": + ref_ids.append(None) + continue + ref_tok = model._tokenize_texts([model._build_ref_text(item.ref_text)])[0] + ref_ids.append(ref_tok) + return ref_ids + + +def main() -> None: + args = parse_args() + args.output_codes.parent.mkdir(parents=True, exist_ok=True) + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + dtype = {"fp32": torch.float32, "bf16": torch.bfloat16}[args.dtype] + model = Qwen3TTSModel.from_pretrained( + args.model_id_or_path, + device_map="cpu", + dtype=dtype, + cache_dir=args.cache_dir, + ) + + if args.ref_audio is not None: + prompt_items = model.create_voice_clone_prompt( + ref_audio=args.ref_audio, + ref_text=args.ref_text, + x_vector_only_mode=args.x_vector_only_mode, + ) + else: + silence, sr = _default_reference_audio() + prompt_items = model.create_voice_clone_prompt( + ref_audio=(silence, sr), + ref_text=None, + x_vector_only_mode=True, + ) + + gen_kwargs = model._merge_generate_kwargs( + max_new_tokens=args.max_new_tokens, + top_k=args.top_k, + top_p=args.top_p, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + ) + if args.instruct: + input_ids = model._tokenize_texts([model._build_assistant_text(args.text)]) + instruct_ids = [model._tokenize_texts([model._build_instruct_text(args.instruct)])[0]] + talker_codes_list, _ = model.model.generate( + input_ids=input_ids, + instruct_ids=instruct_ids, + languages=[args.language], + non_streaming_mode=args.non_streaming_mode, + **gen_kwargs, + ) + else: + prompt_dict = model._prompt_items_to_voice_clone_prompt(prompt_items) + input_ids = model._tokenize_texts([model._build_assistant_text(args.text)]) + ref_ids = _build_ref_ids(model, prompt_items) + talker_codes_list, _ = model.model.generate( + input_ids=input_ids, + ref_ids=ref_ids, + voice_clone_prompt=prompt_dict, + languages=[args.language], + non_streaming_mode=args.non_streaming_mode, + **gen_kwargs, + ) + codes = talker_codes_list[0].detach().cpu() + + if args.trim_silence: + codes = _trim_silent_prefix( + codes, model, metadata_decoder_config=None, threshold=args.trim_threshold + ) + + write_codes_binary(args.output_codes, codes) + + meta = { + "model_id_or_path": args.model_id_or_path, + "language": args.language, + "text": args.text, + "num_codes": int(codes.shape[0]), + "num_quantizers": int(codes.shape[1]), + "seed": args.seed, + "instruct": args.instruct, + "x_vector_only_mode": bool( + args.x_vector_only_mode or args.ref_audio is None + ), + "ref_audio_provided": args.ref_audio is not None, + "non_streaming_mode": args.non_streaming_mode, + "trim_silence": args.trim_silence, + } + meta_path = args.output_meta or args.output_codes.with_suffix(".json") + with meta_path.open("w", encoding="utf-8") as f: + json.dump(meta, f, indent=2, sort_keys=True) + + print(f"Saved codec ids: {args.output_codes}") + print(f"Saved metadata: {meta_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/models/qwen3-tts/main.cpp b/examples/models/qwen3-tts/main.cpp new file mode 100644 index 00000000000..a2c44faa20e --- /dev/null +++ b/examples/models/qwen3-tts/main.cpp @@ -0,0 +1,142 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +#include + +#include "qwen3_tts_runner.h" + +DEFINE_string(model_path, "model.pte", "Path to qwen3-tts decoder model (.pte)."); +DEFINE_string( + model_dir, + "", + "Directory containing bucketed .pte files and export_manifest.json. " + "If set, overrides --model_path and enables multi-bucket mode."); +DEFINE_string( + data_path, + "", + "Path to optional data file (.ptd) for delegate data."); +DEFINE_string( + codes_path, + "", + "Path to pre-generated codec ids (.bin). If omitted, helper script is used."); +DEFINE_string(output_wav, "output.wav", "Path to output wav file."); + +DEFINE_string( + text, + "", + "Text for synthesis. Required when --codes_path is not provided."); +DEFINE_string(language, "English", "Language used for generation helper."); +DEFINE_string( + model_id_or_path, + "", + "Model id/path used by the generation helper (required when --codes_path is not provided)."); +DEFINE_string( + helper_script, + "examples/models/qwen3-tts/generate_codes.py", + "Path to Python helper script that generates codec ids."); +DEFINE_string( + python_executable, + "python", + "Python executable used to run the helper script."); +DEFINE_string(ref_audio, "", "Optional reference audio for voice cloning."); +DEFINE_string(ref_text, "", "Optional reference text for voice cloning."); +DEFINE_bool(x_vector_only_mode, false, "Use x-vector-only mode for voice clone."); +DEFINE_bool( + non_streaming_mode, + false, + "Forward non-streaming text mode to helper generation."); + +DEFINE_int32(max_new_tokens, -1, "Optional max_new_tokens forwarded to helper."); +DEFINE_int32(top_k, -1, "Optional top_k forwarded to helper."); +DEFINE_double(top_p, -1.0, "Optional top_p forwarded to helper."); +DEFINE_double(temperature, -1.0, "Optional temperature forwarded to helper."); +DEFINE_double( + repetition_penalty, + -1.0, + "Optional repetition_penalty forwarded to helper."); + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + std::unique_ptr runner_ptr; + if (!FLAGS_model_dir.empty()) { + runner_ptr = qwen3_tts::Qwen3TTSRunner::from_model_dir(FLAGS_model_dir); + if (!runner_ptr) { + ET_LOG(Error, "Failed to load bucketed models from: %s", FLAGS_model_dir.c_str()); + return 1; + } + } else { + runner_ptr = std::make_unique( + FLAGS_model_path, FLAGS_data_path); + } + auto& runner = *runner_ptr; + + std::string codes_path = FLAGS_codes_path; + std::filesystem::path tmp_codes; + if (codes_path.empty()) { + if (FLAGS_text.empty()) { + ET_LOG(Error, "Either --codes_path or --text must be provided."); + return 1; + } + if (FLAGS_model_id_or_path.empty()) { + ET_LOG( + Error, + "--model_id_or_path is required when --codes_path is not provided."); + return 1; + } + + tmp_codes = std::filesystem::temp_directory_path() / + "qwen3_tts_codegen_codes.bin"; + codes_path = tmp_codes.string(); + + qwen3_tts::CodeGenerationArgs helper_args; + helper_args.python_executable = FLAGS_python_executable; + helper_args.helper_script = FLAGS_helper_script; + helper_args.model_id_or_path = FLAGS_model_id_or_path; + helper_args.text = FLAGS_text; + helper_args.language = FLAGS_language; + helper_args.output_codes_path = codes_path; + helper_args.ref_audio_path = FLAGS_ref_audio; + helper_args.ref_text = FLAGS_ref_text; + helper_args.x_vector_only_mode = FLAGS_x_vector_only_mode; + helper_args.non_streaming_mode = FLAGS_non_streaming_mode; + helper_args.max_new_tokens = FLAGS_max_new_tokens; + helper_args.top_k = FLAGS_top_k; + helper_args.top_p = static_cast(FLAGS_top_p); + helper_args.temperature = static_cast(FLAGS_temperature); + helper_args.repetition_penalty = static_cast(FLAGS_repetition_penalty); + + if (!runner.run_code_generation(helper_args)) { + return 1; + } + } + + std::vector waveform; + if (!runner.decode_codes_file(codes_path, &waveform)) { + return 1; + } + + if (!runner.write_wav_file(FLAGS_output_wav, waveform)) { + ET_LOG(Error, "Failed to write wav output: %s", FLAGS_output_wav.c_str()); + return 1; + } + + ET_LOG( + Info, + "Wrote %zu samples at %d Hz to %s", + waveform.size(), + runner.output_sample_rate(), + FLAGS_output_wav.c_str()); + return 0; +} diff --git a/examples/models/qwen3-tts/main_unified.cpp b/examples/models/qwen3-tts/main_unified.cpp new file mode 100644 index 00000000000..58e383eb44b --- /dev/null +++ b/examples/models/qwen3-tts/main_unified.cpp @@ -0,0 +1,388 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// Generated with assistance from Claude. + +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include "qwen3_tts_unified_runner.h" + +DEFINE_string( + model_path, + "model.pte", + "Path to unified qwen3-tts model (.pte)."); +DEFINE_string( + tokenizer_path, + "", + "Path to tokenizer.json (for text-to-audio mode)."); +DEFINE_string( + codes_path, + "", + "Path to pre-generated codec ids (.bin). " + "If provided, runs decode-only mode."); +DEFINE_string(output_wav, "output.wav", "Path to output wav file."); +DEFINE_string( + output_dir, + "", + "Optional directory for per-prompt wav outputs in --prompts_path mode."); +DEFINE_string( + text, + "", + "Text for synthesis (requires --tokenizer_path)."); +DEFINE_string( + prompts_path, + "", + "Optional newline-delimited prompt file for warm multi-prompt benchmarking."); +DEFINE_string(language, "English", "Language for synthesis."); + +DEFINE_int32(max_new_tokens, 200, "Max codec tokens to generate."); +DEFINE_double(temperature, 0.9, "Sampling temperature."); +DEFINE_int32(top_k, 50, "Top-k sampling."); +DEFINE_double(top_p, 1.0, "Top-p sampling. Values <= 0 disable nucleus filtering."); +DEFINE_double(repetition_penalty, 1.05, "Repetition penalty for talker code_0 sampling."); +DEFINE_int32(repeat, 1, "Repeat count for each prompt in --prompts_path mode."); +DEFINE_uint64(seed, 42, "Base RNG seed for text synthesis."); +DEFINE_bool( + disable_fused_cp_generate, + false, + "Force the legacy host-side code predictor loop for validation."); +DEFINE_string( + instruct, + "", + "VoiceDesign instruct text (e.g. 'A cheerful young female voice')."); +DEFINE_bool( + non_streaming_mode, + false, + "Disable chunk emission during generation and only decode final audio."); +DEFINE_double( + streaming_interval, + 2.0, + "Streaming emit interval in seconds (0 = disabled unless --streaming_chunk_steps is set)."); +DEFINE_int32( + streaming_chunk_steps, + 0, + "Deprecated alias for the emit interval expressed in codec steps."); +DEFINE_int32( + streaming_chunk_size, + 300, + "Maximum codec steps per overlap-context decode window."); +DEFINE_int32( + streaming_left_context_size, + 25, + "Left-context codec steps preserved for overlap-context decode."); +DEFINE_bool( + disable_streaming_decoder_surface, + false, + "Force the runner-side overlap decode path even when decode_audio_stream is available."); +DEFINE_bool( + force_streaming_decoder_surface, + false, + "Override export metadata and force decode_audio_stream when it is available."); +DEFINE_bool( + use_legacy_cumulative_streaming_decode, + false, + "For benchmarking only: re-decode the full accumulated prefix on each chunk."); +DEFINE_bool( + trim_silence, + true, + "Trim leading silence from output audio."); +DEFINE_double( + trim_threshold, + 0.005, + "RMS threshold for silence trimming."); + +namespace { + +bool trim_leading_silence( + std::vector* waveform, + int sample_rate, + double threshold, + double* trimmed_ms) { + if (waveform == nullptr || waveform->empty()) { + if (trimmed_ms != nullptr) { + *trimmed_ms = 0.0; + } + return true; + } + size_t speech_start = 0; + const float threshold_f = static_cast(threshold); + for (size_t i = 0; i < waveform->size(); ++i) { + if (std::abs((*waveform)[i]) > threshold_f) { + const size_t margin = static_cast(0.05 * sample_rate); + speech_start = (i > margin) ? i - margin : 0; + break; + } + } + if (trimmed_ms != nullptr) { + *trimmed_ms = 1000.0 * static_cast(speech_start) / sample_rate; + } + if (speech_start > 0) { + waveform->erase(waveform->begin(), waveform->begin() + speech_start); + } + return true; +} + +bool read_prompts_file(const std::string& prompts_path, std::vector* prompts) { + std::ifstream in(prompts_path); + if (!in.good()) { + ET_LOG(Error, "Could not open prompts file: %s", prompts_path.c_str()); + return false; + } + std::string line; + while (std::getline(in, line)) { + if (!line.empty()) { + prompts->push_back(line); + } + } + if (prompts->empty()) { + ET_LOG(Error, "No non-empty prompts found in: %s", prompts_path.c_str()); + return false; + } + return true; +} + +} // namespace + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + if (!FLAGS_codes_path.empty() && + (!FLAGS_text.empty() || !FLAGS_prompts_path.empty())) { + ET_LOG( + Error, + "Provide either --codes_path or text synthesis inputs, not both."); + return 1; + } + if (!FLAGS_text.empty() && !FLAGS_prompts_path.empty()) { + ET_LOG(Error, "Provide either --text or --prompts_path, not both."); + return 1; + } + if (FLAGS_codes_path.empty() && FLAGS_text.empty() && FLAGS_prompts_path.empty()) { + ET_LOG(Error, "Either --codes_path, --text, or --prompts_path must be provided."); + return 1; + } + if ((!FLAGS_text.empty() || !FLAGS_prompts_path.empty()) && + FLAGS_tokenizer_path.empty()) { + ET_LOG(Error, "Text synthesis requires --tokenizer_path."); + return 1; + } + if (FLAGS_repeat <= 0) { + ET_LOG(Error, "--repeat must be positive."); + return 1; + } + + const auto t_construct_start = std::chrono::steady_clock::now(); + qwen3_tts::Qwen3TTSUnifiedRunner runner( + FLAGS_model_path, FLAGS_tokenizer_path); + const auto t_construct_end = std::chrono::steady_clock::now(); + const double construct_ms = + std::chrono::duration( + t_construct_end - t_construct_start) + .count(); + + const auto t_warmup_start = std::chrono::steady_clock::now(); + if (!FLAGS_codes_path.empty()) { + runner.warmup_decode(); + } else { + runner.warmup_all(); + } + const auto t_warmup_end = std::chrono::steady_clock::now(); + const double warmup_ms = + std::chrono::duration(t_warmup_end - t_warmup_start) + .count(); + ET_LOG(Info, "Runner construction: %.1f ms", construct_ms); + ET_LOG(Info, "Warmup: %.1f ms", warmup_ms); + + std::vector waveform; + + if (!FLAGS_codes_path.empty()) { + // Decode-only mode: use precomputed codes. + ET_LOG(Info, "Decode-only mode: %s", FLAGS_codes_path.c_str()); + auto t0 = std::chrono::steady_clock::now(); + if (!runner.decode_codes_file(FLAGS_codes_path, &waveform)) { + ET_LOG(Error, "decode_codes_file failed."); + return 1; + } + auto t1 = std::chrono::steady_clock::now(); + double decode_ms = + std::chrono::duration(t1 - t0).count(); + double audio_sec = + static_cast(waveform.size()) / runner.output_sample_rate(); + ET_LOG( + Info, + "Decoded %zu samples (%.2fs audio) in %.1f ms (%.2fx realtime)", + waveform.size(), + audio_sec, + decode_ms, + audio_sec / (decode_ms / 1000.0)); + } else { + std::vector prompts; + if (!FLAGS_text.empty()) { + prompts.push_back(FLAGS_text); + } else if (!read_prompts_file(FLAGS_prompts_path, &prompts)) { + return 1; + } + + qwen3_tts::SynthesizeConfig config; + config.max_new_tokens = FLAGS_max_new_tokens; + config.temperature = static_cast(FLAGS_temperature); + config.top_k = FLAGS_top_k; + config.top_p = static_cast(FLAGS_top_p); + config.repetition_penalty = static_cast(FLAGS_repetition_penalty); + config.seed = FLAGS_seed; + config.use_fused_cp_generate = !FLAGS_disable_fused_cp_generate; + config.instruct = FLAGS_instruct; + config.non_streaming_mode = FLAGS_non_streaming_mode; + config.streaming_interval_sec = static_cast(FLAGS_streaming_interval); + config.streaming_chunk_steps = FLAGS_streaming_chunk_steps; + config.streaming_chunk_size = FLAGS_streaming_chunk_size; + config.streaming_left_context_size = FLAGS_streaming_left_context_size; + config.disable_streaming_decoder_surface = + FLAGS_disable_streaming_decoder_surface; + config.force_streaming_decoder_surface = FLAGS_force_streaming_decoder_surface; + config.use_legacy_cumulative_streaming_decode = + FLAGS_use_legacy_cumulative_streaming_decode; + + if (!FLAGS_prompts_path.empty() && !FLAGS_disable_fused_cp_generate && + FLAGS_top_k <= 0) { + config.top_k = 50; + ET_LOG( + Info, + "Benchmark mode defaulting top_k to %d so cp_generate fast path is exercised.", + config.top_k); + } + + if (!FLAGS_output_dir.empty()) { + std::filesystem::create_directories(FLAGS_output_dir); + } + + for (int repeat_idx = 0; repeat_idx < FLAGS_repeat; ++repeat_idx) { + for (size_t prompt_idx = 0; prompt_idx < prompts.size(); ++prompt_idx) { + waveform.clear(); + qwen3_tts::SynthesizeConfig prompt_config = config; + prompt_config.seed = + FLAGS_seed + static_cast(repeat_idx * prompts.size() + prompt_idx); + auto session = runner.create_synthesis_session(prompt_config); + qwen3_tts::SynthesisTiming timing; + int streaming_chunks_received = 0; + qwen3_tts::AudioChunkCallback stream_cb = nullptr; + const bool streaming_enabled = + !prompt_config.non_streaming_mode && + (prompt_config.streaming_chunk_steps > 0 || + prompt_config.streaming_interval_sec > 0.0f); + if (streaming_enabled) { + stream_cb = [&](const std::vector& chunk, bool is_final) { + ++streaming_chunks_received; + double chunk_sec = + static_cast(chunk.size()) / runner.output_sample_rate(); + ET_LOG( + Info, + "Stream chunk %d: %.2fs (%zu samples)%s", + streaming_chunks_received, + chunk_sec, + chunk.size(), + is_final ? " [final]" : ""); + }; + } + if (!session->synthesize( + prompts[prompt_idx], FLAGS_language, &waveform, &timing, + std::move(stream_cb))) { + ET_LOG( + Error, + "Synthesis failed for prompt %zu repeat %d.", + prompt_idx, + repeat_idx); + return 1; + } + + const size_t raw_sample_count = waveform.size(); + const double raw_audio_sec = + static_cast(raw_sample_count) / runner.output_sample_rate(); + double postprocess_ms = 0.0; + double trimmed_ms = 0.0; + const auto t_postprocess = std::chrono::steady_clock::now(); + if (FLAGS_trim_silence) { + trim_leading_silence( + &waveform, + runner.output_sample_rate(), + FLAGS_trim_threshold, + &trimmed_ms); + } + std::string output_path; + const bool should_write_single = + prompts.size() == 1 && FLAGS_prompts_path.empty() && + !FLAGS_output_wav.empty(); + const bool should_write_batch = + prompts.size() > 1 && !FLAGS_output_dir.empty(); + if (should_write_single) { + output_path = FLAGS_output_wav; + } else if (should_write_batch) { + output_path = FLAGS_output_dir + "/prompt_" + + std::to_string(prompt_idx) + "_repeat_" + + std::to_string(repeat_idx) + ".wav"; + } + if (!output_path.empty() && !runner.write_wav_file(output_path, waveform)) { + ET_LOG(Error, "Failed to write wav: %s", output_path.c_str()); + return 1; + } + postprocess_ms = + std::chrono::duration( + std::chrono::steady_clock::now() - t_postprocess) + .count(); + + const double trimmed_audio_sec = + static_cast(waveform.size()) / runner.output_sample_rate(); + const double generation_sec = timing.total_generation_ms / 1000.0; + const double raw_rtf = + generation_sec > 0.0 ? raw_audio_sec / generation_sec : 0.0; + const double trimmed_rtf = + generation_sec > 0.0 ? trimmed_audio_sec / generation_sec : 0.0; + ET_LOG( + Info, + "prompt=%zu repeat=%d tokens=%d steps=%d audio=%.2fs " + "trimmed_audio=%.2fs " + "prep=%.1fms prefill=%.1fms codegen=%.1fms first_audio=%.1fms " + "chunk_decode=%.1fms final_decode=%.1fms decode=%.1fms " + "generation=%.1fms post=%.1fms trimmed=%.1fms " + "rtf=%.2fx rtf_trimmed=%.2fx", + prompt_idx, + repeat_idx, + timing.prompt_token_count, + timing.generated_codec_steps, + raw_audio_sec, + trimmed_audio_sec, + timing.prompt_prep_ms, + timing.talker_prefill_ms, + timing.codegen_ms, + timing.first_audio_ms, + timing.chunk_decode_ms, + timing.final_decode_ms, + timing.decode_audio_ms, + timing.total_generation_ms, + postprocess_ms, + trimmed_ms, + raw_rtf, + trimmed_rtf); + if (!output_path.empty()) { + ET_LOG(Info, "Wrote wav: %s", output_path.c_str()); + } + } + } + } + + return 0; +} diff --git a/examples/models/qwen3-tts/mermaid_architecture_qwen3_tts_xnnpack.md b/examples/models/qwen3-tts/mermaid_architecture_qwen3_tts_xnnpack.md new file mode 100644 index 00000000000..ba93b6b2344 --- /dev/null +++ b/examples/models/qwen3-tts/mermaid_architecture_qwen3_tts_xnnpack.md @@ -0,0 +1,135 @@ +# Qwen3-TTS XNNPACK Pipeline Architecture + +Copy the code below and paste into: +- **VS Code**: Paste in any `.md` file, press `Ctrl+Shift+V` to preview +- **Mermaid Playground**: https://mermaid.live +- **GitHub**: Renders natively in `.md` files + +## End-to-End Pipeline + +```mermaid +flowchart TD + subgraph Export["Export (Python)"] + direction TB + PyModel["Qwen3-TTS PyTorch Model"] --> ExpScript["export_unified.py"] + ExpScript --> PTE["model.pte (2.3 GB, 8da4w)"] + end + + subgraph Runner["C++ Runner (main_unified.cpp)"] + direction TB + CLI["CLI Input\n--text / --prompts_path"] --> Session["SynthesisSession\nper-session RNG + config"] + Session --> Pipeline + end + + subgraph Pipeline["Synthesis Pipeline (XNNPACK, CPU-only)"] + direction TB + Tokenize["Tokenize Text\n(HuggingFace JSON tokenizer)"] --> EncText["encode_text\ntoken_ids → embeddings ∈ ℝ¹ˣˢˣ¹⁰²⁴"] + EncText --> Prefill["talker prefill\ntext embeddings → hidden + logits"] + Prefill --> Loop + + subgraph Loop["⚠️ BOTTLENECK: Autoregressive Codec Loop\n~130-150ms per step on CPU\n~20s total for 11s of audio"] + direction TB + SampleCode0["Sample code_0\n(top-k, rep. penalty, EOS check)"] + SampleCode0 --> Decision{{"use fused\nfast path?"}} + + Decision -- "Yes\n(contract v2, top_k=50)" --> Fused["cp_generate (fused)\n1 XNNPACK call → 15 sub-codes\ninverse-CDF top-k(50) sampling"] + Decision -- "No\n(fallback)" --> Legacy + + subgraph Legacy["Legacy Host Loop"] + direction TB + CP["code_predictor"] --> CPH["cp_head"] + CPH --> CE["codec_embed"] + CE --> CP + end + + Fused --> EmbSum["embedding sum → next talker input"] + Legacy --> EmbSum + EmbSum --> Talker["🔴 talker decode step\ndense matmul on CPU\nhidden + logits for next code_0"] + Talker --> SampleCode0 + end + + Loop -- "EOS or limit" --> Decode["decode_audio\naudio codes → waveform (24 kHz)"] + end + + Decode --> WAV["Output .wav file"] + + PTE -.-> Runner + Pipeline -.-> Timing["SynthesisTiming\nprep | prefill | codegen | decode"] +``` + +### Current Performance + +> | Metric | Legacy (host loop) | Fused cp_generate v2 | +> |--------|-------------------|----------------------| +> | Generation time | 23.9s | 19.6s | +> | Codegen | 21.1s | 17.0s | +> | Per-step cost | ~150ms | ~125ms | +> | Audio output | 11.5s | 10.8s | +> +> Fused `cp_generate` v2 collapsed 15 host round-trips into 1 graph call, achieving ~15-20% codegen speedup. + +## Fused cp_generate v2 Detail + +```mermaid +flowchart LR + subgraph Inputs["Host → Graph"] + H["talker_hidden"] + E0["code_0_embed"] + T["temperature"] + U["sample_uniforms\n(15 uniform randoms)"] + end + + subgraph FusedGraph["cp_generate XNNPACK Graph"] + direction TB + Pre["CP prefill\n(hidden + code_0)"] --> G1 + + subgraph G1["Group 1..15 Loop (unrolled)"] + direction TB + Head["LM Head\nhidden → logits"] --> TopK["top-k(50)"] + TopK --> Softmax["softmax / temperature"] + Softmax --> CDF["cumsum → CDF"] + CDF --> Sample["inverse-CDF sample\nusing uniform random"] + Sample --> Embed["codec embed lookup"] + Embed --> CP_Step["CP transformer step"] + CP_Step --> Head + end + end + + subgraph Outputs["Graph → Host"] + Codes["sampled_subcodes\nint64 × 15"] + ESum["embed_sum\nfloat × 1024"] + end + + Inputs --> FusedGraph --> Outputs +``` + +## Warm Benchmark Session Flow + +```mermaid +sequenceDiagram + participant CLI as main_unified + participant Runner as Qwen3TTSUnifiedRunner + participant Session as SynthesisSession + participant PTE as model.pte (XNNPACK) + + CLI->>Runner: construct(model_path, tokenizer_path) + CLI->>Runner: warmup_all() + Runner->>PTE: load + execute all 7 methods once + + loop For each prompt × repeat + CLI->>Runner: create_synthesis_session(config) + Runner-->>Session: new session (fresh RNG) + CLI->>Session: synthesize(text, language) + Session->>PTE: encode_text → talker → cp_generate loop → decode_audio + Session-->>CLI: waveform + SynthesisTiming + CLI->>CLI: trim silence, write WAV, log timing + end +``` + +## Summary + +These diagrams show the Qwen3-TTS XNNPACK pipeline at three levels: + +1. **End-to-end pipeline**: text input → tokenization → 7-method model execution → WAV output, with the fused/legacy fast-path decision gate +2. **Fused cp_generate v2**: the internal XNNPACK graph that collapses 15 host round-trips into one call using inverse-CDF sampling +3. **Warm benchmark session**: how `SynthesisSession` keeps the runner warm across sequential prompts for honest latency measurement diff --git a/examples/models/qwen3-tts/metal-progress.md b/examples/models/qwen3-tts/metal-progress.md new file mode 100644 index 00000000000..f36029be8b6 --- /dev/null +++ b/examples/models/qwen3-tts/metal-progress.md @@ -0,0 +1,220 @@ +# Metal Streaming Progress + +Branch: `qwen3-tts-metal-streaming` + +## Goal + +Build the best practical Metal-backed Qwen3-TTS streaming path in the C++ +runner first, using a hybrid deployment where text generation runs on Metal and +the vocoder remains on XNNPACK until a true Metal decoder path is proven. + +## Transferred lessons + +### From XNNPACK + +- The fixed-shape `decode_audio_stream` export is functionally correct, but its + performance is much more sensitive to emit interval and warm state than the + overlap-window `decode_audio` fallback. +- The metric layer now separates `codegen_ms`, `first_audio_ms`, and raw + realtime factor correctly, so we should reuse the same benchmark discipline + for Metal. + +### From MLX + +- Stateful decoder ideas are important long term, but the first shipping win is + often in orchestration and cached context rather than a wholesale model + rewrite. +- Streaming policy must be validated on a warmed prompt set instead of inferred + from isolated decode-only timings. + +### From `voxtral_realtime` + +- Export metadata should act as a runtime contract. +- Backend choice should be represented explicitly instead of inferred loosely in + the runner. +- Streaming runners should prefer the backend-specific fast path by default, not + just whichever method happens to exist in the export. + +## Current implementation slice + +- Added backend split metadata to unified exports: + - `generation_backend_code` + - `decoder_backend_code` + - `prefer_streaming_decoder_surface` +- For current Metal exports, the intended contract is: + - generation backend = `metal` + - decoder backend = `xnnpack` + - preferred streaming decoder surface = `overlap_window` +- Updated the C++ runner to honor that metadata by default while still exposing + a force flag for experiments: + - `--force_streaming_decoder_surface` +- Kept `cp_generate` on XNNPACK for Metal exports after confirming that the + fused method still needs `topk` and `cumsum` fallback kernels that the + current AOTI Metal backend does not provide. +- Reused the existing Llama MPS fix for bool causal masks by applying + `replace_causal_mask()` to the Metal-exported talker and code-predictor + transformers before export. + +## Why this is the right first step + +Today Qwen3-TTS is not blocked on "no Metal support at all." It already has a +mixed Metal/XNNPACK export path. The real practical issue is that the runner can +still auto-select the slower streaming decoder surface because it only checks +capability, not backend-aware preference. + +Fixing that gives us a better hybrid shipping path immediately and makes the +next benchmark meaningful. + +## Verification completed + +Focused contract tests: + +```bash +conda run -n executorch python -m unittest \ + examples.models.qwen3-tts.tests.test_unified_runner_contract \ + examples.models.qwen3-tts.tests.test_unified_quality_contract \ + examples.models.qwen3-tts.tests.test_unified_metadata +``` + +Result: `PASS` + +Runner rebuild: + +```bash +cmake --build cmake-out/examples/models/qwen3-tts --target qwen3_tts_unified_runner +``` + +Result: `PASS` + +## Verification completed on Metal artifact + +Metal export with the new metadata: + +```bash +conda run -n executorch python examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir /tmp/qwen3_tts_exports_metal_streaming_maskfix \ + --backend metal \ + --dtype fp32 +``` + +Result: `PASS` + +First export attempt failed before lowering completed because `cp_generate` was +still being partitioned to Metal: + +```text +RuntimeError: Method cp_generate missing fallback kernels (2 total): + - at::_ops::cumsum::call + - at::_ops::topk::call +``` + +That failure is now the documented reason the branch keeps `cp_generate` on +XNNPACK for the hybrid Metal path. + +First runtime attempt with the saved Metal artifact exposed the next backend +compatibility issue: + +```text +Unsupported dtype: 11. Supported dtypes: 0 (uint8), 4 (int64), 6 (float32), 15 (bfloat16) +``` + +Root-cause investigation points to the bool causal mask buffer inherited from +the reused Llama attention stack. The current branch now mirrors the working +Llama MPS path and rewrites those masks to float additive masks before Metal +export. + +## Benchmark caveats discovered + +- Process-to-process warm state matters a lot on the hybrid Metal artifact even + after the runner's in-process warmup. The same overlap-window benchmark + improved from weighted `RTF=0.0658x` on the first process to `RTF=0.1152x` on + the next process. +- The fixed-surface path is very sensitive to emit cadence. With the same + artifact, `--streaming_interval 2.0` dropped the weighted prompt-set RTF to + `0.0528x`, while `--streaming_interval 4.0` raised it to `0.1010x` on a warm + process. +- Because of those two effects, a single run is not a trustworthy policy signal + for the Metal branch. We should compare warmed prompt-set runs with matched + intervals before changing the export default. + +## Benchmark commands + +Use the same artifact and the same emit interval for both policies: + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path /tmp/qwen3_tts_exports_metal_streaming_maskfix/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --prompts_path examples/models/qwen3-tts/benchmark_prompts.txt \ + --repeat 1 \ + --max_new_tokens 128 \ + --temperature 1.0 \ + --top_k 50 \ + --streaming_interval 4.0 +``` + +```bash +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path /tmp/qwen3_tts_exports_metal_streaming_maskfix/model.pte \ + --tokenizer_path examples/models/qwen3-tts/qwen3-tts-12Hz-0.6B-Base/tokenizer.json \ + --prompts_path examples/models/qwen3-tts/benchmark_prompts.txt \ + --repeat 1 \ + --max_new_tokens 128 \ + --temperature 1.0 \ + --top_k 50 \ + --streaming_interval 4.0 \ + --force_streaming_decoder_surface +``` + +If you want numbers that are comparable to the warm-run table below, do one +throwaway process first and judge the second process, not the very first launch +after export. + +## Benchmark results + +| Run | Policy | Interval | Weighted RTF | Avg first audio | Notes | +| --- | --- | --- | --- | --- | --- | +| First matched pair | overlap-window | 4.0s | `0.0658x` | `49.09s` | Fresh process after export | +| First matched pair | fixed-surface | 4.0s | `0.0699x` | `45.04s` | Fresh process after export | +| Sensitivity probe | fixed-surface | 2.0s | `0.0528x` | `30.22s` | Earlier unfair comparison; included here only to show interval sensitivity | +| Warm rerun | overlap-window | 4.0s | `0.1152x` | `28.97s` | Best current apples-to-apples result | +| Warm rerun | fixed-surface | 4.0s | `0.1010x` | `30.70s` | Still useful for short utterances, but worse overall | + +Warm-run prompt details: + +- Overlap-window: + - prompt 0: `audio=1.68s`, `generation=20.30s`, `rtf=0.08x` + - prompt 1: `audio=6.16s`, `generation=52.40s`, `rtf=0.12x` + - prompt 2: `audio=8.08s`, `generation=65.49s`, `rtf=0.12x` +- Fixed-surface: + - prompt 0: `audio=1.68s`, `generation=18.50s`, `rtf=0.09x` + - prompt 1: `audio=6.16s`, `generation=58.43s`, `rtf=0.11x` + - prompt 2: `audio=8.08s`, `generation=80.74s`, `rtf=0.10x` + +## Current decision + +- Keep `prefer_streaming_decoder_surface = 0` for the current hybrid Metal + export. On the warmed apples-to-apples `4.0s` benchmark, overlap-window beats + fixed-surface on weighted prompt-set throughput (`0.1152x` vs `0.1010x`) and + average first-audio latency (`28.97s` vs `30.70s`). +- Keep `--force_streaming_decoder_surface` as an experiment knob. It can still + help on very short utterances, and it is the right path to compare when we + revisit a true Metal decoder surface later. +- Treat `--streaming_interval 4.0` as the current benchmark baseline for this + branch. `2.0s` is too punitive to the fixed-surface path and obscures the real + policy decision. +- Any future claims about Metal streaming speed should use a warmed prompt-set + benchmark and call out whether the result comes from the first or second + process after export. + +## Next decision gate + +The runner/export contract is now stable enough to move to second-stage tuning: + +- explain the large cross-process warm-state delta on the Metal artifact +- benchmark interval and chunk-size tuning around the overlap-window default +- evaluate whether a deeper hybrid pipeline overlap change is worth the added + complexity +- defer true Metal vocoder work until the hybrid baseline is clearly established diff --git a/examples/models/qwen3-tts/metal_benchmark.md b/examples/models/qwen3-tts/metal_benchmark.md new file mode 100644 index 00000000000..df141594f1c --- /dev/null +++ b/examples/models/qwen3-tts/metal_benchmark.md @@ -0,0 +1,61 @@ +# Qwen3-TTS Metal Backend Benchmark + +## Results + +### Metal/AOTI Export ✅ WORKING + +Export command: +```bash +python3 export_unified.py --backend metal --dtype fp32 \ + --converted-dir qwen3_tts_artifacts \ + --talker-dir qwen3_tts_artifacts/talker_converted \ + --output-dir /tmp/qwen3_tts_metal_v2 +``` + +| Metric | Value | +|--------|-------| +| Export time | ~8 min (AOTInductor compile) | +| Model size | 4,636 MB (fp32, no quantization) | +| Methods | 7 (encode_text, talker, code_predictor, codec_embed, cp_head, cp_generate, decode_audio) | +| Metal methods | 5 (everything except codec_embed + decode_audio) | +| decode_audio backend | XNNPACK (Metal lacks cumsum fallback) | + +### Decode Performance (codes → audio) + +| Backend | 26 codes | Realtime | Notes | +|---------|----------|----------|-------| +| Metal + XNNPACK decoder | **728 ms** | **2.86x RT** | Mixed: Metal talker, XNNPACK decoder | +| XNNPACK only | 1,056 ms | 2.42x RT | Previous best | +| Portable CPU (no backend) | 72,761 ms | 0.03x RT | When decoder has no XNNPACK | + +### Audio Quality +- Metal output is correct: 1.59s speech from "Hello from ExecuTorch." +- Automatic silence trimming works + +### Known Issues +1. `decode_audio` cannot use Metal (missing `cumsum` fallback kernel) +2. `fpa4w` quantization requires `TORCHAO_BUILD_EXPERIMENTAL_MPS=1` +3. libomp symlink needed: `sudo ln -sf /opt/homebrew/Cellar/libomp/*/lib/libomp.dylib /opt/llvm-openmp/lib/libomp.dylib` + +### How to Run + +```bash +# 1. Export (one time, ~8 min) +python3 examples/models/qwen3-tts/export_unified.py \ + --converted-dir examples/models/qwen3-tts/qwen3_tts_artifacts \ + --talker-dir examples/models/qwen3-tts/qwen3_tts_artifacts/talker_converted \ + --output-dir /tmp/qwen3_tts_metal \ + --backend metal --dtype fp32 + +# 2. Generate codes (Python talker) +python3 examples/models/qwen3-tts/generate_codes.py \ + --model-id-or-path Qwen/Qwen3-TTS-12Hz-0.6B-Base \ + --text "Your text here." \ + --output-codes /tmp/codes.bin + +# 3. Decode (C++ Metal runner) +cmake-out/examples/models/qwen3-tts/qwen3_tts_unified_runner \ + --model_path /tmp/qwen3_tts_metal/model.pte \ + --codes_path /tmp/codes.bin \ + --output_wav output.wav +``` diff --git a/examples/models/qwen3-tts/mlx-progress.md b/examples/models/qwen3-tts/mlx-progress.md new file mode 100644 index 00000000000..c126eb010b2 --- /dev/null +++ b/examples/models/qwen3-tts/mlx-progress.md @@ -0,0 +1,69 @@ +# MLX Progress + +Branch: `qwen3-tts-mlx-realtime` + +## Goal + +Build an MLX backend path in-tree for Qwen3-TTS that is measurably faster than +the plain `mlx-audio` reference implementation on the same warmed prompt set. + +## Benchmark protocol + +- Model: `mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16` +- Prompt set: `examples/models/qwen3-tts/benchmark_prompts.txt` +- Reference voice: + - audio: `poem.wav` + - text: `This is what my voice sounds like.` +- Metric: audio seconds / generation seconds (`> 1` means faster than realtime) +- Mode: warmed sequential generation after model load +- Seed: `123 + prompt_idx` +- Streaming: enabled (`--stream`, `streaming_interval=2.0`) + +## Current status + +| Path | Avg throughput | Total throughput | Avg first audio | Result | +|------|----------------|------------------|-----------------|--------| +| Baseline `mlx-audio` generate | `0.540x` | `0.546x` | `5.50s` | reference | +| Cached MLX session backend | `0.556x` | `0.559x` | `5.40s` | `1.030x` faster | + +## What changed + +Added `mlx_backend.py` with a persistent ICL session that: + +- loads the local `mlx-audio` checkout once +- caches reference-audio speech tokens (`ref_codes`) +- caches projected reference text embeddings +- caches the ICL codec/text prefix overlays used before generation +- reuses upstream `_generate_icl()` by overriding only the expensive + `_prepare_icl_generation_inputs()` step + +This keeps generation semantics close to the upstream MLX path while removing +per-prompt re-encoding of the same reference voice context. The current tuned +streaming interval is `4.0s`, which gave the best throughput on the warmed +three-prompt benchmark. + +## Latest verification + +```bash +python examples/models/qwen3-tts/benchmark_mlx.py --mode both --stream +``` + +Verified output with the tuned default (`streaming_interval=4.0`): + +- Baseline `mlx-audio`: `avg=0.540x`, `total=0.546x` +- Cached session backend: `avg=0.556x`, `total=0.559x` +- Speedup: `1.030x` + +Focused retest of the same seeded prompt set also showed a better best-observed +point at the same interval (`avg=0.565x`, `total=0.569x`), so the current +speedup range is roughly `1.03x` to `1.09x` depending on warm-state noise. + +## Next experiments + +- Measure non-streaming mode with the same seeded prompt set in case the cached + prefix work matters more when decoder chunking is removed. +- Add an apples-to-apples comparison mode against the current XNNPACK benchmark + output so MLX and XNNPACK use the same reporting format. +- Investigate whether the tokenizer regex fix path changes prompt length, + generation length, or throughput in a way that is worth folding into the + benchmark harness. diff --git a/examples/models/qwen3-tts/mlx_backend.py b/examples/models/qwen3-tts/mlx_backend.py new file mode 100644 index 00000000000..2133b133741 --- /dev/null +++ b/examples/models/qwen3-tts/mlx_backend.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +"""MLX helpers for Qwen3-TTS benchmarking and session reuse.""" + +from __future__ import annotations + +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from types import SimpleNamespace +from typing import Optional + + +MLX_AUDIO_REPO_ENV = "MLX_AUDIO_REPO" + + +def _purge_mlx_audio_modules() -> None: + for name in list(sys.modules): + if name == "mlx_audio" or name.startswith("mlx_audio."): + sys.modules.pop(name, None) + + +def _resolve_mlx_audio_repo(explicit_repo: Optional[Path]) -> Optional[Path]: + candidates = [] + if explicit_repo is not None: + candidates.append(Path(explicit_repo).expanduser().resolve()) + env_repo = os.environ.get(MLX_AUDIO_REPO_ENV) + if env_repo: + candidates.append(Path(env_repo).expanduser().resolve()) + + repo_root = Path(__file__).resolve().parents[3] + candidates.append((repo_root.parent / "mlx-audio").resolve()) + + for candidate in candidates: + if (candidate / "mlx_audio").exists(): + return candidate + return None + + +def _load_mlx_symbols(explicit_repo: Optional[Path]): + repo = _resolve_mlx_audio_repo(explicit_repo) + if repo is not None: + repo_str = str(repo) + if not sys.path or sys.path[0] != repo_str: + sys.path.insert(0, repo_str) + _purge_mlx_audio_modules() + + import mlx.core as mx + from mlx_audio.tts.utils import load_model + from mlx_audio.utils import load_audio + + return repo, mx, load_model, load_audio + + +@dataclass +class PromptBenchmark: + elapsed_s: float + audio_s: float + throughput_x: float + first_audio_s: float + chunk_count: int + sample_count: int + + +def _collect_prompt_benchmark(mx, results, elapsed_s: float) -> PromptBenchmark: + if not results: + raise RuntimeError("MLX generate returned no results.") + + audio_chunks = [result.audio for result in results] + if len(audio_chunks) == 1: + waveform = audio_chunks[0] + else: + waveform = mx.concatenate(audio_chunks, axis=0) + mx.eval(waveform) + + sample_rate = getattr(results[-1], "sample_rate", 24000) + sample_count = int(waveform.shape[-1]) + audio_s = sample_count / sample_rate if sample_rate > 0 else 0.0 + throughput_x = audio_s / elapsed_s if elapsed_s > 0.0 else 0.0 + if len(results) == 1: + first_audio_s = elapsed_s + else: + first_audio_s = getattr(results[0], "processing_time_seconds", elapsed_s) + + return PromptBenchmark( + elapsed_s=elapsed_s, + audio_s=audio_s, + throughput_x=throughput_x, + first_audio_s=first_audio_s, + chunk_count=len(results), + sample_count=sample_count, + ) + + +class Qwen3TTSMlxBackend: + """Loads the local mlx-audio Qwen3-TTS model once and reuses it.""" + + def __init__( + self, + model_path: str = "mlx-community/Qwen3-TTS-12Hz-0.6B-Base-bf16", + mlx_audio_repo: Optional[Path] = None, + ) -> None: + repo, mx, load_model, load_audio = _load_mlx_symbols(mlx_audio_repo) + self.repo_path = repo + self.mx = mx + self._load_audio = load_audio + self.model_path = model_path + self.model = load_model(model_path) + + @property + def sample_rate(self) -> int: + return int(self.model.sample_rate) + + def warmup( + self, + *, + text: str, + ref_audio, + ref_text: str, + stream: bool, + streaming_interval: float = 4.0, + seed: Optional[int] = None, + ) -> PromptBenchmark: + return self.benchmark_baseline( + text=text, + ref_audio=ref_audio, + ref_text=ref_text, + stream=stream, + streaming_interval=streaming_interval, + seed=seed, + ) + + def benchmark_baseline( + self, + *, + text: str, + ref_audio, + ref_text: str, + stream: bool, + streaming_interval: float = 4.0, + seed: Optional[int] = None, + temperature: float = 0.9, + top_k: int = 50, + top_p: float = 1.0, + repetition_penalty: float = 1.5, + max_tokens: int = 4096, + ) -> PromptBenchmark: + if isinstance(ref_audio, Path): + ref_audio = str(ref_audio) + if seed is not None: + self.mx.random.seed(seed) + started_at = time.perf_counter() + results = list( + self.model.generate( + text=text, + ref_audio=ref_audio, + ref_text=ref_text, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + stream=stream, + streaming_interval=streaming_interval, + ) + ) + elapsed_s = time.perf_counter() - started_at + return _collect_prompt_benchmark(self.mx, results, elapsed_s) + + def create_icl_session( + self, + *, + ref_audio, + ref_text: str, + language: str = "auto", + ) -> "Qwen3TTSMlxIclSession": + return Qwen3TTSMlxIclSession( + backend=self, + ref_audio=ref_audio, + ref_text=ref_text, + language=language, + ) + + +class Qwen3TTSMlxIclSession: + """Caches the reference ICL conditioning across prompts.""" + + def __init__( + self, + *, + backend: Qwen3TTSMlxBackend, + ref_audio, + ref_text: str, + language: str = "auto", + ) -> None: + self.backend = backend + self.mx = backend.mx + self.model = backend.model + self.ref_text = ref_text + self.language = language + if isinstance(ref_audio, Path): + ref_audio = str(ref_audio) + if isinstance(ref_audio, str): + self.ref_audio = backend._load_audio(ref_audio, sample_rate=self.model.sample_rate) + else: + self.ref_audio = ref_audio + self._cached = self._build_cached_icl_state() + + def _build_cached_icl_state(self): + if self.model.tokenizer is None: + raise ValueError("Tokenizer not loaded on the MLX model.") + if self.model.speech_tokenizer is None: + raise ValueError("Speech tokenizer not loaded on the MLX model.") + + config = self.model.config.talker_config + ref_audio = self.ref_audio + audio_for_spk = ref_audio + if ref_audio.ndim == 1: + ref_audio = ref_audio[None, None, :] + elif ref_audio.ndim == 2: + ref_audio = ref_audio[None, :] + + ref_codes = self.model.speech_tokenizer.encode(ref_audio) + ref_chat = f"<|im_start|>assistant\n{self.ref_text}<|im_end|>\n" + ref_ids = self.mx.array(self.model.tokenizer.encode(ref_chat))[None, :] + ref_text_ids = ref_ids[:, 3:-2] + + tts_tokens = self.mx.array( + [[ + self.model.config.tts_bos_token_id, + self.model.config.tts_eos_token_id, + self.model.config.tts_pad_token_id, + ]] + ) + tts_embeds = self.model.talker.text_projection( + self.model.talker.get_text_embeddings()(tts_tokens) + ) + tts_bos_embed = tts_embeds[:, 0:1, :] + tts_eos_embed = tts_embeds[:, 1:2, :] + tts_pad_embed = tts_embeds[:, 2:3, :] + + ref_text_embed = self.model.talker.text_projection( + self.model.talker.get_text_embeddings()(ref_text_ids) + ) + role_ids = self.mx.array(self.model.tokenizer.encode("<|im_start|>assistant\n"))[ + None, : + ] + role_embed = self.model.talker.text_projection( + self.model.talker.get_text_embeddings()(role_ids) + ) + + first_cb_codes = ref_codes[:, 0, :] + ref_codec_embed = self.model.talker.get_input_embeddings()(first_cb_codes) + for i in range(config.num_code_groups - 1): + cb_codes = ref_codes[:, i + 1, :] + ref_codec_embed = ( + ref_codec_embed + + self.model.talker.code_predictor.codec_embedding[i](cb_codes) + ) + + codec_bos_embed = self.model.talker.get_input_embeddings()( + self.mx.array([[config.codec_bos_id]]) + ) + codec_embed_icl = self.mx.concatenate( + [codec_bos_embed, ref_codec_embed], + axis=1, + ) + codec_pad_embed = self.model.talker.get_input_embeddings()( + self.mx.array([[config.codec_pad_id]]) + ) + codec_with_text_pad = codec_embed_icl + self.mx.broadcast_to( + tts_pad_embed, (1, codec_embed_icl.shape[1], tts_pad_embed.shape[-1]) + ) + ref_text_with_codec_pad = ref_text_embed + self.mx.broadcast_to( + codec_pad_embed, (1, ref_text_embed.shape[1], codec_pad_embed.shape[-1]) + ) + eos_with_codec_pad = tts_eos_embed + codec_pad_embed + + language_id = None + if self.language.lower() != "auto" and config.codec_language_id: + language_id = config.codec_language_id.get(self.language.lower()) + + speaker_embed = None + if self.model.speaker_encoder is not None: + speaker_embed = self.model.extract_speaker_embedding(audio_for_spk) + + if language_id is None: + codec_prefill = [ + config.codec_nothink_id, + config.codec_think_bos_id, + config.codec_think_eos_id, + ] + else: + codec_prefill = [ + config.codec_think_id, + config.codec_think_bos_id, + language_id, + config.codec_think_eos_id, + ] + + codec_prefix_embed = self.model.talker.get_input_embeddings()( + self.mx.array([codec_prefill]) + ) + codec_prefix_suffix = self.model.talker.get_input_embeddings()( + self.mx.array([[config.codec_pad_id, config.codec_bos_id]]) + ) + if speaker_embed is not None: + codec_prefix_embed = self.mx.concatenate( + [ + codec_prefix_embed, + speaker_embed.reshape(1, 1, -1), + codec_prefix_suffix, + ], + axis=1, + ) + else: + codec_prefix_embed = self.mx.concatenate( + [codec_prefix_embed, codec_prefix_suffix], + axis=1, + ) + + pad_count = codec_prefix_embed.shape[1] - 2 + pad_embeds = self.mx.broadcast_to( + tts_pad_embed, + (1, pad_count, tts_pad_embed.shape[-1]), + ) + combined_prefix = self.mx.concatenate([pad_embeds, tts_bos_embed], axis=1) + combined_prefix = combined_prefix + codec_prefix_embed[:, :-1, :] + + self.mx.eval( + ref_codes, + ref_text_embed, + role_embed, + tts_eos_embed, + tts_pad_embed, + codec_pad_embed, + codec_with_text_pad, + ref_text_with_codec_pad, + eos_with_codec_pad, + combined_prefix, + ) + + return SimpleNamespace( + ref_codes=ref_codes, + ref_text_embed=ref_text_embed, + role_embed=role_embed, + tts_eos_embed=tts_eos_embed, + tts_pad_embed=tts_pad_embed, + codec_pad_embed=codec_pad_embed, + codec_with_text_pad=codec_with_text_pad, + ref_text_with_codec_pad=ref_text_with_codec_pad, + eos_with_codec_pad=eos_with_codec_pad, + combined_prefix=combined_prefix, + ) + + def _prepare_cached_icl_generation_inputs( + self, + text: str, + ref_audio, + ref_text: str, + language: str = "auto", + ): + del ref_audio, ref_text + if language.lower() != self.language.lower(): + raise ValueError( + "Cached ICL session language mismatch: " + f"expected {self.language!r}, got {language!r}." + ) + + target_chat = ( + f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" + ) + target_ids = self.mx.array(self.model.tokenizer.encode(target_chat))[None, :] + text_ids = target_ids[:, 3:-5] + + target_text_embed = self.model.talker.text_projection( + self.model.talker.get_text_embeddings()(text_ids) + ) + target_text_with_codec_pad = target_text_embed + self.mx.broadcast_to( + self._cached.codec_pad_embed, + (1, target_text_embed.shape[1], self._cached.codec_pad_embed.shape[-1]), + ) + text_with_codec_pad = self.mx.concatenate( + [ + self._cached.ref_text_with_codec_pad, + target_text_with_codec_pad, + self._cached.eos_with_codec_pad, + ], + axis=1, + ) + icl_input_embed = self.mx.concatenate( + [text_with_codec_pad, self._cached.codec_with_text_pad], + axis=1, + ) + input_embeds = self.mx.concatenate( + [self._cached.role_embed, self._cached.combined_prefix, icl_input_embed], + axis=1, + ) + return ( + input_embeds, + self._cached.tts_pad_embed, + self._cached.tts_pad_embed, + self._cached.ref_codes, + ) + + def generate( + self, + *, + text: str, + stream: bool, + streaming_interval: float = 4.0, + streaming_context_size: int = 25, + seed: Optional[int] = None, + temperature: float = 0.9, + top_k: int = 50, + top_p: float = 1.0, + repetition_penalty: float = 1.5, + max_tokens: int = 4096, + verbose: bool = False, + ): + if seed is not None: + self.mx.random.seed(seed) + + original_prepare = self.model._prepare_icl_generation_inputs + + def iterator(): + self.model._prepare_icl_generation_inputs = ( + self._prepare_cached_icl_generation_inputs + ) + try: + yield from self.model._generate_icl( + text=text, + ref_audio=self.ref_audio, + ref_text=self.ref_text, + language=self.language, + temperature=temperature, + max_tokens=max_tokens, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + verbose=verbose, + stream=stream, + streaming_interval=streaming_interval, + streaming_context_size=streaming_context_size, + ) + finally: + self.model._prepare_icl_generation_inputs = original_prepare + + return iterator() + + def benchmark( + self, + *, + text: str, + stream: bool, + streaming_interval: float = 4.0, + streaming_context_size: int = 25, + seed: Optional[int] = None, + temperature: float = 0.9, + top_k: int = 50, + top_p: float = 1.0, + repetition_penalty: float = 1.5, + max_tokens: int = 4096, + ) -> PromptBenchmark: + started_at = time.perf_counter() + results = list( + self.generate( + text=text, + stream=stream, + streaming_interval=streaming_interval, + streaming_context_size=streaming_context_size, + seed=seed, + temperature=temperature, + top_k=top_k, + top_p=top_p, + repetition_penalty=repetition_penalty, + max_tokens=max_tokens, + ) + ) + elapsed_s = time.perf_counter() - started_at + return _collect_prompt_benchmark(self.mx, results, elapsed_s) diff --git a/examples/models/qwen3-tts/model.py b/examples/models/qwen3-tts/model.py new file mode 100644 index 00000000000..ffb81471f06 --- /dev/null +++ b/examples/models/qwen3-tts/model.py @@ -0,0 +1,205 @@ +import json +import struct +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Tuple + +import torch +import torch.nn as nn +from transformers import modeling_rope_utils as hf_rope_utils +from transformers.utils import generic as hf_generic + +if not hasattr(hf_generic, "check_model_inputs"): + def _identity_check_model_inputs(*args, **kwargs): + def decorator(fn): + return fn + + return decorator + + hf_generic.check_model_inputs = _identity_check_model_inputs + +if "default" not in hf_rope_utils.ROPE_INIT_FUNCTIONS: + def _compute_default_rope_parameters(config, device=None, seq_len=None, layer_type=None): + if hasattr(config, "standardize_rope_params"): + config.standardize_rope_params() + rope_parameters = getattr(config, "rope_parameters", None) + if rope_parameters is None: + base = getattr(config, "rope_theta", 10000.0) + partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0) + else: + rope_parameters = ( + rope_parameters[layer_type] if layer_type is not None else rope_parameters + ) + base = rope_parameters.get("rope_theta", getattr(config, "rope_theta", 10000.0)) + partial_rotary_factor = rope_parameters.get( + "partial_rotary_factor", + getattr(config, "partial_rotary_factor", 1.0), + ) + head_dim = getattr(config, "head_dim", None) or ( + config.hidden_size // config.num_attention_heads + ) + dim = int(head_dim * partial_rotary_factor) + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, dim, 2, dtype=torch.int64).to( + device=device, dtype=torch.float + ) + / dim + ) + ) + return inv_freq, 1.0 + + hf_rope_utils.ROPE_INIT_FUNCTIONS["default"] = _compute_default_rope_parameters + +from qwen_tts.core.tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import ( + Qwen3TTSTokenizerV2DecoderConfig, +) +from qwen_tts.core.tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import ( + Qwen3TTSTokenizerV2Decoder, +) + + +@dataclass +class DecoderExportMetadata: + model_id_or_path: str + tokenizer_type: str + tts_model_type: str + decoder_checkpoint: str + output_sample_rate: int + decode_upsample_rate: int + num_quantizers: int + codebook_size: int + decoder_config: Dict + + def to_constant_methods(self) -> Dict[str, int]: + return { + "output_sample_rate": int(self.output_sample_rate), + "decode_upsample_rate": int(self.decode_upsample_rate), + "num_quantizers": int(self.num_quantizers), + "codebook_size": int(self.codebook_size), + } + + @classmethod + def from_json(cls, path: Path) -> "DecoderExportMetadata": + with path.open("r", encoding="utf-8") as f: + raw = json.load(f) + return cls( + model_id_or_path=raw["model_id_or_path"], + tokenizer_type=raw["tokenizer_type"], + tts_model_type=raw["tts_model_type"], + decoder_checkpoint=raw["decoder_checkpoint"], + output_sample_rate=int(raw["output_sample_rate"]), + decode_upsample_rate=int(raw["decode_upsample_rate"]), + num_quantizers=int(raw["num_quantizers"]), + codebook_size=int(raw["codebook_size"]), + decoder_config=raw["decoder_config"], + ) + + def to_json(self, path: Path) -> None: + payload = { + "model_id_or_path": self.model_id_or_path, + "tokenizer_type": self.tokenizer_type, + "tts_model_type": self.tts_model_type, + "decoder_checkpoint": self.decoder_checkpoint, + "output_sample_rate": self.output_sample_rate, + "decode_upsample_rate": self.decode_upsample_rate, + "num_quantizers": self.num_quantizers, + "codebook_size": self.codebook_size, + "decoder_config": self.decoder_config, + } + with path.open("w", encoding="utf-8") as f: + json.dump(payload, f, indent=2, sort_keys=True) + + +class Qwen3TTSSpeechDecoderExport(nn.Module): + """ + Export wrapper for speech tokenizer decode path (audio code -> waveform). + """ + + def __init__(self, decoder: Qwen3TTSTokenizerV2Decoder, decode_upsample_rate: int): + super().__init__() + self.decoder = decoder + self.decode_upsample_rate = int(decode_upsample_rate) + + def forward(self, audio_codes: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + if audio_codes.dim() != 3: + raise ValueError( + f"audio_codes must be rank-3 [B, T, Q], got {tuple(audio_codes.shape)}" + ) + audio_lengths = (audio_codes[..., 0] > -1).sum(1) * self.decode_upsample_rate + # Decoder expects non-negative code ids. + clamped_codes = torch.clamp(audio_codes, min=0) + wav = self.decoder(clamped_codes.transpose(1, 2)).squeeze(1) + return wav, audio_lengths + + +def load_decoder_from_metadata( + metadata: DecoderExportMetadata, checkpoint_path: Path, dtype: torch.dtype +) -> Qwen3TTSTokenizerV2Decoder: + decoder_cfg = Qwen3TTSTokenizerV2DecoderConfig(**metadata.decoder_config) + decoder = Qwen3TTSTokenizerV2Decoder(decoder_cfg) + state = torch.load(checkpoint_path, map_location="cpu", weights_only=True) + decoder.load_state_dict(state, strict=True) + decoder.eval() + decoder.to(dtype=dtype) + return decoder + + +def make_decode_export_module( + metadata: DecoderExportMetadata, checkpoint_path: Path, dtype: torch.dtype +) -> Qwen3TTSSpeechDecoderExport: + decoder = load_decoder_from_metadata(metadata, checkpoint_path, dtype=dtype) + module = Qwen3TTSSpeechDecoderExport( + decoder=decoder, decode_upsample_rate=metadata.decode_upsample_rate + ) + module.eval() + module.to(dtype=dtype) + return module + + +def make_sample_codes( + codebook_size: int, + num_quantizers: int, + code_len: int, + device: str = "cpu", +) -> torch.Tensor: + return torch.randint( + low=0, + high=codebook_size, + size=(1, code_len, num_quantizers), + dtype=torch.long, + device=device, + ) + + +def write_codes_binary(path: Path, codes: torch.Tensor) -> None: + """ + Write codec ids as a simple binary format: + - int32 codes_len + - int32 num_quantizers + - int32[codes_len * num_quantizers] flattened row-major + """ + if codes.dim() != 2: + raise ValueError( + f"codes tensor must be rank-2 [T, Q], got shape={tuple(codes.shape)}" + ) + codes_i32 = codes.to(dtype=torch.int32).contiguous().cpu() + t_len, num_q = int(codes_i32.shape[0]), int(codes_i32.shape[1]) + flat_values: List[int] = [int(v) for v in codes_i32.view(-1).tolist()] + with path.open("wb") as f: + f.write(struct.pack(" torch.Tensor: + with path.open("rb") as f: + header = f.read(8) + if len(header) != 8: + raise ValueError(f"Invalid codes file header: {path}") + t_len, num_q = struct.unpack("