import argparse
import json
import math
import os
import types
import torch
import torch.nn.functional as F
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
from executorch.exir import EdgeCompileConfig, to_edge
from kokoro.model import KModel
from torch import nn
# ---------------------------------------------------------------------------
# Static export wrapper
# ---------------------------------------------------------------------------
class StaticKModelForExport(nn.Module):
"""
Wraps KModel.forward_with_tokens for torch.export.
Two static dimensions must be fixed at export time:
- seq_len: number of input tokens (e.g. 128)
- max_audio_frames: maximum mel frames in the alignment target (e.g. 512)
pred_dur is clamped so that sum(pred_dur) == max_audio_frames exactly,
giving the export a fully-static shape for the alignment matrix.
"""
def __init__(self, kmodel: KModel, max_audio_frames: int):
super().__init__()
self.kmodel = kmodel
self.max_audio_frames = max_audio_frames
def forward(
self,
input_ids: torch.LongTensor, # (1, seq_len)
ref_s: torch.FloatTensor, # (1, 256)
speed: torch.FloatTensor, # scalar tensor
) -> tuple[torch.FloatTensor, torch.LongTensor]:
device = input_ids.device
batch_size = input_ids.shape[0] # always 1 for export
seq_len = input_ids.shape[1]
# --- masks (static because seq_len is fixed) ---
input_lengths = torch.full(
(batch_size,), seq_len, device=device, dtype=torch.long
)
text_positions = torch.arange(seq_len, device=device).unsqueeze(
0
) # (1, seq_len)
text_mask = torch.gt(
text_positions + 1, input_lengths.unsqueeze(1)
) # (1, seq_len)
# --- BERT duration embedding ---
bert_dur = self.kmodel.bert(input_ids, attention_mask=(~text_mask).int())
d_en = self.kmodel.bert_encoder(bert_dur).transpose(
-1, -2
) # (1, hidden, seq_len)
s = ref_s[:, 128:] # (1, 128) style vector
# --- DurationEncoder (text_encoder) ---
# Uses patched forward: no pack_padded_sequence, explicit h0/c0.
d = self.kmodel.predictor.text_encoder(d_en, s, input_lengths, text_mask)
# d: (1, seq_len, d_hid + style_dim) after DurationEncoder
# --- Duration LSTM ---
pred_lstm = self.kmodel.predictor.lstm
num_dirs = 2 if pred_lstm.bidirectional else 1
h0 = d.new_zeros(
pred_lstm.num_layers * num_dirs, batch_size, pred_lstm.hidden_size
)
c0 = d.new_zeros(
pred_lstm.num_layers * num_dirs, batch_size, pred_lstm.hidden_size
)
x, _ = pred_lstm(d, (h0, c0)) # (1, seq_len, d_hid)
# --- Duration projection ---
duration = self.kmodel.predictor.duration_proj(x) # (1, seq_len, max_dur)
duration = torch.sigmoid(duration).sum(axis=-1) / speed # (1, seq_len)
pred_dur = torch.round(duration).clamp(min=1).long().squeeze(0) # (seq_len,)
# --- Clamp pred_dur so total frames == max_audio_frames (static shape) ---
# Strategy: scale durations proportionally then fix remainder on last token.
total = pred_dur.sum()
# Integer-scale to fit within max_audio_frames
pred_dur_clamped = torch.clamp(
pred_dur,
min=1,
max=self.max_audio_frames
- seq_len
+ 1, # ensure at least 1 per other token
)
# Trim/pad the last token to make the total exactly max_audio_frames
deficit = self.max_audio_frames - pred_dur_clamped.sum()
# Add deficit to last position (may be negative = trim)
last_val = (pred_dur_clamped[-1] + deficit).clamp(min=1)
pred_dur_final = torch.cat(
[pred_dur_clamped[:-1], last_val.unsqueeze(0)], dim=0
)
# --- Build alignment target: (1, seq_len, max_audio_frames) ---
# pred_aln_trg[token_idx, frame_idx] = 1 iff frame belongs to that token
frame_indices = torch.arange(
self.max_audio_frames, device=device
) # (max_audio_frames,)
# boundaries: cumulative sum gives end-frame (exclusive) for each token
ends = torch.cumsum(pred_dur_final, dim=0) # (seq_len,)
starts = torch.cat(
[torch.zeros(1, device=device, dtype=torch.long), ends[:-1]], dim=0
)
# (seq_len, max_audio_frames) boolean
pred_aln_trg = (frame_indices.unsqueeze(0) >= starts.unsqueeze(1)) & (
frame_indices.unsqueeze(0) < ends.unsqueeze(1)
)
pred_aln_trg = pred_aln_trg.float().unsqueeze(
0
) # (1, seq_len, max_audio_frames)
# --- en: expand duration-encoded features along time ---
en = d.transpose(-1, -2) @ pred_aln_trg # (1, d_hid+sty, max_audio_frames)
# --- Shared LSTM for F0 / N prediction ---
shared = self.kmodel.predictor.shared
shared_num_dirs = 2 if shared.bidirectional else 1
sh0 = en.new_zeros(
shared.num_layers * shared_num_dirs, batch_size, shared.hidden_size
)
sc0 = en.new_zeros(
shared.num_layers * shared_num_dirs, batch_size, shared.hidden_size
)
shared_x, _ = shared(
en.transpose(-1, -2), (sh0, sc0)
) # (1, max_audio_frames, d_hid)
# --- F0 prediction ---
F0_pred = shared_x.transpose(-1, -2) # (1, d_hid, max_audio_frames)
for block in self.kmodel.predictor.F0:
F0_pred = block(F0_pred, s)
F0_pred = self.kmodel.predictor.F0_proj(F0_pred).squeeze(
1
) # (1, max_audio_frames*2) after upsample
# --- N (noise) prediction ---
N_pred = shared_x.transpose(-1, -2)
for block in self.kmodel.predictor.N:
N_pred = block(N_pred, s)
N_pred = self.kmodel.predictor.N_proj(N_pred).squeeze(1)
# --- Text encoder (patched: no pack_padded_sequence) ---
t_en = self.kmodel.text_encoder(input_ids, input_lengths, text_mask)
asr = t_en @ pred_aln_trg # (1, hidden, max_audio_frames)
# --- Decoder / vocoder ---
audio = self.kmodel.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
return audio, pred_dur_final
# ---------------------------------------------------------------------------
# Pre-export patches
# ---------------------------------------------------------------------------
def prepare_for_export(kmodel: KModel) -> None:
"""Apply all patches required to make the model torch.export-friendly."""
# Force BERT to use eager (non-SDPA) attention — SDPA has data-dependent
# shape branches that torch.export cannot trace through.
bert_config = getattr(getattr(kmodel, "bert", None), "config", None)
if bert_config is not None:
bert_config.attn_implementation = "eager"
bert_config._attn_implementation = "eager"
patch_duration_encoder(kmodel.predictor.text_encoder)
patch_text_encoder(kmodel.text_encoder)
patch_vocoder_noise(kmodel)
def patch_duration_encoder(module) -> None:
"""
Replace DurationEncoder.forward to eliminate pack_padded_sequence.
Bug fixed vs. original:
The original patched version re-concatenated the style vector `s` inside the
AdaLayerNorm branch AFTER it was already prepended before the loop, causing a
channel-count mismatch on the second LSTM. This version matches the structure
of the original DurationEncoder.forward exactly, only replacing
pack_padded_sequence with explicit h0/c0 zeros.
"""
def forward(self, x, style, text_lengths, m):
# x: (1, d_hid, seq_len) — comes in as d_en from bert_encoder
masks = m # (1, seq_len) bool
x = x.permute(2, 0, 1) # (seq_len, 1, d_hid)
s = style.expand(x.shape[0], x.shape[1], -1) # (seq_len, 1, sty_dim)
x = torch.cat([x, s], axis=-1) # (seq_len, 1, d_hid+sty_dim)
x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
x = x.transpose(0, 1) # (1, seq_len, d_hid+sty_dim)
x = x.transpose(-1, -2) # (1, d_hid+sty_dim, seq_len)
for block in self.lstms:
if block.__class__.__name__ == "AdaLayerNorm":
# AdaLayerNorm: expects (1, d_hid, seq_len), returns same
x = block(x.transpose(-1, -2), style).transpose(-1, -2)
# Re-append style so next LSTM sees d_hid+sty_dim channels
x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
else:
# LSTM block — replace pack_padded_sequence with explicit h0/c0
x = x.transpose(-1, -2) # (1, seq_len, d_hid+sty_dim)
num_dirs = 2 if block.bidirectional else 1
h0 = x.new_zeros(
block.num_layers * num_dirs, x.shape[0], block.hidden_size
)
c0 = x.new_zeros(
block.num_layers * num_dirs, x.shape[0], block.hidden_size
)
block.flatten_parameters()
x, _ = block(x, (h0, c0)) # (1, seq_len, d_hid)
x = F.dropout(x, p=self.dropout, training=False)
x = x.transpose(-1, -2) # (1, d_hid, seq_len)
x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
return x.transpose(-1, -2) # (1, seq_len, d_hid)
module.forward = types.MethodType(forward, module)
def patch_text_encoder(module) -> None:
"""
Replace TextEncoder.forward to eliminate pack_padded_sequence and
the subsequent pad_packed_sequence + manual zero-padding step.
"""
def forward(self, x, input_lengths, m):
x = self.embedding(x) # (1, seq_len, channels)
x = x.transpose(1, 2) # (1, channels, seq_len)
m = m.unsqueeze(1) # (1, 1, seq_len)
x.masked_fill_(m, 0.0)
for c in self.cnn:
x = c(x)
x.masked_fill_(m, 0.0)
x = x.transpose(1, 2) # (1, seq_len, channels)
num_dirs = 2 if self.lstm.bidirectional else 1
h0 = x.new_zeros(
self.lstm.num_layers * num_dirs, x.shape[0], self.lstm.hidden_size
)
c0 = x.new_zeros(
self.lstm.num_layers * num_dirs, x.shape[0], self.lstm.hidden_size
)
self.lstm.flatten_parameters()
x, _ = self.lstm(x, (h0, c0)) # (1, seq_len, channels)
x = x.transpose(-1, -2) # (1, channels, seq_len)
x.masked_fill_(m, 0.0)
return x
module.forward = types.MethodType(forward, module)
def patch_vocoder_noise(kmodel: KModel) -> None:
"""
Replace the stochastic branches in SineGen and SourceModuleHnNSF with
deterministic equivalents so torch.export can trace through them.
Changes:
- SineGen._f02sine: removes rand_ini phase randomization (uses zero phase).
- SineGen.forward: removes noise addition; returns zero noise tensor.
- SourceModuleHnNSF.forward: removes torch.randn noise; returns zeros.
"""
def sine_f02sine(self, f0_values):
# f0_values: (B, T, dim)
rad_values = (f0_values / self.sampling_rate) % 1
if not self.flag_for_pulse:
rad_values = F.interpolate(
rad_values.transpose(1, 2),
scale_factor=1 / self.upsample_scale,
mode="linear",
recompute_scale_factor=False,
).transpose(1, 2)
phase = torch.cumsum(rad_values, dim=1) * 2 * math.pi
phase = F.interpolate(
phase.transpose(1, 2) * self.upsample_scale,
scale_factor=self.upsample_scale,
mode="linear",
recompute_scale_factor=False,
).transpose(1, 2)
return torch.sin(phase)
i_phase = torch.cumsum(rad_values, dim=1)
return torch.cos(i_phase * 2 * math.pi)
def sine_forward(self, f0):
# Deterministic: no rand_ini, no noise
harmonics = torch.arange(
1,
self.harmonic_num + 2,
device=f0.device,
dtype=f0.dtype,
).view(
1, 1, -1
) # (1, 1, dim)
fn = torch.multiply(f0, harmonics) # (B, T, dim)
sine_waves = self._f02sine(fn) * self.sine_amp
uv = self._f02uv(f0)
noise = torch.zeros_like(sine_waves)
sine_waves = sine_waves * uv # silence unvoiced regions
return sine_waves, uv, noise
def source_forward(self, x):
# x: F0 (B, T, 1)
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
noise = torch.zeros_like(uv)
return sine_merge, noise, uv
for module in kmodel.modules():
cls = module.__class__.__name__
if cls == "SineGen":
module._f02sine = types.MethodType(sine_f02sine, module)
module.forward = types.MethodType(sine_forward, module)
elif cls == "SourceModuleHnNSF":
module.forward = types.MethodType(source_forward, module)
# ---------------------------------------------------------------------------
# Export helpers
# ---------------------------------------------------------------------------
def export_main_model(
kmodel: KModel, output_dir: str, seq_len: int, max_audio_frames: int
) -> None:
"""Export the core model as a .pte file."""
kmodel.eval()
prepare_for_export(kmodel)
wrapper = StaticKModelForExport(kmodel, max_audio_frames=max_audio_frames)
wrapper.eval()
example_input_ids = torch.zeros(1, seq_len, dtype=torch.long)
example_ref_s = torch.randn(1, 256)
example_speed = torch.tensor(1.0)
print(
f"Exporting model: seq_len={seq_len}, max_audio_frames={max_audio_frames} ..."
)
exported_program = None
with torch.no_grad():
try:
exported_program = torch.export.export(
wrapper,
args=(example_input_ids, example_ref_s, example_speed),
strict=False,
)
print("torch.export succeeded.")
except Exception as e:
print(
f"torch.export failed ({e}), falling back to draft_export for diagnostics..."
)
exported_program = torch.export.draft_export(
wrapper,
args=(example_input_ids, example_ref_s, example_speed),
strict=False,
)
print(
"draft_export succeeded (model may have unresolved guards — check output)."
)
print("Converting to ExecuTorch edge dialect...")
edge_program = to_edge(
exported_program,
compile_config=EdgeCompileConfig(_check_ir_validity=False),
)
print("Partitioning with XNNPACK backend...")
try:
edge_program = edge_program.to_backend(XnnpackPartitioner())
print("XNNPACK partitioning succeeded.")
except Exception as exc:
print(f"XNNPACK partitioning failed; continuing with portable kernels: {exc}")
print("Generating .pte buffer...")
buffer = edge_program.to_executorch().buffer
os.makedirs(output_dir, exist_ok=True)
output_path = os.path.join(output_dir, "kokoro_main.pte")
with open(output_path, "wb") as f:
f.write(buffer)
print(f"Exported main model → {output_path} ({len(buffer) / 1024 / 1024:.1f} MB)")
def export_voices(voice_dir: str, output_dir: str) -> None:
"""Copy voice .pt tensor files to the output directory."""
out_voices = os.path.join(output_dir, "voices")
os.makedirs(out_voices, exist_ok=True)
if os.path.isdir(voice_dir):
for fname in os.listdir(voice_dir):
if fname.endswith(".pt"):
src = os.path.join(voice_dir, fname)
dst = os.path.join(out_voices, fname)
torch.save(torch.load(src, weights_only=True), dst)
print(f"Copied voice: {fname}")
def export_config(kmodel: KModel, output_dir: str) -> None:
"""Write vocab + context length config for the Kotlin inference pipeline."""
export_data = {
"vocab": kmodel.vocab,
"context_length": kmodel.context_length,
}
config_path = os.path.join(output_dir, "config.json")
with open(config_path, "w", encoding="utf-8") as f:
json.dump(export_data, f, ensure_ascii=False, indent=2)
print(f"Exported config → {config_path}")
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(description="Export Kokoro to ExecuTorch .pte")
parser.add_argument("--repo_id", default="hexgrad/Kokoro-82M")
parser.add_argument("--model_path", default=None, help="Path to .pth weights file")
parser.add_argument("--voice_dir", default=None, help="Path to voices/ directory")
parser.add_argument("--output_dir", default="./exported", help="Output directory")
parser.add_argument(
"--seq_len",
type=int,
default=128,
help="Fixed number of input tokens for the exported .pte",
)
parser.add_argument(
"--max_audio_frames",
type=int,
default=512,
help="Fixed number of mel frames (alignment target length). "
"Must be >= seq_len. Larger = longer speech supported.",
)
parser.add_argument(
"--enable_complex",
action="store_true",
help="Keep Kokoro's native complex STFT path (not ExecuTorch-friendly; "
"use only if you have a custom ISTFT kernel on device)",
)
args = parser.parse_args()
if args.max_audio_frames < args.seq_len:
parser.error("--max_audio_frames must be >= --seq_len")
os.makedirs(args.output_dir, exist_ok=True)
print(f"Loading KModel from {args.repo_id} ...")
kmodel = KModel(
repo_id=args.repo_id,
model=args.model_path,
disable_complex=not args.enable_complex,
)
export_main_model(kmodel, args.output_dir, args.seq_len, args.max_audio_frames)
export_config(kmodel, args.output_dir)
if args.voice_dir:
export_voices(args.voice_dir, args.output_dir)
print("\nExport complete!")
if __name__ == "__main__":
main()
Collecting environment information...
PyTorch version: 2.12.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A
Python version: 3.12.3 (main, Mar 23 2026, 19:04:32) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-124-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 13.1.115
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 590.48.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
🐛 Describe the bug
So here is my script:
I was trying to generate
.ptefor android for kokoro. But it always gives OOM. I tried booting up a VM with 64 gb ram and 16 cores. But wheneveer I run./.venv_kokoro_cpu/bin/python python/convert-kokoro.pyafter 10 to 15 minutes I get OOM. How can I resolve this??Versions
Collecting environment information...
PyTorch version: 2.12.0+cu130
Is debug build: False
CUDA used to build PyTorch: 13.0
ROCM used to build PyTorch: N/A
OS: Ubuntu 24.04.4 LTS (x86_64)
GCC version: (Ubuntu 13.3.0-6ubuntu2~24.04.1) 13.3.0
Clang version: 18.1.3 (1ubuntu1)
CMake version: version 4.3.3
Libc version: glibc-2.39
Python version: 3.12.3 (main, Mar 23 2026, 19:04:32) [GCC 13.3.0] (64-bit runtime)
Python platform: Linux-6.8.0-124-generic-x86_64-with-glibc2.39
Is CUDA available: True
CUDA runtime version: 13.1.115
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3060
Nvidia driver version: 590.48.01
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Caching allocator config: N/A
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 48 bits physical, 48 bits virtual
Byte Order: Little Endian
CPU(s): 16
On-line CPU(s) list: 0-15
Vendor ID: AuthenticAMD
Model name: AMD Ryzen 7 5700G with Radeon Graphics
CPU family: 25
Model: 80
Thread(s) per core: 2
Core(s) per socket: 8
Socket(s): 1
Stepping: 0
CPU(s) scaling MHz: 84%
CPU max MHz: 4673.0000
CPU min MHz: 400.0000
BogoMIPS: 7586.06
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibrs ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk clzero irperf xsaveerptr rdpru wbnoinvd cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip pku ospke vaes vpclmulqdq rdpid overflow_recov succor smca fsrm debug_swap ibpb_exit_to_user
Virtualization: AMD-V
L1d cache: 256 KiB (8 instances)
L1i cache: 256 KiB (8 instances)
L2 cache: 4 MiB (8 instances)
L3 cache: 16 MiB (1 instance)
NUMA node(s): 1
NUMA node0 CPU(s): 0-15
Vulnerability Gather data sampling: Not affected
Vulnerability Indirect target selection: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Retpolines; IBPB conditional; IBRS_FW; STIBP always-on; RSB filling; PBRSB-eIBRS Not affected; BHI Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsa: Vulnerable: Clear CPU buffers attempted, no microcode
Vulnerability Tsx async abort: Not affected
Vulnerability Vmscape: Mitigation; IBPB before exit to userspace
Versions of relevant libraries:
[pip3] executorch==1.3.1
[pip3] numpy==2.4.6
[pip3] nvidia-cublas==13.1.1.3
[pip3] nvidia-cuda-cupti==13.0.85
[pip3] nvidia-cuda-nvrtc==13.0.88
[pip3] nvidia-cuda-runtime==13.0.96
[pip3] nvidia-cudnn-cu13==9.20.0.48
[pip3] nvidia-cufft==12.0.0.61
[pip3] nvidia-curand==10.4.0.35
[pip3] nvidia-cusolver==12.0.4.66
[pip3] nvidia-cusparse==12.6.3.3
[pip3] nvidia-cusparselt-cu13==0.8.1
[pip3] nvidia-nccl-cu13==2.29.7
[pip3] nvidia-nvjitlink==13.0.88
[pip3] nvidia-nvtx==13.0.85
[pip3] pytorch_tokenizers==1.3.0
[pip3] torch==2.12.0
[pip3] torchao==0.17.0
[pip3] triton==3.7.0
[conda] Could not collect