Skip to content

byeolki/PolyFlow-TTS

Repository files navigation

PolyFlow-TTS

Zero-shot voice cloning TTS for Korean, English, and Japanese. Fine-tuned from the F5-TTS base checkpoint using Conditional Flow Matching with a Diffusion Transformer backbone.

Overview

PolyFlow-TTS clones any speaker's voice from a short reference clip and synthesizes natural speech in Korean, English, or Japanese — without speaker-specific training. The system is built on three core ideas:

  • OT-CFM (Optimal Transport Conditional Flow Matching) as the generative backbone, replacing diffusion with a straighter probability path and faster inference.
  • Dual-Level LID Injection to condition the model on language identity at both the timestep (AdaLN) and text representation (FiLM) levels without disturbing the pretrained F5-TTS weights at initialization.
  • Asymmetric CFG that applies acoustic and linguistic guidance with separate schedules, allowing the model to balance speaker fidelity and text intelligibility independently.

Architecture

System Overview

flowchart TD
    subgraph TRAIN["Training"]
        direction TB
        tx1["Target Mel  x₁"] --> interp["ψₜ = (1−t)·x₀ + t·x₁"]
        tx0["Gaussian Noise  x₀"] --> interp
        tt["t ~ Uniform(0,1)"] --> interp
        interp --> xt["Noisy Sample  xₜ"]

        tref["Reference Mel"] --> cat["concat  [ref | xₜ]"]
        xt --> cat

        ttext["Input Text"] --> g2p["G2P → IPA + Filler Insertion"]
        g2p --> tenc["Text Encoder  (ConvNeXt V2)"]
        tenc --> et["Text Repr  e_T'"]

        tlang["Language ID"] --> lid["LID Injection"]
        tt --> temb["Timestep Embedding  h_t"]
        temb --> lid

        cat --> dit["DiT  (22 blocks)"]
        et --> dit
        lid --> dit

        dit --> vpred["Predicted Vector Field  v_θ"]
        vpred --> loss["MSE Loss\n‖v_θ − (x₁ − x₀)‖²  ·  target region mask"]
    end

    subgraph INFER["Inference"]
        direction TB
        noise["x_T ~ N(0, I)"] --> ode["Euler ODE  ×16 steps  (sway sampling  α=−1.0)"]

        iref["Reference Mel"] --> cfg["Asymmetric CFG\nv + wA(t)·ΔAcoustic + wL(t)·ΔLinguistic"]
        itext["Input Text"] --> ig2p["G2P → IPA + Filler Insertion"]
        ig2p --> ienc["Text Encoder"]
        ienc --> cfg
        ilang["Language ID"] --> cfg

        cfg --> ode
        ode --> mel["Generated Mel  (100 bands · 24 kHz)"]
        mel --> vocos["Vocos Vocoder  (frozen)"]
        vocos --> wav["Output WAV  24 kHz"]
    end
Loading

During training, a random timestep t ~ Uniform(0,1) is drawn and the target mel x₁ is interpolated toward Gaussian noise x₀ to produce a noisy sample xₜ. The DiT receives the concatenation of the reference mel and xₜ, together with the text representation and language ID, and predicts the vector field v_θ that points from noise to data. The loss is computed only on the target region of the output (not the reference).

During inference, the model starts from pure noise and integrates the learned vector field forward in time using a 16-step Euler ODE. Asymmetric CFG combines three parallel forward passes at each step to steer generation toward the correct speaker identity and text content simultaneously.


1. Text Processing Pipeline

Text goes through three stages before reaching the DiT: G2P conversion, IPA normalization, and filler token insertion.

G2P → IPA

Each language uses a dedicated G2P backend:

Language Tool Notes
Korean g2pk2 Applies phonological rules (liaison, nasalization, tensification), then maps jamo to IPA via a static table
English phonemizer + espeak-ng Directly outputs IPA with stress markers
Japanese pyopenjtalk Produces romaji-style phoneme labels, then maps to IPA

Language-Prefixed Vocab

Korean, English, and Japanese share many surface-identical IPA symbols (e.g. a, i, n) that are phonetically distinct across languages. To prevent the model from conflating them, every IPA token is prefixed with its source language at preprocessing time:

Language Prefix Example tokens
Korean ko_ ko_a, ko_k, ko_tɕ
English en_ en_a, en_æ, en_θ
Japanese ja_ ja_a, ja_ɯ, ja_ɴ

Special tokens — <PAD>, <UNK>, <FILLER>, <BOS>, <EOS> — carry no prefix and are shared across all languages. The merged token set is saved as vocab.json at the end of preprocessing. IPA embeddings for Korean and Japanese that are absent from the original F5-TTS checkpoint are randomly initialized and trained from scratch.

Filler Token Insertion

The DiT expects the input text sequence to have the same length as the target mel sequence (so that text encoder output can be directly added to the mel input frame-by-frame without interpolation). Since IPA phones are denser than mel frames for fast speakers and sparser for slow ones, we insert <FILLER> tokens to bridge the gap.

The insertion is word-proportional:

sentence: "hello world"
phones:    [h ɛ l oʊ]  [w ɜː l d]     total_phones = 9
mel_len:   50 frames                   total_filler = 50 − 9 = 41

word_filler_budget:
  "hello": round(41 × 4/9) = 18   →  [h ɛ l oʊ <F> <F> ... ×18]
  "world": 41 − 18 = 23            →  [w ɜː l d <F> <F> ... ×23]

final sequence length: 4+18 + 4+23 = 49 → pad 1 more at end → 50 ✓

Rules:

  • Each word gets a minimum of 1 filler appended after its last phone, guaranteeing at least one phone-to-filler boundary per word.
  • Any remainder after proportional distribution is appended at the sequence end.
  • Samples where mel_len < n_phones (physically impossible to fill) are dropped during data loading.

This strategy means the filler distribution implicitly encodes local speaking rate — words with longer mel durations receive more fillers, and the model learns to associate phone clusters with appropriate durational context.


2. Text Encoder — ConvNeXt V2

flowchart TD
    In["Input\nB × C × T"] --> DWConv["Depthwise Conv1d  kernel=7\n(per-channel local context)"]
    DWConv --> LN[LayerNorm]
    LN --> Expand["Linear  C → 4C"]
    Expand --> Split["Split  4C → x₁ (2C)  +  x₂ (2C)"]

    Split -->|x₁| GELU[GELU]
    Split -->|x₂| Mul["⊗"]
    GELU --> Mul

    Mul --> GRN["GRN  (Global Response Normalization)\nnormalize each channel by its global L₂ norm"]
    GRN --> Contract["Linear  2C → C"]
    Contract --> Add((" + "))
    In --> Add
    Add --> Out["Output\nB × C × T"]

    subgraph Stack["Full Text Encoder  (C = 512)"]
        direction LR
        B1["Block 1"] --> B2["Block 2"] --> B3["..."] --> B6["Block 6"]
        B6 --> ProjD["Linear  C → D\n(D = 1024)"]
        ProjD --> ET["Text Repr\nB × T × 1024"]
    end
Loading

Each block applies a depthwise Conv1d (kernel size 7) for local context, followed by LayerNorm, a pointwise expansion (C → 4C), a GELU gating split (x = GELU(x₁) ⊗ x₂), Global Response Normalization, and a pointwise projection back to C. The stack of 6 blocks feeds into a final linear projection to D=1024, aligned to the DiT's hidden dimension. The GRN (Global Response Normalization) layer normalizes each channel relative to its global L2 norm across the time axis, suppressing redundant activations and encouraging sparse feature selectivity — a key improvement in ConvNeXt V2 over V1.

The encoder output is added directly to the projected mel input inside the DiT. There is no cross-attention between text and mel; the text representation is spatially aligned with the mel frames through the filler strategy.


3. Dual-Level Language ID (LID) Injection

Language conditioning is injected at two distinct points so that the model can distinguish phonological characteristics (text level) and prosodic/acoustic characteristics (time level) separately.

