Skip to content

[TTS][Magpietts] Added local transformer CFG distillation#15748

Open
artem-gorodetskii wants to merge 1 commit into
mainfrom
magpietts_cfg_distillation_lt_term
Open

[TTS][Magpietts] Added local transformer CFG distillation#15748
artem-gorodetskii wants to merge 1 commit into
mainfrom
magpietts_cfg_distillation_lt_term

Conversation

@artem-gorodetskii
Copy link
Copy Markdown
Contributor

@artem-gorodetskii artem-gorodetskii commented Jun 3, 2026

Summary

This PR extends online classifier-free guidance (CFG) distillation for MagpieTTS with optional local transformer (LT) distillation.

What’s included

  • Added optional local transformer distillation to OnlineCFGDistillation.
  • Extended teacher rollout generation to optionally produce local transformer codes and logits.
  • Extended student teacher-forced decoding to optionally compute local transformer logits for distillation.
  • Added support for local transformer loss computation using the same distillation objectives: KL-divergence, cross-entropy and normalized RMSE.
  • Added configuration options to control local transformer distillation: enable/disable LT distillation, LT loss weight, LT distillation start step, LT loss ramp length.

Training flow

  • The teacher model is loaded from checkpoint and frozen.
  • The teacher performs autoregressive CFG rollout generation.
  • When enabled, the teacher also generates local transformer targets and logits aligned with the rollout.
  • The generated rollout is fed into the student in teacher-forced mode.
  • The student computes decoder logits and, optionally, local transformer logits.
  • The total training loss is computed from: the main decoder distillation loss, optional MoE auxiliary loss, optional local transformer distillation loss mixed into the final loss according to the configured LT schedule.

@artem-gorodetskii artem-gorodetskii self-assigned this Jun 3, 2026
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 3, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@github-actions github-actions Bot added the TTS label Jun 3, 2026
Copy link
Copy Markdown
Collaborator

@blisc blisc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good from my end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants