fix(akv_video): preserve encoder logvar in KVAE video VAE (#13652)#13657
Open
Anai-Guo wants to merge 1 commit intohuggingface:mainfrom
Open
fix(akv_video): preserve encoder logvar in KVAE video VAE (#13652)#13657Anai-Guo wants to merge 1 commit intohuggingface:mainfrom
Anai-Guo wants to merge 1 commit intohuggingface:mainfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes Issue 1 in #13652.
AutoencoderKLKVAEVideowas discarding the encoder log-variance, solatent_dist.logvarwas always zero andsample_posterior=Truesampled from the wrong distribution. The static-image variantAutoencoderKLKVAEalready keeps the full encoder output and letsDiagonalGaussianDistributionsplit it into mean/logvar — this PR aligns the video variant with that path.Root cause
KVAECachedEncoder3Doutputs2 * z_channelsper chunk (mean and logvar concatenated), matching the standard KL-VAE convention. But_encode()was discarding the second half:…and
encode()then padded the missing half with zeros before constructing the posterior:So checkpoint logvar weights were silently ignored, posterior sampling used
exp(0/2) = 1.0as the std everywhere, and parity with the upstream KVAE 3D VAE was broken.Fix
Mirror the
autoencoder_kl_kvae.pypattern: keep the full encoder output and letDiagonalGaussianDistributiondo the chunk-and-clamp itself.DiagonalGaussianDistribution.__init__doesself.mean, self.logvar = torch.chunk(parameters, 2, dim=1), so this is functionally identical to the static-image variant.Verification
The reproduction snippet from #13652 now shows non-zero logvar and
posterior.meanmatching the raw encoder mean half:posterior.sample()shape and dtype are unchanged; the only behavioral change is that posterior.logvar,.std, and stochastic samples now reflect the actual encoder output instead of constant zeros.🤖 Generated with Claude Code