flowchart TD
    LangID["Language ID\n0 = Korean · 1 = English · 2 = Japanese"]
    LangID --> Emb["Language Embedding Table\ne_LID ∈ ℝ¹²⁸"]

    Emb -->|Time Branch| Concat["Concat  [ h_t ‖ e_LID ]\nℝ^(D+128)"]
    HT["Timestep Embedding\nh_t ∈ ℝ^D"] --> Concat
    Concat --> TimeProj["SiLU(Linear  D+128 → D)\nzero-initialized weights & bias"]
    TimeProj --> HTPrime["h_t'  ∈ ℝ^D"]
    HTPrime --> DiTBlocks["AdaLN conditioning\nin all 22 DiT blocks"]

    Emb -->|Text Branch| Gamma["film_γ  Linear(128 → D)\nzero-initialized"]
    Emb --> Beta["film_β  Linear(128 → D)\nzero-initialized"]
    ET["Text Encoder Output  e_T"] --> FiLM["FiLM:  γ ⊗ e_T + β"]
    Gamma --> FiLM
    Beta --> FiLM
    FiLM --> ETPrime["e_T'  added to DiT input\n(identity at step 0: γ=0, β=0)"]
Loading

A language embedding e_LID ∈ ℝ¹²⁸ is looked up from a 3-entry table and routed through two independent branches:

  • Time-level: e_LID is concatenated with the timestep embedding h_t and projected back to D via SiLU(Linear(D+128 → D)). The result h_t' replaces h_t as the AdaLN conditioning vector fed into every DiT block.
  • Text-level: e_LID drives a FiLM layer — γ = Linear(128→D)(e_LID) and β = Linear(128→D)(e_LID) — that modulates the text encoder output as e_T' = γ ⊗ e_T + β before it is added to the mel input.

Zero-initialization: All four linear layers in the LID module (time_proj, film_gamma, film_beta) are initialized to zero weights and zero biases. At step 0, h_t' = SiLU(0) = 0 and e_T' = 0 ⊗ e_T + 0 = 0, so the model behaves identically to the pretrained F5-TTS for the first gradient updates. Language-specific features emerge gradually, reducing the risk of catastrophic forgetting.


4. Diffusion Transformer (DiT)

flowchart TD
    x["Input  x\nB × T × D"] --> xAdd((" + "))
    eT["Text Repr  e_T'\n(zero-padded for ref region)"] --> xAdd
    xAdd --> xCombined["Combined Input"]

    HTPrime["h_t'  (from LID Injection)"] --> AdaLN["AdaLN-Zero\n→  α₁, β₁, γ₁,  α₂, β₂, γ₂\n(output proj zero-initialized)"]

    xCombined --> LN1[LayerNorm]
    LN1 --> SS1["scale & shift  (α₁, β₁)"]
    AdaLN --> SS1
    SS1 --> Attn["Multi-Head Self-Attention\n16 heads  ·  RoPE on Q, K"]
    Attn --> Gate1["× γ₁"]
    AdaLN --> Gate1
    Gate1 --> Add1((" + "))
    xCombined --> Add1

    Add1 --> LN2[LayerNorm]
    LN2 --> SS2["scale & shift  (α₂, β₂)"]
    AdaLN --> SS2
    SS2 --> FFN["Feed-Forward Network\nD → 4D → D"]
    FFN --> Gate2["× γ₂"]
    AdaLN --> Gate2
    Gate2 --> Add2((" + "))
    Add1 --> Add2

    Add2 --> Out["Output\nB × T × D"]
Loading

The input is x_cat = [ref_mel | xₜ] projected to [B, R+T, D] via a linear layer, then summed with the text representation e_T' (zero-padded over the reference region). This combined tensor passes through 22 identical DiT blocks before a zero-initialized output projection produces the velocity field v_θ ∈ ℝ^{100×(R+T)}.

Attention: Standard multi-head self-attention (16 heads, head dim = 64) with Rotary Position Embeddings (RoPE). RoPE is applied to queries and keys only, encoding relative positional bias directly into the attention logits without learned position embeddings.

AdaLN-Zero: The output projection of the AdaLN parameter network is zero-initialized, so all gates γ start at 0. This means each block contributes zero residual at initialization, and the full model behaves as the identity function from the F5-TTS perspective at the start of fine-tuning.

Ref/Target masking in text repr: The text encoder output e_T' is zero-padded for the first R positions (the reference region) and contains the actual phone+filler representation only for positions R through R+T-1 (the target region). This ensures the DiT cannot "cheat" by reading text features from the reference segment — it must extract speaker characteristics from the mel frames of the reference audio directly.

F5-TTS checkpoint loading: Weights are transferred by matching parameter names and shapes. Parameters whose shapes differ (new IPA embedding rows, LID layers, final output projection if vocabulary sizes differ) are kept at their initialized values.


5. Optimal Transport CFM (OT-CFM)

OT-CFM replaces DDPM's stochastic noising schedule with a straight-line probability path between noise and data. This produces a simpler vector field that is easier to learn and can be accurately integrated with very few Euler steps.

Probability path:

ψ_t(x₀, x₁) = (1 − t)·x₀ + t·x₁,   t ∈ [0, 1]

x₀ is sampled from N(0, I) and x₁ is the ground-truth mel spectrogram. At t=0 the sample is pure noise; at t=1 it is the clean mel.

Training objective:

L = E_{t, x₀, x₁} [ ‖ v_θ(ψ_t(x₀, x₁), t) − (x₁ − x₀) ‖² ]

The target vector field (x₁ − x₀) is constant along the straight-line path, which means the model only needs to learn a single direction per sample rather than a time-varying score function.

Loss is computed only on the target region, not the reference:

loss_mask[i, :, ref_len : ref_len + mel_len] = 1.0
loss = (v_pred - target_full* loss_mask / mask_sum

Inference — Euler ODE with sway sampling:

dx/dt = v_θ(x, t)

Euler step:  x_{t+Δt} = xₜ + Δt · v_θ(xₜ, t)

Standard uniform timesteps space ODE steps evenly, but the vector field tends to change more rapidly near t=0 (noisy) and t=1 (clean). Sway sampling redistributes timesteps to concentrate more steps where the field is most curved:

t_sway = t + α · (cos(π/2 · t) − 1 + t),   α = −1.0

With α = −1.0, the schedule is denser near t=0 and sparser near t=1, matching typical OT-CFM field curvature. NFE = 16 is sufficient for high quality output.


6. Asymmetric CFG (Classifier-Free Guidance)

Standard CFG applies a single guidance weight uniformly across all timesteps. This is suboptimal for TTS because the model needs different amounts of text and acoustic guidance at different phases of generation — early steps establish coarse structure, late steps refine fine acoustic detail.

Three forward passes are made at each inference step:

Pass Conditioning Interpretation
v(ψ; A, T, L) ref mel + text + lang Full conditioning
v(ψ; T, L) no ref mel + text + lang Text/language only (no acoustic)
v(ψ; ∅) nothing Unconditional

Guided vector field:

v_DCFG = v(A,T,L)
       + wA(t) · [v(A,T,L) − v(T,L)]      ← acoustic guidance
       + wL(t) · [v(T,L)   − v(∅)]         ← linguistic guidance

The two guidance weights follow independent time schedules:

Acoustic guidance wA(t) — controls speaker identity fidelity:

         ┌ wA = 2.5                                          if t ≤ 0.6
wA(t) =  │
         └ 2.5 · (1 − (t − 0.6) / 0.4)²                    if t > 0.6

Acoustic guidance stays at full strength throughout most of the trajectory (where the coarse speaker timbre is being established) and is smoothly turned off at the end (where fine spectral details are filled in and over-constraining the speaker would hurt naturalness).

Linguistic guidance wL(t) — controls text intelligibility:

         ┌ 4.0 · (t / 0.01)                                 if t ≤ 0.01
wL(t) =  │ 4.0                                              if 0.01 < t ≤ 0.6
         └ 4.0 · (1 − (t − 0.6) / 0.4)²                    if t > 0.6

The brief linear warmup at t=0.01 prevents large linguistic gradients at the very start (when the sample is nearly pure noise and the text signal has no meaningful structure to push against). Full linguistic guidance runs through the bulk of generation, then tapers off as the waveform is being refined into its final acoustic form.


7. Vocos Vocoder

The model operates entirely in the 100-band log-mel spectrogram domain at 24000 Hz. Conversion to a waveform uses a pretrained, frozen Vocos model from charactr/vocos-mel-24khz (HuggingFace Hub), which internally uses a ConvNeXt backbone and an ISTFT head to reconstruct the waveform at 24000 Hz. The mel parameters (24 kHz, 100 bands) are chosen to match Vocos's native input specification exactly, eliminating any resampling at vocoder decode time.

Vocos is not trained or fine-tuned during PolyFlow-TTS training — all gradients stop at the mel spectrogram boundary. This means vocoder quality is bounded by the upstream Vocos checkpoint, but it also means the TTS model sees a stable, fixed target representation during training.


Model Scale

Component Parameters
Text Encoder (ConvNeXt V2, 6 layers, C=512) ~14M
DiT (22 layers, D=1024, 16 heads) ~330M
LID Injection ~0.4M
Total (trainable) ~345M
Vocos (frozen) ~13M

Dataset

Language Source Hours Batch share
Korean Emilia-KO ~500h 50%
English LibriTTS-R ~300h 25%
Japanese ReazonSpeech ~200h 25%

Requirements

torch>=2.0
torchaudio
transformers>=4.35.0
phonemizer
g2pk2
pyopenjtalk
fugashi
unidic-lite
silero-vad
dnsmos
jiwer
wandb
hydra-core>=1.3.0
omegaconf
vocos
einops
scipy
numpy
pandas

Install espeak-ng (required by phonemizer for English G2P):

# macOS
brew install espeak-ng

# Ubuntu / Debian
sudo apt-get install espeak-ng

Install Python dependencies:

pip install -r requirements.txt

Data Preparation

Place raw data under data/ with the following layout:

data/
├── korean/
│   ├── audio/        # Korean .wav files
│   └── metadata.csv  # columns: filename, text
├── english/
│   ├── audio/        # English .wav files
│   └── metadata.csv
└── japanese/
    ├── audio/        # Japanese .wav files
    └── metadata.csv

metadata.csv format:

filename,text
sample_001.wav,안녕하세요 반갑습니다
sample_002.wav,오늘 날씨가 맑습니다

Usage

1. Preprocessing

python preprocess.py --data_dir data/

This runs the full pipeline:

  1. Resample to 24000 Hz, convert to mono
  2. Trim silence with Silero VAD
  3. Filter by duration (0.5s – 15s)
  4. Filter by DNSMOS P.835 score (≥ 3.5)
  5. Transcribe with Whisper large-v3, remove samples with CER > 5%
  6. Remove speaking rate outliers per language (IQR × 1.5)
  7. Convert text to IPA via language-specific G2P, apply language prefix (ko_*, en_*, ja_*)
  8. Extract 100-band log-mel spectrograms and save as .npy
  9. Build language-prefixed IPA vocabulary and save vocab.json

To skip the Whisper transcription step (faster, uses less memory):

python preprocess.py --data_dir data/ --no_whisper

Preprocessing is resumable — already-processed files are detected and skipped on restart.

Output layout:

data_processed/
├── korean/
│   ├── mels/         # .npy files (100 × T)
│   ├── ipa/          # .txt files (language-prefixed IPA tokens)
│   └── metadata.csv  # filename, ipa, mel_len, text_len, duration
├── english/
├── japanese/
└── vocab.json        # {"<PAD>": 0, ..., "ko_a": 6, "en_æ": 7, ...}

2. Training

python train.py --data_dir data_processed/

Fine-tunes from the F5-TTS base checkpoint specified in configs/config.yaml (f5_checkpoint). If no checkpoint is set, trains from scratch.

Key training details:

  • DiT blocks are frozen for the first 10K steps while LID layers and new IPA embeddings warm up
  • AdamW optimizer, lr = 2e-5 with 20K-step linear warmup and linear decay
  • bfloat16 mixed precision, DDP multi-GPU supported
  • Checkpoints saved every 1000 steps to checkpoints/

Resume training:

python train.py resume=checkpoints/latest.pt

Multi-GPU (DDP):

torchrun --nproc_per_node=4 train.py

With WandB logging:

python train.py wandb=true wandb_project=my-project

Override any config value via Hydra:

python train.py training.learning_rate=1e-5 training.total_steps=500000

3. Inference

python inference.py \
  --text "안녕하세요, 반갑습니다." \
  --lang ko \
  --ref_audio ref.wav \
  --ref_text "안녕하세요 저는 아나운서입니다." \
  --checkpoint checkpoints/step_0100000.pt \
  --output output.wav

If --ref_text is omitted, Whisper large-v3 automatically transcribes ref_audio to estimate the reference speaking rate.

Argument Description
--text Input text to synthesize
--lang Language code: ko, en, ja
--ref_audio Reference WAV for voice cloning (3–10s recommended)
--ref_text Transcript of ref_audio. Omit to auto-transcribe with Whisper
--checkpoint Path to trained checkpoint
--output Output WAV path (default: output.wav)
--nfe Number of ODE steps (default: 16)
--sway Sway sampling coefficient (default: -1.0)
--no_cfg Disable asymmetric CFG guidance
--fast Use single-weight CFG (2 DiT passes) instead of asymmetric CFG (3 passes)

Target length estimation: the model estimates how many mel frames to generate based on the reference speaker's pace:

mel_len = ref_mel_len × (target_phone_count / ref_phone_count)
mel_len = clamp(mel_len, target_phone_count × 3, target_phone_count × 20)

Providing an accurate --ref_text gives a better speaking rate estimate and therefore more natural prosody. If Whisper is used as a fallback, transcription adds ~5s latency on first run (model download on first use).

Configuration

All hyperparameters are managed through configs/config.yaml (Hydra). Key sections:

audio:
    sample_rate: 24000
    n_mels: 100
    hop_length: 256

model:
    hidden_dim: 1024
    n_layers: 22
    n_heads: 16

cfm:
    nfe: 16
    sway_coef: -1.0

cfg:
    w_acoustic: 2.5
    w_linguistic: 4.0
    t_decay: 0.6

training:
    learning_rate: 2.0e-5
    warmup_steps: 20000
    dit_freeze_steps: 10000

f5_checkpoint: /path/to/f5_base.pt

Logging & Notifications

Training uses ASHi for unified logging. Copy .env.example to .env and fill in whichever backends you want — any key left blank is silently skipped.

cp .env.example .env
# Discord / Slack / Telegram — fill in the ones you want, leave the rest blank
DISCORD_WEBHOOK=
SLACK_WEBHOOK=
TELEGRAM_TOKEN=
TELEGRAM_CHAT_ID=

# Weights & Biases — setting WANDB_API_KEY activates WandbBackend automatically
WANDB_API_KEY=
WANDB_PROJECT=polyflow-tts

ASHi sends notifications at training start/end, every checkpoint save, and on error. At least one backend must be configured for notifications to work; if all keys are blank, logging falls back to stdout only.

Mel Spectrogram Parameters

Parameter Value
Sample rate 24000 Hz
FFT size 1024
Hop length 256
Mel bins 100
Vocoder output SR 24000 Hz (Vocos)

Future Work

  • Speaking Rate Predictor. The current mel length estimation uses a simple proportional rule derived from the reference speaker's pace (mel_len = ref_mel_len × target_phones / ref_phones). This is sufficient for short references with consistent speaking rate but breaks down when the reference clip has variable tempo or is too short to be representative. A dedicated Speaking Rate Predictor module — trained to predict per-phone duration from the text and speaker embedding — would allow finer-grained control over pacing and improve naturalness, especially for expressive or conversational speech styles.

References

  • Chen, S. et al. F5-TTS: A Fairytaler that Fakes Fluent and Faithful Speech with Flow Matching. arXiv:2410.06885 (2025). https://arxiv.org/abs/2410.06885

  • Zhao, Y. et al. X-Voice: Enabling Everyone to Speak 30 Languages via Zero-Shot Cross-Lingual Voice Cloning. arXiv:2605.05611 (2026). https://arxiv.org/abs/2605.05611

License

This project builds on F5-TTS (MIT License), Vocos (MIT License), and datasets under their respective terms. See individual dataset licenses before commercial use.

About

Zero-shot voice cloning TTS for Korean, English, and Japanese, built on F5-TTS with OT-CFM, Dual-Level LID Injection, and Asymmetric CFG.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages