Skip to content

OOM with kokoro models #20178

@maifeeulasad

Description

@maifeeulasad

🐛 Describe the bug

So here is my script:

import argparse
import json
import math
import os
import types

import torch
import torch.nn.functional as F
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, to_edge
from kokoro.model import KModel
from torch import nn

# ---------------------------------------------------------------------------
# Static export wrapper
# ---------------------------------------------------------------------------


class StaticKModelForExport(nn.Module):
    """
    Wraps KModel.forward_with_tokens for torch.export.

    Two static dimensions must be fixed at export time:
      - seq_len:         number of input tokens  (e.g. 128)
      - max_audio_frames: maximum mel frames in the alignment target  (e.g. 512)

    pred_dur is clamped so that sum(pred_dur) == max_audio_frames exactly,
    giving the export a fully-static shape for the alignment matrix.
    """

    def __init__(self, kmodel: KModel, max_audio_frames: int):
        super().__init__()
        self.kmodel = kmodel
        self.max_audio_frames = max_audio_frames

    def forward(
        self,
        input_ids: torch.LongTensor,  # (1, seq_len)
        ref_s: torch.FloatTensor,  # (1, 256)
        speed: torch.FloatTensor,  # scalar tensor
    ) -> tuple[torch.FloatTensor, torch.LongTensor]:
        device = input_ids.device
        batch_size = input_ids.shape[0]  # always 1 for export
        seq_len = input_ids.shape[1]

        # --- masks (static because seq_len is fixed) ---
        input_lengths = torch.full(
            (batch_size,), seq_len, device=device, dtype=torch.long
        )
        text_positions = torch.arange(seq_len, device=device).unsqueeze(
            0
        )  # (1, seq_len)
        text_mask = torch.gt(
            text_positions + 1, input_lengths.unsqueeze(1)
        )  # (1, seq_len)

        # --- BERT duration embedding ---
        bert_dur = self.kmodel.bert(input_ids, attention_mask=(~text_mask).int())
        d_en = self.kmodel.bert_encoder(bert_dur).transpose(
            -1, -2
        )  # (1, hidden, seq_len)

        s = ref_s[:, 128:]  # (1, 128) style vector

        # --- DurationEncoder (text_encoder) ---
        # Uses patched forward: no pack_padded_sequence, explicit h0/c0.
        d = self.kmodel.predictor.text_encoder(d_en, s, input_lengths, text_mask)
        # d: (1, seq_len, d_hid + style_dim)  after DurationEncoder

        # --- Duration LSTM ---
        pred_lstm = self.kmodel.predictor.lstm
        num_dirs = 2 if pred_lstm.bidirectional else 1
        h0 = d.new_zeros(
            pred_lstm.num_layers * num_dirs, batch_size, pred_lstm.hidden_size
        )
        c0 = d.new_zeros(
            pred_lstm.num_layers * num_dirs, batch_size, pred_lstm.hidden_size
        )
        x, _ = pred_lstm(d, (h0, c0))  # (1, seq_len, d_hid)

        # --- Duration projection ---
        duration = self.kmodel.predictor.duration_proj(x)  # (1, seq_len, max_dur)
        duration = torch.sigmoid(duration).sum(axis=-1) / speed  # (1, seq_len)
        pred_dur = torch.round(duration).clamp(min=1).long().squeeze(0)  # (seq_len,)

        # --- Clamp pred_dur so total frames == max_audio_frames (static shape) ---
        # Strategy: scale durations proportionally then fix remainder on last token.
        total = pred_dur.sum()
        # Integer-scale to fit within max_audio_frames
        pred_dur_clamped = torch.clamp(
            pred_dur,
            min=1,
            max=self.max_audio_frames
            - seq_len
            + 1,  # ensure at least 1 per other token
        )
        # Trim/pad the last token to make the total exactly max_audio_frames
        deficit = self.max_audio_frames - pred_dur_clamped.sum()
        # Add deficit to last position (may be negative = trim)
        last_val = (pred_dur_clamped[-1] + deficit).clamp(min=1)
        pred_dur_final = torch.cat(
            [pred_dur_clamped[:-1], last_val.unsqueeze(0)], dim=0
        )

        # --- Build alignment target: (1, seq_len, max_audio_frames) ---
        # pred_aln_trg[token_idx, frame_idx] = 1 iff frame belongs to that token
        frame_indices = torch.arange(
            self.max_audio_frames, device=device
        )  # (max_audio_frames,)
        # boundaries: cumulative sum gives end-frame (exclusive) for each token
        ends = torch.cumsum(pred_dur_final, dim=0)  # (seq_len,)
        starts = torch.cat(
            [torch.zeros(1, device=device, dtype=torch.long), ends[:-1]], dim=0
        )
        # (seq_len, max_audio_frames) boolean
        pred_aln_trg = (frame_indices.unsqueeze(0) >= starts.unsqueeze(1)) & (
            frame_indices.unsqueeze(0) < ends.unsqueeze(1)
        )
        pred_aln_trg = pred_aln_trg.float().unsqueeze(
            0
        )  # (1, seq_len, max_audio_frames)

        # --- en: expand duration-encoded features along time ---
        en = d.transpose(-1, -2) @ pred_aln_trg  # (1, d_hid+sty, max_audio_frames)

        # --- Shared LSTM for F0 / N prediction ---
        shared = self.kmodel.predictor.shared
        shared_num_dirs = 2 if shared.bidirectional else 1
        sh0 = en.new_zeros(
            shared.num_layers * shared_num_dirs, batch_size, shared.hidden_size
        )
        sc0 = en.new_zeros(
            shared.num_layers * shared_num_dirs, batch_size, shared.hidden_size
        )
        shared_x, _ = shared(
            en.transpose(-1, -2), (sh0, sc0)
        )  # (1, max_audio_frames, d_hid)

        # --- F0 prediction ---
        F0_pred = shared_x.transpose(-1, -2)  # (1, d_hid, max_audio_frames)
        for block in self.kmodel.predictor.F0:
            F0_pred = block(F0_pred, s)
        F0_pred = self.kmodel.predictor.F0_proj(F0_pred).squeeze(
            1
        )  # (1, max_audio_frames*2) after upsample

        # --- N (noise) prediction ---
        N_pred = shared_x.transpose(-1, -2)
        for block in self.kmodel.predictor.N:
            N_pred = block(N_pred, s)
        N_pred = self.kmodel.predictor.N_proj(N_pred).squeeze(1)

        # --- Text encoder (patched: no pack_padded_sequence) ---
        t_en = self.kmodel.text_encoder(input_ids, input_lengths, text_mask)
        asr = t_en @ pred_aln_trg  # (1, hidden, max_audio_frames)

        # --- Decoder / vocoder ---
        audio = self.kmodel.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()

        return audio, pred_dur_final


# ---------------------------------------------------------------------------
# Pre-export patches
# ---------------------------------------------------------------------------


def prepare_for_export(kmodel: KModel) -> None:
    """Apply all patches required to make the model torch.export-friendly."""
    # Force BERT to use eager (non-SDPA) attention — SDPA has data-dependent
    # shape branches that torch.export cannot trace through.
    bert_config = getattr(getattr(kmodel, "bert", None), "config", None)
    if bert_config is not None:
        bert_config.attn_implementation = "eager"
        bert_config._attn_implementation = "eager"

    patch_duration_encoder(kmodel.predictor.text_encoder)
    patch_text_encoder(kmodel.text_encoder)
    patch_vocoder_noise(kmodel)


def patch_duration_encoder(module) -> None:
    """
    Replace DurationEncoder.forward to eliminate pack_padded_sequence.

    Bug fixed vs. original:
      The original patched version re-concatenated the style vector `s` inside the
      AdaLayerNorm branch AFTER it was already prepended before the loop, causing a
      channel-count mismatch on the second LSTM.  This version matches the structure
      of the original DurationEncoder.forward exactly, only replacing
      pack_padded_sequence with explicit h0/c0 zeros.
    """

    def forward(self, x, style, text_lengths, m):
        # x: (1, d_hid, seq_len)  — comes in as d_en from bert_encoder
        masks = m  # (1, seq_len) bool
        x = x.permute(2, 0, 1)  # (seq_len, 1, d_hid)
        s = style.expand(x.shape[0], x.shape[1], -1)  # (seq_len, 1, sty_dim)
        x = torch.cat([x, s], axis=-1)  # (seq_len, 1, d_hid+sty_dim)
        x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
        x = x.transpose(0, 1)  # (1, seq_len, d_hid+sty_dim)
        x = x.transpose(-1, -2)  # (1, d_hid+sty_dim, seq_len)

        for block in self.lstms:
            if block.__class__.__name__ == "AdaLayerNorm":
                # AdaLayerNorm: expects (1, d_hid, seq_len), returns same
                x = block(x.transpose(-1, -2), style).transpose(-1, -2)
                # Re-append style so next LSTM sees d_hid+sty_dim channels
                x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
                x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
            else:
                # LSTM block — replace pack_padded_sequence with explicit h0/c0
                x = x.transpose(-1, -2)  # (1, seq_len, d_hid+sty_dim)
                num_dirs = 2 if block.bidirectional else 1
                h0 = x.new_zeros(
                    block.num_layers * num_dirs, x.shape[0], block.hidden_size
                )
                c0 = x.new_zeros(
                    block.num_layers * num_dirs, x.shape[0], block.hidden_size
                )
                block.flatten_parameters()
                x, _ = block(x, (h0, c0))  # (1, seq_len, d_hid)
                x = F.dropout(x, p=self.dropout, training=False)
                x = x.transpose(-1, -2)  # (1, d_hid, seq_len)
                x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)

        return x.transpose(-1, -2)  # (1, seq_len, d_hid)

    module.forward = types.MethodType(forward, module)


def patch_text_encoder(module) -> None:
    """
    Replace TextEncoder.forward to eliminate pack_padded_sequence and
    the subsequent pad_packed_sequence + manual zero-padding step.
    """

    def forward(self, x, input_lengths, m):
        x = self.embedding(x)  # (1, seq_len, channels)
        x = x.transpose(1, 2)  # (1, channels, seq_len)
        m = m.unsqueeze(1)  # (1, 1, seq_len)
        x.masked_fill_(m, 0.0)
        for c in self.cnn:
            x = c(x)
            x.masked_fill_(m, 0.0)
        x = x.transpose(1, 2)  # (1, seq_len, channels)
        num_dirs = 2 if self.lstm.bidirectional else 1
        h0 = x.new_zeros(
            self.lstm.num_layers * num_dirs, x.shape[0], self.lstm.hidden_size
        )
        c0 = x.new_zeros(
            self.lstm.num_layers * num_dirs, x.shape[0], self.lstm.hidden_size
        )
        self.lstm.flatten_parameters()
        x, _ = self.lstm(x, (h0, c0))  # (1, seq_len, channels)
        x = x.transpose(-1, -2)  # (1, channels, seq_len)
        x.masked_fill_(m, 0.0)
        return x

    module.forward = types.MethodType(forward, module)


def patch_vocoder_noise(kmodel: KModel) -> None:
    """
    Replace the stochastic branches in SineGen and SourceModuleHnNSF with
    deterministic equivalents so torch.export can trace through them.

    Changes:
      - SineGen._f02sine: removes rand_ini phase randomization (uses zero phase).
      - SineGen.forward:  removes noise addition; returns zero noise tensor.
      - SourceModuleHnNSF.forward: removes torch.randn noise; returns zeros.
    """

    def sine_f02sine(self, f0_values):
        # f0_values: (B, T, dim)
        rad_values = (f0_values / self.sampling_rate) % 1
        if not self.flag_for_pulse:
            rad_values = F.interpolate(
                rad_values.transpose(1, 2),
                scale_factor=1 / self.upsample_scale,
                mode="linear",
                recompute_scale_factor=False,
            ).transpose(1, 2)
            phase = torch.cumsum(rad_values, dim=1) * 2 * math.pi
            phase = F.interpolate(
                phase.transpose(1, 2) * self.upsample_scale,
                scale_factor=self.upsample_scale,
                mode="linear",
                recompute_scale_factor=False,
            ).transpose(1, 2)
            return torch.sin(phase)
        i_phase = torch.cumsum(rad_values, dim=1)
        return torch.cos(i_phase * 2 * math.pi)

    def sine_forward(self, f0):
        # Deterministic: no rand_ini, no noise
        harmonics = torch.arange(
            1,
            self.harmonic_num + 2,
            device=f0.device,
            dtype=f0.dtype,
        ).view(
            1, 1, -1
        )  # (1, 1, dim)
        fn = torch.multiply(f0, harmonics)  # (B, T, dim)
        sine_waves = self._f02sine(fn) * self.sine_amp
        uv = self._f02uv(f0)
        noise = torch.zeros_like(sine_waves)
        sine_waves = sine_waves * uv  # silence unvoiced regions
        return sine_waves, uv, noise

    def source_forward(self, x):
        # x: F0 (B, T, 1)
        sine_wavs, uv, _ = self.l_sin_gen(x)
        sine_merge = self.l_tanh(self.l_linear(sine_wavs))
        noise = torch.zeros_like(uv)
        return sine_merge, noise, uv

    for module in kmodel.modules():
        cls = module.__class__.__name__
        if cls == "SineGen":
            module._f02sine = types.MethodType(sine_f02sine, module)
            module.forward = types.MethodType(sine_forward, module)
        elif cls == "SourceModuleHnNSF":
            module.forward = types.MethodType(source_forward, module)


# ---------------------------------------------------------------------------
# Export helpers
# ---------------------------------------------------------------------------


def export_main_model(
    kmodel: KModel, output_dir: str, seq_len: int, max_audio_frames: int
) -> None:
    """Export the core model as a .pte file."""
    kmodel.eval()
    prepare_for_export(kmodel)

    wrapper = StaticKModelForExport(kmodel, max_audio_frames=max_audio_frames)
    wrapper.eval()

    example_input_ids = torch.zeros(1, seq_len, dtype=torch.long)
    example_ref_s = torch.randn(1, 256)
    example_speed = torch.tensor(1.0)

    print(
        f"Exporting model: seq_len={seq_len}, max_audio_frames={max_audio_frames} ..."
    )

    exported_program = None
    with torch.no_grad():
        try:
            exported_program = torch.export.export(
                wrapper,
                args=(example_input_ids, example_ref_s, example_speed),
                strict=False,
            )
            print("torch.export succeeded.")
        except Exception as e:
            print(
                f"torch.export failed ({e}), falling back to draft_export for diagnostics..."
            )
            exported_program = torch.export.draft_export(
                wrapper,
                args=(example_input_ids, example_ref_s, example_speed),
                strict=False,
            )
            print(
                "draft_export succeeded (model may have unresolved guards — check output)."
            )

    print("Converting to ExecuTorch edge dialect...")
    edge_program = to_edge(
        exported_program,
        compile_config=EdgeCompileConfig(_check_ir_validity=False),
    )

    print("Partitioning with XNNPACK backend...")
    try:
        edge_program = edge_program.to_backend(XnnpackPartitioner())
        print("XNNPACK partitioning succeeded.")
    except Exception as exc:
        print(f"XNNPACK partitioning failed; continuing with portable kernels: {exc}")

    print("Generating .pte buffer...")
    buffer = edge_program.to_executorch().buffer

    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, "kokoro_main.pte")
    with open(output_path, "wb") as f:
        f.write(buffer)
    print(f"Exported main model → {output_path}  ({len(buffer) / 1024 / 1024:.1f} MB)")


def export_voices(voice_dir: str, output_dir: str) -> None:
    """Copy voice .pt tensor files to the output directory."""
    out_voices = os.path.join(output_dir, "voices")
    os.makedirs(out_voices, exist_ok=True)
    if os.path.isdir(voice_dir):
        for fname in os.listdir(voice_dir):
            if fname.endswith(".pt"):
                src = os.path.join(voice_dir, fname)
                dst = os.path.join(out_voices, fname)
                torch.save(torch.load(src, weights_only=True), dst)
                print(f"Copied voice: {fname}")


def export_config(kmodel: KModel, output_dir: str) -> None:
    """Write vocab + context length config for the Kotlin inference pipeline."""
    export_data = {
        "vocab": kmodel.vocab,
        "context_length": kmodel.context_length,
    }
    config_path = os.path.join(output_dir, "config.json")
    with open(config_path, "w", encoding="utf-8") as f:
        json.dump(export_data, f, ensure_ascii=False, indent=2)
    print(f"Exported config → {config_path}")


# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------


def main() -> None:
    parser = argparse.ArgumentParser(description="Export Kokoro to ExecuTorch .pte")
    parser.add_argument("--repo_id", default="hexgrad/Kokoro-82M")
    parser.add_argument("--model_path", default=None, help="Path to .pth weights file")
    parser.add_argument("--voice_dir", default=None, help="Path to voices/ directory")
    parser.add_argument("--output_dir", default="./exported", help="Output directory")
    parser.add_argument(
        "--seq_len",
        type=int,
        default=128,
        help="Fixed number of input tokens for the exported .pte",
    )
    parser.add_argument(
        "--max_audio_frames",
        type=int,
        default=512,
        help="Fixed number of mel frames (alignment target length). "
        "Must be >= seq_len.  Larger = longer speech supported.",
    )
    parser.add_argument(
        "--enable_complex",
        action="store_true",
        help="Keep Kokoro's native complex STFT path (not ExecuTorch-friendly; "
        "use only if you have a custom ISTFT kernel on device)",
    )
    args = parser.parse_args()

    if args.max_audio_frames < args.seq_len:
        parser.error("--max_audio_frames must be >= --seq_len")

    os.makedirs(args.output_dir, exist_ok=True)

    print(f"Loading KModel from {args.repo_id} ...")
    kmodel = KModel(
        repo_id=args.repo_id,
        model=args.model_path,
        disable_complex=not args.enable_complex,
    )

    export_main_model(kmodel, args.output_dir, args.seq_len, args.max_audio_frames)
    export_config(kmodel, args.output_dir)

    if args.voice_dir:
        export_voices(args.voice_dir, args.output_dir)

    print("\nExport complete!")


if __name__ == "__main__":
    main()

I was trying to generate .pte for android for kokoro. But it always gives OOM. I tried booting up a VM with 64 gb ram and 16 cores. But wheneveer I run ./.venv_kokoro_cpu/bin/python python/convert-kokoro.py after 10 to 15 minutes I get OOM. How can I resolve this??

Versions

Collecting environment information...
PyTorch version: 2.12.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A

OS: Ubuntu 24.04.4 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 4.3.3
Libc version: glibc-2.39

Python version: 3.12.3 (main, Mar 23 2026, 19:04:32) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-124-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 13.1.115
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 590.48.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 5700G with Radeon Graphics
CPU family: 25
Model: 80
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 0
CPU(s) scaling MHz: 84%
CPU max MHz: 4673.0000
CPU min MHz: 400.0000
BogoMIPS: 7586.06
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap ibpb_exit_to_user
Virtualization: AMD-V
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 4 MiB (8 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Mitigation; IBPB before exit to userspace

Versions of relevant libraries:
[pip3] executorch==1.3.1
[pip3] numpy==2.4.6
[pip3] nvidia-cublas==13.1.1.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.20.0.48
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.1
[pip3] nvidia-nccl-cu13==2.29.7
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] pytorch_tokenizers==1.3.0
[pip3] torch==2.12.0
[pip3] torchao==0.17.0
[pip3] triton==3.7.0
[conda] Could not collect

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions